mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 23:41:12 +08:00
feat[Go]: implement searches/<search_id>/completions POST (#16440)
### Summary As title ### Type of change - [x] New Feature (non-breaking change which adds functionality) - [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
@@ -229,6 +229,7 @@ func startServer(config *server.Config) {
|
||||
systemService := service.NewSystemService()
|
||||
connectorService := service.NewConnectorService()
|
||||
searchService := service.NewSearchService()
|
||||
searchService.SetTenantService(tenantService)
|
||||
fileService := service.NewFileService()
|
||||
memoryService := service.NewMemoryService()
|
||||
mcpService := service.NewMCPService()
|
||||
@@ -303,7 +304,9 @@ func startServer(config *server.Config) {
|
||||
chunkService,
|
||||
)
|
||||
searchBotHandler.SetStreamLLM(searchBotLLM)
|
||||
searchBotHandler.SetAskService(service.NewAskService(chunkService, nil, 0, 0))
|
||||
askService := service.NewAskService(chunkService, nil, 0, 0)
|
||||
searchBotHandler.SetAskService(askService)
|
||||
searchHandler.SetCompletionDependencies(searchBotLLM, askService)
|
||||
pluginHandler := handler.NewPluginHandler(service.NewPluginService())
|
||||
modelHandler := handler.NewModelHandler(service.NewModelProviderService())
|
||||
fileCommitHandler := handler.NewFileCommitHandler(service.NewFileCommitService())
|
||||
|
||||
@@ -30,6 +30,9 @@ import (
|
||||
type SearchHandler struct {
|
||||
searchService *service.SearchService
|
||||
userService *service.UserService
|
||||
streamLLM streamingLLM
|
||||
askService *service.AskService
|
||||
sseWriter SSEWriter
|
||||
}
|
||||
|
||||
// NewSearchHandler create search handler
|
||||
@@ -37,9 +40,16 @@ func NewSearchHandler(searchService *service.SearchService, userService *service
|
||||
return &SearchHandler{
|
||||
searchService: searchService,
|
||||
userService: userService,
|
||||
sseWriter: &ginSSEWriter{},
|
||||
}
|
||||
}
|
||||
|
||||
// SetCompletionDependencies wires the streaming search completion runtime.
|
||||
func (h *SearchHandler) SetCompletionDependencies(streamLLM streamingLLM, askService *service.AskService) {
|
||||
h.streamLLM = streamLLM
|
||||
h.askService = askService
|
||||
}
|
||||
|
||||
// ListSearches list search apps
|
||||
// @Summary List Search Apps
|
||||
// @Description Get list of search apps for the current user with filtering, pagination and sorting
|
||||
@@ -421,3 +431,87 @@ func (h *SearchHandler) UpdateSearch(c *gin.Context) {
|
||||
"message": "success",
|
||||
})
|
||||
}
|
||||
|
||||
func (h *SearchHandler) Completion(c *gin.Context) {
|
||||
user, errorCode, errorMessage := GetUser(c)
|
||||
if errorCode != common.CodeSuccess {
|
||||
jsonError(c, errorCode, errorMessage)
|
||||
return
|
||||
}
|
||||
|
||||
var req service.SearchCompletionsRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
jsonError(c, common.CodeArgumentError, "question is required")
|
||||
return
|
||||
}
|
||||
|
||||
searchSvc := h.searchService
|
||||
if searchSvc == nil {
|
||||
searchSvc = service.NewSearchService()
|
||||
}
|
||||
|
||||
plan, code, err := searchSvc.PrepareCompletion(user.ID, c.Param("search_id"), &req)
|
||||
if err != nil {
|
||||
if code == common.CodeAuthenticationError {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": code,
|
||||
"data": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
if code == common.CodeServerError {
|
||||
jsonInternalError(c, err)
|
||||
return
|
||||
}
|
||||
jsonError(c, code, err.Error())
|
||||
return
|
||||
}
|
||||
if plan == nil {
|
||||
jsonError(c, common.CodeServerError, "completion plan is nil")
|
||||
return
|
||||
}
|
||||
|
||||
disableWriteDeadlineForSSE(c)
|
||||
c.Header("Content-Type", "text/event-stream; charset=utf-8")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("X-Accel-Buffering", "no")
|
||||
|
||||
writer := h.sseWriter
|
||||
if writer == nil {
|
||||
writer = &ginSSEWriter{}
|
||||
}
|
||||
if plan.ModelID == "" {
|
||||
writer.Write(c, sseError("chat model not configured"))
|
||||
return
|
||||
}
|
||||
if h.askService == nil {
|
||||
writer.Write(c, sseError("ask service not configured"))
|
||||
return
|
||||
}
|
||||
if h.streamLLM == nil {
|
||||
writer.Write(c, sseError("streaming LLM not configured"))
|
||||
return
|
||||
}
|
||||
|
||||
adapter := &askStreamAdapter{llm: h.streamLLM, tenantID: plan.UserID, modelID: plan.ModelID}
|
||||
|
||||
hadError := false
|
||||
for delta := range h.askService.StreamWithOptions(c.Request.Context(), adapter, plan.UserID, plan.Question, plan.KBIDs, plan.Options) {
|
||||
switch delta.Kind {
|
||||
case service.AskDeltaAnswer:
|
||||
writer.Write(c, sseAnswer(delta.Value, nil, false))
|
||||
case service.AskDeltaMarker:
|
||||
writer.Write(c, sseMarker(delta.Value))
|
||||
case service.AskDeltaError:
|
||||
hadError = true
|
||||
writer.Write(c, sseError(delta.Value))
|
||||
case service.AskDeltaFinal:
|
||||
writer.Write(c, sseAnswer(delta.Value, delta.Refs, true))
|
||||
}
|
||||
}
|
||||
if !hadError {
|
||||
writer.Write(c, "data: {\"code\": 0, \"message\": \"\", \"data\": true}\n\n")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -364,6 +364,8 @@ func (r *Router) Setup(engine *gin.Engine) {
|
||||
searches.GET("/:search_id", r.searchHandler.GetSearch)
|
||||
searches.PUT("/:search_id", r.searchHandler.UpdateSearch)
|
||||
searches.DELETE("/:search_id", r.searchHandler.DeleteSearch)
|
||||
searches.POST("/:search_id/completion", r.searchHandler.Completion)
|
||||
searches.POST("/:search_id/completions", r.searchHandler.Completion)
|
||||
}
|
||||
|
||||
file := v1.Group("/files")
|
||||
|
||||
@@ -53,6 +53,22 @@ type AskDelta struct {
|
||||
Refs interface{} // populated on AskDeltaFinal: {chunks, doc_aggs}
|
||||
}
|
||||
|
||||
// AskStreamOptions carries optional retrieval settings supplied by saved
|
||||
// search_config. Zero values keep the same defaults as Stream.
|
||||
type AskStreamOptions struct {
|
||||
SearchID string
|
||||
DocIDs []string
|
||||
UseKG *bool
|
||||
TopK *int
|
||||
CrossLanguages []string
|
||||
Filter map[string]interface{}
|
||||
TenantRerankID *string
|
||||
RerankID *string
|
||||
Keyword *bool
|
||||
SimilarityThreshold *float64
|
||||
VectorSimilarityWeight *float64
|
||||
}
|
||||
|
||||
// Retriever abstracts chunk retrieval for AskService.
|
||||
type Retriever interface {
|
||||
RetrievalTest(req *RetrievalTestRequest, userID string) (*RetrievalTestResponse, error)
|
||||
@@ -91,22 +107,51 @@ func NewAskService(retriever Retriever, embedder Embedder, tokenBudget, minStrea
|
||||
// Stream runs the full ask pipeline. llm must not be nil. The returned
|
||||
// channel is closed when the pipeline completes or ctx is cancelled.
|
||||
func (s *AskService) Stream(ctx context.Context, llm StreamingLLM, userID, question string, kbIDs []string) <-chan AskDelta {
|
||||
return s.StreamWithOptions(ctx, llm, userID, question, kbIDs, AskStreamOptions{})
|
||||
}
|
||||
|
||||
// StreamWithOptions runs Stream while allowing callers such as saved Search
|
||||
// apps to pass search_config retrieval options through to RetrievalTest.
|
||||
func (s *AskService) StreamWithOptions(ctx context.Context, llm StreamingLLM, userID, question string, kbIDs []string, opts AskStreamOptions) <-chan AskDelta {
|
||||
out := make(chan AskDelta, 32)
|
||||
go func() {
|
||||
defer close(out)
|
||||
s.run(ctx, llm, userID, question, kbIDs, out)
|
||||
s.run(ctx, llm, userID, question, kbIDs, opts, out)
|
||||
}()
|
||||
return out
|
||||
}
|
||||
|
||||
func (s *AskService) run(ctx context.Context, llm StreamingLLM, userID, question string, kbIDs []string, out chan<- AskDelta) {
|
||||
func (s *AskService) run(ctx context.Context, llm StreamingLLM, userID, question string, kbIDs []string, opts AskStreamOptions, out chan<- AskDelta) {
|
||||
// Phase 1: Retrieval.
|
||||
topK := DefaultAskTopK
|
||||
if opts.TopK != nil {
|
||||
topK = *opts.TopK
|
||||
}
|
||||
similarityThreshold := DefaultAskSimilarityThreshold
|
||||
if opts.SimilarityThreshold != nil {
|
||||
similarityThreshold = *opts.SimilarityThreshold
|
||||
}
|
||||
vectorSimilarityWeight := DefaultAskVectorSimilarityWeight
|
||||
if opts.VectorSimilarityWeight != nil {
|
||||
vectorSimilarityWeight = *opts.VectorSimilarityWeight
|
||||
}
|
||||
|
||||
req := &RetrievalTestRequest{
|
||||
Datasets: common.StringSlice(kbIDs),
|
||||
Question: question,
|
||||
TopK: ptrInt(DefaultAskTopK),
|
||||
SimilarityThreshold: ptrFloat64(DefaultAskSimilarityThreshold),
|
||||
VectorSimilarityWeight: ptrFloat64(DefaultAskVectorSimilarityWeight),
|
||||
DocIDs: opts.DocIDs,
|
||||
UseKG: opts.UseKG,
|
||||
TopK: ptrInt(topK),
|
||||
CrossLanguages: opts.CrossLanguages,
|
||||
Filter: opts.Filter,
|
||||
TenantRerankID: opts.TenantRerankID,
|
||||
RerankID: opts.RerankID,
|
||||
Keyword: opts.Keyword,
|
||||
SimilarityThreshold: ptrFloat64(similarityThreshold),
|
||||
VectorSimilarityWeight: ptrFloat64(vectorSimilarityWeight),
|
||||
}
|
||||
if opts.SearchID != "" {
|
||||
req.SearchID = &opts.SearchID
|
||||
}
|
||||
page := DefaultAskPage
|
||||
ps := DefaultAskPageSize
|
||||
|
||||
@@ -872,16 +872,24 @@ func parseMessages(raw json.RawMessage) []map[string]interface{} {
|
||||
return messages
|
||||
}
|
||||
|
||||
var wrapped struct {
|
||||
Messages []map[string]interface{} `json:"messages"`
|
||||
}
|
||||
var wrapped map[string]json.RawMessage
|
||||
if err := json.Unmarshal(raw, &wrapped); err != nil {
|
||||
return nil
|
||||
}
|
||||
wrappedMessages, ok := wrapped["messages"]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
if len(wrappedMessages) == 0 || string(wrappedMessages) == "null" {
|
||||
return make([]map[string]interface{}, 0)
|
||||
}
|
||||
if wrapped.Messages == nil {
|
||||
if err := json.Unmarshal(wrappedMessages, &messages); err != nil {
|
||||
return nil
|
||||
}
|
||||
if messages == nil {
|
||||
return make([]map[string]interface{}, 0)
|
||||
}
|
||||
return wrapped.Messages
|
||||
return messages
|
||||
}
|
||||
|
||||
func parseReferenceList(raw json.RawMessage) []interface{} {
|
||||
@@ -891,7 +899,7 @@ func parseReferenceList(raw json.RawMessage) []interface{} {
|
||||
}
|
||||
err := json.Unmarshal(raw, &references)
|
||||
if err != nil {
|
||||
return make([]interface{}, 0)
|
||||
return nil
|
||||
}
|
||||
if references == nil {
|
||||
return make([]interface{}, 0)
|
||||
|
||||
@@ -1558,12 +1558,11 @@ func TestBuildSessionPayload_EmptyCollectionsEncodeAsEmptyArrays(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseCollections_ReturnEmptySlicesForMissingNullOrInvalid(t *testing.T) {
|
||||
func TestParseCollections_ReturnEmptySlicesForMissingOrNull(t *testing.T) {
|
||||
messageInputs := []json.RawMessage{
|
||||
nil,
|
||||
json.RawMessage(`null`),
|
||||
json.RawMessage(`{"messages":null}`),
|
||||
json.RawMessage(`not-json`),
|
||||
}
|
||||
for _, input := range messageInputs {
|
||||
got := parseMessages(input)
|
||||
@@ -1575,7 +1574,6 @@ func TestParseCollections_ReturnEmptySlicesForMissingNullOrInvalid(t *testing.T)
|
||||
referenceInputs := []json.RawMessage{
|
||||
nil,
|
||||
json.RawMessage(`null`),
|
||||
json.RawMessage(`not-json`),
|
||||
}
|
||||
for _, input := range referenceInputs {
|
||||
got := parseReferenceList(input)
|
||||
@@ -1585,6 +1583,28 @@ func TestParseCollections_ReturnEmptySlicesForMissingNullOrInvalid(t *testing.T)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseCollections_ReturnNilForMalformedData(t *testing.T) {
|
||||
messageInputs := []json.RawMessage{
|
||||
json.RawMessage(`not-json`),
|
||||
json.RawMessage(`{"unexpected":[]}`),
|
||||
}
|
||||
for _, input := range messageInputs {
|
||||
if got := parseMessages(input); got != nil {
|
||||
t.Fatalf("parseMessages(%s)=%#v, want nil", string(input), got)
|
||||
}
|
||||
}
|
||||
|
||||
referenceInputs := []json.RawMessage{
|
||||
json.RawMessage(`not-json`),
|
||||
json.RawMessage(`{"unexpected":[]}`),
|
||||
}
|
||||
for _, input := range referenceInputs {
|
||||
if got := parseReferenceList(input); got != nil {
|
||||
t.Fatalf("parseReferenceList(%s)=%#v, want nil", string(input), got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompletionStream_EmptyMessages(t *testing.T) {
|
||||
svc := &ChatSessionService{
|
||||
chatSessionDAO: &fakeSessionStore{},
|
||||
|
||||
@@ -77,6 +77,17 @@ type chunkImageMergeLock struct {
|
||||
refs int
|
||||
}
|
||||
|
||||
func searchConfigMap(value interface{}) (map[string]interface{}, bool) {
|
||||
switch typed := value.(type) {
|
||||
case entity.JSONMap:
|
||||
return map[string]interface{}(typed), true
|
||||
case map[string]interface{}:
|
||||
return typed, true
|
||||
default:
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
|
||||
// ChunkService chunk service
|
||||
type ChunkService struct {
|
||||
docEngine engine.DocEngine
|
||||
@@ -200,9 +211,9 @@ func (s *ChunkService) RetrievalTest(req *service.RetrievalTestRequest, userID s
|
||||
|
||||
// Check if all kbs have the same embedding model
|
||||
if len(kbRecords) > 1 {
|
||||
firstEmbdID := kbRecords[0].EmbdID
|
||||
firstEmbeddingKey := knowledgebaseEmbeddingKey(kbRecords[0], tenantIDs[0])
|
||||
for i := 1; i < len(kbRecords); i++ {
|
||||
if kbRecords[i].EmbdID != firstEmbdID {
|
||||
if knowledgebaseEmbeddingKey(kbRecords[i], tenantIDs[i]) != firstEmbeddingKey {
|
||||
return nil, fmt.Errorf("cannot retrieve across datasets with different embedding models")
|
||||
}
|
||||
}
|
||||
@@ -218,8 +229,8 @@ func (s *ChunkService) RetrievalTest(req *service.RetrievalTestRequest, userID s
|
||||
searchDetail, err := s.searchService.GetDetail(*req.SearchID)
|
||||
if err != nil {
|
||||
common.Warn("Failed to get search detail for search_id, proceeding without it", zap.String("searchID", *req.SearchID), zap.Error(err))
|
||||
} else if searchConfig, ok := searchDetail["search_config"].(entity.JSONMap); ok && searchConfig != nil {
|
||||
if searchMetaFilter, ok := searchConfig["meta_data_filter"].(map[string]interface{}); ok {
|
||||
} else if searchConfig, ok := searchConfigMap(searchDetail["search_config"]); ok && searchConfig != nil {
|
||||
if searchMetaFilter, ok := searchConfigMap(searchConfig["meta_data_filter"]); ok {
|
||||
filter = searchMetaFilter
|
||||
}
|
||||
chatID, _ = searchConfig["chat_id"].(string)
|
||||
@@ -345,48 +356,47 @@ func (s *ChunkService) RetrievalTest(req *service.RetrievalTestRequest, userID s
|
||||
labels := metadataSvc.LabelQuestion(modifiedQuestion, kbRecords)
|
||||
common.Debug("LabelQuestion result", zap.Any("labels", labels))
|
||||
|
||||
// Determine embedding model
|
||||
// Determine embedding model.
|
||||
modelProviderSvc := service.NewModelProviderService()
|
||||
var embdID string
|
||||
var tenantLLM *entity.TenantLLM
|
||||
var embeddingModel *models.EmbeddingModel
|
||||
var embdID string
|
||||
if kbRecords[0].TenantEmbdID != nil && *kbRecords[0].TenantEmbdID > 0 {
|
||||
tenantLLM, embdID, err = dao.LookupTenantLLMByID(dao.NewTenantLLMDAO(), *kbRecords[0].TenantEmbdID)
|
||||
_, embdID, err = dao.LookupTenantLLMByID(dao.NewTenantLLMDAO(), *kbRecords[0].TenantEmbdID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get embedding model by tenant_embd_id: %w", err)
|
||||
}
|
||||
driver, modelName, apiConfig, maxTokens, getErr := modelProviderSvc.GetModelConfigFromProviderInstance(tenantIDs[0], entity.ModelTypeEmbedding, embdID)
|
||||
if getErr != nil {
|
||||
return nil, fmt.Errorf("failed to get embedding model by tenant_embd_id: %w", getErr)
|
||||
}
|
||||
embeddingModel = models.NewEmbeddingModel(driver, &modelName, apiConfig, maxTokens)
|
||||
} else if kbRecords[0].EmbdID != "" {
|
||||
if strings.Contains(kbRecords[0].EmbdID, "@") {
|
||||
driver, modelName, apiConfig, maxTokens, embErr := modelProviderSvc.GetModelConfigFromProviderInstance(tenantIDs[0], entity.ModelTypeEmbedding, kbRecords[0].EmbdID)
|
||||
if embErr != nil {
|
||||
return nil, fmt.Errorf("failed to get embedding model by embd_id: %w", embErr)
|
||||
}
|
||||
embeddingModel = models.NewEmbeddingModel(driver, &modelName, apiConfig, maxTokens)
|
||||
embdID = kbRecords[0].EmbdID
|
||||
} else {
|
||||
tenantLLM, embdID, err = dao.LookupTenantLLMByName(dao.NewTenantLLMDAO(), tenantIDs[0], kbRecords[0].EmbdID, entity.ModelTypeEmbedding)
|
||||
embdID = kbRecords[0].EmbdID
|
||||
driver, modelName, apiConfig, maxTokens, getErr := modelProviderSvc.GetModelConfigFromProviderInstance(tenantIDs[0], entity.ModelTypeEmbedding, embdID)
|
||||
if getErr != nil {
|
||||
_, embdID, err = dao.LookupTenantLLMByName(dao.NewTenantLLMDAO(), tenantIDs[0], kbRecords[0].EmbdID, entity.ModelTypeEmbedding)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get embedding model by embd_id: %w", err)
|
||||
return nil, fmt.Errorf("failed to get embedding model by embd_id: %w", getErr)
|
||||
}
|
||||
driver, modelName, apiConfig, maxTokens, getErr = modelProviderSvc.GetModelConfigFromProviderInstance(tenantIDs[0], entity.ModelTypeEmbedding, embdID)
|
||||
if getErr != nil {
|
||||
return nil, fmt.Errorf("failed to get embedding model by embd_id: %w", getErr)
|
||||
}
|
||||
}
|
||||
embeddingModel = models.NewEmbeddingModel(driver, &modelName, apiConfig, maxTokens)
|
||||
} else {
|
||||
tenantLLM, err = dao.NewTenantLLMDAO().GetByTenantAndType(tenantIDs[0], entity.ModelTypeEmbedding)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get tenant default embedding model: %w", err)
|
||||
driver, modelName, apiConfig, maxTokens, getErr := modelProviderSvc.GetTenantDefaultModelByType(tenantIDs[0], entity.ModelTypeEmbedding)
|
||||
if getErr != nil {
|
||||
return nil, fmt.Errorf("failed to get tenant default embedding model: %w", getErr)
|
||||
}
|
||||
if tenantLLM == nil || tenantLLM.LLMName == nil || *tenantLLM.LLMName == "" {
|
||||
return nil, fmt.Errorf("no default embedding model found for tenant %s", tenantIDs[0])
|
||||
}
|
||||
embdID = fmt.Sprintf("%s@%s", *tenantLLM.LLMName, tenantLLM.LLMFactory)
|
||||
embeddingModel = models.NewEmbeddingModel(driver, &modelName, apiConfig, maxTokens)
|
||||
embdID = fmt.Sprintf("%s@default", modelName)
|
||||
}
|
||||
|
||||
// Get embedding model for the tenant
|
||||
if embeddingModel == nil {
|
||||
embeddingModel, err = modelProviderSvc.GetEmbeddingModel(tenantIDs[0], embdID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get embedding model: %w", err)
|
||||
}
|
||||
return nil, fmt.Errorf("no embedding model found for tenant %s", tenantIDs[0])
|
||||
}
|
||||
|
||||
common.Info("Fetched embedding model for retrieval",
|
||||
zap.String("tenantID", tenantIDs[0]),
|
||||
zap.String("embdID", embdID))
|
||||
@@ -405,6 +415,12 @@ func (s *ChunkService) RetrievalTest(req *service.RetrievalTestRequest, userID s
|
||||
}
|
||||
} else if req.RerankID != nil && *req.RerankID != "" {
|
||||
rerankCompositeName = *req.RerankID
|
||||
if _, _, _, _, getErr := modelProviderSvc.GetModelConfigFromProviderInstance(tenantIDs[0], entity.ModelTypeRerank, rerankCompositeName); getErr != nil {
|
||||
_, rerankCompositeName, err = dao.LookupTenantLLMByName(dao.NewTenantLLMDAO(), tenantIDs[0], *req.RerankID, entity.ModelTypeRerank)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get rerank model by rerank_id: %w", getErr)
|
||||
}
|
||||
}
|
||||
}
|
||||
if rerankCompositeName != "" {
|
||||
driver, mdlName, apiConfig, _, getErr := modelProviderSvc.GetModelConfigFromProviderInstance(tenantIDs[0], entity.ModelTypeRerank, rerankCompositeName)
|
||||
@@ -466,6 +482,16 @@ func (s *ChunkService) RetrievalTest(req *service.RetrievalTestRequest, userID s
|
||||
}, nil
|
||||
}
|
||||
|
||||
func knowledgebaseEmbeddingKey(kb *entity.Knowledgebase, tenantID string) string {
|
||||
if kb.TenantEmbdID != nil && *kb.TenantEmbdID > 0 {
|
||||
return fmt.Sprintf("tenant:%d", *kb.TenantEmbdID)
|
||||
}
|
||||
if kb.EmbdID == "" {
|
||||
return fmt.Sprintf("default:%s", tenantID)
|
||||
}
|
||||
return fmt.Sprintf("embd:%s", kb.EmbdID)
|
||||
}
|
||||
|
||||
// hydrateChunkVectors replaces zero (placeholder) vectors in chunks with real
|
||||
// vectors fetched from the engine. Infinity and OceanBase already ship real
|
||||
// vectors with chunks, so this is a no-op for those engines; for ES it queries
|
||||
|
||||
@@ -77,6 +77,55 @@ func TestHydrateChunkVectors_NoDim(t *testing.T) {
|
||||
// Empty vectors have dim=0 → early return. No crash.
|
||||
}
|
||||
|
||||
func TestKnowledgebaseEmbeddingKey(t *testing.T) {
|
||||
tenantEmbdID := int64(42)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
kb *entity.Knowledgebase
|
||||
tenantID string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "uses tenant embedding id before embd id",
|
||||
kb: &entity.Knowledgebase{
|
||||
EmbdID: "shared-model",
|
||||
TenantEmbdID: &tenantEmbdID,
|
||||
},
|
||||
want: "tenant:42",
|
||||
},
|
||||
{
|
||||
name: "uses embd id without tenant embedding id",
|
||||
kb: &entity.Knowledgebase{
|
||||
EmbdID: "shared-model",
|
||||
},
|
||||
want: "embd:shared-model",
|
||||
},
|
||||
{
|
||||
name: "uses tenant default when embedding id is empty",
|
||||
kb: &entity.Knowledgebase{},
|
||||
tenantID: "tenant-1",
|
||||
want: "default:tenant-1",
|
||||
},
|
||||
{
|
||||
name: "ignores non-positive tenant embedding id",
|
||||
kb: &entity.Knowledgebase{
|
||||
EmbdID: "shared-model",
|
||||
TenantEmbdID: new(int64),
|
||||
},
|
||||
want: "embd:shared-model",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := knowledgebaseEmbeddingKey(tt.kb, tt.tenantID); got != tt.want {
|
||||
t.Fatalf("knowledgebaseEmbeddingKey() = %q, want %q", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParsePrevalidatesDocumentsBeforeMutating(t *testing.T) {
|
||||
db := setupChunkTestDB(t)
|
||||
pushChunkTestDB(t, db)
|
||||
|
||||
@@ -17,16 +17,21 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"ragflow/internal/common"
|
||||
"ragflow/internal/dao"
|
||||
"ragflow/internal/entity"
|
||||
"strings"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// SearchService search service
|
||||
type SearchService struct {
|
||||
searchDAO *dao.SearchDAO
|
||||
userTenantDAO *dao.UserTenantDAO
|
||||
tenantService *TenantService
|
||||
}
|
||||
|
||||
// NewSearchService create search service
|
||||
@@ -34,6 +39,13 @@ func NewSearchService() *SearchService {
|
||||
return &SearchService{
|
||||
searchDAO: dao.NewSearchDAO(),
|
||||
userTenantDAO: dao.NewUserTenantDAO(),
|
||||
tenantService: NewTenantService(),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SearchService) SetTenantService(tenantService *TenantService) {
|
||||
if tenantService != nil {
|
||||
s.tenantService = tenantService
|
||||
}
|
||||
}
|
||||
|
||||
@@ -285,6 +297,210 @@ func (s *SearchService) DeleteSearch(userID string, searchID string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// AccessibleForCompletion check if it is accessible
|
||||
func (s *SearchService) AccessibleForCompletion(userID string, searchID string) (bool, error) {
|
||||
ok, err := s.searchDAO.Accessible4Deletion(searchID, userID)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return false, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
return ok, nil
|
||||
}
|
||||
|
||||
type SearchCompletionPlan struct {
|
||||
UserID string
|
||||
SearchID string
|
||||
Question string
|
||||
KBIDs []string
|
||||
ModelID string
|
||||
Options AskStreamOptions
|
||||
}
|
||||
|
||||
func (s *SearchService) PrepareCompletion(userID, searchID string, req *SearchCompletionsRequest) (*SearchCompletionPlan, common.ErrorCode, error) {
|
||||
userID = strings.TrimSpace(userID)
|
||||
if userID == "" {
|
||||
return nil, common.CodeBadRequest, fmt.Errorf("user id is required")
|
||||
}
|
||||
searchID = strings.TrimSpace(searchID)
|
||||
if searchID == "" {
|
||||
return nil, common.CodeBadRequest, fmt.Errorf("search_id is required")
|
||||
}
|
||||
if req == nil {
|
||||
return nil, common.CodeArgumentError, fmt.Errorf("question is required")
|
||||
}
|
||||
question := strings.TrimSpace(req.Question)
|
||||
if question == "" {
|
||||
return nil, common.CodeArgumentError, fmt.Errorf("question is required")
|
||||
}
|
||||
|
||||
accessible, err := s.AccessibleForCompletion(userID, searchID)
|
||||
if err != nil {
|
||||
return nil, common.CodeServerError, err
|
||||
}
|
||||
if !accessible {
|
||||
return nil, common.CodeAuthenticationError, fmt.Errorf("No authorization.")
|
||||
}
|
||||
|
||||
searchDetail, err := s.GetDetail(searchID)
|
||||
if err != nil || searchDetail == nil {
|
||||
return nil, common.CodeDataError, fmt.Errorf("Cannot find search %s", searchID)
|
||||
}
|
||||
searchConfig := searchConfigMapFromValue(searchDetail["search_config"])
|
||||
|
||||
kbIDs := stringSliceFromSearchConfig(searchConfig["kb_ids"])
|
||||
if len(kbIDs) == 0 {
|
||||
kbIDs = stringSliceFromSearchConfig(req.KBIDs)
|
||||
}
|
||||
if len(kbIDs) == 0 {
|
||||
return nil, common.CodeDataError, fmt.Errorf("`kb_ids` is required.")
|
||||
}
|
||||
|
||||
modelID, _ := stringFromSearchConfig(searchConfig["chat_id"])
|
||||
if modelID == "" {
|
||||
tenantSvc := s.tenantService
|
||||
if tenantSvc == nil {
|
||||
tenantSvc = NewTenantService()
|
||||
}
|
||||
defaultModel, err := tenantSvc.GetDefaultModelName(userID, entity.ModelTypeChat)
|
||||
if err == nil {
|
||||
modelID = strings.TrimSpace(defaultModel)
|
||||
}
|
||||
}
|
||||
|
||||
return &SearchCompletionPlan{
|
||||
UserID: userID,
|
||||
SearchID: searchID,
|
||||
Question: question,
|
||||
KBIDs: kbIDs,
|
||||
ModelID: modelID,
|
||||
Options: askOptionsFromSearchConfig(searchID, searchConfig),
|
||||
}, common.CodeSuccess, nil
|
||||
}
|
||||
|
||||
func askOptionsFromSearchConfig(searchID string, searchConfig map[string]interface{}) AskStreamOptions {
|
||||
opts := AskStreamOptions{
|
||||
SearchID: searchID,
|
||||
DocIDs: stringSliceFromSearchConfig(searchConfig["doc_ids"]),
|
||||
CrossLanguages: stringSliceFromSearchConfig(searchConfig["cross_languages"]),
|
||||
}
|
||||
if value, ok := searchConfig["use_kg"].(bool); ok {
|
||||
opts.UseKG = &value
|
||||
}
|
||||
if value, ok := intFromSearchConfig(searchConfig["top_k"]); ok {
|
||||
opts.TopK = &value
|
||||
}
|
||||
if value, ok := searchConfigMapValue(searchConfig["meta_data_filter"]); ok {
|
||||
opts.Filter = value
|
||||
}
|
||||
if value, ok := stringFromSearchConfig(searchConfig["tenant_rerank_id"]); ok {
|
||||
opts.TenantRerankID = &value
|
||||
}
|
||||
if value, ok := stringFromSearchConfig(searchConfig["rerank_id"]); ok {
|
||||
opts.RerankID = &value
|
||||
}
|
||||
if value, ok := searchConfig["keyword"].(bool); ok {
|
||||
opts.Keyword = &value
|
||||
}
|
||||
if value, ok := floatFromSearchConfig(searchConfig["similarity_threshold"]); ok {
|
||||
opts.SimilarityThreshold = &value
|
||||
}
|
||||
if value, ok := floatFromSearchConfig(searchConfig["vector_similarity_weight"]); ok {
|
||||
opts.VectorSimilarityWeight = &value
|
||||
}
|
||||
return opts
|
||||
}
|
||||
|
||||
func searchConfigMapFromValue(value interface{}) map[string]interface{} {
|
||||
if result, ok := searchConfigMapValue(value); ok {
|
||||
return result
|
||||
}
|
||||
return map[string]interface{}{}
|
||||
}
|
||||
|
||||
func searchConfigMapValue(value interface{}) (map[string]interface{}, bool) {
|
||||
switch typed := value.(type) {
|
||||
case nil:
|
||||
return nil, false
|
||||
case map[string]interface{}:
|
||||
return typed, true
|
||||
case entity.JSONMap:
|
||||
return map[string]interface{}(typed), true
|
||||
default:
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
|
||||
func stringSliceFromSearchConfig(value interface{}) []string {
|
||||
switch typed := value.(type) {
|
||||
case nil:
|
||||
return nil
|
||||
case []string:
|
||||
result := make([]string, 0, len(typed))
|
||||
for _, item := range typed {
|
||||
if item = strings.TrimSpace(item); item != "" {
|
||||
result = append(result, item)
|
||||
}
|
||||
}
|
||||
return result
|
||||
case common.StringSlice:
|
||||
return stringSliceFromSearchConfig([]string(typed))
|
||||
case []interface{}:
|
||||
result := make([]string, 0, len(typed))
|
||||
for _, item := range typed {
|
||||
if value, ok := stringFromSearchConfig(item); ok {
|
||||
result = append(result, value)
|
||||
}
|
||||
}
|
||||
return result
|
||||
default:
|
||||
if value, ok := stringFromSearchConfig(value); ok {
|
||||
return []string{value}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func stringFromSearchConfig(value interface{}) (string, bool) {
|
||||
typed, ok := value.(string)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
typed = strings.TrimSpace(typed)
|
||||
return typed, typed != ""
|
||||
}
|
||||
|
||||
func intFromSearchConfig(value interface{}) (int, bool) {
|
||||
switch typed := value.(type) {
|
||||
case int:
|
||||
return typed, true
|
||||
case int64:
|
||||
return int(typed), true
|
||||
case float64:
|
||||
return int(typed), true
|
||||
case float32:
|
||||
return int(typed), true
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
func floatFromSearchConfig(value interface{}) (float64, bool) {
|
||||
switch typed := value.(type) {
|
||||
case float64:
|
||||
return typed, true
|
||||
case float32:
|
||||
return float64(typed), true
|
||||
case int:
|
||||
return float64(typed), true
|
||||
case int64:
|
||||
return float64(typed), true
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateSearchRequest update search request
|
||||
// Reference: api/apps/restful_apis/search_api.py::update
|
||||
// Required fields: name, search_config
|
||||
@@ -399,3 +615,8 @@ func (s *SearchService) GetDetail(searchID string) (map[string]interface{}, erro
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
type SearchCompletionsRequest struct {
|
||||
Question string `json:"question" binding:"required"`
|
||||
KBIDs []string `json:"kb_ids,omitempty"`
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user