diff --git a/internal/handler/chat_session.go b/internal/handler/chat_session.go index c3489d70e5..882d3da87b 100644 --- a/internal/handler/chat_session.go +++ b/internal/handler/chat_session.go @@ -292,9 +292,10 @@ func (h *ChatSessionHandler) Completion(c *gin.Context) { // Create a channel for streaming data streamChan := make(chan string) + reqCtx := c.Request.Context() go func() { defer close(streamChan) - err := h.chatSessionService.CompletionStream(userID, req.ConversationID, processedMessages, req.LLMID, chatModelConfig, messageID, streamChan) + err := h.chatSessionService.CompletionStream(reqCtx, userID, req.ConversationID, processedMessages, req.LLMID, chatModelConfig, messageID, streamChan) if err != nil { streamChan <- fmt.Sprintf("data: %s\n\n", err.Error()) } diff --git a/internal/handler/dify_retrieval_handler.go b/internal/handler/dify_retrieval_handler.go index 51cb929d1d..366e8b9b92 100644 --- a/internal/handler/dify_retrieval_handler.go +++ b/internal/handler/dify_retrieval_handler.go @@ -24,9 +24,9 @@ import ( "strconv" "strings" - "ragflow/internal/common" - "gorm.io/gorm" "go.uber.org/zap" + "gorm.io/gorm" + "ragflow/internal/common" "ragflow/internal/engine" "ragflow/internal/entity" modelModule "ragflow/internal/entity/models" @@ -71,16 +71,16 @@ type DocumentDAOIface interface { // difyRetrievalRequest is the JSON body / query params for the Dify retrieval endpoint. type difyRetrievalRequest struct { - KnowledgeID string `json:"knowledge_id" form:"knowledge_id"` - Query string `json:"query" form:"query"` - UseKG bool `json:"use_kg" form:"use_kg"` - RetrievalSetting *difyRetrievalSetting `json:"retrieval_setting"` - MetadataCondition *difyMetadataCondition `json:"metadata_condition"` + KnowledgeID string `json:"knowledge_id" form:"knowledge_id"` + Query string `json:"query" form:"query"` + UseKG bool `json:"use_kg" form:"use_kg"` + RetrievalSetting *difyRetrievalSetting `json:"retrieval_setting"` + MetadataCondition *difyMetadataCondition `json:"metadata_condition"` } type difyRetrievalSetting struct { - TopK *int `json:"top_k" form:"top_k"` - ScoreThreshold *float64 `json:"score_threshold" form:"score_threshold"` + TopK *int `json:"top_k" form:"top_k"` + ScoreThreshold *float64 `json:"score_threshold" form:"score_threshold"` } // difyCondition is a Dify-format metadata filter condition. @@ -108,8 +108,8 @@ func (c difyMetadataCondition) toMetaFilterConditions() []service.MetaFilterCond v = fmt.Sprint(dc.Value) } result[i] = service.MetaFilterCondition{ - Key: dc.Name, - Op: dc.ComparisonOperator, + Key: dc.Name, + Op: dc.ComparisonOperator, Value: v, } } @@ -241,15 +241,15 @@ func (h *DifyRetrievalHandler) Retrieval(c *gin.Context) { metas, metaErr := h.metadataSvc.GetFlattedMetaByKBs([]string{req.KnowledgeID}) docIDs := make([]string, 0) if metaErr == nil && req.MetadataCondition != nil { - logic := req.MetadataCondition.Logic - if logic == "" { - logic = "and" - } - filteredIDs := service.ApplyMetaFilter(metas, req.MetadataCondition.toMetaFilterConditions(), logic) + logic := req.MetadataCondition.Logic + if logic == "" { + logic = "and" + } + filteredIDs := service.ApplyMetaFilter(metas, req.MetadataCondition.toMetaFilterConditions(), logic) docIDs = append(docIDs, filteredIDs...) } if len(docIDs) == 0 && req.MetadataCondition != nil { - docIDs = []string{"-999"} + docIDs = []string{service.NoMatchDocIDSentinel} } // Label question for rank features @@ -258,15 +258,15 @@ func (h *DifyRetrievalHandler) Retrieval(c *gin.Context) { // Chunk retrieval sr := &nlp.RetrievalRequest{ - Question: req.Query, - TenantIDs: []string{kb.TenantID}, - KbIDs: []string{req.KnowledgeID}, - DocIDs: docIDs, - Page: 1, - PageSize: pageSize, - Top: topK, - SimilarityThreshold: scoreThreshold, - EmbeddingModel: embModel, + Question: req.Query, + TenantIDs: []string{kb.TenantID}, + KbIDs: []string{req.KnowledgeID}, + DocIDs: docIDs, + Page: 1, + PageSize: pageSize, + Top: topK, + SimilarityThreshold: scoreThreshold, + EmbeddingModel: embModel, } if rankFeature != nil { sr.RankFeature = &rankFeature diff --git a/internal/service/chat_session.go b/internal/service/chat_session.go index fe4de1e5f9..1c1aa36d54 100644 --- a/internal/service/chat_session.go +++ b/internal/service/chat_session.go @@ -17,10 +17,13 @@ package service import ( + "context" "encoding/json" "errors" "fmt" "ragflow/internal/common" + "ragflow/internal/engine" + "ragflow/internal/service/nlp" "strings" "time" @@ -31,21 +34,61 @@ import ( modelModule "ragflow/internal/entity/models" ) +type chatKnowledgebaseStore interface { + Accessible(kbID, userID string) bool + GetByIDs(ids []string) ([]*entity.Knowledgebase, error) +} + +type chatModelProvider interface { + GetChatModel(tenantID, compositeModelName string) (*modelModule.ChatModel, error) + GetEmbeddingModel(tenantID, compositeModelName string) (*modelModule.EmbeddingModel, error) + GetRerankModel(tenantID, compositeModelName string) (*modelModule.RerankModel, error) + GetModelConfigFromProviderInstance(tenantID string, modelType entity.ModelType, modelName string) (modelModule.ModelDriver, string, *modelModule.APIConfig, int, error) + GetTenantDefaultModelByType(tenantID string, modelType entity.ModelType) (modelModule.ModelDriver, string, *modelModule.APIConfig, int, error) +} + +type chatMetadataService interface { + LabelQuestion(question string, kbs []*entity.Knowledgebase) map[string]float64 + GetFlattedMetaByKBs(kbIDs []string) (common.MetaData, error) +} + +type chatRetrievalService interface { + Retrieval(ctx context.Context, req *nlp.RetrievalRequest) (*nlp.RetrievalResult, error) +} + // ChatSessionService chat session (conversation) service type ChatSessionService struct { chatSessionDAO *dao.ChatSessionDAO chatDAO *dao.ChatDAO userTenantDAO *dao.UserTenantDAO - modelProviderSvc *ModelProviderService + kbDAO chatKnowledgebaseStore + docEngine engine.DocEngine + modelProviderSvc chatModelProvider + metadataSvc chatMetadataService + retrievalSvc chatRetrievalService } // NewChatSessionService create chat session service func NewChatSessionService() *ChatSessionService { + docEngine := engine.Get() + return newChatSessionServiceWithRetrieval(docEngine, nlp.NewRetrievalService(docEngine, dao.NewDocumentDAO())) +} + +// NewChatSessionServiceWithRetrieval creates a chat session service with a retrieval service. +func NewChatSessionServiceWithRetrieval(retrievalSvc chatRetrievalService) *ChatSessionService { + return newChatSessionServiceWithRetrieval(engine.Get(), retrievalSvc) +} + +func newChatSessionServiceWithRetrieval(docEngine engine.DocEngine, retrievalSvc chatRetrievalService) *ChatSessionService { return &ChatSessionService{ chatSessionDAO: dao.NewChatSessionDAO(), chatDAO: dao.NewChatDAO(), userTenantDAO: dao.NewUserTenantDAO(), + kbDAO: dao.NewKnowledgebaseDAO(), + docEngine: docEngine, modelProviderSvc: NewModelProviderService(), + metadataSvc: NewMetadataService(), + retrievalSvc: retrievalSvc, } } @@ -294,7 +337,7 @@ func (s *ChatSessionService) Completion(userID string, conversationID string, me } // Perform chat completion with RAG - result, err := s.asyncChat(dialog, session, messages, chatModelConfig, messageID, reference, false) + result, err := s.asyncChat(userID, dialog, session, messages, chatModelConfig, messageID, reference, false) if err != nil { return nil, err } @@ -308,7 +351,11 @@ func (s *ChatSessionService) Completion(userID string, conversationID string, me } // CompletionStream performs streaming chat completion with full RAG support -func (s *ChatSessionService) CompletionStream(userID string, conversationID string, messages []map[string]interface{}, llmID string, chatModelConfig map[string]interface{}, messageID string, streamChan chan<- string) error { +func (s *ChatSessionService) CompletionStream(ctx context.Context, userID string, conversationID string, messages []map[string]interface{}, llmID string, chatModelConfig map[string]interface{}, messageID string, streamChan chan<- string) error { + if ctx == nil { + ctx = context.Background() + } + // Validate the last message is from user if len(messages) == 0 { streamChan <- fmt.Sprintf("data: %s\n\n", `{"code": 500, "message": "messages cannot be empty", "data": {"answer": "**ERROR**: messages cannot be empty", "reference": []}}`) @@ -356,7 +403,7 @@ func (s *ChatSessionService) CompletionStream(userID string, conversationID stri } // Perform streaming chat completion with RAG - resultChan, err := s.asyncChatStream(dialog, session, messages, chatModelConfig, messageID, reference) + resultChan, err := s.asyncChatStream(ctx, userID, dialog, session, messages, chatModelConfig, messageID, reference) if err != nil { streamChan <- fmt.Sprintf("data: %s\n\n", fmt.Sprintf(`{"code": 500, "message": "%s", "data": {"answer": "**ERROR**: %s", "reference": []}}`, err.Error(), err.Error())) return err @@ -450,7 +497,7 @@ func (s *ChatSessionService) updateSessionMessages(session *entity.ChatSession, } // asyncChat performs chat with RAG support (non-streaming) -func (s *ChatSessionService) asyncChat(dialog *entity.Chat, session *entity.ChatSession, messages []map[string]interface{}, config map[string]interface{}, messageID string, reference []interface{}, stream bool) (map[string]interface{}, error) { +func (s *ChatSessionService) asyncChat(userID string, dialog *entity.Chat, session *entity.ChatSession, messages []map[string]interface{}, config map[string]interface{}, messageID string, reference []interface{}, stream bool) (map[string]interface{}, error) { // Check if we need RAG (knowledge base or tavily) hasKB := len(dialog.KBIDs) > 0 hasTavily := false @@ -465,21 +512,20 @@ func (s *ChatSessionService) asyncChat(dialog *entity.Chat, session *entity.Chat return s.asyncChatSolo(dialog, session, messages, config, messageID, reference, stream) } - // TODO: Full RAG implementation with knowledge base retrieval - // This would include: - // 1. Get embedding model and rerank model - // 2. Extract questions from messages - // 3. Retrieve chunks from knowledge bases - // 4. Rerank chunks - // 5. Build prompt with context - // 6. Call LLM + if hasKB { + return s.asyncChatWithRetrieval(context.Background(), userID, dialog, session, messages, config, messageID, reference, stream) + } - // For now, fall back to solo chat + common.Warn("Tavily-backed chat retrieval is not implemented in Go; falling back to solo chat", + zap.String("dialog_id", dialog.ID)) return s.asyncChatSolo(dialog, session, messages, config, messageID, reference, stream) } // asyncChatStream performs streaming chat with RAG support -func (s *ChatSessionService) asyncChatStream(dialog *entity.Chat, session *entity.ChatSession, messages []map[string]interface{}, config map[string]interface{}, messageID string, reference []interface{}) (<-chan map[string]interface{}, error) { +func (s *ChatSessionService) asyncChatStream(ctx context.Context, userID string, dialog *entity.Chat, session *entity.ChatSession, messages []map[string]interface{}, config map[string]interface{}, messageID string, reference []interface{}) (<-chan map[string]interface{}, error) { + if ctx == nil { + ctx = context.Background() + } resultChan := make(chan map[string]interface{}) go func() { @@ -500,14 +546,599 @@ func (s *ChatSessionService) asyncChatStream(dialog *entity.Chat, session *entit return } - // TODO: Full RAG streaming implementation - // For now, fall back to solo chat + if hasKB { + ragMessages, ragDialog, emptyResponse, err := s.messagesWithRetrievedKnowledge(ctx, userID, dialog, messages, reference) + if err != nil { + resultChan <- s.structureAnswer(session, "**ERROR**: "+err.Error(), messageID, session.ID, reference) + return + } + if emptyResponse != nil { + resultChan <- s.structureAnswer(session, *emptyResponse, messageID, session.ID, reference) + return + } + s.asyncChatSoloStream(ragDialog, session, ragMessages, config, messageID, reference, resultChan) + return + } + + common.Warn("Tavily-backed streaming chat retrieval is not implemented in Go; falling back to solo chat", + zap.String("dialog_id", dialog.ID)) s.asyncChatSoloStream(dialog, session, messages, config, messageID, reference, resultChan) }() return resultChan, nil } +func (s *ChatSessionService) asyncChatWithRetrieval(ctx context.Context, userID string, dialog *entity.Chat, session *entity.ChatSession, messages []map[string]interface{}, config map[string]interface{}, messageID string, reference []interface{}, stream bool) (map[string]interface{}, error) { + ragMessages, ragDialog, emptyResponse, err := s.messagesWithRetrievedKnowledge(ctx, userID, dialog, messages, reference) + if err != nil { + return nil, err + } + if emptyResponse != nil { + var lastRef interface{} + if len(reference) > 0 { + lastRef = reference[len(reference)-1] + } + ans := map[string]interface{}{ + "answer": *emptyResponse, + "reference": lastRef, + "final": true, + } + return s.structureAnswerWithConv(session, ans, messageID, session.ID, reference), nil + } + return s.asyncChatSolo(ragDialog, session, ragMessages, config, messageID, reference, stream) +} + +func (s *ChatSessionService) messagesWithRetrievedKnowledge(ctx context.Context, userID string, dialog *entity.Chat, messages []map[string]interface{}, reference []interface{}) ([]map[string]interface{}, *entity.Chat, *string, error) { + kbIDs := stringSliceFromJSON(dialog.KBIDs) + if len(kbIDs) == 0 { + return messages, dialog, nil, nil + } + if s.retrievalSvc == nil { + return nil, nil, nil, errors.New("retrieval service is not configured") + } + + question := latestUserQuestion(messages) + if question == "" { + return messages, dialog, nil, nil + } + + kbs, err := s.kbDAO.GetByIDs(kbIDs) + if err != nil { + return nil, nil, nil, fmt.Errorf("failed to load knowledge bases: %w", err) + } + kbs, err = s.knowledgebasesForDialog(userID, dialog, kbIDs, kbs) + if err != nil { + return nil, nil, nil, err + } + embeddingTenantID, embeddingModelName, err := validateKnowledgebaseEmbeddingModels(kbs, dialog.TenantID, resolveEmbeddingModelName) + if err != nil { + return nil, nil, nil, err + } + + embeddingModel, err := s.modelProviderSvc.GetEmbeddingModel(embeddingTenantID, embeddingModelName) + if err != nil { + return nil, nil, nil, fmt.Errorf("failed to get embedding model: %w", err) + } + rerankModel, err := s.rerankModelForDialog(dialog) + if err != nil { + return nil, nil, nil, err + } + + top := int(dialog.TopK) + pageSize := int(dialog.TopN) + if pageSize <= 0 { + pageSize = 6 + } + similarityThreshold := dialog.SimilarityThreshold + vectorSimilarityWeight := dialog.VectorSimilarityWeight + var rankFeature map[string]float64 + if s.metadataSvc != nil { + rankFeature = s.metadataSvc.LabelQuestion(question, kbs) + } + baseDocIDs := docIDsFromMessages(messages) + docIDs, err := s.filteredDocIDsForDialog(ctx, dialog, kbIDs, question, baseDocIDs) + if err != nil { + return nil, nil, nil, err + } + tenantIDs := tenantIDsFromKnowledgebases(kbs, dialog.TenantID) + + retrievalResult, err := s.retrievalSvc.Retrieval(ctx, &nlp.RetrievalRequest{ + Question: question, + TenantIDs: tenantIDs, + KbIDs: kbIDs, + DocIDs: docIDs, + Page: 1, + PageSize: pageSize, + Top: &top, + SimilarityThreshold: &similarityThreshold, + VectorSimilarityWeight: &vectorSimilarityWeight, + RankFeature: &rankFeature, + EmbeddingModel: embeddingModel, + RerankModel: rerankModel, + }) + if err != nil { + return nil, nil, nil, fmt.Errorf("retrieval search failed: %w", err) + } + if retrievalResult == nil { + retrievalResult = &nlp.RetrievalResult{} + } + + chunks := retrievalResult.Chunks + if s.docEngine != nil { + chunks = nlp.RetrievalByChildren(chunks, tenantIDs, s.docEngine, ctx) + } + setLatestReference(reference, chunks, retrievalResult.DocAggs) + knowledge := buildKnowledgeBlock(chunks) + if knowledge == "" { + return messages, dialog, emptyResponseForDialog(dialog), nil + } + if ragDialog, ok := dialogWithInjectedKnowledgePrompt(dialog, knowledge); ok { + return copyMessages(messages), ragDialog, nil, nil + } + + return injectKnowledge(messages, knowledge), dialog, nil, nil +} + +type embeddingModelNameResolver func(tenantID string, kb *entity.Knowledgebase) (string, error) + +func validateKnowledgebaseEmbeddingModels(kbs []*entity.Knowledgebase, fallbackTenantID string, resolve embeddingModelNameResolver) (string, string, error) { + if len(kbs) == 0 { + return fallbackTenantID, "", nil + } + + expected := "" + expectedKBID := "" + expectedTenantID := fallbackTenantID + for _, kb := range kbs { + if kb == nil { + return "", "", errors.New("knowledge base is nil") + } + tenantID := kb.TenantID + if tenantID == "" { + tenantID = fallbackTenantID + } + modelName, err := resolve(tenantID, kb) + if err != nil { + return "", "", err + } + modelName = strings.TrimSpace(modelName) + if modelName == "" { + return "", "", fmt.Errorf("knowledge base %s has no embedding model", kb.ID) + } + if expected == "" { + expected = modelName + expectedKBID = kb.ID + expectedTenantID = tenantID + continue + } + if modelName != expected { + return "", "", fmt.Errorf("knowledge bases must use the same embedding model: %s resolves to %q, expected %q from %s", kb.ID, modelName, expected, expectedKBID) + } + } + return expectedTenantID, expected, nil +} + +func (s *ChatSessionService) rerankModelForDialog(dialog *entity.Chat) (*modelModule.RerankModel, error) { + compositeName, err := resolveRerankModelName(dialog) + if err != nil { + return nil, err + } + if compositeName == "" { + return nil, nil + } + rerankModel, err := s.modelProviderSvc.GetRerankModel(dialog.TenantID, compositeName) + if err != nil { + return nil, fmt.Errorf("failed to get rerank model: %w", err) + } + return rerankModel, nil +} + +func (s *ChatSessionService) filteredDocIDsForDialog(ctx context.Context, dialog *entity.Chat, kbIDs []string, question string, baseDocIDs []string) ([]string, error) { + if dialog.MetaDataFilter == nil || len(*dialog.MetaDataFilter) == 0 { + return baseDocIDs, nil + } + if s.metadataSvc == nil { + return nil, errors.New("metadata service is not configured") + } + + filter := make(map[string]interface{}, len(*dialog.MetaDataFilter)) + for key, value := range *dialog.MetaDataFilter { + filter[key] = value + } + + metaData, err := s.metadataSvc.GetFlattedMetaByKBs(kbIDs) + if err != nil { + return nil, fmt.Errorf("failed to get flattened metadata for chat retrieval: %w", err) + } + + var filterChatModel *modelModule.ChatModel + method, _ := filter["method"].(string) + if method == "auto" || method == "semi_auto" { + filterChatModel, err = s.modelProviderSvc.GetChatModel(dialog.TenantID, dialog.LLMID) + if err != nil { + common.Warn("Failed to get chat model for chat metadata filter", zap.Error(err)) + } + } + + docIDs, empty := ApplyMetaDataFilter(ctx, filter, metaData, question, filterChatModel, baseDocIDs, kbIDs) + if empty { + return []string{NoMatchDocIDSentinel}, nil + } + return docIDs, nil +} + +func resolveEmbeddingModelName(tenantID string, kb *entity.Knowledgebase) (string, error) { + if kb.TenantEmbdID != nil && *kb.TenantEmbdID > 0 { + _, compositeName, err := dao.LookupTenantLLMByID(dao.NewTenantLLMDAO(), *kb.TenantEmbdID) + if err != nil { + return "", fmt.Errorf("failed to get embedding model by tenant_embd_id: %w", err) + } + return compositeName, nil + } + if kb.EmbdID != "" { + if strings.Contains(kb.EmbdID, "@") { + return kb.EmbdID, nil + } + _, compositeName, err := dao.LookupTenantLLMByName(dao.NewTenantLLMDAO(), tenantID, kb.EmbdID, entity.ModelTypeEmbedding) + if err != nil { + return "", fmt.Errorf("failed to get embedding model by embd_id: %w", err) + } + return compositeName, nil + } + + tenantLLM, err := dao.NewTenantLLMDAO().GetByTenantAndType(tenantID, entity.ModelTypeEmbedding) + if err != nil { + return "", fmt.Errorf("failed to get tenant default embedding model: %w", err) + } + if tenantLLM == nil || tenantLLM.LLMName == nil || *tenantLLM.LLMName == "" { + return "", fmt.Errorf("no default embedding model found for tenant %s", tenantID) + } + return fmt.Sprintf("%s@%s", *tenantLLM.LLMName, tenantLLM.LLMFactory), nil +} + +func resolveRerankModelName(dialog *entity.Chat) (string, error) { + if dialog.TenantRerankID != nil && *dialog.TenantRerankID > 0 { + _, compositeName, err := dao.LookupTenantLLMByID(dao.NewTenantLLMDAO(), *dialog.TenantRerankID) + if err != nil { + return "", fmt.Errorf("failed to get rerank model by tenant_rerank_id: %w", err) + } + return compositeName, nil + } + if dialog.RerankID == "" { + return "", nil + } + if strings.Contains(dialog.RerankID, "@") { + return dialog.RerankID, nil + } + _, compositeName, err := dao.LookupTenantLLMByName(dao.NewTenantLLMDAO(), dialog.TenantID, dialog.RerankID, entity.ModelTypeRerank) + if err != nil { + return "", fmt.Errorf("failed to get rerank model by rerank_id: %w", err) + } + return compositeName, nil +} + +func stringSliceFromJSON(values entity.JSONSlice) []string { + result := make([]string, 0, len(values)) + seen := make(map[string]struct{}, len(values)) + for _, value := range values { + str, ok := value.(string) + if !ok || str == "" { + continue + } + if _, exists := seen[str]; exists { + continue + } + seen[str] = struct{}{} + result = append(result, str) + } + return result +} + +func tenantIDsFromKnowledgebases(kbs []*entity.Knowledgebase, fallback string) []string { + seen := make(map[string]struct{}, len(kbs)+1) + var tenantIDs []string + for _, kb := range kbs { + if kb == nil || kb.TenantID == "" { + continue + } + if _, exists := seen[kb.TenantID]; exists { + continue + } + seen[kb.TenantID] = struct{}{} + tenantIDs = append(tenantIDs, kb.TenantID) + } + if len(tenantIDs) == 0 && fallback != "" { + tenantIDs = append(tenantIDs, fallback) + } + return tenantIDs +} + +func (s *ChatSessionService) knowledgebasesForDialog(userID string, dialog *entity.Chat, kbIDs []string, loaded []*entity.Knowledgebase) ([]*entity.Knowledgebase, error) { + byID := make(map[string]*entity.Knowledgebase, len(loaded)) + for _, kb := range loaded { + if kb != nil { + byID[kb.ID] = kb + } + } + + kbs := make([]*entity.Knowledgebase, 0, len(kbIDs)) + for _, kbID := range kbIDs { + kb := byID[kbID] + if kb == nil { + return nil, fmt.Errorf("knowledge base %s not found", kbID) + } + if userID != "" && !s.kbDAO.Accessible(kbID, userID) { + return nil, fmt.Errorf("knowledge base %s is not authorized for user", kbID) + } + if userID == "" && kb.TenantID != dialog.TenantID { + return nil, fmt.Errorf("knowledge base %s is not authorized for dialog tenant", kbID) + } + kbs = append(kbs, kb) + } + if len(kbs) == 0 { + return nil, errors.New("no valid knowledge bases found") + } + return kbs, nil +} + +func docIDsFromMessages(messages []map[string]interface{}) []string { + for i := len(messages) - 1; i >= 0; i-- { + if role, _ := messages[i]["role"].(string); role != "user" { + continue + } + return stringSliceFromValue(messages[i]["doc_ids"]) + } + return nil +} + +func latestUserQuestion(messages []map[string]interface{}) string { + for i := len(messages) - 1; i >= 0; i-- { + if role, _ := messages[i]["role"].(string); role != "user" { + continue + } + return textFromMessageContent(messages[i]["content"]) + } + return "" +} + +func stringSliceFromValue(value interface{}) []string { + switch typed := value.(type) { + case nil: + return nil + case []string: + return uniqueNonEmptyStrings(typed) + case []interface{}: + values := make([]string, 0, len(typed)) + for _, item := range typed { + if str, ok := item.(string); ok { + values = append(values, str) + } + } + return uniqueNonEmptyStrings(values) + default: + return nil + } +} + +func uniqueNonEmptyStrings(values []string) []string { + result := make([]string, 0, len(values)) + seen := make(map[string]struct{}, len(values)) + for _, value := range values { + value = strings.TrimSpace(value) + if value == "" { + continue + } + if _, exists := seen[value]; exists { + continue + } + seen[value] = struct{}{} + result = append(result, value) + } + if len(result) == 0 { + return nil + } + return result +} + +func emptyResponseForDialog(dialog *entity.Chat) *string { + if dialog.PromptConfig == nil { + return nil + } + emptyResponse, ok := dialog.PromptConfig["empty_response"].(string) + if !ok || emptyResponse == "" { + return nil + } + return &emptyResponse +} + +func buildKnowledgeBlock(chunks []map[string]interface{}) string { + var builder strings.Builder + for i, chunk := range chunks { + content := chunkText(chunk) + if content == "" { + continue + } + if builder.Len() > 0 { + builder.WriteString("\n\n") + } + builder.WriteString(fmt.Sprintf("[%d]", i+1)) + if docName, ok := chunk["docnm_kwd"].(string); ok && docName != "" { + builder.WriteString(" ") + builder.WriteString(docName) + } + builder.WriteString("\n") + builder.WriteString(content) + } + return builder.String() +} + +func chunkText(chunk map[string]interface{}) string { + for _, key := range []string{"content_with_weight", "content_ltks", "content"} { + if value, ok := chunk[key].(string); ok && strings.TrimSpace(value) != "" { + return strings.TrimSpace(value) + } + } + return "" +} + +func injectKnowledge(messages []map[string]interface{}, knowledge string) []map[string]interface{} { + copied := copyMessages(messages) + if len(copied) == 0 { + return copied + } + + knowledgePrompt := fmt.Sprintf("Use the following knowledge snippets to answer the user's question. If the snippets do not contain the answer, say that the knowledge base does not provide enough information.\n\n%s", knowledge) + for i := len(copied) - 1; i >= 0; i-- { + if role, _ := copied[i]["role"].(string); role != "user" { + continue + } + copied[i]["content"] = injectKnowledgeIntoContent(copied[i]["content"], knowledgePrompt) + return copied + } + + copied = append(copied, map[string]interface{}{ + "role": "system", + "content": knowledgePrompt, + }) + return copied +} + +func injectKnowledgeIntoContent(content interface{}, knowledgePrompt string) interface{} { + switch typed := content.(type) { + case []interface{}: + injected := make([]interface{}, 0, len(typed)+1) + injected = append(injected, knowledgeTextBlock(knowledgePrompt)) + injected = append(injected, typed...) + return injected + case []map[string]interface{}: + injected := make([]interface{}, 0, len(typed)+1) + injected = append(injected, knowledgeTextBlock(knowledgePrompt)) + for _, block := range typed { + injected = append(injected, block) + } + return injected + default: + contentText := "" + if content != nil { + contentText = fmt.Sprint(content) + } + return strings.TrimSpace(knowledgePrompt + "\n\nQuestion:\n" + contentText) + } +} + +func knowledgeTextBlock(knowledgePrompt string) map[string]interface{} { + return map[string]interface{}{ + "type": "text", + "text": knowledgePrompt + "\n\nQuestion:", + } +} + +func textFromMessageContent(content interface{}) string { + switch typed := content.(type) { + case string: + return strings.TrimSpace(typed) + case []interface{}: + return strings.TrimSpace(strings.Join(textsFromContentBlocks(typed), "\n")) + case []map[string]interface{}: + blocks := make([]interface{}, 0, len(typed)) + for _, block := range typed { + blocks = append(blocks, block) + } + return strings.TrimSpace(strings.Join(textsFromContentBlocks(blocks), "\n")) + default: + if content == nil { + return "" + } + return strings.TrimSpace(fmt.Sprint(content)) + } +} + +func textsFromContentBlocks(blocks []interface{}) []string { + texts := make([]string, 0, len(blocks)) + for _, block := range blocks { + switch typed := block.(type) { + case string: + if text := strings.TrimSpace(typed); text != "" { + texts = append(texts, text) + } + case map[string]interface{}: + if text, ok := typed["text"].(string); ok && strings.TrimSpace(text) != "" { + texts = append(texts, strings.TrimSpace(text)) + } + } + } + return texts +} + +func dialogWithInjectedKnowledgePrompt(dialog *entity.Chat, knowledge string) (*entity.Chat, bool) { + if dialog.PromptConfig == nil { + return dialog, false + } + systemPrompt, ok := dialog.PromptConfig["system"].(string) + if !ok || !strings.Contains(systemPrompt, "{knowledge}") { + return dialog, false + } + + copied := cloneJSONMap(dialog.PromptConfig) + copied["system"] = strings.ReplaceAll(systemPrompt, "{knowledge}", knowledge) + dialogCopy := *dialog + dialogCopy.PromptConfig = copied + return &dialogCopy, true +} + +func cloneJSONMap(values entity.JSONMap) entity.JSONMap { + copied := make(entity.JSONMap, len(values)) + for key, value := range values { + copied[key] = value + } + return copied +} + +func copyMessages(messages []map[string]interface{}) []map[string]interface{} { + copied := make([]map[string]interface{}, len(messages)) + for i, msg := range messages { + copied[i] = make(map[string]interface{}, len(msg)) + for key, value := range msg { + copied[i][key] = value + } + } + return copied +} + +func setLatestReference(reference []interface{}, chunks []map[string]interface{}, docAggs []map[string]interface{}) { + ref := map[string]interface{}{ + "chunks": chunksForReference(chunks), + "doc_aggs": mapsForReference(docAggs), + } + if len(reference) == 0 { + return + } + reference[len(reference)-1] = ref +} + +func chunksForReference(chunks []map[string]interface{}) []interface{} { + result := make([]interface{}, 0, len(chunks)) + for _, chunk := range chunks { + copied := make(map[string]interface{}, len(chunk)) + for key, value := range chunk { + if key == "vector" { + continue + } + copied[key] = value + } + result = append(result, copied) + } + return result +} + +func mapsForReference(values []map[string]interface{}) []interface{} { + result := make([]interface{}, 0, len(values)) + for _, value := range values { + result = append(result, value) + } + return result +} + // asyncChatSolo performs simple chat without RAG (non-streaming) func (s *ChatSessionService) asyncChatSolo(dialog *entity.Chat, session *entity.ChatSession, messages []map[string]interface{}, config map[string]interface{}, messageID string, reference []interface{}, stream bool) (map[string]interface{}, error) { common.Info("asyncChatSolo started", @@ -765,7 +1396,8 @@ func (s *ChatSessionService) structureAnswerWithConv(session *entity.ChatSession if ans["final"] == true && ans["answer"] != nil { lastMsg["content"] = ans["answer"] } else { - lastMsg["content"] = (lastMsg["content"].(string)) + content + existing, _ := lastMsg["content"].(string) + lastMsg["content"] = existing + content } lastMsg["created_at"] = float64(time.Now().Unix()) lastMsg["id"] = messageID diff --git a/internal/service/chat_session_test.go b/internal/service/chat_session_test.go new file mode 100644 index 0000000000..98d0c25c24 --- /dev/null +++ b/internal/service/chat_session_test.go @@ -0,0 +1,949 @@ +package service + +import ( + "context" + "encoding/json" + "errors" + "strings" + "testing" + + "ragflow/internal/common" + "ragflow/internal/engine/types" + "ragflow/internal/entity" + modelModule "ragflow/internal/entity/models" + "ragflow/internal/service/nlp" +) + +type fakeChatKBStore struct { + kbs []*entity.Knowledgebase + accessible map[string]bool +} + +func (f fakeChatKBStore) Accessible(kbID, userID string) bool { + if f.accessible == nil { + return true + } + return f.accessible[kbID] +} + +func (f fakeChatKBStore) GetByIDs(ids []string) ([]*entity.Knowledgebase, error) { + return f.kbs, nil +} + +type fakeChatMetadataService struct{} + +func (fakeChatMetadataService) LabelQuestion(question string, kbs []*entity.Knowledgebase) map[string]float64 { + return map[string]float64{"pagerank_fea": 10} +} + +func (fakeChatMetadataService) GetFlattedMetaByKBs(kbIDs []string) (common.MetaData, error) { + return common.MetaData{ + "category": common.MetaValueDocs{ + "policy": []string{"doc-policy"}, + }, + }, nil +} + +type failingChatMetadataService struct{} + +func (failingChatMetadataService) LabelQuestion(question string, kbs []*entity.Knowledgebase) map[string]float64 { + return nil +} + +func (failingChatMetadataService) GetFlattedMetaByKBs(kbIDs []string) (common.MetaData, error) { + return nil, errors.New("metadata unavailable") +} + +type fakeChatDocEngine struct { + chunk map[string]interface{} +} + +func (f fakeChatDocEngine) CreateChunkStore(ctx context.Context, baseName, datasetID string, vectorSize int, parserID string) error { + return nil +} + +func (f fakeChatDocEngine) InsertChunks(ctx context.Context, chunks []map[string]interface{}, baseName string, datasetID string) ([]string, error) { + return nil, nil +} + +func (f fakeChatDocEngine) UpdateChunks(ctx context.Context, condition map[string]interface{}, newValue map[string]interface{}, baseName string, datasetID string) error { + return nil +} + +func (f fakeChatDocEngine) DeleteChunks(ctx context.Context, condition map[string]interface{}, baseName string, datasetID string) (int64, error) { + return 0, nil +} + +func (f fakeChatDocEngine) Search(ctx context.Context, req *types.SearchRequest) (*types.SearchResult, error) { + return nil, nil +} + +func (f fakeChatDocEngine) GetChunk(ctx context.Context, baseName, chunkID string, datasetIDs []string) (interface{}, error) { + return f.chunk, nil +} + +func (f fakeChatDocEngine) DropChunkStore(ctx context.Context, baseName, datasetID string) error { + return nil +} + +func (f fakeChatDocEngine) ChunkStoreExists(ctx context.Context, baseName, datasetID string) (bool, error) { + return true, nil +} + +func (f fakeChatDocEngine) CreateMetadataStore(ctx context.Context, tenantID string) error { + return nil +} + +func (f fakeChatDocEngine) InsertMetadata(ctx context.Context, metadata []map[string]interface{}, tenantID string) ([]string, error) { + return nil, nil +} + +func (f fakeChatDocEngine) UpdateMetadata(ctx context.Context, docID string, datasetID string, metaFields map[string]interface{}, tenantID string) error { + return nil +} + +func (f fakeChatDocEngine) DeleteMetadata(ctx context.Context, condition map[string]interface{}, tenantID string) (int64, error) { + return 0, nil +} + +func (f fakeChatDocEngine) DeleteMetadataKeys(ctx context.Context, docID string, datasetID string, keys []string, tenantID string) error { + return nil +} + +func (f fakeChatDocEngine) DropMetadataStore(ctx context.Context, tenantID string) error { + return nil +} + +func (f fakeChatDocEngine) MetadataStoreExists(ctx context.Context, tenantID string) (bool, error) { + return true, nil +} + +func (f fakeChatDocEngine) SearchMetadata(ctx context.Context, req *types.SearchMetadataRequest) (*types.SearchMetadataResult, error) { + return nil, nil +} + +func (f fakeChatDocEngine) IndexDocument(ctx context.Context, indexName, docID string, doc interface{}) error { + return nil +} + +func (f fakeChatDocEngine) DeleteDocument(ctx context.Context, indexName, docID string) error { + return nil +} + +func (f fakeChatDocEngine) BulkIndex(ctx context.Context, indexName string, docs []interface{}) (interface{}, error) { + return nil, nil +} + +func (f fakeChatDocEngine) GetFields(chunks []map[string]interface{}, fields []string) map[string]map[string]interface{} { + return nil +} + +func (f fakeChatDocEngine) GetAggregation(chunks []map[string]interface{}, fieldName string) []map[string]interface{} { + return nil +} + +func (f fakeChatDocEngine) GetHighlight(chunks []map[string]interface{}, keywords []string, fieldName string) map[string]string { + return nil +} + +func (f fakeChatDocEngine) GetChunkIDs(chunks []map[string]interface{}) []string { + return nil +} + +func (f fakeChatDocEngine) KNNScores(ctx context.Context, chunks []map[string]interface{}, queryVector []float64, topK int) (map[string]interface{}, error) { + return nil, nil +} + +func (f fakeChatDocEngine) GetScores(searchResult map[string]interface{}) map[string]float64 { + return nil +} + +func (f fakeChatDocEngine) FilterDocIdsByMetaPushdown(ctx context.Context, kbIDs []string, conditions []map[string]interface{}, logic string) []string { + return nil +} + +func (f fakeChatDocEngine) Ping(ctx context.Context) error { + return nil +} + +func (f fakeChatDocEngine) Close() error { + return nil +} + +func (f fakeChatDocEngine) GetType() string { + return "fake" +} + +type fakeChatRetrievalService struct { + req *nlp.RetrievalRequest + result *nlp.RetrievalResult +} + +func (f *fakeChatRetrievalService) Retrieval(ctx context.Context, req *nlp.RetrievalRequest) (*nlp.RetrievalResult, error) { + f.req = req + return f.result, nil +} + +type fakeChatModelProvider struct { + driver *fakeChatModelDriver +} + +func (f fakeChatModelProvider) GetChatModel(tenantID, compositeModelName string) (*modelModule.ChatModel, error) { + modelName := compositeModelName + return modelModule.NewChatModel(f.driver, &modelName, &modelModule.APIConfig{}), nil +} + +func (f fakeChatModelProvider) GetEmbeddingModel(tenantID, compositeModelName string) (*modelModule.EmbeddingModel, error) { + modelName := compositeModelName + return modelModule.NewEmbeddingModel(f.driver, &modelName, &modelModule.APIConfig{}, 512), nil +} + +func (f fakeChatModelProvider) GetRerankModel(tenantID, compositeModelName string) (*modelModule.RerankModel, error) { + modelName := compositeModelName + return modelModule.NewRerankModel(f.driver, &modelName, &modelModule.APIConfig{}), nil +} + +func (f fakeChatModelProvider) GetModelConfigFromProviderInstance(tenantID string, modelType entity.ModelType, modelName string) (modelModule.ModelDriver, string, *modelModule.APIConfig, int, error) { + return f.driver, modelName, &modelModule.APIConfig{}, 0, nil +} + +func (f fakeChatModelProvider) GetTenantDefaultModelByType(tenantID string, modelType entity.ModelType) (modelModule.ModelDriver, string, *modelModule.APIConfig, int, error) { + modelName := "default@factory" + return f.driver, modelName, &modelModule.APIConfig{}, 0, nil +} + +type fakeChatModelDriver struct { + messages []modelModule.Message +} + +func (f *fakeChatModelDriver) NewInstance(baseURL map[string]string) modelModule.ModelDriver { + return f +} + +func (f *fakeChatModelDriver) Name() string { + return "fake" +} + +func (f *fakeChatModelDriver) ChatWithMessages(modelName string, messages []modelModule.Message, apiConfig *modelModule.APIConfig, chatModelConfig *modelModule.ChatConfig) (*modelModule.ChatResponse, error) { + f.messages = messages + answer := "answer from knowledge" + return &modelModule.ChatResponse{Answer: &answer}, nil +} + +func (f *fakeChatModelDriver) ChatStreamlyWithSender(modelName string, messages []modelModule.Message, apiConfig *modelModule.APIConfig, modelConfig *modelModule.ChatConfig, sender func(*string, *string) error) error { + f.messages = messages + answer := "stream answer from knowledge" + return sender(&answer, nil) +} + +func (f *fakeChatModelDriver) Embed(modelName *string, texts []string, apiConfig *modelModule.APIConfig, embeddingConfig *modelModule.EmbeddingConfig) ([]modelModule.EmbeddingData, error) { + return nil, nil +} + +func (f *fakeChatModelDriver) Rerank(modelName *string, query string, documents []string, apiConfig *modelModule.APIConfig, rerankConfig *modelModule.RerankConfig) (*modelModule.RerankResponse, error) { + return nil, nil +} + +func (f *fakeChatModelDriver) TranscribeAudio(modelName *string, file *string, apiConfig *modelModule.APIConfig, asrConfig *modelModule.ASRConfig) (*modelModule.ASRResponse, error) { + return nil, nil +} + +func (f *fakeChatModelDriver) TranscribeAudioWithSender(modelName *string, file *string, apiConfig *modelModule.APIConfig, asrConfig *modelModule.ASRConfig, sender func(*string, *string) error) error { + return nil +} + +func (f *fakeChatModelDriver) AudioSpeech(modelName *string, audioContent *string, apiConfig *modelModule.APIConfig, ttsConfig *modelModule.TTSConfig) (*modelModule.TTSResponse, error) { + return nil, nil +} + +func (f *fakeChatModelDriver) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *modelModule.APIConfig, ttsConfig *modelModule.TTSConfig, sender func(*string, *string) error) error { + return nil +} + +func (f *fakeChatModelDriver) OCRFile(modelName *string, content []byte, url *string, apiConfig *modelModule.APIConfig, ocrConfig *modelModule.OCRConfig) (*modelModule.OCRFileResponse, error) { + return nil, nil +} + +func (f *fakeChatModelDriver) ParseFile(modelName *string, content []byte, url *string, apiConfig *modelModule.APIConfig, parseFileConfig *modelModule.ParseFileConfig) (*modelModule.ParseFileResponse, error) { + return nil, nil +} + +func (f *fakeChatModelDriver) ListModels(apiConfig *modelModule.APIConfig) ([]modelModule.ListModelResponse, error) { + return nil, nil +} + +func (f *fakeChatModelDriver) Balance(apiConfig *modelModule.APIConfig) (map[string]interface{}, error) { + return nil, nil +} + +func (f *fakeChatModelDriver) CheckConnection(apiConfig *modelModule.APIConfig) error { + return nil +} + +func (f *fakeChatModelDriver) ListTasks(apiConfig *modelModule.APIConfig) ([]modelModule.ListTaskStatus, error) { + return nil, nil +} + +func (f *fakeChatModelDriver) ShowTask(taskID string, apiConfig *modelModule.APIConfig) (*modelModule.TaskResponse, error) { + return nil, nil +} + +func TestAsyncChatUsesRetrievedKnowledgeForKBDialog(t *testing.T) { + driver := &fakeChatModelDriver{} + retrieval := &fakeChatRetrievalService{ + result: &nlp.RetrievalResult{ + Chunks: []map[string]interface{}{ + { + "chunk_id": "chunk-1", + "content_with_weight": "RAGFlow stores conversation references alongside the session.", + "doc_id": "doc-1", + "docnm_kwd": "manual.md", + "vector": []float64{0.1, 0.2}, + }, + }, + DocAggs: []map[string]interface{}{ + {"doc_id": "doc-1", "doc_name": "manual.md", "count": 1}, + }, + }, + } + svc := &ChatSessionService{ + kbDAO: fakeChatKBStore{kbs: []*entity.Knowledgebase{ + {ID: "kb-1", TenantID: "tenant-1", Name: "Manual", EmbdID: "embed@factory"}, + }}, + modelProviderSvc: fakeChatModelProvider{driver: driver}, + metadataSvc: fakeChatMetadataService{}, + retrievalSvc: retrieval, + } + + reference := []interface{}{map[string]interface{}{"chunks": []interface{}{}, "doc_aggs": []interface{}{}}} + sessionMessage, err := json.Marshal(map[string]interface{}{"messages": []interface{}{}}) + if err != nil { + t.Fatalf("failed to marshal session message: %v", err) + } + session := &entity.ChatSession{ID: "session-1", Message: sessionMessage} + dialog := &entity.Chat{ + ID: "dialog-1", + TenantID: "tenant-1", + LLMID: "chat@factory", + PromptConfig: entity.JSONMap{"system": "You are helpful."}, + LLMSetting: entity.JSONMap{}, + KBIDs: entity.JSONSlice{"kb-1"}, + TopN: 3, + TopK: 32, + SimilarityThreshold: 0.2, + VectorSimilarityWeight: 0.3, + } + + result, err := svc.asyncChat("user-1", dialog, session, []map[string]interface{}{ + {"role": "user", "content": "Where are references stored?"}, + }, nil, "message-1", reference, false) + if err != nil { + t.Fatalf("asyncChat returned error: %v", err) + } + + if retrieval.req == nil { + t.Fatal("expected retrieval service to be called") + } + if retrieval.req.Question != "Where are references stored?" { + t.Fatalf("unexpected retrieval question: %q", retrieval.req.Question) + } + if retrieval.req.PageSize != 3 || retrieval.req.Top == nil || *retrieval.req.Top != 32 { + t.Fatalf("unexpected retrieval paging: page_size=%d top=%v", retrieval.req.PageSize, retrieval.req.Top) + } + if len(driver.messages) == 0 { + t.Fatal("expected chat model to receive messages") + } + last := driver.messages[len(driver.messages)-1] + content, ok := last.Content.(string) + if !ok { + t.Fatalf("expected string content, got %T", last.Content) + } + if !strings.Contains(content, "RAGFlow stores conversation references") { + t.Fatalf("expected retrieved content in prompt, got %q", content) + } + + ref, ok := result["reference"].(map[string]interface{}) + if !ok { + t.Fatalf("expected reference map, got %T", result["reference"]) + } + chunks, ok := ref["chunks"].([]interface{}) + if !ok || len(chunks) != 1 { + t.Fatalf("expected one reference chunk, got %#v", ref["chunks"]) + } + chunk, ok := chunks[0].(map[string]interface{}) + if !ok { + t.Fatalf("expected chunk map, got %T", chunks[0]) + } + if _, exists := chunk["vector"]; exists { + t.Fatal("reference chunk should not expose vector") + } + if result["answer"] != "answer from knowledge" { + t.Fatalf("unexpected answer: %#v", result["answer"]) + } +} + +func TestAsyncChatPropagatesRetrievalErrors(t *testing.T) { + retrievalErr := errors.New("search unavailable") + retrieval := &failingChatRetrievalService{err: retrievalErr} + svc := &ChatSessionService{ + kbDAO: fakeChatKBStore{kbs: []*entity.Knowledgebase{ + {ID: "kb-1", TenantID: "tenant-1", Name: "Manual", EmbdID: "embed@factory"}, + }}, + modelProviderSvc: fakeChatModelProvider{driver: &fakeChatModelDriver{}}, + metadataSvc: fakeChatMetadataService{}, + retrievalSvc: retrieval, + } + + _, err := svc.asyncChat("user-1", &entity.Chat{ + ID: "dialog-1", + TenantID: "tenant-1", + LLMID: "chat@factory", + PromptConfig: entity.JSONMap{}, + LLMSetting: entity.JSONMap{}, + KBIDs: entity.JSONSlice{"kb-1"}, + TopN: 3, + TopK: 32, + SimilarityThreshold: 0.2, + VectorSimilarityWeight: 0.3, + }, &entity.ChatSession{ID: "session-1"}, []map[string]interface{}{ + {"role": "user", "content": "question"}, + }, nil, "message-1", []interface{}{map[string]interface{}{"chunks": []interface{}{}, "doc_aggs": []interface{}{}}}, false) + if err == nil || !strings.Contains(err.Error(), "retrieval search failed") { + t.Fatalf("expected retrieval error, got %v", err) + } +} + +func TestMessagesWithRetrievedKnowledgeFillsSystemPlaceholder(t *testing.T) { + retrieval := &fakeChatRetrievalService{ + result: &nlp.RetrievalResult{ + Chunks: []map[string]interface{}{ + {"content_with_weight": "Knowledge inserted into the system prompt."}, + }, + }, + } + svc := &ChatSessionService{ + kbDAO: fakeChatKBStore{kbs: []*entity.Knowledgebase{ + {ID: "kb-1", TenantID: "tenant-1", Name: "Manual", EmbdID: "embed@factory"}, + }}, + modelProviderSvc: fakeChatModelProvider{driver: &fakeChatModelDriver{}}, + metadataSvc: fakeChatMetadataService{}, + retrievalSvc: retrieval, + } + dialog := &entity.Chat{ + ID: "dialog-1", + TenantID: "tenant-1", + PromptConfig: entity.JSONMap{"system": "Answer from this context: {knowledge}"}, + KBIDs: entity.JSONSlice{"kb-1"}, + TopN: 3, + TopK: 32, + } + messages := []map[string]interface{}{ + {"role": "user", "content": "What context is available?"}, + } + + got, ragDialog, emptyResponse, err := svc.messagesWithRetrievedKnowledge(context.Background(), "user-1", dialog, messages, []interface{}{ + map[string]interface{}{"chunks": []interface{}{}, "doc_aggs": []interface{}{}}, + }) + if err != nil { + t.Fatalf("messagesWithRetrievedKnowledge returned error: %v", err) + } + if emptyResponse != nil { + t.Fatalf("expected no empty response, got %q", *emptyResponse) + } + if got[0]["content"] != "What context is available?" { + t.Fatalf("expected user content to stay unchanged, got %q", got[0]["content"]) + } + originalPrompt, _ := dialog.PromptConfig["system"].(string) + if !strings.Contains(originalPrompt, "{knowledge}") { + t.Fatalf("expected original dialog prompt to remain unchanged, got %q", originalPrompt) + } + systemPrompt, _ := ragDialog.PromptConfig["system"].(string) + if strings.Contains(systemPrompt, "{knowledge}") { + t.Fatalf("expected knowledge placeholder to be replaced, got %q", systemPrompt) + } + if !strings.Contains(systemPrompt, "Knowledge inserted into the system prompt.") { + t.Fatalf("expected retrieved knowledge in system prompt, got %q", systemPrompt) + } +} + +func TestAsyncChatReturnsEmptyResponseWhenRetrievalHasNoKnowledge(t *testing.T) { + driver := &fakeChatModelDriver{} + svc := &ChatSessionService{ + kbDAO: fakeChatKBStore{kbs: []*entity.Knowledgebase{ + {ID: "kb-1", TenantID: "tenant-1", Name: "Manual", EmbdID: "embed@factory"}, + }}, + modelProviderSvc: fakeChatModelProvider{driver: driver}, + metadataSvc: fakeChatMetadataService{}, + retrievalSvc: &fakeChatRetrievalService{result: &nlp.RetrievalResult{}}, + } + reference := []interface{}{map[string]interface{}{"chunks": []interface{}{}, "doc_aggs": []interface{}{}}} + sessionMessage, err := json.Marshal(map[string]interface{}{"messages": []interface{}{}}) + if err != nil { + t.Fatalf("failed to marshal session message: %v", err) + } + result, err := svc.asyncChat("user-1", &entity.Chat{ + ID: "dialog-1", + TenantID: "tenant-1", + LLMID: "chat@factory", + PromptConfig: entity.JSONMap{"empty_response": "No relevant content."}, + LLMSetting: entity.JSONMap{}, + KBIDs: entity.JSONSlice{"kb-1"}, + TopN: 3, + TopK: 32, + }, &entity.ChatSession{ID: "session-1", Message: sessionMessage}, []map[string]interface{}{ + {"role": "user", "content": "question"}, + }, nil, "message-1", reference, false) + if err != nil { + t.Fatalf("asyncChat returned error: %v", err) + } + if result["answer"] != "No relevant content." { + t.Fatalf("unexpected empty response answer: %#v", result["answer"]) + } + if len(driver.messages) != 0 { + t.Fatal("chat model should not be called when empty_response is returned") + } +} + +func TestMessagesWithRetrievedKnowledgeAppliesMetadataFilter(t *testing.T) { + retrieval := &fakeChatRetrievalService{result: &nlp.RetrievalResult{}} + svc := &ChatSessionService{ + kbDAO: fakeChatKBStore{kbs: []*entity.Knowledgebase{ + {ID: "kb-1", TenantID: "tenant-1", Name: "Manual", EmbdID: "embed@factory"}, + }}, + modelProviderSvc: fakeChatModelProvider{driver: &fakeChatModelDriver{}}, + metadataSvc: fakeChatMetadataService{}, + retrievalSvc: retrieval, + } + filter := entity.JSONMap{ + "method": "manual", + "manual": []interface{}{ + map[string]interface{}{"key": "category", "op": "=", "value": "policy"}, + }, + "logic": "and", + } + _, _, _, err := svc.messagesWithRetrievedKnowledge(context.Background(), "user-1", &entity.Chat{ + ID: "dialog-1", + TenantID: "tenant-1", + PromptConfig: entity.JSONMap{}, + MetaDataFilter: &filter, + KBIDs: entity.JSONSlice{"kb-1"}, + TopN: 3, + TopK: 32, + }, []map[string]interface{}{ + {"role": "user", "content": "question"}, + }, []interface{}{map[string]interface{}{"chunks": []interface{}{}, "doc_aggs": []interface{}{}}}) + if err != nil { + t.Fatalf("messagesWithRetrievedKnowledge returned error: %v", err) + } + if retrieval.req == nil { + t.Fatal("expected retrieval to be called") + } + if len(retrieval.req.DocIDs) != 1 || retrieval.req.DocIDs[0] != "doc-policy" { + t.Fatalf("expected metadata-filtered doc id, got %#v", retrieval.req.DocIDs) + } +} + +func TestMessagesWithRetrievedKnowledgeIntersectsDocIDsWithMetadataFilter(t *testing.T) { + retrieval := &fakeChatRetrievalService{result: &nlp.RetrievalResult{}} + svc := &ChatSessionService{ + kbDAO: fakeChatKBStore{kbs: []*entity.Knowledgebase{ + {ID: "kb-1", TenantID: "tenant-1", Name: "Manual", EmbdID: "embed@factory"}, + }}, + modelProviderSvc: fakeChatModelProvider{driver: &fakeChatModelDriver{}}, + metadataSvc: fakeChatMetadataService{}, + retrievalSvc: retrieval, + } + filter := entity.JSONMap{ + "method": "manual", + "manual": []interface{}{ + map[string]interface{}{"key": "category", "op": "=", "value": "policy"}, + }, + "logic": "and", + } + + _, _, _, err := svc.messagesWithRetrievedKnowledge(context.Background(), "user-1", &entity.Chat{ + ID: "dialog-1", + TenantID: "tenant-1", + PromptConfig: entity.JSONMap{}, + MetaDataFilter: &filter, + KBIDs: entity.JSONSlice{"kb-1"}, + TopN: 3, + TopK: 32, + }, []map[string]interface{}{ + {"role": "user", "content": "question", "doc_ids": []interface{}{"doc-explicit", "doc-policy"}}, + }, []interface{}{map[string]interface{}{"chunks": []interface{}{}, "doc_aggs": []interface{}{}}}) + if err != nil { + t.Fatalf("messagesWithRetrievedKnowledge returned error: %v", err) + } + if len(retrieval.req.DocIDs) != 1 || retrieval.req.DocIDs[0] != "doc-policy" { + t.Fatalf("expected metadata and message doc_ids intersection, got %#v", retrieval.req.DocIDs) + } +} + +func TestMessagesWithRetrievedKnowledgeNoMetadataIntersectionUsesSentinel(t *testing.T) { + retrieval := &fakeChatRetrievalService{result: &nlp.RetrievalResult{}} + svc := &ChatSessionService{ + kbDAO: fakeChatKBStore{kbs: []*entity.Knowledgebase{ + {ID: "kb-1", TenantID: "tenant-1", Name: "Manual", EmbdID: "embed@factory"}, + }}, + modelProviderSvc: fakeChatModelProvider{driver: &fakeChatModelDriver{}}, + metadataSvc: fakeChatMetadataService{}, + retrievalSvc: retrieval, + } + filter := entity.JSONMap{ + "method": "manual", + "manual": []interface{}{ + map[string]interface{}{"key": "category", "op": "=", "value": "policy"}, + }, + "logic": "and", + } + + _, _, _, err := svc.messagesWithRetrievedKnowledge(context.Background(), "user-1", &entity.Chat{ + ID: "dialog-1", + TenantID: "tenant-1", + PromptConfig: entity.JSONMap{}, + MetaDataFilter: &filter, + KBIDs: entity.JSONSlice{"kb-1"}, + TopN: 3, + TopK: 32, + }, []map[string]interface{}{ + {"role": "user", "content": "question", "doc_ids": []interface{}{"doc-explicit"}}, + }, []interface{}{map[string]interface{}{"chunks": []interface{}{}, "doc_aggs": []interface{}{}}}) + if err != nil { + t.Fatalf("messagesWithRetrievedKnowledge returned error: %v", err) + } + if len(retrieval.req.DocIDs) != 1 || retrieval.req.DocIDs[0] != NoMatchDocIDSentinel { + t.Fatalf("expected empty metadata/doc_ids intersection sentinel, got %#v", retrieval.req.DocIDs) + } +} + +func TestMessagesWithRetrievedKnowledgePreservesEmptyMetadataFilterMatches(t *testing.T) { + retrieval := &fakeChatRetrievalService{result: &nlp.RetrievalResult{}} + svc := &ChatSessionService{ + kbDAO: fakeChatKBStore{kbs: []*entity.Knowledgebase{ + {ID: "kb-1", TenantID: "tenant-1", Name: "Manual", EmbdID: "embed@factory"}, + }}, + modelProviderSvc: fakeChatModelProvider{driver: &fakeChatModelDriver{}}, + metadataSvc: fakeChatMetadataService{}, + retrievalSvc: retrieval, + } + filter := entity.JSONMap{"method": "auto"} + + _, _, _, err := svc.messagesWithRetrievedKnowledge(context.Background(), "user-1", &entity.Chat{ + ID: "dialog-1", + TenantID: "tenant-1", + LLMID: "chat@factory", + PromptConfig: entity.JSONMap{}, + MetaDataFilter: &filter, + KBIDs: entity.JSONSlice{"kb-1"}, + TopN: 3, + TopK: 32, + }, []map[string]interface{}{ + {"role": "user", "content": "question"}, + }, []interface{}{map[string]interface{}{"chunks": []interface{}{}, "doc_aggs": []interface{}{}}}) + if err != nil { + t.Fatalf("messagesWithRetrievedKnowledge returned error: %v", err) + } + if len(retrieval.req.DocIDs) != 1 || retrieval.req.DocIDs[0] != NoMatchDocIDSentinel { + t.Fatalf("expected empty metadata filter sentinel, got %#v", retrieval.req.DocIDs) + } +} + +func TestMessagesWithRetrievedKnowledgeFailsClosedWhenMetadataUnavailable(t *testing.T) { + retrieval := &fakeChatRetrievalService{result: &nlp.RetrievalResult{}} + svc := &ChatSessionService{ + kbDAO: fakeChatKBStore{kbs: []*entity.Knowledgebase{ + {ID: "kb-1", TenantID: "tenant-1", Name: "Manual", EmbdID: "embed@factory"}, + }}, + modelProviderSvc: fakeChatModelProvider{driver: &fakeChatModelDriver{}}, + metadataSvc: failingChatMetadataService{}, + retrievalSvc: retrieval, + } + filter := entity.JSONMap{ + "method": "manual", + "manual": []interface{}{ + map[string]interface{}{"key": "category", "op": "=", "value": "policy"}, + }, + "logic": "and", + } + + _, _, _, err := svc.messagesWithRetrievedKnowledge(context.Background(), "user-1", &entity.Chat{ + ID: "dialog-1", + TenantID: "tenant-1", + PromptConfig: entity.JSONMap{}, + MetaDataFilter: &filter, + KBIDs: entity.JSONSlice{"kb-1"}, + TopN: 3, + TopK: 32, + }, []map[string]interface{}{ + {"role": "user", "content": "question", "doc_ids": []interface{}{"doc-explicit"}}, + }, []interface{}{map[string]interface{}{"chunks": []interface{}{}, "doc_aggs": []interface{}{}}}) + if err == nil || !strings.Contains(err.Error(), "flattened metadata") { + t.Fatalf("expected metadata filter error, got %v", err) + } + if retrieval.req != nil { + t.Fatal("retrieval should not run when metadata filtering cannot be evaluated") + } +} + +func TestMessagesWithRetrievedKnowledgeExpandsChildChunks(t *testing.T) { + retrieval := &fakeChatRetrievalService{result: &nlp.RetrievalResult{ + Chunks: []map[string]interface{}{ + { + "chunk_id": "child-1", + "mom_id": "parent-1", + "kb_id": "kb-1", + "doc_id": "doc-1", + "docnm_kwd": "doc.md", + "content_ltks": "child tokens", + "content_with_weight": "child-only passage", + "similarity": 0.8, + }, + }, + }} + svc := &ChatSessionService{ + kbDAO: fakeChatKBStore{kbs: []*entity.Knowledgebase{ + {ID: "kb-1", TenantID: "tenant-1", Name: "Manual", EmbdID: "embed@factory"}, + }}, + docEngine: fakeChatDocEngine{chunk: map[string]interface{}{ + "doc_id": "doc-1", + "docnm_kwd": "doc.md", + "kb_id": "kb-1", + "content_with_weight": "parent passage with surrounding context", + "position_int": []interface{}{1}, + }}, + modelProviderSvc: fakeChatModelProvider{driver: &fakeChatModelDriver{}}, + metadataSvc: fakeChatMetadataService{}, + retrievalSvc: retrieval, + } + + ragMessages, _, _, err := svc.messagesWithRetrievedKnowledge(context.Background(), "user-1", &entity.Chat{ + ID: "dialog-1", + TenantID: "tenant-1", + PromptConfig: entity.JSONMap{}, + KBIDs: entity.JSONSlice{"kb-1"}, + TopN: 3, + TopK: 32, + }, []map[string]interface{}{ + {"role": "user", "content": "question"}, + }, []interface{}{map[string]interface{}{"chunks": []interface{}{}, "doc_aggs": []interface{}{}}}) + if err != nil { + t.Fatalf("messagesWithRetrievedKnowledge returned error: %v", err) + } + content, _ := ragMessages[0]["content"].(string) + if !strings.Contains(content, "parent passage with surrounding context") { + t.Fatalf("expected expanded parent content in prompt, got %q", content) + } + if strings.Contains(content, "child-only passage") { + t.Fatalf("expected child content to be replaced by expanded parent content, got %q", content) + } +} + +func TestMessagesWithRetrievedKnowledgeRejectsCrossTenantKnowledgebase(t *testing.T) { + retrieval := &fakeChatRetrievalService{result: &nlp.RetrievalResult{}} + svc := &ChatSessionService{ + kbDAO: fakeChatKBStore{ + kbs: []*entity.Knowledgebase{ + {ID: "kb-1", TenantID: "tenant-2", Name: "Manual", EmbdID: "embed@factory"}, + }, + accessible: map[string]bool{"kb-1": false}, + }, + modelProviderSvc: fakeChatModelProvider{driver: &fakeChatModelDriver{}}, + metadataSvc: fakeChatMetadataService{}, + retrievalSvc: retrieval, + } + + _, _, _, err := svc.messagesWithRetrievedKnowledge(context.Background(), "user-1", &entity.Chat{ + ID: "dialog-1", + TenantID: "tenant-1", + PromptConfig: entity.JSONMap{}, + KBIDs: entity.JSONSlice{"kb-1"}, + TopN: 3, + TopK: 32, + }, []map[string]interface{}{ + {"role": "user", "content": "question"}, + }, []interface{}{map[string]interface{}{"chunks": []interface{}{}, "doc_aggs": []interface{}{}}}) + if err == nil || !strings.Contains(err.Error(), "not authorized") { + t.Fatalf("expected cross-tenant authorization error, got %v", err) + } + if retrieval.req != nil { + t.Fatal("retrieval should not be called for an unauthorized knowledge base") + } +} + +func TestMessagesWithRetrievedKnowledgeAllowsAccessibleSharedKnowledgebase(t *testing.T) { + retrieval := &fakeChatRetrievalService{result: &nlp.RetrievalResult{}} + svc := &ChatSessionService{ + kbDAO: fakeChatKBStore{ + kbs: []*entity.Knowledgebase{ + {ID: "kb-1", TenantID: "tenant-2", Name: "Shared Manual", EmbdID: "embed@factory"}, + }, + accessible: map[string]bool{"kb-1": true}, + }, + modelProviderSvc: fakeChatModelProvider{driver: &fakeChatModelDriver{}}, + metadataSvc: fakeChatMetadataService{}, + retrievalSvc: retrieval, + } + + _, _, _, err := svc.messagesWithRetrievedKnowledge(context.Background(), "user-1", &entity.Chat{ + ID: "dialog-1", + TenantID: "tenant-1", + PromptConfig: entity.JSONMap{}, + KBIDs: entity.JSONSlice{"kb-1"}, + TopN: 3, + TopK: 32, + }, []map[string]interface{}{ + {"role": "user", "content": "question"}, + }, []interface{}{map[string]interface{}{"chunks": []interface{}{}, "doc_aggs": []interface{}{}}}) + if err != nil { + t.Fatalf("messagesWithRetrievedKnowledge returned error: %v", err) + } + if retrieval.req == nil || len(retrieval.req.TenantIDs) != 1 || retrieval.req.TenantIDs[0] != "tenant-2" { + t.Fatalf("expected retrieval to use shared KB tenant, got %#v", retrieval.req) + } +} + +func TestMessagesWithRetrievedKnowledgeRejectsMixedEmbeddingModels(t *testing.T) { + retrieval := &fakeChatRetrievalService{result: &nlp.RetrievalResult{}} + svc := &ChatSessionService{ + kbDAO: fakeChatKBStore{kbs: []*entity.Knowledgebase{ + {ID: "kb-1", TenantID: "tenant-1", Name: "Manual", EmbdID: "embed-a@factory"}, + {ID: "kb-2", TenantID: "tenant-1", Name: "FAQ", EmbdID: "embed-b@factory"}, + }}, + modelProviderSvc: fakeChatModelProvider{driver: &fakeChatModelDriver{}}, + metadataSvc: fakeChatMetadataService{}, + retrievalSvc: retrieval, + } + + _, _, _, err := svc.messagesWithRetrievedKnowledge(context.Background(), "user-1", &entity.Chat{ + ID: "dialog-1", + TenantID: "tenant-1", + PromptConfig: entity.JSONMap{}, + KBIDs: entity.JSONSlice{"kb-1", "kb-2"}, + TopN: 3, + TopK: 32, + }, []map[string]interface{}{ + {"role": "user", "content": "question"}, + }, []interface{}{map[string]interface{}{"chunks": []interface{}{}, "doc_aggs": []interface{}{}}}) + if err == nil || !strings.Contains(err.Error(), "same embedding model") { + t.Fatalf("expected mixed embedding model error, got %v", err) + } + if retrieval.req != nil { + t.Fatal("retrieval should not run when knowledge bases use different embedding models") + } +} + +func TestValidateKnowledgebaseEmbeddingModelsComparesResolvedNames(t *testing.T) { + firstTenantEmbdID := int64(1) + secondTenantEmbdID := int64(2) + kbs := []*entity.Knowledgebase{ + {ID: "kb-1", TenantID: "tenant-1", Name: "Manual", EmbdID: "same-legacy-name", TenantEmbdID: &firstTenantEmbdID}, + {ID: "kb-2", TenantID: "tenant-1", Name: "FAQ", EmbdID: "same-legacy-name", TenantEmbdID: &secondTenantEmbdID}, + } + resolver := func(tenantID string, kb *entity.Knowledgebase) (string, error) { + if kb.TenantEmbdID != nil && *kb.TenantEmbdID == firstTenantEmbdID { + return "embed-a@factory", nil + } + return "embed-b@factory", nil + } + + _, _, err := validateKnowledgebaseEmbeddingModels(kbs, "tenant-1", resolver) + if err == nil || !strings.Contains(err.Error(), "same embedding model") { + t.Fatalf("expected resolved mixed embedding model error, got %v", err) + } +} + +func TestMessagesWithRetrievedKnowledgePreservesMultimodalContent(t *testing.T) { + retrieval := &fakeChatRetrievalService{ + result: &nlp.RetrievalResult{ + Chunks: []map[string]interface{}{ + {"content_with_weight": "Knowledge for an image question."}, + }, + }, + } + svc := &ChatSessionService{ + kbDAO: fakeChatKBStore{kbs: []*entity.Knowledgebase{ + {ID: "kb-1", TenantID: "tenant-1", Name: "Manual", EmbdID: "embed@factory"}, + }}, + modelProviderSvc: fakeChatModelProvider{driver: &fakeChatModelDriver{}}, + metadataSvc: fakeChatMetadataService{}, + retrievalSvc: retrieval, + } + imageBlock := map[string]interface{}{"type": "image_url", "image_url": map[string]interface{}{"url": "https://example.com/cat.png"}} + messages := []map[string]interface{}{ + {"role": "user", "content": []interface{}{ + map[string]interface{}{"type": "text", "text": "What is in this image?"}, + imageBlock, + }}, + } + + got, _, emptyResponse, err := svc.messagesWithRetrievedKnowledge(context.Background(), "user-1", &entity.Chat{ + ID: "dialog-1", + TenantID: "tenant-1", + PromptConfig: entity.JSONMap{}, + KBIDs: entity.JSONSlice{"kb-1"}, + TopN: 3, + TopK: 32, + }, messages, []interface{}{map[string]interface{}{"chunks": []interface{}{}, "doc_aggs": []interface{}{}}}) + if err != nil { + t.Fatalf("messagesWithRetrievedKnowledge returned error: %v", err) + } + if emptyResponse != nil { + t.Fatalf("expected no empty response, got %q", *emptyResponse) + } + content, ok := got[0]["content"].([]interface{}) + if !ok { + t.Fatalf("expected multimodal content to stay as blocks, got %T", got[0]["content"]) + } + if len(content) != 3 { + t.Fatalf("expected injected text plus original blocks, got %#v", content) + } + injected, ok := content[0].(map[string]interface{}) + if !ok || injected["type"] != "text" || !strings.Contains(injected["text"].(string), "Knowledge for an image question.") { + t.Fatalf("expected injected knowledge text block, got %#v", content[0]) + } + preservedImage, ok := content[2].(map[string]interface{}) + if !ok || preservedImage["type"] != "image_url" { + t.Fatalf("expected original image block to be preserved, got %#v", content[2]) + } + if retrieval.req == nil || retrieval.req.Question != "What is in this image?" { + t.Fatalf("expected retrieval question from text block, got %#v", retrieval.req) + } +} + +func TestMessagesWithRetrievedKnowledgePassesMessageDocIDs(t *testing.T) { + retrieval := &fakeChatRetrievalService{result: &nlp.RetrievalResult{}} + svc := &ChatSessionService{ + kbDAO: fakeChatKBStore{kbs: []*entity.Knowledgebase{ + {ID: "kb-1", TenantID: "tenant-1", Name: "Manual", EmbdID: "embed@factory"}, + }}, + modelProviderSvc: fakeChatModelProvider{driver: &fakeChatModelDriver{}}, + metadataSvc: fakeChatMetadataService{}, + retrievalSvc: retrieval, + } + + _, _, _, err := svc.messagesWithRetrievedKnowledge(context.Background(), "user-1", &entity.Chat{ + ID: "dialog-1", + TenantID: "tenant-1", + PromptConfig: entity.JSONMap{}, + KBIDs: entity.JSONSlice{"kb-1"}, + TopN: 3, + TopK: 32, + }, []map[string]interface{}{ + {"role": "user", "content": "question", "doc_ids": []interface{}{"doc-1", "doc-2", "doc-1"}}, + }, []interface{}{map[string]interface{}{"chunks": []interface{}{}, "doc_aggs": []interface{}{}}}) + if err != nil { + t.Fatalf("messagesWithRetrievedKnowledge returned error: %v", err) + } + if len(retrieval.req.DocIDs) != 2 || retrieval.req.DocIDs[0] != "doc-1" || retrieval.req.DocIDs[1] != "doc-2" { + t.Fatalf("expected scoped doc ids, got %#v", retrieval.req.DocIDs) + } +} + +type failingChatRetrievalService struct { + err error +} + +func (f *failingChatRetrievalService) Retrieval(ctx context.Context, req *nlp.RetrievalRequest) (*nlp.RetrievalResult, error) { + return nil, f.err +} diff --git a/internal/service/metadata_filter.go b/internal/service/metadata_filter.go index 9919f1f571..c5c27bd381 100644 --- a/internal/service/metadata_filter.go +++ b/internal/service/metadata_filter.go @@ -83,6 +83,9 @@ func compareValues(val1, val2, op string) bool { // ManualValueResolver is a callback function to transform manual filter values type ManualValueResolver func(map[string]interface{}) map[string]interface{} +// NoMatchDocIDSentinel forces retrieval to return no documents when filters match nothing. +const NoMatchDocIDSentinel = "-999" + // metaFilterTemplateCache caches the template content var metaFilterTemplateCache string @@ -264,7 +267,8 @@ func ApplyMetaFilter(metaData common.MetaData, filters []MetaFilterCondition, lo // - "==" = "=" "!=" = "≠" // - ">=" = "≥" "<=" = "≤" // - "is" = "=" "not is" = "≠" -// (see common.metadata_utils.operatorMapping for the full list) +// (see common.metadata_utils.operatorMapping for the full list) +// // Value conversion: // - "in" / "not in": comma-separated string → []interface{} (as expected by common.MetaFilter) // - all other operators: passed through as-is (string) @@ -280,11 +284,13 @@ func convertToMetaCondition(f MetaFilterCondition) common.MetaCondition { parts := strings.Split(strVal, ",") arr := make([]interface{}, 0, len(parts)) for _, p := range parts { - if trimmed := strings.TrimSpace(p); trimmed != "" { arr = append(arr, trimmed) } - } - mc.Value = arr - } - return mc + if trimmed := strings.TrimSpace(p); trimmed != "" { + arr = append(arr, trimmed) + } + } + mc.Value = arr + } + return mc } // applySingleCondition applies a single filter condition and returns matching doc IDs @@ -587,9 +593,6 @@ func ApplyMetaDataFilter( return baseDocIDs, false } - docIDs := make([]string, len(baseDocIDs)) - copy(docIDs, baseDocIDs) - method, _ := metaDataFilter["method"].(string) // Helper to run metadata filter with push-down fallback @@ -632,13 +635,14 @@ func ApplyMetaDataFilter( filters, err := GenMetaFilter(ctx, chatModel, metaData, question, nil) if err != nil { common.Warn("Failed to generate meta filter", zap.Error(err)) - return docIDs, false + return baseDocIDs, false } filteredIDs := runMetadataFilter(filters.Conditions, filters.Logic) - docIDs = append(docIDs, filteredIDs...) + docIDs := constrainDocIDs(baseDocIDs, filteredIDs) if len(docIDs) == 0 { return nil, true // Return nil to indicate auto filter returned empty } + return docIDs, false case "semi_auto": selectedKeys := []string{} @@ -673,13 +677,14 @@ func ApplyMetaDataFilter( filters, err := GenMetaFilter(ctx, chatModel, filteredMeta, question, constraints) if err != nil { common.Warn("Failed to generate meta filter", zap.Error(err)) - return docIDs, false + return baseDocIDs, false } filteredIDs := runMetadataFilter(filters.Conditions, filters.Logic) - docIDs = append(docIDs, filteredIDs...) + docIDs := constrainDocIDs(baseDocIDs, filteredIDs) if len(docIDs) == 0 { return nil, true } + return docIDs, false } } @@ -689,6 +694,9 @@ func ApplyMetaDataFilter( if logicVal, ok := metaDataFilter["logic"].(string); ok { logic = logicVal } + if len(manualFilters) == 0 { + return baseDocIDs, false + } // Apply manual_value_resolver callback if provided if len(manualValueResolver) > 0 && manualValueResolver[0] != nil { @@ -720,13 +728,42 @@ func ApplyMetaDataFilter( } filteredIDs := runMetadataFilter(conditions, logic) - docIDs = append(docIDs, filteredIDs...) + docIDs := constrainDocIDs(baseDocIDs, filteredIDs) if len(manualFilters) > 0 && len(docIDs) == 0 { - return []string{"-999"}, false + return []string{NoMatchDocIDSentinel}, false } + return docIDs, false } - return docIDs, false + return baseDocIDs, false +} + +func constrainDocIDs(baseDocIDs, filteredDocIDs []string) []string { + filteredDocIDs = common.Deduplicate(filteredDocIDs) + if len(baseDocIDs) == 0 { + return filteredDocIDs + } + if len(filteredDocIDs) == 0 { + return []string{} + } + + filteredSet := make(map[string]struct{}, len(filteredDocIDs)) + for _, docID := range filteredDocIDs { + filteredSet[docID] = struct{}{} + } + result := make([]string, 0, min(len(baseDocIDs), len(filteredSet))) + seen := make(map[string]struct{}, len(baseDocIDs)) + for _, docID := range baseDocIDs { + if _, allowed := filteredSet[docID]; !allowed { + continue + } + if _, exists := seen[docID]; exists { + continue + } + seen[docID] = struct{}{} + result = append(result, docID) + } + return result } // repairJSON attempts to fix common JSON formatting issues in LLM output. diff --git a/internal/service/metadata_filter_test.go b/internal/service/metadata_filter_test.go index 8aaf4c6719..7e5d46c83f 100644 --- a/internal/service/metadata_filter_test.go +++ b/internal/service/metadata_filter_test.go @@ -455,7 +455,7 @@ func TestApplySingleConditionRelationalFallsBackToString(t *testing.T) { val string want []string }{ - {op: ">", val: "banana", want: []string{"d-c"}}, // "cherry" > "banana" + {op: ">", val: "banana", want: []string{"d-c"}}, // "cherry" > "banana" {op: "<", val: "cherry", want: []string{"d-a", "d-b"}}, // apple, banana {op: ">=", val: "banana", want: []string{"d-b", "d-c"}}, {op: "<=", val: "banana", want: []string{"d-a", "d-b"}}, diff --git a/internal/service/model_service.go b/internal/service/model_service.go index 0f3cff06a2..e0f1fa6c8b 100644 --- a/internal/service/model_service.go +++ b/internal/service/model_service.go @@ -1972,6 +1972,15 @@ func (m *ModelProviderService) GetChatModel(tenantID, compositeModelName string) return modelModule.NewChatModel(driver, &modelName, apiConfig), nil } +// GetRerankModel returns a RerankModel wrapper for the given tenant +func (m *ModelProviderService) GetRerankModel(tenantID, compositeModelName string) (*modelModule.RerankModel, error) { + driver, modelName, apiConfig, _, err := m.getModelConfig(tenantID, compositeModelName) + if err != nil { + return nil, err + } + return modelModule.NewRerankModel(driver, &modelName, apiConfig), nil +} + type AddModelRequest struct { ProviderName string `json:"provider_name"` InstanceName string `json:"instance_name"`