From 1c0cdd84cea23c23907e211c1ec370ff04f649d7 Mon Sep 17 00:00:00 2001 From: Haruko386 Date: Mon, 29 Jun 2026 20:07:12 +0800 Subject: [PATCH] feat[Go]: implement searches//completions POST (#16440) ### Summary As title ### Type of change - [x] New Feature (non-breaking change which adds functionality) - [x] Bug Fix (non-breaking change which fixes an issue) --- cmd/server_main.go | 5 +- internal/handler/search.go | 94 +++++++++++ internal/router/router.go | 2 + internal/service/ask_service.go | 55 ++++++- internal/service/chat_session.go | 20 ++- internal/service/chat_session_test.go | 26 ++- internal/service/chunk/chunk.go | 86 ++++++---- internal/service/chunk/chunk_test.go | 49 ++++++ internal/service/search.go | 221 ++++++++++++++++++++++++++ 9 files changed, 513 insertions(+), 45 deletions(-) diff --git a/cmd/server_main.go b/cmd/server_main.go index cde8f69c11..aa7db3d33f 100644 --- a/cmd/server_main.go +++ b/cmd/server_main.go @@ -229,6 +229,7 @@ func startServer(config *server.Config) { systemService := service.NewSystemService() connectorService := service.NewConnectorService() searchService := service.NewSearchService() + searchService.SetTenantService(tenantService) fileService := service.NewFileService() memoryService := service.NewMemoryService() mcpService := service.NewMCPService() @@ -303,7 +304,9 @@ func startServer(config *server.Config) { chunkService, ) searchBotHandler.SetStreamLLM(searchBotLLM) - searchBotHandler.SetAskService(service.NewAskService(chunkService, nil, 0, 0)) + askService := service.NewAskService(chunkService, nil, 0, 0) + searchBotHandler.SetAskService(askService) + searchHandler.SetCompletionDependencies(searchBotLLM, askService) pluginHandler := handler.NewPluginHandler(service.NewPluginService()) modelHandler := handler.NewModelHandler(service.NewModelProviderService()) fileCommitHandler := handler.NewFileCommitHandler(service.NewFileCommitService()) diff --git a/internal/handler/search.go b/internal/handler/search.go index 19d505a9c9..c7db252ecf 100644 --- a/internal/handler/search.go +++ b/internal/handler/search.go @@ -30,6 +30,9 @@ import ( type SearchHandler struct { searchService *service.SearchService userService *service.UserService + streamLLM streamingLLM + askService *service.AskService + sseWriter SSEWriter } // NewSearchHandler create search handler @@ -37,9 +40,16 @@ func NewSearchHandler(searchService *service.SearchService, userService *service return &SearchHandler{ searchService: searchService, userService: userService, + sseWriter: &ginSSEWriter{}, } } +// SetCompletionDependencies wires the streaming search completion runtime. +func (h *SearchHandler) SetCompletionDependencies(streamLLM streamingLLM, askService *service.AskService) { + h.streamLLM = streamLLM + h.askService = askService +} + // ListSearches list search apps // @Summary List Search Apps // @Description Get list of search apps for the current user with filtering, pagination and sorting @@ -421,3 +431,87 @@ func (h *SearchHandler) UpdateSearch(c *gin.Context) { "message": "success", }) } + +func (h *SearchHandler) Completion(c *gin.Context) { + user, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) + return + } + + var req service.SearchCompletionsRequest + if err := c.ShouldBindJSON(&req); err != nil { + jsonError(c, common.CodeArgumentError, "question is required") + return + } + + searchSvc := h.searchService + if searchSvc == nil { + searchSvc = service.NewSearchService() + } + + plan, code, err := searchSvc.PrepareCompletion(user.ID, c.Param("search_id"), &req) + if err != nil { + if code == common.CodeAuthenticationError { + c.JSON(http.StatusOK, gin.H{ + "code": code, + "data": false, + "message": err.Error(), + }) + return + } + if code == common.CodeServerError { + jsonInternalError(c, err) + return + } + jsonError(c, code, err.Error()) + return + } + if plan == nil { + jsonError(c, common.CodeServerError, "completion plan is nil") + return + } + + disableWriteDeadlineForSSE(c) + c.Header("Content-Type", "text/event-stream; charset=utf-8") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("X-Accel-Buffering", "no") + + writer := h.sseWriter + if writer == nil { + writer = &ginSSEWriter{} + } + if plan.ModelID == "" { + writer.Write(c, sseError("chat model not configured")) + return + } + if h.askService == nil { + writer.Write(c, sseError("ask service not configured")) + return + } + if h.streamLLM == nil { + writer.Write(c, sseError("streaming LLM not configured")) + return + } + + adapter := &askStreamAdapter{llm: h.streamLLM, tenantID: plan.UserID, modelID: plan.ModelID} + + hadError := false + for delta := range h.askService.StreamWithOptions(c.Request.Context(), adapter, plan.UserID, plan.Question, plan.KBIDs, plan.Options) { + switch delta.Kind { + case service.AskDeltaAnswer: + writer.Write(c, sseAnswer(delta.Value, nil, false)) + case service.AskDeltaMarker: + writer.Write(c, sseMarker(delta.Value)) + case service.AskDeltaError: + hadError = true + writer.Write(c, sseError(delta.Value)) + case service.AskDeltaFinal: + writer.Write(c, sseAnswer(delta.Value, delta.Refs, true)) + } + } + if !hadError { + writer.Write(c, "data: {\"code\": 0, \"message\": \"\", \"data\": true}\n\n") + } +} diff --git a/internal/router/router.go b/internal/router/router.go index 9a51844936..50f4a108ad 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -364,6 +364,8 @@ func (r *Router) Setup(engine *gin.Engine) { searches.GET("/:search_id", r.searchHandler.GetSearch) searches.PUT("/:search_id", r.searchHandler.UpdateSearch) searches.DELETE("/:search_id", r.searchHandler.DeleteSearch) + searches.POST("/:search_id/completion", r.searchHandler.Completion) + searches.POST("/:search_id/completions", r.searchHandler.Completion) } file := v1.Group("/files") diff --git a/internal/service/ask_service.go b/internal/service/ask_service.go index 87e037c9fa..5e987d02c1 100644 --- a/internal/service/ask_service.go +++ b/internal/service/ask_service.go @@ -53,6 +53,22 @@ type AskDelta struct { Refs interface{} // populated on AskDeltaFinal: {chunks, doc_aggs} } +// AskStreamOptions carries optional retrieval settings supplied by saved +// search_config. Zero values keep the same defaults as Stream. +type AskStreamOptions struct { + SearchID string + DocIDs []string + UseKG *bool + TopK *int + CrossLanguages []string + Filter map[string]interface{} + TenantRerankID *string + RerankID *string + Keyword *bool + SimilarityThreshold *float64 + VectorSimilarityWeight *float64 +} + // Retriever abstracts chunk retrieval for AskService. type Retriever interface { RetrievalTest(req *RetrievalTestRequest, userID string) (*RetrievalTestResponse, error) @@ -91,22 +107,51 @@ func NewAskService(retriever Retriever, embedder Embedder, tokenBudget, minStrea // Stream runs the full ask pipeline. llm must not be nil. The returned // channel is closed when the pipeline completes or ctx is cancelled. func (s *AskService) Stream(ctx context.Context, llm StreamingLLM, userID, question string, kbIDs []string) <-chan AskDelta { + return s.StreamWithOptions(ctx, llm, userID, question, kbIDs, AskStreamOptions{}) +} + +// StreamWithOptions runs Stream while allowing callers such as saved Search +// apps to pass search_config retrieval options through to RetrievalTest. +func (s *AskService) StreamWithOptions(ctx context.Context, llm StreamingLLM, userID, question string, kbIDs []string, opts AskStreamOptions) <-chan AskDelta { out := make(chan AskDelta, 32) go func() { defer close(out) - s.run(ctx, llm, userID, question, kbIDs, out) + s.run(ctx, llm, userID, question, kbIDs, opts, out) }() return out } -func (s *AskService) run(ctx context.Context, llm StreamingLLM, userID, question string, kbIDs []string, out chan<- AskDelta) { +func (s *AskService) run(ctx context.Context, llm StreamingLLM, userID, question string, kbIDs []string, opts AskStreamOptions, out chan<- AskDelta) { // Phase 1: Retrieval. + topK := DefaultAskTopK + if opts.TopK != nil { + topK = *opts.TopK + } + similarityThreshold := DefaultAskSimilarityThreshold + if opts.SimilarityThreshold != nil { + similarityThreshold = *opts.SimilarityThreshold + } + vectorSimilarityWeight := DefaultAskVectorSimilarityWeight + if opts.VectorSimilarityWeight != nil { + vectorSimilarityWeight = *opts.VectorSimilarityWeight + } + req := &RetrievalTestRequest{ Datasets: common.StringSlice(kbIDs), Question: question, - TopK: ptrInt(DefaultAskTopK), - SimilarityThreshold: ptrFloat64(DefaultAskSimilarityThreshold), - VectorSimilarityWeight: ptrFloat64(DefaultAskVectorSimilarityWeight), + DocIDs: opts.DocIDs, + UseKG: opts.UseKG, + TopK: ptrInt(topK), + CrossLanguages: opts.CrossLanguages, + Filter: opts.Filter, + TenantRerankID: opts.TenantRerankID, + RerankID: opts.RerankID, + Keyword: opts.Keyword, + SimilarityThreshold: ptrFloat64(similarityThreshold), + VectorSimilarityWeight: ptrFloat64(vectorSimilarityWeight), + } + if opts.SearchID != "" { + req.SearchID = &opts.SearchID } page := DefaultAskPage ps := DefaultAskPageSize diff --git a/internal/service/chat_session.go b/internal/service/chat_session.go index f41afef6b7..84705b58eb 100644 --- a/internal/service/chat_session.go +++ b/internal/service/chat_session.go @@ -872,16 +872,24 @@ func parseMessages(raw json.RawMessage) []map[string]interface{} { return messages } - var wrapped struct { - Messages []map[string]interface{} `json:"messages"` - } + var wrapped map[string]json.RawMessage if err := json.Unmarshal(raw, &wrapped); err != nil { + return nil + } + wrappedMessages, ok := wrapped["messages"] + if !ok { + return nil + } + if len(wrappedMessages) == 0 || string(wrappedMessages) == "null" { return make([]map[string]interface{}, 0) } - if wrapped.Messages == nil { + if err := json.Unmarshal(wrappedMessages, &messages); err != nil { + return nil + } + if messages == nil { return make([]map[string]interface{}, 0) } - return wrapped.Messages + return messages } func parseReferenceList(raw json.RawMessage) []interface{} { @@ -891,7 +899,7 @@ func parseReferenceList(raw json.RawMessage) []interface{} { } err := json.Unmarshal(raw, &references) if err != nil { - return make([]interface{}, 0) + return nil } if references == nil { return make([]interface{}, 0) diff --git a/internal/service/chat_session_test.go b/internal/service/chat_session_test.go index c9e5acf21a..d988474481 100644 --- a/internal/service/chat_session_test.go +++ b/internal/service/chat_session_test.go @@ -1558,12 +1558,11 @@ func TestBuildSessionPayload_EmptyCollectionsEncodeAsEmptyArrays(t *testing.T) { } } -func TestParseCollections_ReturnEmptySlicesForMissingNullOrInvalid(t *testing.T) { +func TestParseCollections_ReturnEmptySlicesForMissingOrNull(t *testing.T) { messageInputs := []json.RawMessage{ nil, json.RawMessage(`null`), json.RawMessage(`{"messages":null}`), - json.RawMessage(`not-json`), } for _, input := range messageInputs { got := parseMessages(input) @@ -1575,7 +1574,6 @@ func TestParseCollections_ReturnEmptySlicesForMissingNullOrInvalid(t *testing.T) referenceInputs := []json.RawMessage{ nil, json.RawMessage(`null`), - json.RawMessage(`not-json`), } for _, input := range referenceInputs { got := parseReferenceList(input) @@ -1585,6 +1583,28 @@ func TestParseCollections_ReturnEmptySlicesForMissingNullOrInvalid(t *testing.T) } } +func TestParseCollections_ReturnNilForMalformedData(t *testing.T) { + messageInputs := []json.RawMessage{ + json.RawMessage(`not-json`), + json.RawMessage(`{"unexpected":[]}`), + } + for _, input := range messageInputs { + if got := parseMessages(input); got != nil { + t.Fatalf("parseMessages(%s)=%#v, want nil", string(input), got) + } + } + + referenceInputs := []json.RawMessage{ + json.RawMessage(`not-json`), + json.RawMessage(`{"unexpected":[]}`), + } + for _, input := range referenceInputs { + if got := parseReferenceList(input); got != nil { + t.Fatalf("parseReferenceList(%s)=%#v, want nil", string(input), got) + } + } +} + func TestCompletionStream_EmptyMessages(t *testing.T) { svc := &ChatSessionService{ chatSessionDAO: &fakeSessionStore{}, diff --git a/internal/service/chunk/chunk.go b/internal/service/chunk/chunk.go index ecda520cd9..b165439137 100644 --- a/internal/service/chunk/chunk.go +++ b/internal/service/chunk/chunk.go @@ -77,6 +77,17 @@ type chunkImageMergeLock struct { refs int } +func searchConfigMap(value interface{}) (map[string]interface{}, bool) { + switch typed := value.(type) { + case entity.JSONMap: + return map[string]interface{}(typed), true + case map[string]interface{}: + return typed, true + default: + return nil, false + } +} + // ChunkService chunk service type ChunkService struct { docEngine engine.DocEngine @@ -200,9 +211,9 @@ func (s *ChunkService) RetrievalTest(req *service.RetrievalTestRequest, userID s // Check if all kbs have the same embedding model if len(kbRecords) > 1 { - firstEmbdID := kbRecords[0].EmbdID + firstEmbeddingKey := knowledgebaseEmbeddingKey(kbRecords[0], tenantIDs[0]) for i := 1; i < len(kbRecords); i++ { - if kbRecords[i].EmbdID != firstEmbdID { + if knowledgebaseEmbeddingKey(kbRecords[i], tenantIDs[i]) != firstEmbeddingKey { return nil, fmt.Errorf("cannot retrieve across datasets with different embedding models") } } @@ -218,8 +229,8 @@ func (s *ChunkService) RetrievalTest(req *service.RetrievalTestRequest, userID s searchDetail, err := s.searchService.GetDetail(*req.SearchID) if err != nil { common.Warn("Failed to get search detail for search_id, proceeding without it", zap.String("searchID", *req.SearchID), zap.Error(err)) - } else if searchConfig, ok := searchDetail["search_config"].(entity.JSONMap); ok && searchConfig != nil { - if searchMetaFilter, ok := searchConfig["meta_data_filter"].(map[string]interface{}); ok { + } else if searchConfig, ok := searchConfigMap(searchDetail["search_config"]); ok && searchConfig != nil { + if searchMetaFilter, ok := searchConfigMap(searchConfig["meta_data_filter"]); ok { filter = searchMetaFilter } chatID, _ = searchConfig["chat_id"].(string) @@ -345,48 +356,47 @@ func (s *ChunkService) RetrievalTest(req *service.RetrievalTestRequest, userID s labels := metadataSvc.LabelQuestion(modifiedQuestion, kbRecords) common.Debug("LabelQuestion result", zap.Any("labels", labels)) - // Determine embedding model + // Determine embedding model. modelProviderSvc := service.NewModelProviderService() - var embdID string - var tenantLLM *entity.TenantLLM var embeddingModel *models.EmbeddingModel + var embdID string if kbRecords[0].TenantEmbdID != nil && *kbRecords[0].TenantEmbdID > 0 { - tenantLLM, embdID, err = dao.LookupTenantLLMByID(dao.NewTenantLLMDAO(), *kbRecords[0].TenantEmbdID) + _, embdID, err = dao.LookupTenantLLMByID(dao.NewTenantLLMDAO(), *kbRecords[0].TenantEmbdID) if err != nil { return nil, fmt.Errorf("failed to get embedding model by tenant_embd_id: %w", err) } + driver, modelName, apiConfig, maxTokens, getErr := modelProviderSvc.GetModelConfigFromProviderInstance(tenantIDs[0], entity.ModelTypeEmbedding, embdID) + if getErr != nil { + return nil, fmt.Errorf("failed to get embedding model by tenant_embd_id: %w", getErr) + } + embeddingModel = models.NewEmbeddingModel(driver, &modelName, apiConfig, maxTokens) } else if kbRecords[0].EmbdID != "" { - if strings.Contains(kbRecords[0].EmbdID, "@") { - driver, modelName, apiConfig, maxTokens, embErr := modelProviderSvc.GetModelConfigFromProviderInstance(tenantIDs[0], entity.ModelTypeEmbedding, kbRecords[0].EmbdID) - if embErr != nil { - return nil, fmt.Errorf("failed to get embedding model by embd_id: %w", embErr) - } - embeddingModel = models.NewEmbeddingModel(driver, &modelName, apiConfig, maxTokens) - embdID = kbRecords[0].EmbdID - } else { - tenantLLM, embdID, err = dao.LookupTenantLLMByName(dao.NewTenantLLMDAO(), tenantIDs[0], kbRecords[0].EmbdID, entity.ModelTypeEmbedding) + embdID = kbRecords[0].EmbdID + driver, modelName, apiConfig, maxTokens, getErr := modelProviderSvc.GetModelConfigFromProviderInstance(tenantIDs[0], entity.ModelTypeEmbedding, embdID) + if getErr != nil { + _, embdID, err = dao.LookupTenantLLMByName(dao.NewTenantLLMDAO(), tenantIDs[0], kbRecords[0].EmbdID, entity.ModelTypeEmbedding) if err != nil { - return nil, fmt.Errorf("failed to get embedding model by embd_id: %w", err) + return nil, fmt.Errorf("failed to get embedding model by embd_id: %w", getErr) + } + driver, modelName, apiConfig, maxTokens, getErr = modelProviderSvc.GetModelConfigFromProviderInstance(tenantIDs[0], entity.ModelTypeEmbedding, embdID) + if getErr != nil { + return nil, fmt.Errorf("failed to get embedding model by embd_id: %w", getErr) } } + embeddingModel = models.NewEmbeddingModel(driver, &modelName, apiConfig, maxTokens) } else { - tenantLLM, err = dao.NewTenantLLMDAO().GetByTenantAndType(tenantIDs[0], entity.ModelTypeEmbedding) - if err != nil { - return nil, fmt.Errorf("failed to get tenant default embedding model: %w", err) + driver, modelName, apiConfig, maxTokens, getErr := modelProviderSvc.GetTenantDefaultModelByType(tenantIDs[0], entity.ModelTypeEmbedding) + if getErr != nil { + return nil, fmt.Errorf("failed to get tenant default embedding model: %w", getErr) } - if tenantLLM == nil || tenantLLM.LLMName == nil || *tenantLLM.LLMName == "" { - return nil, fmt.Errorf("no default embedding model found for tenant %s", tenantIDs[0]) - } - embdID = fmt.Sprintf("%s@%s", *tenantLLM.LLMName, tenantLLM.LLMFactory) + embeddingModel = models.NewEmbeddingModel(driver, &modelName, apiConfig, maxTokens) + embdID = fmt.Sprintf("%s@default", modelName) } - // Get embedding model for the tenant if embeddingModel == nil { - embeddingModel, err = modelProviderSvc.GetEmbeddingModel(tenantIDs[0], embdID) - if err != nil { - return nil, fmt.Errorf("failed to get embedding model: %w", err) - } + return nil, fmt.Errorf("no embedding model found for tenant %s", tenantIDs[0]) } + common.Info("Fetched embedding model for retrieval", zap.String("tenantID", tenantIDs[0]), zap.String("embdID", embdID)) @@ -405,6 +415,12 @@ func (s *ChunkService) RetrievalTest(req *service.RetrievalTestRequest, userID s } } else if req.RerankID != nil && *req.RerankID != "" { rerankCompositeName = *req.RerankID + if _, _, _, _, getErr := modelProviderSvc.GetModelConfigFromProviderInstance(tenantIDs[0], entity.ModelTypeRerank, rerankCompositeName); getErr != nil { + _, rerankCompositeName, err = dao.LookupTenantLLMByName(dao.NewTenantLLMDAO(), tenantIDs[0], *req.RerankID, entity.ModelTypeRerank) + if err != nil { + return nil, fmt.Errorf("failed to get rerank model by rerank_id: %w", getErr) + } + } } if rerankCompositeName != "" { driver, mdlName, apiConfig, _, getErr := modelProviderSvc.GetModelConfigFromProviderInstance(tenantIDs[0], entity.ModelTypeRerank, rerankCompositeName) @@ -466,6 +482,16 @@ func (s *ChunkService) RetrievalTest(req *service.RetrievalTestRequest, userID s }, nil } +func knowledgebaseEmbeddingKey(kb *entity.Knowledgebase, tenantID string) string { + if kb.TenantEmbdID != nil && *kb.TenantEmbdID > 0 { + return fmt.Sprintf("tenant:%d", *kb.TenantEmbdID) + } + if kb.EmbdID == "" { + return fmt.Sprintf("default:%s", tenantID) + } + return fmt.Sprintf("embd:%s", kb.EmbdID) +} + // hydrateChunkVectors replaces zero (placeholder) vectors in chunks with real // vectors fetched from the engine. Infinity and OceanBase already ship real // vectors with chunks, so this is a no-op for those engines; for ES it queries diff --git a/internal/service/chunk/chunk_test.go b/internal/service/chunk/chunk_test.go index 04cf534567..cf77c990fe 100644 --- a/internal/service/chunk/chunk_test.go +++ b/internal/service/chunk/chunk_test.go @@ -77,6 +77,55 @@ func TestHydrateChunkVectors_NoDim(t *testing.T) { // Empty vectors have dim=0 → early return. No crash. } +func TestKnowledgebaseEmbeddingKey(t *testing.T) { + tenantEmbdID := int64(42) + + tests := []struct { + name string + kb *entity.Knowledgebase + tenantID string + want string + }{ + { + name: "uses tenant embedding id before embd id", + kb: &entity.Knowledgebase{ + EmbdID: "shared-model", + TenantEmbdID: &tenantEmbdID, + }, + want: "tenant:42", + }, + { + name: "uses embd id without tenant embedding id", + kb: &entity.Knowledgebase{ + EmbdID: "shared-model", + }, + want: "embd:shared-model", + }, + { + name: "uses tenant default when embedding id is empty", + kb: &entity.Knowledgebase{}, + tenantID: "tenant-1", + want: "default:tenant-1", + }, + { + name: "ignores non-positive tenant embedding id", + kb: &entity.Knowledgebase{ + EmbdID: "shared-model", + TenantEmbdID: new(int64), + }, + want: "embd:shared-model", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := knowledgebaseEmbeddingKey(tt.kb, tt.tenantID); got != tt.want { + t.Fatalf("knowledgebaseEmbeddingKey() = %q, want %q", got, tt.want) + } + }) + } +} + func TestParsePrevalidatesDocumentsBeforeMutating(t *testing.T) { db := setupChunkTestDB(t) pushChunkTestDB(t, db) diff --git a/internal/service/search.go b/internal/service/search.go index 21e9275615..8ee26a616f 100644 --- a/internal/service/search.go +++ b/internal/service/search.go @@ -17,16 +17,21 @@ package service import ( + "errors" "fmt" "ragflow/internal/common" "ragflow/internal/dao" "ragflow/internal/entity" + "strings" + + "gorm.io/gorm" ) // SearchService search service type SearchService struct { searchDAO *dao.SearchDAO userTenantDAO *dao.UserTenantDAO + tenantService *TenantService } // NewSearchService create search service @@ -34,6 +39,13 @@ func NewSearchService() *SearchService { return &SearchService{ searchDAO: dao.NewSearchDAO(), userTenantDAO: dao.NewUserTenantDAO(), + tenantService: NewTenantService(), + } +} + +func (s *SearchService) SetTenantService(tenantService *TenantService) { + if tenantService != nil { + s.tenantService = tenantService } } @@ -285,6 +297,210 @@ func (s *SearchService) DeleteSearch(userID string, searchID string) error { return nil } +// AccessibleForCompletion check if it is accessible +func (s *SearchService) AccessibleForCompletion(userID string, searchID string) (bool, error) { + ok, err := s.searchDAO.Accessible4Deletion(searchID, userID) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return false, nil + } + return false, err + } + return ok, nil +} + +type SearchCompletionPlan struct { + UserID string + SearchID string + Question string + KBIDs []string + ModelID string + Options AskStreamOptions +} + +func (s *SearchService) PrepareCompletion(userID, searchID string, req *SearchCompletionsRequest) (*SearchCompletionPlan, common.ErrorCode, error) { + userID = strings.TrimSpace(userID) + if userID == "" { + return nil, common.CodeBadRequest, fmt.Errorf("user id is required") + } + searchID = strings.TrimSpace(searchID) + if searchID == "" { + return nil, common.CodeBadRequest, fmt.Errorf("search_id is required") + } + if req == nil { + return nil, common.CodeArgumentError, fmt.Errorf("question is required") + } + question := strings.TrimSpace(req.Question) + if question == "" { + return nil, common.CodeArgumentError, fmt.Errorf("question is required") + } + + accessible, err := s.AccessibleForCompletion(userID, searchID) + if err != nil { + return nil, common.CodeServerError, err + } + if !accessible { + return nil, common.CodeAuthenticationError, fmt.Errorf("No authorization.") + } + + searchDetail, err := s.GetDetail(searchID) + if err != nil || searchDetail == nil { + return nil, common.CodeDataError, fmt.Errorf("Cannot find search %s", searchID) + } + searchConfig := searchConfigMapFromValue(searchDetail["search_config"]) + + kbIDs := stringSliceFromSearchConfig(searchConfig["kb_ids"]) + if len(kbIDs) == 0 { + kbIDs = stringSliceFromSearchConfig(req.KBIDs) + } + if len(kbIDs) == 0 { + return nil, common.CodeDataError, fmt.Errorf("`kb_ids` is required.") + } + + modelID, _ := stringFromSearchConfig(searchConfig["chat_id"]) + if modelID == "" { + tenantSvc := s.tenantService + if tenantSvc == nil { + tenantSvc = NewTenantService() + } + defaultModel, err := tenantSvc.GetDefaultModelName(userID, entity.ModelTypeChat) + if err == nil { + modelID = strings.TrimSpace(defaultModel) + } + } + + return &SearchCompletionPlan{ + UserID: userID, + SearchID: searchID, + Question: question, + KBIDs: kbIDs, + ModelID: modelID, + Options: askOptionsFromSearchConfig(searchID, searchConfig), + }, common.CodeSuccess, nil +} + +func askOptionsFromSearchConfig(searchID string, searchConfig map[string]interface{}) AskStreamOptions { + opts := AskStreamOptions{ + SearchID: searchID, + DocIDs: stringSliceFromSearchConfig(searchConfig["doc_ids"]), + CrossLanguages: stringSliceFromSearchConfig(searchConfig["cross_languages"]), + } + if value, ok := searchConfig["use_kg"].(bool); ok { + opts.UseKG = &value + } + if value, ok := intFromSearchConfig(searchConfig["top_k"]); ok { + opts.TopK = &value + } + if value, ok := searchConfigMapValue(searchConfig["meta_data_filter"]); ok { + opts.Filter = value + } + if value, ok := stringFromSearchConfig(searchConfig["tenant_rerank_id"]); ok { + opts.TenantRerankID = &value + } + if value, ok := stringFromSearchConfig(searchConfig["rerank_id"]); ok { + opts.RerankID = &value + } + if value, ok := searchConfig["keyword"].(bool); ok { + opts.Keyword = &value + } + if value, ok := floatFromSearchConfig(searchConfig["similarity_threshold"]); ok { + opts.SimilarityThreshold = &value + } + if value, ok := floatFromSearchConfig(searchConfig["vector_similarity_weight"]); ok { + opts.VectorSimilarityWeight = &value + } + return opts +} + +func searchConfigMapFromValue(value interface{}) map[string]interface{} { + if result, ok := searchConfigMapValue(value); ok { + return result + } + return map[string]interface{}{} +} + +func searchConfigMapValue(value interface{}) (map[string]interface{}, bool) { + switch typed := value.(type) { + case nil: + return nil, false + case map[string]interface{}: + return typed, true + case entity.JSONMap: + return map[string]interface{}(typed), true + default: + return nil, false + } +} + +func stringSliceFromSearchConfig(value interface{}) []string { + switch typed := value.(type) { + case nil: + return nil + case []string: + result := make([]string, 0, len(typed)) + for _, item := range typed { + if item = strings.TrimSpace(item); item != "" { + result = append(result, item) + } + } + return result + case common.StringSlice: + return stringSliceFromSearchConfig([]string(typed)) + case []interface{}: + result := make([]string, 0, len(typed)) + for _, item := range typed { + if value, ok := stringFromSearchConfig(item); ok { + result = append(result, value) + } + } + return result + default: + if value, ok := stringFromSearchConfig(value); ok { + return []string{value} + } + return nil + } +} + +func stringFromSearchConfig(value interface{}) (string, bool) { + typed, ok := value.(string) + if !ok { + return "", false + } + typed = strings.TrimSpace(typed) + return typed, typed != "" +} + +func intFromSearchConfig(value interface{}) (int, bool) { + switch typed := value.(type) { + case int: + return typed, true + case int64: + return int(typed), true + case float64: + return int(typed), true + case float32: + return int(typed), true + default: + return 0, false + } +} + +func floatFromSearchConfig(value interface{}) (float64, bool) { + switch typed := value.(type) { + case float64: + return typed, true + case float32: + return float64(typed), true + case int: + return float64(typed), true + case int64: + return float64(typed), true + default: + return 0, false + } +} + // UpdateSearchRequest update search request // Reference: api/apps/restful_apis/search_api.py::update // Required fields: name, search_config @@ -399,3 +615,8 @@ func (s *SearchService) GetDetail(searchID string) (map[string]interface{}, erro return result, nil } + +type SearchCompletionsRequest struct { + Question string `json:"question" binding:"required"` + KBIDs []string `json:"kb_ids,omitempty"` +}