feat(go-api): add RAG retrieval to chat completions (#15739)

## Summary
- Add knowledge-base retrieval support to Go chat completions.

## What changed
- Routes KB-backed chat sessions through the Go retrieval service
instead of falling back to solo chat.
- Resolves embedding and rerank models, validates accessible knowledge
bases, and preserves tenant-aware retrieval.
- Rejects mixed embedding models across selected knowledge bases before
retrieval to avoid incompatible vector dimensions.
- Threads the HTTP request context into streaming retrieval so cancelled
requests can stop downstream retrieval work.
- Applies metadata filters and message-level `doc_ids` before retrieval.
- Expands parent/child chunks before building references and prompt
context.
- Injects retrieved knowledge through a copied dialog prompt config so
the caller's original dialog is not mutated.
- Honors configured empty responses when no chunks are found.
- Names the metadata no-match sentinel and reuses it across
retrieval/handler paths.
- Adds a defensive content cast while appending streamed answers.
- Adds focused unit coverage for retrieval, metadata filtering,
authorization, multimodal messages, references, empty-response behavior,
prompt immutability, and mixed embedding models.

---------

Co-authored-by: Yingfeng <yingfeng.zhang@gmail.com>
Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
oktofeesh
2026-06-09 20:07:45 -07:00
committed by GitHub
parent 7c1bd9a5a5
commit bbc1f2ecec
7 changed files with 1690 additions and 62 deletions

View File

@@ -292,9 +292,10 @@ func (h *ChatSessionHandler) Completion(c *gin.Context) {
// Create a channel for streaming data
streamChan := make(chan string)
reqCtx := c.Request.Context()
go func() {
defer close(streamChan)
err := h.chatSessionService.CompletionStream(userID, req.ConversationID, processedMessages, req.LLMID, chatModelConfig, messageID, streamChan)
err := h.chatSessionService.CompletionStream(reqCtx, userID, req.ConversationID, processedMessages, req.LLMID, chatModelConfig, messageID, streamChan)
if err != nil {
streamChan <- fmt.Sprintf("data: %s\n\n", err.Error())
}

View File

@@ -24,9 +24,9 @@ import (
"strconv"
"strings"
"ragflow/internal/common"
"gorm.io/gorm"
"go.uber.org/zap"
"gorm.io/gorm"
"ragflow/internal/common"
"ragflow/internal/engine"
"ragflow/internal/entity"
modelModule "ragflow/internal/entity/models"
@@ -71,16 +71,16 @@ type DocumentDAOIface interface {
// difyRetrievalRequest is the JSON body / query params for the Dify retrieval endpoint.
type difyRetrievalRequest struct {
KnowledgeID string `json:"knowledge_id" form:"knowledge_id"`
Query string `json:"query" form:"query"`
UseKG bool `json:"use_kg" form:"use_kg"`
RetrievalSetting *difyRetrievalSetting `json:"retrieval_setting"`
MetadataCondition *difyMetadataCondition `json:"metadata_condition"`
KnowledgeID string `json:"knowledge_id" form:"knowledge_id"`
Query string `json:"query" form:"query"`
UseKG bool `json:"use_kg" form:"use_kg"`
RetrievalSetting *difyRetrievalSetting `json:"retrieval_setting"`
MetadataCondition *difyMetadataCondition `json:"metadata_condition"`
}
type difyRetrievalSetting struct {
TopK *int `json:"top_k" form:"top_k"`
ScoreThreshold *float64 `json:"score_threshold" form:"score_threshold"`
TopK *int `json:"top_k" form:"top_k"`
ScoreThreshold *float64 `json:"score_threshold" form:"score_threshold"`
}
// difyCondition is a Dify-format metadata filter condition.
@@ -108,8 +108,8 @@ func (c difyMetadataCondition) toMetaFilterConditions() []service.MetaFilterCond
v = fmt.Sprint(dc.Value)
}
result[i] = service.MetaFilterCondition{
Key: dc.Name,
Op: dc.ComparisonOperator,
Key: dc.Name,
Op: dc.ComparisonOperator,
Value: v,
}
}
@@ -241,15 +241,15 @@ func (h *DifyRetrievalHandler) Retrieval(c *gin.Context) {
metas, metaErr := h.metadataSvc.GetFlattedMetaByKBs([]string{req.KnowledgeID})
docIDs := make([]string, 0)
if metaErr == nil && req.MetadataCondition != nil {
logic := req.MetadataCondition.Logic
if logic == "" {
logic = "and"
}
filteredIDs := service.ApplyMetaFilter(metas, req.MetadataCondition.toMetaFilterConditions(), logic)
logic := req.MetadataCondition.Logic
if logic == "" {
logic = "and"
}
filteredIDs := service.ApplyMetaFilter(metas, req.MetadataCondition.toMetaFilterConditions(), logic)
docIDs = append(docIDs, filteredIDs...)
}
if len(docIDs) == 0 && req.MetadataCondition != nil {
docIDs = []string{"-999"}
docIDs = []string{service.NoMatchDocIDSentinel}
}
// Label question for rank features
@@ -258,15 +258,15 @@ func (h *DifyRetrievalHandler) Retrieval(c *gin.Context) {
// Chunk retrieval
sr := &nlp.RetrievalRequest{
Question: req.Query,
TenantIDs: []string{kb.TenantID},
KbIDs: []string{req.KnowledgeID},
DocIDs: docIDs,
Page: 1,
PageSize: pageSize,
Top: topK,
SimilarityThreshold: scoreThreshold,
EmbeddingModel: embModel,
Question: req.Query,
TenantIDs: []string{kb.TenantID},
KbIDs: []string{req.KnowledgeID},
DocIDs: docIDs,
Page: 1,
PageSize: pageSize,
Top: topK,
SimilarityThreshold: scoreThreshold,
EmbeddingModel: embModel,
}
if rankFeature != nil {
sr.RankFeature = &rankFeature

View File

@@ -17,10 +17,13 @@
package service
import (
"context"
"encoding/json"
"errors"
"fmt"
"ragflow/internal/common"
"ragflow/internal/engine"
"ragflow/internal/service/nlp"
"strings"
"time"
@@ -31,21 +34,61 @@ import (
modelModule "ragflow/internal/entity/models"
)
type chatKnowledgebaseStore interface {
Accessible(kbID, userID string) bool
GetByIDs(ids []string) ([]*entity.Knowledgebase, error)
}
type chatModelProvider interface {
GetChatModel(tenantID, compositeModelName string) (*modelModule.ChatModel, error)
GetEmbeddingModel(tenantID, compositeModelName string) (*modelModule.EmbeddingModel, error)
GetRerankModel(tenantID, compositeModelName string) (*modelModule.RerankModel, error)
GetModelConfigFromProviderInstance(tenantID string, modelType entity.ModelType, modelName string) (modelModule.ModelDriver, string, *modelModule.APIConfig, int, error)
GetTenantDefaultModelByType(tenantID string, modelType entity.ModelType) (modelModule.ModelDriver, string, *modelModule.APIConfig, int, error)
}
type chatMetadataService interface {
LabelQuestion(question string, kbs []*entity.Knowledgebase) map[string]float64
GetFlattedMetaByKBs(kbIDs []string) (common.MetaData, error)
}
type chatRetrievalService interface {
Retrieval(ctx context.Context, req *nlp.RetrievalRequest) (*nlp.RetrievalResult, error)
}
// ChatSessionService chat session (conversation) service
type ChatSessionService struct {
chatSessionDAO *dao.ChatSessionDAO
chatDAO *dao.ChatDAO
userTenantDAO *dao.UserTenantDAO
modelProviderSvc *ModelProviderService
kbDAO chatKnowledgebaseStore
docEngine engine.DocEngine
modelProviderSvc chatModelProvider
metadataSvc chatMetadataService
retrievalSvc chatRetrievalService
}
// NewChatSessionService create chat session service
func NewChatSessionService() *ChatSessionService {
docEngine := engine.Get()
return newChatSessionServiceWithRetrieval(docEngine, nlp.NewRetrievalService(docEngine, dao.NewDocumentDAO()))
}
// NewChatSessionServiceWithRetrieval creates a chat session service with a retrieval service.
func NewChatSessionServiceWithRetrieval(retrievalSvc chatRetrievalService) *ChatSessionService {
return newChatSessionServiceWithRetrieval(engine.Get(), retrievalSvc)
}
func newChatSessionServiceWithRetrieval(docEngine engine.DocEngine, retrievalSvc chatRetrievalService) *ChatSessionService {
return &ChatSessionService{
chatSessionDAO: dao.NewChatSessionDAO(),
chatDAO: dao.NewChatDAO(),
userTenantDAO: dao.NewUserTenantDAO(),
kbDAO: dao.NewKnowledgebaseDAO(),
docEngine: docEngine,
modelProviderSvc: NewModelProviderService(),
metadataSvc: NewMetadataService(),
retrievalSvc: retrievalSvc,
}
}
@@ -294,7 +337,7 @@ func (s *ChatSessionService) Completion(userID string, conversationID string, me
}
// Perform chat completion with RAG
result, err := s.asyncChat(dialog, session, messages, chatModelConfig, messageID, reference, false)
result, err := s.asyncChat(userID, dialog, session, messages, chatModelConfig, messageID, reference, false)
if err != nil {
return nil, err
}
@@ -308,7 +351,11 @@ func (s *ChatSessionService) Completion(userID string, conversationID string, me
}
// CompletionStream performs streaming chat completion with full RAG support
func (s *ChatSessionService) CompletionStream(userID string, conversationID string, messages []map[string]interface{}, llmID string, chatModelConfig map[string]interface{}, messageID string, streamChan chan<- string) error {
func (s *ChatSessionService) CompletionStream(ctx context.Context, userID string, conversationID string, messages []map[string]interface{}, llmID string, chatModelConfig map[string]interface{}, messageID string, streamChan chan<- string) error {
if ctx == nil {
ctx = context.Background()
}
// Validate the last message is from user
if len(messages) == 0 {
streamChan <- fmt.Sprintf("data: %s\n\n", `{"code": 500, "message": "messages cannot be empty", "data": {"answer": "**ERROR**: messages cannot be empty", "reference": []}}`)
@@ -356,7 +403,7 @@ func (s *ChatSessionService) CompletionStream(userID string, conversationID stri
}
// Perform streaming chat completion with RAG
resultChan, err := s.asyncChatStream(dialog, session, messages, chatModelConfig, messageID, reference)
resultChan, err := s.asyncChatStream(ctx, userID, dialog, session, messages, chatModelConfig, messageID, reference)
if err != nil {
streamChan <- fmt.Sprintf("data: %s\n\n", fmt.Sprintf(`{"code": 500, "message": "%s", "data": {"answer": "**ERROR**: %s", "reference": []}}`, err.Error(), err.Error()))
return err
@@ -450,7 +497,7 @@ func (s *ChatSessionService) updateSessionMessages(session *entity.ChatSession,
}
// asyncChat performs chat with RAG support (non-streaming)
func (s *ChatSessionService) asyncChat(dialog *entity.Chat, session *entity.ChatSession, messages []map[string]interface{}, config map[string]interface{}, messageID string, reference []interface{}, stream bool) (map[string]interface{}, error) {
func (s *ChatSessionService) asyncChat(userID string, dialog *entity.Chat, session *entity.ChatSession, messages []map[string]interface{}, config map[string]interface{}, messageID string, reference []interface{}, stream bool) (map[string]interface{}, error) {
// Check if we need RAG (knowledge base or tavily)
hasKB := len(dialog.KBIDs) > 0
hasTavily := false
@@ -465,21 +512,20 @@ func (s *ChatSessionService) asyncChat(dialog *entity.Chat, session *entity.Chat
return s.asyncChatSolo(dialog, session, messages, config, messageID, reference, stream)
}
// TODO: Full RAG implementation with knowledge base retrieval
// This would include:
// 1. Get embedding model and rerank model
// 2. Extract questions from messages
// 3. Retrieve chunks from knowledge bases
// 4. Rerank chunks
// 5. Build prompt with context
// 6. Call LLM
if hasKB {
return s.asyncChatWithRetrieval(context.Background(), userID, dialog, session, messages, config, messageID, reference, stream)
}
// For now, fall back to solo chat
common.Warn("Tavily-backed chat retrieval is not implemented in Go; falling back to solo chat",
zap.String("dialog_id", dialog.ID))
return s.asyncChatSolo(dialog, session, messages, config, messageID, reference, stream)
}
// asyncChatStream performs streaming chat with RAG support
func (s *ChatSessionService) asyncChatStream(dialog *entity.Chat, session *entity.ChatSession, messages []map[string]interface{}, config map[string]interface{}, messageID string, reference []interface{}) (<-chan map[string]interface{}, error) {
func (s *ChatSessionService) asyncChatStream(ctx context.Context, userID string, dialog *entity.Chat, session *entity.ChatSession, messages []map[string]interface{}, config map[string]interface{}, messageID string, reference []interface{}) (<-chan map[string]interface{}, error) {
if ctx == nil {
ctx = context.Background()
}
resultChan := make(chan map[string]interface{})
go func() {
@@ -500,14 +546,599 @@ func (s *ChatSessionService) asyncChatStream(dialog *entity.Chat, session *entit
return
}
// TODO: Full RAG streaming implementation
// For now, fall back to solo chat
if hasKB {
ragMessages, ragDialog, emptyResponse, err := s.messagesWithRetrievedKnowledge(ctx, userID, dialog, messages, reference)
if err != nil {
resultChan <- s.structureAnswer(session, "**ERROR**: "+err.Error(), messageID, session.ID, reference)
return
}
if emptyResponse != nil {
resultChan <- s.structureAnswer(session, *emptyResponse, messageID, session.ID, reference)
return
}
s.asyncChatSoloStream(ragDialog, session, ragMessages, config, messageID, reference, resultChan)
return
}
common.Warn("Tavily-backed streaming chat retrieval is not implemented in Go; falling back to solo chat",
zap.String("dialog_id", dialog.ID))
s.asyncChatSoloStream(dialog, session, messages, config, messageID, reference, resultChan)
}()
return resultChan, nil
}
func (s *ChatSessionService) asyncChatWithRetrieval(ctx context.Context, userID string, dialog *entity.Chat, session *entity.ChatSession, messages []map[string]interface{}, config map[string]interface{}, messageID string, reference []interface{}, stream bool) (map[string]interface{}, error) {
ragMessages, ragDialog, emptyResponse, err := s.messagesWithRetrievedKnowledge(ctx, userID, dialog, messages, reference)
if err != nil {
return nil, err
}
if emptyResponse != nil {
var lastRef interface{}
if len(reference) > 0 {
lastRef = reference[len(reference)-1]
}
ans := map[string]interface{}{
"answer": *emptyResponse,
"reference": lastRef,
"final": true,
}
return s.structureAnswerWithConv(session, ans, messageID, session.ID, reference), nil
}
return s.asyncChatSolo(ragDialog, session, ragMessages, config, messageID, reference, stream)
}
func (s *ChatSessionService) messagesWithRetrievedKnowledge(ctx context.Context, userID string, dialog *entity.Chat, messages []map[string]interface{}, reference []interface{}) ([]map[string]interface{}, *entity.Chat, *string, error) {
kbIDs := stringSliceFromJSON(dialog.KBIDs)
if len(kbIDs) == 0 {
return messages, dialog, nil, nil
}
if s.retrievalSvc == nil {
return nil, nil, nil, errors.New("retrieval service is not configured")
}
question := latestUserQuestion(messages)
if question == "" {
return messages, dialog, nil, nil
}
kbs, err := s.kbDAO.GetByIDs(kbIDs)
if err != nil {
return nil, nil, nil, fmt.Errorf("failed to load knowledge bases: %w", err)
}
kbs, err = s.knowledgebasesForDialog(userID, dialog, kbIDs, kbs)
if err != nil {
return nil, nil, nil, err
}
embeddingTenantID, embeddingModelName, err := validateKnowledgebaseEmbeddingModels(kbs, dialog.TenantID, resolveEmbeddingModelName)
if err != nil {
return nil, nil, nil, err
}
embeddingModel, err := s.modelProviderSvc.GetEmbeddingModel(embeddingTenantID, embeddingModelName)
if err != nil {
return nil, nil, nil, fmt.Errorf("failed to get embedding model: %w", err)
}
rerankModel, err := s.rerankModelForDialog(dialog)
if err != nil {
return nil, nil, nil, err
}
top := int(dialog.TopK)
pageSize := int(dialog.TopN)
if pageSize <= 0 {
pageSize = 6
}
similarityThreshold := dialog.SimilarityThreshold
vectorSimilarityWeight := dialog.VectorSimilarityWeight
var rankFeature map[string]float64
if s.metadataSvc != nil {
rankFeature = s.metadataSvc.LabelQuestion(question, kbs)
}
baseDocIDs := docIDsFromMessages(messages)
docIDs, err := s.filteredDocIDsForDialog(ctx, dialog, kbIDs, question, baseDocIDs)
if err != nil {
return nil, nil, nil, err
}
tenantIDs := tenantIDsFromKnowledgebases(kbs, dialog.TenantID)
retrievalResult, err := s.retrievalSvc.Retrieval(ctx, &nlp.RetrievalRequest{
Question: question,
TenantIDs: tenantIDs,
KbIDs: kbIDs,
DocIDs: docIDs,
Page: 1,
PageSize: pageSize,
Top: &top,
SimilarityThreshold: &similarityThreshold,
VectorSimilarityWeight: &vectorSimilarityWeight,
RankFeature: &rankFeature,
EmbeddingModel: embeddingModel,
RerankModel: rerankModel,
})
if err != nil {
return nil, nil, nil, fmt.Errorf("retrieval search failed: %w", err)
}
if retrievalResult == nil {
retrievalResult = &nlp.RetrievalResult{}
}
chunks := retrievalResult.Chunks
if s.docEngine != nil {
chunks = nlp.RetrievalByChildren(chunks, tenantIDs, s.docEngine, ctx)
}
setLatestReference(reference, chunks, retrievalResult.DocAggs)
knowledge := buildKnowledgeBlock(chunks)
if knowledge == "" {
return messages, dialog, emptyResponseForDialog(dialog), nil
}
if ragDialog, ok := dialogWithInjectedKnowledgePrompt(dialog, knowledge); ok {
return copyMessages(messages), ragDialog, nil, nil
}
return injectKnowledge(messages, knowledge), dialog, nil, nil
}
type embeddingModelNameResolver func(tenantID string, kb *entity.Knowledgebase) (string, error)
func validateKnowledgebaseEmbeddingModels(kbs []*entity.Knowledgebase, fallbackTenantID string, resolve embeddingModelNameResolver) (string, string, error) {
if len(kbs) == 0 {
return fallbackTenantID, "", nil
}
expected := ""
expectedKBID := ""
expectedTenantID := fallbackTenantID
for _, kb := range kbs {
if kb == nil {
return "", "", errors.New("knowledge base is nil")
}
tenantID := kb.TenantID
if tenantID == "" {
tenantID = fallbackTenantID
}
modelName, err := resolve(tenantID, kb)
if err != nil {
return "", "", err
}
modelName = strings.TrimSpace(modelName)
if modelName == "" {
return "", "", fmt.Errorf("knowledge base %s has no embedding model", kb.ID)
}
if expected == "" {
expected = modelName
expectedKBID = kb.ID
expectedTenantID = tenantID
continue
}
if modelName != expected {
return "", "", fmt.Errorf("knowledge bases must use the same embedding model: %s resolves to %q, expected %q from %s", kb.ID, modelName, expected, expectedKBID)
}
}
return expectedTenantID, expected, nil
}
func (s *ChatSessionService) rerankModelForDialog(dialog *entity.Chat) (*modelModule.RerankModel, error) {
compositeName, err := resolveRerankModelName(dialog)
if err != nil {
return nil, err
}
if compositeName == "" {
return nil, nil
}
rerankModel, err := s.modelProviderSvc.GetRerankModel(dialog.TenantID, compositeName)
if err != nil {
return nil, fmt.Errorf("failed to get rerank model: %w", err)
}
return rerankModel, nil
}
func (s *ChatSessionService) filteredDocIDsForDialog(ctx context.Context, dialog *entity.Chat, kbIDs []string, question string, baseDocIDs []string) ([]string, error) {
if dialog.MetaDataFilter == nil || len(*dialog.MetaDataFilter) == 0 {
return baseDocIDs, nil
}
if s.metadataSvc == nil {
return nil, errors.New("metadata service is not configured")
}
filter := make(map[string]interface{}, len(*dialog.MetaDataFilter))
for key, value := range *dialog.MetaDataFilter {
filter[key] = value
}
metaData, err := s.metadataSvc.GetFlattedMetaByKBs(kbIDs)
if err != nil {
return nil, fmt.Errorf("failed to get flattened metadata for chat retrieval: %w", err)
}
var filterChatModel *modelModule.ChatModel
method, _ := filter["method"].(string)
if method == "auto" || method == "semi_auto" {
filterChatModel, err = s.modelProviderSvc.GetChatModel(dialog.TenantID, dialog.LLMID)
if err != nil {
common.Warn("Failed to get chat model for chat metadata filter", zap.Error(err))
}
}
docIDs, empty := ApplyMetaDataFilter(ctx, filter, metaData, question, filterChatModel, baseDocIDs, kbIDs)
if empty {
return []string{NoMatchDocIDSentinel}, nil
}
return docIDs, nil
}
func resolveEmbeddingModelName(tenantID string, kb *entity.Knowledgebase) (string, error) {
if kb.TenantEmbdID != nil && *kb.TenantEmbdID > 0 {
_, compositeName, err := dao.LookupTenantLLMByID(dao.NewTenantLLMDAO(), *kb.TenantEmbdID)
if err != nil {
return "", fmt.Errorf("failed to get embedding model by tenant_embd_id: %w", err)
}
return compositeName, nil
}
if kb.EmbdID != "" {
if strings.Contains(kb.EmbdID, "@") {
return kb.EmbdID, nil
}
_, compositeName, err := dao.LookupTenantLLMByName(dao.NewTenantLLMDAO(), tenantID, kb.EmbdID, entity.ModelTypeEmbedding)
if err != nil {
return "", fmt.Errorf("failed to get embedding model by embd_id: %w", err)
}
return compositeName, nil
}
tenantLLM, err := dao.NewTenantLLMDAO().GetByTenantAndType(tenantID, entity.ModelTypeEmbedding)
if err != nil {
return "", fmt.Errorf("failed to get tenant default embedding model: %w", err)
}
if tenantLLM == nil || tenantLLM.LLMName == nil || *tenantLLM.LLMName == "" {
return "", fmt.Errorf("no default embedding model found for tenant %s", tenantID)
}
return fmt.Sprintf("%s@%s", *tenantLLM.LLMName, tenantLLM.LLMFactory), nil
}
func resolveRerankModelName(dialog *entity.Chat) (string, error) {
if dialog.TenantRerankID != nil && *dialog.TenantRerankID > 0 {
_, compositeName, err := dao.LookupTenantLLMByID(dao.NewTenantLLMDAO(), *dialog.TenantRerankID)
if err != nil {
return "", fmt.Errorf("failed to get rerank model by tenant_rerank_id: %w", err)
}
return compositeName, nil
}
if dialog.RerankID == "" {
return "", nil
}
if strings.Contains(dialog.RerankID, "@") {
return dialog.RerankID, nil
}
_, compositeName, err := dao.LookupTenantLLMByName(dao.NewTenantLLMDAO(), dialog.TenantID, dialog.RerankID, entity.ModelTypeRerank)
if err != nil {
return "", fmt.Errorf("failed to get rerank model by rerank_id: %w", err)
}
return compositeName, nil
}
func stringSliceFromJSON(values entity.JSONSlice) []string {
result := make([]string, 0, len(values))
seen := make(map[string]struct{}, len(values))
for _, value := range values {
str, ok := value.(string)
if !ok || str == "" {
continue
}
if _, exists := seen[str]; exists {
continue
}
seen[str] = struct{}{}
result = append(result, str)
}
return result
}
func tenantIDsFromKnowledgebases(kbs []*entity.Knowledgebase, fallback string) []string {
seen := make(map[string]struct{}, len(kbs)+1)
var tenantIDs []string
for _, kb := range kbs {
if kb == nil || kb.TenantID == "" {
continue
}
if _, exists := seen[kb.TenantID]; exists {
continue
}
seen[kb.TenantID] = struct{}{}
tenantIDs = append(tenantIDs, kb.TenantID)
}
if len(tenantIDs) == 0 && fallback != "" {
tenantIDs = append(tenantIDs, fallback)
}
return tenantIDs
}
func (s *ChatSessionService) knowledgebasesForDialog(userID string, dialog *entity.Chat, kbIDs []string, loaded []*entity.Knowledgebase) ([]*entity.Knowledgebase, error) {
byID := make(map[string]*entity.Knowledgebase, len(loaded))
for _, kb := range loaded {
if kb != nil {
byID[kb.ID] = kb
}
}
kbs := make([]*entity.Knowledgebase, 0, len(kbIDs))
for _, kbID := range kbIDs {
kb := byID[kbID]
if kb == nil {
return nil, fmt.Errorf("knowledge base %s not found", kbID)
}
if userID != "" && !s.kbDAO.Accessible(kbID, userID) {
return nil, fmt.Errorf("knowledge base %s is not authorized for user", kbID)
}
if userID == "" && kb.TenantID != dialog.TenantID {
return nil, fmt.Errorf("knowledge base %s is not authorized for dialog tenant", kbID)
}
kbs = append(kbs, kb)
}
if len(kbs) == 0 {
return nil, errors.New("no valid knowledge bases found")
}
return kbs, nil
}
func docIDsFromMessages(messages []map[string]interface{}) []string {
for i := len(messages) - 1; i >= 0; i-- {
if role, _ := messages[i]["role"].(string); role != "user" {
continue
}
return stringSliceFromValue(messages[i]["doc_ids"])
}
return nil
}
func latestUserQuestion(messages []map[string]interface{}) string {
for i := len(messages) - 1; i >= 0; i-- {
if role, _ := messages[i]["role"].(string); role != "user" {
continue
}
return textFromMessageContent(messages[i]["content"])
}
return ""
}
func stringSliceFromValue(value interface{}) []string {
switch typed := value.(type) {
case nil:
return nil
case []string:
return uniqueNonEmptyStrings(typed)
case []interface{}:
values := make([]string, 0, len(typed))
for _, item := range typed {
if str, ok := item.(string); ok {
values = append(values, str)
}
}
return uniqueNonEmptyStrings(values)
default:
return nil
}
}
func uniqueNonEmptyStrings(values []string) []string {
result := make([]string, 0, len(values))
seen := make(map[string]struct{}, len(values))
for _, value := range values {
value = strings.TrimSpace(value)
if value == "" {
continue
}
if _, exists := seen[value]; exists {
continue
}
seen[value] = struct{}{}
result = append(result, value)
}
if len(result) == 0 {
return nil
}
return result
}
func emptyResponseForDialog(dialog *entity.Chat) *string {
if dialog.PromptConfig == nil {
return nil
}
emptyResponse, ok := dialog.PromptConfig["empty_response"].(string)
if !ok || emptyResponse == "" {
return nil
}
return &emptyResponse
}
func buildKnowledgeBlock(chunks []map[string]interface{}) string {
var builder strings.Builder
for i, chunk := range chunks {
content := chunkText(chunk)
if content == "" {
continue
}
if builder.Len() > 0 {
builder.WriteString("\n\n")
}
builder.WriteString(fmt.Sprintf("[%d]", i+1))
if docName, ok := chunk["docnm_kwd"].(string); ok && docName != "" {
builder.WriteString(" ")
builder.WriteString(docName)
}
builder.WriteString("\n")
builder.WriteString(content)
}
return builder.String()
}
func chunkText(chunk map[string]interface{}) string {
for _, key := range []string{"content_with_weight", "content_ltks", "content"} {
if value, ok := chunk[key].(string); ok && strings.TrimSpace(value) != "" {
return strings.TrimSpace(value)
}
}
return ""
}
func injectKnowledge(messages []map[string]interface{}, knowledge string) []map[string]interface{} {
copied := copyMessages(messages)
if len(copied) == 0 {
return copied
}
knowledgePrompt := fmt.Sprintf("Use the following knowledge snippets to answer the user's question. If the snippets do not contain the answer, say that the knowledge base does not provide enough information.\n\n%s", knowledge)
for i := len(copied) - 1; i >= 0; i-- {
if role, _ := copied[i]["role"].(string); role != "user" {
continue
}
copied[i]["content"] = injectKnowledgeIntoContent(copied[i]["content"], knowledgePrompt)
return copied
}
copied = append(copied, map[string]interface{}{
"role": "system",
"content": knowledgePrompt,
})
return copied
}
func injectKnowledgeIntoContent(content interface{}, knowledgePrompt string) interface{} {
switch typed := content.(type) {
case []interface{}:
injected := make([]interface{}, 0, len(typed)+1)
injected = append(injected, knowledgeTextBlock(knowledgePrompt))
injected = append(injected, typed...)
return injected
case []map[string]interface{}:
injected := make([]interface{}, 0, len(typed)+1)
injected = append(injected, knowledgeTextBlock(knowledgePrompt))
for _, block := range typed {
injected = append(injected, block)
}
return injected
default:
contentText := ""
if content != nil {
contentText = fmt.Sprint(content)
}
return strings.TrimSpace(knowledgePrompt + "\n\nQuestion:\n" + contentText)
}
}
func knowledgeTextBlock(knowledgePrompt string) map[string]interface{} {
return map[string]interface{}{
"type": "text",
"text": knowledgePrompt + "\n\nQuestion:",
}
}
func textFromMessageContent(content interface{}) string {
switch typed := content.(type) {
case string:
return strings.TrimSpace(typed)
case []interface{}:
return strings.TrimSpace(strings.Join(textsFromContentBlocks(typed), "\n"))
case []map[string]interface{}:
blocks := make([]interface{}, 0, len(typed))
for _, block := range typed {
blocks = append(blocks, block)
}
return strings.TrimSpace(strings.Join(textsFromContentBlocks(blocks), "\n"))
default:
if content == nil {
return ""
}
return strings.TrimSpace(fmt.Sprint(content))
}
}
func textsFromContentBlocks(blocks []interface{}) []string {
texts := make([]string, 0, len(blocks))
for _, block := range blocks {
switch typed := block.(type) {
case string:
if text := strings.TrimSpace(typed); text != "" {
texts = append(texts, text)
}
case map[string]interface{}:
if text, ok := typed["text"].(string); ok && strings.TrimSpace(text) != "" {
texts = append(texts, strings.TrimSpace(text))
}
}
}
return texts
}
func dialogWithInjectedKnowledgePrompt(dialog *entity.Chat, knowledge string) (*entity.Chat, bool) {
if dialog.PromptConfig == nil {
return dialog, false
}
systemPrompt, ok := dialog.PromptConfig["system"].(string)
if !ok || !strings.Contains(systemPrompt, "{knowledge}") {
return dialog, false
}
copied := cloneJSONMap(dialog.PromptConfig)
copied["system"] = strings.ReplaceAll(systemPrompt, "{knowledge}", knowledge)
dialogCopy := *dialog
dialogCopy.PromptConfig = copied
return &dialogCopy, true
}
func cloneJSONMap(values entity.JSONMap) entity.JSONMap {
copied := make(entity.JSONMap, len(values))
for key, value := range values {
copied[key] = value
}
return copied
}
func copyMessages(messages []map[string]interface{}) []map[string]interface{} {
copied := make([]map[string]interface{}, len(messages))
for i, msg := range messages {
copied[i] = make(map[string]interface{}, len(msg))
for key, value := range msg {
copied[i][key] = value
}
}
return copied
}
func setLatestReference(reference []interface{}, chunks []map[string]interface{}, docAggs []map[string]interface{}) {
ref := map[string]interface{}{
"chunks": chunksForReference(chunks),
"doc_aggs": mapsForReference(docAggs),
}
if len(reference) == 0 {
return
}
reference[len(reference)-1] = ref
}
func chunksForReference(chunks []map[string]interface{}) []interface{} {
result := make([]interface{}, 0, len(chunks))
for _, chunk := range chunks {
copied := make(map[string]interface{}, len(chunk))
for key, value := range chunk {
if key == "vector" {
continue
}
copied[key] = value
}
result = append(result, copied)
}
return result
}
func mapsForReference(values []map[string]interface{}) []interface{} {
result := make([]interface{}, 0, len(values))
for _, value := range values {
result = append(result, value)
}
return result
}
// asyncChatSolo performs simple chat without RAG (non-streaming)
func (s *ChatSessionService) asyncChatSolo(dialog *entity.Chat, session *entity.ChatSession, messages []map[string]interface{}, config map[string]interface{}, messageID string, reference []interface{}, stream bool) (map[string]interface{}, error) {
common.Info("asyncChatSolo started",
@@ -765,7 +1396,8 @@ func (s *ChatSessionService) structureAnswerWithConv(session *entity.ChatSession
if ans["final"] == true && ans["answer"] != nil {
lastMsg["content"] = ans["answer"]
} else {
lastMsg["content"] = (lastMsg["content"].(string)) + content
existing, _ := lastMsg["content"].(string)
lastMsg["content"] = existing + content
}
lastMsg["created_at"] = float64(time.Now().Unix())
lastMsg["id"] = messageID

View File

@@ -0,0 +1,949 @@
package service
import (
"context"
"encoding/json"
"errors"
"strings"
"testing"
"ragflow/internal/common"
"ragflow/internal/engine/types"
"ragflow/internal/entity"
modelModule "ragflow/internal/entity/models"
"ragflow/internal/service/nlp"
)
type fakeChatKBStore struct {
kbs []*entity.Knowledgebase
accessible map[string]bool
}
func (f fakeChatKBStore) Accessible(kbID, userID string) bool {
if f.accessible == nil {
return true
}
return f.accessible[kbID]
}
func (f fakeChatKBStore) GetByIDs(ids []string) ([]*entity.Knowledgebase, error) {
return f.kbs, nil
}
type fakeChatMetadataService struct{}
func (fakeChatMetadataService) LabelQuestion(question string, kbs []*entity.Knowledgebase) map[string]float64 {
return map[string]float64{"pagerank_fea": 10}
}
func (fakeChatMetadataService) GetFlattedMetaByKBs(kbIDs []string) (common.MetaData, error) {
return common.MetaData{
"category": common.MetaValueDocs{
"policy": []string{"doc-policy"},
},
}, nil
}
type failingChatMetadataService struct{}
func (failingChatMetadataService) LabelQuestion(question string, kbs []*entity.Knowledgebase) map[string]float64 {
return nil
}
func (failingChatMetadataService) GetFlattedMetaByKBs(kbIDs []string) (common.MetaData, error) {
return nil, errors.New("metadata unavailable")
}
type fakeChatDocEngine struct {
chunk map[string]interface{}
}
func (f fakeChatDocEngine) CreateChunkStore(ctx context.Context, baseName, datasetID string, vectorSize int, parserID string) error {
return nil
}
func (f fakeChatDocEngine) InsertChunks(ctx context.Context, chunks []map[string]interface{}, baseName string, datasetID string) ([]string, error) {
return nil, nil
}
func (f fakeChatDocEngine) UpdateChunks(ctx context.Context, condition map[string]interface{}, newValue map[string]interface{}, baseName string, datasetID string) error {
return nil
}
func (f fakeChatDocEngine) DeleteChunks(ctx context.Context, condition map[string]interface{}, baseName string, datasetID string) (int64, error) {
return 0, nil
}
func (f fakeChatDocEngine) Search(ctx context.Context, req *types.SearchRequest) (*types.SearchResult, error) {
return nil, nil
}
func (f fakeChatDocEngine) GetChunk(ctx context.Context, baseName, chunkID string, datasetIDs []string) (interface{}, error) {
return f.chunk, nil
}
func (f fakeChatDocEngine) DropChunkStore(ctx context.Context, baseName, datasetID string) error {
return nil
}
func (f fakeChatDocEngine) ChunkStoreExists(ctx context.Context, baseName, datasetID string) (bool, error) {
return true, nil
}
func (f fakeChatDocEngine) CreateMetadataStore(ctx context.Context, tenantID string) error {
return nil
}
func (f fakeChatDocEngine) InsertMetadata(ctx context.Context, metadata []map[string]interface{}, tenantID string) ([]string, error) {
return nil, nil
}
func (f fakeChatDocEngine) UpdateMetadata(ctx context.Context, docID string, datasetID string, metaFields map[string]interface{}, tenantID string) error {
return nil
}
func (f fakeChatDocEngine) DeleteMetadata(ctx context.Context, condition map[string]interface{}, tenantID string) (int64, error) {
return 0, nil
}
func (f fakeChatDocEngine) DeleteMetadataKeys(ctx context.Context, docID string, datasetID string, keys []string, tenantID string) error {
return nil
}
func (f fakeChatDocEngine) DropMetadataStore(ctx context.Context, tenantID string) error {
return nil
}
func (f fakeChatDocEngine) MetadataStoreExists(ctx context.Context, tenantID string) (bool, error) {
return true, nil
}
func (f fakeChatDocEngine) SearchMetadata(ctx context.Context, req *types.SearchMetadataRequest) (*types.SearchMetadataResult, error) {
return nil, nil
}
func (f fakeChatDocEngine) IndexDocument(ctx context.Context, indexName, docID string, doc interface{}) error {
return nil
}
func (f fakeChatDocEngine) DeleteDocument(ctx context.Context, indexName, docID string) error {
return nil
}
func (f fakeChatDocEngine) BulkIndex(ctx context.Context, indexName string, docs []interface{}) (interface{}, error) {
return nil, nil
}
func (f fakeChatDocEngine) GetFields(chunks []map[string]interface{}, fields []string) map[string]map[string]interface{} {
return nil
}
func (f fakeChatDocEngine) GetAggregation(chunks []map[string]interface{}, fieldName string) []map[string]interface{} {
return nil
}
func (f fakeChatDocEngine) GetHighlight(chunks []map[string]interface{}, keywords []string, fieldName string) map[string]string {
return nil
}
func (f fakeChatDocEngine) GetChunkIDs(chunks []map[string]interface{}) []string {
return nil
}
func (f fakeChatDocEngine) KNNScores(ctx context.Context, chunks []map[string]interface{}, queryVector []float64, topK int) (map[string]interface{}, error) {
return nil, nil
}
func (f fakeChatDocEngine) GetScores(searchResult map[string]interface{}) map[string]float64 {
return nil
}
func (f fakeChatDocEngine) FilterDocIdsByMetaPushdown(ctx context.Context, kbIDs []string, conditions []map[string]interface{}, logic string) []string {
return nil
}
func (f fakeChatDocEngine) Ping(ctx context.Context) error {
return nil
}
func (f fakeChatDocEngine) Close() error {
return nil
}
func (f fakeChatDocEngine) GetType() string {
return "fake"
}
type fakeChatRetrievalService struct {
req *nlp.RetrievalRequest
result *nlp.RetrievalResult
}
func (f *fakeChatRetrievalService) Retrieval(ctx context.Context, req *nlp.RetrievalRequest) (*nlp.RetrievalResult, error) {
f.req = req
return f.result, nil
}
type fakeChatModelProvider struct {
driver *fakeChatModelDriver
}
func (f fakeChatModelProvider) GetChatModel(tenantID, compositeModelName string) (*modelModule.ChatModel, error) {
modelName := compositeModelName
return modelModule.NewChatModel(f.driver, &modelName, &modelModule.APIConfig{}), nil
}
func (f fakeChatModelProvider) GetEmbeddingModel(tenantID, compositeModelName string) (*modelModule.EmbeddingModel, error) {
modelName := compositeModelName
return modelModule.NewEmbeddingModel(f.driver, &modelName, &modelModule.APIConfig{}, 512), nil
}
func (f fakeChatModelProvider) GetRerankModel(tenantID, compositeModelName string) (*modelModule.RerankModel, error) {
modelName := compositeModelName
return modelModule.NewRerankModel(f.driver, &modelName, &modelModule.APIConfig{}), nil
}
func (f fakeChatModelProvider) GetModelConfigFromProviderInstance(tenantID string, modelType entity.ModelType, modelName string) (modelModule.ModelDriver, string, *modelModule.APIConfig, int, error) {
return f.driver, modelName, &modelModule.APIConfig{}, 0, nil
}
func (f fakeChatModelProvider) GetTenantDefaultModelByType(tenantID string, modelType entity.ModelType) (modelModule.ModelDriver, string, *modelModule.APIConfig, int, error) {
modelName := "default@factory"
return f.driver, modelName, &modelModule.APIConfig{}, 0, nil
}
type fakeChatModelDriver struct {
messages []modelModule.Message
}
func (f *fakeChatModelDriver) NewInstance(baseURL map[string]string) modelModule.ModelDriver {
return f
}
func (f *fakeChatModelDriver) Name() string {
return "fake"
}
func (f *fakeChatModelDriver) ChatWithMessages(modelName string, messages []modelModule.Message, apiConfig *modelModule.APIConfig, chatModelConfig *modelModule.ChatConfig) (*modelModule.ChatResponse, error) {
f.messages = messages
answer := "answer from knowledge"
return &modelModule.ChatResponse{Answer: &answer}, nil
}
func (f *fakeChatModelDriver) ChatStreamlyWithSender(modelName string, messages []modelModule.Message, apiConfig *modelModule.APIConfig, modelConfig *modelModule.ChatConfig, sender func(*string, *string) error) error {
f.messages = messages
answer := "stream answer from knowledge"
return sender(&answer, nil)
}
func (f *fakeChatModelDriver) Embed(modelName *string, texts []string, apiConfig *modelModule.APIConfig, embeddingConfig *modelModule.EmbeddingConfig) ([]modelModule.EmbeddingData, error) {
return nil, nil
}
func (f *fakeChatModelDriver) Rerank(modelName *string, query string, documents []string, apiConfig *modelModule.APIConfig, rerankConfig *modelModule.RerankConfig) (*modelModule.RerankResponse, error) {
return nil, nil
}
func (f *fakeChatModelDriver) TranscribeAudio(modelName *string, file *string, apiConfig *modelModule.APIConfig, asrConfig *modelModule.ASRConfig) (*modelModule.ASRResponse, error) {
return nil, nil
}
func (f *fakeChatModelDriver) TranscribeAudioWithSender(modelName *string, file *string, apiConfig *modelModule.APIConfig, asrConfig *modelModule.ASRConfig, sender func(*string, *string) error) error {
return nil
}
func (f *fakeChatModelDriver) AudioSpeech(modelName *string, audioContent *string, apiConfig *modelModule.APIConfig, ttsConfig *modelModule.TTSConfig) (*modelModule.TTSResponse, error) {
return nil, nil
}
func (f *fakeChatModelDriver) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *modelModule.APIConfig, ttsConfig *modelModule.TTSConfig, sender func(*string, *string) error) error {
return nil
}
func (f *fakeChatModelDriver) OCRFile(modelName *string, content []byte, url *string, apiConfig *modelModule.APIConfig, ocrConfig *modelModule.OCRConfig) (*modelModule.OCRFileResponse, error) {
return nil, nil
}
func (f *fakeChatModelDriver) ParseFile(modelName *string, content []byte, url *string, apiConfig *modelModule.APIConfig, parseFileConfig *modelModule.ParseFileConfig) (*modelModule.ParseFileResponse, error) {
return nil, nil
}
func (f *fakeChatModelDriver) ListModels(apiConfig *modelModule.APIConfig) ([]modelModule.ListModelResponse, error) {
return nil, nil
}
func (f *fakeChatModelDriver) Balance(apiConfig *modelModule.APIConfig) (map[string]interface{}, error) {
return nil, nil
}
func (f *fakeChatModelDriver) CheckConnection(apiConfig *modelModule.APIConfig) error {
return nil
}
func (f *fakeChatModelDriver) ListTasks(apiConfig *modelModule.APIConfig) ([]modelModule.ListTaskStatus, error) {
return nil, nil
}
func (f *fakeChatModelDriver) ShowTask(taskID string, apiConfig *modelModule.APIConfig) (*modelModule.TaskResponse, error) {
return nil, nil
}
func TestAsyncChatUsesRetrievedKnowledgeForKBDialog(t *testing.T) {
driver := &fakeChatModelDriver{}
retrieval := &fakeChatRetrievalService{
result: &nlp.RetrievalResult{
Chunks: []map[string]interface{}{
{
"chunk_id": "chunk-1",
"content_with_weight": "RAGFlow stores conversation references alongside the session.",
"doc_id": "doc-1",
"docnm_kwd": "manual.md",
"vector": []float64{0.1, 0.2},
},
},
DocAggs: []map[string]interface{}{
{"doc_id": "doc-1", "doc_name": "manual.md", "count": 1},
},
},
}
svc := &ChatSessionService{
kbDAO: fakeChatKBStore{kbs: []*entity.Knowledgebase{
{ID: "kb-1", TenantID: "tenant-1", Name: "Manual", EmbdID: "embed@factory"},
}},
modelProviderSvc: fakeChatModelProvider{driver: driver},
metadataSvc: fakeChatMetadataService{},
retrievalSvc: retrieval,
}
reference := []interface{}{map[string]interface{}{"chunks": []interface{}{}, "doc_aggs": []interface{}{}}}
sessionMessage, err := json.Marshal(map[string]interface{}{"messages": []interface{}{}})
if err != nil {
t.Fatalf("failed to marshal session message: %v", err)
}
session := &entity.ChatSession{ID: "session-1", Message: sessionMessage}
dialog := &entity.Chat{
ID: "dialog-1",
TenantID: "tenant-1",
LLMID: "chat@factory",
PromptConfig: entity.JSONMap{"system": "You are helpful."},
LLMSetting: entity.JSONMap{},
KBIDs: entity.JSONSlice{"kb-1"},
TopN: 3,
TopK: 32,
SimilarityThreshold: 0.2,
VectorSimilarityWeight: 0.3,
}
result, err := svc.asyncChat("user-1", dialog, session, []map[string]interface{}{
{"role": "user", "content": "Where are references stored?"},
}, nil, "message-1", reference, false)
if err != nil {
t.Fatalf("asyncChat returned error: %v", err)
}
if retrieval.req == nil {
t.Fatal("expected retrieval service to be called")
}
if retrieval.req.Question != "Where are references stored?" {
t.Fatalf("unexpected retrieval question: %q", retrieval.req.Question)
}
if retrieval.req.PageSize != 3 || retrieval.req.Top == nil || *retrieval.req.Top != 32 {
t.Fatalf("unexpected retrieval paging: page_size=%d top=%v", retrieval.req.PageSize, retrieval.req.Top)
}
if len(driver.messages) == 0 {
t.Fatal("expected chat model to receive messages")
}
last := driver.messages[len(driver.messages)-1]
content, ok := last.Content.(string)
if !ok {
t.Fatalf("expected string content, got %T", last.Content)
}
if !strings.Contains(content, "RAGFlow stores conversation references") {
t.Fatalf("expected retrieved content in prompt, got %q", content)
}
ref, ok := result["reference"].(map[string]interface{})
if !ok {
t.Fatalf("expected reference map, got %T", result["reference"])
}
chunks, ok := ref["chunks"].([]interface{})
if !ok || len(chunks) != 1 {
t.Fatalf("expected one reference chunk, got %#v", ref["chunks"])
}
chunk, ok := chunks[0].(map[string]interface{})
if !ok {
t.Fatalf("expected chunk map, got %T", chunks[0])
}
if _, exists := chunk["vector"]; exists {
t.Fatal("reference chunk should not expose vector")
}
if result["answer"] != "answer from knowledge" {
t.Fatalf("unexpected answer: %#v", result["answer"])
}
}
func TestAsyncChatPropagatesRetrievalErrors(t *testing.T) {
retrievalErr := errors.New("search unavailable")
retrieval := &failingChatRetrievalService{err: retrievalErr}
svc := &ChatSessionService{
kbDAO: fakeChatKBStore{kbs: []*entity.Knowledgebase{
{ID: "kb-1", TenantID: "tenant-1", Name: "Manual", EmbdID: "embed@factory"},
}},
modelProviderSvc: fakeChatModelProvider{driver: &fakeChatModelDriver{}},
metadataSvc: fakeChatMetadataService{},
retrievalSvc: retrieval,
}
_, err := svc.asyncChat("user-1", &entity.Chat{
ID: "dialog-1",
TenantID: "tenant-1",
LLMID: "chat@factory",
PromptConfig: entity.JSONMap{},
LLMSetting: entity.JSONMap{},
KBIDs: entity.JSONSlice{"kb-1"},
TopN: 3,
TopK: 32,
SimilarityThreshold: 0.2,
VectorSimilarityWeight: 0.3,
}, &entity.ChatSession{ID: "session-1"}, []map[string]interface{}{
{"role": "user", "content": "question"},
}, nil, "message-1", []interface{}{map[string]interface{}{"chunks": []interface{}{}, "doc_aggs": []interface{}{}}}, false)
if err == nil || !strings.Contains(err.Error(), "retrieval search failed") {
t.Fatalf("expected retrieval error, got %v", err)
}
}
func TestMessagesWithRetrievedKnowledgeFillsSystemPlaceholder(t *testing.T) {
retrieval := &fakeChatRetrievalService{
result: &nlp.RetrievalResult{
Chunks: []map[string]interface{}{
{"content_with_weight": "Knowledge inserted into the system prompt."},
},
},
}
svc := &ChatSessionService{
kbDAO: fakeChatKBStore{kbs: []*entity.Knowledgebase{
{ID: "kb-1", TenantID: "tenant-1", Name: "Manual", EmbdID: "embed@factory"},
}},
modelProviderSvc: fakeChatModelProvider{driver: &fakeChatModelDriver{}},
metadataSvc: fakeChatMetadataService{},
retrievalSvc: retrieval,
}
dialog := &entity.Chat{
ID: "dialog-1",
TenantID: "tenant-1",
PromptConfig: entity.JSONMap{"system": "Answer from this context: {knowledge}"},
KBIDs: entity.JSONSlice{"kb-1"},
TopN: 3,
TopK: 32,
}
messages := []map[string]interface{}{
{"role": "user", "content": "What context is available?"},
}
got, ragDialog, emptyResponse, err := svc.messagesWithRetrievedKnowledge(context.Background(), "user-1", dialog, messages, []interface{}{
map[string]interface{}{"chunks": []interface{}{}, "doc_aggs": []interface{}{}},
})
if err != nil {
t.Fatalf("messagesWithRetrievedKnowledge returned error: %v", err)
}
if emptyResponse != nil {
t.Fatalf("expected no empty response, got %q", *emptyResponse)
}
if got[0]["content"] != "What context is available?" {
t.Fatalf("expected user content to stay unchanged, got %q", got[0]["content"])
}
originalPrompt, _ := dialog.PromptConfig["system"].(string)
if !strings.Contains(originalPrompt, "{knowledge}") {
t.Fatalf("expected original dialog prompt to remain unchanged, got %q", originalPrompt)
}
systemPrompt, _ := ragDialog.PromptConfig["system"].(string)
if strings.Contains(systemPrompt, "{knowledge}") {
t.Fatalf("expected knowledge placeholder to be replaced, got %q", systemPrompt)
}
if !strings.Contains(systemPrompt, "Knowledge inserted into the system prompt.") {
t.Fatalf("expected retrieved knowledge in system prompt, got %q", systemPrompt)
}
}
func TestAsyncChatReturnsEmptyResponseWhenRetrievalHasNoKnowledge(t *testing.T) {
driver := &fakeChatModelDriver{}
svc := &ChatSessionService{
kbDAO: fakeChatKBStore{kbs: []*entity.Knowledgebase{
{ID: "kb-1", TenantID: "tenant-1", Name: "Manual", EmbdID: "embed@factory"},
}},
modelProviderSvc: fakeChatModelProvider{driver: driver},
metadataSvc: fakeChatMetadataService{},
retrievalSvc: &fakeChatRetrievalService{result: &nlp.RetrievalResult{}},
}
reference := []interface{}{map[string]interface{}{"chunks": []interface{}{}, "doc_aggs": []interface{}{}}}
sessionMessage, err := json.Marshal(map[string]interface{}{"messages": []interface{}{}})
if err != nil {
t.Fatalf("failed to marshal session message: %v", err)
}
result, err := svc.asyncChat("user-1", &entity.Chat{
ID: "dialog-1",
TenantID: "tenant-1",
LLMID: "chat@factory",
PromptConfig: entity.JSONMap{"empty_response": "No relevant content."},
LLMSetting: entity.JSONMap{},
KBIDs: entity.JSONSlice{"kb-1"},
TopN: 3,
TopK: 32,
}, &entity.ChatSession{ID: "session-1", Message: sessionMessage}, []map[string]interface{}{
{"role": "user", "content": "question"},
}, nil, "message-1", reference, false)
if err != nil {
t.Fatalf("asyncChat returned error: %v", err)
}
if result["answer"] != "No relevant content." {
t.Fatalf("unexpected empty response answer: %#v", result["answer"])
}
if len(driver.messages) != 0 {
t.Fatal("chat model should not be called when empty_response is returned")
}
}
func TestMessagesWithRetrievedKnowledgeAppliesMetadataFilter(t *testing.T) {
retrieval := &fakeChatRetrievalService{result: &nlp.RetrievalResult{}}
svc := &ChatSessionService{
kbDAO: fakeChatKBStore{kbs: []*entity.Knowledgebase{
{ID: "kb-1", TenantID: "tenant-1", Name: "Manual", EmbdID: "embed@factory"},
}},
modelProviderSvc: fakeChatModelProvider{driver: &fakeChatModelDriver{}},
metadataSvc: fakeChatMetadataService{},
retrievalSvc: retrieval,
}
filter := entity.JSONMap{
"method": "manual",
"manual": []interface{}{
map[string]interface{}{"key": "category", "op": "=", "value": "policy"},
},
"logic": "and",
}
_, _, _, err := svc.messagesWithRetrievedKnowledge(context.Background(), "user-1", &entity.Chat{
ID: "dialog-1",
TenantID: "tenant-1",
PromptConfig: entity.JSONMap{},
MetaDataFilter: &filter,
KBIDs: entity.JSONSlice{"kb-1"},
TopN: 3,
TopK: 32,
}, []map[string]interface{}{
{"role": "user", "content": "question"},
}, []interface{}{map[string]interface{}{"chunks": []interface{}{}, "doc_aggs": []interface{}{}}})
if err != nil {
t.Fatalf("messagesWithRetrievedKnowledge returned error: %v", err)
}
if retrieval.req == nil {
t.Fatal("expected retrieval to be called")
}
if len(retrieval.req.DocIDs) != 1 || retrieval.req.DocIDs[0] != "doc-policy" {
t.Fatalf("expected metadata-filtered doc id, got %#v", retrieval.req.DocIDs)
}
}
func TestMessagesWithRetrievedKnowledgeIntersectsDocIDsWithMetadataFilter(t *testing.T) {
retrieval := &fakeChatRetrievalService{result: &nlp.RetrievalResult{}}
svc := &ChatSessionService{
kbDAO: fakeChatKBStore{kbs: []*entity.Knowledgebase{
{ID: "kb-1", TenantID: "tenant-1", Name: "Manual", EmbdID: "embed@factory"},
}},
modelProviderSvc: fakeChatModelProvider{driver: &fakeChatModelDriver{}},
metadataSvc: fakeChatMetadataService{},
retrievalSvc: retrieval,
}
filter := entity.JSONMap{
"method": "manual",
"manual": []interface{}{
map[string]interface{}{"key": "category", "op": "=", "value": "policy"},
},
"logic": "and",
}
_, _, _, err := svc.messagesWithRetrievedKnowledge(context.Background(), "user-1", &entity.Chat{
ID: "dialog-1",
TenantID: "tenant-1",
PromptConfig: entity.JSONMap{},
MetaDataFilter: &filter,
KBIDs: entity.JSONSlice{"kb-1"},
TopN: 3,
TopK: 32,
}, []map[string]interface{}{
{"role": "user", "content": "question", "doc_ids": []interface{}{"doc-explicit", "doc-policy"}},
}, []interface{}{map[string]interface{}{"chunks": []interface{}{}, "doc_aggs": []interface{}{}}})
if err != nil {
t.Fatalf("messagesWithRetrievedKnowledge returned error: %v", err)
}
if len(retrieval.req.DocIDs) != 1 || retrieval.req.DocIDs[0] != "doc-policy" {
t.Fatalf("expected metadata and message doc_ids intersection, got %#v", retrieval.req.DocIDs)
}
}
func TestMessagesWithRetrievedKnowledgeNoMetadataIntersectionUsesSentinel(t *testing.T) {
retrieval := &fakeChatRetrievalService{result: &nlp.RetrievalResult{}}
svc := &ChatSessionService{
kbDAO: fakeChatKBStore{kbs: []*entity.Knowledgebase{
{ID: "kb-1", TenantID: "tenant-1", Name: "Manual", EmbdID: "embed@factory"},
}},
modelProviderSvc: fakeChatModelProvider{driver: &fakeChatModelDriver{}},
metadataSvc: fakeChatMetadataService{},
retrievalSvc: retrieval,
}
filter := entity.JSONMap{
"method": "manual",
"manual": []interface{}{
map[string]interface{}{"key": "category", "op": "=", "value": "policy"},
},
"logic": "and",
}
_, _, _, err := svc.messagesWithRetrievedKnowledge(context.Background(), "user-1", &entity.Chat{
ID: "dialog-1",
TenantID: "tenant-1",
PromptConfig: entity.JSONMap{},
MetaDataFilter: &filter,
KBIDs: entity.JSONSlice{"kb-1"},
TopN: 3,
TopK: 32,
}, []map[string]interface{}{
{"role": "user", "content": "question", "doc_ids": []interface{}{"doc-explicit"}},
}, []interface{}{map[string]interface{}{"chunks": []interface{}{}, "doc_aggs": []interface{}{}}})
if err != nil {
t.Fatalf("messagesWithRetrievedKnowledge returned error: %v", err)
}
if len(retrieval.req.DocIDs) != 1 || retrieval.req.DocIDs[0] != NoMatchDocIDSentinel {
t.Fatalf("expected empty metadata/doc_ids intersection sentinel, got %#v", retrieval.req.DocIDs)
}
}
func TestMessagesWithRetrievedKnowledgePreservesEmptyMetadataFilterMatches(t *testing.T) {
retrieval := &fakeChatRetrievalService{result: &nlp.RetrievalResult{}}
svc := &ChatSessionService{
kbDAO: fakeChatKBStore{kbs: []*entity.Knowledgebase{
{ID: "kb-1", TenantID: "tenant-1", Name: "Manual", EmbdID: "embed@factory"},
}},
modelProviderSvc: fakeChatModelProvider{driver: &fakeChatModelDriver{}},
metadataSvc: fakeChatMetadataService{},
retrievalSvc: retrieval,
}
filter := entity.JSONMap{"method": "auto"}
_, _, _, err := svc.messagesWithRetrievedKnowledge(context.Background(), "user-1", &entity.Chat{
ID: "dialog-1",
TenantID: "tenant-1",
LLMID: "chat@factory",
PromptConfig: entity.JSONMap{},
MetaDataFilter: &filter,
KBIDs: entity.JSONSlice{"kb-1"},
TopN: 3,
TopK: 32,
}, []map[string]interface{}{
{"role": "user", "content": "question"},
}, []interface{}{map[string]interface{}{"chunks": []interface{}{}, "doc_aggs": []interface{}{}}})
if err != nil {
t.Fatalf("messagesWithRetrievedKnowledge returned error: %v", err)
}
if len(retrieval.req.DocIDs) != 1 || retrieval.req.DocIDs[0] != NoMatchDocIDSentinel {
t.Fatalf("expected empty metadata filter sentinel, got %#v", retrieval.req.DocIDs)
}
}
func TestMessagesWithRetrievedKnowledgeFailsClosedWhenMetadataUnavailable(t *testing.T) {
retrieval := &fakeChatRetrievalService{result: &nlp.RetrievalResult{}}
svc := &ChatSessionService{
kbDAO: fakeChatKBStore{kbs: []*entity.Knowledgebase{
{ID: "kb-1", TenantID: "tenant-1", Name: "Manual", EmbdID: "embed@factory"},
}},
modelProviderSvc: fakeChatModelProvider{driver: &fakeChatModelDriver{}},
metadataSvc: failingChatMetadataService{},
retrievalSvc: retrieval,
}
filter := entity.JSONMap{
"method": "manual",
"manual": []interface{}{
map[string]interface{}{"key": "category", "op": "=", "value": "policy"},
},
"logic": "and",
}
_, _, _, err := svc.messagesWithRetrievedKnowledge(context.Background(), "user-1", &entity.Chat{
ID: "dialog-1",
TenantID: "tenant-1",
PromptConfig: entity.JSONMap{},
MetaDataFilter: &filter,
KBIDs: entity.JSONSlice{"kb-1"},
TopN: 3,
TopK: 32,
}, []map[string]interface{}{
{"role": "user", "content": "question", "doc_ids": []interface{}{"doc-explicit"}},
}, []interface{}{map[string]interface{}{"chunks": []interface{}{}, "doc_aggs": []interface{}{}}})
if err == nil || !strings.Contains(err.Error(), "flattened metadata") {
t.Fatalf("expected metadata filter error, got %v", err)
}
if retrieval.req != nil {
t.Fatal("retrieval should not run when metadata filtering cannot be evaluated")
}
}
func TestMessagesWithRetrievedKnowledgeExpandsChildChunks(t *testing.T) {
retrieval := &fakeChatRetrievalService{result: &nlp.RetrievalResult{
Chunks: []map[string]interface{}{
{
"chunk_id": "child-1",
"mom_id": "parent-1",
"kb_id": "kb-1",
"doc_id": "doc-1",
"docnm_kwd": "doc.md",
"content_ltks": "child tokens",
"content_with_weight": "child-only passage",
"similarity": 0.8,
},
},
}}
svc := &ChatSessionService{
kbDAO: fakeChatKBStore{kbs: []*entity.Knowledgebase{
{ID: "kb-1", TenantID: "tenant-1", Name: "Manual", EmbdID: "embed@factory"},
}},
docEngine: fakeChatDocEngine{chunk: map[string]interface{}{
"doc_id": "doc-1",
"docnm_kwd": "doc.md",
"kb_id": "kb-1",
"content_with_weight": "parent passage with surrounding context",
"position_int": []interface{}{1},
}},
modelProviderSvc: fakeChatModelProvider{driver: &fakeChatModelDriver{}},
metadataSvc: fakeChatMetadataService{},
retrievalSvc: retrieval,
}
ragMessages, _, _, err := svc.messagesWithRetrievedKnowledge(context.Background(), "user-1", &entity.Chat{
ID: "dialog-1",
TenantID: "tenant-1",
PromptConfig: entity.JSONMap{},
KBIDs: entity.JSONSlice{"kb-1"},
TopN: 3,
TopK: 32,
}, []map[string]interface{}{
{"role": "user", "content": "question"},
}, []interface{}{map[string]interface{}{"chunks": []interface{}{}, "doc_aggs": []interface{}{}}})
if err != nil {
t.Fatalf("messagesWithRetrievedKnowledge returned error: %v", err)
}
content, _ := ragMessages[0]["content"].(string)
if !strings.Contains(content, "parent passage with surrounding context") {
t.Fatalf("expected expanded parent content in prompt, got %q", content)
}
if strings.Contains(content, "child-only passage") {
t.Fatalf("expected child content to be replaced by expanded parent content, got %q", content)
}
}
func TestMessagesWithRetrievedKnowledgeRejectsCrossTenantKnowledgebase(t *testing.T) {
retrieval := &fakeChatRetrievalService{result: &nlp.RetrievalResult{}}
svc := &ChatSessionService{
kbDAO: fakeChatKBStore{
kbs: []*entity.Knowledgebase{
{ID: "kb-1", TenantID: "tenant-2", Name: "Manual", EmbdID: "embed@factory"},
},
accessible: map[string]bool{"kb-1": false},
},
modelProviderSvc: fakeChatModelProvider{driver: &fakeChatModelDriver{}},
metadataSvc: fakeChatMetadataService{},
retrievalSvc: retrieval,
}
_, _, _, err := svc.messagesWithRetrievedKnowledge(context.Background(), "user-1", &entity.Chat{
ID: "dialog-1",
TenantID: "tenant-1",
PromptConfig: entity.JSONMap{},
KBIDs: entity.JSONSlice{"kb-1"},
TopN: 3,
TopK: 32,
}, []map[string]interface{}{
{"role": "user", "content": "question"},
}, []interface{}{map[string]interface{}{"chunks": []interface{}{}, "doc_aggs": []interface{}{}}})
if err == nil || !strings.Contains(err.Error(), "not authorized") {
t.Fatalf("expected cross-tenant authorization error, got %v", err)
}
if retrieval.req != nil {
t.Fatal("retrieval should not be called for an unauthorized knowledge base")
}
}
func TestMessagesWithRetrievedKnowledgeAllowsAccessibleSharedKnowledgebase(t *testing.T) {
retrieval := &fakeChatRetrievalService{result: &nlp.RetrievalResult{}}
svc := &ChatSessionService{
kbDAO: fakeChatKBStore{
kbs: []*entity.Knowledgebase{
{ID: "kb-1", TenantID: "tenant-2", Name: "Shared Manual", EmbdID: "embed@factory"},
},
accessible: map[string]bool{"kb-1": true},
},
modelProviderSvc: fakeChatModelProvider{driver: &fakeChatModelDriver{}},
metadataSvc: fakeChatMetadataService{},
retrievalSvc: retrieval,
}
_, _, _, err := svc.messagesWithRetrievedKnowledge(context.Background(), "user-1", &entity.Chat{
ID: "dialog-1",
TenantID: "tenant-1",
PromptConfig: entity.JSONMap{},
KBIDs: entity.JSONSlice{"kb-1"},
TopN: 3,
TopK: 32,
}, []map[string]interface{}{
{"role": "user", "content": "question"},
}, []interface{}{map[string]interface{}{"chunks": []interface{}{}, "doc_aggs": []interface{}{}}})
if err != nil {
t.Fatalf("messagesWithRetrievedKnowledge returned error: %v", err)
}
if retrieval.req == nil || len(retrieval.req.TenantIDs) != 1 || retrieval.req.TenantIDs[0] != "tenant-2" {
t.Fatalf("expected retrieval to use shared KB tenant, got %#v", retrieval.req)
}
}
func TestMessagesWithRetrievedKnowledgeRejectsMixedEmbeddingModels(t *testing.T) {
retrieval := &fakeChatRetrievalService{result: &nlp.RetrievalResult{}}
svc := &ChatSessionService{
kbDAO: fakeChatKBStore{kbs: []*entity.Knowledgebase{
{ID: "kb-1", TenantID: "tenant-1", Name: "Manual", EmbdID: "embed-a@factory"},
{ID: "kb-2", TenantID: "tenant-1", Name: "FAQ", EmbdID: "embed-b@factory"},
}},
modelProviderSvc: fakeChatModelProvider{driver: &fakeChatModelDriver{}},
metadataSvc: fakeChatMetadataService{},
retrievalSvc: retrieval,
}
_, _, _, err := svc.messagesWithRetrievedKnowledge(context.Background(), "user-1", &entity.Chat{
ID: "dialog-1",
TenantID: "tenant-1",
PromptConfig: entity.JSONMap{},
KBIDs: entity.JSONSlice{"kb-1", "kb-2"},
TopN: 3,
TopK: 32,
}, []map[string]interface{}{
{"role": "user", "content": "question"},
}, []interface{}{map[string]interface{}{"chunks": []interface{}{}, "doc_aggs": []interface{}{}}})
if err == nil || !strings.Contains(err.Error(), "same embedding model") {
t.Fatalf("expected mixed embedding model error, got %v", err)
}
if retrieval.req != nil {
t.Fatal("retrieval should not run when knowledge bases use different embedding models")
}
}
func TestValidateKnowledgebaseEmbeddingModelsComparesResolvedNames(t *testing.T) {
firstTenantEmbdID := int64(1)
secondTenantEmbdID := int64(2)
kbs := []*entity.Knowledgebase{
{ID: "kb-1", TenantID: "tenant-1", Name: "Manual", EmbdID: "same-legacy-name", TenantEmbdID: &firstTenantEmbdID},
{ID: "kb-2", TenantID: "tenant-1", Name: "FAQ", EmbdID: "same-legacy-name", TenantEmbdID: &secondTenantEmbdID},
}
resolver := func(tenantID string, kb *entity.Knowledgebase) (string, error) {
if kb.TenantEmbdID != nil && *kb.TenantEmbdID == firstTenantEmbdID {
return "embed-a@factory", nil
}
return "embed-b@factory", nil
}
_, _, err := validateKnowledgebaseEmbeddingModels(kbs, "tenant-1", resolver)
if err == nil || !strings.Contains(err.Error(), "same embedding model") {
t.Fatalf("expected resolved mixed embedding model error, got %v", err)
}
}
func TestMessagesWithRetrievedKnowledgePreservesMultimodalContent(t *testing.T) {
retrieval := &fakeChatRetrievalService{
result: &nlp.RetrievalResult{
Chunks: []map[string]interface{}{
{"content_with_weight": "Knowledge for an image question."},
},
},
}
svc := &ChatSessionService{
kbDAO: fakeChatKBStore{kbs: []*entity.Knowledgebase{
{ID: "kb-1", TenantID: "tenant-1", Name: "Manual", EmbdID: "embed@factory"},
}},
modelProviderSvc: fakeChatModelProvider{driver: &fakeChatModelDriver{}},
metadataSvc: fakeChatMetadataService{},
retrievalSvc: retrieval,
}
imageBlock := map[string]interface{}{"type": "image_url", "image_url": map[string]interface{}{"url": "https://example.com/cat.png"}}
messages := []map[string]interface{}{
{"role": "user", "content": []interface{}{
map[string]interface{}{"type": "text", "text": "What is in this image?"},
imageBlock,
}},
}
got, _, emptyResponse, err := svc.messagesWithRetrievedKnowledge(context.Background(), "user-1", &entity.Chat{
ID: "dialog-1",
TenantID: "tenant-1",
PromptConfig: entity.JSONMap{},
KBIDs: entity.JSONSlice{"kb-1"},
TopN: 3,
TopK: 32,
}, messages, []interface{}{map[string]interface{}{"chunks": []interface{}{}, "doc_aggs": []interface{}{}}})
if err != nil {
t.Fatalf("messagesWithRetrievedKnowledge returned error: %v", err)
}
if emptyResponse != nil {
t.Fatalf("expected no empty response, got %q", *emptyResponse)
}
content, ok := got[0]["content"].([]interface{})
if !ok {
t.Fatalf("expected multimodal content to stay as blocks, got %T", got[0]["content"])
}
if len(content) != 3 {
t.Fatalf("expected injected text plus original blocks, got %#v", content)
}
injected, ok := content[0].(map[string]interface{})
if !ok || injected["type"] != "text" || !strings.Contains(injected["text"].(string), "Knowledge for an image question.") {
t.Fatalf("expected injected knowledge text block, got %#v", content[0])
}
preservedImage, ok := content[2].(map[string]interface{})
if !ok || preservedImage["type"] != "image_url" {
t.Fatalf("expected original image block to be preserved, got %#v", content[2])
}
if retrieval.req == nil || retrieval.req.Question != "What is in this image?" {
t.Fatalf("expected retrieval question from text block, got %#v", retrieval.req)
}
}
func TestMessagesWithRetrievedKnowledgePassesMessageDocIDs(t *testing.T) {
retrieval := &fakeChatRetrievalService{result: &nlp.RetrievalResult{}}
svc := &ChatSessionService{
kbDAO: fakeChatKBStore{kbs: []*entity.Knowledgebase{
{ID: "kb-1", TenantID: "tenant-1", Name: "Manual", EmbdID: "embed@factory"},
}},
modelProviderSvc: fakeChatModelProvider{driver: &fakeChatModelDriver{}},
metadataSvc: fakeChatMetadataService{},
retrievalSvc: retrieval,
}
_, _, _, err := svc.messagesWithRetrievedKnowledge(context.Background(), "user-1", &entity.Chat{
ID: "dialog-1",
TenantID: "tenant-1",
PromptConfig: entity.JSONMap{},
KBIDs: entity.JSONSlice{"kb-1"},
TopN: 3,
TopK: 32,
}, []map[string]interface{}{
{"role": "user", "content": "question", "doc_ids": []interface{}{"doc-1", "doc-2", "doc-1"}},
}, []interface{}{map[string]interface{}{"chunks": []interface{}{}, "doc_aggs": []interface{}{}}})
if err != nil {
t.Fatalf("messagesWithRetrievedKnowledge returned error: %v", err)
}
if len(retrieval.req.DocIDs) != 2 || retrieval.req.DocIDs[0] != "doc-1" || retrieval.req.DocIDs[1] != "doc-2" {
t.Fatalf("expected scoped doc ids, got %#v", retrieval.req.DocIDs)
}
}
type failingChatRetrievalService struct {
err error
}
func (f *failingChatRetrievalService) Retrieval(ctx context.Context, req *nlp.RetrievalRequest) (*nlp.RetrievalResult, error) {
return nil, f.err
}

View File

@@ -83,6 +83,9 @@ func compareValues(val1, val2, op string) bool {
// ManualValueResolver is a callback function to transform manual filter values
type ManualValueResolver func(map[string]interface{}) map[string]interface{}
// NoMatchDocIDSentinel forces retrieval to return no documents when filters match nothing.
const NoMatchDocIDSentinel = "-999"
// metaFilterTemplateCache caches the template content
var metaFilterTemplateCache string
@@ -264,7 +267,8 @@ func ApplyMetaFilter(metaData common.MetaData, filters []MetaFilterCondition, lo
// - "==" = "=" "!=" = "≠"
// - ">=" = "≥" "<=" = "≤"
// - "is" = "=" "not is" = "≠"
// (see common.metadata_utils.operatorMapping for the full list)
// (see common.metadata_utils.operatorMapping for the full list)
//
// Value conversion:
// - "in" / "not in": comma-separated string → []interface{} (as expected by common.MetaFilter)
// - all other operators: passed through as-is (string)
@@ -280,11 +284,13 @@ func convertToMetaCondition(f MetaFilterCondition) common.MetaCondition {
parts := strings.Split(strVal, ",")
arr := make([]interface{}, 0, len(parts))
for _, p := range parts {
if trimmed := strings.TrimSpace(p); trimmed != "" { arr = append(arr, trimmed) }
}
mc.Value = arr
}
return mc
if trimmed := strings.TrimSpace(p); trimmed != "" {
arr = append(arr, trimmed)
}
}
mc.Value = arr
}
return mc
}
// applySingleCondition applies a single filter condition and returns matching doc IDs
@@ -587,9 +593,6 @@ func ApplyMetaDataFilter(
return baseDocIDs, false
}
docIDs := make([]string, len(baseDocIDs))
copy(docIDs, baseDocIDs)
method, _ := metaDataFilter["method"].(string)
// Helper to run metadata filter with push-down fallback
@@ -632,13 +635,14 @@ func ApplyMetaDataFilter(
filters, err := GenMetaFilter(ctx, chatModel, metaData, question, nil)
if err != nil {
common.Warn("Failed to generate meta filter", zap.Error(err))
return docIDs, false
return baseDocIDs, false
}
filteredIDs := runMetadataFilter(filters.Conditions, filters.Logic)
docIDs = append(docIDs, filteredIDs...)
docIDs := constrainDocIDs(baseDocIDs, filteredIDs)
if len(docIDs) == 0 {
return nil, true // Return nil to indicate auto filter returned empty
}
return docIDs, false
case "semi_auto":
selectedKeys := []string{}
@@ -673,13 +677,14 @@ func ApplyMetaDataFilter(
filters, err := GenMetaFilter(ctx, chatModel, filteredMeta, question, constraints)
if err != nil {
common.Warn("Failed to generate meta filter", zap.Error(err))
return docIDs, false
return baseDocIDs, false
}
filteredIDs := runMetadataFilter(filters.Conditions, filters.Logic)
docIDs = append(docIDs, filteredIDs...)
docIDs := constrainDocIDs(baseDocIDs, filteredIDs)
if len(docIDs) == 0 {
return nil, true
}
return docIDs, false
}
}
@@ -689,6 +694,9 @@ func ApplyMetaDataFilter(
if logicVal, ok := metaDataFilter["logic"].(string); ok {
logic = logicVal
}
if len(manualFilters) == 0 {
return baseDocIDs, false
}
// Apply manual_value_resolver callback if provided
if len(manualValueResolver) > 0 && manualValueResolver[0] != nil {
@@ -720,13 +728,42 @@ func ApplyMetaDataFilter(
}
filteredIDs := runMetadataFilter(conditions, logic)
docIDs = append(docIDs, filteredIDs...)
docIDs := constrainDocIDs(baseDocIDs, filteredIDs)
if len(manualFilters) > 0 && len(docIDs) == 0 {
return []string{"-999"}, false
return []string{NoMatchDocIDSentinel}, false
}
return docIDs, false
}
return docIDs, false
return baseDocIDs, false
}
func constrainDocIDs(baseDocIDs, filteredDocIDs []string) []string {
filteredDocIDs = common.Deduplicate(filteredDocIDs)
if len(baseDocIDs) == 0 {
return filteredDocIDs
}
if len(filteredDocIDs) == 0 {
return []string{}
}
filteredSet := make(map[string]struct{}, len(filteredDocIDs))
for _, docID := range filteredDocIDs {
filteredSet[docID] = struct{}{}
}
result := make([]string, 0, min(len(baseDocIDs), len(filteredSet)))
seen := make(map[string]struct{}, len(baseDocIDs))
for _, docID := range baseDocIDs {
if _, allowed := filteredSet[docID]; !allowed {
continue
}
if _, exists := seen[docID]; exists {
continue
}
seen[docID] = struct{}{}
result = append(result, docID)
}
return result
}
// repairJSON attempts to fix common JSON formatting issues in LLM output.

View File

@@ -455,7 +455,7 @@ func TestApplySingleConditionRelationalFallsBackToString(t *testing.T) {
val string
want []string
}{
{op: ">", val: "banana", want: []string{"d-c"}}, // "cherry" > "banana"
{op: ">", val: "banana", want: []string{"d-c"}}, // "cherry" > "banana"
{op: "<", val: "cherry", want: []string{"d-a", "d-b"}}, // apple, banana
{op: ">=", val: "banana", want: []string{"d-b", "d-c"}},
{op: "<=", val: "banana", want: []string{"d-a", "d-b"}},

View File

@@ -1972,6 +1972,15 @@ func (m *ModelProviderService) GetChatModel(tenantID, compositeModelName string)
return modelModule.NewChatModel(driver, &modelName, apiConfig), nil
}
// GetRerankModel returns a RerankModel wrapper for the given tenant
func (m *ModelProviderService) GetRerankModel(tenantID, compositeModelName string) (*modelModule.RerankModel, error) {
driver, modelName, apiConfig, _, err := m.getModelConfig(tenantID, compositeModelName)
if err != nil {
return nil, err
}
return modelModule.NewRerankModel(driver, &modelName, apiConfig), nil
}
type AddModelRequest struct {
ProviderName string `json:"provider_name"`
InstanceName string `json:"instance_name"`