mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-07-05 10:58:34 +08:00
feat(go-api): add RAG retrieval to chat completions (#15739)
## Summary - Add knowledge-base retrieval support to Go chat completions. ## What changed - Routes KB-backed chat sessions through the Go retrieval service instead of falling back to solo chat. - Resolves embedding and rerank models, validates accessible knowledge bases, and preserves tenant-aware retrieval. - Rejects mixed embedding models across selected knowledge bases before retrieval to avoid incompatible vector dimensions. - Threads the HTTP request context into streaming retrieval so cancelled requests can stop downstream retrieval work. - Applies metadata filters and message-level `doc_ids` before retrieval. - Expands parent/child chunks before building references and prompt context. - Injects retrieved knowledge through a copied dialog prompt config so the caller's original dialog is not mutated. - Honors configured empty responses when no chunks are found. - Names the metadata no-match sentinel and reuses it across retrieval/handler paths. - Adds a defensive content cast while appending streamed answers. - Adds focused unit coverage for retrieval, metadata filtering, authorization, multimodal messages, references, empty-response behavior, prompt immutability, and mixed embedding models. --------- Co-authored-by: Yingfeng <yingfeng.zhang@gmail.com> Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -292,9 +292,10 @@ func (h *ChatSessionHandler) Completion(c *gin.Context) {
|
||||
|
||||
// Create a channel for streaming data
|
||||
streamChan := make(chan string)
|
||||
reqCtx := c.Request.Context()
|
||||
go func() {
|
||||
defer close(streamChan)
|
||||
err := h.chatSessionService.CompletionStream(userID, req.ConversationID, processedMessages, req.LLMID, chatModelConfig, messageID, streamChan)
|
||||
err := h.chatSessionService.CompletionStream(reqCtx, userID, req.ConversationID, processedMessages, req.LLMID, chatModelConfig, messageID, streamChan)
|
||||
if err != nil {
|
||||
streamChan <- fmt.Sprintf("data: %s\n\n", err.Error())
|
||||
}
|
||||
|
||||
@@ -24,9 +24,9 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"ragflow/internal/common"
|
||||
"gorm.io/gorm"
|
||||
"go.uber.org/zap"
|
||||
"gorm.io/gorm"
|
||||
"ragflow/internal/common"
|
||||
"ragflow/internal/engine"
|
||||
"ragflow/internal/entity"
|
||||
modelModule "ragflow/internal/entity/models"
|
||||
@@ -71,16 +71,16 @@ type DocumentDAOIface interface {
|
||||
|
||||
// difyRetrievalRequest is the JSON body / query params for the Dify retrieval endpoint.
|
||||
type difyRetrievalRequest struct {
|
||||
KnowledgeID string `json:"knowledge_id" form:"knowledge_id"`
|
||||
Query string `json:"query" form:"query"`
|
||||
UseKG bool `json:"use_kg" form:"use_kg"`
|
||||
RetrievalSetting *difyRetrievalSetting `json:"retrieval_setting"`
|
||||
MetadataCondition *difyMetadataCondition `json:"metadata_condition"`
|
||||
KnowledgeID string `json:"knowledge_id" form:"knowledge_id"`
|
||||
Query string `json:"query" form:"query"`
|
||||
UseKG bool `json:"use_kg" form:"use_kg"`
|
||||
RetrievalSetting *difyRetrievalSetting `json:"retrieval_setting"`
|
||||
MetadataCondition *difyMetadataCondition `json:"metadata_condition"`
|
||||
}
|
||||
|
||||
type difyRetrievalSetting struct {
|
||||
TopK *int `json:"top_k" form:"top_k"`
|
||||
ScoreThreshold *float64 `json:"score_threshold" form:"score_threshold"`
|
||||
TopK *int `json:"top_k" form:"top_k"`
|
||||
ScoreThreshold *float64 `json:"score_threshold" form:"score_threshold"`
|
||||
}
|
||||
|
||||
// difyCondition is a Dify-format metadata filter condition.
|
||||
@@ -108,8 +108,8 @@ func (c difyMetadataCondition) toMetaFilterConditions() []service.MetaFilterCond
|
||||
v = fmt.Sprint(dc.Value)
|
||||
}
|
||||
result[i] = service.MetaFilterCondition{
|
||||
Key: dc.Name,
|
||||
Op: dc.ComparisonOperator,
|
||||
Key: dc.Name,
|
||||
Op: dc.ComparisonOperator,
|
||||
Value: v,
|
||||
}
|
||||
}
|
||||
@@ -241,15 +241,15 @@ func (h *DifyRetrievalHandler) Retrieval(c *gin.Context) {
|
||||
metas, metaErr := h.metadataSvc.GetFlattedMetaByKBs([]string{req.KnowledgeID})
|
||||
docIDs := make([]string, 0)
|
||||
if metaErr == nil && req.MetadataCondition != nil {
|
||||
logic := req.MetadataCondition.Logic
|
||||
if logic == "" {
|
||||
logic = "and"
|
||||
}
|
||||
filteredIDs := service.ApplyMetaFilter(metas, req.MetadataCondition.toMetaFilterConditions(), logic)
|
||||
logic := req.MetadataCondition.Logic
|
||||
if logic == "" {
|
||||
logic = "and"
|
||||
}
|
||||
filteredIDs := service.ApplyMetaFilter(metas, req.MetadataCondition.toMetaFilterConditions(), logic)
|
||||
docIDs = append(docIDs, filteredIDs...)
|
||||
}
|
||||
if len(docIDs) == 0 && req.MetadataCondition != nil {
|
||||
docIDs = []string{"-999"}
|
||||
docIDs = []string{service.NoMatchDocIDSentinel}
|
||||
}
|
||||
|
||||
// Label question for rank features
|
||||
@@ -258,15 +258,15 @@ func (h *DifyRetrievalHandler) Retrieval(c *gin.Context) {
|
||||
|
||||
// Chunk retrieval
|
||||
sr := &nlp.RetrievalRequest{
|
||||
Question: req.Query,
|
||||
TenantIDs: []string{kb.TenantID},
|
||||
KbIDs: []string{req.KnowledgeID},
|
||||
DocIDs: docIDs,
|
||||
Page: 1,
|
||||
PageSize: pageSize,
|
||||
Top: topK,
|
||||
SimilarityThreshold: scoreThreshold,
|
||||
EmbeddingModel: embModel,
|
||||
Question: req.Query,
|
||||
TenantIDs: []string{kb.TenantID},
|
||||
KbIDs: []string{req.KnowledgeID},
|
||||
DocIDs: docIDs,
|
||||
Page: 1,
|
||||
PageSize: pageSize,
|
||||
Top: topK,
|
||||
SimilarityThreshold: scoreThreshold,
|
||||
EmbeddingModel: embModel,
|
||||
}
|
||||
if rankFeature != nil {
|
||||
sr.RankFeature = &rankFeature
|
||||
|
||||
Reference in New Issue
Block a user