mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 23:41:12 +08:00
Port PR14454 to GO (PruneDeletedChunks) (#14463)
### What problem does this PR solve? Port PR14454 to GO (PruneDeletedChunks) ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
@@ -122,6 +122,16 @@ func (dao *DocumentDAO) GetAllDocIDsByKBIDs(kbIDs []string) ([]map[string]string
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetByIDs retrieves documents by multiple IDs
|
||||
func (dao *DocumentDAO) GetByIDs(ids []string) ([]*entity.Document, error) {
|
||||
var documents []*entity.Document
|
||||
err := DB.Where("id IN ?", ids).Find(&documents).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return documents, nil
|
||||
}
|
||||
|
||||
// CountByTenantID counts documents by tenant ID
|
||||
func (dao *DocumentDAO) CountByTenantID(tenantID string) (int64, error) {
|
||||
var count int64
|
||||
|
||||
@@ -797,8 +797,9 @@ func (s *ChatSessionService) buildChatConfig(dialog *entity.Chat, config map[str
|
||||
if v, ok := dialog.LLMSetting["thinking"].(bool); ok {
|
||||
cfg.Thinking = &v
|
||||
}
|
||||
if v, ok := dialog.LLMSetting["max_tokens"].(int); ok {
|
||||
cfg.MaxTokens = &v
|
||||
if v, ok := dialog.LLMSetting["max_tokens"].(float64); ok {
|
||||
intVal := int(v)
|
||||
cfg.MaxTokens = &intVal
|
||||
}
|
||||
if v, ok := dialog.LLMSetting["temperature"].(float64); ok {
|
||||
cfg.Temperature = &v
|
||||
@@ -809,8 +810,14 @@ func (s *ChatSessionService) buildChatConfig(dialog *entity.Chat, config map[str
|
||||
if v, ok := dialog.LLMSetting["do_sample"].(bool); ok {
|
||||
cfg.DoSample = &v
|
||||
}
|
||||
if v, ok := dialog.LLMSetting["stop"].([]string); ok {
|
||||
cfg.Stop = &v
|
||||
if v, ok := dialog.LLMSetting["stop"].([]interface{}); ok {
|
||||
stopStrs := make([]string, 0, len(v))
|
||||
for _, s := range v {
|
||||
if str, ok := s.(string); ok {
|
||||
stopStrs = append(stopStrs, str)
|
||||
}
|
||||
}
|
||||
cfg.Stop = &stopStrs
|
||||
}
|
||||
if v, ok := dialog.LLMSetting["model_class"].(string); ok {
|
||||
cfg.ModelClass = &v
|
||||
@@ -831,8 +838,9 @@ func (s *ChatSessionService) buildChatConfig(dialog *entity.Chat, config map[str
|
||||
if v, ok := config["thinking"].(bool); ok {
|
||||
cfg.Thinking = &v
|
||||
}
|
||||
if v, ok := config["max_tokens"].(int); ok {
|
||||
cfg.MaxTokens = &v
|
||||
if v, ok := config["max_tokens"].(float64); ok {
|
||||
intVal := int(v)
|
||||
cfg.MaxTokens = &intVal
|
||||
}
|
||||
if v, ok := config["temperature"].(float64); ok {
|
||||
cfg.Temperature = &v
|
||||
@@ -843,8 +851,14 @@ func (s *ChatSessionService) buildChatConfig(dialog *entity.Chat, config map[str
|
||||
if v, ok := config["do_sample"].(bool); ok {
|
||||
cfg.DoSample = &v
|
||||
}
|
||||
if v, ok := config["stop"].([]string); ok {
|
||||
cfg.Stop = &v
|
||||
if v, ok := config["stop"].([]interface{}); ok {
|
||||
stopStrs := make([]string, 0, len(v))
|
||||
for _, s := range v {
|
||||
if str, ok := s.(string); ok {
|
||||
stopStrs = append(stopStrs, str)
|
||||
}
|
||||
}
|
||||
cfg.Stop = &stopStrs
|
||||
}
|
||||
if v, ok := config["model_class"].(string); ok {
|
||||
cfg.ModelClass = &v
|
||||
|
||||
@@ -44,6 +44,7 @@ type ChunkService struct {
|
||||
embeddingCache *utility.EmbeddingLRU
|
||||
kbDAO *dao.KnowledgebaseDAO
|
||||
userTenantDAO *dao.UserTenantDAO
|
||||
documentDAO *dao.DocumentDAO
|
||||
searchService *SearchService
|
||||
}
|
||||
|
||||
@@ -56,6 +57,7 @@ func NewChunkService() *ChunkService {
|
||||
embeddingCache: utility.NewEmbeddingLRU(1000), // default capacity
|
||||
kbDAO: dao.NewKnowledgebaseDAO(),
|
||||
userTenantDAO: dao.NewUserTenantDAO(),
|
||||
documentDAO: dao.NewDocumentDAO(),
|
||||
searchService: NewSearchService(),
|
||||
}
|
||||
}
|
||||
@@ -395,7 +397,7 @@ func (s *ChunkService) RetrievalTest(req *RetrievalTestRequest, userID string) (
|
||||
}
|
||||
|
||||
// Call RetrievalService to perform retrieval
|
||||
retrievalResult, err := nlp.NewRetrievalService(s.docEngine).Retrieval(ctx, retrievalReq)
|
||||
retrievalResult, err := nlp.NewRetrievalService(s.docEngine, s.documentDAO).Retrieval(ctx, retrievalReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("retrieval search failed: %w", err)
|
||||
}
|
||||
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
"ragflow/internal/dao"
|
||||
"ragflow/internal/engine"
|
||||
"ragflow/internal/engine/types"
|
||||
"ragflow/internal/entity/models"
|
||||
@@ -34,12 +35,13 @@ import (
|
||||
|
||||
// RetrievalService provides retrieval search functionality
|
||||
type RetrievalService struct {
|
||||
docEngine engine.DocEngine
|
||||
docEngine engine.DocEngine
|
||||
documentDAO *dao.DocumentDAO
|
||||
}
|
||||
|
||||
// NewRetrievalService creates a new RetrievalService with the given doc engine
|
||||
func NewRetrievalService(docEngine engine.DocEngine) *RetrievalService {
|
||||
return &RetrievalService{docEngine: docEngine}
|
||||
func NewRetrievalService(docEngine engine.DocEngine, documentDAO *dao.DocumentDAO) *RetrievalService {
|
||||
return &RetrievalService{docEngine: docEngine, documentDAO: documentDAO}
|
||||
}
|
||||
|
||||
// RetrievalRequest request for retrieval search
|
||||
@@ -146,7 +148,15 @@ func (s *RetrievalService) Retrieval(ctx context.Context, req *RetrievalRequest)
|
||||
return nil, fmt.Errorf("Search failed: %w", err)
|
||||
}
|
||||
|
||||
// Perform reranking
|
||||
// Prune deleted chunks
|
||||
searchResult, err = s.PruneDeletedChunks(searchResult)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("PruneDeletedChunks failed: %w", err)
|
||||
}
|
||||
if searchResult.Total == 0 {
|
||||
return &RetrievalResult{Chunks: []map[string]interface{}{}, DocAggs: []map[string]interface{}{}}, nil
|
||||
}
|
||||
|
||||
vtWeight := *req.VectorSimilarityWeight
|
||||
tkWeight := 1.0 - vtWeight
|
||||
qb := GetQueryBuilder()
|
||||
@@ -778,6 +788,105 @@ func RetrievalByChildren(chunks []map[string]interface{}, tenantIDs []string, do
|
||||
return remainingChunks
|
||||
}
|
||||
|
||||
// PruneDeletedChunks removes chunks whose documents no longer exist
|
||||
func (s *RetrievalService) PruneDeletedChunks(result *RetrievalSearchResult) (*RetrievalSearchResult, error) {
|
||||
if s.documentDAO == nil {
|
||||
return nil, fmt.Errorf("documentDAO is not initialized")
|
||||
}
|
||||
// Collect all doc_ids from chunks
|
||||
chunkDocIDs := make([]string, 0, len(result.Field))
|
||||
for _, chunk := range result.Field {
|
||||
if docID, ok := chunk["doc_id"].(string); ok && docID != "" {
|
||||
chunkDocIDs = append(chunkDocIDs, docID)
|
||||
}
|
||||
}
|
||||
|
||||
if len(chunkDocIDs) == 0 {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Deduplicate chunkDocIDs for correct comparison with existingDocIDs
|
||||
uniqueDocIDs := make([]string, 0, len(chunkDocIDs))
|
||||
seen := make(map[string]struct{}, len(chunkDocIDs))
|
||||
for _, id := range chunkDocIDs {
|
||||
if _, exists := seen[id]; !exists {
|
||||
seen[id] = struct{}{}
|
||||
uniqueDocIDs = append(uniqueDocIDs, id)
|
||||
}
|
||||
}
|
||||
|
||||
// Get existing document IDs
|
||||
docs, err := s.documentDAO.GetByIDs(uniqueDocIDs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("GetByIDs failed: %w", err)
|
||||
}
|
||||
|
||||
existingDocIDs := make(map[string]struct{}, len(docs))
|
||||
for _, doc := range docs {
|
||||
existingDocIDs[doc.ID] = struct{}{}
|
||||
}
|
||||
|
||||
// Early return if all docs exist
|
||||
if len(existingDocIDs) == len(uniqueDocIDs) {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Filter out chunks with deleted documents
|
||||
filteredIDs := make([]string, 0, len(result.IDs))
|
||||
filteredChunks := make([]map[string]interface{}, 0, len(result.IDs))
|
||||
filteredField := make(map[string]map[string]interface{}, len(result.IDs))
|
||||
filteredHighlight := make(map[string]string)
|
||||
removed := 0
|
||||
|
||||
for _, chunkID := range result.IDs {
|
||||
chunk, exists := result.Field[chunkID]
|
||||
if !exists {
|
||||
continue
|
||||
}
|
||||
docID, ok := chunk["doc_id"].(string)
|
||||
if !ok || docID == "" {
|
||||
// Keep chunks without doc_id
|
||||
filteredIDs = append(filteredIDs, chunkID)
|
||||
filteredChunks = append(filteredChunks, chunk)
|
||||
filteredField[chunkID] = chunk
|
||||
if result.Highlight != nil {
|
||||
if hl, ok := result.Highlight[chunkID]; ok {
|
||||
filteredHighlight[chunkID] = hl
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
if _, docExists := existingDocIDs[docID]; !docExists {
|
||||
removed++
|
||||
continue
|
||||
}
|
||||
filteredIDs = append(filteredIDs, chunkID)
|
||||
filteredChunks = append(filteredChunks, chunk)
|
||||
filteredField[chunkID] = chunk
|
||||
if result.Highlight != nil {
|
||||
if hl, ok := result.Highlight[chunkID]; ok {
|
||||
filteredHighlight[chunkID] = hl
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if removed > 0 {
|
||||
logger.Warn("Pruned stale chunks whose documents no longer exist", zap.Int("removed", removed))
|
||||
}
|
||||
|
||||
return &RetrievalSearchResult{
|
||||
Chunks: filteredChunks,
|
||||
Total: int64(len(filteredIDs)),
|
||||
QueryVector: result.QueryVector,
|
||||
Highlight: filteredHighlight,
|
||||
Field: filteredField,
|
||||
IDs: filteredIDs,
|
||||
Keywords: result.Keywords,
|
||||
Aggregation: result.Aggregation,
|
||||
Options: result.Options,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// buildIndexNames creates index names for the given tenant IDs
|
||||
func buildIndexNames(tenantIDs []string) []string {
|
||||
indexNames := make([]string, len(tenantIDs))
|
||||
|
||||
Reference in New Issue
Block a user