Implement GetChunk() in Infinity in GO (#13758)

### What problem does this PR solve?

Implement GetChunk() in Infinity in GO

Add cli:
GET CHUNK 'XXX';
LIST CHUNKS OF DOCUMENT 'XXX';

### Type of change

- [x] Refactoring
This commit is contained in:
qinling0210
2026-03-24 20:10:21 +08:00
committed by GitHub
parent b308cd3a02
commit 7c8927c4fb
11 changed files with 989 additions and 75 deletions

View File

@@ -91,6 +91,8 @@ sql_command: login_user
| parse_dataset_async
| import_docs_into_dataset
| search_on_datasets
| get_chunk
| list_chunks
| create_chat_session
| drop_chat_session
| list_chat_sessions
@@ -164,6 +166,7 @@ DEFAULT: "DEFAULT"i
CHATS: "CHATS"i
CHAT: "CHAT"i
FILES: "FILES"i
DOCUMENT: "DOCUMENT"i
DOCUMENTS: "DOCUMENTS"i
METADATA: "METADATA"i
SUMMARY: "SUMMARY"i
@@ -194,6 +197,13 @@ FINGERPRINT: "FINGERPRINT"i
LICENSE: "LICENSE"i
CHECK: "CHECK"i
CONFIG: "CONFIG"i
CHUNK: "CHUNK"i
CHUNKS: "CHUNKS"i
GET: "GET"i
PAGE: "PAGE"i
SIZE: "SIZE"i
KEYWORDS: "KEYWORDS"i
AVAILABLE: "AVAILABLE"i
login_user: LOGIN USER quoted_string ";"
list_services: LIST SERVICES ";"
@@ -321,6 +331,8 @@ list_user_model_providers: LIST MODEL PROVIDERS ";"
list_user_default_models: LIST DEFAULT MODELS ";"
import_docs_into_dataset: IMPORT quoted_string INTO DATASET quoted_string ";"
search_on_datasets: SEARCH quoted_string ON DATASETS quoted_string ";"
get_chunk: GET CHUNK quoted_string ";"
list_chunks: LIST CHUNKS OF DOCUMENT quoted_string ("PAGE" NUMBER)? ("SIZE" NUMBER)? ("KEYWORDS" quoted_string)? ("AVAILABLE" NUMBER)? ";"
parse_dataset_docs: PARSE quoted_string OF DATASET quoted_string ";"
parse_dataset_sync: PARSE DATASET quoted_string SYNC ";"
@@ -698,6 +710,28 @@ class RAGFlowCLITransformer(Transformer):
datasets = datasets.split(" ")
return {"type": "search_on_datasets", "datasets": datasets, "question": question}
def get_chunk(self, items):
chunk_id = items[2].children[0].strip("'\"")
return {"type": "get_chunk", "chunk_id": chunk_id}
def list_chunks(self, items):
doc_id = items[4].children[0].strip("'\"")
result = {"type": "list_chunks", "doc_id": doc_id}
# Parse optional parameters: PAGE, SIZE, KEYWORDS, AVAILABLE
# items structure varies based on which params are present
for i, item in enumerate(items):
if str(item) == "PAGE":
result["page"] = int(items[i + 1])
elif str(item) == "SIZE":
result["size"] = int(items[i + 1])
elif str(item) == "KEYWORDS":
result["keywords"] = items[i + 1].children[0].strip("'\"")
elif str(item) == "AVAILABLE":
result["available_int"] = int(items[i + 1])
return result
def benchmark(self, items):
concurrency: int = int(items[1])
iterations: int = int(items[2])

View File

@@ -1434,6 +1434,61 @@ class RAGFlowClient:
print(
f"Fail to search datasets: {dataset_names}, code: {res_json['code']}, message: {res_json['message']}")
def get_chunk(self, command_dict):
if self.server_type != "user":
print("This command is only allowed in USER mode")
return
chunk_id = command_dict["chunk_id"]
response = self.http_client.request("GET", f"/chunk/get?chunk_id={chunk_id}", use_api_base=False,
auth_kind="web")
res_json = response.json()
if response.status_code == 200:
if res_json["code"] == 0:
self._print_key_value(res_json["data"])
else:
print(f"Fail to get chunk, code: {res_json['code']}, message: {res_json['message']}")
else:
print(f"Fail to get chunk, code: {res_json['code']}, message: {res_json['message']}")
def list_chunks(self, command_dict):
if self.server_type != "user":
print("This command is only allowed in USER mode")
return
doc_id = command_dict["doc_id"]
payload = {
"doc_id": doc_id,
}
# Add optional parameters (only if explicitly provided)
if "page" in command_dict:
payload["page"] = command_dict["page"]
if "size" in command_dict:
payload["size"] = command_dict["size"]
if "keywords" in command_dict and command_dict["keywords"]:
payload["keywords"] = command_dict["keywords"]
if "available_int" in command_dict:
payload["available_int"] = command_dict["available_int"]
response = self.http_client.request("POST", "/chunk/list", json_body=payload, use_api_base=False,
auth_kind="web")
res_json = response.json()
if response.status_code == 200:
if res_json["code"] == 0:
chunks = res_json["data"]["chunks"]
if chunks:
for i, chunk in enumerate(chunks):
print(f"\n--- Chunk {i+1} ---")
for key, value in chunk.items():
print(f" {key}: {value}")
else:
print("No chunks found")
else:
print(f"Fail to list chunks, code: {res_json['code']}, message: {res_json['message']}")
else:
print(f"Fail to list chunks, code: {res_json['code']}, message: {res_json['message']}")
def show_version(self, command):
if self.server_type == "admin":
response = self.http_client.request("GET", "/admin/version", use_api_base=True, auth_kind="admin")
@@ -1618,6 +1673,14 @@ class RAGFlowClient:
print(separator)
def _print_key_value(self, data: dict):
"""Print data as key-value pairs (one per line)"""
if not data:
print("No data to print")
return
for key, value in data.items():
print(f"{key}: {value}")
def run_command(client: RAGFlowClient, command_dict: dict):
command_type = command_dict["type"]
@@ -1761,6 +1824,10 @@ def run_command(client: RAGFlowClient, command_dict: dict):
client.import_docs_into_dataset(command_dict)
case "search_on_datasets":
return client.search_on_datasets(command_dict)
case "get_chunk":
return client.get_chunk(command_dict)
case "list_chunks":
return client.list_chunks(command_dict)
case "meta":
_handle_meta_command(command_dict)
case _:
@@ -1818,6 +1885,8 @@ LIST DOCUMENTS OF DATASET <dataset>
SEARCH <query> ON DATASETS <dataset>
LIST METADATA OF DATASETS <dataset>[, <dataset>]*
LIST METADATA SUMMARY OF DATASET <dataset> DOCUMENTS <doc_id>[, <doc_id>]*
GET CHUNK <chunk_id>
LIST CHUNKS OF DOCUMENT <doc_id> [PAGE <page>] [SIZE <size>] [KEYWORDS <keywords>] [AVAILABLE <0|1>]
Meta Commands:
\\?, \\h, \\help Show this help

View File

@@ -0,0 +1,56 @@
//
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
package elasticsearch
import (
"context"
"fmt"
)
// GetChunk gets a chunk by ID
func (e *elasticsearchEngine) GetChunk(ctx context.Context, indexName, chunkID string, kbIDs []string) (interface{}, error) {
// Build query to get the chunk by ID
query := map[string]interface{}{
"term": map[string]interface{}{
"id": chunkID,
},
}
searchReq := &SearchRequest{
IndexNames: []string{indexName},
Query: query,
Size: 1,
From: 0,
}
// Execute search
result, err := e.Search(ctx, searchReq)
if err != nil {
return nil, fmt.Errorf("failed to search: %w", err)
}
esResp, ok := result.(*SearchResponse)
if !ok {
return nil, fmt.Errorf("invalid search response type")
}
if len(esResp.Hits.Hits) == 0 {
return nil, nil
}
return esResp.Hits.Hits[0].Source, nil
}

View File

@@ -49,9 +49,11 @@ type DocEngine interface {
// Document operations
IndexDocument(ctx context.Context, indexName, docID string, doc interface{}) error
BulkIndex(ctx context.Context, indexName string, docs []interface{}) (interface{}, error)
GetDocument(ctx context.Context, indexName, docID string) (interface{}, error)
DeleteDocument(ctx context.Context, indexName, docID string) error
// Chunk operations
GetChunk(ctx context.Context, indexName, chunkID string, kbIDs []string) (interface{}, error)
// Health check
Ping(ctx context.Context) error
Close() error

View File

@@ -0,0 +1,219 @@
//
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
package infinity
import (
"context"
"fmt"
"strings"
infinity "github.com/infiniflow/infinity-go-sdk"
"ragflow/internal/logger"
"ragflow/internal/utility"
"go.uber.org/zap"
)
// GetChunk gets a chunk by ID
func (e *infinityEngine) GetChunk(ctx context.Context, tableName, chunkID string, kbIDs []string) (interface{}, error) {
if e.client == nil || e.client.conn == nil {
return nil, fmt.Errorf("Infinity client not initialized")
}
// Build list of table names to search
var tableNames []string
if strings.HasPrefix(tableName, "ragflow_doc_meta_") {
tableNames = []string{tableName}
} else {
// Search in tables like <tableName>_<kb_id> for each kbID
if len(kbIDs) > 0 {
for _, kbID := range kbIDs {
tableNames = append(tableNames, fmt.Sprintf("%s_%s", tableName, kbID))
}
}
// Also try the base tableName
tableNames = append(tableNames, tableName)
}
// Try each table and collect results from all tables
db, err := e.client.conn.GetDatabase(e.client.dbName)
if err != nil {
return nil, fmt.Errorf("failed to get database: %w", err)
}
// Collect chunks from all tables (same as Python's concat_dataframes)
allChunks := make(map[string]map[string]interface{})
for _, tblName := range tableNames {
table, err := db.GetTable(tblName)
if err != nil {
continue
}
// Query with filter for the specific chunk ID
filter := fmt.Sprintf("id = '%s'", chunkID)
result, err := table.Output([]string{"*"}).Filter(filter).ToResult()
if err != nil {
continue
}
qr, ok := result.(*infinity.QueryResult)
if !ok {
continue
}
if len(qr.Data) == 0 {
continue
}
// Convert to chunk format
chunks := make([]map[string]interface{}, 0)
for colName, colData := range qr.Data {
for i, val := range colData {
for len(chunks) <= i {
chunks = append(chunks, make(map[string]interface{}))
}
chunks[i][colName] = val
}
}
// Merge chunks into allChunks (by id), keeping first non-empty value
for _, chunk := range chunks {
if idVal, ok := chunk["id"].(string); ok {
if existing, exists := allChunks[idVal]; exists {
// Merge: keep first non-empty value for each field
for k, v := range chunk {
if _, has := existing[k]; !has || utility.IsEmpty(v) {
existing[k] = v
}
}
} else {
allChunks[idVal] = chunk
}
}
}
}
// Get the chunk by chunkID
chunk, found := allChunks[chunkID]
if !found {
return nil, nil
}
getFields(chunk)
logger.Debug("infinity get chunk", zap.String("chunkID", chunkID), zap.Any("tables", tableNames))
return chunk, nil
}
// getFields applies field mappings to a chunk, similar to Python's get_fields function.
func getFields(chunk map[string]interface{}) {
// Field mappings
// docnm -> docnm_kwd, title_tks, title_sm_tks
if val, ok := chunk["docnm"].(string); ok {
chunk["docnm_kwd"] = val
chunk["title_tks"] = val
chunk["title_sm_tks"] = val
}
// important_keywords -> important_kwd (split by comma), important_tks
if val, ok := chunk["important_keywords"].(string); ok {
if val == "" {
chunk["important_kwd"] = []interface{}{}
} else {
parts := strings.Split(val, ",")
chunk["important_kwd"] = parts
}
chunk["important_tks"] = val
} else {
chunk["important_kwd"] = []interface{}{}
chunk["important_tks"] = []interface{}{}
}
// questions -> question_kwd (split by newline), question_tks
if val, ok := chunk["questions"].(string); ok {
if val == "" {
chunk["question_kwd"] = []interface{}{}
} else {
parts := strings.Split(val, "\n")
chunk["question_kwd"] = parts
}
chunk["question_tks"] = val
} else {
chunk["question_kwd"] = []interface{}{}
chunk["question_tks"] = []interface{}{}
}
// content -> content_with_weight, content_ltks, content_sm_ltks
if val, ok := chunk["content"].(string); ok {
chunk["content_with_weight"] = val
chunk["content_ltks"] = val
chunk["content_sm_ltks"] = val
}
// authors -> authors_tks, authors_sm_tks
if val, ok := chunk["authors"].(string); ok {
chunk["authors_tks"] = val
chunk["authors_sm_tks"] = val
}
// position_int: convert from hex string to array format (grouped by 5)
if val, ok := chunk["position_int"].(string); ok {
chunk["position_int"] = utility.ConvertHexToPositionIntArray(val)
} else {
chunk["position_int"] = []interface{}{}
}
// Convert page_num_int and top_int from hex string to array
for _, colName := range []string{"page_num_int", "top_int"} {
if val, ok := chunk[colName].(string); ok && val != "" {
chunk[colName] = utility.ConvertHexToIntArray(val)
} else {
chunk[colName] = []int{}
}
}
// Post-process: convert nil/empty values to empty slices for array-like fields
// and split _kwd fields by "###" (except knowledge_graph_kwd, docnm_kwd, important_kwd, question_kwd)
kwdNoSplit := map[string]bool{
"knowledge_graph_kwd": true, "docnm_kwd": true,
"important_kwd": true, "question_kwd": true,
}
arrayFields := []string{
"doc_type_kwd", "important_kwd", "important_tks", "question_tks",
"question_kwd", "authors_tks", "authors_sm_tks", "title_tks",
"title_sm_tks", "content_ltks", "content_sm_ltks",
}
for _, colName := range arrayFields {
if val, ok := chunk[colName]; !ok || val == nil || val == "" {
chunk[colName] = []interface{}{}
} else if !kwdNoSplit[colName] {
// Split by "###" for _kwd fields
if strVal, ok := val.(string); ok && strings.Contains(strVal, "###") {
parts := strings.Split(strVal, "###")
var filtered []interface{}
for _, p := range parts {
if p != "" {
filtered = append(filtered, p)
}
}
chunk[colName] = filtered
}
}
}
}

View File

@@ -20,7 +20,7 @@ import (
"context"
"fmt"
"ragflow/internal/engine/types"
"strconv"
"ragflow/internal/utility"
"strings"
"unicode/utf8"
@@ -458,18 +458,25 @@ func (e *infinityEngine) searchUnified(ctx context.Context, req *types.SearchReq
}
}
// DocIDs filters by doc_id (document ID) to find all chunks belonging to a document
// This is used by ChunkService.List() to list all chunks for a document
if len(req.DocIDs) > 0 {
if len(req.DocIDs) == 1 {
filterParts = append(filterParts, fmt.Sprintf("id = '%s'", req.DocIDs[0]))
filterParts = append(filterParts, fmt.Sprintf("doc_id = '%s'", req.DocIDs[0]))
} else {
docIDs := strings.Join(req.DocIDs, "', '")
filterParts = append(filterParts, fmt.Sprintf("id IN ('%s')", docIDs))
filterParts = append(filterParts, fmt.Sprintf("doc_id IN ('%s')", docIDs))
}
}
if !isMetadataTable {
// Default filter for available chunks
filterParts = append(filterParts, "available_int=1")
// Only add available_int filter when there's text/vector match or AvailableInt is explicitly set
// This matches Python's behavior where chunk_list doesn't filter by available_int
if !isMetadataTable && (hasTextMatch || hasVectorMatch || req.AvailableInt != nil) {
if req.AvailableInt != nil {
filterParts = append(filterParts, fmt.Sprintf("available_int=%d", *req.AvailableInt))
} else {
filterParts = append(filterParts, "available_int=1")
}
}
filterStr := strings.Join(filterParts, " AND ")
@@ -637,13 +644,13 @@ func calculateScores(chunks []map[string]interface{}, scoreColumn, pagerankField
for i := range chunks {
score := 0.0
if scoreVal, ok := chunks[i][scoreColumn]; ok {
if f, ok := toFloat64(scoreVal); ok {
if f, ok := utility.ToFloat64(scoreVal); ok {
score += f
fmt.Printf("[DEBUG] chunk[%d]: %s=%f\n", i, scoreColumn, f)
}
}
if pagerankVal, ok := chunks[i][pagerankField]; ok {
if f, ok := toFloat64(pagerankVal); ok {
if f, ok := utility.ToFloat64(pagerankVal); ok {
score += f
}
}
@@ -699,27 +706,6 @@ func getScore(chunk map[string]interface{}) float64 {
return 0.0
}
func toFloat64(val interface{}) (float64, bool) {
switch v := val.(type) {
case float64:
return v, true
case float32:
return float64(v), true
case int:
return float64(v), true
case int64:
return float64(v), true
case string:
f, err := strconv.ParseFloat(v, 64)
if err != nil {
return 0, false
}
return f, true
default:
return 0, false
}
}
// executeTableSearch executes search on a single table
func (e *infinityEngine) executeTableSearch(db *infinity.Database, tableName string, outputColumns []string, question string, vector []float64, filterStr string, topK, pageSize, offset int, orderBy *OrderByExpr, rankFeature map[string]float64, similarityThreshold float64, minMatch float64) (*types.SearchResponse, error) {
// Debug logging
@@ -937,6 +923,18 @@ func (e *infinityEngine) executeQuery(table *infinity.Table) (*types.SearchRespo
chunks[i][colName] = []interface{}{}
}
}
// Convert position_int from hex string to array format
if posVal, ok := chunks[i]["position_int"].(string); ok {
chunks[i]["position_int"] = utility.ConvertHexToPositionIntArray(posVal)
} else {
chunks[i]["position_int"] = []interface{}{}
}
// Convert page_num_int and top_int from hex string to array
for _, colName := range []string{"page_num_int", "top_int"} {
if val, ok := chunks[i][colName].(string); ok {
chunks[i][colName] = utility.ConvertHexToIntArray(val)
}
}
}
return &types.SearchResponse{

View File

@@ -28,8 +28,9 @@ type SearchRequest struct {
Keywords []string // Extracted keywords from question
// Filters
KbIDs []string // Knowledge base IDs filter
DocIDs []string // Document IDs filter
KbIDs []string // Knowledge base IDs filter
DocIDs []string // Document IDs filter
AvailableInt *int // Available_int filter (1 = available, 0 = unavailable)
// Pagination
Page int // Page number (1-based)

View File

@@ -165,3 +165,84 @@ func (h *ChunkHandler) RetrievalTest(c *gin.Context) {
"message": "success",
})
}
// Get retrieves a chunk by ID
func (h *ChunkHandler) Get(c *gin.Context) {
user, errorCode, errorMessage := GetUser(c)
if errorCode != common.CodeSuccess {
jsonError(c, errorCode, errorMessage)
return
}
chunkID := c.Query("chunk_id")
if chunkID == "" {
c.JSON(http.StatusBadRequest, gin.H{
"code": 400,
"message": "chunk_id is required",
})
return
}
req := &service.GetChunkRequest{
ChunkID: chunkID,
}
resp, err := h.chunkService.Get(req, user.ID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"code": 500,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"code": 0,
"data": resp.Chunk,
"message": "success",
})
}
// List retrieves chunks for a document
func (h *ChunkHandler) List(c *gin.Context) {
user, errorCode, errorMessage := GetUser(c)
if errorCode != common.CodeSuccess {
jsonError(c, errorCode, errorMessage)
return
}
// Bind JSON request
var req service.ListChunksRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{
"code": 400,
"message": err.Error(),
})
return
}
// Set default values for optional parameters
if req.Page == nil {
defaultPage := 1
req.Page = &defaultPage
}
if req.Size == nil {
defaultSize := 30
req.Size = &defaultSize
}
resp, err := h.chunkService.List(&req, user.ID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"code": 500,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"code": 0,
"data": resp,
"message": "success",
})
}

View File

@@ -200,6 +200,8 @@ func (r *Router) Setup(engine *gin.Engine) {
chunk := authorized.Group("/v1/chunk")
{
chunk.POST("/retrieval_test", r.chunkHandler.RetrievalTest)
chunk.GET("/get", r.chunkHandler.Get)
chunk.POST("/list", r.chunkHandler.List)
}
// LLM routes

View File

@@ -20,7 +20,6 @@ import (
"context"
"fmt"
"ragflow/internal/server"
"strconv"
"strings"
"go.uber.org/zap"
@@ -77,9 +76,10 @@ type RetrievalTestRequest struct {
// RetrievalTestResponse retrieval test response
type RetrievalTestResponse struct {
Chunks []map[string]interface{} `json:"chunks"`
Labels []map[string]interface{} `json:"labels"`
Total int64 `json:"total,omitempty"`
Chunks []map[string]interface{} `json:"chunks"`
DocAggs []map[string]interface{} `json:"doc_aggs"`
Labels *[]map[string]interface{} `json:"labels"`
Total int64 `json:"total,omitempty"`
}
// RetrievalTest performs retrieval test
@@ -283,8 +283,8 @@ func (s *ChunkService) RetrievalTest(req *RetrievalTestRequest, userID string) (
// Perform reranking
// Reference: rag/nlp/search.py L404-L429
tkWeight := 1.0 - *req.VectorSimilarityWeight
vtWeight := *req.VectorSimilarityWeight
vtWeight := getVectorSimilarityWeight(req.VectorSimilarityWeight)
tkWeight := 1.0 - vtWeight
useInfinity := s.engineType == server.EngineInfinity
sim, term_similarity, vector_similarity := nlp.Rerank(
@@ -312,10 +312,71 @@ func (s *ChunkService) RetrievalTest(req *RetrievalTestRequest, userID string) (
convertedChunks := buildRetrievalTestResults(filteredChunks)
// Build doc_aggs by aggregating chunks by docnm
docAggsMap := make(map[string]struct {
docID string
count int
})
docNameOrder := []string{} // Track insertion order of doc names
for _, chunk := range filteredChunks {
docName := ""
docID := ""
if v, ok := chunk["docnm"].(string); ok {
docName = v
}
if v, ok := chunk["doc_id"].(string); ok {
docID = v
}
if docName == "" {
continue
}
if entry, exists := docAggsMap[docName]; exists {
entry.count++
docAggsMap[docName] = entry
} else {
docAggsMap[docName] = struct {
docID string
count int
}{docID: docID, count: 1}
docNameOrder = append(docNameOrder, docName)
}
}
// Convert to list maintaining insertion order
type docAggEntry struct {
docName string
docID string
count int
order int
}
docAggsList := make([]docAggEntry, 0, len(docAggsMap))
for order, docName := range docNameOrder {
entry := docAggsMap[docName]
docAggsList = append(docAggsList, docAggEntry{docName: docName, docID: entry.docID, count: entry.count, order: order})
}
// Sort by count descending, then by order ascending (for tie-breaking)
for i := 0; i < len(docAggsList)-1; i++ {
for j := i + 1; j < len(docAggsList); j++ {
if docAggsList[j].count > docAggsList[i].count ||
(docAggsList[j].count == docAggsList[i].count && docAggsList[j].order < docAggsList[i].order) {
docAggsList[i], docAggsList[j] = docAggsList[j], docAggsList[i]
}
}
}
docAggs := make([]map[string]interface{}, 0, len(docAggsList))
for _, entry := range docAggsList {
docAggs = append(docAggs, map[string]interface{}{
"doc_name": entry.docName,
"doc_id": entry.docID,
"count": entry.count,
})
}
return &RetrievalTestResponse{
Chunks: convertedChunks,
Labels: []map[string]interface{}{}, // Empty labels for now
Total: int64(len(convertedChunks)),
Chunks: convertedChunks,
DocAggs: docAggs,
Labels: nil,
Total: int64(len(convertedChunks)),
}, nil
}
@@ -457,11 +518,7 @@ func buildRetrievalTestResults(filteredChunks []map[string]interface{}) []map[st
result["kb_id"] = v
}
if v, ok := chunk["position_int"]; ok {
if strVal, ok := v.(string); ok && strVal != "" {
result["positions"] = convertPositionInt(strVal)
} else {
result["positions"] = []interface{}{}
}
result["positions"] = v
}
if v, ok := chunk["doc_type_kwd"]; ok {
result["doc_type_kwd"] = v
@@ -490,42 +547,307 @@ func buildRetrievalTestResults(filteredChunks []map[string]interface{}) []map[st
return results
}
// convertPositionInt converts hex string format "00000001_0000005e_..." to array [[1, 94, ...], ...]
func convertPositionInt(hexStr string) []interface{} {
if hexStr == "" {
return []interface{}{}
// GetChunkRequest request for getting a chunk by ID
type GetChunkRequest struct {
ChunkID string `json:"chunk_id"`
}
// GetChunkResponse response for getting a chunk
type GetChunkResponse struct {
Chunk map[string]interface{} `json:"chunk"`
}
// Get retrieves a chunk by ID
func (s *ChunkService) Get(req *GetChunkRequest, userID string) (*GetChunkResponse, error) {
if s.docEngine == nil {
return nil, fmt.Errorf("doc engine not initialized")
}
parts := strings.Split(hexStr, "_")
var intVals []int
for _, part := range parts {
if part == "" {
continue
}
// Parse hex string (without 0x prefix)
val, err := strconv.ParseInt(part, 16, 64)
if req.ChunkID == "" {
return nil, fmt.Errorf("chunk_id is required")
}
ctx := context.Background()
// Get user's tenants
tenants, err := s.userTenantDAO.GetByUserID(userID)
if err != nil {
return nil, fmt.Errorf("failed to get user tenants: %w", err)
}
if len(tenants) == 0 {
return nil, fmt.Errorf("user has no accessible tenants")
}
// Try each tenant to find the chunk
var chunk map[string]interface{}
for _, tenant := range tenants {
// Get kbIDs for this tenant
kbIDs, err := s.kbDAO.GetKBIDsByTenantID(tenant.TenantID)
if err != nil {
continue
}
intVals = append(intVals, int(val))
indexName := fmt.Sprintf("ragflow_%s", tenant.TenantID)
doc, err := s.docEngine.GetChunk(ctx, indexName, req.ChunkID, kbIDs)
if err != nil {
continue
}
if doc != nil {
chunk, ok := doc.(map[string]interface{})
if ok {
// Format to match Python output
result := make(map[string]interface{})
skipFields := map[string]bool{
"id": true, "authors": true, "_score": true, "SCORE": true,
}
for k, v := range chunk {
if skipFields[k] || strings.HasSuffix(k, "_vec") || strings.Contains(k, "_sm_") || strings.HasSuffix(k, "_tks") || strings.HasSuffix(k, "_ltks") {
continue
}
switch k {
case "content":
result["content_with_weight"] = v
case "docnm":
result["docnm_kwd"] = v
case "important_keywords":
utility.SetFieldArray(result, "important_kwd", v)
case "questions":
utility.SetFieldArray(result, "question_kwd", v)
case "entities_kwd", "entity_kwd", "entity_type_kwd", "from_entity_kwd",
"name_kwd", "raptor_kwd", "removed_kwd", "source_id", "tag_kwd",
"to_entity_kwd", "toc_kwd", "authors_tks", "doc_type_kwd":
if utility.IsEmpty(v) {
result[k] = []interface{}{}
} else {
result[k] = v
}
case "tag_feas":
if utility.IsEmpty(v) {
result[k] = map[string]interface{}{}
} else {
result[k] = v
}
case "create_timestamp_flt", "rank_flt", "weight_flt":
if floatVal, ok := utility.ToFloat64(v); ok {
result[k] = utility.JSONFloat64(floatVal)
}
default:
result[k] = v
}
}
return &GetChunkResponse{Chunk: result}, nil
}
}
}
// Group by 5 elements
var result []interface{}
for i := 0; i < len(intVals); i += 5 {
end := i + 5
if end > len(intVals) {
end = len(intVals)
}
group := make([]int, end-i)
copy(group, intVals[i:end])
// Convert to interface{} for JSON serialization
groupIf := make([]interface{}, len(group))
for j, v := range group {
groupIf[j] = v
}
result = append(result, groupIf)
if chunk == nil {
return nil, fmt.Errorf("chunk not found")
}
return result
return &GetChunkResponse{Chunk: chunk}, nil
}
// ListChunksRequest request for listing chunks
type ListChunksRequest struct {
DocID string `json:"doc_id" binding:"required"`
Page *int `json:"page,omitempty"`
Size *int `json:"size,omitempty"`
Keywords string `json:"keywords,omitempty"`
AvailableInt *int `json:"available_int,omitempty"`
}
// ListChunksResponse response for listing chunks
type ListChunksResponse struct {
Chunks []map[string]interface{} `json:"chunks"`
Doc map[string]interface{} `json:"doc"`
Total int64 `json:"total"`
}
// List retrieves chunks for a document
func (s *ChunkService) List(req *ListChunksRequest, userID string) (*ListChunksResponse, error) {
if s.docEngine == nil {
return nil, fmt.Errorf("doc engine not initialized")
}
if req.DocID == "" {
return nil, fmt.Errorf("doc_id is required")
}
ctx := context.Background()
// Get user's tenants
tenants, err := s.userTenantDAO.GetByUserID(userID)
if err != nil {
return nil, fmt.Errorf("failed to get user tenants: %w", err)
}
if len(tenants) == 0 {
return nil, fmt.Errorf("user has no accessible tenants")
}
// Get document to find its tenant
docDAO := dao.NewDocumentDAO()
doc, err := docDAO.GetByID(req.DocID)
if err != nil || doc == nil {
return nil, fmt.Errorf("document not found")
}
// Get knowledge base to find tenant
kb, err := s.kbDAO.GetByID(doc.KbID)
if err != nil || kb == nil {
return nil, fmt.Errorf("knowledge base not found")
}
// Find which tenant this document belongs to
var targetTenantID string
for _, tenant := range tenants {
if tenant.TenantID == kb.TenantID {
targetTenantID = tenant.TenantID
break
}
}
if targetTenantID == "" {
return nil, fmt.Errorf("user does not have access to this document")
}
// Get kbIDs for this tenant
kbIDs, err := s.kbDAO.GetKBIDsByTenantID(targetTenantID)
if err != nil {
return nil, fmt.Errorf("failed to get kb ids: %w", err)
}
indexName := fmt.Sprintf("ragflow_%s", targetTenantID)
page := getPageNum(req.Page)
size := getPageSize(req.Size)
keywords := req.Keywords
// Build search request - same as retrieval test but filtered by doc_id
searchReq := &engine.SearchRequest{
IndexNames: []string{indexName},
Question: keywords,
KbIDs: kbIDs,
DocIDs: []string{req.DocID},
Page: page,
Size: size,
TopK: size,
}
// Add available_int filter if specified
if req.AvailableInt != nil {
searchReq.AvailableInt = req.AvailableInt
}
// Execute search through unified engine interface
result, err := s.docEngine.Search(ctx, searchReq)
if err != nil {
return nil, fmt.Errorf("search failed: %w", err)
}
// Convert result to unified response
searchResp, ok := result.(*engine.SearchResponse)
if !ok {
return nil, fmt.Errorf("invalid search response type")
}
// Format output to match Python
chunks := make([]map[string]interface{}, 0, len(searchResp.Chunks))
for _, chunk := range searchResp.Chunks {
// Inline formatChunkForList
result := make(map[string]interface{})
skipFields := map[string]bool{
"_id": true, "authors": true, "_score": true, "SCORE": true,
"important_kwd_empty_count": true, "kb_id": true, "mom_id": true, "page_num_int": true,
}
for k, v := range chunk {
if skipFields[k] || strings.HasSuffix(k, "_vec") || strings.Contains(k, "_sm_") || strings.HasSuffix(k, "_ltks") || strings.HasSuffix(k, "_tks") {
continue
}
switch k {
case "img_id":
if strVal, ok := v.(string); ok {
result["image_id"] = strVal
} else {
result["image_id"] = ""
}
case "position_int":
result["positions"] = v
case "id":
result["chunk_id"] = v
case "content":
result["content_with_weight"] = v
case "docnm":
result["docnm_kwd"] = v
case "important_keywords":
utility.SetFieldArray(result, "important_kwd", v)
case "questions":
utility.SetFieldArray(result, "question_kwd", v)
case "entities_kwd", "entity_kwd", "entity_type_kwd", "from_entity_kwd",
"name_kwd", "raptor_kwd", "removed_kwd",
"source_id", "tag_kwd", "to_entity_kwd", "toc_kwd", "doc_type_kwd":
if utility.IsEmpty(v) {
result[k] = []interface{}{}
} else {
result[k] = v
}
default:
// Handle _kwd fields that need "###" splitting
if strings.HasSuffix(k, "_kwd") && k != "knowledge_graph_kwd" {
if strVal, ok := v.(string); ok && strings.Contains(strVal, "###") {
parts := strings.Split(strVal, "###")
var filtered []interface{}
for _, p := range parts {
if p != "" {
filtered = append(filtered, p)
}
}
result[k] = filtered
} else {
result[k] = v
}
} else {
result[k] = v
}
}
}
chunks = append(chunks, result)
}
// Build document info (matching Python doc.to_dict())
timeFormat := "2006-01-02T15:04:05"
docInfo := map[string]interface{}{
"id": doc.ID,
"thumbnail": doc.Thumbnail,
"kb_id": doc.KbID,
"parser_id": doc.ParserID,
"pipeline_id": doc.PipelineID,
"parser_config": doc.ParserConfig,
"source_type": doc.SourceType,
"type": doc.Type,
"created_by": doc.CreatedBy,
"name": doc.Name,
"location": doc.Location,
"size": doc.Size,
"token_num": doc.TokenNum,
"chunk_num": doc.ChunkNum,
"progress": utility.JSONFloat64(doc.Progress),
"progress_msg": doc.ProgressMsg,
"process_begin_at": utility.FormatTimeToString(doc.ProcessBeginAt, timeFormat),
"process_duration": doc.ProcessDuration,
"content_hash": doc.ContentHash,
"suffix": doc.Suffix,
"run": doc.Run,
"status": doc.Status,
"create_time": doc.CreateTime,
"create_date": utility.FormatTimeToString(doc.CreateDate, timeFormat),
"update_time": doc.UpdateTime,
"update_date": utility.FormatTimeToString(doc.UpdateDate, timeFormat),
}
return &ListChunksResponse{
Total: searchResp.Total,
Chunks: chunks,
Doc: docInfo,
}, nil
}

View File

@@ -19,9 +19,19 @@ package utility
import (
"fmt"
"os"
"strconv"
"strings"
"time"
)
// JSONFloat64 is a float64 that always marshals with decimal point
type JSONFloat64 float64
func (f JSONFloat64) MarshalJSON() ([]byte, error) {
// Always output with decimal point (e.g., 0.0 instead of 0)
return []byte(fmt.Sprintf("%.1f", float64(f))), nil
}
// GetProjectBaseDirectory returns the current working directory.
// If an error occurs while getting the current directory, it returns ".".
//
@@ -87,3 +97,123 @@ func FormatTime(t time.Time) string {
}
return t.Format("2006-01-02 15:04:05")
}
// FormatTimeToString converts time.Time to string in specified format
func FormatTimeToString(t *time.Time, format string) interface{} {
if t == nil {
return nil
}
return t.Format(format)
}
// ConvertHexToPositionIntArray converts hex string to position int array (grouped by 5)
func ConvertHexToPositionIntArray(hexStr string) interface{} {
if hexStr == "" {
return nil
}
parts := strings.Split(hexStr, "_")
var intVals []int
for _, part := range parts {
if part == "" {
continue
}
val, err := strconv.ParseInt(part, 16, 64)
if err != nil {
continue
}
intVals = append(intVals, int(val))
}
if len(intVals) == 0 {
return nil
}
// Group by 5 elements
var result [][]int
for i := 0; i < len(intVals); i += 5 {
end := i + 5
if end > len(intVals) {
end = len(intVals)
}
result = append(result, intVals[i:end])
}
return result
}
// ConvertHexToIntArray converts hex string to int array (split by "_")
func ConvertHexToIntArray(hexStr string) interface{} {
if hexStr == "" {
return nil
}
parts := strings.Split(hexStr, "_")
var result []int
for _, part := range parts {
if part == "" {
continue
}
val, err := strconv.ParseInt(part, 16, 64)
if err != nil {
continue
}
result = append(result, int(val))
}
if len(result) == 0 {
return nil
}
return result
}
// IsEmpty checks if value is empty (nil, empty array, or empty string)
func IsEmpty(v interface{}) bool {
if v == nil {
return true
}
if arr, ok := v.([]interface{}); ok {
return len(arr) == 0
}
if arr, ok := v.([]string); ok {
return len(arr) == 0
}
if arr, ok := v.([]int); ok {
return len(arr) == 0
}
if strVal, ok := v.(string); ok && strVal == "" {
return true
}
return false
}
// SetFieldArray copies value to dest key, or sets empty array if value is empty
func SetFieldArray(result map[string]interface{}, destKey string, v interface{}) {
if IsEmpty(v) {
result[destKey] = []interface{}{}
} else {
result[destKey] = v
}
}
// ToFloat64 converts various types to float64
func ToFloat64(val interface{}) (float64, bool) {
switch v := val.(type) {
case float64:
return v, true
case float32:
return float64(v), true
case int:
return float64(v), true
case int64:
return float64(v), true
case string:
f, err := strconv.ParseFloat(v, 64)
if err != nil {
return 0, false
}
return f, true
default:
return 0, false
}
}