feat[Go]: implement searches/<search_id>/completions POST (#16440)

### Summary

As title

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
- [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
Haruko386
2026-06-29 20:07:12 +08:00
committed by GitHub
parent 7c1edca15e
commit 1c0cdd84ce
9 changed files with 513 additions and 45 deletions

View File

@@ -229,6 +229,7 @@ func startServer(config *server.Config) {
systemService := service.NewSystemService()
connectorService := service.NewConnectorService()
searchService := service.NewSearchService()
searchService.SetTenantService(tenantService)
fileService := service.NewFileService()
memoryService := service.NewMemoryService()
mcpService := service.NewMCPService()
@@ -303,7 +304,9 @@ func startServer(config *server.Config) {
chunkService,
)
searchBotHandler.SetStreamLLM(searchBotLLM)
searchBotHandler.SetAskService(service.NewAskService(chunkService, nil, 0, 0))
askService := service.NewAskService(chunkService, nil, 0, 0)
searchBotHandler.SetAskService(askService)
searchHandler.SetCompletionDependencies(searchBotLLM, askService)
pluginHandler := handler.NewPluginHandler(service.NewPluginService())
modelHandler := handler.NewModelHandler(service.NewModelProviderService())
fileCommitHandler := handler.NewFileCommitHandler(service.NewFileCommitService())

View File

@@ -30,6 +30,9 @@ import (
type SearchHandler struct {
searchService *service.SearchService
userService *service.UserService
streamLLM streamingLLM
askService *service.AskService
sseWriter SSEWriter
}
// NewSearchHandler create search handler
@@ -37,9 +40,16 @@ func NewSearchHandler(searchService *service.SearchService, userService *service
return &SearchHandler{
searchService: searchService,
userService: userService,
sseWriter: &ginSSEWriter{},
}
}
// SetCompletionDependencies wires the streaming search completion runtime.
func (h *SearchHandler) SetCompletionDependencies(streamLLM streamingLLM, askService *service.AskService) {
h.streamLLM = streamLLM
h.askService = askService
}
// ListSearches list search apps
// @Summary List Search Apps
// @Description Get list of search apps for the current user with filtering, pagination and sorting
@@ -421,3 +431,87 @@ func (h *SearchHandler) UpdateSearch(c *gin.Context) {
"message": "success",
})
}
func (h *SearchHandler) Completion(c *gin.Context) {
user, errorCode, errorMessage := GetUser(c)
if errorCode != common.CodeSuccess {
jsonError(c, errorCode, errorMessage)
return
}
var req service.SearchCompletionsRequest
if err := c.ShouldBindJSON(&req); err != nil {
jsonError(c, common.CodeArgumentError, "question is required")
return
}
searchSvc := h.searchService
if searchSvc == nil {
searchSvc = service.NewSearchService()
}
plan, code, err := searchSvc.PrepareCompletion(user.ID, c.Param("search_id"), &req)
if err != nil {
if code == common.CodeAuthenticationError {
c.JSON(http.StatusOK, gin.H{
"code": code,
"data": false,
"message": err.Error(),
})
return
}
if code == common.CodeServerError {
jsonInternalError(c, err)
return
}
jsonError(c, code, err.Error())
return
}
if plan == nil {
jsonError(c, common.CodeServerError, "completion plan is nil")
return
}
disableWriteDeadlineForSSE(c)
c.Header("Content-Type", "text/event-stream; charset=utf-8")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("X-Accel-Buffering", "no")
writer := h.sseWriter
if writer == nil {
writer = &ginSSEWriter{}
}
if plan.ModelID == "" {
writer.Write(c, sseError("chat model not configured"))
return
}
if h.askService == nil {
writer.Write(c, sseError("ask service not configured"))
return
}
if h.streamLLM == nil {
writer.Write(c, sseError("streaming LLM not configured"))
return
}
adapter := &askStreamAdapter{llm: h.streamLLM, tenantID: plan.UserID, modelID: plan.ModelID}
hadError := false
for delta := range h.askService.StreamWithOptions(c.Request.Context(), adapter, plan.UserID, plan.Question, plan.KBIDs, plan.Options) {
switch delta.Kind {
case service.AskDeltaAnswer:
writer.Write(c, sseAnswer(delta.Value, nil, false))
case service.AskDeltaMarker:
writer.Write(c, sseMarker(delta.Value))
case service.AskDeltaError:
hadError = true
writer.Write(c, sseError(delta.Value))
case service.AskDeltaFinal:
writer.Write(c, sseAnswer(delta.Value, delta.Refs, true))
}
}
if !hadError {
writer.Write(c, "data: {\"code\": 0, \"message\": \"\", \"data\": true}\n\n")
}
}

View File

@@ -364,6 +364,8 @@ func (r *Router) Setup(engine *gin.Engine) {
searches.GET("/:search_id", r.searchHandler.GetSearch)
searches.PUT("/:search_id", r.searchHandler.UpdateSearch)
searches.DELETE("/:search_id", r.searchHandler.DeleteSearch)
searches.POST("/:search_id/completion", r.searchHandler.Completion)
searches.POST("/:search_id/completions", r.searchHandler.Completion)
}
file := v1.Group("/files")

View File

@@ -53,6 +53,22 @@ type AskDelta struct {
Refs interface{} // populated on AskDeltaFinal: {chunks, doc_aggs}
}
// AskStreamOptions carries optional retrieval settings supplied by saved
// search_config. Zero values keep the same defaults as Stream.
type AskStreamOptions struct {
SearchID string
DocIDs []string
UseKG *bool
TopK *int
CrossLanguages []string
Filter map[string]interface{}
TenantRerankID *string
RerankID *string
Keyword *bool
SimilarityThreshold *float64
VectorSimilarityWeight *float64
}
// Retriever abstracts chunk retrieval for AskService.
type Retriever interface {
RetrievalTest(req *RetrievalTestRequest, userID string) (*RetrievalTestResponse, error)
@@ -91,22 +107,51 @@ func NewAskService(retriever Retriever, embedder Embedder, tokenBudget, minStrea
// Stream runs the full ask pipeline. llm must not be nil. The returned
// channel is closed when the pipeline completes or ctx is cancelled.
func (s *AskService) Stream(ctx context.Context, llm StreamingLLM, userID, question string, kbIDs []string) <-chan AskDelta {
return s.StreamWithOptions(ctx, llm, userID, question, kbIDs, AskStreamOptions{})
}
// StreamWithOptions runs Stream while allowing callers such as saved Search
// apps to pass search_config retrieval options through to RetrievalTest.
func (s *AskService) StreamWithOptions(ctx context.Context, llm StreamingLLM, userID, question string, kbIDs []string, opts AskStreamOptions) <-chan AskDelta {
out := make(chan AskDelta, 32)
go func() {
defer close(out)
s.run(ctx, llm, userID, question, kbIDs, out)
s.run(ctx, llm, userID, question, kbIDs, opts, out)
}()
return out
}
func (s *AskService) run(ctx context.Context, llm StreamingLLM, userID, question string, kbIDs []string, out chan<- AskDelta) {
func (s *AskService) run(ctx context.Context, llm StreamingLLM, userID, question string, kbIDs []string, opts AskStreamOptions, out chan<- AskDelta) {
// Phase 1: Retrieval.
topK := DefaultAskTopK
if opts.TopK != nil {
topK = *opts.TopK
}
similarityThreshold := DefaultAskSimilarityThreshold
if opts.SimilarityThreshold != nil {
similarityThreshold = *opts.SimilarityThreshold
}
vectorSimilarityWeight := DefaultAskVectorSimilarityWeight
if opts.VectorSimilarityWeight != nil {
vectorSimilarityWeight = *opts.VectorSimilarityWeight
}
req := &RetrievalTestRequest{
Datasets: common.StringSlice(kbIDs),
Question: question,
TopK: ptrInt(DefaultAskTopK),
SimilarityThreshold: ptrFloat64(DefaultAskSimilarityThreshold),
VectorSimilarityWeight: ptrFloat64(DefaultAskVectorSimilarityWeight),
DocIDs: opts.DocIDs,
UseKG: opts.UseKG,
TopK: ptrInt(topK),
CrossLanguages: opts.CrossLanguages,
Filter: opts.Filter,
TenantRerankID: opts.TenantRerankID,
RerankID: opts.RerankID,
Keyword: opts.Keyword,
SimilarityThreshold: ptrFloat64(similarityThreshold),
VectorSimilarityWeight: ptrFloat64(vectorSimilarityWeight),
}
if opts.SearchID != "" {
req.SearchID = &opts.SearchID
}
page := DefaultAskPage
ps := DefaultAskPageSize

View File

@@ -872,16 +872,24 @@ func parseMessages(raw json.RawMessage) []map[string]interface{} {
return messages
}
var wrapped struct {
Messages []map[string]interface{} `json:"messages"`
}
var wrapped map[string]json.RawMessage
if err := json.Unmarshal(raw, &wrapped); err != nil {
return nil
}
wrappedMessages, ok := wrapped["messages"]
if !ok {
return nil
}
if len(wrappedMessages) == 0 || string(wrappedMessages) == "null" {
return make([]map[string]interface{}, 0)
}
if wrapped.Messages == nil {
if err := json.Unmarshal(wrappedMessages, &messages); err != nil {
return nil
}
if messages == nil {
return make([]map[string]interface{}, 0)
}
return wrapped.Messages
return messages
}
func parseReferenceList(raw json.RawMessage) []interface{} {
@@ -891,7 +899,7 @@ func parseReferenceList(raw json.RawMessage) []interface{} {
}
err := json.Unmarshal(raw, &references)
if err != nil {
return make([]interface{}, 0)
return nil
}
if references == nil {
return make([]interface{}, 0)

View File

@@ -1558,12 +1558,11 @@ func TestBuildSessionPayload_EmptyCollectionsEncodeAsEmptyArrays(t *testing.T) {
}
}
func TestParseCollections_ReturnEmptySlicesForMissingNullOrInvalid(t *testing.T) {
func TestParseCollections_ReturnEmptySlicesForMissingOrNull(t *testing.T) {
messageInputs := []json.RawMessage{
nil,
json.RawMessage(`null`),
json.RawMessage(`{"messages":null}`),
json.RawMessage(`not-json`),
}
for _, input := range messageInputs {
got := parseMessages(input)
@@ -1575,7 +1574,6 @@ func TestParseCollections_ReturnEmptySlicesForMissingNullOrInvalid(t *testing.T)
referenceInputs := []json.RawMessage{
nil,
json.RawMessage(`null`),
json.RawMessage(`not-json`),
}
for _, input := range referenceInputs {
got := parseReferenceList(input)
@@ -1585,6 +1583,28 @@ func TestParseCollections_ReturnEmptySlicesForMissingNullOrInvalid(t *testing.T)
}
}
func TestParseCollections_ReturnNilForMalformedData(t *testing.T) {
messageInputs := []json.RawMessage{
json.RawMessage(`not-json`),
json.RawMessage(`{"unexpected":[]}`),
}
for _, input := range messageInputs {
if got := parseMessages(input); got != nil {
t.Fatalf("parseMessages(%s)=%#v, want nil", string(input), got)
}
}
referenceInputs := []json.RawMessage{
json.RawMessage(`not-json`),
json.RawMessage(`{"unexpected":[]}`),
}
for _, input := range referenceInputs {
if got := parseReferenceList(input); got != nil {
t.Fatalf("parseReferenceList(%s)=%#v, want nil", string(input), got)
}
}
}
func TestCompletionStream_EmptyMessages(t *testing.T) {
svc := &ChatSessionService{
chatSessionDAO: &fakeSessionStore{},

View File

@@ -77,6 +77,17 @@ type chunkImageMergeLock struct {
refs int
}
func searchConfigMap(value interface{}) (map[string]interface{}, bool) {
switch typed := value.(type) {
case entity.JSONMap:
return map[string]interface{}(typed), true
case map[string]interface{}:
return typed, true
default:
return nil, false
}
}
// ChunkService chunk service
type ChunkService struct {
docEngine engine.DocEngine
@@ -200,9 +211,9 @@ func (s *ChunkService) RetrievalTest(req *service.RetrievalTestRequest, userID s
// Check if all kbs have the same embedding model
if len(kbRecords) > 1 {
firstEmbdID := kbRecords[0].EmbdID
firstEmbeddingKey := knowledgebaseEmbeddingKey(kbRecords[0], tenantIDs[0])
for i := 1; i < len(kbRecords); i++ {
if kbRecords[i].EmbdID != firstEmbdID {
if knowledgebaseEmbeddingKey(kbRecords[i], tenantIDs[i]) != firstEmbeddingKey {
return nil, fmt.Errorf("cannot retrieve across datasets with different embedding models")
}
}
@@ -218,8 +229,8 @@ func (s *ChunkService) RetrievalTest(req *service.RetrievalTestRequest, userID s
searchDetail, err := s.searchService.GetDetail(*req.SearchID)
if err != nil {
common.Warn("Failed to get search detail for search_id, proceeding without it", zap.String("searchID", *req.SearchID), zap.Error(err))
} else if searchConfig, ok := searchDetail["search_config"].(entity.JSONMap); ok && searchConfig != nil {
if searchMetaFilter, ok := searchConfig["meta_data_filter"].(map[string]interface{}); ok {
} else if searchConfig, ok := searchConfigMap(searchDetail["search_config"]); ok && searchConfig != nil {
if searchMetaFilter, ok := searchConfigMap(searchConfig["meta_data_filter"]); ok {
filter = searchMetaFilter
}
chatID, _ = searchConfig["chat_id"].(string)
@@ -345,48 +356,47 @@ func (s *ChunkService) RetrievalTest(req *service.RetrievalTestRequest, userID s
labels := metadataSvc.LabelQuestion(modifiedQuestion, kbRecords)
common.Debug("LabelQuestion result", zap.Any("labels", labels))
// Determine embedding model
// Determine embedding model.
modelProviderSvc := service.NewModelProviderService()
var embdID string
var tenantLLM *entity.TenantLLM
var embeddingModel *models.EmbeddingModel
var embdID string
if kbRecords[0].TenantEmbdID != nil && *kbRecords[0].TenantEmbdID > 0 {
tenantLLM, embdID, err = dao.LookupTenantLLMByID(dao.NewTenantLLMDAO(), *kbRecords[0].TenantEmbdID)
_, embdID, err = dao.LookupTenantLLMByID(dao.NewTenantLLMDAO(), *kbRecords[0].TenantEmbdID)
if err != nil {
return nil, fmt.Errorf("failed to get embedding model by tenant_embd_id: %w", err)
}
driver, modelName, apiConfig, maxTokens, getErr := modelProviderSvc.GetModelConfigFromProviderInstance(tenantIDs[0], entity.ModelTypeEmbedding, embdID)
if getErr != nil {
return nil, fmt.Errorf("failed to get embedding model by tenant_embd_id: %w", getErr)
}
embeddingModel = models.NewEmbeddingModel(driver, &modelName, apiConfig, maxTokens)
} else if kbRecords[0].EmbdID != "" {
if strings.Contains(kbRecords[0].EmbdID, "@") {
driver, modelName, apiConfig, maxTokens, embErr := modelProviderSvc.GetModelConfigFromProviderInstance(tenantIDs[0], entity.ModelTypeEmbedding, kbRecords[0].EmbdID)
if embErr != nil {
return nil, fmt.Errorf("failed to get embedding model by embd_id: %w", embErr)
}
embeddingModel = models.NewEmbeddingModel(driver, &modelName, apiConfig, maxTokens)
embdID = kbRecords[0].EmbdID
} else {
tenantLLM, embdID, err = dao.LookupTenantLLMByName(dao.NewTenantLLMDAO(), tenantIDs[0], kbRecords[0].EmbdID, entity.ModelTypeEmbedding)
embdID = kbRecords[0].EmbdID
driver, modelName, apiConfig, maxTokens, getErr := modelProviderSvc.GetModelConfigFromProviderInstance(tenantIDs[0], entity.ModelTypeEmbedding, embdID)
if getErr != nil {
_, embdID, err = dao.LookupTenantLLMByName(dao.NewTenantLLMDAO(), tenantIDs[0], kbRecords[0].EmbdID, entity.ModelTypeEmbedding)
if err != nil {
return nil, fmt.Errorf("failed to get embedding model by embd_id: %w", err)
return nil, fmt.Errorf("failed to get embedding model by embd_id: %w", getErr)
}
driver, modelName, apiConfig, maxTokens, getErr = modelProviderSvc.GetModelConfigFromProviderInstance(tenantIDs[0], entity.ModelTypeEmbedding, embdID)
if getErr != nil {
return nil, fmt.Errorf("failed to get embedding model by embd_id: %w", getErr)
}
}
embeddingModel = models.NewEmbeddingModel(driver, &modelName, apiConfig, maxTokens)
} else {
tenantLLM, err = dao.NewTenantLLMDAO().GetByTenantAndType(tenantIDs[0], entity.ModelTypeEmbedding)
if err != nil {
return nil, fmt.Errorf("failed to get tenant default embedding model: %w", err)
driver, modelName, apiConfig, maxTokens, getErr := modelProviderSvc.GetTenantDefaultModelByType(tenantIDs[0], entity.ModelTypeEmbedding)
if getErr != nil {
return nil, fmt.Errorf("failed to get tenant default embedding model: %w", getErr)
}
if tenantLLM == nil || tenantLLM.LLMName == nil || *tenantLLM.LLMName == "" {
return nil, fmt.Errorf("no default embedding model found for tenant %s", tenantIDs[0])
}
embdID = fmt.Sprintf("%s@%s", *tenantLLM.LLMName, tenantLLM.LLMFactory)
embeddingModel = models.NewEmbeddingModel(driver, &modelName, apiConfig, maxTokens)
embdID = fmt.Sprintf("%s@default", modelName)
}
// Get embedding model for the tenant
if embeddingModel == nil {
embeddingModel, err = modelProviderSvc.GetEmbeddingModel(tenantIDs[0], embdID)
if err != nil {
return nil, fmt.Errorf("failed to get embedding model: %w", err)
}
return nil, fmt.Errorf("no embedding model found for tenant %s", tenantIDs[0])
}
common.Info("Fetched embedding model for retrieval",
zap.String("tenantID", tenantIDs[0]),
zap.String("embdID", embdID))
@@ -405,6 +415,12 @@ func (s *ChunkService) RetrievalTest(req *service.RetrievalTestRequest, userID s
}
} else if req.RerankID != nil && *req.RerankID != "" {
rerankCompositeName = *req.RerankID
if _, _, _, _, getErr := modelProviderSvc.GetModelConfigFromProviderInstance(tenantIDs[0], entity.ModelTypeRerank, rerankCompositeName); getErr != nil {
_, rerankCompositeName, err = dao.LookupTenantLLMByName(dao.NewTenantLLMDAO(), tenantIDs[0], *req.RerankID, entity.ModelTypeRerank)
if err != nil {
return nil, fmt.Errorf("failed to get rerank model by rerank_id: %w", getErr)
}
}
}
if rerankCompositeName != "" {
driver, mdlName, apiConfig, _, getErr := modelProviderSvc.GetModelConfigFromProviderInstance(tenantIDs[0], entity.ModelTypeRerank, rerankCompositeName)
@@ -466,6 +482,16 @@ func (s *ChunkService) RetrievalTest(req *service.RetrievalTestRequest, userID s
}, nil
}
func knowledgebaseEmbeddingKey(kb *entity.Knowledgebase, tenantID string) string {
if kb.TenantEmbdID != nil && *kb.TenantEmbdID > 0 {
return fmt.Sprintf("tenant:%d", *kb.TenantEmbdID)
}
if kb.EmbdID == "" {
return fmt.Sprintf("default:%s", tenantID)
}
return fmt.Sprintf("embd:%s", kb.EmbdID)
}
// hydrateChunkVectors replaces zero (placeholder) vectors in chunks with real
// vectors fetched from the engine. Infinity and OceanBase already ship real
// vectors with chunks, so this is a no-op for those engines; for ES it queries

View File

@@ -77,6 +77,55 @@ func TestHydrateChunkVectors_NoDim(t *testing.T) {
// Empty vectors have dim=0 → early return. No crash.
}
func TestKnowledgebaseEmbeddingKey(t *testing.T) {
tenantEmbdID := int64(42)
tests := []struct {
name string
kb *entity.Knowledgebase
tenantID string
want string
}{
{
name: "uses tenant embedding id before embd id",
kb: &entity.Knowledgebase{
EmbdID: "shared-model",
TenantEmbdID: &tenantEmbdID,
},
want: "tenant:42",
},
{
name: "uses embd id without tenant embedding id",
kb: &entity.Knowledgebase{
EmbdID: "shared-model",
},
want: "embd:shared-model",
},
{
name: "uses tenant default when embedding id is empty",
kb: &entity.Knowledgebase{},
tenantID: "tenant-1",
want: "default:tenant-1",
},
{
name: "ignores non-positive tenant embedding id",
kb: &entity.Knowledgebase{
EmbdID: "shared-model",
TenantEmbdID: new(int64),
},
want: "embd:shared-model",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := knowledgebaseEmbeddingKey(tt.kb, tt.tenantID); got != tt.want {
t.Fatalf("knowledgebaseEmbeddingKey() = %q, want %q", got, tt.want)
}
})
}
}
func TestParsePrevalidatesDocumentsBeforeMutating(t *testing.T) {
db := setupChunkTestDB(t)
pushChunkTestDB(t, db)

View File

@@ -17,16 +17,21 @@
package service
import (
"errors"
"fmt"
"ragflow/internal/common"
"ragflow/internal/dao"
"ragflow/internal/entity"
"strings"
"gorm.io/gorm"
)
// SearchService search service
type SearchService struct {
searchDAO *dao.SearchDAO
userTenantDAO *dao.UserTenantDAO
tenantService *TenantService
}
// NewSearchService create search service
@@ -34,6 +39,13 @@ func NewSearchService() *SearchService {
return &SearchService{
searchDAO: dao.NewSearchDAO(),
userTenantDAO: dao.NewUserTenantDAO(),
tenantService: NewTenantService(),
}
}
func (s *SearchService) SetTenantService(tenantService *TenantService) {
if tenantService != nil {
s.tenantService = tenantService
}
}
@@ -285,6 +297,210 @@ func (s *SearchService) DeleteSearch(userID string, searchID string) error {
return nil
}
// AccessibleForCompletion check if it is accessible
func (s *SearchService) AccessibleForCompletion(userID string, searchID string) (bool, error) {
ok, err := s.searchDAO.Accessible4Deletion(searchID, userID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return false, nil
}
return false, err
}
return ok, nil
}
type SearchCompletionPlan struct {
UserID string
SearchID string
Question string
KBIDs []string
ModelID string
Options AskStreamOptions
}
func (s *SearchService) PrepareCompletion(userID, searchID string, req *SearchCompletionsRequest) (*SearchCompletionPlan, common.ErrorCode, error) {
userID = strings.TrimSpace(userID)
if userID == "" {
return nil, common.CodeBadRequest, fmt.Errorf("user id is required")
}
searchID = strings.TrimSpace(searchID)
if searchID == "" {
return nil, common.CodeBadRequest, fmt.Errorf("search_id is required")
}
if req == nil {
return nil, common.CodeArgumentError, fmt.Errorf("question is required")
}
question := strings.TrimSpace(req.Question)
if question == "" {
return nil, common.CodeArgumentError, fmt.Errorf("question is required")
}
accessible, err := s.AccessibleForCompletion(userID, searchID)
if err != nil {
return nil, common.CodeServerError, err
}
if !accessible {
return nil, common.CodeAuthenticationError, fmt.Errorf("No authorization.")
}
searchDetail, err := s.GetDetail(searchID)
if err != nil || searchDetail == nil {
return nil, common.CodeDataError, fmt.Errorf("Cannot find search %s", searchID)
}
searchConfig := searchConfigMapFromValue(searchDetail["search_config"])
kbIDs := stringSliceFromSearchConfig(searchConfig["kb_ids"])
if len(kbIDs) == 0 {
kbIDs = stringSliceFromSearchConfig(req.KBIDs)
}
if len(kbIDs) == 0 {
return nil, common.CodeDataError, fmt.Errorf("`kb_ids` is required.")
}
modelID, _ := stringFromSearchConfig(searchConfig["chat_id"])
if modelID == "" {
tenantSvc := s.tenantService
if tenantSvc == nil {
tenantSvc = NewTenantService()
}
defaultModel, err := tenantSvc.GetDefaultModelName(userID, entity.ModelTypeChat)
if err == nil {
modelID = strings.TrimSpace(defaultModel)
}
}
return &SearchCompletionPlan{
UserID: userID,
SearchID: searchID,
Question: question,
KBIDs: kbIDs,
ModelID: modelID,
Options: askOptionsFromSearchConfig(searchID, searchConfig),
}, common.CodeSuccess, nil
}
func askOptionsFromSearchConfig(searchID string, searchConfig map[string]interface{}) AskStreamOptions {
opts := AskStreamOptions{
SearchID: searchID,
DocIDs: stringSliceFromSearchConfig(searchConfig["doc_ids"]),
CrossLanguages: stringSliceFromSearchConfig(searchConfig["cross_languages"]),
}
if value, ok := searchConfig["use_kg"].(bool); ok {
opts.UseKG = &value
}
if value, ok := intFromSearchConfig(searchConfig["top_k"]); ok {
opts.TopK = &value
}
if value, ok := searchConfigMapValue(searchConfig["meta_data_filter"]); ok {
opts.Filter = value
}
if value, ok := stringFromSearchConfig(searchConfig["tenant_rerank_id"]); ok {
opts.TenantRerankID = &value
}
if value, ok := stringFromSearchConfig(searchConfig["rerank_id"]); ok {
opts.RerankID = &value
}
if value, ok := searchConfig["keyword"].(bool); ok {
opts.Keyword = &value
}
if value, ok := floatFromSearchConfig(searchConfig["similarity_threshold"]); ok {
opts.SimilarityThreshold = &value
}
if value, ok := floatFromSearchConfig(searchConfig["vector_similarity_weight"]); ok {
opts.VectorSimilarityWeight = &value
}
return opts
}
func searchConfigMapFromValue(value interface{}) map[string]interface{} {
if result, ok := searchConfigMapValue(value); ok {
return result
}
return map[string]interface{}{}
}
func searchConfigMapValue(value interface{}) (map[string]interface{}, bool) {
switch typed := value.(type) {
case nil:
return nil, false
case map[string]interface{}:
return typed, true
case entity.JSONMap:
return map[string]interface{}(typed), true
default:
return nil, false
}
}
func stringSliceFromSearchConfig(value interface{}) []string {
switch typed := value.(type) {
case nil:
return nil
case []string:
result := make([]string, 0, len(typed))
for _, item := range typed {
if item = strings.TrimSpace(item); item != "" {
result = append(result, item)
}
}
return result
case common.StringSlice:
return stringSliceFromSearchConfig([]string(typed))
case []interface{}:
result := make([]string, 0, len(typed))
for _, item := range typed {
if value, ok := stringFromSearchConfig(item); ok {
result = append(result, value)
}
}
return result
default:
if value, ok := stringFromSearchConfig(value); ok {
return []string{value}
}
return nil
}
}
func stringFromSearchConfig(value interface{}) (string, bool) {
typed, ok := value.(string)
if !ok {
return "", false
}
typed = strings.TrimSpace(typed)
return typed, typed != ""
}
func intFromSearchConfig(value interface{}) (int, bool) {
switch typed := value.(type) {
case int:
return typed, true
case int64:
return int(typed), true
case float64:
return int(typed), true
case float32:
return int(typed), true
default:
return 0, false
}
}
func floatFromSearchConfig(value interface{}) (float64, bool) {
switch typed := value.(type) {
case float64:
return typed, true
case float32:
return float64(typed), true
case int:
return float64(typed), true
case int64:
return float64(typed), true
default:
return 0, false
}
}
// UpdateSearchRequest update search request
// Reference: api/apps/restful_apis/search_api.py::update
// Required fields: name, search_config
@@ -399,3 +615,8 @@ func (s *SearchService) GetDetail(searchID string) (map[string]interface{}, erro
return result, nil
}
type SearchCompletionsRequest struct {
Question string `json:"question" binding:"required"`
KBIDs []string `json:"kb_ids,omitempty"`
}