mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 15:31:05 +08:00
feat(go-api): add RAG retrieval to chat completions (#15739)
## Summary - Add knowledge-base retrieval support to Go chat completions. ## What changed - Routes KB-backed chat sessions through the Go retrieval service instead of falling back to solo chat. - Resolves embedding and rerank models, validates accessible knowledge bases, and preserves tenant-aware retrieval. - Rejects mixed embedding models across selected knowledge bases before retrieval to avoid incompatible vector dimensions. - Threads the HTTP request context into streaming retrieval so cancelled requests can stop downstream retrieval work. - Applies metadata filters and message-level `doc_ids` before retrieval. - Expands parent/child chunks before building references and prompt context. - Injects retrieved knowledge through a copied dialog prompt config so the caller's original dialog is not mutated. - Honors configured empty responses when no chunks are found. - Names the metadata no-match sentinel and reuses it across retrieval/handler paths. - Adds a defensive content cast while appending streamed answers. - Adds focused unit coverage for retrieval, metadata filtering, authorization, multimodal messages, references, empty-response behavior, prompt immutability, and mixed embedding models. --------- Co-authored-by: Yingfeng <yingfeng.zhang@gmail.com> Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -292,9 +292,10 @@ func (h *ChatSessionHandler) Completion(c *gin.Context) {
|
||||
|
||||
// Create a channel for streaming data
|
||||
streamChan := make(chan string)
|
||||
reqCtx := c.Request.Context()
|
||||
go func() {
|
||||
defer close(streamChan)
|
||||
err := h.chatSessionService.CompletionStream(userID, req.ConversationID, processedMessages, req.LLMID, chatModelConfig, messageID, streamChan)
|
||||
err := h.chatSessionService.CompletionStream(reqCtx, userID, req.ConversationID, processedMessages, req.LLMID, chatModelConfig, messageID, streamChan)
|
||||
if err != nil {
|
||||
streamChan <- fmt.Sprintf("data: %s\n\n", err.Error())
|
||||
}
|
||||
|
||||
@@ -24,9 +24,9 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"ragflow/internal/common"
|
||||
"gorm.io/gorm"
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
"ragflow/internal/common"
|
||||
"ragflow/internal/engine"
|
||||
"ragflow/internal/entity"
|
||||
modelModule "ragflow/internal/entity/models"
|
||||
@@ -71,16 +71,16 @@ type DocumentDAOIface interface {
|
||||
|
||||
// difyRetrievalRequest is the JSON body / query params for the Dify retrieval endpoint.
|
||||
type difyRetrievalRequest struct {
|
||||
KnowledgeID string `json:"knowledge_id" form:"knowledge_id"`
|
||||
Query string `json:"query" form:"query"`
|
||||
UseKG bool `json:"use_kg" form:"use_kg"`
|
||||
RetrievalSetting *difyRetrievalSetting `json:"retrieval_setting"`
|
||||
MetadataCondition *difyMetadataCondition `json:"metadata_condition"`
|
||||
KnowledgeID string `json:"knowledge_id" form:"knowledge_id"`
|
||||
Query string `json:"query" form:"query"`
|
||||
UseKG bool `json:"use_kg" form:"use_kg"`
|
||||
RetrievalSetting *difyRetrievalSetting `json:"retrieval_setting"`
|
||||
MetadataCondition *difyMetadataCondition `json:"metadata_condition"`
|
||||
}
|
||||
|
||||
type difyRetrievalSetting struct {
|
||||
TopK *int `json:"top_k" form:"top_k"`
|
||||
ScoreThreshold *float64 `json:"score_threshold" form:"score_threshold"`
|
||||
TopK *int `json:"top_k" form:"top_k"`
|
||||
ScoreThreshold *float64 `json:"score_threshold" form:"score_threshold"`
|
||||
}
|
||||
|
||||
// difyCondition is a Dify-format metadata filter condition.
|
||||
@@ -108,8 +108,8 @@ func (c difyMetadataCondition) toMetaFilterConditions() []service.MetaFilterCond
|
||||
v = fmt.Sprint(dc.Value)
|
||||
}
|
||||
result[i] = service.MetaFilterCondition{
|
||||
Key: dc.Name,
|
||||
Op: dc.ComparisonOperator,
|
||||
Key: dc.Name,
|
||||
Op: dc.ComparisonOperator,
|
||||
Value: v,
|
||||
}
|
||||
}
|
||||
@@ -241,15 +241,15 @@ func (h *DifyRetrievalHandler) Retrieval(c *gin.Context) {
|
||||
metas, metaErr := h.metadataSvc.GetFlattedMetaByKBs([]string{req.KnowledgeID})
|
||||
docIDs := make([]string, 0)
|
||||
if metaErr == nil && req.MetadataCondition != nil {
|
||||
logic := req.MetadataCondition.Logic
|
||||
if logic == "" {
|
||||
logic = "and"
|
||||
}
|
||||
filteredIDs := service.ApplyMetaFilter(metas, req.MetadataCondition.toMetaFilterConditions(), logic)
|
||||
logic := req.MetadataCondition.Logic
|
||||
if logic == "" {
|
||||
logic = "and"
|
||||
}
|
||||
filteredIDs := service.ApplyMetaFilter(metas, req.MetadataCondition.toMetaFilterConditions(), logic)
|
||||
docIDs = append(docIDs, filteredIDs...)
|
||||
}
|
||||
if len(docIDs) == 0 && req.MetadataCondition != nil {
|
||||
docIDs = []string{"-999"}
|
||||
docIDs = []string{service.NoMatchDocIDSentinel}
|
||||
}
|
||||
|
||||
// Label question for rank features
|
||||
@@ -258,15 +258,15 @@ func (h *DifyRetrievalHandler) Retrieval(c *gin.Context) {
|
||||
|
||||
// Chunk retrieval
|
||||
sr := &nlp.RetrievalRequest{
|
||||
Question: req.Query,
|
||||
TenantIDs: []string{kb.TenantID},
|
||||
KbIDs: []string{req.KnowledgeID},
|
||||
DocIDs: docIDs,
|
||||
Page: 1,
|
||||
PageSize: pageSize,
|
||||
Top: topK,
|
||||
SimilarityThreshold: scoreThreshold,
|
||||
EmbeddingModel: embModel,
|
||||
Question: req.Query,
|
||||
TenantIDs: []string{kb.TenantID},
|
||||
KbIDs: []string{req.KnowledgeID},
|
||||
DocIDs: docIDs,
|
||||
Page: 1,
|
||||
PageSize: pageSize,
|
||||
Top: topK,
|
||||
SimilarityThreshold: scoreThreshold,
|
||||
EmbeddingModel: embModel,
|
||||
}
|
||||
if rankFeature != nil {
|
||||
sr.RankFeature = &rankFeature
|
||||
|
||||
@@ -17,10 +17,13 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"ragflow/internal/common"
|
||||
"ragflow/internal/engine"
|
||||
"ragflow/internal/service/nlp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -31,21 +34,61 @@ import (
|
||||
modelModule "ragflow/internal/entity/models"
|
||||
)
|
||||
|
||||
type chatKnowledgebaseStore interface {
|
||||
Accessible(kbID, userID string) bool
|
||||
GetByIDs(ids []string) ([]*entity.Knowledgebase, error)
|
||||
}
|
||||
|
||||
type chatModelProvider interface {
|
||||
GetChatModel(tenantID, compositeModelName string) (*modelModule.ChatModel, error)
|
||||
GetEmbeddingModel(tenantID, compositeModelName string) (*modelModule.EmbeddingModel, error)
|
||||
GetRerankModel(tenantID, compositeModelName string) (*modelModule.RerankModel, error)
|
||||
GetModelConfigFromProviderInstance(tenantID string, modelType entity.ModelType, modelName string) (modelModule.ModelDriver, string, *modelModule.APIConfig, int, error)
|
||||
GetTenantDefaultModelByType(tenantID string, modelType entity.ModelType) (modelModule.ModelDriver, string, *modelModule.APIConfig, int, error)
|
||||
}
|
||||
|
||||
type chatMetadataService interface {
|
||||
LabelQuestion(question string, kbs []*entity.Knowledgebase) map[string]float64
|
||||
GetFlattedMetaByKBs(kbIDs []string) (common.MetaData, error)
|
||||
}
|
||||
|
||||
type chatRetrievalService interface {
|
||||
Retrieval(ctx context.Context, req *nlp.RetrievalRequest) (*nlp.RetrievalResult, error)
|
||||
}
|
||||
|
||||
// ChatSessionService chat session (conversation) service
|
||||
type ChatSessionService struct {
|
||||
chatSessionDAO *dao.ChatSessionDAO
|
||||
chatDAO *dao.ChatDAO
|
||||
userTenantDAO *dao.UserTenantDAO
|
||||
modelProviderSvc *ModelProviderService
|
||||
kbDAO chatKnowledgebaseStore
|
||||
docEngine engine.DocEngine
|
||||
modelProviderSvc chatModelProvider
|
||||
metadataSvc chatMetadataService
|
||||
retrievalSvc chatRetrievalService
|
||||
}
|
||||
|
||||
// NewChatSessionService create chat session service
|
||||
func NewChatSessionService() *ChatSessionService {
|
||||
docEngine := engine.Get()
|
||||
return newChatSessionServiceWithRetrieval(docEngine, nlp.NewRetrievalService(docEngine, dao.NewDocumentDAO()))
|
||||
}
|
||||
|
||||
// NewChatSessionServiceWithRetrieval creates a chat session service with a retrieval service.
|
||||
func NewChatSessionServiceWithRetrieval(retrievalSvc chatRetrievalService) *ChatSessionService {
|
||||
return newChatSessionServiceWithRetrieval(engine.Get(), retrievalSvc)
|
||||
}
|
||||
|
||||
func newChatSessionServiceWithRetrieval(docEngine engine.DocEngine, retrievalSvc chatRetrievalService) *ChatSessionService {
|
||||
return &ChatSessionService{
|
||||
chatSessionDAO: dao.NewChatSessionDAO(),
|
||||
chatDAO: dao.NewChatDAO(),
|
||||
userTenantDAO: dao.NewUserTenantDAO(),
|
||||
kbDAO: dao.NewKnowledgebaseDAO(),
|
||||
docEngine: docEngine,
|
||||
modelProviderSvc: NewModelProviderService(),
|
||||
metadataSvc: NewMetadataService(),
|
||||
retrievalSvc: retrievalSvc,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -294,7 +337,7 @@ func (s *ChatSessionService) Completion(userID string, conversationID string, me
|
||||
}
|
||||
|
||||
// Perform chat completion with RAG
|
||||
result, err := s.asyncChat(dialog, session, messages, chatModelConfig, messageID, reference, false)
|
||||
result, err := s.asyncChat(userID, dialog, session, messages, chatModelConfig, messageID, reference, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -308,7 +351,11 @@ func (s *ChatSessionService) Completion(userID string, conversationID string, me
|
||||
}
|
||||
|
||||
// CompletionStream performs streaming chat completion with full RAG support
|
||||
func (s *ChatSessionService) CompletionStream(userID string, conversationID string, messages []map[string]interface{}, llmID string, chatModelConfig map[string]interface{}, messageID string, streamChan chan<- string) error {
|
||||
func (s *ChatSessionService) CompletionStream(ctx context.Context, userID string, conversationID string, messages []map[string]interface{}, llmID string, chatModelConfig map[string]interface{}, messageID string, streamChan chan<- string) error {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
// Validate the last message is from user
|
||||
if len(messages) == 0 {
|
||||
streamChan <- fmt.Sprintf("data: %s\n\n", `{"code": 500, "message": "messages cannot be empty", "data": {"answer": "**ERROR**: messages cannot be empty", "reference": []}}`)
|
||||
@@ -356,7 +403,7 @@ func (s *ChatSessionService) CompletionStream(userID string, conversationID stri
|
||||
}
|
||||
|
||||
// Perform streaming chat completion with RAG
|
||||
resultChan, err := s.asyncChatStream(dialog, session, messages, chatModelConfig, messageID, reference)
|
||||
resultChan, err := s.asyncChatStream(ctx, userID, dialog, session, messages, chatModelConfig, messageID, reference)
|
||||
if err != nil {
|
||||
streamChan <- fmt.Sprintf("data: %s\n\n", fmt.Sprintf(`{"code": 500, "message": "%s", "data": {"answer": "**ERROR**: %s", "reference": []}}`, err.Error(), err.Error()))
|
||||
return err
|
||||
@@ -450,7 +497,7 @@ func (s *ChatSessionService) updateSessionMessages(session *entity.ChatSession,
|
||||
}
|
||||
|
||||
// asyncChat performs chat with RAG support (non-streaming)
|
||||
func (s *ChatSessionService) asyncChat(dialog *entity.Chat, session *entity.ChatSession, messages []map[string]interface{}, config map[string]interface{}, messageID string, reference []interface{}, stream bool) (map[string]interface{}, error) {
|
||||
func (s *ChatSessionService) asyncChat(userID string, dialog *entity.Chat, session *entity.ChatSession, messages []map[string]interface{}, config map[string]interface{}, messageID string, reference []interface{}, stream bool) (map[string]interface{}, error) {
|
||||
// Check if we need RAG (knowledge base or tavily)
|
||||
hasKB := len(dialog.KBIDs) > 0
|
||||
hasTavily := false
|
||||
@@ -465,21 +512,20 @@ func (s *ChatSessionService) asyncChat(dialog *entity.Chat, session *entity.Chat
|
||||
return s.asyncChatSolo(dialog, session, messages, config, messageID, reference, stream)
|
||||
}
|
||||
|
||||
// TODO: Full RAG implementation with knowledge base retrieval
|
||||
// This would include:
|
||||
// 1. Get embedding model and rerank model
|
||||
// 2. Extract questions from messages
|
||||
// 3. Retrieve chunks from knowledge bases
|
||||
// 4. Rerank chunks
|
||||
// 5. Build prompt with context
|
||||
// 6. Call LLM
|
||||
if hasKB {
|
||||
return s.asyncChatWithRetrieval(context.Background(), userID, dialog, session, messages, config, messageID, reference, stream)
|
||||
}
|
||||
|
||||
// For now, fall back to solo chat
|
||||
common.Warn("Tavily-backed chat retrieval is not implemented in Go; falling back to solo chat",
|
||||
zap.String("dialog_id", dialog.ID))
|
||||
return s.asyncChatSolo(dialog, session, messages, config, messageID, reference, stream)
|
||||
}
|
||||
|
||||
// asyncChatStream performs streaming chat with RAG support
|
||||
func (s *ChatSessionService) asyncChatStream(dialog *entity.Chat, session *entity.ChatSession, messages []map[string]interface{}, config map[string]interface{}, messageID string, reference []interface{}) (<-chan map[string]interface{}, error) {
|
||||
func (s *ChatSessionService) asyncChatStream(ctx context.Context, userID string, dialog *entity.Chat, session *entity.ChatSession, messages []map[string]interface{}, config map[string]interface{}, messageID string, reference []interface{}) (<-chan map[string]interface{}, error) {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
resultChan := make(chan map[string]interface{})
|
||||
|
||||
go func() {
|
||||
@@ -500,14 +546,599 @@ func (s *ChatSessionService) asyncChatStream(dialog *entity.Chat, session *entit
|
||||
return
|
||||
}
|
||||
|
||||
// TODO: Full RAG streaming implementation
|
||||
// For now, fall back to solo chat
|
||||
if hasKB {
|
||||
ragMessages, ragDialog, emptyResponse, err := s.messagesWithRetrievedKnowledge(ctx, userID, dialog, messages, reference)
|
||||
if err != nil {
|
||||
resultChan <- s.structureAnswer(session, "**ERROR**: "+err.Error(), messageID, session.ID, reference)
|
||||
return
|
||||
}
|
||||
if emptyResponse != nil {
|
||||
resultChan <- s.structureAnswer(session, *emptyResponse, messageID, session.ID, reference)
|
||||
return
|
||||
}
|
||||
s.asyncChatSoloStream(ragDialog, session, ragMessages, config, messageID, reference, resultChan)
|
||||
return
|
||||
}
|
||||
|
||||
common.Warn("Tavily-backed streaming chat retrieval is not implemented in Go; falling back to solo chat",
|
||||
zap.String("dialog_id", dialog.ID))
|
||||
s.asyncChatSoloStream(dialog, session, messages, config, messageID, reference, resultChan)
|
||||
}()
|
||||
|
||||
return resultChan, nil
|
||||
}
|
||||
|
||||
func (s *ChatSessionService) asyncChatWithRetrieval(ctx context.Context, userID string, dialog *entity.Chat, session *entity.ChatSession, messages []map[string]interface{}, config map[string]interface{}, messageID string, reference []interface{}, stream bool) (map[string]interface{}, error) {
|
||||
ragMessages, ragDialog, emptyResponse, err := s.messagesWithRetrievedKnowledge(ctx, userID, dialog, messages, reference)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if emptyResponse != nil {
|
||||
var lastRef interface{}
|
||||
if len(reference) > 0 {
|
||||
lastRef = reference[len(reference)-1]
|
||||
}
|
||||
ans := map[string]interface{}{
|
||||
"answer": *emptyResponse,
|
||||
"reference": lastRef,
|
||||
"final": true,
|
||||
}
|
||||
return s.structureAnswerWithConv(session, ans, messageID, session.ID, reference), nil
|
||||
}
|
||||
return s.asyncChatSolo(ragDialog, session, ragMessages, config, messageID, reference, stream)
|
||||
}
|
||||
|
||||
func (s *ChatSessionService) messagesWithRetrievedKnowledge(ctx context.Context, userID string, dialog *entity.Chat, messages []map[string]interface{}, reference []interface{}) ([]map[string]interface{}, *entity.Chat, *string, error) {
|
||||
kbIDs := stringSliceFromJSON(dialog.KBIDs)
|
||||
if len(kbIDs) == 0 {
|
||||
return messages, dialog, nil, nil
|
||||
}
|
||||
if s.retrievalSvc == nil {
|
||||
return nil, nil, nil, errors.New("retrieval service is not configured")
|
||||
}
|
||||
|
||||
question := latestUserQuestion(messages)
|
||||
if question == "" {
|
||||
return messages, dialog, nil, nil
|
||||
}
|
||||
|
||||
kbs, err := s.kbDAO.GetByIDs(kbIDs)
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("failed to load knowledge bases: %w", err)
|
||||
}
|
||||
kbs, err = s.knowledgebasesForDialog(userID, dialog, kbIDs, kbs)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
embeddingTenantID, embeddingModelName, err := validateKnowledgebaseEmbeddingModels(kbs, dialog.TenantID, resolveEmbeddingModelName)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
embeddingModel, err := s.modelProviderSvc.GetEmbeddingModel(embeddingTenantID, embeddingModelName)
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("failed to get embedding model: %w", err)
|
||||
}
|
||||
rerankModel, err := s.rerankModelForDialog(dialog)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
top := int(dialog.TopK)
|
||||
pageSize := int(dialog.TopN)
|
||||
if pageSize <= 0 {
|
||||
pageSize = 6
|
||||
}
|
||||
similarityThreshold := dialog.SimilarityThreshold
|
||||
vectorSimilarityWeight := dialog.VectorSimilarityWeight
|
||||
var rankFeature map[string]float64
|
||||
if s.metadataSvc != nil {
|
||||
rankFeature = s.metadataSvc.LabelQuestion(question, kbs)
|
||||
}
|
||||
baseDocIDs := docIDsFromMessages(messages)
|
||||
docIDs, err := s.filteredDocIDsForDialog(ctx, dialog, kbIDs, question, baseDocIDs)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
tenantIDs := tenantIDsFromKnowledgebases(kbs, dialog.TenantID)
|
||||
|
||||
retrievalResult, err := s.retrievalSvc.Retrieval(ctx, &nlp.RetrievalRequest{
|
||||
Question: question,
|
||||
TenantIDs: tenantIDs,
|
||||
KbIDs: kbIDs,
|
||||
DocIDs: docIDs,
|
||||
Page: 1,
|
||||
PageSize: pageSize,
|
||||
Top: &top,
|
||||
SimilarityThreshold: &similarityThreshold,
|
||||
VectorSimilarityWeight: &vectorSimilarityWeight,
|
||||
RankFeature: &rankFeature,
|
||||
EmbeddingModel: embeddingModel,
|
||||
RerankModel: rerankModel,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, nil, fmt.Errorf("retrieval search failed: %w", err)
|
||||
}
|
||||
if retrievalResult == nil {
|
||||
retrievalResult = &nlp.RetrievalResult{}
|
||||
}
|
||||
|
||||
chunks := retrievalResult.Chunks
|
||||
if s.docEngine != nil {
|
||||
chunks = nlp.RetrievalByChildren(chunks, tenantIDs, s.docEngine, ctx)
|
||||
}
|
||||
setLatestReference(reference, chunks, retrievalResult.DocAggs)
|
||||
knowledge := buildKnowledgeBlock(chunks)
|
||||
if knowledge == "" {
|
||||
return messages, dialog, emptyResponseForDialog(dialog), nil
|
||||
}
|
||||
if ragDialog, ok := dialogWithInjectedKnowledgePrompt(dialog, knowledge); ok {
|
||||
return copyMessages(messages), ragDialog, nil, nil
|
||||
}
|
||||
|
||||
return injectKnowledge(messages, knowledge), dialog, nil, nil
|
||||
}
|
||||
|
||||
type embeddingModelNameResolver func(tenantID string, kb *entity.Knowledgebase) (string, error)
|
||||
|
||||
func validateKnowledgebaseEmbeddingModels(kbs []*entity.Knowledgebase, fallbackTenantID string, resolve embeddingModelNameResolver) (string, string, error) {
|
||||
if len(kbs) == 0 {
|
||||
return fallbackTenantID, "", nil
|
||||
}
|
||||
|
||||
expected := ""
|
||||
expectedKBID := ""
|
||||
expectedTenantID := fallbackTenantID
|
||||
for _, kb := range kbs {
|
||||
if kb == nil {
|
||||
return "", "", errors.New("knowledge base is nil")
|
||||
}
|
||||
tenantID := kb.TenantID
|
||||
if tenantID == "" {
|
||||
tenantID = fallbackTenantID
|
||||
}
|
||||
modelName, err := resolve(tenantID, kb)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
modelName = strings.TrimSpace(modelName)
|
||||
if modelName == "" {
|
||||
return "", "", fmt.Errorf("knowledge base %s has no embedding model", kb.ID)
|
||||
}
|
||||
if expected == "" {
|
||||
expected = modelName
|
||||
expectedKBID = kb.ID
|
||||
expectedTenantID = tenantID
|
||||
continue
|
||||
}
|
||||
if modelName != expected {
|
||||
return "", "", fmt.Errorf("knowledge bases must use the same embedding model: %s resolves to %q, expected %q from %s", kb.ID, modelName, expected, expectedKBID)
|
||||
}
|
||||
}
|
||||
return expectedTenantID, expected, nil
|
||||
}
|
||||
|
||||
func (s *ChatSessionService) rerankModelForDialog(dialog *entity.Chat) (*modelModule.RerankModel, error) {
|
||||
compositeName, err := resolveRerankModelName(dialog)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if compositeName == "" {
|
||||
return nil, nil
|
||||
}
|
||||
rerankModel, err := s.modelProviderSvc.GetRerankModel(dialog.TenantID, compositeName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get rerank model: %w", err)
|
||||
}
|
||||
return rerankModel, nil
|
||||
}
|
||||
|
||||
func (s *ChatSessionService) filteredDocIDsForDialog(ctx context.Context, dialog *entity.Chat, kbIDs []string, question string, baseDocIDs []string) ([]string, error) {
|
||||
if dialog.MetaDataFilter == nil || len(*dialog.MetaDataFilter) == 0 {
|
||||
return baseDocIDs, nil
|
||||
}
|
||||
if s.metadataSvc == nil {
|
||||
return nil, errors.New("metadata service is not configured")
|
||||
}
|
||||
|
||||
filter := make(map[string]interface{}, len(*dialog.MetaDataFilter))
|
||||
for key, value := range *dialog.MetaDataFilter {
|
||||
filter[key] = value
|
||||
}
|
||||
|
||||
metaData, err := s.metadataSvc.GetFlattedMetaByKBs(kbIDs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get flattened metadata for chat retrieval: %w", err)
|
||||
}
|
||||
|
||||
var filterChatModel *modelModule.ChatModel
|
||||
method, _ := filter["method"].(string)
|
||||
if method == "auto" || method == "semi_auto" {
|
||||
filterChatModel, err = s.modelProviderSvc.GetChatModel(dialog.TenantID, dialog.LLMID)
|
||||
if err != nil {
|
||||
common.Warn("Failed to get chat model for chat metadata filter", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
docIDs, empty := ApplyMetaDataFilter(ctx, filter, metaData, question, filterChatModel, baseDocIDs, kbIDs)
|
||||
if empty {
|
||||
return []string{NoMatchDocIDSentinel}, nil
|
||||
}
|
||||
return docIDs, nil
|
||||
}
|
||||
|
||||
func resolveEmbeddingModelName(tenantID string, kb *entity.Knowledgebase) (string, error) {
|
||||
if kb.TenantEmbdID != nil && *kb.TenantEmbdID > 0 {
|
||||
_, compositeName, err := dao.LookupTenantLLMByID(dao.NewTenantLLMDAO(), *kb.TenantEmbdID)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get embedding model by tenant_embd_id: %w", err)
|
||||
}
|
||||
return compositeName, nil
|
||||
}
|
||||
if kb.EmbdID != "" {
|
||||
if strings.Contains(kb.EmbdID, "@") {
|
||||
return kb.EmbdID, nil
|
||||
}
|
||||
_, compositeName, err := dao.LookupTenantLLMByName(dao.NewTenantLLMDAO(), tenantID, kb.EmbdID, entity.ModelTypeEmbedding)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get embedding model by embd_id: %w", err)
|
||||
}
|
||||
return compositeName, nil
|
||||
}
|
||||
|
||||
tenantLLM, err := dao.NewTenantLLMDAO().GetByTenantAndType(tenantID, entity.ModelTypeEmbedding)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get tenant default embedding model: %w", err)
|
||||
}
|
||||
if tenantLLM == nil || tenantLLM.LLMName == nil || *tenantLLM.LLMName == "" {
|
||||
return "", fmt.Errorf("no default embedding model found for tenant %s", tenantID)
|
||||
}
|
||||
return fmt.Sprintf("%s@%s", *tenantLLM.LLMName, tenantLLM.LLMFactory), nil
|
||||
}
|
||||
|
||||
func resolveRerankModelName(dialog *entity.Chat) (string, error) {
|
||||
if dialog.TenantRerankID != nil && *dialog.TenantRerankID > 0 {
|
||||
_, compositeName, err := dao.LookupTenantLLMByID(dao.NewTenantLLMDAO(), *dialog.TenantRerankID)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get rerank model by tenant_rerank_id: %w", err)
|
||||
}
|
||||
return compositeName, nil
|
||||
}
|
||||
if dialog.RerankID == "" {
|
||||
return "", nil
|
||||
}
|
||||
if strings.Contains(dialog.RerankID, "@") {
|
||||
return dialog.RerankID, nil
|
||||
}
|
||||
_, compositeName, err := dao.LookupTenantLLMByName(dao.NewTenantLLMDAO(), dialog.TenantID, dialog.RerankID, entity.ModelTypeRerank)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to get rerank model by rerank_id: %w", err)
|
||||
}
|
||||
return compositeName, nil
|
||||
}
|
||||
|
||||
func stringSliceFromJSON(values entity.JSONSlice) []string {
|
||||
result := make([]string, 0, len(values))
|
||||
seen := make(map[string]struct{}, len(values))
|
||||
for _, value := range values {
|
||||
str, ok := value.(string)
|
||||
if !ok || str == "" {
|
||||
continue
|
||||
}
|
||||
if _, exists := seen[str]; exists {
|
||||
continue
|
||||
}
|
||||
seen[str] = struct{}{}
|
||||
result = append(result, str)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func tenantIDsFromKnowledgebases(kbs []*entity.Knowledgebase, fallback string) []string {
|
||||
seen := make(map[string]struct{}, len(kbs)+1)
|
||||
var tenantIDs []string
|
||||
for _, kb := range kbs {
|
||||
if kb == nil || kb.TenantID == "" {
|
||||
continue
|
||||
}
|
||||
if _, exists := seen[kb.TenantID]; exists {
|
||||
continue
|
||||
}
|
||||
seen[kb.TenantID] = struct{}{}
|
||||
tenantIDs = append(tenantIDs, kb.TenantID)
|
||||
}
|
||||
if len(tenantIDs) == 0 && fallback != "" {
|
||||
tenantIDs = append(tenantIDs, fallback)
|
||||
}
|
||||
return tenantIDs
|
||||
}
|
||||
|
||||
func (s *ChatSessionService) knowledgebasesForDialog(userID string, dialog *entity.Chat, kbIDs []string, loaded []*entity.Knowledgebase) ([]*entity.Knowledgebase, error) {
|
||||
byID := make(map[string]*entity.Knowledgebase, len(loaded))
|
||||
for _, kb := range loaded {
|
||||
if kb != nil {
|
||||
byID[kb.ID] = kb
|
||||
}
|
||||
}
|
||||
|
||||
kbs := make([]*entity.Knowledgebase, 0, len(kbIDs))
|
||||
for _, kbID := range kbIDs {
|
||||
kb := byID[kbID]
|
||||
if kb == nil {
|
||||
return nil, fmt.Errorf("knowledge base %s not found", kbID)
|
||||
}
|
||||
if userID != "" && !s.kbDAO.Accessible(kbID, userID) {
|
||||
return nil, fmt.Errorf("knowledge base %s is not authorized for user", kbID)
|
||||
}
|
||||
if userID == "" && kb.TenantID != dialog.TenantID {
|
||||
return nil, fmt.Errorf("knowledge base %s is not authorized for dialog tenant", kbID)
|
||||
}
|
||||
kbs = append(kbs, kb)
|
||||
}
|
||||
if len(kbs) == 0 {
|
||||
return nil, errors.New("no valid knowledge bases found")
|
||||
}
|
||||
return kbs, nil
|
||||
}
|
||||
|
||||
func docIDsFromMessages(messages []map[string]interface{}) []string {
|
||||
for i := len(messages) - 1; i >= 0; i-- {
|
||||
if role, _ := messages[i]["role"].(string); role != "user" {
|
||||
continue
|
||||
}
|
||||
return stringSliceFromValue(messages[i]["doc_ids"])
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func latestUserQuestion(messages []map[string]interface{}) string {
|
||||
for i := len(messages) - 1; i >= 0; i-- {
|
||||
if role, _ := messages[i]["role"].(string); role != "user" {
|
||||
continue
|
||||
}
|
||||
return textFromMessageContent(messages[i]["content"])
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func stringSliceFromValue(value interface{}) []string {
|
||||
switch typed := value.(type) {
|
||||
case nil:
|
||||
return nil
|
||||
case []string:
|
||||
return uniqueNonEmptyStrings(typed)
|
||||
case []interface{}:
|
||||
values := make([]string, 0, len(typed))
|
||||
for _, item := range typed {
|
||||
if str, ok := item.(string); ok {
|
||||
values = append(values, str)
|
||||
}
|
||||
}
|
||||
return uniqueNonEmptyStrings(values)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func uniqueNonEmptyStrings(values []string) []string {
|
||||
result := make([]string, 0, len(values))
|
||||
seen := make(map[string]struct{}, len(values))
|
||||
for _, value := range values {
|
||||
value = strings.TrimSpace(value)
|
||||
if value == "" {
|
||||
continue
|
||||
}
|
||||
if _, exists := seen[value]; exists {
|
||||
continue
|
||||
}
|
||||
seen[value] = struct{}{}
|
||||
result = append(result, value)
|
||||
}
|
||||
if len(result) == 0 {
|
||||
return nil
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func emptyResponseForDialog(dialog *entity.Chat) *string {
|
||||
if dialog.PromptConfig == nil {
|
||||
return nil
|
||||
}
|
||||
emptyResponse, ok := dialog.PromptConfig["empty_response"].(string)
|
||||
if !ok || emptyResponse == "" {
|
||||
return nil
|
||||
}
|
||||
return &emptyResponse
|
||||
}
|
||||
|
||||
func buildKnowledgeBlock(chunks []map[string]interface{}) string {
|
||||
var builder strings.Builder
|
||||
for i, chunk := range chunks {
|
||||
content := chunkText(chunk)
|
||||
if content == "" {
|
||||
continue
|
||||
}
|
||||
if builder.Len() > 0 {
|
||||
builder.WriteString("\n\n")
|
||||
}
|
||||
builder.WriteString(fmt.Sprintf("[%d]", i+1))
|
||||
if docName, ok := chunk["docnm_kwd"].(string); ok && docName != "" {
|
||||
builder.WriteString(" ")
|
||||
builder.WriteString(docName)
|
||||
}
|
||||
builder.WriteString("\n")
|
||||
builder.WriteString(content)
|
||||
}
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
func chunkText(chunk map[string]interface{}) string {
|
||||
for _, key := range []string{"content_with_weight", "content_ltks", "content"} {
|
||||
if value, ok := chunk[key].(string); ok && strings.TrimSpace(value) != "" {
|
||||
return strings.TrimSpace(value)
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func injectKnowledge(messages []map[string]interface{}, knowledge string) []map[string]interface{} {
|
||||
copied := copyMessages(messages)
|
||||
if len(copied) == 0 {
|
||||
return copied
|
||||
}
|
||||
|
||||
knowledgePrompt := fmt.Sprintf("Use the following knowledge snippets to answer the user's question. If the snippets do not contain the answer, say that the knowledge base does not provide enough information.\n\n%s", knowledge)
|
||||
for i := len(copied) - 1; i >= 0; i-- {
|
||||
if role, _ := copied[i]["role"].(string); role != "user" {
|
||||
continue
|
||||
}
|
||||
copied[i]["content"] = injectKnowledgeIntoContent(copied[i]["content"], knowledgePrompt)
|
||||
return copied
|
||||
}
|
||||
|
||||
copied = append(copied, map[string]interface{}{
|
||||
"role": "system",
|
||||
"content": knowledgePrompt,
|
||||
})
|
||||
return copied
|
||||
}
|
||||
|
||||
func injectKnowledgeIntoContent(content interface{}, knowledgePrompt string) interface{} {
|
||||
switch typed := content.(type) {
|
||||
case []interface{}:
|
||||
injected := make([]interface{}, 0, len(typed)+1)
|
||||
injected = append(injected, knowledgeTextBlock(knowledgePrompt))
|
||||
injected = append(injected, typed...)
|
||||
return injected
|
||||
case []map[string]interface{}:
|
||||
injected := make([]interface{}, 0, len(typed)+1)
|
||||
injected = append(injected, knowledgeTextBlock(knowledgePrompt))
|
||||
for _, block := range typed {
|
||||
injected = append(injected, block)
|
||||
}
|
||||
return injected
|
||||
default:
|
||||
contentText := ""
|
||||
if content != nil {
|
||||
contentText = fmt.Sprint(content)
|
||||
}
|
||||
return strings.TrimSpace(knowledgePrompt + "\n\nQuestion:\n" + contentText)
|
||||
}
|
||||
}
|
||||
|
||||
func knowledgeTextBlock(knowledgePrompt string) map[string]interface{} {
|
||||
return map[string]interface{}{
|
||||
"type": "text",
|
||||
"text": knowledgePrompt + "\n\nQuestion:",
|
||||
}
|
||||
}
|
||||
|
||||
func textFromMessageContent(content interface{}) string {
|
||||
switch typed := content.(type) {
|
||||
case string:
|
||||
return strings.TrimSpace(typed)
|
||||
case []interface{}:
|
||||
return strings.TrimSpace(strings.Join(textsFromContentBlocks(typed), "\n"))
|
||||
case []map[string]interface{}:
|
||||
blocks := make([]interface{}, 0, len(typed))
|
||||
for _, block := range typed {
|
||||
blocks = append(blocks, block)
|
||||
}
|
||||
return strings.TrimSpace(strings.Join(textsFromContentBlocks(blocks), "\n"))
|
||||
default:
|
||||
if content == nil {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(fmt.Sprint(content))
|
||||
}
|
||||
}
|
||||
|
||||
func textsFromContentBlocks(blocks []interface{}) []string {
|
||||
texts := make([]string, 0, len(blocks))
|
||||
for _, block := range blocks {
|
||||
switch typed := block.(type) {
|
||||
case string:
|
||||
if text := strings.TrimSpace(typed); text != "" {
|
||||
texts = append(texts, text)
|
||||
}
|
||||
case map[string]interface{}:
|
||||
if text, ok := typed["text"].(string); ok && strings.TrimSpace(text) != "" {
|
||||
texts = append(texts, strings.TrimSpace(text))
|
||||
}
|
||||
}
|
||||
}
|
||||
return texts
|
||||
}
|
||||
|
||||
func dialogWithInjectedKnowledgePrompt(dialog *entity.Chat, knowledge string) (*entity.Chat, bool) {
|
||||
if dialog.PromptConfig == nil {
|
||||
return dialog, false
|
||||
}
|
||||
systemPrompt, ok := dialog.PromptConfig["system"].(string)
|
||||
if !ok || !strings.Contains(systemPrompt, "{knowledge}") {
|
||||
return dialog, false
|
||||
}
|
||||
|
||||
copied := cloneJSONMap(dialog.PromptConfig)
|
||||
copied["system"] = strings.ReplaceAll(systemPrompt, "{knowledge}", knowledge)
|
||||
dialogCopy := *dialog
|
||||
dialogCopy.PromptConfig = copied
|
||||
return &dialogCopy, true
|
||||
}
|
||||
|
||||
func cloneJSONMap(values entity.JSONMap) entity.JSONMap {
|
||||
copied := make(entity.JSONMap, len(values))
|
||||
for key, value := range values {
|
||||
copied[key] = value
|
||||
}
|
||||
return copied
|
||||
}
|
||||
|
||||
func copyMessages(messages []map[string]interface{}) []map[string]interface{} {
|
||||
copied := make([]map[string]interface{}, len(messages))
|
||||
for i, msg := range messages {
|
||||
copied[i] = make(map[string]interface{}, len(msg))
|
||||
for key, value := range msg {
|
||||
copied[i][key] = value
|
||||
}
|
||||
}
|
||||
return copied
|
||||
}
|
||||
|
||||
func setLatestReference(reference []interface{}, chunks []map[string]interface{}, docAggs []map[string]interface{}) {
|
||||
ref := map[string]interface{}{
|
||||
"chunks": chunksForReference(chunks),
|
||||
"doc_aggs": mapsForReference(docAggs),
|
||||
}
|
||||
if len(reference) == 0 {
|
||||
return
|
||||
}
|
||||
reference[len(reference)-1] = ref
|
||||
}
|
||||
|
||||
func chunksForReference(chunks []map[string]interface{}) []interface{} {
|
||||
result := make([]interface{}, 0, len(chunks))
|
||||
for _, chunk := range chunks {
|
||||
copied := make(map[string]interface{}, len(chunk))
|
||||
for key, value := range chunk {
|
||||
if key == "vector" {
|
||||
continue
|
||||
}
|
||||
copied[key] = value
|
||||
}
|
||||
result = append(result, copied)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func mapsForReference(values []map[string]interface{}) []interface{} {
|
||||
result := make([]interface{}, 0, len(values))
|
||||
for _, value := range values {
|
||||
result = append(result, value)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// asyncChatSolo performs simple chat without RAG (non-streaming)
|
||||
func (s *ChatSessionService) asyncChatSolo(dialog *entity.Chat, session *entity.ChatSession, messages []map[string]interface{}, config map[string]interface{}, messageID string, reference []interface{}, stream bool) (map[string]interface{}, error) {
|
||||
common.Info("asyncChatSolo started",
|
||||
@@ -765,7 +1396,8 @@ func (s *ChatSessionService) structureAnswerWithConv(session *entity.ChatSession
|
||||
if ans["final"] == true && ans["answer"] != nil {
|
||||
lastMsg["content"] = ans["answer"]
|
||||
} else {
|
||||
lastMsg["content"] = (lastMsg["content"].(string)) + content
|
||||
existing, _ := lastMsg["content"].(string)
|
||||
lastMsg["content"] = existing + content
|
||||
}
|
||||
lastMsg["created_at"] = float64(time.Now().Unix())
|
||||
lastMsg["id"] = messageID
|
||||
|
||||
949
internal/service/chat_session_test.go
Normal file
949
internal/service/chat_session_test.go
Normal file
@@ -0,0 +1,949 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"ragflow/internal/common"
|
||||
"ragflow/internal/engine/types"
|
||||
"ragflow/internal/entity"
|
||||
modelModule "ragflow/internal/entity/models"
|
||||
"ragflow/internal/service/nlp"
|
||||
)
|
||||
|
||||
type fakeChatKBStore struct {
|
||||
kbs []*entity.Knowledgebase
|
||||
accessible map[string]bool
|
||||
}
|
||||
|
||||
func (f fakeChatKBStore) Accessible(kbID, userID string) bool {
|
||||
if f.accessible == nil {
|
||||
return true
|
||||
}
|
||||
return f.accessible[kbID]
|
||||
}
|
||||
|
||||
func (f fakeChatKBStore) GetByIDs(ids []string) ([]*entity.Knowledgebase, error) {
|
||||
return f.kbs, nil
|
||||
}
|
||||
|
||||
type fakeChatMetadataService struct{}
|
||||
|
||||
func (fakeChatMetadataService) LabelQuestion(question string, kbs []*entity.Knowledgebase) map[string]float64 {
|
||||
return map[string]float64{"pagerank_fea": 10}
|
||||
}
|
||||
|
||||
func (fakeChatMetadataService) GetFlattedMetaByKBs(kbIDs []string) (common.MetaData, error) {
|
||||
return common.MetaData{
|
||||
"category": common.MetaValueDocs{
|
||||
"policy": []string{"doc-policy"},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
type failingChatMetadataService struct{}
|
||||
|
||||
func (failingChatMetadataService) LabelQuestion(question string, kbs []*entity.Knowledgebase) map[string]float64 {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (failingChatMetadataService) GetFlattedMetaByKBs(kbIDs []string) (common.MetaData, error) {
|
||||
return nil, errors.New("metadata unavailable")
|
||||
}
|
||||
|
||||
type fakeChatDocEngine struct {
|
||||
chunk map[string]interface{}
|
||||
}
|
||||
|
||||
func (f fakeChatDocEngine) CreateChunkStore(ctx context.Context, baseName, datasetID string, vectorSize int, parserID string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f fakeChatDocEngine) InsertChunks(ctx context.Context, chunks []map[string]interface{}, baseName string, datasetID string) ([]string, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (f fakeChatDocEngine) UpdateChunks(ctx context.Context, condition map[string]interface{}, newValue map[string]interface{}, baseName string, datasetID string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f fakeChatDocEngine) DeleteChunks(ctx context.Context, condition map[string]interface{}, baseName string, datasetID string) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (f fakeChatDocEngine) Search(ctx context.Context, req *types.SearchRequest) (*types.SearchResult, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (f fakeChatDocEngine) GetChunk(ctx context.Context, baseName, chunkID string, datasetIDs []string) (interface{}, error) {
|
||||
return f.chunk, nil
|
||||
}
|
||||
|
||||
func (f fakeChatDocEngine) DropChunkStore(ctx context.Context, baseName, datasetID string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f fakeChatDocEngine) ChunkStoreExists(ctx context.Context, baseName, datasetID string) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (f fakeChatDocEngine) CreateMetadataStore(ctx context.Context, tenantID string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f fakeChatDocEngine) InsertMetadata(ctx context.Context, metadata []map[string]interface{}, tenantID string) ([]string, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (f fakeChatDocEngine) UpdateMetadata(ctx context.Context, docID string, datasetID string, metaFields map[string]interface{}, tenantID string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f fakeChatDocEngine) DeleteMetadata(ctx context.Context, condition map[string]interface{}, tenantID string) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (f fakeChatDocEngine) DeleteMetadataKeys(ctx context.Context, docID string, datasetID string, keys []string, tenantID string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f fakeChatDocEngine) DropMetadataStore(ctx context.Context, tenantID string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f fakeChatDocEngine) MetadataStoreExists(ctx context.Context, tenantID string) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (f fakeChatDocEngine) SearchMetadata(ctx context.Context, req *types.SearchMetadataRequest) (*types.SearchMetadataResult, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (f fakeChatDocEngine) IndexDocument(ctx context.Context, indexName, docID string, doc interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f fakeChatDocEngine) DeleteDocument(ctx context.Context, indexName, docID string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f fakeChatDocEngine) BulkIndex(ctx context.Context, indexName string, docs []interface{}) (interface{}, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (f fakeChatDocEngine) GetFields(chunks []map[string]interface{}, fields []string) map[string]map[string]interface{} {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f fakeChatDocEngine) GetAggregation(chunks []map[string]interface{}, fieldName string) []map[string]interface{} {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f fakeChatDocEngine) GetHighlight(chunks []map[string]interface{}, keywords []string, fieldName string) map[string]string {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f fakeChatDocEngine) GetChunkIDs(chunks []map[string]interface{}) []string {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f fakeChatDocEngine) KNNScores(ctx context.Context, chunks []map[string]interface{}, queryVector []float64, topK int) (map[string]interface{}, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (f fakeChatDocEngine) GetScores(searchResult map[string]interface{}) map[string]float64 {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f fakeChatDocEngine) FilterDocIdsByMetaPushdown(ctx context.Context, kbIDs []string, conditions []map[string]interface{}, logic string) []string {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f fakeChatDocEngine) Ping(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f fakeChatDocEngine) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f fakeChatDocEngine) GetType() string {
|
||||
return "fake"
|
||||
}
|
||||
|
||||
type fakeChatRetrievalService struct {
|
||||
req *nlp.RetrievalRequest
|
||||
result *nlp.RetrievalResult
|
||||
}
|
||||
|
||||
func (f *fakeChatRetrievalService) Retrieval(ctx context.Context, req *nlp.RetrievalRequest) (*nlp.RetrievalResult, error) {
|
||||
f.req = req
|
||||
return f.result, nil
|
||||
}
|
||||
|
||||
type fakeChatModelProvider struct {
|
||||
driver *fakeChatModelDriver
|
||||
}
|
||||
|
||||
func (f fakeChatModelProvider) GetChatModel(tenantID, compositeModelName string) (*modelModule.ChatModel, error) {
|
||||
modelName := compositeModelName
|
||||
return modelModule.NewChatModel(f.driver, &modelName, &modelModule.APIConfig{}), nil
|
||||
}
|
||||
|
||||
func (f fakeChatModelProvider) GetEmbeddingModel(tenantID, compositeModelName string) (*modelModule.EmbeddingModel, error) {
|
||||
modelName := compositeModelName
|
||||
return modelModule.NewEmbeddingModel(f.driver, &modelName, &modelModule.APIConfig{}, 512), nil
|
||||
}
|
||||
|
||||
func (f fakeChatModelProvider) GetRerankModel(tenantID, compositeModelName string) (*modelModule.RerankModel, error) {
|
||||
modelName := compositeModelName
|
||||
return modelModule.NewRerankModel(f.driver, &modelName, &modelModule.APIConfig{}), nil
|
||||
}
|
||||
|
||||
func (f fakeChatModelProvider) GetModelConfigFromProviderInstance(tenantID string, modelType entity.ModelType, modelName string) (modelModule.ModelDriver, string, *modelModule.APIConfig, int, error) {
|
||||
return f.driver, modelName, &modelModule.APIConfig{}, 0, nil
|
||||
}
|
||||
|
||||
func (f fakeChatModelProvider) GetTenantDefaultModelByType(tenantID string, modelType entity.ModelType) (modelModule.ModelDriver, string, *modelModule.APIConfig, int, error) {
|
||||
modelName := "default@factory"
|
||||
return f.driver, modelName, &modelModule.APIConfig{}, 0, nil
|
||||
}
|
||||
|
||||
type fakeChatModelDriver struct {
|
||||
messages []modelModule.Message
|
||||
}
|
||||
|
||||
func (f *fakeChatModelDriver) NewInstance(baseURL map[string]string) modelModule.ModelDriver {
|
||||
return f
|
||||
}
|
||||
|
||||
func (f *fakeChatModelDriver) Name() string {
|
||||
return "fake"
|
||||
}
|
||||
|
||||
func (f *fakeChatModelDriver) ChatWithMessages(modelName string, messages []modelModule.Message, apiConfig *modelModule.APIConfig, chatModelConfig *modelModule.ChatConfig) (*modelModule.ChatResponse, error) {
|
||||
f.messages = messages
|
||||
answer := "answer from knowledge"
|
||||
return &modelModule.ChatResponse{Answer: &answer}, nil
|
||||
}
|
||||
|
||||
func (f *fakeChatModelDriver) ChatStreamlyWithSender(modelName string, messages []modelModule.Message, apiConfig *modelModule.APIConfig, modelConfig *modelModule.ChatConfig, sender func(*string, *string) error) error {
|
||||
f.messages = messages
|
||||
answer := "stream answer from knowledge"
|
||||
return sender(&answer, nil)
|
||||
}
|
||||
|
||||
func (f *fakeChatModelDriver) Embed(modelName *string, texts []string, apiConfig *modelModule.APIConfig, embeddingConfig *modelModule.EmbeddingConfig) ([]modelModule.EmbeddingData, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (f *fakeChatModelDriver) Rerank(modelName *string, query string, documents []string, apiConfig *modelModule.APIConfig, rerankConfig *modelModule.RerankConfig) (*modelModule.RerankResponse, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (f *fakeChatModelDriver) TranscribeAudio(modelName *string, file *string, apiConfig *modelModule.APIConfig, asrConfig *modelModule.ASRConfig) (*modelModule.ASRResponse, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (f *fakeChatModelDriver) TranscribeAudioWithSender(modelName *string, file *string, apiConfig *modelModule.APIConfig, asrConfig *modelModule.ASRConfig, sender func(*string, *string) error) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeChatModelDriver) AudioSpeech(modelName *string, audioContent *string, apiConfig *modelModule.APIConfig, ttsConfig *modelModule.TTSConfig) (*modelModule.TTSResponse, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (f *fakeChatModelDriver) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *modelModule.APIConfig, ttsConfig *modelModule.TTSConfig, sender func(*string, *string) error) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeChatModelDriver) OCRFile(modelName *string, content []byte, url *string, apiConfig *modelModule.APIConfig, ocrConfig *modelModule.OCRConfig) (*modelModule.OCRFileResponse, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (f *fakeChatModelDriver) ParseFile(modelName *string, content []byte, url *string, apiConfig *modelModule.APIConfig, parseFileConfig *modelModule.ParseFileConfig) (*modelModule.ParseFileResponse, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (f *fakeChatModelDriver) ListModels(apiConfig *modelModule.APIConfig) ([]modelModule.ListModelResponse, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (f *fakeChatModelDriver) Balance(apiConfig *modelModule.APIConfig) (map[string]interface{}, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (f *fakeChatModelDriver) CheckConnection(apiConfig *modelModule.APIConfig) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeChatModelDriver) ListTasks(apiConfig *modelModule.APIConfig) ([]modelModule.ListTaskStatus, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (f *fakeChatModelDriver) ShowTask(taskID string, apiConfig *modelModule.APIConfig) (*modelModule.TaskResponse, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func TestAsyncChatUsesRetrievedKnowledgeForKBDialog(t *testing.T) {
|
||||
driver := &fakeChatModelDriver{}
|
||||
retrieval := &fakeChatRetrievalService{
|
||||
result: &nlp.RetrievalResult{
|
||||
Chunks: []map[string]interface{}{
|
||||
{
|
||||
"chunk_id": "chunk-1",
|
||||
"content_with_weight": "RAGFlow stores conversation references alongside the session.",
|
||||
"doc_id": "doc-1",
|
||||
"docnm_kwd": "manual.md",
|
||||
"vector": []float64{0.1, 0.2},
|
||||
},
|
||||
},
|
||||
DocAggs: []map[string]interface{}{
|
||||
{"doc_id": "doc-1", "doc_name": "manual.md", "count": 1},
|
||||
},
|
||||
},
|
||||
}
|
||||
svc := &ChatSessionService{
|
||||
kbDAO: fakeChatKBStore{kbs: []*entity.Knowledgebase{
|
||||
{ID: "kb-1", TenantID: "tenant-1", Name: "Manual", EmbdID: "embed@factory"},
|
||||
}},
|
||||
modelProviderSvc: fakeChatModelProvider{driver: driver},
|
||||
metadataSvc: fakeChatMetadataService{},
|
||||
retrievalSvc: retrieval,
|
||||
}
|
||||
|
||||
reference := []interface{}{map[string]interface{}{"chunks": []interface{}{}, "doc_aggs": []interface{}{}}}
|
||||
sessionMessage, err := json.Marshal(map[string]interface{}{"messages": []interface{}{}})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to marshal session message: %v", err)
|
||||
}
|
||||
session := &entity.ChatSession{ID: "session-1", Message: sessionMessage}
|
||||
dialog := &entity.Chat{
|
||||
ID: "dialog-1",
|
||||
TenantID: "tenant-1",
|
||||
LLMID: "chat@factory",
|
||||
PromptConfig: entity.JSONMap{"system": "You are helpful."},
|
||||
LLMSetting: entity.JSONMap{},
|
||||
KBIDs: entity.JSONSlice{"kb-1"},
|
||||
TopN: 3,
|
||||
TopK: 32,
|
||||
SimilarityThreshold: 0.2,
|
||||
VectorSimilarityWeight: 0.3,
|
||||
}
|
||||
|
||||
result, err := svc.asyncChat("user-1", dialog, session, []map[string]interface{}{
|
||||
{"role": "user", "content": "Where are references stored?"},
|
||||
}, nil, "message-1", reference, false)
|
||||
if err != nil {
|
||||
t.Fatalf("asyncChat returned error: %v", err)
|
||||
}
|
||||
|
||||
if retrieval.req == nil {
|
||||
t.Fatal("expected retrieval service to be called")
|
||||
}
|
||||
if retrieval.req.Question != "Where are references stored?" {
|
||||
t.Fatalf("unexpected retrieval question: %q", retrieval.req.Question)
|
||||
}
|
||||
if retrieval.req.PageSize != 3 || retrieval.req.Top == nil || *retrieval.req.Top != 32 {
|
||||
t.Fatalf("unexpected retrieval paging: page_size=%d top=%v", retrieval.req.PageSize, retrieval.req.Top)
|
||||
}
|
||||
if len(driver.messages) == 0 {
|
||||
t.Fatal("expected chat model to receive messages")
|
||||
}
|
||||
last := driver.messages[len(driver.messages)-1]
|
||||
content, ok := last.Content.(string)
|
||||
if !ok {
|
||||
t.Fatalf("expected string content, got %T", last.Content)
|
||||
}
|
||||
if !strings.Contains(content, "RAGFlow stores conversation references") {
|
||||
t.Fatalf("expected retrieved content in prompt, got %q", content)
|
||||
}
|
||||
|
||||
ref, ok := result["reference"].(map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatalf("expected reference map, got %T", result["reference"])
|
||||
}
|
||||
chunks, ok := ref["chunks"].([]interface{})
|
||||
if !ok || len(chunks) != 1 {
|
||||
t.Fatalf("expected one reference chunk, got %#v", ref["chunks"])
|
||||
}
|
||||
chunk, ok := chunks[0].(map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatalf("expected chunk map, got %T", chunks[0])
|
||||
}
|
||||
if _, exists := chunk["vector"]; exists {
|
||||
t.Fatal("reference chunk should not expose vector")
|
||||
}
|
||||
if result["answer"] != "answer from knowledge" {
|
||||
t.Fatalf("unexpected answer: %#v", result["answer"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestAsyncChatPropagatesRetrievalErrors(t *testing.T) {
|
||||
retrievalErr := errors.New("search unavailable")
|
||||
retrieval := &failingChatRetrievalService{err: retrievalErr}
|
||||
svc := &ChatSessionService{
|
||||
kbDAO: fakeChatKBStore{kbs: []*entity.Knowledgebase{
|
||||
{ID: "kb-1", TenantID: "tenant-1", Name: "Manual", EmbdID: "embed@factory"},
|
||||
}},
|
||||
modelProviderSvc: fakeChatModelProvider{driver: &fakeChatModelDriver{}},
|
||||
metadataSvc: fakeChatMetadataService{},
|
||||
retrievalSvc: retrieval,
|
||||
}
|
||||
|
||||
_, err := svc.asyncChat("user-1", &entity.Chat{
|
||||
ID: "dialog-1",
|
||||
TenantID: "tenant-1",
|
||||
LLMID: "chat@factory",
|
||||
PromptConfig: entity.JSONMap{},
|
||||
LLMSetting: entity.JSONMap{},
|
||||
KBIDs: entity.JSONSlice{"kb-1"},
|
||||
TopN: 3,
|
||||
TopK: 32,
|
||||
SimilarityThreshold: 0.2,
|
||||
VectorSimilarityWeight: 0.3,
|
||||
}, &entity.ChatSession{ID: "session-1"}, []map[string]interface{}{
|
||||
{"role": "user", "content": "question"},
|
||||
}, nil, "message-1", []interface{}{map[string]interface{}{"chunks": []interface{}{}, "doc_aggs": []interface{}{}}}, false)
|
||||
if err == nil || !strings.Contains(err.Error(), "retrieval search failed") {
|
||||
t.Fatalf("expected retrieval error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessagesWithRetrievedKnowledgeFillsSystemPlaceholder(t *testing.T) {
|
||||
retrieval := &fakeChatRetrievalService{
|
||||
result: &nlp.RetrievalResult{
|
||||
Chunks: []map[string]interface{}{
|
||||
{"content_with_weight": "Knowledge inserted into the system prompt."},
|
||||
},
|
||||
},
|
||||
}
|
||||
svc := &ChatSessionService{
|
||||
kbDAO: fakeChatKBStore{kbs: []*entity.Knowledgebase{
|
||||
{ID: "kb-1", TenantID: "tenant-1", Name: "Manual", EmbdID: "embed@factory"},
|
||||
}},
|
||||
modelProviderSvc: fakeChatModelProvider{driver: &fakeChatModelDriver{}},
|
||||
metadataSvc: fakeChatMetadataService{},
|
||||
retrievalSvc: retrieval,
|
||||
}
|
||||
dialog := &entity.Chat{
|
||||
ID: "dialog-1",
|
||||
TenantID: "tenant-1",
|
||||
PromptConfig: entity.JSONMap{"system": "Answer from this context: {knowledge}"},
|
||||
KBIDs: entity.JSONSlice{"kb-1"},
|
||||
TopN: 3,
|
||||
TopK: 32,
|
||||
}
|
||||
messages := []map[string]interface{}{
|
||||
{"role": "user", "content": "What context is available?"},
|
||||
}
|
||||
|
||||
got, ragDialog, emptyResponse, err := svc.messagesWithRetrievedKnowledge(context.Background(), "user-1", dialog, messages, []interface{}{
|
||||
map[string]interface{}{"chunks": []interface{}{}, "doc_aggs": []interface{}{}},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("messagesWithRetrievedKnowledge returned error: %v", err)
|
||||
}
|
||||
if emptyResponse != nil {
|
||||
t.Fatalf("expected no empty response, got %q", *emptyResponse)
|
||||
}
|
||||
if got[0]["content"] != "What context is available?" {
|
||||
t.Fatalf("expected user content to stay unchanged, got %q", got[0]["content"])
|
||||
}
|
||||
originalPrompt, _ := dialog.PromptConfig["system"].(string)
|
||||
if !strings.Contains(originalPrompt, "{knowledge}") {
|
||||
t.Fatalf("expected original dialog prompt to remain unchanged, got %q", originalPrompt)
|
||||
}
|
||||
systemPrompt, _ := ragDialog.PromptConfig["system"].(string)
|
||||
if strings.Contains(systemPrompt, "{knowledge}") {
|
||||
t.Fatalf("expected knowledge placeholder to be replaced, got %q", systemPrompt)
|
||||
}
|
||||
if !strings.Contains(systemPrompt, "Knowledge inserted into the system prompt.") {
|
||||
t.Fatalf("expected retrieved knowledge in system prompt, got %q", systemPrompt)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAsyncChatReturnsEmptyResponseWhenRetrievalHasNoKnowledge(t *testing.T) {
|
||||
driver := &fakeChatModelDriver{}
|
||||
svc := &ChatSessionService{
|
||||
kbDAO: fakeChatKBStore{kbs: []*entity.Knowledgebase{
|
||||
{ID: "kb-1", TenantID: "tenant-1", Name: "Manual", EmbdID: "embed@factory"},
|
||||
}},
|
||||
modelProviderSvc: fakeChatModelProvider{driver: driver},
|
||||
metadataSvc: fakeChatMetadataService{},
|
||||
retrievalSvc: &fakeChatRetrievalService{result: &nlp.RetrievalResult{}},
|
||||
}
|
||||
reference := []interface{}{map[string]interface{}{"chunks": []interface{}{}, "doc_aggs": []interface{}{}}}
|
||||
sessionMessage, err := json.Marshal(map[string]interface{}{"messages": []interface{}{}})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to marshal session message: %v", err)
|
||||
}
|
||||
result, err := svc.asyncChat("user-1", &entity.Chat{
|
||||
ID: "dialog-1",
|
||||
TenantID: "tenant-1",
|
||||
LLMID: "chat@factory",
|
||||
PromptConfig: entity.JSONMap{"empty_response": "No relevant content."},
|
||||
LLMSetting: entity.JSONMap{},
|
||||
KBIDs: entity.JSONSlice{"kb-1"},
|
||||
TopN: 3,
|
||||
TopK: 32,
|
||||
}, &entity.ChatSession{ID: "session-1", Message: sessionMessage}, []map[string]interface{}{
|
||||
{"role": "user", "content": "question"},
|
||||
}, nil, "message-1", reference, false)
|
||||
if err != nil {
|
||||
t.Fatalf("asyncChat returned error: %v", err)
|
||||
}
|
||||
if result["answer"] != "No relevant content." {
|
||||
t.Fatalf("unexpected empty response answer: %#v", result["answer"])
|
||||
}
|
||||
if len(driver.messages) != 0 {
|
||||
t.Fatal("chat model should not be called when empty_response is returned")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessagesWithRetrievedKnowledgeAppliesMetadataFilter(t *testing.T) {
|
||||
retrieval := &fakeChatRetrievalService{result: &nlp.RetrievalResult{}}
|
||||
svc := &ChatSessionService{
|
||||
kbDAO: fakeChatKBStore{kbs: []*entity.Knowledgebase{
|
||||
{ID: "kb-1", TenantID: "tenant-1", Name: "Manual", EmbdID: "embed@factory"},
|
||||
}},
|
||||
modelProviderSvc: fakeChatModelProvider{driver: &fakeChatModelDriver{}},
|
||||
metadataSvc: fakeChatMetadataService{},
|
||||
retrievalSvc: retrieval,
|
||||
}
|
||||
filter := entity.JSONMap{
|
||||
"method": "manual",
|
||||
"manual": []interface{}{
|
||||
map[string]interface{}{"key": "category", "op": "=", "value": "policy"},
|
||||
},
|
||||
"logic": "and",
|
||||
}
|
||||
_, _, _, err := svc.messagesWithRetrievedKnowledge(context.Background(), "user-1", &entity.Chat{
|
||||
ID: "dialog-1",
|
||||
TenantID: "tenant-1",
|
||||
PromptConfig: entity.JSONMap{},
|
||||
MetaDataFilter: &filter,
|
||||
KBIDs: entity.JSONSlice{"kb-1"},
|
||||
TopN: 3,
|
||||
TopK: 32,
|
||||
}, []map[string]interface{}{
|
||||
{"role": "user", "content": "question"},
|
||||
}, []interface{}{map[string]interface{}{"chunks": []interface{}{}, "doc_aggs": []interface{}{}}})
|
||||
if err != nil {
|
||||
t.Fatalf("messagesWithRetrievedKnowledge returned error: %v", err)
|
||||
}
|
||||
if retrieval.req == nil {
|
||||
t.Fatal("expected retrieval to be called")
|
||||
}
|
||||
if len(retrieval.req.DocIDs) != 1 || retrieval.req.DocIDs[0] != "doc-policy" {
|
||||
t.Fatalf("expected metadata-filtered doc id, got %#v", retrieval.req.DocIDs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessagesWithRetrievedKnowledgeIntersectsDocIDsWithMetadataFilter(t *testing.T) {
|
||||
retrieval := &fakeChatRetrievalService{result: &nlp.RetrievalResult{}}
|
||||
svc := &ChatSessionService{
|
||||
kbDAO: fakeChatKBStore{kbs: []*entity.Knowledgebase{
|
||||
{ID: "kb-1", TenantID: "tenant-1", Name: "Manual", EmbdID: "embed@factory"},
|
||||
}},
|
||||
modelProviderSvc: fakeChatModelProvider{driver: &fakeChatModelDriver{}},
|
||||
metadataSvc: fakeChatMetadataService{},
|
||||
retrievalSvc: retrieval,
|
||||
}
|
||||
filter := entity.JSONMap{
|
||||
"method": "manual",
|
||||
"manual": []interface{}{
|
||||
map[string]interface{}{"key": "category", "op": "=", "value": "policy"},
|
||||
},
|
||||
"logic": "and",
|
||||
}
|
||||
|
||||
_, _, _, err := svc.messagesWithRetrievedKnowledge(context.Background(), "user-1", &entity.Chat{
|
||||
ID: "dialog-1",
|
||||
TenantID: "tenant-1",
|
||||
PromptConfig: entity.JSONMap{},
|
||||
MetaDataFilter: &filter,
|
||||
KBIDs: entity.JSONSlice{"kb-1"},
|
||||
TopN: 3,
|
||||
TopK: 32,
|
||||
}, []map[string]interface{}{
|
||||
{"role": "user", "content": "question", "doc_ids": []interface{}{"doc-explicit", "doc-policy"}},
|
||||
}, []interface{}{map[string]interface{}{"chunks": []interface{}{}, "doc_aggs": []interface{}{}}})
|
||||
if err != nil {
|
||||
t.Fatalf("messagesWithRetrievedKnowledge returned error: %v", err)
|
||||
}
|
||||
if len(retrieval.req.DocIDs) != 1 || retrieval.req.DocIDs[0] != "doc-policy" {
|
||||
t.Fatalf("expected metadata and message doc_ids intersection, got %#v", retrieval.req.DocIDs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessagesWithRetrievedKnowledgeNoMetadataIntersectionUsesSentinel(t *testing.T) {
|
||||
retrieval := &fakeChatRetrievalService{result: &nlp.RetrievalResult{}}
|
||||
svc := &ChatSessionService{
|
||||
kbDAO: fakeChatKBStore{kbs: []*entity.Knowledgebase{
|
||||
{ID: "kb-1", TenantID: "tenant-1", Name: "Manual", EmbdID: "embed@factory"},
|
||||
}},
|
||||
modelProviderSvc: fakeChatModelProvider{driver: &fakeChatModelDriver{}},
|
||||
metadataSvc: fakeChatMetadataService{},
|
||||
retrievalSvc: retrieval,
|
||||
}
|
||||
filter := entity.JSONMap{
|
||||
"method": "manual",
|
||||
"manual": []interface{}{
|
||||
map[string]interface{}{"key": "category", "op": "=", "value": "policy"},
|
||||
},
|
||||
"logic": "and",
|
||||
}
|
||||
|
||||
_, _, _, err := svc.messagesWithRetrievedKnowledge(context.Background(), "user-1", &entity.Chat{
|
||||
ID: "dialog-1",
|
||||
TenantID: "tenant-1",
|
||||
PromptConfig: entity.JSONMap{},
|
||||
MetaDataFilter: &filter,
|
||||
KBIDs: entity.JSONSlice{"kb-1"},
|
||||
TopN: 3,
|
||||
TopK: 32,
|
||||
}, []map[string]interface{}{
|
||||
{"role": "user", "content": "question", "doc_ids": []interface{}{"doc-explicit"}},
|
||||
}, []interface{}{map[string]interface{}{"chunks": []interface{}{}, "doc_aggs": []interface{}{}}})
|
||||
if err != nil {
|
||||
t.Fatalf("messagesWithRetrievedKnowledge returned error: %v", err)
|
||||
}
|
||||
if len(retrieval.req.DocIDs) != 1 || retrieval.req.DocIDs[0] != NoMatchDocIDSentinel {
|
||||
t.Fatalf("expected empty metadata/doc_ids intersection sentinel, got %#v", retrieval.req.DocIDs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessagesWithRetrievedKnowledgePreservesEmptyMetadataFilterMatches(t *testing.T) {
|
||||
retrieval := &fakeChatRetrievalService{result: &nlp.RetrievalResult{}}
|
||||
svc := &ChatSessionService{
|
||||
kbDAO: fakeChatKBStore{kbs: []*entity.Knowledgebase{
|
||||
{ID: "kb-1", TenantID: "tenant-1", Name: "Manual", EmbdID: "embed@factory"},
|
||||
}},
|
||||
modelProviderSvc: fakeChatModelProvider{driver: &fakeChatModelDriver{}},
|
||||
metadataSvc: fakeChatMetadataService{},
|
||||
retrievalSvc: retrieval,
|
||||
}
|
||||
filter := entity.JSONMap{"method": "auto"}
|
||||
|
||||
_, _, _, err := svc.messagesWithRetrievedKnowledge(context.Background(), "user-1", &entity.Chat{
|
||||
ID: "dialog-1",
|
||||
TenantID: "tenant-1",
|
||||
LLMID: "chat@factory",
|
||||
PromptConfig: entity.JSONMap{},
|
||||
MetaDataFilter: &filter,
|
||||
KBIDs: entity.JSONSlice{"kb-1"},
|
||||
TopN: 3,
|
||||
TopK: 32,
|
||||
}, []map[string]interface{}{
|
||||
{"role": "user", "content": "question"},
|
||||
}, []interface{}{map[string]interface{}{"chunks": []interface{}{}, "doc_aggs": []interface{}{}}})
|
||||
if err != nil {
|
||||
t.Fatalf("messagesWithRetrievedKnowledge returned error: %v", err)
|
||||
}
|
||||
if len(retrieval.req.DocIDs) != 1 || retrieval.req.DocIDs[0] != NoMatchDocIDSentinel {
|
||||
t.Fatalf("expected empty metadata filter sentinel, got %#v", retrieval.req.DocIDs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessagesWithRetrievedKnowledgeFailsClosedWhenMetadataUnavailable(t *testing.T) {
|
||||
retrieval := &fakeChatRetrievalService{result: &nlp.RetrievalResult{}}
|
||||
svc := &ChatSessionService{
|
||||
kbDAO: fakeChatKBStore{kbs: []*entity.Knowledgebase{
|
||||
{ID: "kb-1", TenantID: "tenant-1", Name: "Manual", EmbdID: "embed@factory"},
|
||||
}},
|
||||
modelProviderSvc: fakeChatModelProvider{driver: &fakeChatModelDriver{}},
|
||||
metadataSvc: failingChatMetadataService{},
|
||||
retrievalSvc: retrieval,
|
||||
}
|
||||
filter := entity.JSONMap{
|
||||
"method": "manual",
|
||||
"manual": []interface{}{
|
||||
map[string]interface{}{"key": "category", "op": "=", "value": "policy"},
|
||||
},
|
||||
"logic": "and",
|
||||
}
|
||||
|
||||
_, _, _, err := svc.messagesWithRetrievedKnowledge(context.Background(), "user-1", &entity.Chat{
|
||||
ID: "dialog-1",
|
||||
TenantID: "tenant-1",
|
||||
PromptConfig: entity.JSONMap{},
|
||||
MetaDataFilter: &filter,
|
||||
KBIDs: entity.JSONSlice{"kb-1"},
|
||||
TopN: 3,
|
||||
TopK: 32,
|
||||
}, []map[string]interface{}{
|
||||
{"role": "user", "content": "question", "doc_ids": []interface{}{"doc-explicit"}},
|
||||
}, []interface{}{map[string]interface{}{"chunks": []interface{}{}, "doc_aggs": []interface{}{}}})
|
||||
if err == nil || !strings.Contains(err.Error(), "flattened metadata") {
|
||||
t.Fatalf("expected metadata filter error, got %v", err)
|
||||
}
|
||||
if retrieval.req != nil {
|
||||
t.Fatal("retrieval should not run when metadata filtering cannot be evaluated")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessagesWithRetrievedKnowledgeExpandsChildChunks(t *testing.T) {
|
||||
retrieval := &fakeChatRetrievalService{result: &nlp.RetrievalResult{
|
||||
Chunks: []map[string]interface{}{
|
||||
{
|
||||
"chunk_id": "child-1",
|
||||
"mom_id": "parent-1",
|
||||
"kb_id": "kb-1",
|
||||
"doc_id": "doc-1",
|
||||
"docnm_kwd": "doc.md",
|
||||
"content_ltks": "child tokens",
|
||||
"content_with_weight": "child-only passage",
|
||||
"similarity": 0.8,
|
||||
},
|
||||
},
|
||||
}}
|
||||
svc := &ChatSessionService{
|
||||
kbDAO: fakeChatKBStore{kbs: []*entity.Knowledgebase{
|
||||
{ID: "kb-1", TenantID: "tenant-1", Name: "Manual", EmbdID: "embed@factory"},
|
||||
}},
|
||||
docEngine: fakeChatDocEngine{chunk: map[string]interface{}{
|
||||
"doc_id": "doc-1",
|
||||
"docnm_kwd": "doc.md",
|
||||
"kb_id": "kb-1",
|
||||
"content_with_weight": "parent passage with surrounding context",
|
||||
"position_int": []interface{}{1},
|
||||
}},
|
||||
modelProviderSvc: fakeChatModelProvider{driver: &fakeChatModelDriver{}},
|
||||
metadataSvc: fakeChatMetadataService{},
|
||||
retrievalSvc: retrieval,
|
||||
}
|
||||
|
||||
ragMessages, _, _, err := svc.messagesWithRetrievedKnowledge(context.Background(), "user-1", &entity.Chat{
|
||||
ID: "dialog-1",
|
||||
TenantID: "tenant-1",
|
||||
PromptConfig: entity.JSONMap{},
|
||||
KBIDs: entity.JSONSlice{"kb-1"},
|
||||
TopN: 3,
|
||||
TopK: 32,
|
||||
}, []map[string]interface{}{
|
||||
{"role": "user", "content": "question"},
|
||||
}, []interface{}{map[string]interface{}{"chunks": []interface{}{}, "doc_aggs": []interface{}{}}})
|
||||
if err != nil {
|
||||
t.Fatalf("messagesWithRetrievedKnowledge returned error: %v", err)
|
||||
}
|
||||
content, _ := ragMessages[0]["content"].(string)
|
||||
if !strings.Contains(content, "parent passage with surrounding context") {
|
||||
t.Fatalf("expected expanded parent content in prompt, got %q", content)
|
||||
}
|
||||
if strings.Contains(content, "child-only passage") {
|
||||
t.Fatalf("expected child content to be replaced by expanded parent content, got %q", content)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessagesWithRetrievedKnowledgeRejectsCrossTenantKnowledgebase(t *testing.T) {
|
||||
retrieval := &fakeChatRetrievalService{result: &nlp.RetrievalResult{}}
|
||||
svc := &ChatSessionService{
|
||||
kbDAO: fakeChatKBStore{
|
||||
kbs: []*entity.Knowledgebase{
|
||||
{ID: "kb-1", TenantID: "tenant-2", Name: "Manual", EmbdID: "embed@factory"},
|
||||
},
|
||||
accessible: map[string]bool{"kb-1": false},
|
||||
},
|
||||
modelProviderSvc: fakeChatModelProvider{driver: &fakeChatModelDriver{}},
|
||||
metadataSvc: fakeChatMetadataService{},
|
||||
retrievalSvc: retrieval,
|
||||
}
|
||||
|
||||
_, _, _, err := svc.messagesWithRetrievedKnowledge(context.Background(), "user-1", &entity.Chat{
|
||||
ID: "dialog-1",
|
||||
TenantID: "tenant-1",
|
||||
PromptConfig: entity.JSONMap{},
|
||||
KBIDs: entity.JSONSlice{"kb-1"},
|
||||
TopN: 3,
|
||||
TopK: 32,
|
||||
}, []map[string]interface{}{
|
||||
{"role": "user", "content": "question"},
|
||||
}, []interface{}{map[string]interface{}{"chunks": []interface{}{}, "doc_aggs": []interface{}{}}})
|
||||
if err == nil || !strings.Contains(err.Error(), "not authorized") {
|
||||
t.Fatalf("expected cross-tenant authorization error, got %v", err)
|
||||
}
|
||||
if retrieval.req != nil {
|
||||
t.Fatal("retrieval should not be called for an unauthorized knowledge base")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessagesWithRetrievedKnowledgeAllowsAccessibleSharedKnowledgebase(t *testing.T) {
|
||||
retrieval := &fakeChatRetrievalService{result: &nlp.RetrievalResult{}}
|
||||
svc := &ChatSessionService{
|
||||
kbDAO: fakeChatKBStore{
|
||||
kbs: []*entity.Knowledgebase{
|
||||
{ID: "kb-1", TenantID: "tenant-2", Name: "Shared Manual", EmbdID: "embed@factory"},
|
||||
},
|
||||
accessible: map[string]bool{"kb-1": true},
|
||||
},
|
||||
modelProviderSvc: fakeChatModelProvider{driver: &fakeChatModelDriver{}},
|
||||
metadataSvc: fakeChatMetadataService{},
|
||||
retrievalSvc: retrieval,
|
||||
}
|
||||
|
||||
_, _, _, err := svc.messagesWithRetrievedKnowledge(context.Background(), "user-1", &entity.Chat{
|
||||
ID: "dialog-1",
|
||||
TenantID: "tenant-1",
|
||||
PromptConfig: entity.JSONMap{},
|
||||
KBIDs: entity.JSONSlice{"kb-1"},
|
||||
TopN: 3,
|
||||
TopK: 32,
|
||||
}, []map[string]interface{}{
|
||||
{"role": "user", "content": "question"},
|
||||
}, []interface{}{map[string]interface{}{"chunks": []interface{}{}, "doc_aggs": []interface{}{}}})
|
||||
if err != nil {
|
||||
t.Fatalf("messagesWithRetrievedKnowledge returned error: %v", err)
|
||||
}
|
||||
if retrieval.req == nil || len(retrieval.req.TenantIDs) != 1 || retrieval.req.TenantIDs[0] != "tenant-2" {
|
||||
t.Fatalf("expected retrieval to use shared KB tenant, got %#v", retrieval.req)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessagesWithRetrievedKnowledgeRejectsMixedEmbeddingModels(t *testing.T) {
|
||||
retrieval := &fakeChatRetrievalService{result: &nlp.RetrievalResult{}}
|
||||
svc := &ChatSessionService{
|
||||
kbDAO: fakeChatKBStore{kbs: []*entity.Knowledgebase{
|
||||
{ID: "kb-1", TenantID: "tenant-1", Name: "Manual", EmbdID: "embed-a@factory"},
|
||||
{ID: "kb-2", TenantID: "tenant-1", Name: "FAQ", EmbdID: "embed-b@factory"},
|
||||
}},
|
||||
modelProviderSvc: fakeChatModelProvider{driver: &fakeChatModelDriver{}},
|
||||
metadataSvc: fakeChatMetadataService{},
|
||||
retrievalSvc: retrieval,
|
||||
}
|
||||
|
||||
_, _, _, err := svc.messagesWithRetrievedKnowledge(context.Background(), "user-1", &entity.Chat{
|
||||
ID: "dialog-1",
|
||||
TenantID: "tenant-1",
|
||||
PromptConfig: entity.JSONMap{},
|
||||
KBIDs: entity.JSONSlice{"kb-1", "kb-2"},
|
||||
TopN: 3,
|
||||
TopK: 32,
|
||||
}, []map[string]interface{}{
|
||||
{"role": "user", "content": "question"},
|
||||
}, []interface{}{map[string]interface{}{"chunks": []interface{}{}, "doc_aggs": []interface{}{}}})
|
||||
if err == nil || !strings.Contains(err.Error(), "same embedding model") {
|
||||
t.Fatalf("expected mixed embedding model error, got %v", err)
|
||||
}
|
||||
if retrieval.req != nil {
|
||||
t.Fatal("retrieval should not run when knowledge bases use different embedding models")
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateKnowledgebaseEmbeddingModelsComparesResolvedNames(t *testing.T) {
|
||||
firstTenantEmbdID := int64(1)
|
||||
secondTenantEmbdID := int64(2)
|
||||
kbs := []*entity.Knowledgebase{
|
||||
{ID: "kb-1", TenantID: "tenant-1", Name: "Manual", EmbdID: "same-legacy-name", TenantEmbdID: &firstTenantEmbdID},
|
||||
{ID: "kb-2", TenantID: "tenant-1", Name: "FAQ", EmbdID: "same-legacy-name", TenantEmbdID: &secondTenantEmbdID},
|
||||
}
|
||||
resolver := func(tenantID string, kb *entity.Knowledgebase) (string, error) {
|
||||
if kb.TenantEmbdID != nil && *kb.TenantEmbdID == firstTenantEmbdID {
|
||||
return "embed-a@factory", nil
|
||||
}
|
||||
return "embed-b@factory", nil
|
||||
}
|
||||
|
||||
_, _, err := validateKnowledgebaseEmbeddingModels(kbs, "tenant-1", resolver)
|
||||
if err == nil || !strings.Contains(err.Error(), "same embedding model") {
|
||||
t.Fatalf("expected resolved mixed embedding model error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessagesWithRetrievedKnowledgePreservesMultimodalContent(t *testing.T) {
|
||||
retrieval := &fakeChatRetrievalService{
|
||||
result: &nlp.RetrievalResult{
|
||||
Chunks: []map[string]interface{}{
|
||||
{"content_with_weight": "Knowledge for an image question."},
|
||||
},
|
||||
},
|
||||
}
|
||||
svc := &ChatSessionService{
|
||||
kbDAO: fakeChatKBStore{kbs: []*entity.Knowledgebase{
|
||||
{ID: "kb-1", TenantID: "tenant-1", Name: "Manual", EmbdID: "embed@factory"},
|
||||
}},
|
||||
modelProviderSvc: fakeChatModelProvider{driver: &fakeChatModelDriver{}},
|
||||
metadataSvc: fakeChatMetadataService{},
|
||||
retrievalSvc: retrieval,
|
||||
}
|
||||
imageBlock := map[string]interface{}{"type": "image_url", "image_url": map[string]interface{}{"url": "https://example.com/cat.png"}}
|
||||
messages := []map[string]interface{}{
|
||||
{"role": "user", "content": []interface{}{
|
||||
map[string]interface{}{"type": "text", "text": "What is in this image?"},
|
||||
imageBlock,
|
||||
}},
|
||||
}
|
||||
|
||||
got, _, emptyResponse, err := svc.messagesWithRetrievedKnowledge(context.Background(), "user-1", &entity.Chat{
|
||||
ID: "dialog-1",
|
||||
TenantID: "tenant-1",
|
||||
PromptConfig: entity.JSONMap{},
|
||||
KBIDs: entity.JSONSlice{"kb-1"},
|
||||
TopN: 3,
|
||||
TopK: 32,
|
||||
}, messages, []interface{}{map[string]interface{}{"chunks": []interface{}{}, "doc_aggs": []interface{}{}}})
|
||||
if err != nil {
|
||||
t.Fatalf("messagesWithRetrievedKnowledge returned error: %v", err)
|
||||
}
|
||||
if emptyResponse != nil {
|
||||
t.Fatalf("expected no empty response, got %q", *emptyResponse)
|
||||
}
|
||||
content, ok := got[0]["content"].([]interface{})
|
||||
if !ok {
|
||||
t.Fatalf("expected multimodal content to stay as blocks, got %T", got[0]["content"])
|
||||
}
|
||||
if len(content) != 3 {
|
||||
t.Fatalf("expected injected text plus original blocks, got %#v", content)
|
||||
}
|
||||
injected, ok := content[0].(map[string]interface{})
|
||||
if !ok || injected["type"] != "text" || !strings.Contains(injected["text"].(string), "Knowledge for an image question.") {
|
||||
t.Fatalf("expected injected knowledge text block, got %#v", content[0])
|
||||
}
|
||||
preservedImage, ok := content[2].(map[string]interface{})
|
||||
if !ok || preservedImage["type"] != "image_url" {
|
||||
t.Fatalf("expected original image block to be preserved, got %#v", content[2])
|
||||
}
|
||||
if retrieval.req == nil || retrieval.req.Question != "What is in this image?" {
|
||||
t.Fatalf("expected retrieval question from text block, got %#v", retrieval.req)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMessagesWithRetrievedKnowledgePassesMessageDocIDs(t *testing.T) {
|
||||
retrieval := &fakeChatRetrievalService{result: &nlp.RetrievalResult{}}
|
||||
svc := &ChatSessionService{
|
||||
kbDAO: fakeChatKBStore{kbs: []*entity.Knowledgebase{
|
||||
{ID: "kb-1", TenantID: "tenant-1", Name: "Manual", EmbdID: "embed@factory"},
|
||||
}},
|
||||
modelProviderSvc: fakeChatModelProvider{driver: &fakeChatModelDriver{}},
|
||||
metadataSvc: fakeChatMetadataService{},
|
||||
retrievalSvc: retrieval,
|
||||
}
|
||||
|
||||
_, _, _, err := svc.messagesWithRetrievedKnowledge(context.Background(), "user-1", &entity.Chat{
|
||||
ID: "dialog-1",
|
||||
TenantID: "tenant-1",
|
||||
PromptConfig: entity.JSONMap{},
|
||||
KBIDs: entity.JSONSlice{"kb-1"},
|
||||
TopN: 3,
|
||||
TopK: 32,
|
||||
}, []map[string]interface{}{
|
||||
{"role": "user", "content": "question", "doc_ids": []interface{}{"doc-1", "doc-2", "doc-1"}},
|
||||
}, []interface{}{map[string]interface{}{"chunks": []interface{}{}, "doc_aggs": []interface{}{}}})
|
||||
if err != nil {
|
||||
t.Fatalf("messagesWithRetrievedKnowledge returned error: %v", err)
|
||||
}
|
||||
if len(retrieval.req.DocIDs) != 2 || retrieval.req.DocIDs[0] != "doc-1" || retrieval.req.DocIDs[1] != "doc-2" {
|
||||
t.Fatalf("expected scoped doc ids, got %#v", retrieval.req.DocIDs)
|
||||
}
|
||||
}
|
||||
|
||||
type failingChatRetrievalService struct {
|
||||
err error
|
||||
}
|
||||
|
||||
func (f *failingChatRetrievalService) Retrieval(ctx context.Context, req *nlp.RetrievalRequest) (*nlp.RetrievalResult, error) {
|
||||
return nil, f.err
|
||||
}
|
||||
@@ -83,6 +83,9 @@ func compareValues(val1, val2, op string) bool {
|
||||
// ManualValueResolver is a callback function to transform manual filter values
|
||||
type ManualValueResolver func(map[string]interface{}) map[string]interface{}
|
||||
|
||||
// NoMatchDocIDSentinel forces retrieval to return no documents when filters match nothing.
|
||||
const NoMatchDocIDSentinel = "-999"
|
||||
|
||||
// metaFilterTemplateCache caches the template content
|
||||
var metaFilterTemplateCache string
|
||||
|
||||
@@ -264,7 +267,8 @@ func ApplyMetaFilter(metaData common.MetaData, filters []MetaFilterCondition, lo
|
||||
// - "==" = "=" "!=" = "≠"
|
||||
// - ">=" = "≥" "<=" = "≤"
|
||||
// - "is" = "=" "not is" = "≠"
|
||||
// (see common.metadata_utils.operatorMapping for the full list)
|
||||
// (see common.metadata_utils.operatorMapping for the full list)
|
||||
//
|
||||
// Value conversion:
|
||||
// - "in" / "not in": comma-separated string → []interface{} (as expected by common.MetaFilter)
|
||||
// - all other operators: passed through as-is (string)
|
||||
@@ -280,11 +284,13 @@ func convertToMetaCondition(f MetaFilterCondition) common.MetaCondition {
|
||||
parts := strings.Split(strVal, ",")
|
||||
arr := make([]interface{}, 0, len(parts))
|
||||
for _, p := range parts {
|
||||
if trimmed := strings.TrimSpace(p); trimmed != "" { arr = append(arr, trimmed) }
|
||||
}
|
||||
mc.Value = arr
|
||||
}
|
||||
return mc
|
||||
if trimmed := strings.TrimSpace(p); trimmed != "" {
|
||||
arr = append(arr, trimmed)
|
||||
}
|
||||
}
|
||||
mc.Value = arr
|
||||
}
|
||||
return mc
|
||||
}
|
||||
|
||||
// applySingleCondition applies a single filter condition and returns matching doc IDs
|
||||
@@ -587,9 +593,6 @@ func ApplyMetaDataFilter(
|
||||
return baseDocIDs, false
|
||||
}
|
||||
|
||||
docIDs := make([]string, len(baseDocIDs))
|
||||
copy(docIDs, baseDocIDs)
|
||||
|
||||
method, _ := metaDataFilter["method"].(string)
|
||||
|
||||
// Helper to run metadata filter with push-down fallback
|
||||
@@ -632,13 +635,14 @@ func ApplyMetaDataFilter(
|
||||
filters, err := GenMetaFilter(ctx, chatModel, metaData, question, nil)
|
||||
if err != nil {
|
||||
common.Warn("Failed to generate meta filter", zap.Error(err))
|
||||
return docIDs, false
|
||||
return baseDocIDs, false
|
||||
}
|
||||
filteredIDs := runMetadataFilter(filters.Conditions, filters.Logic)
|
||||
docIDs = append(docIDs, filteredIDs...)
|
||||
docIDs := constrainDocIDs(baseDocIDs, filteredIDs)
|
||||
if len(docIDs) == 0 {
|
||||
return nil, true // Return nil to indicate auto filter returned empty
|
||||
}
|
||||
return docIDs, false
|
||||
|
||||
case "semi_auto":
|
||||
selectedKeys := []string{}
|
||||
@@ -673,13 +677,14 @@ func ApplyMetaDataFilter(
|
||||
filters, err := GenMetaFilter(ctx, chatModel, filteredMeta, question, constraints)
|
||||
if err != nil {
|
||||
common.Warn("Failed to generate meta filter", zap.Error(err))
|
||||
return docIDs, false
|
||||
return baseDocIDs, false
|
||||
}
|
||||
filteredIDs := runMetadataFilter(filters.Conditions, filters.Logic)
|
||||
docIDs = append(docIDs, filteredIDs...)
|
||||
docIDs := constrainDocIDs(baseDocIDs, filteredIDs)
|
||||
if len(docIDs) == 0 {
|
||||
return nil, true
|
||||
}
|
||||
return docIDs, false
|
||||
}
|
||||
}
|
||||
|
||||
@@ -689,6 +694,9 @@ func ApplyMetaDataFilter(
|
||||
if logicVal, ok := metaDataFilter["logic"].(string); ok {
|
||||
logic = logicVal
|
||||
}
|
||||
if len(manualFilters) == 0 {
|
||||
return baseDocIDs, false
|
||||
}
|
||||
|
||||
// Apply manual_value_resolver callback if provided
|
||||
if len(manualValueResolver) > 0 && manualValueResolver[0] != nil {
|
||||
@@ -720,13 +728,42 @@ func ApplyMetaDataFilter(
|
||||
}
|
||||
|
||||
filteredIDs := runMetadataFilter(conditions, logic)
|
||||
docIDs = append(docIDs, filteredIDs...)
|
||||
docIDs := constrainDocIDs(baseDocIDs, filteredIDs)
|
||||
if len(manualFilters) > 0 && len(docIDs) == 0 {
|
||||
return []string{"-999"}, false
|
||||
return []string{NoMatchDocIDSentinel}, false
|
||||
}
|
||||
return docIDs, false
|
||||
}
|
||||
|
||||
return docIDs, false
|
||||
return baseDocIDs, false
|
||||
}
|
||||
|
||||
func constrainDocIDs(baseDocIDs, filteredDocIDs []string) []string {
|
||||
filteredDocIDs = common.Deduplicate(filteredDocIDs)
|
||||
if len(baseDocIDs) == 0 {
|
||||
return filteredDocIDs
|
||||
}
|
||||
if len(filteredDocIDs) == 0 {
|
||||
return []string{}
|
||||
}
|
||||
|
||||
filteredSet := make(map[string]struct{}, len(filteredDocIDs))
|
||||
for _, docID := range filteredDocIDs {
|
||||
filteredSet[docID] = struct{}{}
|
||||
}
|
||||
result := make([]string, 0, min(len(baseDocIDs), len(filteredSet)))
|
||||
seen := make(map[string]struct{}, len(baseDocIDs))
|
||||
for _, docID := range baseDocIDs {
|
||||
if _, allowed := filteredSet[docID]; !allowed {
|
||||
continue
|
||||
}
|
||||
if _, exists := seen[docID]; exists {
|
||||
continue
|
||||
}
|
||||
seen[docID] = struct{}{}
|
||||
result = append(result, docID)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// repairJSON attempts to fix common JSON formatting issues in LLM output.
|
||||
|
||||
@@ -455,7 +455,7 @@ func TestApplySingleConditionRelationalFallsBackToString(t *testing.T) {
|
||||
val string
|
||||
want []string
|
||||
}{
|
||||
{op: ">", val: "banana", want: []string{"d-c"}}, // "cherry" > "banana"
|
||||
{op: ">", val: "banana", want: []string{"d-c"}}, // "cherry" > "banana"
|
||||
{op: "<", val: "cherry", want: []string{"d-a", "d-b"}}, // apple, banana
|
||||
{op: ">=", val: "banana", want: []string{"d-b", "d-c"}},
|
||||
{op: "<=", val: "banana", want: []string{"d-a", "d-b"}},
|
||||
|
||||
@@ -1972,6 +1972,15 @@ func (m *ModelProviderService) GetChatModel(tenantID, compositeModelName string)
|
||||
return modelModule.NewChatModel(driver, &modelName, apiConfig), nil
|
||||
}
|
||||
|
||||
// GetRerankModel returns a RerankModel wrapper for the given tenant
|
||||
func (m *ModelProviderService) GetRerankModel(tenantID, compositeModelName string) (*modelModule.RerankModel, error) {
|
||||
driver, modelName, apiConfig, _, err := m.getModelConfig(tenantID, compositeModelName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return modelModule.NewRerankModel(driver, &modelName, apiConfig), nil
|
||||
}
|
||||
|
||||
type AddModelRequest struct {
|
||||
ProviderName string `json:"provider_name"`
|
||||
InstanceName string `json:"instance_name"`
|
||||
|
||||
Reference in New Issue
Block a user