feat: Dify-compatible retrieval API endpoint (#15704)

## Summary

Dify-compatible retrieval API for external knowledge base integration.

## Changes

- **New handler**: DifyRetrievalHandler with POST/GET
/api/v1/dify/retrieval
- **Health check**: GET /api/v1/dify/retrieval/health
- **Full pipeline**: KB validation -> permission check -> embedding ->
metadata filter -> chunk retrieval -> child chunk aggregation ->
optional KG search -> response assembly
- **12 tests** covering all paths (success, errors, metadata filter, KG
mode)
- **Testability**: Handler dependencies defined as interfaces
(KBServiceIface, ModelServiceIface, etc.)

## Files

| File | Type |
|------|------|
| internal/handler/dify_retrieval_handler.go | New — handler +
interfaces |
| internal/handler/dify_retrieval_handler_test.go | New — 12 tests |
| internal/router/router.go | Modified — route registration |
| cmd/server_main.go | Modified — handler wiring |
| internal/service/kg/pipeline.go | Modified — SetChatModel/SetEmbModel
|
| internal/service/kg/retrieval.go | New — helper functions |
| internal/service/kg/scoring.go | Moved from service package |
| internal/service/kg/search.go | New — KG search functions |
| internal/service/kg/types.go | New — type definitions |

---------

Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
Jack
2026-06-05 21:16:25 +08:00
committed by GitHub
parent 1deb1313d2
commit 5a04ac0864
14 changed files with 1394 additions and 1000 deletions

View File

@@ -218,8 +218,20 @@ func startServer(config *server.Config) {
&handler.SearchbotRealLLM{Svc: modelProviderService},
)
// Dify retrieval handler
docDAO := dao.NewDocumentDAO()
retrievalService := nlp.NewRetrievalService(docEngine, docDAO)
difyRetrievalHandler := handler.NewDifyRetrievalHandler(
knowledgebaseService,
modelProviderService,
metadataService,
retrievalService,
docDAO,
docEngine,
)
// Initialize router
r := router.NewRouter(authHandler, userHandler, tenantHandler, documentHandler, datasetsHandler, systemHandler, knowledgebaseHandler, chunkHandler, llmHandler, chatHandler, chatSessionHandler, connectorHandler, searchHandler, fileHandler, memoryHandler, mcpHandler, skillSearchHandler, providerHandler, agentHandler, relatedQuestionsHandler)
r := router.NewRouter(authHandler, userHandler, tenantHandler, documentHandler, datasetsHandler, systemHandler, knowledgebaseHandler, chunkHandler, llmHandler, chatHandler, chatSessionHandler, connectorHandler, searchHandler, fileHandler, memoryHandler, mcpHandler, skillSearchHandler, providerHandler, agentHandler, relatedQuestionsHandler, difyRetrievalHandler)
// Create Gin engine
ginEngine := gin.New()

View File

@@ -0,0 +1,373 @@
//
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
package handler
import (
"context"
"errors"
"fmt"
"net/http"
"strconv"
"strings"
"ragflow/internal/common"
"gorm.io/gorm"
"go.uber.org/zap"
"ragflow/internal/engine"
"ragflow/internal/entity"
modelModule "ragflow/internal/entity/models"
"ragflow/internal/service"
"ragflow/internal/service/kg"
"ragflow/internal/service/nlp"
"github.com/gin-gonic/gin"
)
// --- Interfaces (for testability) ---
// KBServiceIface abstracts KnowledgebaseService for the Dify handler.
type KBServiceIface interface {
GetByID(kbID string) (*entity.Knowledgebase, error)
Accessible(kbID, userID string) bool
}
// ModelServiceIface abstracts ModelProviderService for the Dify handler.
type ModelServiceIface interface {
GetEmbeddingModel(tenantID, embdID string) (*modelModule.EmbeddingModel, error)
GetChatModel(tenantID, compositeModelName string) (*modelModule.ChatModel, error)
}
// MetadataServiceIface abstracts MetadataService for the Dify handler.
type MetadataServiceIface interface {
GetFlattedMetaByKBs(kbIDs []string) (common.MetaData, error)
LabelQuestion(question string, kbs []*entity.Knowledgebase) map[string]float64
}
// RetrievalServiceIface abstracts RetrievalService for the Dify handler.
type RetrievalServiceIface interface {
Retrieval(ctx context.Context, req *nlp.RetrievalRequest) (*nlp.RetrievalResult, error)
}
// DocumentDAOIface abstracts DocumentDAO for the Dify handler.
type DocumentDAOIface interface {
GetByIDs(ids []string) ([]*entity.Document, error)
}
// --- Request / Response types ---
// difyRetrievalRequest is the JSON body / query params for the Dify retrieval endpoint.
type difyRetrievalRequest struct {
KnowledgeID string `json:"knowledge_id" form:"knowledge_id"`
Query string `json:"query" form:"query"`
UseKG bool `json:"use_kg" form:"use_kg"`
RetrievalSetting *difyRetrievalSetting `json:"retrieval_setting"`
MetadataCondition *difyMetadataCondition `json:"metadata_condition"`
}
type difyRetrievalSetting struct {
TopK *int `json:"top_k" form:"top_k"`
ScoreThreshold *float64 `json:"score_threshold" form:"score_threshold"`
}
// difyCondition is a Dify-format metadata filter condition.
// Dify uses "name"/"comparison_operator" instead of MetaFilterCondition's "key"/"op".
type difyCondition struct {
Name string `json:"name"`
ComparisonOperator string `json:"comparison_operator"`
Value interface{} `json:"value"`
}
type difyMetadataCondition struct {
Conditions []difyCondition `json:"conditions"`
Logic string `json:"logic"`
}
// toMetaFilterConditions converts Dify-format conditions to internal MetaFilterConditions.
func (c difyMetadataCondition) toMetaFilterConditions() []service.MetaFilterCondition {
if len(c.Conditions) == 0 {
return nil
}
result := make([]service.MetaFilterCondition, len(c.Conditions))
for i, dc := range c.Conditions {
v := ""
if dc.Value != nil {
v = fmt.Sprint(dc.Value)
}
result[i] = service.MetaFilterCondition{
Key: dc.Name,
Op: dc.ComparisonOperator,
Value: v,
}
}
return result
}
// difyRecord is one item in the response records array.
type difyRecord struct {
Content string `json:"content"`
Score float64 `json:"score"`
Title string `json:"title"`
Metadata map[string]interface{} `json:"metadata"`
}
// --- Handler ---
// DifyRetrievalHandler handles Dify-compatible retrieval requests.
type DifyRetrievalHandler struct {
kbSvc KBServiceIface
modelSvc ModelServiceIface
metadataSvc MetadataServiceIface
retrievalSvc RetrievalServiceIface
docDAO DocumentDAOIface
docEngine engine.DocEngine
}
// NewDifyRetrievalHandler creates a new DifyRetrievalHandler.
// The KG pipeline is created inline when use_kg=true to avoid injecting
// a pipeline that depends on per-request model configuration.
func NewDifyRetrievalHandler(
kbSvc KBServiceIface,
modelSvc ModelServiceIface,
metadataSvc MetadataServiceIface,
retrievalSvc RetrievalServiceIface,
docDAO DocumentDAOIface,
docEngine engine.DocEngine,
) *DifyRetrievalHandler {
return &DifyRetrievalHandler{
kbSvc: kbSvc,
modelSvc: modelSvc,
metadataSvc: metadataSvc,
retrievalSvc: retrievalSvc,
docDAO: docDAO,
docEngine: docEngine,
}
}
// Retrieval handles POST/GET /api/v1/dify/retrieval.
// Matches Python: api/apps/restful_apis/dify_retrieval_api.py::retrieval()
func (h *DifyRetrievalHandler) Retrieval(c *gin.Context) {
user, errCode, errMsg := GetUser(c)
if errCode != common.CodeSuccess {
c.JSON(http.StatusUnauthorized, gin.H{"code": errCode, "message": errMsg})
return
}
var req difyRetrievalRequest
if c.Request.Method == http.MethodGet {
if err := c.ShouldBindQuery(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"code": common.CodeArgumentError, "message": "invalid query parameters"})
return
}
// Manually extract top_k and score_threshold from query (flat params, not nested)
if v := c.Query("top_k"); v != "" {
if parsed, err := strconv.Atoi(v); err == nil {
if req.RetrievalSetting == nil {
req.RetrievalSetting = &difyRetrievalSetting{}
}
req.RetrievalSetting.TopK = &parsed
}
}
if v := c.Query("score_threshold"); v != "" {
if parsed, err := strconv.ParseFloat(v, 64); err == nil {
if req.RetrievalSetting == nil {
req.RetrievalSetting = &difyRetrievalSetting{}
}
req.RetrievalSetting.ScoreThreshold = &parsed
}
}
} else {
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"code": common.CodeArgumentError, "message": "invalid request body"})
return
}
}
if req.KnowledgeID == "" || req.Query == "" {
c.JSON(http.StatusBadRequest, gin.H{"code": common.CodeArgumentError, "message": "knowledge_id and query are required"})
return
}
kb, err := h.kbSvc.GetByID(req.KnowledgeID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
c.JSON(http.StatusNotFound, gin.H{"code": common.CodeNotFound, "message": "Knowledgebase not found!"})
} else {
c.JSON(http.StatusInternalServerError, gin.H{"code": common.CodeServerError, "message": "failed to query knowledgebase"})
}
return
}
if !h.kbSvc.Accessible(req.KnowledgeID, user.ID) {
c.JSON(http.StatusUnauthorized, gin.H{"code": common.CodeAuthenticationError, "message": "No authorization."})
return
}
// Parse retrieval options (nil means service uses defaults)
var topK *int
if req.RetrievalSetting != nil && req.RetrievalSetting.TopK != nil {
topK = req.RetrievalSetting.TopK
}
var scoreThreshold *float64
if req.RetrievalSetting != nil && req.RetrievalSetting.ScoreThreshold != nil {
scoreThreshold = req.RetrievalSetting.ScoreThreshold
}
pageSize := 1024
if topK != nil {
pageSize = *topK
}
// Get embedding model
embModel, err := h.modelSvc.GetEmbeddingModel(kb.TenantID, kb.EmbdID)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"code": common.CodeServerError, "message": fmt.Sprintf("failed to get embedding model: %v", err)})
return
}
// Metadata filter
metas, metaErr := h.metadataSvc.GetFlattedMetaByKBs([]string{req.KnowledgeID})
docIDs := make([]string, 0)
if metaErr == nil && req.MetadataCondition != nil {
logic := req.MetadataCondition.Logic
if logic == "" {
logic = "and"
}
filteredIDs := service.ApplyMetaFilter(metas, req.MetadataCondition.toMetaFilterConditions(), logic)
docIDs = append(docIDs, filteredIDs...)
}
if len(docIDs) == 0 && req.MetadataCondition != nil {
docIDs = []string{"-999"}
}
// Label question for rank features
kbs := []*entity.Knowledgebase{kb}
rankFeature := h.metadataSvc.LabelQuestion(req.Query, kbs)
// Chunk retrieval
sr := &nlp.RetrievalRequest{
Question: req.Query,
TenantIDs: []string{kb.TenantID},
KbIDs: []string{req.KnowledgeID},
DocIDs: docIDs,
Page: 1,
PageSize: pageSize,
Top: topK,
SimilarityThreshold: scoreThreshold,
EmbeddingModel: embModel,
}
if rankFeature != nil {
sr.RankFeature = &rankFeature
}
result, err := h.retrievalSvc.Retrieval(c.Request.Context(), sr)
if err != nil {
if strings.Contains(err.Error(), "not_found") {
c.JSON(http.StatusNotFound, gin.H{"code": common.CodeNotFound, "message": "No chunk found! Check the chunk status please!"})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"code": common.CodeServerError, "message": err.Error()})
return
}
// Enrich with child chunks
chunks := nlp.RetrievalByChildren(result.Chunks, []string{kb.TenantID}, h.docEngine, c.Request.Context())
// KG retrieval (optional)
if req.UseKG {
chatModel, kgErr := h.modelSvc.GetChatModel(kb.TenantID, "")
if kgErr != nil {
common.Warn("KG retrieval: failed to get chat model", zap.String("kbID", req.KnowledgeID), zap.Error(kgErr))
} else if chatModel != nil {
kgPipeline := kg.NewPipeline(
h.docEngine,
[]string{req.KnowledgeID},
[]string{kb.TenantID},
req.Query,
)
kgPipeline.SetChatModel(chatModel)
kgPipeline.SetEmbModel(embModel)
if kgResult, kgErr := kgPipeline.Retrieval(c.Request.Context()); kgErr == nil {
if content, ok := kgResult["content_with_weight"].(string); ok && content != "" {
chunks = append([]map[string]interface{}{kgResult}, chunks...)
}
}
}
}
// Collect doc IDs and fetch documents
docIDSet := make(map[string]struct{})
for _, ch := range chunks {
if docID, ok := ch["doc_id"].(string); ok && docID != "" {
docIDSet[docID] = struct{}{}
}
}
allDocIDs := make([]string, 0, len(docIDSet))
for id := range docIDSet {
allDocIDs = append(allDocIDs, id)
}
docMap := make(map[string]*entity.Document)
if len(allDocIDs) > 0 {
docs, err := h.docDAO.GetByIDs(allDocIDs)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"code": common.CodeServerError, "message": fmt.Sprintf("failed to load documents: %v", err)})
return
}
for _, d := range docs {
docMap[d.ID] = d
}
}
// Build response
records := make([]difyRecord, 0, len(chunks))
for _, ch := range chunks {
docID, _ := ch["doc_id"].(string)
doc := docMap[docID]
if doc == nil {
continue
}
// Remove vector to reduce response size
delete(ch, "vector")
meta := make(map[string]interface{})
if doc.MetaFields != nil {
for k, v := range *doc.MetaFields {
meta[k] = v
}
}
meta["doc_id"] = docID
meta["document_id"] = docID
score, _ := ch["similarity"].(float64)
title, _ := ch["docnm_kwd"].(string)
content, _ := ch["content_with_weight"].(string)
records = append(records, difyRecord{
Content: content,
Score: score,
Title: title,
Metadata: meta,
})
}
c.JSON(http.StatusOK, gin.H{"records": records})
}
// HealthCheck returns a simple health check response.
func (h *DifyRetrievalHandler) HealthCheck(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"code": 0, "data": true})
}

