From 486ca463aadf1a5ff088e6879efa4ac1f54ae2b2 Mon Sep 17 00:00:00 2001 From: qinling0210 <88864212+qinling0210@users.noreply.github.com> Date: Wed, 29 Apr 2026 17:04:22 +0800 Subject: [PATCH] Port PR14454 to GO (PruneDeletedChunks) (#14463) ### What problem does this PR solve? Port PR14454 to GO (PruneDeletedChunks) ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --- internal/dao/document.go | 10 +++ internal/service/chat_session.go | 30 ++++++-- internal/service/chunk.go | 4 +- internal/service/nlp/retrieval.go | 117 +++++++++++++++++++++++++++++- 4 files changed, 148 insertions(+), 13 deletions(-) diff --git a/internal/dao/document.go b/internal/dao/document.go index ddd13e35ad..e2e055a118 100644 --- a/internal/dao/document.go +++ b/internal/dao/document.go @@ -122,6 +122,16 @@ func (dao *DocumentDAO) GetAllDocIDsByKBIDs(kbIDs []string) ([]map[string]string return result, nil } +// GetByIDs retrieves documents by multiple IDs +func (dao *DocumentDAO) GetByIDs(ids []string) ([]*entity.Document, error) { + var documents []*entity.Document + err := DB.Where("id IN ?", ids).Find(&documents).Error + if err != nil { + return nil, err + } + return documents, nil +} + // CountByTenantID counts documents by tenant ID func (dao *DocumentDAO) CountByTenantID(tenantID string) (int64, error) { var count int64 diff --git a/internal/service/chat_session.go b/internal/service/chat_session.go index d563a2c363..30fdb53d1b 100644 --- a/internal/service/chat_session.go +++ b/internal/service/chat_session.go @@ -797,8 +797,9 @@ func (s *ChatSessionService) buildChatConfig(dialog *entity.Chat, config map[str if v, ok := dialog.LLMSetting["thinking"].(bool); ok { cfg.Thinking = &v } - if v, ok := dialog.LLMSetting["max_tokens"].(int); ok { - cfg.MaxTokens = &v + if v, ok := dialog.LLMSetting["max_tokens"].(float64); ok { + intVal := int(v) + cfg.MaxTokens = &intVal } if v, ok := dialog.LLMSetting["temperature"].(float64); ok { cfg.Temperature = &v @@ -809,8 +810,14 @@ func (s *ChatSessionService) buildChatConfig(dialog *entity.Chat, config map[str if v, ok := dialog.LLMSetting["do_sample"].(bool); ok { cfg.DoSample = &v } - if v, ok := dialog.LLMSetting["stop"].([]string); ok { - cfg.Stop = &v + if v, ok := dialog.LLMSetting["stop"].([]interface{}); ok { + stopStrs := make([]string, 0, len(v)) + for _, s := range v { + if str, ok := s.(string); ok { + stopStrs = append(stopStrs, str) + } + } + cfg.Stop = &stopStrs } if v, ok := dialog.LLMSetting["model_class"].(string); ok { cfg.ModelClass = &v @@ -831,8 +838,9 @@ func (s *ChatSessionService) buildChatConfig(dialog *entity.Chat, config map[str if v, ok := config["thinking"].(bool); ok { cfg.Thinking = &v } - if v, ok := config["max_tokens"].(int); ok { - cfg.MaxTokens = &v + if v, ok := config["max_tokens"].(float64); ok { + intVal := int(v) + cfg.MaxTokens = &intVal } if v, ok := config["temperature"].(float64); ok { cfg.Temperature = &v @@ -843,8 +851,14 @@ func (s *ChatSessionService) buildChatConfig(dialog *entity.Chat, config map[str if v, ok := config["do_sample"].(bool); ok { cfg.DoSample = &v } - if v, ok := config["stop"].([]string); ok { - cfg.Stop = &v + if v, ok := config["stop"].([]interface{}); ok { + stopStrs := make([]string, 0, len(v)) + for _, s := range v { + if str, ok := s.(string); ok { + stopStrs = append(stopStrs, str) + } + } + cfg.Stop = &stopStrs } if v, ok := config["model_class"].(string); ok { cfg.ModelClass = &v diff --git a/internal/service/chunk.go b/internal/service/chunk.go index fe9a71ff27..0da359d9d6 100644 --- a/internal/service/chunk.go +++ b/internal/service/chunk.go @@ -44,6 +44,7 @@ type ChunkService struct { embeddingCache *utility.EmbeddingLRU kbDAO *dao.KnowledgebaseDAO userTenantDAO *dao.UserTenantDAO + documentDAO *dao.DocumentDAO searchService *SearchService } @@ -56,6 +57,7 @@ func NewChunkService() *ChunkService { embeddingCache: utility.NewEmbeddingLRU(1000), // default capacity kbDAO: dao.NewKnowledgebaseDAO(), userTenantDAO: dao.NewUserTenantDAO(), + documentDAO: dao.NewDocumentDAO(), searchService: NewSearchService(), } } @@ -395,7 +397,7 @@ func (s *ChunkService) RetrievalTest(req *RetrievalTestRequest, userID string) ( } // Call RetrievalService to perform retrieval - retrievalResult, err := nlp.NewRetrievalService(s.docEngine).Retrieval(ctx, retrievalReq) + retrievalResult, err := nlp.NewRetrievalService(s.docEngine, s.documentDAO).Retrieval(ctx, retrievalReq) if err != nil { return nil, fmt.Errorf("retrieval search failed: %w", err) } diff --git a/internal/service/nlp/retrieval.go b/internal/service/nlp/retrieval.go index c271d32f40..36e38cf2d4 100644 --- a/internal/service/nlp/retrieval.go +++ b/internal/service/nlp/retrieval.go @@ -20,6 +20,7 @@ import ( "context" "fmt" "math" + "ragflow/internal/dao" "ragflow/internal/engine" "ragflow/internal/engine/types" "ragflow/internal/entity/models" @@ -34,12 +35,13 @@ import ( // RetrievalService provides retrieval search functionality type RetrievalService struct { - docEngine engine.DocEngine + docEngine engine.DocEngine + documentDAO *dao.DocumentDAO } // NewRetrievalService creates a new RetrievalService with the given doc engine -func NewRetrievalService(docEngine engine.DocEngine) *RetrievalService { - return &RetrievalService{docEngine: docEngine} +func NewRetrievalService(docEngine engine.DocEngine, documentDAO *dao.DocumentDAO) *RetrievalService { + return &RetrievalService{docEngine: docEngine, documentDAO: documentDAO} } // RetrievalRequest request for retrieval search @@ -146,7 +148,15 @@ func (s *RetrievalService) Retrieval(ctx context.Context, req *RetrievalRequest) return nil, fmt.Errorf("Search failed: %w", err) } - // Perform reranking + // Prune deleted chunks + searchResult, err = s.PruneDeletedChunks(searchResult) + if err != nil { + return nil, fmt.Errorf("PruneDeletedChunks failed: %w", err) + } + if searchResult.Total == 0 { + return &RetrievalResult{Chunks: []map[string]interface{}{}, DocAggs: []map[string]interface{}{}}, nil + } + vtWeight := *req.VectorSimilarityWeight tkWeight := 1.0 - vtWeight qb := GetQueryBuilder() @@ -778,6 +788,105 @@ func RetrievalByChildren(chunks []map[string]interface{}, tenantIDs []string, do return remainingChunks } +// PruneDeletedChunks removes chunks whose documents no longer exist +func (s *RetrievalService) PruneDeletedChunks(result *RetrievalSearchResult) (*RetrievalSearchResult, error) { + if s.documentDAO == nil { + return nil, fmt.Errorf("documentDAO is not initialized") + } + // Collect all doc_ids from chunks + chunkDocIDs := make([]string, 0, len(result.Field)) + for _, chunk := range result.Field { + if docID, ok := chunk["doc_id"].(string); ok && docID != "" { + chunkDocIDs = append(chunkDocIDs, docID) + } + } + + if len(chunkDocIDs) == 0 { + return result, nil + } + + // Deduplicate chunkDocIDs for correct comparison with existingDocIDs + uniqueDocIDs := make([]string, 0, len(chunkDocIDs)) + seen := make(map[string]struct{}, len(chunkDocIDs)) + for _, id := range chunkDocIDs { + if _, exists := seen[id]; !exists { + seen[id] = struct{}{} + uniqueDocIDs = append(uniqueDocIDs, id) + } + } + + // Get existing document IDs + docs, err := s.documentDAO.GetByIDs(uniqueDocIDs) + if err != nil { + return nil, fmt.Errorf("GetByIDs failed: %w", err) + } + + existingDocIDs := make(map[string]struct{}, len(docs)) + for _, doc := range docs { + existingDocIDs[doc.ID] = struct{}{} + } + + // Early return if all docs exist + if len(existingDocIDs) == len(uniqueDocIDs) { + return result, nil + } + + // Filter out chunks with deleted documents + filteredIDs := make([]string, 0, len(result.IDs)) + filteredChunks := make([]map[string]interface{}, 0, len(result.IDs)) + filteredField := make(map[string]map[string]interface{}, len(result.IDs)) + filteredHighlight := make(map[string]string) + removed := 0 + + for _, chunkID := range result.IDs { + chunk, exists := result.Field[chunkID] + if !exists { + continue + } + docID, ok := chunk["doc_id"].(string) + if !ok || docID == "" { + // Keep chunks without doc_id + filteredIDs = append(filteredIDs, chunkID) + filteredChunks = append(filteredChunks, chunk) + filteredField[chunkID] = chunk + if result.Highlight != nil { + if hl, ok := result.Highlight[chunkID]; ok { + filteredHighlight[chunkID] = hl + } + } + continue + } + if _, docExists := existingDocIDs[docID]; !docExists { + removed++ + continue + } + filteredIDs = append(filteredIDs, chunkID) + filteredChunks = append(filteredChunks, chunk) + filteredField[chunkID] = chunk + if result.Highlight != nil { + if hl, ok := result.Highlight[chunkID]; ok { + filteredHighlight[chunkID] = hl + } + } + } + + if removed > 0 { + logger.Warn("Pruned stale chunks whose documents no longer exist", zap.Int("removed", removed)) + } + + return &RetrievalSearchResult{ + Chunks: filteredChunks, + Total: int64(len(filteredIDs)), + QueryVector: result.QueryVector, + Highlight: filteredHighlight, + Field: filteredField, + IDs: filteredIDs, + Keywords: result.Keywords, + Aggregation: result.Aggregation, + Options: result.Options, + }, nil +} + // buildIndexNames creates index names for the given tenant IDs func buildIndexNames(tenantIDs []string) []string { indexNames := make([]string, len(tenantIDs))