From 7c8927c4fbdf42ac3126e4220d692b39c20d0c16 Mon Sep 17 00:00:00 2001 From: qinling0210 <88864212+qinling0210@users.noreply.github.com> Date: Tue, 24 Mar 2026 20:10:21 +0800 Subject: [PATCH] Implement GetChunk() in Infinity in GO (#13758) ### What problem does this PR solve? Implement GetChunk() in Infinity in GO Add cli: GET CHUNK 'XXX'; LIST CHUNKS OF DOCUMENT 'XXX'; ### Type of change - [x] Refactoring --- admin/client/parser.py | 34 +++ admin/client/ragflow_client.py | 69 +++++ internal/engine/elasticsearch/get.go | 56 ++++ internal/engine/engine.go | 4 +- internal/engine/infinity/get.go | 219 ++++++++++++++ internal/engine/infinity/search.go | 56 ++-- internal/engine/types/types.go | 5 +- internal/handler/chunk.go | 81 ++++++ internal/router/router.go | 2 + internal/service/chunk.go | 408 ++++++++++++++++++++++++--- internal/utility/convert.go | 130 +++++++++ 11 files changed, 989 insertions(+), 75 deletions(-) create mode 100644 internal/engine/elasticsearch/get.go create mode 100644 internal/engine/infinity/get.go diff --git a/admin/client/parser.py b/admin/client/parser.py index 91b09e138e..7465919305 100644 --- a/admin/client/parser.py +++ b/admin/client/parser.py @@ -91,6 +91,8 @@ sql_command: login_user | parse_dataset_async | import_docs_into_dataset | search_on_datasets + | get_chunk + | list_chunks | create_chat_session | drop_chat_session | list_chat_sessions @@ -164,6 +166,7 @@ DEFAULT: "DEFAULT"i CHATS: "CHATS"i CHAT: "CHAT"i FILES: "FILES"i +DOCUMENT: "DOCUMENT"i DOCUMENTS: "DOCUMENTS"i METADATA: "METADATA"i SUMMARY: "SUMMARY"i @@ -194,6 +197,13 @@ FINGERPRINT: "FINGERPRINT"i LICENSE: "LICENSE"i CHECK: "CHECK"i CONFIG: "CONFIG"i +CHUNK: "CHUNK"i +CHUNKS: "CHUNKS"i +GET: "GET"i +PAGE: "PAGE"i +SIZE: "SIZE"i +KEYWORDS: "KEYWORDS"i +AVAILABLE: "AVAILABLE"i login_user: LOGIN USER quoted_string ";" list_services: LIST SERVICES ";" @@ -321,6 +331,8 @@ list_user_model_providers: LIST MODEL PROVIDERS ";" list_user_default_models: LIST DEFAULT MODELS ";" import_docs_into_dataset: IMPORT quoted_string INTO DATASET quoted_string ";" search_on_datasets: SEARCH quoted_string ON DATASETS quoted_string ";" +get_chunk: GET CHUNK quoted_string ";" +list_chunks: LIST CHUNKS OF DOCUMENT quoted_string ("PAGE" NUMBER)? ("SIZE" NUMBER)? ("KEYWORDS" quoted_string)? ("AVAILABLE" NUMBER)? ";" parse_dataset_docs: PARSE quoted_string OF DATASET quoted_string ";" parse_dataset_sync: PARSE DATASET quoted_string SYNC ";" @@ -698,6 +710,28 @@ class RAGFlowCLITransformer(Transformer): datasets = datasets.split(" ") return {"type": "search_on_datasets", "datasets": datasets, "question": question} + def get_chunk(self, items): + chunk_id = items[2].children[0].strip("'\"") + return {"type": "get_chunk", "chunk_id": chunk_id} + + def list_chunks(self, items): + doc_id = items[4].children[0].strip("'\"") + result = {"type": "list_chunks", "doc_id": doc_id} + + # Parse optional parameters: PAGE, SIZE, KEYWORDS, AVAILABLE + # items structure varies based on which params are present + for i, item in enumerate(items): + if str(item) == "PAGE": + result["page"] = int(items[i + 1]) + elif str(item) == "SIZE": + result["size"] = int(items[i + 1]) + elif str(item) == "KEYWORDS": + result["keywords"] = items[i + 1].children[0].strip("'\"") + elif str(item) == "AVAILABLE": + result["available_int"] = int(items[i + 1]) + + return result + def benchmark(self, items): concurrency: int = int(items[1]) iterations: int = int(items[2]) diff --git a/admin/client/ragflow_client.py b/admin/client/ragflow_client.py index a6fc85a440..e45ec99c38 100644 --- a/admin/client/ragflow_client.py +++ b/admin/client/ragflow_client.py @@ -1434,6 +1434,61 @@ class RAGFlowClient: print( f"Fail to search datasets: {dataset_names}, code: {res_json['code']}, message: {res_json['message']}") + def get_chunk(self, command_dict): + if self.server_type != "user": + print("This command is only allowed in USER mode") + return + + chunk_id = command_dict["chunk_id"] + response = self.http_client.request("GET", f"/chunk/get?chunk_id={chunk_id}", use_api_base=False, + auth_kind="web") + res_json = response.json() + if response.status_code == 200: + if res_json["code"] == 0: + self._print_key_value(res_json["data"]) + else: + print(f"Fail to get chunk, code: {res_json['code']}, message: {res_json['message']}") + else: + print(f"Fail to get chunk, code: {res_json['code']}, message: {res_json['message']}") + + def list_chunks(self, command_dict): + if self.server_type != "user": + print("This command is only allowed in USER mode") + return + + doc_id = command_dict["doc_id"] + payload = { + "doc_id": doc_id, + } + + # Add optional parameters (only if explicitly provided) + if "page" in command_dict: + payload["page"] = command_dict["page"] + if "size" in command_dict: + payload["size"] = command_dict["size"] + if "keywords" in command_dict and command_dict["keywords"]: + payload["keywords"] = command_dict["keywords"] + if "available_int" in command_dict: + payload["available_int"] = command_dict["available_int"] + + response = self.http_client.request("POST", "/chunk/list", json_body=payload, use_api_base=False, + auth_kind="web") + res_json = response.json() + if response.status_code == 200: + if res_json["code"] == 0: + chunks = res_json["data"]["chunks"] + if chunks: + for i, chunk in enumerate(chunks): + print(f"\n--- Chunk {i+1} ---") + for key, value in chunk.items(): + print(f" {key}: {value}") + else: + print("No chunks found") + else: + print(f"Fail to list chunks, code: {res_json['code']}, message: {res_json['message']}") + else: + print(f"Fail to list chunks, code: {res_json['code']}, message: {res_json['message']}") + def show_version(self, command): if self.server_type == "admin": response = self.http_client.request("GET", "/admin/version", use_api_base=True, auth_kind="admin") @@ -1618,6 +1673,14 @@ class RAGFlowClient: print(separator) + def _print_key_value(self, data: dict): + """Print data as key-value pairs (one per line)""" + if not data: + print("No data to print") + return + for key, value in data.items(): + print(f"{key}: {value}") + def run_command(client: RAGFlowClient, command_dict: dict): command_type = command_dict["type"] @@ -1761,6 +1824,10 @@ def run_command(client: RAGFlowClient, command_dict: dict): client.import_docs_into_dataset(command_dict) case "search_on_datasets": return client.search_on_datasets(command_dict) + case "get_chunk": + return client.get_chunk(command_dict) + case "list_chunks": + return client.list_chunks(command_dict) case "meta": _handle_meta_command(command_dict) case _: @@ -1818,6 +1885,8 @@ LIST DOCUMENTS OF DATASET SEARCH ON DATASETS LIST METADATA OF DATASETS [, ]* LIST METADATA SUMMARY OF DATASET DOCUMENTS [, ]* +GET CHUNK +LIST CHUNKS OF DOCUMENT [PAGE ] [SIZE ] [KEYWORDS ] [AVAILABLE <0|1>] Meta Commands: \\?, \\h, \\help Show this help diff --git a/internal/engine/elasticsearch/get.go b/internal/engine/elasticsearch/get.go new file mode 100644 index 0000000000..a2a4071260 --- /dev/null +++ b/internal/engine/elasticsearch/get.go @@ -0,0 +1,56 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package elasticsearch + +import ( + "context" + "fmt" +) + +// GetChunk gets a chunk by ID +func (e *elasticsearchEngine) GetChunk(ctx context.Context, indexName, chunkID string, kbIDs []string) (interface{}, error) { + // Build query to get the chunk by ID + query := map[string]interface{}{ + "term": map[string]interface{}{ + "id": chunkID, + }, + } + + searchReq := &SearchRequest{ + IndexNames: []string{indexName}, + Query: query, + Size: 1, + From: 0, + } + + // Execute search + result, err := e.Search(ctx, searchReq) + if err != nil { + return nil, fmt.Errorf("failed to search: %w", err) + } + + esResp, ok := result.(*SearchResponse) + if !ok { + return nil, fmt.Errorf("invalid search response type") + } + + if len(esResp.Hits.Hits) == 0 { + return nil, nil + } + + return esResp.Hits.Hits[0].Source, nil +} diff --git a/internal/engine/engine.go b/internal/engine/engine.go index c8e9165426..f6cd56e3cd 100644 --- a/internal/engine/engine.go +++ b/internal/engine/engine.go @@ -49,9 +49,11 @@ type DocEngine interface { // Document operations IndexDocument(ctx context.Context, indexName, docID string, doc interface{}) error BulkIndex(ctx context.Context, indexName string, docs []interface{}) (interface{}, error) - GetDocument(ctx context.Context, indexName, docID string) (interface{}, error) DeleteDocument(ctx context.Context, indexName, docID string) error + // Chunk operations + GetChunk(ctx context.Context, indexName, chunkID string, kbIDs []string) (interface{}, error) + // Health check Ping(ctx context.Context) error Close() error diff --git a/internal/engine/infinity/get.go b/internal/engine/infinity/get.go new file mode 100644 index 0000000000..a8f8b58135 --- /dev/null +++ b/internal/engine/infinity/get.go @@ -0,0 +1,219 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package infinity + +import ( + "context" + "fmt" + "strings" + + infinity "github.com/infiniflow/infinity-go-sdk" + "ragflow/internal/logger" + "ragflow/internal/utility" + + "go.uber.org/zap" +) + +// GetChunk gets a chunk by ID +func (e *infinityEngine) GetChunk(ctx context.Context, tableName, chunkID string, kbIDs []string) (interface{}, error) { + if e.client == nil || e.client.conn == nil { + return nil, fmt.Errorf("Infinity client not initialized") + } + + // Build list of table names to search + var tableNames []string + if strings.HasPrefix(tableName, "ragflow_doc_meta_") { + tableNames = []string{tableName} + } else { + // Search in tables like _ for each kbID + if len(kbIDs) > 0 { + for _, kbID := range kbIDs { + tableNames = append(tableNames, fmt.Sprintf("%s_%s", tableName, kbID)) + } + } + // Also try the base tableName + tableNames = append(tableNames, tableName) + } + + // Try each table and collect results from all tables + db, err := e.client.conn.GetDatabase(e.client.dbName) + if err != nil { + return nil, fmt.Errorf("failed to get database: %w", err) + } + + // Collect chunks from all tables (same as Python's concat_dataframes) + allChunks := make(map[string]map[string]interface{}) + + for _, tblName := range tableNames { + table, err := db.GetTable(tblName) + if err != nil { + continue + } + + // Query with filter for the specific chunk ID + filter := fmt.Sprintf("id = '%s'", chunkID) + result, err := table.Output([]string{"*"}).Filter(filter).ToResult() + if err != nil { + continue + } + + qr, ok := result.(*infinity.QueryResult) + if !ok { + continue + } + + if len(qr.Data) == 0 { + continue + } + + // Convert to chunk format + chunks := make([]map[string]interface{}, 0) + for colName, colData := range qr.Data { + for i, val := range colData { + for len(chunks) <= i { + chunks = append(chunks, make(map[string]interface{})) + } + chunks[i][colName] = val + } + } + + // Merge chunks into allChunks (by id), keeping first non-empty value + for _, chunk := range chunks { + if idVal, ok := chunk["id"].(string); ok { + if existing, exists := allChunks[idVal]; exists { + // Merge: keep first non-empty value for each field + for k, v := range chunk { + if _, has := existing[k]; !has || utility.IsEmpty(v) { + existing[k] = v + } + } + } else { + allChunks[idVal] = chunk + } + } + } + } + + // Get the chunk by chunkID + chunk, found := allChunks[chunkID] + if !found { + return nil, nil + } + + getFields(chunk) + + logger.Debug("infinity get chunk", zap.String("chunkID", chunkID), zap.Any("tables", tableNames)) + + return chunk, nil +} + +// getFields applies field mappings to a chunk, similar to Python's get_fields function. +func getFields(chunk map[string]interface{}) { + // Field mappings + // docnm -> docnm_kwd, title_tks, title_sm_tks + if val, ok := chunk["docnm"].(string); ok { + chunk["docnm_kwd"] = val + chunk["title_tks"] = val + chunk["title_sm_tks"] = val + } + + // important_keywords -> important_kwd (split by comma), important_tks + if val, ok := chunk["important_keywords"].(string); ok { + if val == "" { + chunk["important_kwd"] = []interface{}{} + } else { + parts := strings.Split(val, ",") + chunk["important_kwd"] = parts + } + chunk["important_tks"] = val + } else { + chunk["important_kwd"] = []interface{}{} + chunk["important_tks"] = []interface{}{} + } + + // questions -> question_kwd (split by newline), question_tks + if val, ok := chunk["questions"].(string); ok { + if val == "" { + chunk["question_kwd"] = []interface{}{} + } else { + parts := strings.Split(val, "\n") + chunk["question_kwd"] = parts + } + chunk["question_tks"] = val + } else { + chunk["question_kwd"] = []interface{}{} + chunk["question_tks"] = []interface{}{} + } + + // content -> content_with_weight, content_ltks, content_sm_ltks + if val, ok := chunk["content"].(string); ok { + chunk["content_with_weight"] = val + chunk["content_ltks"] = val + chunk["content_sm_ltks"] = val + } + + // authors -> authors_tks, authors_sm_tks + if val, ok := chunk["authors"].(string); ok { + chunk["authors_tks"] = val + chunk["authors_sm_tks"] = val + } + + // position_int: convert from hex string to array format (grouped by 5) + if val, ok := chunk["position_int"].(string); ok { + chunk["position_int"] = utility.ConvertHexToPositionIntArray(val) + } else { + chunk["position_int"] = []interface{}{} + } + + // Convert page_num_int and top_int from hex string to array + for _, colName := range []string{"page_num_int", "top_int"} { + if val, ok := chunk[colName].(string); ok && val != "" { + chunk[colName] = utility.ConvertHexToIntArray(val) + } else { + chunk[colName] = []int{} + } + } + + // Post-process: convert nil/empty values to empty slices for array-like fields + // and split _kwd fields by "###" (except knowledge_graph_kwd, docnm_kwd, important_kwd, question_kwd) + kwdNoSplit := map[string]bool{ + "knowledge_graph_kwd": true, "docnm_kwd": true, + "important_kwd": true, "question_kwd": true, + } + arrayFields := []string{ + "doc_type_kwd", "important_kwd", "important_tks", "question_tks", + "question_kwd", "authors_tks", "authors_sm_tks", "title_tks", + "title_sm_tks", "content_ltks", "content_sm_ltks", + } + for _, colName := range arrayFields { + if val, ok := chunk[colName]; !ok || val == nil || val == "" { + chunk[colName] = []interface{}{} + } else if !kwdNoSplit[colName] { + // Split by "###" for _kwd fields + if strVal, ok := val.(string); ok && strings.Contains(strVal, "###") { + parts := strings.Split(strVal, "###") + var filtered []interface{} + for _, p := range parts { + if p != "" { + filtered = append(filtered, p) + } + } + chunk[colName] = filtered + } + } + } +} diff --git a/internal/engine/infinity/search.go b/internal/engine/infinity/search.go index 2f026ebd11..4b6641c7be 100644 --- a/internal/engine/infinity/search.go +++ b/internal/engine/infinity/search.go @@ -20,7 +20,7 @@ import ( "context" "fmt" "ragflow/internal/engine/types" - "strconv" + "ragflow/internal/utility" "strings" "unicode/utf8" @@ -458,18 +458,25 @@ func (e *infinityEngine) searchUnified(ctx context.Context, req *types.SearchReq } } + // DocIDs filters by doc_id (document ID) to find all chunks belonging to a document + // This is used by ChunkService.List() to list all chunks for a document if len(req.DocIDs) > 0 { if len(req.DocIDs) == 1 { - filterParts = append(filterParts, fmt.Sprintf("id = '%s'", req.DocIDs[0])) + filterParts = append(filterParts, fmt.Sprintf("doc_id = '%s'", req.DocIDs[0])) } else { docIDs := strings.Join(req.DocIDs, "', '") - filterParts = append(filterParts, fmt.Sprintf("id IN ('%s')", docIDs)) + filterParts = append(filterParts, fmt.Sprintf("doc_id IN ('%s')", docIDs)) } } - if !isMetadataTable { - // Default filter for available chunks - filterParts = append(filterParts, "available_int=1") + // Only add available_int filter when there's text/vector match or AvailableInt is explicitly set + // This matches Python's behavior where chunk_list doesn't filter by available_int + if !isMetadataTable && (hasTextMatch || hasVectorMatch || req.AvailableInt != nil) { + if req.AvailableInt != nil { + filterParts = append(filterParts, fmt.Sprintf("available_int=%d", *req.AvailableInt)) + } else { + filterParts = append(filterParts, "available_int=1") + } } filterStr := strings.Join(filterParts, " AND ") @@ -637,13 +644,13 @@ func calculateScores(chunks []map[string]interface{}, scoreColumn, pagerankField for i := range chunks { score := 0.0 if scoreVal, ok := chunks[i][scoreColumn]; ok { - if f, ok := toFloat64(scoreVal); ok { + if f, ok := utility.ToFloat64(scoreVal); ok { score += f fmt.Printf("[DEBUG] chunk[%d]: %s=%f\n", i, scoreColumn, f) } } if pagerankVal, ok := chunks[i][pagerankField]; ok { - if f, ok := toFloat64(pagerankVal); ok { + if f, ok := utility.ToFloat64(pagerankVal); ok { score += f } } @@ -699,27 +706,6 @@ func getScore(chunk map[string]interface{}) float64 { return 0.0 } -func toFloat64(val interface{}) (float64, bool) { - switch v := val.(type) { - case float64: - return v, true - case float32: - return float64(v), true - case int: - return float64(v), true - case int64: - return float64(v), true - case string: - f, err := strconv.ParseFloat(v, 64) - if err != nil { - return 0, false - } - return f, true - default: - return 0, false - } -} - // executeTableSearch executes search on a single table func (e *infinityEngine) executeTableSearch(db *infinity.Database, tableName string, outputColumns []string, question string, vector []float64, filterStr string, topK, pageSize, offset int, orderBy *OrderByExpr, rankFeature map[string]float64, similarityThreshold float64, minMatch float64) (*types.SearchResponse, error) { // Debug logging @@ -937,6 +923,18 @@ func (e *infinityEngine) executeQuery(table *infinity.Table) (*types.SearchRespo chunks[i][colName] = []interface{}{} } } + // Convert position_int from hex string to array format + if posVal, ok := chunks[i]["position_int"].(string); ok { + chunks[i]["position_int"] = utility.ConvertHexToPositionIntArray(posVal) + } else { + chunks[i]["position_int"] = []interface{}{} + } + // Convert page_num_int and top_int from hex string to array + for _, colName := range []string{"page_num_int", "top_int"} { + if val, ok := chunks[i][colName].(string); ok { + chunks[i][colName] = utility.ConvertHexToIntArray(val) + } + } } return &types.SearchResponse{ diff --git a/internal/engine/types/types.go b/internal/engine/types/types.go index a7990f9c4c..5556774121 100644 --- a/internal/engine/types/types.go +++ b/internal/engine/types/types.go @@ -28,8 +28,9 @@ type SearchRequest struct { Keywords []string // Extracted keywords from question // Filters - KbIDs []string // Knowledge base IDs filter - DocIDs []string // Document IDs filter + KbIDs []string // Knowledge base IDs filter + DocIDs []string // Document IDs filter + AvailableInt *int // Available_int filter (1 = available, 0 = unavailable) // Pagination Page int // Page number (1-based) diff --git a/internal/handler/chunk.go b/internal/handler/chunk.go index 6b855ad4d1..233b8e1221 100644 --- a/internal/handler/chunk.go +++ b/internal/handler/chunk.go @@ -165,3 +165,84 @@ func (h *ChunkHandler) RetrievalTest(c *gin.Context) { "message": "success", }) } + +// Get retrieves a chunk by ID +func (h *ChunkHandler) Get(c *gin.Context) { + user, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) + return + } + + chunkID := c.Query("chunk_id") + if chunkID == "" { + c.JSON(http.StatusBadRequest, gin.H{ + "code": 400, + "message": "chunk_id is required", + }) + return + } + + req := &service.GetChunkRequest{ + ChunkID: chunkID, + } + + resp, err := h.chunkService.Get(req, user.ID) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "code": 500, + "message": err.Error(), + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "code": 0, + "data": resp.Chunk, + "message": "success", + }) +} + +// List retrieves chunks for a document +func (h *ChunkHandler) List(c *gin.Context) { + user, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) + return + } + + // Bind JSON request + var req service.ListChunksRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "code": 400, + "message": err.Error(), + }) + return + } + + // Set default values for optional parameters + if req.Page == nil { + defaultPage := 1 + req.Page = &defaultPage + } + if req.Size == nil { + defaultSize := 30 + req.Size = &defaultSize + } + + resp, err := h.chunkService.List(&req, user.ID) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{ + "code": 500, + "message": err.Error(), + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "code": 0, + "data": resp, + "message": "success", + }) +} diff --git a/internal/router/router.go b/internal/router/router.go index 8e6226fb75..e7ddb8958c 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -200,6 +200,8 @@ func (r *Router) Setup(engine *gin.Engine) { chunk := authorized.Group("/v1/chunk") { chunk.POST("/retrieval_test", r.chunkHandler.RetrievalTest) + chunk.GET("/get", r.chunkHandler.Get) + chunk.POST("/list", r.chunkHandler.List) } // LLM routes diff --git a/internal/service/chunk.go b/internal/service/chunk.go index 89b06fa86a..9d227ccb86 100644 --- a/internal/service/chunk.go +++ b/internal/service/chunk.go @@ -20,7 +20,6 @@ import ( "context" "fmt" "ragflow/internal/server" - "strconv" "strings" "go.uber.org/zap" @@ -77,9 +76,10 @@ type RetrievalTestRequest struct { // RetrievalTestResponse retrieval test response type RetrievalTestResponse struct { - Chunks []map[string]interface{} `json:"chunks"` - Labels []map[string]interface{} `json:"labels"` - Total int64 `json:"total,omitempty"` + Chunks []map[string]interface{} `json:"chunks"` + DocAggs []map[string]interface{} `json:"doc_aggs"` + Labels *[]map[string]interface{} `json:"labels"` + Total int64 `json:"total,omitempty"` } // RetrievalTest performs retrieval test @@ -283,8 +283,8 @@ func (s *ChunkService) RetrievalTest(req *RetrievalTestRequest, userID string) ( // Perform reranking // Reference: rag/nlp/search.py L404-L429 - tkWeight := 1.0 - *req.VectorSimilarityWeight - vtWeight := *req.VectorSimilarityWeight + vtWeight := getVectorSimilarityWeight(req.VectorSimilarityWeight) + tkWeight := 1.0 - vtWeight useInfinity := s.engineType == server.EngineInfinity sim, term_similarity, vector_similarity := nlp.Rerank( @@ -312,10 +312,71 @@ func (s *ChunkService) RetrievalTest(req *RetrievalTestRequest, userID string) ( convertedChunks := buildRetrievalTestResults(filteredChunks) + // Build doc_aggs by aggregating chunks by docnm + docAggsMap := make(map[string]struct { + docID string + count int + }) + docNameOrder := []string{} // Track insertion order of doc names + for _, chunk := range filteredChunks { + docName := "" + docID := "" + if v, ok := chunk["docnm"].(string); ok { + docName = v + } + if v, ok := chunk["doc_id"].(string); ok { + docID = v + } + if docName == "" { + continue + } + if entry, exists := docAggsMap[docName]; exists { + entry.count++ + docAggsMap[docName] = entry + } else { + docAggsMap[docName] = struct { + docID string + count int + }{docID: docID, count: 1} + docNameOrder = append(docNameOrder, docName) + } + } + + // Convert to list maintaining insertion order + type docAggEntry struct { + docName string + docID string + count int + order int + } + docAggsList := make([]docAggEntry, 0, len(docAggsMap)) + for order, docName := range docNameOrder { + entry := docAggsMap[docName] + docAggsList = append(docAggsList, docAggEntry{docName: docName, docID: entry.docID, count: entry.count, order: order}) + } + // Sort by count descending, then by order ascending (for tie-breaking) + for i := 0; i < len(docAggsList)-1; i++ { + for j := i + 1; j < len(docAggsList); j++ { + if docAggsList[j].count > docAggsList[i].count || + (docAggsList[j].count == docAggsList[i].count && docAggsList[j].order < docAggsList[i].order) { + docAggsList[i], docAggsList[j] = docAggsList[j], docAggsList[i] + } + } + } + docAggs := make([]map[string]interface{}, 0, len(docAggsList)) + for _, entry := range docAggsList { + docAggs = append(docAggs, map[string]interface{}{ + "doc_name": entry.docName, + "doc_id": entry.docID, + "count": entry.count, + }) + } + return &RetrievalTestResponse{ - Chunks: convertedChunks, - Labels: []map[string]interface{}{}, // Empty labels for now - Total: int64(len(convertedChunks)), + Chunks: convertedChunks, + DocAggs: docAggs, + Labels: nil, + Total: int64(len(convertedChunks)), }, nil } @@ -457,11 +518,7 @@ func buildRetrievalTestResults(filteredChunks []map[string]interface{}) []map[st result["kb_id"] = v } if v, ok := chunk["position_int"]; ok { - if strVal, ok := v.(string); ok && strVal != "" { - result["positions"] = convertPositionInt(strVal) - } else { - result["positions"] = []interface{}{} - } + result["positions"] = v } if v, ok := chunk["doc_type_kwd"]; ok { result["doc_type_kwd"] = v @@ -490,42 +547,307 @@ func buildRetrievalTestResults(filteredChunks []map[string]interface{}) []map[st return results } -// convertPositionInt converts hex string format "00000001_0000005e_..." to array [[1, 94, ...], ...] -func convertPositionInt(hexStr string) []interface{} { - if hexStr == "" { - return []interface{}{} +// GetChunkRequest request for getting a chunk by ID +type GetChunkRequest struct { + ChunkID string `json:"chunk_id"` +} + +// GetChunkResponse response for getting a chunk +type GetChunkResponse struct { + Chunk map[string]interface{} `json:"chunk"` +} + +// Get retrieves a chunk by ID +func (s *ChunkService) Get(req *GetChunkRequest, userID string) (*GetChunkResponse, error) { + if s.docEngine == nil { + return nil, fmt.Errorf("doc engine not initialized") } - parts := strings.Split(hexStr, "_") - var intVals []int - for _, part := range parts { - if part == "" { - continue - } - // Parse hex string (without 0x prefix) - val, err := strconv.ParseInt(part, 16, 64) + if req.ChunkID == "" { + return nil, fmt.Errorf("chunk_id is required") + } + + ctx := context.Background() + + // Get user's tenants + tenants, err := s.userTenantDAO.GetByUserID(userID) + if err != nil { + return nil, fmt.Errorf("failed to get user tenants: %w", err) + } + if len(tenants) == 0 { + return nil, fmt.Errorf("user has no accessible tenants") + } + + // Try each tenant to find the chunk + var chunk map[string]interface{} + for _, tenant := range tenants { + // Get kbIDs for this tenant + kbIDs, err := s.kbDAO.GetKBIDsByTenantID(tenant.TenantID) if err != nil { continue } - intVals = append(intVals, int(val)) + + indexName := fmt.Sprintf("ragflow_%s", tenant.TenantID) + + doc, err := s.docEngine.GetChunk(ctx, indexName, req.ChunkID, kbIDs) + if err != nil { + continue + } + + if doc != nil { + chunk, ok := doc.(map[string]interface{}) + if ok { + // Format to match Python output + result := make(map[string]interface{}) + skipFields := map[string]bool{ + "id": true, "authors": true, "_score": true, "SCORE": true, + } + for k, v := range chunk { + if skipFields[k] || strings.HasSuffix(k, "_vec") || strings.Contains(k, "_sm_") || strings.HasSuffix(k, "_tks") || strings.HasSuffix(k, "_ltks") { + continue + } + switch k { + case "content": + result["content_with_weight"] = v + case "docnm": + result["docnm_kwd"] = v + case "important_keywords": + utility.SetFieldArray(result, "important_kwd", v) + case "questions": + utility.SetFieldArray(result, "question_kwd", v) + case "entities_kwd", "entity_kwd", "entity_type_kwd", "from_entity_kwd", + "name_kwd", "raptor_kwd", "removed_kwd", "source_id", "tag_kwd", + "to_entity_kwd", "toc_kwd", "authors_tks", "doc_type_kwd": + if utility.IsEmpty(v) { + result[k] = []interface{}{} + } else { + result[k] = v + } + case "tag_feas": + if utility.IsEmpty(v) { + result[k] = map[string]interface{}{} + } else { + result[k] = v + } + case "create_timestamp_flt", "rank_flt", "weight_flt": + if floatVal, ok := utility.ToFloat64(v); ok { + result[k] = utility.JSONFloat64(floatVal) + } + default: + result[k] = v + } + } + return &GetChunkResponse{Chunk: result}, nil + } + } } - // Group by 5 elements - var result []interface{} - for i := 0; i < len(intVals); i += 5 { - end := i + 5 - if end > len(intVals) { - end = len(intVals) - } - group := make([]int, end-i) - copy(group, intVals[i:end]) - // Convert to interface{} for JSON serialization - groupIf := make([]interface{}, len(group)) - for j, v := range group { - groupIf[j] = v - } - result = append(result, groupIf) + if chunk == nil { + return nil, fmt.Errorf("chunk not found") } - return result + return &GetChunkResponse{Chunk: chunk}, nil +} + +// ListChunksRequest request for listing chunks +type ListChunksRequest struct { + DocID string `json:"doc_id" binding:"required"` + Page *int `json:"page,omitempty"` + Size *int `json:"size,omitempty"` + Keywords string `json:"keywords,omitempty"` + AvailableInt *int `json:"available_int,omitempty"` +} + +// ListChunksResponse response for listing chunks +type ListChunksResponse struct { + Chunks []map[string]interface{} `json:"chunks"` + Doc map[string]interface{} `json:"doc"` + Total int64 `json:"total"` +} + +// List retrieves chunks for a document +func (s *ChunkService) List(req *ListChunksRequest, userID string) (*ListChunksResponse, error) { + if s.docEngine == nil { + return nil, fmt.Errorf("doc engine not initialized") + } + + if req.DocID == "" { + return nil, fmt.Errorf("doc_id is required") + } + + ctx := context.Background() + + // Get user's tenants + tenants, err := s.userTenantDAO.GetByUserID(userID) + if err != nil { + return nil, fmt.Errorf("failed to get user tenants: %w", err) + } + if len(tenants) == 0 { + return nil, fmt.Errorf("user has no accessible tenants") + } + + // Get document to find its tenant + docDAO := dao.NewDocumentDAO() + doc, err := docDAO.GetByID(req.DocID) + if err != nil || doc == nil { + return nil, fmt.Errorf("document not found") + } + + // Get knowledge base to find tenant + kb, err := s.kbDAO.GetByID(doc.KbID) + if err != nil || kb == nil { + return nil, fmt.Errorf("knowledge base not found") + } + + // Find which tenant this document belongs to + var targetTenantID string + for _, tenant := range tenants { + if tenant.TenantID == kb.TenantID { + targetTenantID = tenant.TenantID + break + } + } + if targetTenantID == "" { + return nil, fmt.Errorf("user does not have access to this document") + } + + // Get kbIDs for this tenant + kbIDs, err := s.kbDAO.GetKBIDsByTenantID(targetTenantID) + if err != nil { + return nil, fmt.Errorf("failed to get kb ids: %w", err) + } + + indexName := fmt.Sprintf("ragflow_%s", targetTenantID) + + page := getPageNum(req.Page) + size := getPageSize(req.Size) + keywords := req.Keywords + + // Build search request - same as retrieval test but filtered by doc_id + searchReq := &engine.SearchRequest{ + IndexNames: []string{indexName}, + Question: keywords, + KbIDs: kbIDs, + DocIDs: []string{req.DocID}, + Page: page, + Size: size, + TopK: size, + } + + // Add available_int filter if specified + if req.AvailableInt != nil { + searchReq.AvailableInt = req.AvailableInt + } + + // Execute search through unified engine interface + result, err := s.docEngine.Search(ctx, searchReq) + if err != nil { + return nil, fmt.Errorf("search failed: %w", err) + } + + // Convert result to unified response + searchResp, ok := result.(*engine.SearchResponse) + if !ok { + return nil, fmt.Errorf("invalid search response type") + } + + // Format output to match Python + chunks := make([]map[string]interface{}, 0, len(searchResp.Chunks)) + for _, chunk := range searchResp.Chunks { + // Inline formatChunkForList + result := make(map[string]interface{}) + skipFields := map[string]bool{ + "_id": true, "authors": true, "_score": true, "SCORE": true, + "important_kwd_empty_count": true, "kb_id": true, "mom_id": true, "page_num_int": true, + } + for k, v := range chunk { + if skipFields[k] || strings.HasSuffix(k, "_vec") || strings.Contains(k, "_sm_") || strings.HasSuffix(k, "_ltks") || strings.HasSuffix(k, "_tks") { + continue + } + switch k { + case "img_id": + if strVal, ok := v.(string); ok { + result["image_id"] = strVal + } else { + result["image_id"] = "" + } + case "position_int": + result["positions"] = v + case "id": + result["chunk_id"] = v + case "content": + result["content_with_weight"] = v + case "docnm": + result["docnm_kwd"] = v + case "important_keywords": + utility.SetFieldArray(result, "important_kwd", v) + case "questions": + utility.SetFieldArray(result, "question_kwd", v) + case "entities_kwd", "entity_kwd", "entity_type_kwd", "from_entity_kwd", + "name_kwd", "raptor_kwd", "removed_kwd", + "source_id", "tag_kwd", "to_entity_kwd", "toc_kwd", "doc_type_kwd": + if utility.IsEmpty(v) { + result[k] = []interface{}{} + } else { + result[k] = v + } + default: + // Handle _kwd fields that need "###" splitting + if strings.HasSuffix(k, "_kwd") && k != "knowledge_graph_kwd" { + if strVal, ok := v.(string); ok && strings.Contains(strVal, "###") { + parts := strings.Split(strVal, "###") + var filtered []interface{} + for _, p := range parts { + if p != "" { + filtered = append(filtered, p) + } + } + result[k] = filtered + } else { + result[k] = v + } + } else { + result[k] = v + } + } + } + chunks = append(chunks, result) + } + + // Build document info (matching Python doc.to_dict()) + timeFormat := "2006-01-02T15:04:05" + docInfo := map[string]interface{}{ + "id": doc.ID, + "thumbnail": doc.Thumbnail, + "kb_id": doc.KbID, + "parser_id": doc.ParserID, + "pipeline_id": doc.PipelineID, + "parser_config": doc.ParserConfig, + "source_type": doc.SourceType, + "type": doc.Type, + "created_by": doc.CreatedBy, + "name": doc.Name, + "location": doc.Location, + "size": doc.Size, + "token_num": doc.TokenNum, + "chunk_num": doc.ChunkNum, + "progress": utility.JSONFloat64(doc.Progress), + "progress_msg": doc.ProgressMsg, + "process_begin_at": utility.FormatTimeToString(doc.ProcessBeginAt, timeFormat), + "process_duration": doc.ProcessDuration, + "content_hash": doc.ContentHash, + "suffix": doc.Suffix, + "run": doc.Run, + "status": doc.Status, + "create_time": doc.CreateTime, + "create_date": utility.FormatTimeToString(doc.CreateDate, timeFormat), + "update_time": doc.UpdateTime, + "update_date": utility.FormatTimeToString(doc.UpdateDate, timeFormat), + } + + return &ListChunksResponse{ + Total: searchResp.Total, + Chunks: chunks, + Doc: docInfo, + }, nil } diff --git a/internal/utility/convert.go b/internal/utility/convert.go index ae6a6e591f..87c52dc35f 100644 --- a/internal/utility/convert.go +++ b/internal/utility/convert.go @@ -19,9 +19,19 @@ package utility import ( "fmt" "os" + "strconv" + "strings" "time" ) +// JSONFloat64 is a float64 that always marshals with decimal point +type JSONFloat64 float64 + +func (f JSONFloat64) MarshalJSON() ([]byte, error) { + // Always output with decimal point (e.g., 0.0 instead of 0) + return []byte(fmt.Sprintf("%.1f", float64(f))), nil +} + // GetProjectBaseDirectory returns the current working directory. // If an error occurs while getting the current directory, it returns ".". // @@ -87,3 +97,123 @@ func FormatTime(t time.Time) string { } return t.Format("2006-01-02 15:04:05") } + +// FormatTimeToString converts time.Time to string in specified format +func FormatTimeToString(t *time.Time, format string) interface{} { + if t == nil { + return nil + } + return t.Format(format) +} + +// ConvertHexToPositionIntArray converts hex string to position int array (grouped by 5) +func ConvertHexToPositionIntArray(hexStr string) interface{} { + if hexStr == "" { + return nil + } + + parts := strings.Split(hexStr, "_") + var intVals []int + for _, part := range parts { + if part == "" { + continue + } + val, err := strconv.ParseInt(part, 16, 64) + if err != nil { + continue + } + intVals = append(intVals, int(val)) + } + + if len(intVals) == 0 { + return nil + } + + // Group by 5 elements + var result [][]int + for i := 0; i < len(intVals); i += 5 { + end := i + 5 + if end > len(intVals) { + end = len(intVals) + } + result = append(result, intVals[i:end]) + } + + return result +} + +// ConvertHexToIntArray converts hex string to int array (split by "_") +func ConvertHexToIntArray(hexStr string) interface{} { + if hexStr == "" { + return nil + } + + parts := strings.Split(hexStr, "_") + var result []int + for _, part := range parts { + if part == "" { + continue + } + val, err := strconv.ParseInt(part, 16, 64) + if err != nil { + continue + } + result = append(result, int(val)) + } + + if len(result) == 0 { + return nil + } + return result +} + +// IsEmpty checks if value is empty (nil, empty array, or empty string) +func IsEmpty(v interface{}) bool { + if v == nil { + return true + } + if arr, ok := v.([]interface{}); ok { + return len(arr) == 0 + } + if arr, ok := v.([]string); ok { + return len(arr) == 0 + } + if arr, ok := v.([]int); ok { + return len(arr) == 0 + } + if strVal, ok := v.(string); ok && strVal == "" { + return true + } + return false +} + +// SetFieldArray copies value to dest key, or sets empty array if value is empty +func SetFieldArray(result map[string]interface{}, destKey string, v interface{}) { + if IsEmpty(v) { + result[destKey] = []interface{}{} + } else { + result[destKey] = v + } +} + +// ToFloat64 converts various types to float64 +func ToFloat64(val interface{}) (float64, bool) { + switch v := val.(type) { + case float64: + return v, true + case float32: + return float64(v), true + case int: + return float64(v), true + case int64: + return float64(v), true + case string: + f, err := strconv.ParseFloat(v, 64) + if err != nil { + return 0, false + } + return f, true + default: + return 0, false + } +}