View File

@@ -0,0 +1,401 @@
//
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
package handler
import (
"context"
"encoding/json"
"errors"
"gorm.io/gorm"
"net/http"
"net/http/httptest"
"strings"
"testing"
"ragflow/internal/common"
"ragflow/internal/engine"
"ragflow/internal/entity"
modelModule "ragflow/internal/entity/models"
"ragflow/internal/service/nlp"
"ragflow/internal/engine/types"
"github.com/gin-gonic/gin"
)
// --- Mock implementations ---
type mockKBService struct {
KBServiceIface
getByIDFn func(kbID string) (*entity.Knowledgebase, error)
accessibleFn func(kbID, userID string) bool
}
func (m *mockKBService) GetByID(kbID string) (*entity.Knowledgebase, error) {
if m.getByIDFn != nil {
return m.getByIDFn(kbID)
}
return &entity.Knowledgebase{
ID: kbID, TenantID: "tenant1", EmbdID: "text-embedding",
}, nil
}
func (m *mockKBService) Accessible(kbID, userID string) bool {
if m.accessibleFn != nil {
return m.accessibleFn(kbID, userID)
}
return true
}
type mockModelService struct {
ModelServiceIface
getEmbeddingFn func(tenantID, embdID string) (*modelModule.EmbeddingModel, error)
getChatModelFn func(tenantID, llmID string) (*modelModule.ChatModel, error)
}
func (m *mockModelService) GetEmbeddingModel(tenantID, embdID string) (*modelModule.EmbeddingModel, error) {
if m.getEmbeddingFn != nil {
return m.getEmbeddingFn(tenantID, embdID)
}
return &modelModule.EmbeddingModel{}, nil
}
func (m *mockModelService) GetChatModel(tenantID, llmID string) (*modelModule.ChatModel, error) {
if m.getChatModelFn != nil {
return m.getChatModelFn(tenantID, llmID)
}
return &modelModule.ChatModel{}, nil
}
type mockMetadataService struct {
MetadataServiceIface
getFlattedMetaFn func(kbIDs []string) (common.MetaData, error)
labelQuestionFn func(question string, kbs []*entity.Knowledgebase) map[string]float64
}
func (m *mockMetadataService) GetFlattedMetaByKBs(kbIDs []string) (common.MetaData, error) {
if m.getFlattedMetaFn != nil {
return m.getFlattedMetaFn(kbIDs)
}
return common.MetaData{}, nil
}
func (m *mockMetadataService) LabelQuestion(question string, kbs []*entity.Knowledgebase) map[string]float64 {
if m.labelQuestionFn != nil {
return m.labelQuestionFn(question, kbs)
}
return nil
}
type mockRetrievalService struct {
RetrievalServiceIface
retrievalFn func(ctx context.Context, req *nlp.RetrievalRequest) (*nlp.RetrievalResult, error)
}
func (m *mockRetrievalService) Retrieval(ctx context.Context, req *nlp.RetrievalRequest) (*nlp.RetrievalResult, error) {
if m.retrievalFn != nil {
return m.retrievalFn(ctx, req)
}
return &nlp.RetrievalResult{
Chunks: []map[string]interface{}{
{"doc_id": "doc1", "docnm_kwd": "Test Doc", "content_with_weight": "test content", "similarity": 0.85},
},
}, nil
}
type mockDocDAO struct {
DocumentDAOIface
getByIDsFn func(ids []string) ([]*entity.Document, error)
}
func (m *mockDocDAO) GetByIDs(ids []string) ([]*entity.Document, error) {
if m.getByIDsFn != nil {
return m.getByIDsFn(ids)
}
return []*entity.Document{
{ID: "doc1", Name: strPtr("Test Doc"), MetaFields: &entity.JSONMap{"author": "Zhang San"}},
}, nil
}
// mockDocEngine stubs the DocEngine interface (embed = panic on unimplemented).
type mockDocEngine struct {
engine.DocEngine
}
func (m *mockDocEngine) Close() error { return nil }
func (m *mockDocEngine) Ping(ctx context.Context) error { return nil }
func (m *mockDocEngine) GetType() string { return "mock" }
func (m *mockDocEngine) Search(ctx context.Context, req *types.SearchRequest) (*types.SearchResult, error) {
return &types.SearchResult{}, nil
}
func (m *mockDocEngine) GetChunk(ctx context.Context, _, _ string, _ []string) (interface{}, error) {
return map[string]interface{}{}, nil
}
// --- Helper ---
func setupDifyTest(userID string) (*DifyRetrievalHandler, *gin.Engine) {
h := &DifyRetrievalHandler{
kbSvc: &mockKBService{},
modelSvc: &mockModelService{},
metadataSvc: &mockMetadataService{},
retrievalSvc: &mockRetrievalService{},
docDAO: &mockDocDAO{},
docEngine: &mockDocEngine{},
}
gin.SetMode(gin.TestMode)
r := gin.New()
r.Use(func(c *gin.Context) {
c.Set("user", &entity.User{ID: userID})
})
r.POST("/api/v1/dify/retrieval", h.Retrieval)
r.GET("/api/v1/dify/retrieval", h.Retrieval)
r.GET("/api/v1/dify/retrieval/health", h.HealthCheck)
return h, r
}
func setupDifyTestNoAuth() (*DifyRetrievalHandler, *gin.Engine) {
h := &DifyRetrievalHandler{
kbSvc: &mockKBService{},
modelSvc: &mockModelService{},
metadataSvc: &mockMetadataService{},
retrievalSvc: &mockRetrievalService{},
docDAO: &mockDocDAO{},
docEngine: &mockDocEngine{},
}
gin.SetMode(gin.TestMode)
r := gin.New()
r.POST("/api/v1/dify/retrieval", h.Retrieval)
return h, r
}
// --- Tests ---
func TestDifyRetrieval_HealthCheck(t *testing.T) {
_, r := setupDifyTest("user1")
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/api/v1/dify/retrieval/health", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("expected 200, got %d", w.Code)
}
var resp map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatal(err)
}
if resp["data"] != true {
t.Errorf("expected data=true, got %v", resp["data"])
}
}
func TestDifyRetrieval_Basic(t *testing.T) {
_, r := setupDifyTest("user1")
body := `{"knowledge_id": "kb1", "query": "test question"}`
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/dify/retrieval", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String())
}
var resp map[string]interface{}
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatal(err)
}
records, ok := resp["records"].([]interface{})
if !ok || len(records) == 0 {
t.Errorf("expected non-empty records, got %v", resp["records"])
}
}
func TestDifyRetrieval_GET(t *testing.T) {
_, r := setupDifyTest("user1")
w := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/api/v1/dify/retrieval?knowledge_id=kb1&query=test", nil)
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String())
}
}
func TestDifyRetrieval_MissingArgs(t *testing.T) {
_, r := setupDifyTest("user1")
tests := []struct {
name string
body string
}{
{"no knowledge_id", `{"query": "test"}`},
{"no query", `{"knowledge_id": "kb1"}`},
{"empty body", `{}`},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/dify/retrieval", strings.NewReader(tc.body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("expected 400, got %d", w.Code)
}
})
}
}
func TestDifyRetrieval_KBNotFound(t *testing.T) {
h, r := setupDifyTest("user1")
h.kbSvc = &mockKBService{
getByIDFn: func(kbID string) (*entity.Knowledgebase, error) {
return nil, gorm.ErrRecordNotFound
},
}
w := httptest.NewRecorder()
body := `{"knowledge_id": "nonexistent", "query": "test"}`
req, _ := http.NewRequest("POST", "/api/v1/dify/retrieval", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Errorf("expected 404, got %d", w.Code)
}
}
func TestDifyRetrieval_NoAuth(t *testing.T) {
_, r := setupDifyTestNoAuth()
w := httptest.NewRecorder()
body := `{"knowledge_id": "kb1", "query": "test"}`
req, _ := http.NewRequest("POST", "/api/v1/dify/retrieval", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusUnauthorized {
t.Errorf("expected 401, got %d", w.Code)
}
}
func TestDifyRetrieval_Unauthorized(t *testing.T) {
h, r := setupDifyTest("user1")
h.kbSvc = &mockKBService{
accessibleFn: func(kbID, userID string) bool { return false },
}
w := httptest.NewRecorder()
body := `{"knowledge_id": "kb1", "query": "test"}`
req, _ := http.NewRequest("POST", "/api/v1/dify/retrieval", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusUnauthorized {
t.Errorf("expected 401, got %d", w.Code)
}
}
func TestDifyRetrieval_WithMetadataFilter(t *testing.T) {
h, r := setupDifyTest("user1")
h.metadataSvc = &mockMetadataService{
getFlattedMetaFn: func(kbIDs []string) (common.MetaData, error) {
return common.MetaData{}, nil
},
}
body := `{"knowledge_id":"kb1","query":"test","metadata_condition":{"conditions":[{"name":"author","comparison_operator":"eq","value":"Zhang San"}],"logic":"and"}}`
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/dify/retrieval", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String())
}
}
func TestDifyRetrieval_InvalidJSON(t *testing.T) {
_, r := setupDifyTest("user1")
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/dify/retrieval", strings.NewReader("{invalid json"))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("expected 400, got %d", w.Code)
}
}
func TestDifyRetrieval_UseKG(t *testing.T) {
h, r := setupDifyTest("user1")
h.metadataSvc = &mockMetadataService{
labelQuestionFn: func(question string, kbs []*entity.Knowledgebase) map[string]float64 {
return map[string]float64{"tag_1": 0.8}
},
}
body := `{"knowledge_id":"kb1","query":"test","use_kg":true}`
w := httptest.NewRecorder()
req, _ := http.NewRequest("POST", "/api/v1/dify/retrieval", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String())
}
}
func strPtr(s string) *string { return &s }
func TestDifyRetrieval_KBDBError(t *testing.T) {
h, r := setupDifyTest("user1")
h.kbSvc = &mockKBService{
getByIDFn: func(kbID string) (*entity.Knowledgebase, error) {
return nil, errors.New("connection refused")
},
}
w := httptest.NewRecorder()
body := `{"knowledge_id": "kb1", "query": "test"}`
req, _ := http.NewRequest("POST", "/api/v1/dify/retrieval", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusInternalServerError {
t.Errorf("expected 500 for DB error, got %d", w.Code)
}
}
func TestDifyRetrieval_DocLoadError(t *testing.T) {
h, r := setupDifyTest("user1")
h.docDAO = &mockDocDAO{
getByIDsFn: func(ids []string) ([]*entity.Document, error) {
return nil, errors.New("db unavailable")
},
}
w := httptest.NewRecorder()
body := `{"knowledge_id": "kb1", "query": "test"}`
req, _ := http.NewRequest("POST", "/api/v1/dify/retrieval", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusInternalServerError {
t.Errorf("expected 500 for doc load error, got %d", w.Code)
}
}
func TestDifyRetrieval_RetrievalNotFound(t *testing.T) {
h, r := setupDifyTest("user1")
h.retrievalSvc = &mockRetrievalService{
retrievalFn: func(ctx context.Context, req *nlp.RetrievalRequest) (*nlp.RetrievalResult, error) {
return nil, errors.New("no chunk found: not_found")
},
}
w := httptest.NewRecorder()
body := `{"knowledge_id": "kb1", "query": "test"}`
req, _ := http.NewRequest("POST", "/api/v1/dify/retrieval", strings.NewReader(body))
req.Header.Set("Content-Type", "application/json")
r.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Errorf("expected 404 for not_found, got %d", w.Code)
}
}

View File

@@ -43,6 +43,7 @@ type Router struct {
providerHandler *handler.ProviderHandler
agentHandler *handler.AgentHandler
relatedQuestionsHandler *handler.SearchbotHandler
difyRetrievalHandler *handler.DifyRetrievalHandler
}
// NewRouter create router
@@ -67,6 +68,7 @@ func NewRouter(
providerHandler *handler.ProviderHandler,
agentHandler *handler.AgentHandler,
relatedQuestionsHandler *handler.SearchbotHandler,
difyRetrievalHandler *handler.DifyRetrievalHandler,
) *Router {
return &Router{
authHandler: authHandler,
@@ -89,6 +91,7 @@ func NewRouter(
providerHandler: providerHandler,
agentHandler: agentHandler,
relatedQuestionsHandler: relatedQuestionsHandler,
difyRetrievalHandler: difyRetrievalHandler,
}
}
@@ -517,6 +520,14 @@ func (r *Router) Setup(engine *gin.Engine) {
}
// Dify retrieval routes
dify := authorized.Group("/api/v1/dify")
{
dify.POST("/retrieval", r.difyRetrievalHandler.Retrieval)
dify.GET("/retrieval", r.difyRetrievalHandler.Retrieval)
}
apiNoAuth.GET("/dify/retrieval/health", r.difyRetrievalHandler.HealthCheck)
// Handle undefined routes
engine.NoRoute(handler.HandleNoRoute)
}

View File

@@ -14,7 +14,7 @@
// limitations under the License.
//
package service
package kg
import (
"context"
@@ -30,9 +30,9 @@ import (
"go.uber.org/zap"
)
// KGSearchPipeline encapsulates the knowledge graph retrieval pipeline.
// Pipeline encapsulates the knowledge graph retrieval pipeline.
// Matches Python: rag/graphrag/search.py::KGSearch
type KGSearchPipeline struct {
type Pipeline struct {
docEngine engine.DocEngine
chatModel *modelModule.ChatModel
embModel *modelModule.EmbeddingModel
@@ -50,51 +50,51 @@ type KGSearchPipeline struct {
maxToken int
}
// KGSearchOption configures a KGSearchPipeline.
type KGSearchOption func(*KGSearchPipeline)
// Option configures a Pipeline.
type Option func(*Pipeline)
// WithKGSimThreshold sets the similarity threshold for entity and relation search.
// WithSimThreshold sets the similarity threshold for entity and relation search.
// Default: 0.3 (matches Python ent_sim_threshold, rel_sim_threshold).
func WithKGSimThreshold(v float64) KGSearchOption {
return func(p *KGSearchPipeline) { p.entSimThreshold = v; p.relSimThreshold = v }
func WithSimThreshold(v float64) Option {
return func(p *Pipeline) { p.entSimThreshold = v; p.relSimThreshold = v }
}
// WithKGDenseTopK sets the TopK for dense vector search.
// WithDenseTopK sets the TopK for dense vector search.
// Default: 1024 (matches Python get_vector topk).
func WithKGDenseTopK(v int) KGSearchOption {
return func(p *KGSearchPipeline) { p.denseTopK = v }
func WithDenseTopK(v int) Option {
return func(p *Pipeline) { p.denseTopK = v }
}
// NewKGSearchPipeline creates a KG search pipeline with the given dependencies.
// NewPipeline creates a KG search pipeline with the given dependencies.
//
// docEngine: search engine backend
// kbIDs: knowledge base IDs to search
// tenantIDs: tenant IDs (converted to index names internally)
// question: user query string
// opts: optional configuration (WithKGSimThreshold, WithKGDenseTopK)
// opts: optional configuration (WithSimThreshold, WithDenseTopK)
//
// chatModel and embModel should be set via WithChatModel/WithEmbModel setters
// or passed directly after construction.
func NewKGSearchPipeline(
func NewPipeline(
docEngine engine.DocEngine,
kbIDs []string,
tenantIDs []string,
question string,
opts ...KGSearchOption,
) *KGSearchPipeline {
opts ...Option,
) *Pipeline {
idxnms := make([]string, len(tenantIDs))
for i, tid := range tenantIDs {
idxnms[i] = indexName(tid)
}
p := &KGSearchPipeline{
p := &Pipeline{
docEngine: docEngine,
kbIDs: kbIDs,
idxnms: idxnms,
question: question,
entSimThreshold: defaultKGSimThreshold,
relSimThreshold: defaultKGSimThreshold,
denseTopK: defaultKGDenseTopK,
entSimThreshold: defaultSimThreshold,
relSimThreshold: defaultSimThreshold,
denseTopK: defaultDenseTopK,
entTopN: 6,
relTopN: 6,
commTopN: 1,
@@ -106,12 +106,22 @@ func NewKGSearchPipeline(
return p
}
// SetChatModel sets the chat model for LLM-based query rewrite.
func (p *Pipeline) SetChatModel(chatModel *modelModule.ChatModel) {
p.chatModel = chatModel
}
// SetEmbModel sets the embedding model for dense/hybrid search.
func (p *Pipeline) SetEmbModel(embModel *modelModule.EmbeddingModel) {
p.embModel = embModel
}
// Retrieval runs the full KG retrieval pipeline and returns a synthetic chunk.
func (p *KGSearchPipeline) Retrieval(ctx context.Context) (map[string]interface{}, error) {
func (p *Pipeline) Retrieval(ctx context.Context) (map[string]interface{}, error) {
// 1. Query rewrite via LLM, or fall back to raw question
ty2entsJSON := ""
if p.chatModel != nil {
typeSamples, err := searchKGTypeSamples(ctx, p.docEngine, p.idxnms, p.kbIDs)
typeSamples, err := searchTypeSamples(ctx, p.docEngine, p.idxnms, p.kbIDs)
if err != nil {
common.Warn("KG type samples search failed", zap.String("kbIDs", fmt.Sprint(p.kbIDs)))
}
@@ -157,11 +167,11 @@ func (p *KGSearchPipeline) Retrieval(ctx context.Context) (map[string]interface{
scoredRels := SortAndTrimRelations(relsFromText, p.relTopN)
// 7. Build KG content with token budget
entsRelsContent := BuildKGContent(scoredEnts, scoredRels, p.maxToken)
entsRelsContent := BuildContent(scoredEnts, scoredRels, p.maxToken)
used := NumTokensFromString(entsRelsContent)
remaining := p.maxToken - used
// 8. Search community reports with remaining token budget
communityContent := searchKGCommunityContent(ctx, p.docEngine, p.idxnms, p.kbIDs, scoredEnts, p.commTopN, &remaining)
communityContent := searchCommunityContent(ctx, p.docEngine, p.idxnms, p.kbIDs, scoredEnts, p.commTopN, &remaining)
// 9. Build synthetic chunk
return map[string]interface{}{
@@ -182,7 +192,7 @@ func (p *KGSearchPipeline) Retrieval(ctx context.Context) (map[string]interface{
}
// searchEntities searches KG entities by keyword text and optional dense vector.
func (p *KGSearchPipeline) searchEntities(ctx context.Context, entities []string) (map[string]*KGEntity, error) {
func (p *Pipeline) searchEntities(ctx context.Context, entities []string) (map[string]*KGEntity, error) {
entsReq := &types.SearchRequest{
IndexNames: p.idxnms,
KbIDs: p.kbIDs,
@@ -207,14 +217,14 @@ func (p *KGSearchPipeline) searchEntities(ctx context.Context, entities []string
if name == "" {
continue
}
e := kgEntityFromChunk(name, chunk)
e := entityFromChunk(name, chunk)
result[name] = &e
}
return result, nil
}
// searchEntityTypes searches KG entities by type keywords.
func (p *KGSearchPipeline) searchEntityTypes(ctx context.Context, typeKeywords []string) map[string]struct{} {
func (p *Pipeline) searchEntityTypes(ctx context.Context, typeKeywords []string) map[string]struct{} {
typesReq := &types.SearchRequest{
IndexNames: p.idxnms,
KbIDs: p.kbIDs,
@@ -244,7 +254,7 @@ func (p *KGSearchPipeline) searchEntityTypes(ctx context.Context, typeKeywords [
}
// searchRelations searches KG relations by entity text and optional dense vector.
func (p *KGSearchPipeline) searchRelations(ctx context.Context, entities []string) map[Edge]*KGRelation {
func (p *Pipeline) searchRelations(ctx context.Context, entities []string) map[Edge]*KGRelation {
relsReq := &types.SearchRequest{
IndexNames: p.idxnms,
KbIDs: p.kbIDs,
@@ -265,7 +275,7 @@ func (p *KGSearchPipeline) searchRelations(ctx context.Context, entities []strin
common.Warn("KG relations search failed", zap.String("kbIDs", fmt.Sprint(p.kbIDs)))
} else {
for _, chunk := range FilterChunksByScore(relsResult.Chunks, p.relSimThreshold) {
edge, rel := kgRelationFromChunk(chunk)
edge, rel := relationFromChunk(chunk)
if edge.From == "" || edge.To == "" {
continue
}

View File

@@ -1,20 +1,4 @@
//
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
package service
package kg
import (
"context"
@@ -27,67 +11,9 @@ import (
modelModule "ragflow/internal/entity/models"
)
// indexName builds the search index name from a tenant ID.
// Matches Python: rag/nlp/search.py::index_name()
func indexName(tenantID string) string {
return "ragflow_" + tenantID
}
// Python alignment defaults — match rag/graphrag/search.py retrieval() params
const (
defaultKGSimThreshold = 0.3 // Python: ent_sim_threshold, rel_sim_threshold
defaultKGDenseTopK = 1024 // Python: get_vector() topk
)
// kgEntityFromChunk parses a single entity chunk into a KGEntity.
func kgEntityFromChunk(name string, chunk map[string]interface{}) KGEntity {
e := KGEntity{}
if v, ok := chunk["_score"].(float64); ok {
e.Similarity = v
} else if v, ok := chunk["score"].(float64); ok {
e.Similarity = v
}
if v, ok := chunk["rank_flt"].(float64); ok {
e.PageRank = v
}
e.Description, _ = chunk["content_with_weight"].(string)
if raw, ok := chunk["n_hop_with_weight"].(string); ok && raw != "" {
var nhopData []struct {
Path []string `json:"path"`
Weights []float64 `json:"weights"`
}
if err := json.Unmarshal([]byte(raw), &nhopData); err == nil {
for _, item := range nhopData {
e.NhopEnts = append(e.NhopEnts, NhopEntity{
Path: item.Path,
Weights: item.Weights,
})
}
}
}
return e
}
// kgRelationFromChunk parses a single relation chunk into a KGRelation.
func kgRelationFromChunk(chunk map[string]interface{}) (Edge, KGRelation) {
r := KGRelation{}
r.Description, _ = chunk["content_with_weight"].(string)
if v, ok := chunk["weight_int"].(float64); ok {
r.PageRank = float64(v)
} else if v, ok := chunk["weight_int"].(int); ok {
r.PageRank = float64(v)
}
from, _ := chunk["from_entity_kwd"].(string)
to, _ := chunk["to_entity_kwd"].(string)
return Edge{From: from, To: to}, r
}
// KGSearchRetrieval performs a full knowledge graph retrieval and returns
// a synthetic chunk to be inserted into search results.
// Corresponds to Python: rag/graphrag/search.py::KGSearch.retrieval()
//
// This is a convenience wrapper around KGSearchPipeline.
func KGSearchRetrieval(
// Retrieval performs a full knowledge graph retrieval and returns
// a synthetic chunk. Convenience wrapper around Pipeline.
func Retrieval(
ctx context.Context,
docEngine engine.DocEngine,
chatModel *modelModule.ChatModel,
@@ -96,16 +22,16 @@ func KGSearchRetrieval(
tenantIDs []string,
question string,
) (map[string]interface{}, error) {
p := &KGSearchPipeline{
p := &Pipeline{
docEngine: docEngine,
chatModel: chatModel,
embModel: embModel,
kbIDs: kbIDs,
idxnms: makeIndexNames(tenantIDs),
question: question,
entSimThreshold: defaultKGSimThreshold,
relSimThreshold: defaultKGSimThreshold,
denseTopK: defaultKGDenseTopK,
entSimThreshold: defaultSimThreshold,
relSimThreshold: defaultSimThreshold,
denseTopK: defaultDenseTopK,
entTopN: 6,
relTopN: 6,
commTopN: 1,
@@ -123,8 +49,13 @@ func makeIndexNames(tenantIDs []string) []string {
return idxnms
}
// searchKGTypeSamples searches for ty2ents data.
func searchKGTypeSamples(ctx context.Context, docEngine engine.DocEngine, idxnms []string, kbIDs []string) (map[string][]string, error) {
// indexName builds the search index name from a tenant ID.
func indexName(tenantID string) string {
return "ragflow_" + tenantID
}
// searchTypeSamples searches for ty2ents data.
func searchTypeSamples(ctx context.Context, docEngine engine.DocEngine, idxnms []string, kbIDs []string) (map[string][]string, error) {
req := &types.SearchRequest{
IndexNames: idxnms,
KbIDs: kbIDs,
@@ -153,8 +84,8 @@ func searchKGTypeSamples(ctx context.Context, docEngine engine.DocEngine, idxnms
return typeMap, nil
}
// searchKGCommunityContent searches for community reports and formats them.
func searchKGCommunityContent(ctx context.Context, docEngine engine.DocEngine, idxnms []string, kbIDs []string, scoredEnts []ScoredEntity, topN int, maxToken *int) string {
// searchCommunityContent searches for community reports and formats them.
func searchCommunityContent(ctx context.Context, docEngine engine.DocEngine, idxnms []string, kbIDs []string, scoredEnts []ScoredEntity, topN int, maxToken *int) string {
if maxToken == nil || len(scoredEnts) == 0 || *maxToken <= 0 {
return ""
}
@@ -189,7 +120,6 @@ func searchKGCommunityContent(ctx context.Context, docEngine engine.DocEngine, i
if title == "" && raw == "" {
continue
}
// Parse JSON for nested report/evidences fields (Python: json.loads)
report := raw
evidence := ""
var parsed map[string]interface{}
@@ -212,30 +142,52 @@ func searchKGCommunityContent(ctx context.Context, docEngine engine.DocEngine, i
return bld
}
// buildMatchDenseExpr constructs a MatchDenseExpr from an embedding vector.
// This is a pure function — no I/O, no external dependencies.
func buildMatchDenseExpr(vector []float64, topN int, similarity float64) *types.MatchDenseExpr {
vectorColumnName := fmt.Sprintf("q_%d_vec", len(vector))
return &types.MatchDenseExpr{
VectorColumnName: vectorColumnName,
EmbeddingData: vector,
EmbeddingDataType: "float",
DistanceType: "cosine",
TopN: topN,
ExtraOptions: map[string]interface{}{"similarity": similarity},
// entityFromChunk parses a single entity chunk into a KGEntity.
func entityFromChunk(name string, chunk map[string]interface{}) KGEntity {
e := KGEntity{}
if v, ok := chunk["_score"].(float64); ok {
e.Similarity = v
} else if v, ok := chunk["score"].(float64); ok {
e.Similarity = v
}
if v, ok := chunk["rank_flt"].(float64); ok {
e.PageRank = v
}
e.Description, _ = chunk["content_with_weight"].(string)
if raw, ok := chunk["n_hop_with_weight"].(string); ok && raw != "" {
var nhopData []struct {
Path []string `json:"path"`
Weights []float64 `json:"weights"`
}
if err := json.Unmarshal([]byte(raw), &nhopData); err == nil {
for _, item := range nhopData {
e.NhopEnts = append(e.NhopEnts, NhopEntity{
Path: item.Path,
Weights: item.Weights,
})
}
}
}
return e
}
// buildFusionExpr constructs a FusionExpr for weighted-sum hybrid search.
// This is a pure function — no I/O, no external dependencies.
func buildFusionExpr(textWeight, vectorWeight float64, topN int) *types.FusionExpr {
return &types.FusionExpr{
Method: "weighted_sum",
TopN: topN,
FusionParams: map[string]interface{}{
"weights": fmt.Sprintf("%.2f,%.2f", textWeight, vectorWeight),
},
// relationFromChunk parses a single relation chunk into a KGRelation.
func relationFromChunk(chunk map[string]interface{}) (Edge, KGRelation) {
r := KGRelation{}
r.Description, _ = chunk["content_with_weight"].(string)
if v, ok := chunk["_score"].(float64); ok {
r.Sim = v
} else if v, ok := chunk["score"].(float64); ok {
r.Sim = v
}
if v, ok := chunk["weight_int"].(float64); ok {
r.PageRank = float64(v)
} else if v, ok := chunk["weight_int"].(int); ok {
r.PageRank = float64(v)
}
from, _ := chunk["from_entity_kwd"].(string)
to, _ := chunk["to_entity_kwd"].(string)
return Edge{From: from, To: to}, r
}
// buildSearchExprs constructs MatchExprs for KG entity/relation search.
@@ -252,12 +204,35 @@ func buildSearchExprs(embModel *modelModule.EmbeddingModel, matchText *types.Mat
return []interface{}{matchText}
}
denseExpr := buildMatchDenseExpr(embeddings[0].Embedding, denseTopK, simThreshold)
fusionExpr := buildFusionExpr(0.5, 0.5, matchText.TopN)
fusionExpr := buildFusionExpr(defaultTextWeight, defaultVectorWeight, matchText.TopN)
return []interface{}{matchText, denseExpr, fusionExpr}
}
// buildMatchDenseExpr constructs a MatchDenseExpr from an embedding vector.
func buildMatchDenseExpr(vector []float64, topN int, similarity float64) *types.MatchDenseExpr {
vectorColumnName := fmt.Sprintf("q_%d_vec", len(vector))
return &types.MatchDenseExpr{
VectorColumnName: vectorColumnName,
EmbeddingData: vector,
EmbeddingDataType: "float",
DistanceType: "cosine",
TopN: topN,
ExtraOptions: map[string]interface{}{"similarity": similarity},
}
}
// buildFusionExpr constructs a FusionExpr for weighted-sum hybrid search.
func buildFusionExpr(textWeight, vectorWeight float64, topN int) *types.FusionExpr {
return &types.FusionExpr{
Method: "weighted_sum",
TopN: topN,
FusionParams: map[string]interface{}{
"weights": fmt.Sprintf("%.2f,%.2f", textWeight, vectorWeight),
},
}
}
// queryRewrite attempts LLM-based query rewrite, falling back to raw question.
// ty2entsJSON is the JSON-encoded type→entities mapping for prompt context.
func queryRewrite(chatModel *modelModule.ChatModel, question string, ty2entsJSON string) (typeKeywords, entities []string) {
if question == "" {
return nil, nil
@@ -276,6 +251,14 @@ func queryRewrite(chatModel *modelModule.ChatModel, question string, ty2entsJSON
}
}
}
// Fallback: use raw question as single entity
return nil, []string{question}
}
// Python alignment defaults
const (
defaultSimThreshold = 0.3
defaultDenseTopK = 1024
// defaultTextWeight / defaultVectorWeight are fusion weights for hybrid search (equal by default).
defaultTextWeight = 0.5
defaultVectorWeight = 0.5
)

View File

@@ -14,7 +14,7 @@
// limitations under the License.
//
package service
package kg
import (
"context"
@@ -56,16 +56,16 @@ func (m *mockRetrievalEngine) Search(ctx context.Context, req *types.SearchReque
return &types.SearchResult{}, nil
}
// --- kgEntityFromChunk ---
// --- entityFromChunk ---
func TestKgEntityFromChunk_Basic(t *testing.T) {
func TestEntityFromChunk_Basic(t *testing.T) {
chunk := map[string]interface{}{
"_score": 0.85,
"rank_flt": 0.9,
"content_with_weight": "Founder of SpaceX",
"n_hop_with_weight": `[{"path":["A","B"],"weights":[0.8]}]`,
}
e := kgEntityFromChunk("Elon Musk", chunk)
e := entityFromChunk("Elon Musk", chunk)
if e.Similarity != 0.85 {
t.Errorf("expected Sim=0.85, got %f", e.Similarity)
}
@@ -80,32 +80,32 @@ func TestKgEntityFromChunk_Basic(t *testing.T) {
}
}
func TestKgEntityFromChunk_ScoreFallback(t *testing.T) {
func TestEntityFromChunk_ScoreFallback(t *testing.T) {
chunk := map[string]interface{}{"score": 0.75}
e := kgEntityFromChunk("Test", chunk)
e := entityFromChunk("Test", chunk)
if e.Similarity != 0.75 {
t.Errorf("expected Sim=0.75 from score field, got %f", e.Similarity)
}
}
func TestKgEntityFromChunk_MissingFields(t *testing.T) {
func TestEntityFromChunk_MissingFields(t *testing.T) {
chunk := map[string]interface{}{}
e := kgEntityFromChunk("Empty", chunk)
e := entityFromChunk("Empty", chunk)
if e.Similarity != 0 || e.PageRank != 0 || len(e.NhopEnts) != 0 {
t.Errorf("expected zero defaults, got %+v", e)
}
}
// --- kgRelationFromChunk ---
// --- relationFromChunk ---
func TestKgRelationFromChunk_Basic(t *testing.T) {
func TestRelationFromChunk_Basic(t *testing.T) {
chunk := map[string]interface{}{
"from_entity_kwd": "Elon Musk",
"to_entity_kwd": "SpaceX",
"weight_int": float64(5),
"content_with_weight": "Founder",
}
edge, rel := kgRelationFromChunk(chunk)
edge, rel := relationFromChunk(chunk)
if edge.From != "Elon Musk" || edge.To != "SpaceX" {
t.Errorf("expected Elon Musk→SpaceX, got %v", edge)
}
@@ -114,17 +114,17 @@ func TestKgRelationFromChunk_Basic(t *testing.T) {
}
}
func TestKgRelationFromChunk_MissingFrom(t *testing.T) {
func TestRelationFromChunk_MissingFrom(t *testing.T) {
chunk := map[string]interface{}{"to_entity_kwd": "B"}
edge, _ := kgRelationFromChunk(chunk)
edge, _ := relationFromChunk(chunk)
if edge.From != "" {
t.Error("expected empty from")
}
}
// --- searchKGTypeSamples ---
// --- searchTypeSamples ---
func TestSearchKGTypeSamples_Success(t *testing.T) {
func TestSearchTypeSamples_Success(t *testing.T) {
data, _ := json.Marshal(map[string][]string{"PERSON": {"Elon Musk"}})
mock := &mockRetrievalEngine{
results: map[string]*types.SearchResult{
@@ -133,7 +133,7 @@ func TestSearchKGTypeSamples_Success(t *testing.T) {
}},
},
}
result, err := searchKGTypeSamples(context.Background(), mock, []string{"ragflow_tenant1"}, []string{"kb1"})
result, err := searchTypeSamples(context.Background(), mock, []string{"ragflow_tenant1"}, []string{"kb1"})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
@@ -142,9 +142,9 @@ func TestSearchKGTypeSamples_Success(t *testing.T) {
}
}
func TestSearchKGTypeSamples_Empty(t *testing.T) {
func TestSearchTypeSamples_Empty(t *testing.T) {
mock := &mockRetrievalEngine{}
result, err := searchKGTypeSamples(context.Background(), mock, []string{"ragflow_tenant1"}, []string{"kb1"})
result, err := searchTypeSamples(context.Background(), mock, []string{"ragflow_tenant1"}, []string{"kb1"})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
@@ -153,9 +153,9 @@ func TestSearchKGTypeSamples_Empty(t *testing.T) {
}
}
// --- KGSearchRetrieval ---
// --- Retrieval ---
func TestKGSearchRetrieval_Basic(t *testing.T) {
func TestRetrieval_Basic(t *testing.T) {
mock := &mockRetrievalEngine{
results: map[string]*types.SearchResult{
"entity": {Chunks: []map[string]interface{}{
@@ -172,9 +172,9 @@ func TestKGSearchRetrieval_Basic(t *testing.T) {
}},
},
}
result, err := KGSearchRetrieval(context.Background(), mock, nil, nil, []string{"kb1"}, []string{"tenant1"}, "Elon Musk")
result, err := Retrieval(context.Background(), mock, nil, nil, []string{"kb1"}, []string{"tenant1"}, "Elon Musk")
if err != nil {
t.Fatalf("KGSearchRetrieval failed: %v", err)
t.Fatalf("Retrieval failed: %v", err)
}
if result == nil {
t.Fatal("expected non-nil result")
@@ -194,11 +194,11 @@ func TestKGSearchRetrieval_Basic(t *testing.T) {
}
}
func TestKGSearchRetrieval_NoEntities(t *testing.T) {
func TestRetrieval_NoEntities(t *testing.T) {
mock := &mockRetrievalEngine{}
result, err := KGSearchRetrieval(context.Background(), mock, nil, nil, []string{"kb1"}, []string{"tenant1"}, "test")
result, err := Retrieval(context.Background(), mock, nil, nil, []string{"kb1"}, []string{"tenant1"}, "test")
if err != nil {
t.Fatalf("KGSearchRetrieval failed: %v", err)
t.Fatalf("Retrieval failed: %v", err)
}
if result == nil {
t.Fatal("expected non-nil result")
@@ -211,7 +211,7 @@ func TestKGSearchRetrieval_NoEntities(t *testing.T) {
// TestEntitySearch_MultiEntities verifies that all entities are used in search query.
func TestKGSearchRetrieval_WithChatModel(t *testing.T) {
func TestRetrieval_WithChatModel(t *testing.T) {
mock := &mockRetrievalEngine{
results: map[string]*types.SearchResult{
"entity": {Chunks: []map[string]interface{}{
@@ -225,9 +225,9 @@ func TestKGSearchRetrieval_WithChatModel(t *testing.T) {
// chatModel with nil ModelName so queryRewrite falls back to raw question,
// but the ty2entsJSON construction path is still exercised.
chatModel := &modelModule.ChatModel{ModelName: nil, APIConfig: nil}
result, err := KGSearchRetrieval(context.Background(), mock, chatModel, nil, []string{"kb1"}, []string{"tenant1"}, "Elon Musk")
result, err := Retrieval(context.Background(), mock, chatModel, nil, []string{"kb1"}, []string{"tenant1"}, "Elon Musk")
if err != nil {
t.Fatalf("KGSearchRetrieval failed: %v", err)
t.Fatalf("Retrieval failed: %v", err)
}
if result == nil {
t.Fatal("expected non-nil result")
@@ -415,7 +415,7 @@ func TestBuildSearchExprs_WithEmbModel(t *testing.T) {
MatchingText: "Elon Musk SpaceX",
TopN: 50,
}
exprs := buildSearchExprs(embModel, matchText, defaultKGSimThreshold, defaultKGDenseTopK)
exprs := buildSearchExprs(embModel, matchText, defaultSimThreshold, defaultDenseTopK)
// Verify Embed was called with matchText.MatchingText, not raw question
if len(driver.capturedTexts) != 1 || driver.capturedTexts[0] != "Elon Musk SpaceX" {
t.Errorf("expected Embed to receive %q, got %v", "Elon Musk SpaceX", driver.capturedTexts)
@@ -439,11 +439,11 @@ func TestBuildSearchExprs_WithEmbModel(t *testing.T) {
if md.VectorColumnName != "q_3_vec" {
t.Errorf("expected q_3_vec, got %q", md.VectorColumnName)
}
if md.TopN != defaultKGDenseTopK {
t.Errorf("expected TopN=%d (Python alignment), got %d", defaultKGDenseTopK, md.TopN)
if md.TopN != defaultDenseTopK {
t.Errorf("expected TopN=%d (Python alignment), got %d", defaultDenseTopK, md.TopN)
}
if md.ExtraOptions["similarity"] != defaultKGSimThreshold {
t.Errorf("expected similarity=%v (Python alignment), got %v", defaultKGSimThreshold, md.ExtraOptions["similarity"])
if md.ExtraOptions["similarity"] != defaultSimThreshold {
t.Errorf("expected similarity=%v (Python alignment), got %v", defaultSimThreshold, md.ExtraOptions["similarity"])
}
// Index 2: FusionExpr
fu, ok := exprs[2].(*types.FusionExpr)
@@ -463,7 +463,7 @@ func TestBuildSearchExprs_EmbModelFallback(t *testing.T) {
MatchingText: "fallback test",
TopN: 10,
}
exprs := buildSearchExprs(embModel, matchText, defaultKGSimThreshold, defaultKGDenseTopK)
exprs := buildSearchExprs(embModel, matchText, defaultSimThreshold, defaultDenseTopK)
// Should fall back to text-only when Embed fails
if len(exprs) != 1 {
t.Fatalf("expected 1 expr (text-only fallback), got %d", len(exprs))
@@ -476,11 +476,11 @@ func TestBuildSearchExprs_EmbModelFallback(t *testing.T) {
// --- Python alignment defaults ---
func TestDefaultValuesMatchPython(t *testing.T) {
if defaultKGSimThreshold != 0.3 {
t.Errorf("expected 0.3 (Python ent_sim_threshold), got %f", defaultKGSimThreshold)
if defaultSimThreshold != 0.3 {
t.Errorf("expected 0.3 (Python ent_sim_threshold), got %f", defaultSimThreshold)
}
if defaultKGDenseTopK != 1024 {
t.Errorf("expected 1024 (Python get_vector topk), got %d", defaultKGDenseTopK)
if defaultDenseTopK != 1024 {
t.Errorf("expected 1024 (Python get_vector topk), got %d", defaultDenseTopK)
}
}
@@ -506,11 +506,11 @@ func TestIndexName_Empty(t *testing.T) {
}
}
// --- searchKGCommunityContent ---
// --- searchCommunityContent ---
func TestSearchKGCommunityContent_EmptyEntities(t *testing.T) {
mock := &mockRetrievalEngine{}
result := searchKGCommunityContent(context.Background(), mock, []string{"ragflow_t1"}, []string{"kb1"}, nil, 1, intPtr(100))
result := searchCommunityContent(context.Background(), mock, []string{"ragflow_t1"}, []string{"kb1"}, nil, 1, intPtr(100))
if result != "" {
t.Errorf("expected empty, got %q", result)
}
@@ -527,7 +527,7 @@ func TestSearchKGCommunityContent_WithContent(t *testing.T) {
}},
},
}
result := searchKGCommunityContent(context.Background(), mock, []string{"ragflow_t1"}, []string{"kb1"}, []ScoredEntity{{Entity: "E1"}}, 1, intPtr(500))
result := searchCommunityContent(context.Background(), mock, []string{"ragflow_t1"}, []string{"kb1"}, []ScoredEntity{{Entity: "E1"}}, 1, intPtr(500))
if result == "" {
t.Fatal("expected non-empty result")
}
@@ -547,7 +547,7 @@ func TestSearchKGCommunityContent_WithContent(t *testing.T) {
func TestSearchKGCommunityContent_NilMaxToken(t *testing.T) {
mock := &mockRetrievalEngine{}
result := searchKGCommunityContent(context.Background(), mock, []string{"ragflow_t1"}, []string{"kb1"}, []ScoredEntity{{Entity: "E1"}}, 1, nil)
result := searchCommunityContent(context.Background(), mock, []string{"ragflow_t1"}, []string{"kb1"}, []ScoredEntity{{Entity: "E1"}}, 1, nil)
if result != "" {
t.Errorf("expected empty when maxToken is nil, got %q", result)
}
@@ -555,7 +555,7 @@ func TestSearchKGCommunityContent_NilMaxToken(t *testing.T) {
func TestSearchKGCommunityContent_ZeroMaxToken(t *testing.T) {
mock := &mockRetrievalEngine{}
result := searchKGCommunityContent(context.Background(), mock, []string{"ragflow_t1"}, []string{"kb1"}, []ScoredEntity{{Entity: "E1"}}, 1, intPtr(0))
result := searchCommunityContent(context.Background(), mock, []string{"ragflow_t1"}, []string{"kb1"}, []ScoredEntity{{Entity: "E1"}}, 1, intPtr(0))
if result != "" {
t.Errorf("expected empty when maxToken=0, got %q", result)
}

View File

@@ -14,7 +14,7 @@
// limitations under the License.
//
package service
package kg
import (
"bytes"
@@ -227,9 +227,9 @@ func FormatRelationsToCSV(relations []ScoredRelation, maxToken int) (csv string,
return b.String(), maxToken
}
// BuildKGContent assembles the final knowledge graph content string.
// BuildContent assembles the final knowledge graph content string.
// Python equivalent: lines 267-291
func BuildKGContent(
func BuildContent(
entities []ScoredEntity,
relations []ScoredRelation,
maxToken int,

View File

@@ -0,0 +1,306 @@
package kg
import (
"context"
"fmt"
"ragflow/internal/engine"
"encoding/json"
"ragflow/internal/engine/types"
modelModule "ragflow/internal/entity/models"
)
// NhopEntityNames extracts unique entity names from an n_hop_with_weight JSON string.
func NhopEntityNames(nHopJSON string) []string {
if nHopJSON == "" {
return nil
}
var nhopData []struct {
Path []string `json:"path"`
Weights []float64 `json:"weights"`
}
if err := json.Unmarshal([]byte(nHopJSON), &nhopData); err != nil {
return nil
}
seen := make(map[string]struct{})
for _, item := range nhopData {
for _, name := range item.Path {
seen[name] = struct{}{}
}
}
result := make([]string, 0, len(seen))
for name := range seen {
result = append(result, name)
}
return result
}
// SearchEntities searches for KG entities matching a question.
func SearchEntities(ctx context.Context, docEngine engine.DocEngine, kbIDs []string, question string, embModel *modelModule.EmbeddingModel, topN int) ([]KGEntity, error) {
dense, err := buildDenseExpr(embModel, question, topN)
if err != nil {
return nil, err
}
searchReq := buildEntitySearchRequest(kbIDs, question, dense, topN)
result, err := docEngine.Search(ctx, searchReq)
if err != nil {
return nil, fmt.Errorf("KG entity search failed: %w", err)
}
return ParseEntityChunks(result.Chunks), nil
}
// SearchEntitiesByTypes searches for KG entities by type keywords.
func SearchEntitiesByTypes(ctx context.Context, docEngine engine.DocEngine, kbIDs []string, typeKeywords []string, topN int) ([]KGEntity, error) {
searchReq := buildEntityTypeSearchRequest(kbIDs, typeKeywords, topN)
result, err := docEngine.Search(ctx, searchReq)
if err != nil {
return nil, fmt.Errorf("KG entity type search failed: %w", err)
}
return ParseEntityChunks(result.Chunks), nil
}
// SearchRelations searches for KG relations matching a question.
func SearchRelations(ctx context.Context, docEngine engine.DocEngine, kbIDs []string, question string, embModel *modelModule.EmbeddingModel, topN int) ([]KGRelation, error) {
dense, err := buildDenseExpr(embModel, question, topN)
if err != nil {
return nil, err
}
searchReq := buildRelationSearchRequest(kbIDs, question, dense, topN)
result, err := docEngine.Search(ctx, searchReq)
if err != nil {
return nil, fmt.Errorf("KG relation search failed: %w", err)
}
return ParseRelationChunks(result.Chunks), nil
}
// SearchCommunityReports searches for community reports related to given entities.
func SearchCommunityReports(ctx context.Context, docEngine engine.DocEngine, kbIDs []string, entityNames []string, topN int) ([]KGCommunityReport, error) {
searchReq := buildCommunitySearchRequest(kbIDs, entityNames, topN)
result, err := docEngine.Search(ctx, searchReq)
if err != nil {
return nil, fmt.Errorf("KG community search failed: %w", err)
}
return ParseCommunityReportChunks(result.Chunks), nil
}
// SearchTypeSamples retrieves the typeu2192entities mapping from ES.
func SearchTypeSamples(ctx context.Context, docEngine engine.DocEngine, kbIDs []string) (map[string][]string, error) {
searchReq := buildTypeSamplesSearchRequest(kbIDs)
result, err := docEngine.Search(ctx, searchReq)
if err != nil {
return nil, err
}
return ParseTypeSamplesChunks(result.Chunks), nil
}
// buildDenseExpr computes the query vector and returns a MatchDenseExpr.
func buildDenseExpr(embModel *modelModule.EmbeddingModel, question string, topN int) (*types.MatchDenseExpr, error) {
if embModel == nil || question == "" {
return nil, nil
}
embCfg := &modelModule.EmbeddingConfig{Dimension: 0}
embeddings, err := embModel.ModelDriver.Embed(embModel.ModelName, []string{question}, embModel.APIConfig, embCfg)
if err != nil {
return nil, fmt.Errorf("KG entity embed failed: %w", err)
}
if len(embeddings) == 0 || len(embeddings[0].Embedding) == 0 {
return nil, nil
}
vector := embeddings[0].Embedding
return &types.MatchDenseExpr{
VectorColumnName: fmt.Sprintf("q_%d_vec", len(vector)),
EmbeddingData: vector,
EmbeddingDataType: "float",
DistanceType: "cosine",
TopN: topN,
ExtraOptions: map[string]interface{}{"similarity": 0.3},
}, nil
}
// buildHybridExpr returns MatchExprs for hybrid search (dense + text + fusion).
func buildHybridExpr(dense *types.MatchDenseExpr, text *types.MatchTextExpr, topN int) []interface{} {
if dense == nil {
return []interface{}{text}
}
fusion := buildFusionExpr(defaultTextWeight, defaultVectorWeight, topN)
return []interface{}{dense, text, fusion}
}
// buildEntitySearchRequest constructs a SearchRequest for KG entities.
func buildEntitySearchRequest(kbIDs []string, question string, dense *types.MatchDenseExpr, topN int) *types.SearchRequest {
req := &types.SearchRequest{
KbIDs: kbIDs,
SelectFields: []string{"entity_kwd", "entity_type_kwd", "rank_flt", "content_with_weight", "n_hop_with_weight", "_score"},
Limit: topN,
Filter: map[string]interface{}{"knowledge_graph_kwd": "entity"},
}
if question != "" {
textExpr := &types.MatchTextExpr{
Fields: []string{"entity_kwd^10", "content_ltks^2"},
MatchingText: question,
TopN: topN,
}
req.MatchExprs = buildHybridExpr(dense, textExpr, topN)
}
return req
}
// buildEntityTypeSearchRequest constructs a SearchRequest for KG entities by type.
func buildEntityTypeSearchRequest(kbIDs []string, typeKeywords []string, topN int) *types.SearchRequest {
req := &types.SearchRequest{
KbIDs: kbIDs,
SelectFields: []string{"entity_kwd", "entity_type_kwd"},
Limit: topN,
Filter: map[string]interface{}{"knowledge_graph_kwd": "entity"},
}
if len(typeKeywords) > 0 {
filters := make([]interface{}, len(typeKeywords))
for i, t := range typeKeywords {
filters[i] = t
}
req.Filter["entity_type_kwd"] = filters
}
return req
}
// buildRelationSearchRequest constructs a SearchRequest for KG relations.
func buildRelationSearchRequest(kbIDs []string, question string, dense *types.MatchDenseExpr, topN int) *types.SearchRequest {
req := &types.SearchRequest{
KbIDs: kbIDs,
SelectFields: []string{"from_entity_kwd", "to_entity_kwd", "weight_int", "content_with_weight", "_score"},
Limit: topN,
Filter: map[string]interface{}{"knowledge_graph_kwd": "relation"},
}
if question != "" {
textExpr := &types.MatchTextExpr{
Fields: []string{"content_ltks", "from_entity_kwd", "to_entity_kwd"},
MatchingText: question,
TopN: topN,
}
req.MatchExprs = buildHybridExpr(dense, textExpr, topN)
}
return req
}
// buildCommunitySearchRequest constructs a SearchRequest for KG community reports.
func buildCommunitySearchRequest(kbIDs []string, entityNames []string, topN int) *types.SearchRequest {
req := &types.SearchRequest{
KbIDs: kbIDs,
SelectFields: []string{"docnm_kwd", "content_with_weight", "weight_flt", "entities_kwd"},
Limit: topN,
Filter: map[string]interface{}{"knowledge_graph_kwd": "community_report"},
OrderBy: (&types.OrderByExpr{}).Desc("weight_flt"),
}
if len(entityNames) > 0 {
filters := make([]interface{}, len(entityNames))
for i, name := range entityNames {
filters[i] = name
}
req.Filter["entities_kwd"] = filters
}
return req
}
// buildTypeSamplesSearchRequest constructs a SearchRequest for type samples.
func buildTypeSamplesSearchRequest(kbIDs []string) *types.SearchRequest {
return &types.SearchRequest{
KbIDs: kbIDs,
SelectFields: []string{"content_with_weight"},
Limit: 10000,
Filter: map[string]interface{}{"knowledge_graph_kwd": "ty2ents"},
}
}
// ParseEntityChunks converts raw search result chunks into KGEntity slices.
func ParseEntityChunks(chunks []map[string]interface{}) []KGEntity {
var entities []KGEntity
for _, chunk := range chunks {
name, _ := chunk["entity_kwd"].(string)
if name == "" {
// Try extracting from list
if list, ok := chunk["entity_kwd"].([]interface{}); ok && len(list) > 0 {
name, _ = list[0].(string)
}
}
if name == "" {
continue
}
typ, _ := chunk["entity_type_kwd"].(string)
e := KGEntity{Name: name, Type: typ}
if v, ok := chunk["rank_flt"].(float64); ok {
e.PageRank = v
}
if v, ok := chunk["_score"].(float64); ok {
e.Similarity = v
} else if v, ok := chunk["score"].(float64); ok {
e.Similarity = v
}
e.Description, _ = chunk["content_with_weight"].(string)
entities = append(entities, e)
}
return entities
}
// ParseRelationChunks converts raw search result chunks into KGRelation slices.
func ParseRelationChunks(chunks []map[string]interface{}) []KGRelation {
var relations []KGRelation
for _, chunk := range chunks {
from, _ := chunk["from_entity_kwd"].(string)
to, _ := chunk["to_entity_kwd"].(string)
if from == "" || to == "" {
continue
}
r := KGRelation{From: from, To: to}
if v, ok := chunk["_score"].(float64); ok {
r.Sim = v
} else if v, ok := chunk["score"].(float64); ok {
r.Sim = v
}
if v, ok := chunk["weight_int"].(float64); ok {
r.PageRank = v
} else if v, ok := chunk["weight_int"].(int); ok {
r.PageRank = float64(v)
}
r.Description, _ = chunk["content_with_weight"].(string)
relations = append(relations, r)
}
return relations
}
// ParseCommunityReportChunks converts raw search result chunks into KGCommunityReport slices.
func ParseCommunityReportChunks(chunks []map[string]interface{}) []KGCommunityReport {
var reports []KGCommunityReport
for _, chunk := range chunks {
title, _ := chunk["docnm_kwd"].(string)
content, _ := chunk["content_with_weight"].(string)
if title == "" && content == "" {
continue
}
r := KGCommunityReport{Title: title, Content: content}
if v, ok := chunk["weight_flt"].(float64); ok {
r.Weight = v
}
r.Entities, _ = chunk["entities_kwd"].(string)
reports = append(reports, r)
}
return reports
}
// ParseTypeSamplesChunks converts raw search result chunks into a typeu2192entities map.
func ParseTypeSamplesChunks(chunks []map[string]interface{}) map[string][]string {
typeMap := make(map[string][]string)
for _, chunk := range chunks {
content, ok := chunk["content_with_weight"].(string)
if !ok || content == "" {
continue
}
var parsed map[string][]string
if err := json.Unmarshal([]byte(content), &parsed); err != nil {
continue
}
for typ, entities := range parsed {
typeMap[typ] = append(typeMap[typ], entities...)
}
}
return typeMap
}

View File

@@ -14,7 +14,7 @@
// limitations under the License.
//
package service
package kg
import (
"context"
@@ -159,13 +159,13 @@ func TestBuildTypeSamplesSearchRequest(t *testing.T) {
}
}
// --- ParseKGEntityChunks ---
// --- ParseEntityChunks ---
func TestParseKGEntityChunks_Basic(t *testing.T) {
func TestParseEntityChunks_Basic(t *testing.T) {
chunks := []map[string]interface{}{
{"entity_kwd": "Elon Musk", "entity_type_kwd": "PERSON", "rank_flt": 0.9, "_score": 0.85, "content_with_weight": "Founder of SpaceX"},
}
entities := ParseKGEntityChunks(chunks)
entities := ParseEntityChunks(chunks)
if len(entities) != 1 {
t.Fatalf("expected 1, got %d", len(entities))
}
@@ -174,98 +174,98 @@ func TestParseKGEntityChunks_Basic(t *testing.T) {
}
}
func TestParseKGEntityChunks_List(t *testing.T) {
func TestParseEntityChunks_List(t *testing.T) {
chunks := []map[string]interface{}{
{"entity_kwd": []interface{}{"Elon Musk", "elon_musk"}},
}
entities := ParseKGEntityChunks(chunks)
entities := ParseEntityChunks(chunks)
if len(entities) != 1 || entities[0].Name != "Elon Musk" {
t.Errorf("expected first list element, got %q", entities[0].Name)
}
}
func TestParseKGEntityChunks_EmptyName(t *testing.T) {
func TestParseEntityChunks_EmptyName(t *testing.T) {
chunks := []map[string]interface{}{{"entity_type_kwd": "PERSON"}}
if len(ParseKGEntityChunks(chunks)) != 0 {
if len(ParseEntityChunks(chunks)) != 0 {
t.Error("expected 0 for missing name")
}
}
func TestParseKGEntityChunks_ScoreFallback(t *testing.T) {
func TestParseEntityChunks_ScoreFallback(t *testing.T) {
chunks := []map[string]interface{}{{"entity_kwd": "Test", "score": 0.75}}
if ParseKGEntityChunks(chunks)[0].Similarity != 0.75 {
if ParseEntityChunks(chunks)[0].Similarity != 0.75 {
t.Error("expected 0.75 from score field")
}
}
func TestParseKGEntityChunks_NilInput(t *testing.T) {
if len(ParseKGEntityChunks(nil)) != 0 {
func TestParseEntityChunks_NilInput(t *testing.T) {
if len(ParseEntityChunks(nil)) != 0 {
t.Error("expected 0 for nil input")
}
}
// --- ParseKGRelationChunks ---
// --- ParseRelationChunks ---
func TestParseKGRelationChunks_Basic(t *testing.T) {
func TestParseRelationChunks_Basic(t *testing.T) {
chunks := []map[string]interface{}{
{"from_entity_kwd": "Elon Musk", "to_entity_kwd": "SpaceX", "weight_int": float64(5), "content_with_weight": "Founder"},
}
relations := ParseKGRelationChunks(chunks)
if len(relations) != 1 || relations[0].From != "Elon Musk" || relations[0].Weight != 5 {
relations := ParseRelationChunks(chunks)
if len(relations) != 1 || relations[0].From != "Elon Musk" || relations[0].PageRank != 5 {
t.Errorf("unexpected: %+v", relations[0])
}
}
func TestParseKGRelationChunks_IntWeight(t *testing.T) {
func TestParseRelationChunks_IntWeight(t *testing.T) {
chunks := []map[string]interface{}{{"from_entity_kwd": "A", "to_entity_kwd": "B", "weight_int": 3}}
if ParseKGRelationChunks(chunks)[0].Weight != 3 {
if ParseRelationChunks(chunks)[0].PageRank != 3 {
t.Error("expected weight 3")
}
}
func TestParseKGRelationChunks_EmptyFrom(t *testing.T) {
if len(ParseKGRelationChunks([]map[string]interface{}{{"to_entity_kwd": "B"}})) != 0 {
func TestParseRelationChunks_EmptyFrom(t *testing.T) {
if len(ParseRelationChunks([]map[string]interface{}{{"to_entity_kwd": "B"}})) != 0 {
t.Error("expected 0 for missing from")
}
}
func TestParseKGRelationChunks_NilInput(t *testing.T) {
if len(ParseKGRelationChunks(nil)) != 0 {
func TestParseRelationChunks_NilInput(t *testing.T) {
if len(ParseRelationChunks(nil)) != 0 {
t.Error("expected 0 for nil")
}
}
// --- ParseKGCommunityReportChunks ---
// --- ParseCommunityReportChunks ---
func TestParseKGCommunityReportChunks_Basic(t *testing.T) {
func TestParseCommunityReportChunks_Basic(t *testing.T) {
chunks := []map[string]interface{}{
{"docnm_kwd": "Report 1", "content_with_weight": "content", "weight_flt": 0.95, "entities_kwd": "A, B"},
}
reports := ParseKGCommunityReportChunks(chunks)
reports := ParseCommunityReportChunks(chunks)
if len(reports) != 1 || reports[0].Title != "Report 1" || reports[0].Weight != 0.95 {
t.Errorf("unexpected: %+v", reports[0])
}
}
func TestParseKGCommunityReportChunks_EmptyTitle(t *testing.T) {
if len(ParseKGCommunityReportChunks([]map[string]interface{}{{"weight_flt": 0.5}})) != 0 {
func TestParseCommunityReportChunks_EmptyTitle(t *testing.T) {
if len(ParseCommunityReportChunks([]map[string]interface{}{{"weight_flt": 0.5}})) != 0 {
t.Error("expected 0 for empty title and content")
}
}
func TestParseKGCommunityReportChunks_NilInput(t *testing.T) {
if len(ParseKGCommunityReportChunks(nil)) != 0 {
func TestParseCommunityReportChunks_NilInput(t *testing.T) {
if len(ParseCommunityReportChunks(nil)) != 0 {
t.Error("expected 0 for nil")
}
}
// --- ParseKGTypeSamplesChunks ---
// --- ParseTypeSamplesChunks ---
func TestParseKGTypeSamplesChunks_ValidJSON(t *testing.T) {
func TestParseTypeSamplesChunks_ValidJSON(t *testing.T) {
chunks := []map[string]interface{}{
{"content_with_weight": `{"PERSON": ["Elon Musk", "Einstein"], "ORGANIZATION": ["SpaceX"]}`},
}
result := ParseKGTypeSamplesChunks(chunks)
result := ParseTypeSamplesChunks(chunks)
if len(result) != 2 {
t.Fatalf("expected 2 types, got %d: %v", len(result), result)
}
@@ -277,18 +277,18 @@ func TestParseKGTypeSamplesChunks_ValidJSON(t *testing.T) {
}
}
func TestParseKGTypeSamplesChunks_InvalidJSON(t *testing.T) {
func TestParseTypeSamplesChunks_InvalidJSON(t *testing.T) {
chunks := []map[string]interface{}{
{"content_with_weight": "not json"},
}
result := ParseKGTypeSamplesChunks(chunks)
result := ParseTypeSamplesChunks(chunks)
if len(result) != 0 {
t.Error("expected empty for invalid JSON")
}
}
func TestParseKGTypeSamplesChunks_Empty(t *testing.T) {
result := ParseKGTypeSamplesChunks(nil)
func TestParseTypeSamplesChunks_Empty(t *testing.T) {
result := ParseTypeSamplesChunks(nil)
if len(result) != 0 {
t.Error("expected empty for nil")
}
@@ -356,9 +356,9 @@ func TestBuildKGDenseExpr_WithModel(t *testing.T) {
},
APIConfig: &modelModule.APIConfig{},
}
dense, err := buildKGDenseExpr(embModel, "test question", 10)
dense, err := buildDenseExpr(embModel, "test question", 10)
if err != nil {
t.Fatalf("buildKGDenseExpr failed: %v", err)
t.Fatalf("buildDenseExpr failed: %v", err)
}
if dense == nil {
t.Fatal("expected non-nil MatchDenseExpr")
@@ -372,14 +372,14 @@ func TestBuildKGDenseExpr_WithModel(t *testing.T) {
}
func TestBuildKGDenseExpr_NilModel(t *testing.T) {
dense, err := buildKGDenseExpr(nil, "test", 10)
dense, err := buildDenseExpr(nil, "test", 10)
if dense != nil || err != nil {
t.Errorf("expected nil,nil for nil model, got dense=%v err=%v", dense, err)
}
}
func TestBuildKGDenseExpr_EmptyQuestion(t *testing.T) {
dense, err := buildKGDenseExpr(&modelModule.EmbeddingModel{}, "", 10)
dense, err := buildDenseExpr(&modelModule.EmbeddingModel{}, "", 10)
if dense != nil || err != nil {
t.Errorf("expected nil,nil for empty question, got dense=%v err=%v", dense, err)
}
@@ -387,7 +387,7 @@ func TestBuildKGDenseExpr_EmptyQuestion(t *testing.T) {
// --- Search integration with mock ---
func TestSearchKGEntities_WithMock(t *testing.T) {
func TestSearchEntities_WithMock(t *testing.T) {
mock := &mockKGEngine{
searchFunc: func(ctx context.Context, req *types.SearchRequest) (*types.SearchResult, error) {
if req.Filter["knowledge_graph_kwd"] != "entity" {
@@ -400,16 +400,16 @@ func TestSearchKGEntities_WithMock(t *testing.T) {
}, nil
},
}
entities, err := SearchKGEntities(context.Background(), mock, []string{"kb1"}, "Elon", nil, 10)
entities, err := SearchEntities(context.Background(), mock, []string{"kb1"}, "Elon", nil, 10)
if err != nil {
t.Fatalf("SearchKGEntities failed: %v", err)
t.Fatalf("SearchEntities failed: %v", err)
}
if len(entities) != 1 || entities[0].Name != "Elon Musk" {
t.Errorf("expected [Elon Musk], got %v", entities)
}
}
func TestSearchKGEntitiesByTypes_WithMock(t *testing.T) {
func TestSearchEntitiesByTypes_WithMock(t *testing.T) {
mock := &mockKGEngine{
searchFunc: func(ctx context.Context, req *types.SearchRequest) (*types.SearchResult, error) {
return &types.SearchResult{
@@ -419,20 +419,20 @@ func TestSearchKGEntitiesByTypes_WithMock(t *testing.T) {
}, nil
},
}
entities, err := SearchKGEntitiesByTypes(context.Background(), mock, []string{"kb1"}, []string{"ORGANIZATION"}, 10)
entities, err := SearchEntitiesByTypes(context.Background(), mock, []string{"kb1"}, []string{"ORGANIZATION"}, 10)
if err != nil {
t.Fatalf("SearchKGEntitiesByTypes failed: %v", err)
t.Fatalf("SearchEntitiesByTypes failed: %v", err)
}
if len(entities) != 1 || entities[0].Type != "ORGANIZATION" {
t.Errorf("expected ORGANIZATION, got %v", entities)
}
}
func TestSearchKGTypeSamples_WithMock(t *testing.T) {
func TestSearchTypeSamples_WithMock(t *testing.T) {
mock := &mockKGEngine{}
samples, err := SearchKGTypeSamples(context.Background(), mock, []string{"kb1"})
samples, err := SearchTypeSamples(context.Background(), mock, []string{"kb1"})
if err != nil {
t.Fatalf("SearchKGTypeSamples failed: %v", err)
t.Fatalf("SearchTypeSamples failed: %v", err)
}
if samples == nil {
samples = map[string][]string{}

View File

@@ -0,0 +1,3 @@
package kg
func strPtr(s string) *string { return &s }

View File

@@ -0,0 +1,60 @@
package kg
// KGEntity represents a knowledge graph entity.
type KGEntity struct {
Name string // entity_kwd
Type string // entity_type_kwd
PageRank float64 // rank_flt
Similarity float64 // _score
Description string // content_with_weight
NhopEnts []NhopEntity // n_hop_with_weight (parsed JSON)
}
// NhopEntity represents an N-hop neighbor path.
type NhopEntity struct {
Path []string // entity names along the path
Weights []float64 // pagerank weights per hop
}
// KGRelation represents a relation between two entities.
type KGRelation struct {
From string // from_entity_kwd
To string // to_entity_kwd
Description string // content_with_weight
Sim float64 // score accumulated during pipeline scoring
PageRank float64 // rank_flt or weight_int as float64
}
// Edge represents a directed (from_entity, to_entity) pair.
type Edge struct {
From, To string
}
// EdgeScore represents the accumulated score for an edge from N-hop analysis.
type EdgeScore struct {
Sim float64
PageRank float64
}
// ScoredEntity is a scored entity ready for output.
type ScoredEntity struct {
Entity string
Score float64
Description string
}
// ScoredRelation is a scored relation ready for output.
type ScoredRelation struct {
From string
To string
Score float64
Description string
}
// KGCommunityReport represents a community report.
type KGCommunityReport struct {
Title string // docnm_kwd
Content string // content_with_weight
Weight float64 // weight_flt
Entities string // entities_kwd
}

View File

@@ -1,368 +0,0 @@
//
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
package service
import (
"strings"
"testing"
)
// --- AnalyzeNHopPaths ---
func TestAnalyzeNHopPaths_Basic(t *testing.T) {
ents := map[string]*KGEntity{
"A": {
Similarity: 0.9,
NhopEnts: []NhopEntity{
{Path: []string{"A", "B", "C"}, Weights: []float64{0.8, 0.5}},
},
},
}
result := AnalyzeNHopPaths(ents)
// A→B: 0.9 / (2+0) = 0.45
// B→C: 0.9 / (2+1) = 0.3
if len(result) != 2 {
t.Fatalf("expected 2 edges, got %d", len(result))
}
if result[Edge{"A", "B"}].Sim != 0.45 {
t.Errorf("expected A→B sim=0.45, got %f", result[Edge{"A", "B"}].Sim)
}
if result[Edge{"B", "C"}].Sim != 0.3 {
t.Errorf("expected B→C sim=0.3, got %f", result[Edge{"B", "C"}].Sim)
}
}
func TestAnalyzeNHopPaths_MultipleContributors(t *testing.T) {
ents := map[string]*KGEntity{
"A": {
Similarity: 0.8,
NhopEnts: []NhopEntity{
{Path: []string{"A", "B"}, Weights: []float64{0.7}},
},
},
"X": {
Similarity: 0.6,
NhopEnts: []NhopEntity{
{Path: []string{"X", "B"}, Weights: []float64{0.5}},
},
},
}
result := AnalyzeNHopPaths(ents)
// A→B: 0.8 / 2 = 0.4
// X→B: 0.6 / 2 = 0.3
if result[Edge{"A", "B"}].Sim != 0.4 {
t.Errorf("expected A→B sim=0.4, got %f", result[Edge{"A", "B"}].Sim)
}
if result[Edge{"X", "B"}].Sim != 0.3 {
t.Errorf("expected X→B sim=0.3, got %f", result[Edge{"X", "B"}].Sim)
}
}
func TestAnalyzeNHopPaths_Empty(t *testing.T) {
result := AnalyzeNHopPaths(nil)
if len(result) != 0 {
t.Errorf("expected empty, got %d", len(result))
}
}
// --- DoubleHitBoost ---
func TestDoubleHitBoost(t *testing.T) {
ents := map[string]*KGEntity{
"A": {Similarity: 0.5},
"B": {Similarity: 0.3},
}
types := map[string]struct{}{"A": {}}
DoubleHitBoost(ents, types)
if ents["A"].Similarity != 1.0 {
t.Errorf("expected A sim=1.0 after boost, got %f", ents["A"].Similarity)
}
if ents["B"].Similarity != 0.3 {
t.Errorf("expected B sim unchanged at 0.3, got %f", ents["B"].Similarity)
}
}
func TestDoubleHitBoost_Empty(t *testing.T) {
ents := map[string]*KGEntity{"A": {Similarity: 0.5}}
DoubleHitBoost(ents, map[string]struct{}{})
if ents["A"].Similarity != 0.5 {
t.Errorf("expected unchanged, got %f", ents["A"].Similarity)
}
}
// --- FuseRelationScores ---
func TestFuseRelationScores_NhopContribution(t *testing.T) {
rels := map[Edge]*KGRelation{
{"A", "B"}: {Sim: 0.5, PageRank: 0.8},
}
types := map[string]struct{}{}
nhop := map[Edge]EdgeScore{
{"A", "B"}: {Sim: 0.3},
}
FuseRelationScores(rels, types, nhop)
// sim = 0.5 * (0.3 + 1) = 0.65
if rels[Edge{"A", "B"}].Sim != 0.65 {
t.Errorf("expected 0.65, got %f", rels[Edge{"A", "B"}].Sim)
}
}
func TestFuseRelationScores_TypeBoost(t *testing.T) {
rels := map[Edge]*KGRelation{
{"A", "B"}: {Sim: 0.5},
}
types := map[string]struct{}{"A": {}, "B": {}}
nhop := map[Edge]EdgeScore{}
FuseRelationScores(rels, types, nhop)
// Both endpoints in types: s=2, sim = 0.5 * (2+1) = 1.5
if rels[Edge{"A", "B"}].Sim != 1.5 {
t.Errorf("expected 1.5, got %f", rels[Edge{"A", "B"}].Sim)
}
}
func TestFuseRelationScores_NhopNewEdge(t *testing.T) {
rels := map[Edge]*KGRelation{}
types := map[string]struct{}{}
nhop := map[Edge]EdgeScore{
{"A", "B"}: {Sim: 0.4, PageRank: 0.7},
}
FuseRelationScores(rels, types, nhop)
if _, ok := rels[Edge{"A", "B"}]; !ok {
t.Fatal("expected new edge from N-hop")
}
if rels[Edge{"A", "B"}].Sim != 0.4 {
t.Errorf("expected sim=0.4, got %f", rels[Edge{"A", "B"}].Sim)
}
}
// --- SortAndTrim ---
func TestSortAndTrimEntities(t *testing.T) {
ents := map[string]*KGEntity{
"A": {Similarity: 0.5, PageRank: 0.9},
"B": {Similarity: 0.8, PageRank: 0.3},
"C": {Similarity: 0.9, PageRank: 0.1},
}
result := SortAndTrimEntities(ents, 2)
if len(result) != 2 {
t.Fatalf("expected 2, got %d", len(result))
}
// A: 0.45, B: 0.24, C: 0.09 → top 2 should be A, B
if result[0].Entity != "A" {
t.Errorf("expected A first (0.45), got %s (%f)", result[0].Entity, result[0].Score)
}
}
func TestSortAndTrimEntities_DefaultTopN(t *testing.T) {
ents := map[string]*KGEntity{
"A": {Similarity: 0.5, PageRank: 0.9},
"B": {Similarity: 0.8, PageRank: 0.3},
}
result := SortAndTrimEntities(ents, 0)
if len(result) != 2 {
t.Errorf("expected default topN to include all, got %d", len(result))
}
}
func TestSortAndTrimRelations(t *testing.T) {
rels := map[Edge]*KGRelation{
{"A", "B"}: {Sim: 0.9, PageRank: 0.1},
{"C", "D"}: {Sim: 0.3, PageRank: 0.8},
}
result := SortAndTrimRelations(rels, 1)
if len(result) != 1 {
t.Fatalf("expected 1, got %d", len(result))
}
// A→B: 0.09, C→D: 0.24 → C→D should be first
if result[0].From != "C" {
t.Errorf("expected C first (0.24), got %s (%f)", result[0].From, result[0].Score)
}
}
// --- Format and Build ---
func TestBuildKGContent_Basic(t *testing.T) {
entities := []ScoredEntity{
{Entity: "A", Score: 0.45, Description: `{"description": "Entity A desc"}`},
}
relations := []ScoredRelation{
{From: "A", To: "B", Score: 0.3, Description: `{"description": "rel A-B"}`},
}
result := BuildKGContent(entities, relations, 10000)
if !strings.Contains(result, "Entity A desc") {
t.Errorf("expected entity description in output, got: %s", result)
}
if !strings.Contains(result, "rel A-B") {
t.Errorf("expected relation description in output, got: %s", result)
}
}
func TestBuildKGContent_TokenBudget(t *testing.T) {
longDesc := strings.Repeat("This is a very long description. ", 50)
entities := []ScoredEntity{
{Entity: "LongEntityName", Score: 1.0, Description: longDesc},
}
relations := []ScoredRelation{
{From: "X", To: "Y", Score: 1.0, Description: "relation desc"},
}
result := BuildKGContent(entities, relations, 50)
// Token budget is very small, should truncate and not include relations
if strings.Contains(result, "relation desc") {
t.Log("Note: relations included despite small budget (depending on token count)")
}
}
func TestFormatEntitiesToCSV_HeaderExceedsBudget(t *testing.T) {
entities := []ScoredEntity{
{Entity: "A", Score: 1.0, Description: "d"},
}
result, remaining := FormatEntitiesToCSV(entities, 3)
tokens := NumTokensFromString(result)
// Header lines (---- Entities ----\n, Entity,Score,Description\n) are written
// before the token budget check. They consume ~11 tokens but are not deducted
// from maxToken. This is a known limitation shared with Python.
if tokens > 3 {
t.Logf("output %d tokens exceeds budget of %d (header not counted, remaining=%d)", tokens, 3, remaining)
}
}
func TestFilterChunksByScore_AllPass(t *testing.T) {
chunks := []map[string]interface{}{
{"entity_kwd": "A", "_score": 0.5},
{"entity_kwd": "B", "_score": 0.8},
}
result := FilterChunksByScore(chunks, 0.3)
if len(result) != 2 {
t.Errorf("expected all 2 chunks to pass, got %d", len(result))
}
}
func TestFilterChunksByScore_SomeFiltered(t *testing.T) {
chunks := []map[string]interface{}{
{"entity_kwd": "A", "_score": 0.2},
{"entity_kwd": "B", "_score": 0.9},
}
result := FilterChunksByScore(chunks, 0.3)
if len(result) != 1 || result[0]["entity_kwd"] != "B" {
t.Errorf("expected only B to pass, got %v", result)
}
}
func TestFilterChunksByScore_MissingScore(t *testing.T) {
chunks := []map[string]interface{}{
{"entity_kwd": "A"}, // no _score → treated as 0
{"entity_kwd": "B", "score": 0.5},
}
result := FilterChunksByScore(chunks, 0.3)
if len(result) != 1 || result[0]["entity_kwd"] != "B" {
t.Errorf("expected only B (using 'score' field), got %v", result)
}
}
func TestFilterChunksByScore_NilInput(t *testing.T) {
result := FilterChunksByScore(nil, 0.3)
if result != nil {
t.Errorf("expected nil, got %v", result)
}
}
func TestFilterChunksByScore_ZeroThreshold(t *testing.T) {
chunks := []map[string]interface{}{
{"entity_kwd": "A", "_score": 0.0},
}
result := FilterChunksByScore(chunks, 0)
if len(result) != 1 {
t.Errorf("expected all pass when threshold=0, got %d", len(result))
}
}
func TestExtractDescription_JSON(t *testing.T) {
result := extractDescription(`{"description": "Entity A description", "other": "value"}`)
if result != "Entity A description" {
t.Errorf("expected 'Entity A description', got %q", result)
}
}
func TestExtractDescription_Plain(t *testing.T) {
result := extractDescription("plain description")
if result != "plain description" {
t.Errorf("expected 'plain description', got %q", result)
}
}
func TestExtractDescription_EscapedQuote(t *testing.T) {
result := extractDescription(`{"description": "has \"quote\" inside"}`)
if result != `has "quote" inside` {
t.Errorf("expected full description with quote, got %q", result)
}
}
func TestExtractDescription_NonStringValue(t *testing.T) {
result := extractDescription(`{"description": null, "other": "val"}`)
if result != `{"description": null, "other": "val"}` {
t.Errorf("expected raw JSON when description is null, got %q", result)
}
}
func TestExtractDescription_EmptyString(t *testing.T) {
result := extractDescription("")
if result != "" {
t.Errorf("expected empty, got %q", result)
}
}
func TestFormatCSVLine_Normal(t *testing.T) {
result := formatCSVLine("Elon Musk", "0.85", "CEO of SpaceX")
// Normal values should not be quoted
if result != "Elon Musk,0.85,CEO of SpaceX\n" {
t.Errorf("expected unquoted CSV, got %q", result)
}
}
func TestFormatCSVLine_CommaInField(t *testing.T) {
result := formatCSVLine("Musk, Elon", "0.85", "CEO, SpaceX")
// Values with commas should be quoted
expected := `"Musk, Elon",0.85,"CEO, SpaceX"` + "\n"
if result != expected {
t.Errorf("expected %q, got %q", expected, result)
}
}
func TestFormatCSVLine_QuoteInField(t *testing.T) {
result := formatCSVLine("Elon Musk", "0.85", `CEO of "SpaceX"`)
// Values with quotes should have quotes escaped
expected := `Elon Musk,0.85,"CEO of ""SpaceX"""` + "\n"
if result != expected {
t.Errorf("expected %q, got %q", expected, result)
}
}
func TestFormatCSVLine_EmptyField(t *testing.T) {
result := formatCSVLine("", "", "")
if result != ",,\n" {
t.Errorf("expected empty fields, got %q", result)
}
}
func TestNumTokensFromString(t *testing.T) {
s := "This is a test string with multiple words"
tokens := NumTokensFromString(s)
if tokens <= 0 {
t.Errorf("expected positive token count, got %d", tokens)
}
}

View File

@@ -1,397 +0,0 @@
//
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
package service
import (
"context"
"encoding/json"
"fmt"
"ragflow/internal/engine"
"ragflow/internal/engine/types"
modelModule "ragflow/internal/entity/models"
)
// NhopEntity represents an N-hop neighbor path.
type NhopEntity struct {
Path []string // entity names along the path
Weights []float64 // pagerank weights per hop
}
// KGEntity represents a knowledge graph entity.
type KGEntity struct {
Name string // entity_kwd
Type string // entity_type_kwd
PageRank float64 // rank_flt
Similarity float64 // _score
Description string // content_with_weight
NhopEnts []NhopEntity // n_hop_with_weight (parsed JSON)
}
// Edge represents a directed (from_entity, to_entity) pair.
type Edge struct {
From, To string
}
// EdgeScore represents the accumulated score for an edge from N-hop analysis.
type EdgeScore struct {
Sim float64
PageRank float64
}
// ScoredEntity is a scored entity ready for output.
type ScoredEntity struct {
Entity string
Score float64
Description string
}
// ScoredRelation is a scored relation ready for output.
type ScoredRelation struct {
From string
To string
Score float64
Description string
}
// KGRelation represents a relation between two entities.
type KGRelation struct {
From string // from_entity_kwd
To string // to_entity_kwd
Weight int // weight_int
Description string // content_with_weight
Sim float64 // score accumulated during pipeline scoring
PageRank float64 // rank_flt or weight_int as float64
}
// KGCommunityReport represents a community report.
type KGCommunityReport struct {
Title string // docnm_kwd
Content string // content_with_weight
Weight float64 // weight_flt
Entities string // entities_kwd
}
// buildKGDenseExpr computes the query vector and returns a MatchDenseExpr
// for KG hybrid search. Returns nil if embModel or question is empty.
func buildKGDenseExpr(embModel *modelModule.EmbeddingModel, question string, topN int) (*types.MatchDenseExpr, error) {
if embModel == nil || question == "" {
return nil, nil
}
embCfg := &modelModule.EmbeddingConfig{Dimension: 0}
embeddings, err := embModel.ModelDriver.Embed(embModel.ModelName, []string{question}, embModel.APIConfig, embCfg)
if err != nil {
return nil, fmt.Errorf("KG entity embed failed: %w", err)
}
if len(embeddings) == 0 || len(embeddings[0].Embedding) == 0 {
return nil, nil
}
vector := embeddings[0].Embedding
return &types.MatchDenseExpr{
VectorColumnName: fmt.Sprintf("q_%d_vec", len(vector)),
EmbeddingData: vector,
EmbeddingDataType: "float",
DistanceType: "cosine",
TopN: topN,
ExtraOptions: map[string]interface{}{"similarity": 0.3},
}, nil
}
// buildHybridExpr returns MatchExprs for hybrid search (dense + text + fusion).
func buildHybridExpr(dense *types.MatchDenseExpr, text *types.MatchTextExpr, topN int) []interface{} {
return []interface{}{
dense,
text,
&types.FusionExpr{
Method: "weighted_sum",
TopN: topN,
FusionParams: map[string]interface{}{"weights": "0.05,0.95"},
},
}
}
// buildEntitySearchRequest constructs a SearchRequest for KG entities.
// dense may be nil for text-only search.
func buildEntitySearchRequest(kbIDs []string, question string, dense *types.MatchDenseExpr, topN int) *types.SearchRequest {
req := &types.SearchRequest{
KbIDs: kbIDs,
SelectFields: []string{"entity_kwd", "entity_type_kwd", "rank_flt", "content_with_weight"},
Limit: topN,
Filter: map[string]interface{}{"knowledge_graph_kwd": "entity"},
}
if question == "" {
return req
}
textExpr := &types.MatchTextExpr{
Fields: []string{"entity_kwd^10", "content_ltks^2"},
MatchingText: question,
TopN: topN,
}
if dense != nil {
req.MatchExprs = buildHybridExpr(dense, textExpr, topN)
req.RankFeature = map[string]float64{"pagerank_fea": 10.0}
} else {
req.MatchExprs = []interface{}{textExpr}
}
return req
}
// buildEntityTypeSearchRequest constructs a SearchRequest for KG entities by type.
func buildEntityTypeSearchRequest(kbIDs []string, typeKeywords []string, topN int) *types.SearchRequest {
req := &types.SearchRequest{
KbIDs: kbIDs,
SelectFields: []string{"entity_kwd", "entity_type_kwd", "rank_flt", "content_with_weight"},
Limit: topN,
Filter: map[string]interface{}{
"knowledge_graph_kwd": "entity",
},
}
if len(typeKeywords) > 0 {
filters := make([]interface{}, len(typeKeywords))
for i, t := range typeKeywords {
filters[i] = t
}
req.Filter["entity_type_kwd"] = filters
}
return req
}
// buildRelationSearchRequest constructs a SearchRequest for KG relations.
// dense may be nil for text-only search.
func buildRelationSearchRequest(kbIDs []string, question string, dense *types.MatchDenseExpr, topN int) *types.SearchRequest {
req := &types.SearchRequest{
KbIDs: kbIDs,
SelectFields: []string{"from_entity_kwd", "to_entity_kwd", "weight_int", "content_with_weight"},
Limit: topN,
Filter: map[string]interface{}{"knowledge_graph_kwd": "relation"},
}
if question != "" {
textExpr := &types.MatchTextExpr{
Fields: []string{"content_ltks"},
MatchingText: question,
TopN: topN,
}
if dense != nil {
req.MatchExprs = buildHybridExpr(dense, textExpr, topN)
} else {
req.MatchExprs = []interface{}{textExpr}
}
}
return req
}
// buildCommunitySearchRequest constructs a SearchRequest for KG community reports.
// Matches community reports whose entities_kwd contains any of the given entity names.
func buildCommunitySearchRequest(kbIDs []string, entityNames []string, topN int) *types.SearchRequest {
req := &types.SearchRequest{
KbIDs: kbIDs,
SelectFields: []string{"docnm_kwd", "content_with_weight", "weight_flt", "entities_kwd"},
Limit: topN,
Filter: map[string]interface{}{
"knowledge_graph_kwd": "community_report",
},
OrderBy: (&types.OrderByExpr{}).Desc("weight_flt"),
}
if len(entityNames) > 0 {
filters := make([]interface{}, len(entityNames))
for i, name := range entityNames {
filters[i] = name
}
req.Filter["entities_kwd"] = filters
}
return req
}
// buildTypeSamplesSearchRequest constructs a SearchRequest for ty2ents data.
func buildTypeSamplesSearchRequest(kbIDs []string) *types.SearchRequest {
return &types.SearchRequest{
KbIDs: kbIDs,
SelectFields: []string{"content_with_weight"},
Limit: 10000,
Filter: map[string]interface{}{"knowledge_graph_kwd": "ty2ents"},
}
}
// ParseKGEntityChunks converts raw search result chunks into KGEntity slices.
func ParseKGEntityChunks(chunks []map[string]interface{}) []KGEntity {
var entities []KGEntity
for _, chunk := range chunks {
e := KGEntity{}
if v, ok := chunk["entity_kwd"].(string); ok {
e.Name = v
} else if list, ok := chunk["entity_kwd"].([]interface{}); ok && len(list) > 0 {
e.Name, _ = list[0].(string)
}
if e.Name == "" {
continue
}
e.Type, _ = chunk["entity_type_kwd"].(string)
e.Description, _ = chunk["content_with_weight"].(string)
if v, ok := chunk["rank_flt"].(float64); ok {
e.PageRank = v
}
if v, ok := chunk["_score"].(float64); ok {
e.Similarity = v
} else if v, ok := chunk["score"].(float64); ok {
e.Similarity = v
}
entities = append(entities, e)
}
return entities
}
// ParseKGRelationChunks converts raw search result chunks into KGRelation slices.
func ParseKGRelationChunks(chunks []map[string]interface{}) []KGRelation {
var relations []KGRelation
for _, chunk := range chunks {
r := KGRelation{}
r.From, _ = chunk["from_entity_kwd"].(string)
r.To, _ = chunk["to_entity_kwd"].(string)
r.Description, _ = chunk["content_with_weight"].(string)
if v, ok := chunk["weight_int"].(float64); ok {
r.Weight = int(v)
} else if v, ok := chunk["weight_int"].(int); ok {
r.Weight = v
}
if r.From == "" || r.To == "" {
continue
}
relations = append(relations, r)
}
return relations
}
// ParseKGCommunityReportChunks converts raw search result chunks into KGCommunityReport slices.
func ParseKGCommunityReportChunks(chunks []map[string]interface{}) []KGCommunityReport {
var reports []KGCommunityReport
for _, chunk := range chunks {
r := KGCommunityReport{}
r.Title, _ = chunk["docnm_kwd"].(string)
r.Content, _ = chunk["content_with_weight"].(string)
r.Entities, _ = chunk["entities_kwd"].(string)
if v, ok := chunk["weight_flt"].(float64); ok {
r.Weight = v
}
if r.Title == "" && r.Content == "" {
continue
}
reports = append(reports, r)
}
return reports
}
// ParseKGTypeSamplesChunks converts raw search result chunks into a type→entities map.
func ParseKGTypeSamplesChunks(chunks []map[string]interface{}) map[string][]string {
result := make(map[string][]string)
for _, chunk := range chunks {
content, ok := chunk["content_with_weight"].(string)
if !ok || content == "" {
continue
}
var typeMap map[string][]string
if err := json.Unmarshal([]byte(content), &typeMap); err != nil {
continue
}
for typ, entities := range typeMap {
result[typ] = append(result[typ], entities...)
}
}
return result
}
// NhopEntityNames extracts unique entity names from n_hop_with_weight JSON string.
// The JSON format is: [{"path": ["A", "B", "C"], "weights": [0.8, 0.5]}, ...]
// Returns entity names in order of first appearance, with duplicates removed.
func NhopEntityNames(nHopJSON string) []string {
type nhopItem struct {
Path []string `json:"path"`
Weights []float64 `json:"weights"`
}
var data []nhopItem
if err := json.Unmarshal([]byte(nHopJSON), &data); err != nil {
return nil
}
seen := make(map[string]struct{})
var names []string
for _, item := range data {
for _, name := range item.Path {
if _, ok := seen[name]; !ok {
seen[name] = struct{}{}
names = append(names, name)
}
}
}
return names
}
// SearchKGEntities searches for KG entities matching a question.
func SearchKGEntities(ctx context.Context, docEngine engine.DocEngine, kbIDs []string, question string, embModel *modelModule.EmbeddingModel, topN int) ([]KGEntity, error) {
dense, err := buildKGDenseExpr(embModel, question, topN)
if err != nil {
return nil, err
}
req := buildEntitySearchRequest(kbIDs, question, dense, topN)
result, err := docEngine.Search(ctx, req)
if err != nil {
return nil, fmt.Errorf("KG entity search failed: %w", err)
}
return ParseKGEntityChunks(result.Chunks), nil
}
// SearchKGEntitiesByTypes searches for KG entities by type keywords.
func SearchKGEntitiesByTypes(ctx context.Context, docEngine engine.DocEngine, kbIDs []string, typeKeywords []string, topN int) ([]KGEntity, error) {
req := buildEntityTypeSearchRequest(kbIDs, typeKeywords, topN)
result, err := docEngine.Search(ctx, req)
if err != nil {
return nil, fmt.Errorf("KG entity type search failed: %w", err)
}
return ParseKGEntityChunks(result.Chunks), nil
}
// SearchKGRelations searches for KG relations matching a question.
func SearchKGRelations(ctx context.Context, docEngine engine.DocEngine, kbIDs []string, question string, embModel *modelModule.EmbeddingModel, topN int) ([]KGRelation, error) {
dense, err := buildKGDenseExpr(embModel, question, topN)
if err != nil {
return nil, err
}
req := buildRelationSearchRequest(kbIDs, question, dense, topN)
result, err := docEngine.Search(ctx, req)
if err != nil {
return nil, fmt.Errorf("KG relation search failed: %w", err)
}
return ParseKGRelationChunks(result.Chunks), nil
}
// SearchKGCommunityReports searches for community reports related to given entities.
func SearchKGCommunityReports(ctx context.Context, docEngine engine.DocEngine, kbIDs []string, entityNames []string, topN int) ([]KGCommunityReport, error) {
req := buildCommunitySearchRequest(kbIDs, entityNames, topN)
result, err := docEngine.Search(ctx, req)
if err != nil {
return nil, fmt.Errorf("KG community search failed: %w", err)
}
return ParseKGCommunityReportChunks(result.Chunks), nil
}
// SearchKGTypeSamples retrieves the type→entities mapping from ES.
func SearchKGTypeSamples(ctx context.Context, docEngine engine.DocEngine, kbIDs []string) (map[string][]string, error) {
req := buildTypeSamplesSearchRequest(kbIDs)
result, err := docEngine.Search(ctx, req)
if err != nil {
return nil, fmt.Errorf("KG type samples search failed: %w", err)
}
return ParseKGTypeSamplesChunks(result.Chunks), nil
}