Files
ragflow/internal/service/chunk/chunk.go
Jack 87b8062df4 feat: implement POST /api/v1/searchbots/ask — streaming RAG with citations and think-tag processing (#15825)
Implements POST /api/v1/searchbots/ask in Go with streaming SSE,
citations, and think-tag processing. 23 files, 90+ unit tests.

---------

Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-09 22:48:50 +08:00

935 lines
30 KiB
Go

//
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
package chunk
import (
"context"
"fmt"
"ragflow/internal/common"
"ragflow/internal/entity"
"ragflow/internal/entity/models"
"ragflow/internal/server"
"strconv"
"strings"
"go.uber.org/zap"
"ragflow/internal/dao"
"ragflow/internal/engine"
"ragflow/internal/engine/types"
"ragflow/internal/service"
"ragflow/internal/service/nlp"
"ragflow/internal/tokenizer"
"ragflow/internal/utility"
)
// ChunkService chunk service
type ChunkService struct {
docEngine engine.DocEngine
engineType server.EngineType
embeddingCache *utility.EmbeddingLRU
kbDAO *dao.KnowledgebaseDAO
userTenantDAO *dao.UserTenantDAO
documentDAO *dao.DocumentDAO
searchService *service.SearchService
}
// NewChunkService creates chunk service
func NewChunkService() *ChunkService {
cfg := server.GetConfig()
return &ChunkService{
docEngine: engine.Get(),
engineType: cfg.DocEngine.Type,
embeddingCache: utility.NewEmbeddingLRU(1000), // default capacity
kbDAO: dao.NewKnowledgebaseDAO(),
userTenantDAO: dao.NewUserTenantDAO(),
documentDAO: dao.NewDocumentDAO(),
searchService: service.NewSearchService(),
}
}
// RetrievalTest performs retrieval test for a given question against specified knowledge bases.
//
// Flow:
// 1. Validate kbs permissions and embedding model
// 2. Apply metadata filter if specified (auto/semi_auto uses LLM, manual uses provided conditions)
// 3. Apply cross_languages transformation if requested (translate question)
// 4. Apply keyword extraction if requested (append keywords to question)
// 5. Get rank features via LabelQuestion() - tag-based weights or pagerank_fld fallback
// 6. Call RetrievalService.Retrieval() which:
// - Computes query embedding
// - Performs hybrid search (text + vector) with rank features
// - Reranks results
// - Builds doc_aggs by aggregating chunks per document
// 7. knowledge graph retrieval (not implemented)
// 8. Apply retrieval by children to group child chunks under parent chunks
func (s *ChunkService) RetrievalTest(req *service.RetrievalTestRequest, userID string) (*service.RetrievalTestResponse, error) {
common.Info("RetrievalTest started", zap.String("userID", userID), zap.Any("kbID", req.Datasets), zap.String("question", req.Question))
common.Debug(fmt.Sprintf("RetrievalTest request:\n"+
" kbID=%v\n"+
" question=%s\n"+
" page=%v, size=%v\n"+
" docIDs=%v\n"+
" useKG=%v, topK=%v\n"+
" crossLanguages=%v\n"+
" searchID=%v\n"+
" filter=%v\n"+
" tenantRerankID=%v\n"+
" rerankID=%v\n"+
" keyword=%v\n"+
" similarityThreshold=%v, vectorSimilarityWeight=%v",
req.Datasets, req.Question,
common.PtrString(req.Page), common.PtrString(req.Size), req.DocIDs,
common.PtrString(req.UseKG), common.PtrString(req.TopK), req.CrossLanguages, common.PtrString(req.SearchID),
req.Filter,
common.PtrString(req.TenantRerankID), common.PtrString(req.RerankID),
common.PtrString(req.Keyword),
common.PtrString(req.SimilarityThreshold), common.PtrString(req.VectorSimilarityWeight)))
if req.Question == "" {
return nil, fmt.Errorf("question is required")
}
if len(req.Datasets) == 0 {
return nil, fmt.Errorf("dataset_ids is required")
}
ctx := context.Background()
tenants, err := s.userTenantDAO.GetByUserID(userID)
if err != nil {
return nil, fmt.Errorf("failed to get user tenants: %w", err)
}
if len(tenants) == 0 {
return nil, fmt.Errorf("user has no accessible tenants")
}
common.Debug("Retrieved user tenants from database", zap.String("userID", userID), zap.Int("tenantCount", len(tenants)))
var tenantIDs []string
var kbRecords []*entity.Knowledgebase
for _, datasetID := range req.Datasets {
found := false
for _, tenant := range tenants {
kb, err := s.kbDAO.GetByIDAndTenantID(datasetID, tenant.TenantID)
if err == nil && kb != nil {
common.Debug("Found knowledge base in database",
zap.String("datasetID", datasetID),
zap.String("tenantID", tenant.TenantID),
zap.String("kbName", kb.Name),
zap.String("embdID", kb.EmbdID))
tenantIDs = append(tenantIDs, tenant.TenantID)
kbRecords = append(kbRecords, kb)
found = true
break
}
}
if !found {
return nil, fmt.Errorf("only owner of dataset is authorized for this operation")
}
}
// Check if all kbs have the same embedding model
if len(kbRecords) > 1 {
firstEmbdID := kbRecords[0].EmbdID
for i := 1; i < len(kbRecords); i++ {
if kbRecords[i].EmbdID != firstEmbdID {
return nil, fmt.Errorf("cannot retrieve across datasets with different embedding models")
}
}
}
// Determine meta_data_filter
var chatID string
var chatModelForFilter *models.ChatModel
filter := req.Filter
if req.SearchID != nil && *req.SearchID != "" {
// If search_id is set, get meta_data_filter and chat_id from search_config
searchDetail, err := s.searchService.GetDetail(*req.SearchID)
if err != nil {
common.Warn("Failed to get search detail for search_id, proceeding without it", zap.String("searchID", *req.SearchID), zap.Error(err))
} else if searchConfig, ok := searchDetail["search_config"].(entity.JSONMap); ok && searchConfig != nil {
if searchMetaFilter, ok := searchConfig["meta_data_filter"].(map[string]interface{}); ok {
filter = searchMetaFilter
}
chatID, _ = searchConfig["chat_id"].(string)
} else {
common.Warn("No search_config found in search detail", zap.String("searchID", *req.SearchID))
}
}
// If meta_data_filter method is auto/semi_auto, get chat model
if filter != nil {
method, _ := filter["method"].(string)
if method == "auto" || method == "semi_auto" {
modelProviderSvc := service.NewModelProviderService()
if chatID != "" {
// Use chat_id from search_config (it's actually the model name)
driver, mdlName, apiConfig, _, getErr := modelProviderSvc.GetModelConfigFromProviderInstance(tenantIDs[0], entity.ModelTypeChat, chatID)
if getErr != nil {
common.Warn("Failed to get chat model from search_config chat_id, using tenant default", zap.String("chatID", chatID), zap.Error(getErr))
} else {
chatModelForFilter = models.NewChatModel(driver, &mdlName, apiConfig)
common.Info("Fetched chat model (from search_config) for metadata filter",
zap.String("chatID", chatID),
zap.String("tenantID", tenantIDs[0]))
}
}
// If no chatID from search_config, or chatModel not found, use tenant default
if chatModelForFilter == nil {
tenantSvc := service.NewTenantService()
modelName, err := tenantSvc.GetDefaultModelName(tenantIDs[0], entity.ModelTypeChat)
if err != nil || modelName == "" {
common.Warn("Failed to get tenant default chat model name for meta_data_filter", zap.Error(err))
} else {
driver, mdlName, apiConfig, _, getErr := modelProviderSvc.GetModelConfigFromProviderInstance(tenantIDs[0], entity.ModelTypeChat, modelName)
if getErr != nil {
common.Warn("Failed to get chat model for meta_data_filter", zap.Error(getErr))
} else {
chatModelForFilter = models.NewChatModel(driver, &mdlName, apiConfig)
common.Info("Fetched chat model (tenant default) for metadata filter",
zap.String("tenantID", tenantIDs[0]),
zap.String("modelName", modelName))
}
}
}
}
}
// Apply meta_data_filter to get filtered doc_ids (filter by metadata before retrieval)
docIDs := make([]string, len(req.DocIDs))
copy(docIDs, req.DocIDs)
if filter != nil {
// Get flattened metadata
metadataSvc := service.NewMetadataService()
flattedMeta, err := metadataSvc.GetFlattedMetaByKBs([]string(req.Datasets))
if err != nil {
common.Warn("Failed to get flatted metadata", zap.Error(err))
} else {
common.Info("metadata filter conditions", zap.Any("filter", filter))
filteredDocIDs, _ := service.ApplyMetaDataFilter(ctx, filter, flattedMeta, req.Question, chatModelForFilter, req.DocIDs, []string(req.Datasets))
docIDs = filteredDocIDs
common.Info("ApplyMetaDataFilter result", zap.Strings("docIDs", docIDs))
}
}
// Apply cross_languages and keyword extraction with tenant default chat model
modifiedQuestion := req.Question
var chatModel *models.ChatModel
// Get chat model for cross_languages and keyword_extraction
var llmModelName string
if len(req.CrossLanguages) > 0 || (req.Keyword != nil && *req.Keyword) {
tenantSvc := service.NewTenantService()
modelProviderSvc := service.NewModelProviderService()
var err error
llmModelName, err = tenantSvc.GetDefaultModelName(tenantIDs[0], "chat")
if err != nil || llmModelName == "" {
common.Warn("Failed to get default chat model name for LLM transformations", zap.Error(err))
} else {
driver, mdlName, apiConfig, _, getErr := modelProviderSvc.GetModelConfigFromProviderInstance(tenantIDs[0], entity.ModelTypeChat, llmModelName)
if getErr != nil {
common.Warn("Failed to get chat model for LLM transformations", zap.Error(getErr))
} else {
chatModel = models.NewChatModel(driver, &mdlName, apiConfig)
common.Info("Fetched chat model (tenant default) for cross_languages/keyword_extraction",
zap.String("tenantID", tenantIDs[0]),
zap.String("modelName", llmModelName))
}
}
}
// Apply cross_languages on the question (translate question)
if len(req.CrossLanguages) > 0 {
translated, err := service.CrossLanguages(ctx, tenantIDs[0], llmModelName, req.Question, req.CrossLanguages)
if err != nil {
common.Warn("Failed to translate question", zap.Error(err))
} else {
modifiedQuestion = translated
}
}
// Apply keyword extraction on the question (append keywords to question)
if chatModel != nil && req.Keyword != nil && *req.Keyword {
extractedKeywords, err := service.KeywordExtraction(ctx, chatModel, modifiedQuestion, 3)
if err != nil {
common.Warn("Failed to extract keywords from question", zap.Error(err))
} else if extractedKeywords != "" {
modifiedQuestion = modifiedQuestion + " " + extractedKeywords
}
}
if modifiedQuestion != req.Question {
common.Info("Modified question after transformations",
zap.String("originalQuestion", req.Question),
zap.String("modifiedQuestion", modifiedQuestion),
zap.Strings("crossLanguages", req.CrossLanguages),
zap.Bool("keywordExtraction", req.Keyword != nil && *req.Keyword))
}
// Get tag-based rank features via LabelQuestion
metadataSvc := service.NewMetadataService()
labels := metadataSvc.LabelQuestion(modifiedQuestion, kbRecords)
common.Debug("LabelQuestion result", zap.Any("labels", labels))
// Determine embedding model
var embdID string
var tenantLLM *entity.TenantLLM
if kbRecords[0].TenantEmbdID != nil && *kbRecords[0].TenantEmbdID > 0 {
tenantLLM, embdID, err = dao.LookupTenantLLMByID(dao.NewTenantLLMDAO(), *kbRecords[0].TenantEmbdID)
if err != nil {
return nil, fmt.Errorf("failed to get embedding model by tenant_embd_id: %w", err)
}
} else if kbRecords[0].EmbdID != "" {
parts := strings.Split(kbRecords[0].EmbdID, "@")
if len(parts) == 2 && parts[1] != "" {
tenantLLM, embdID, err = dao.LookupTenantLLMByFactory(dao.NewTenantLLMDAO(), tenantIDs[0], parts[1], parts[0], entity.ModelTypeEmbedding)
} else {
tenantLLM, embdID, err = dao.LookupTenantLLMByName(dao.NewTenantLLMDAO(), tenantIDs[0], kbRecords[0].EmbdID, entity.ModelTypeEmbedding)
}
if err != nil {
return nil, fmt.Errorf("failed to get embedding model by embd_id: %w", err)
}
} else {
tenantLLM, err = dao.NewTenantLLMDAO().GetByTenantAndType(tenantIDs[0], entity.ModelTypeEmbedding)
if err != nil {
return nil, fmt.Errorf("failed to get tenant default embedding model: %w", err)
}
if tenantLLM == nil || tenantLLM.LLMName == nil || *tenantLLM.LLMName == "" {
return nil, fmt.Errorf("no default embedding model found for tenant %s", tenantIDs[0])
}
embdID = fmt.Sprintf("%s@%s", *tenantLLM.LLMName, tenantLLM.LLMFactory)
}
// Get embedding model for the tenant
modelProviderSvc := service.NewModelProviderService()
embeddingModel, err := modelProviderSvc.GetEmbeddingModel(tenantIDs[0], embdID)
if err != nil {
return nil, fmt.Errorf("failed to get embedding model: %w", err)
}
common.Info("Fetched embedding model for retrieval",
zap.String("tenantID", tenantIDs[0]),
zap.String("embdID", embdID))
// Get rerank model if RerankID is specified
var rerankModel *models.RerankModel
var rerankCompositeName string
if req.TenantRerankID != nil && *req.TenantRerankID != "" {
tenantRerankIDInt, parseErr := strconv.ParseInt(*req.TenantRerankID, 10, 64)
if parseErr != nil {
return nil, fmt.Errorf("invalid tenant_rerank_id: %w", parseErr)
}
_, rerankCompositeName, err = dao.LookupTenantLLMByID(dao.NewTenantLLMDAO(), tenantRerankIDInt)
if err != nil {
return nil, fmt.Errorf("failed to get rerank model by tenant_rerank_id: %w", err)
}
} else if req.RerankID != nil && *req.RerankID != "" {
_, rerankCompositeName, err = dao.LookupTenantLLMByName(dao.NewTenantLLMDAO(), tenantIDs[0], *req.RerankID, entity.ModelTypeRerank)
if err != nil {
return nil, fmt.Errorf("failed to get rerank model by rerank_id: %w", err)
}
}
if rerankCompositeName != "" {
driver, mdlName, apiConfig, _, getErr := modelProviderSvc.GetModelConfigFromProviderInstance(tenantIDs[0], entity.ModelTypeRerank, rerankCompositeName)
if getErr != nil {
return nil, fmt.Errorf("failed to get rerank model: %w", getErr)
}
rerankModel = models.NewRerankModel(driver, &mdlName, apiConfig)
}
if rerankModel != nil {
common.Info("Fetched rerank model",
zap.String("tenantID", tenantIDs[0]),
zap.String("rerankCompositeName", rerankCompositeName))
}
retrievalReq := &nlp.RetrievalRequest{
TenantIDs: tenantIDs,
Question: modifiedQuestion,
KbIDs: []string(req.Datasets),
DocIDs: docIDs,
Page: common.CoalesceInt(req.Page, 1),
PageSize: common.CoalesceInt(req.Size, 30),
Top: req.TopK,
SimilarityThreshold: req.SimilarityThreshold,
VectorSimilarityWeight: req.VectorSimilarityWeight,
RerankModel: rerankModel,
RankFeature: &labels,
EmbeddingModel: embeddingModel,
}
// Call RetrievalService to perform retrieval
retrievalResult, err := nlp.NewRetrievalService(s.docEngine, s.documentDAO).Retrieval(ctx, retrievalReq)
if err != nil {
return nil, fmt.Errorf("retrieval search failed: %w", err)
}
filteredChunks := retrievalResult.Chunks
// Handle knowledge graph retrieval
// TODO: KG retrieval requires GraphRAG infrastructure which is not yet implemented in Go
if req.UseKG != nil && *req.UseKG {
common.Warn("use_kg is not yet implemented in Go - skipping KG retrieval")
}
// Apply retrieval_by_children - aggregate child chunks into parent chunks
filteredChunks = nlp.RetrievalByChildren(filteredChunks, tenantIDs, s.docEngine, ctx)
// Hydrate: ES returns zero vectors; replace with real vectors from FetchChunkVectors.
// Infinity/OceanBase chunks already carry real vectors and are left unchanged.
hydrateChunkVectors(ctx, s.docEngine, filteredChunks, req.Datasets, tenantIDs)
common.Info("RetrievalTest completed", zap.String("userID", userID), zap.Any("kbID", req.Datasets), zap.String("question", req.Question), zap.Int64("chunkCount", int64(len(filteredChunks))))
return &service.RetrievalTestResponse{
Chunks: filteredChunks,
DocAggs: retrievalResult.DocAggs,
Labels: &labels,
Total: retrievalResult.Total,
}, nil
}
// hydrateChunkVectors replaces zero (placeholder) vectors in chunks with real
// vectors fetched from the engine. Infinity and OceanBase already ship real
// vectors with chunks, so this is a no-op for those engines; for ES it queries
// the engine by chunk ID list. No if/else on engine type — just replaces
// whatever is missing or zero.
func hydrateChunkVectors(ctx context.Context, engine engine.DocEngine, chunks []map[string]interface{}, kbIDs []string, tenantIDs []string) {
if len(chunks) == 0 {
return
}
// Collect chunk IDs whose vectors are missing or all-zero.
var missingIDs []string
missingIdx := make(map[string]int)
for i, ck := range chunks {
id, _ := ck["id"].(string)
if id == "" {
continue
}
v, _ := ck["vector"].([]float64)
if len(v) == 0 || common.IsZeroVector(v) {
missingIDs = append(missingIDs, id)
missingIdx[id] = i
}
}
if len(missingIDs) == 0 {
return
}
dim := 0
for _, ck := range chunks {
if v, _ := ck["vector"].([]float64); len(v) > 0 {
dim = len(v)
break
}
}
if dim == 0 {
return
}
vectors := FetchChunkVectors(ctx, engine, missingIDs, tenantIDs, kbIDs, dim)
for id, v := range vectors {
if idx, ok := missingIdx[id]; ok && !common.IsZeroVector(v) {
chunks[idx]["vector"] = v
}
}
}
// Get retrieves a chunk by ID
func (s *ChunkService) Get(req *service.GetChunkRequest, userID string) (*service.GetChunkResponse, error) {
if s.docEngine == nil {
return nil, fmt.Errorf("doc engine not initialized")
}
if req.ChunkID == "" {
return nil, fmt.Errorf("chunk_id is required")
}
ctx := context.Background()
// Get user's tenants
tenants, err := s.userTenantDAO.GetByUserID(userID)
if err != nil {
return nil, fmt.Errorf("failed to get user tenants: %w", err)
}
if len(tenants) == 0 {
return nil, fmt.Errorf("user has no accessible tenants")
}
// Try each tenant to find the chunk
var chunk map[string]interface{}
for _, tenant := range tenants {
// Get kbIDs for this tenant
kbIDs, err := s.kbDAO.GetKBIDsByTenantID(tenant.TenantID)
if err != nil {
continue
}
indexName := fmt.Sprintf("ragflow_%s", tenant.TenantID)
doc, err := s.docEngine.GetChunk(ctx, indexName, req.ChunkID, kbIDs)
if err != nil {
continue
}
if doc != nil {
chunk, ok := doc.(map[string]interface{})
if ok {
result := make(map[string]interface{})
skipFields := map[string]bool{
"id": true, "authors": true, "_score": true, "SCORE": true,
}
for k, v := range chunk {
if skipFields[k] || strings.HasSuffix(k, "_vec") || strings.Contains(k, "_sm_") || strings.HasSuffix(k, "_tks") || strings.HasSuffix(k, "_ltks") {
continue
}
switch k {
case "content":
result["content_with_weight"] = v
case "docnm":
result["docnm_kwd"] = v
case "important_keywords":
utility.SetFieldArray(result, "important_kwd", v)
case "questions":
utility.SetFieldArray(result, "question_kwd", v)
case "entities_kwd", "entity_kwd", "entity_type_kwd", "from_entity_kwd",
"name_kwd", "raptor_kwd", "removed_kwd", "source_id", "tag_kwd",
"to_entity_kwd", "toc_kwd", "authors_tks", "doc_type_kwd":
if utility.IsEmpty(v) {
result[k] = []interface{}{}
} else {
result[k] = v
}
case "tag_feas":
if utility.IsEmpty(v) {
result[k] = map[string]interface{}{}
} else {
result[k] = v
}
case "create_timestamp_flt", "rank_flt", "weight_flt":
if floatVal, ok := utility.ToFloat64(v); ok {
result[k] = utility.JSONFloat64(floatVal)
}
default:
result[k] = v
}
}
return &service.GetChunkResponse{Chunk: result}, nil
}
}
}
if chunk == nil {
return nil, fmt.Errorf("chunk not found")
}
return &service.GetChunkResponse{Chunk: chunk}, nil
}
// List retrieves chunks for a document
func (s *ChunkService) List(req *service.ListChunksRequest, userID string) (*service.ListChunksResponse, error) {
if s.docEngine == nil {
return nil, fmt.Errorf("doc engine not initialized")
}
if req.DocID == "" {
return nil, fmt.Errorf("doc_id is required")
}
ctx := context.Background()
// Get user's tenants
tenants, err := s.userTenantDAO.GetByUserID(userID)
if err != nil {
return nil, fmt.Errorf("failed to get user tenants: %w", err)
}
if len(tenants) == 0 {
return nil, fmt.Errorf("user has no accessible tenants")
}
// Get document to find its tenant
docDAO := dao.NewDocumentDAO()
doc, err := docDAO.GetByID(req.DocID)
if err != nil || doc == nil {
return nil, fmt.Errorf("document not found")
}
// Get knowledge base to find tenant
kb, err := s.kbDAO.GetByID(doc.KbID)
if err != nil || kb == nil {
return nil, fmt.Errorf("knowledge base not found")
}
// Find which tenant this document belongs to
var targetTenantID string
for _, tenant := range tenants {
if tenant.TenantID == kb.TenantID {
targetTenantID = tenant.TenantID
break
}
}
if targetTenantID == "" {
return nil, fmt.Errorf("user does not have access to this document")
}
// Get kbIDs for this tenant
kbIDs, err := s.kbDAO.GetKBIDsByTenantID(targetTenantID)
if err != nil {
return nil, fmt.Errorf("failed to get kb ids: %w", err)
}
indexName := fmt.Sprintf("ragflow_%s", targetTenantID)
page := common.CoalesceInt(req.Page, 1)
size := common.CoalesceInt(req.Size, 30)
keywords := req.Keywords
// Build search request - same as retrieval test but filtered by doc_id
searchReq := &types.SearchRequest{
IndexNames: []string{indexName},
MatchExprs: []interface{}{keywords},
KbIDs: kbIDs,
Offset: (page - 1) * size,
Limit: size,
Filter: map[string]interface{}{
"doc_id": req.DocID,
},
}
// Add available_int filter if specified
if req.AvailableInt != nil {
searchReq.Filter["available_int"] = *req.AvailableInt
}
// Execute search through unified engine interface
searchResp, err := s.docEngine.Search(ctx, searchReq)
if err != nil {
return nil, fmt.Errorf("search failed: %w", err)
}
chunks := make([]map[string]interface{}, 0, len(searchResp.Chunks))
for _, chunk := range searchResp.Chunks {
// Inline formatChunkForList
result := make(map[string]interface{})
skipFields := map[string]bool{
"_id": true, "authors": true, "_score": true, "SCORE": true,
"important_kwd_empty_count": true, "kb_id": true, "mom_id": true, "page_num_int": true,
}
for k, v := range chunk {
if skipFields[k] || strings.HasSuffix(k, "_vec") || strings.Contains(k, "_sm_") || strings.HasSuffix(k, "_ltks") || strings.HasSuffix(k, "_tks") {
continue
}
switch k {
case "img_id":
if strVal, ok := v.(string); ok {
result["image_id"] = strVal
} else {
result["image_id"] = ""
}
case "position_int":
result["positions"] = v
case "id":
result["chunk_id"] = v
case "content":
result["content_with_weight"] = v
case "docnm":
result["docnm_kwd"] = v
case "important_keywords":
utility.SetFieldArray(result, "important_kwd", v)
case "questions":
utility.SetFieldArray(result, "question_kwd", v)
case "entities_kwd", "entity_kwd", "entity_type_kwd", "from_entity_kwd",
"name_kwd", "raptor_kwd", "removed_kwd",
"source_id", "tag_kwd", "to_entity_kwd", "toc_kwd", "doc_type_kwd":
if utility.IsEmpty(v) {
result[k] = []interface{}{}
} else {
result[k] = v
}
default:
// Handle _kwd fields that need "###" splitting
if strings.HasSuffix(k, "_kwd") && k != "knowledge_graph_kwd" {
if strVal, ok := v.(string); ok && strings.Contains(strVal, "###") {
parts := strings.Split(strVal, "###")
var filtered []interface{}
for _, p := range parts {
if p != "" {
filtered = append(filtered, p)
}
}
result[k] = filtered
} else {
result[k] = v
}
} else {
result[k] = v
}
}
}
chunks = append(chunks, result)
}
// Build document info
timeFormat := "2006-01-02T15:04:05"
docInfo := map[string]interface{}{
"id": doc.ID,
"thumbnail": doc.Thumbnail,
"kb_id": doc.KbID,
"parser_id": doc.ParserID,
"pipeline_id": doc.PipelineID,
"parser_config": doc.ParserConfig,
"source_type": doc.SourceType,
"type": doc.Type,
"created_by": doc.CreatedBy,
"name": doc.Name,
"location": doc.Location,
"size": doc.Size,
"token_num": doc.TokenNum,
"chunk_num": doc.ChunkNum,
"progress": utility.JSONFloat64(doc.Progress),
"progress_msg": doc.ProgressMsg,
"process_begin_at": utility.FormatTimeToString(doc.ProcessBeginAt, timeFormat),
"process_duration": doc.ProcessDuration,
"content_hash": doc.ContentHash,
"suffix": doc.Suffix,
"run": doc.Run,
"status": doc.Status,
"create_time": doc.CreateTime,
"create_date": utility.FormatTimeToString(doc.CreateDate, timeFormat),
"update_time": doc.UpdateTime,
"update_date": utility.FormatTimeToString(doc.UpdateDate, timeFormat),
}
return &service.ListChunksResponse{
Total: searchResp.Total,
Chunks: chunks,
Doc: docInfo,
}, nil
}
func (s *ChunkService) UpdateChunk(req *service.UpdateChunkRequest, userID string) error {
if s.docEngine == nil {
return fmt.Errorf("doc engine not initialized")
}
if req.ChunkID == "" {
return fmt.Errorf("chunk_id is required")
}
ctx := context.Background()
// Get user's tenants
tenants, err := s.userTenantDAO.GetByUserID(userID)
if err != nil {
return fmt.Errorf("failed to get user tenants: %w", err)
}
if len(tenants) == 0 {
return fmt.Errorf("user has no accessible tenants")
}
// Find the tenant that owns this dataset
var targetTenantID string
for _, tenant := range tenants {
kb, err := s.kbDAO.GetByIDAndTenantID(req.DatasetID, tenant.TenantID)
if err == nil && kb != nil {
targetTenantID = tenant.TenantID
break
}
}
if targetTenantID == "" {
return fmt.Errorf("user does not have access to this dataset")
}
// Verify document belongs to dataset
docDAO := dao.NewDocumentDAO()
doc, err := docDAO.GetByID(req.DocumentID)
if err != nil || doc == nil {
return fmt.Errorf("document not found")
}
if doc.KbID != req.DatasetID {
return fmt.Errorf("document does not belong to this dataset")
}
// Fetch existing chunk first
indexName := fmt.Sprintf("ragflow_%s", targetTenantID)
existingChunk, err := s.docEngine.GetChunk(ctx, indexName, req.ChunkID, []string{req.DatasetID})
if err != nil {
return fmt.Errorf("failed to get existing chunk: %w", err)
}
existing, ok := existingChunk.(map[string]interface{})
if !ok {
return fmt.Errorf("invalid chunk format")
}
// Build update dict
d := make(map[string]interface{})
// Content - use new value or existing
if req.Content != nil {
d["content_with_weight"] = *req.Content
} else {
if v, ok := existing["content_with_weight"].(string); ok {
d["content_with_weight"] = v
} else if v, ok := existing["content"].(string); ok {
d["content_with_weight"] = v
} else {
d["content_with_weight"] = ""
}
}
// Tokenize content
contentStr := d["content_with_weight"].(string)
d["content_ltks"], _ = tokenizer.Tokenize(contentStr)
d["content_sm_ltks"], _ = tokenizer.FineGrainedTokenize(d["content_ltks"].(string))
// Important keywords - convert []string to []interface{} for transformChunkFields
if req.ImportantKwd != nil {
impKwd := make([]interface{}, len(req.ImportantKwd))
for i, v := range req.ImportantKwd {
impKwd[i] = v
}
d["important_kwd"] = impKwd
}
// Questions
if req.Questions != nil {
// Filter out empty questions and trim
filteredQuestions := []string{}
for _, q := range req.Questions {
q = strings.TrimSpace(q)
if q != "" {
filteredQuestions = append(filteredQuestions, q)
}
}
d["question_kwd"] = filteredQuestions
}
// Available
if req.Available != nil {
if *req.Available {
d["available_int"] = 1
} else {
d["available_int"] = 0
}
}
// Positions
if req.Positions != nil {
d["position_int"] = req.Positions
}
// Tag keywords
if req.TagKwd != nil {
d["tag_kwd"] = req.TagKwd
}
// Tag features
if req.TagFeas != nil {
d["tag_feas"] = req.TagFeas
}
// Always include id
d["id"] = req.ChunkID
// Call update
condition := map[string]interface{}{
"id": req.ChunkID,
}
err = s.docEngine.UpdateChunks(ctx, condition, d, indexName, req.DatasetID)
if err != nil {
return fmt.Errorf("failed to update chunk: %w", err)
}
return nil
}
func (s *ChunkService) RemoveChunks(req *service.RemoveChunksRequest, userID string) (int64, error) {
if s.docEngine == nil {
return 0, fmt.Errorf("doc engine not initialized")
}
if req.DocID == "" {
return 0, fmt.Errorf("doc_id is required")
}
ctx := context.Background()
// Get user's tenants
tenants, err := s.userTenantDAO.GetByUserID(userID)
if err != nil {
return 0, fmt.Errorf("failed to get user tenants: %w", err)
}
if len(tenants) == 0 {
return 0, fmt.Errorf("user has no accessible tenants")
}
// Verify document exists and belongs to a dataset (do this first to get doc.KbID)
docDAO := dao.NewDocumentDAO()
doc, err := docDAO.GetByID(req.DocID)
if err != nil || doc == nil {
return 0, fmt.Errorf("document not found")
}
// Find the tenant that owns this document
var targetTenantID string
for _, tenant := range tenants {
kb, err := s.kbDAO.GetByIDAndTenantID(doc.KbID, tenant.TenantID)
if err == nil && kb != nil {
targetTenantID = tenant.TenantID
break
}
}
if targetTenantID == "" {
return 0, fmt.Errorf("user does not have access to this document")
}
indexName := fmt.Sprintf("ragflow_%s", targetTenantID)
// Build condition
condition := make(map[string]interface{})
switch {
case len(req.ChunkIDs) > 0 && req.DeleteAll:
return 0, fmt.Errorf("chunk_ids and delete_all are mutually exclusive")
case len(req.ChunkIDs) > 0:
// Delete specific chunks - convert []string to []interface{} for buildFilterFromCondition
chunkIDsIf := make([]interface{}, len(req.ChunkIDs))
for i, id := range req.ChunkIDs {
chunkIDsIf[i] = id
}
condition["id"] = chunkIDsIf
condition["doc_id"] = req.DocID
case req.DeleteAll:
// Delete all chunks for this document
condition["doc_id"] = req.DocID
default:
return 0, fmt.Errorf("either chunk_ids or delete_all must be provided")
}
deletedCount, err := s.docEngine.DeleteChunks(ctx, condition, indexName, doc.KbID)
if err != nil {
return 0, fmt.Errorf("failed to delete chunks: %w", err)
}
return deletedCount, nil
}