mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 15:31:05 +08:00
feat: Dify-compatible retrieval API endpoint (#15704)
## Summary Dify-compatible retrieval API for external knowledge base integration. ## Changes - **New handler**: DifyRetrievalHandler with POST/GET /api/v1/dify/retrieval - **Health check**: GET /api/v1/dify/retrieval/health - **Full pipeline**: KB validation -> permission check -> embedding -> metadata filter -> chunk retrieval -> child chunk aggregation -> optional KG search -> response assembly - **12 tests** covering all paths (success, errors, metadata filter, KG mode) - **Testability**: Handler dependencies defined as interfaces (KBServiceIface, ModelServiceIface, etc.) ## Files | File | Type | |------|------| | internal/handler/dify_retrieval_handler.go | New — handler + interfaces | | internal/handler/dify_retrieval_handler_test.go | New — 12 tests | | internal/router/router.go | Modified — route registration | | cmd/server_main.go | Modified — handler wiring | | internal/service/kg/pipeline.go | Modified — SetChatModel/SetEmbModel | | internal/service/kg/retrieval.go | New — helper functions | | internal/service/kg/scoring.go | Moved from service package | | internal/service/kg/search.go | New — KG search functions | | internal/service/kg/types.go | New — type definitions | --------- Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -218,8 +218,20 @@ func startServer(config *server.Config) {
|
||||
&handler.SearchbotRealLLM{Svc: modelProviderService},
|
||||
)
|
||||
|
||||
// Dify retrieval handler
|
||||
docDAO := dao.NewDocumentDAO()
|
||||
retrievalService := nlp.NewRetrievalService(docEngine, docDAO)
|
||||
difyRetrievalHandler := handler.NewDifyRetrievalHandler(
|
||||
knowledgebaseService,
|
||||
modelProviderService,
|
||||
metadataService,
|
||||
retrievalService,
|
||||
docDAO,
|
||||
docEngine,
|
||||
)
|
||||
|
||||
// Initialize router
|
||||
r := router.NewRouter(authHandler, userHandler, tenantHandler, documentHandler, datasetsHandler, systemHandler, knowledgebaseHandler, chunkHandler, llmHandler, chatHandler, chatSessionHandler, connectorHandler, searchHandler, fileHandler, memoryHandler, mcpHandler, skillSearchHandler, providerHandler, agentHandler, relatedQuestionsHandler)
|
||||
r := router.NewRouter(authHandler, userHandler, tenantHandler, documentHandler, datasetsHandler, systemHandler, knowledgebaseHandler, chunkHandler, llmHandler, chatHandler, chatSessionHandler, connectorHandler, searchHandler, fileHandler, memoryHandler, mcpHandler, skillSearchHandler, providerHandler, agentHandler, relatedQuestionsHandler, difyRetrievalHandler)
|
||||
|
||||
// Create Gin engine
|
||||
ginEngine := gin.New()
|
||||
|
||||
373
internal/handler/dify_retrieval_handler.go
Normal file
373
internal/handler/dify_retrieval_handler.go
Normal file
@@ -0,0 +1,373 @@
|
||||
//
|
||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"ragflow/internal/common"
|
||||
"gorm.io/gorm"
|
||||
"go.uber.org/zap"
|
||||
"ragflow/internal/engine"
|
||||
"ragflow/internal/entity"
|
||||
modelModule "ragflow/internal/entity/models"
|
||||
"ragflow/internal/service"
|
||||
"ragflow/internal/service/kg"
|
||||
"ragflow/internal/service/nlp"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// --- Interfaces (for testability) ---
|
||||
|
||||
// KBServiceIface abstracts KnowledgebaseService for the Dify handler.
|
||||
type KBServiceIface interface {
|
||||
GetByID(kbID string) (*entity.Knowledgebase, error)
|
||||
Accessible(kbID, userID string) bool
|
||||
}
|
||||
|
||||
// ModelServiceIface abstracts ModelProviderService for the Dify handler.
|
||||
type ModelServiceIface interface {
|
||||
GetEmbeddingModel(tenantID, embdID string) (*modelModule.EmbeddingModel, error)
|
||||
GetChatModel(tenantID, compositeModelName string) (*modelModule.ChatModel, error)
|
||||
}
|
||||
|
||||
// MetadataServiceIface abstracts MetadataService for the Dify handler.
|
||||
type MetadataServiceIface interface {
|
||||
GetFlattedMetaByKBs(kbIDs []string) (common.MetaData, error)
|
||||
LabelQuestion(question string, kbs []*entity.Knowledgebase) map[string]float64
|
||||
}
|
||||
|
||||
// RetrievalServiceIface abstracts RetrievalService for the Dify handler.
|
||||
type RetrievalServiceIface interface {
|
||||
Retrieval(ctx context.Context, req *nlp.RetrievalRequest) (*nlp.RetrievalResult, error)
|
||||
}
|
||||
|
||||
// DocumentDAOIface abstracts DocumentDAO for the Dify handler.
|
||||
type DocumentDAOIface interface {
|
||||
GetByIDs(ids []string) ([]*entity.Document, error)
|
||||
}
|
||||
|
||||
// --- Request / Response types ---
|
||||
|
||||
// difyRetrievalRequest is the JSON body / query params for the Dify retrieval endpoint.
|
||||
type difyRetrievalRequest struct {
|
||||
KnowledgeID string `json:"knowledge_id" form:"knowledge_id"`
|
||||
Query string `json:"query" form:"query"`
|
||||
UseKG bool `json:"use_kg" form:"use_kg"`
|
||||
RetrievalSetting *difyRetrievalSetting `json:"retrieval_setting"`
|
||||
MetadataCondition *difyMetadataCondition `json:"metadata_condition"`
|
||||
}
|
||||
|
||||
type difyRetrievalSetting struct {
|
||||
TopK *int `json:"top_k" form:"top_k"`
|
||||
ScoreThreshold *float64 `json:"score_threshold" form:"score_threshold"`
|
||||
}
|
||||
|
||||
// difyCondition is a Dify-format metadata filter condition.
|
||||
// Dify uses "name"/"comparison_operator" instead of MetaFilterCondition's "key"/"op".
|
||||
type difyCondition struct {
|
||||
Name string `json:"name"`
|
||||
ComparisonOperator string `json:"comparison_operator"`
|
||||
Value interface{} `json:"value"`
|
||||
}
|
||||
|
||||
type difyMetadataCondition struct {
|
||||
Conditions []difyCondition `json:"conditions"`
|
||||
Logic string `json:"logic"`
|
||||
}
|
||||
|
||||
// toMetaFilterConditions converts Dify-format conditions to internal MetaFilterConditions.
|
||||
func (c difyMetadataCondition) toMetaFilterConditions() []service.MetaFilterCondition {
|
||||
if len(c.Conditions) == 0 {
|
||||
return nil
|
||||
}
|
||||
result := make([]service.MetaFilterCondition, len(c.Conditions))
|
||||
for i, dc := range c.Conditions {
|
||||
v := ""
|
||||
if dc.Value != nil {
|
||||
v = fmt.Sprint(dc.Value)
|
||||
}
|
||||
result[i] = service.MetaFilterCondition{
|
||||
Key: dc.Name,
|
||||
Op: dc.ComparisonOperator,
|
||||
Value: v,
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// difyRecord is one item in the response records array.
|
||||
type difyRecord struct {
|
||||
Content string `json:"content"`
|
||||
Score float64 `json:"score"`
|
||||
Title string `json:"title"`
|
||||
Metadata map[string]interface{} `json:"metadata"`
|
||||
}
|
||||
|
||||
// --- Handler ---
|
||||
|
||||
// DifyRetrievalHandler handles Dify-compatible retrieval requests.
|
||||
type DifyRetrievalHandler struct {
|
||||
kbSvc KBServiceIface
|
||||
modelSvc ModelServiceIface
|
||||
metadataSvc MetadataServiceIface
|
||||
retrievalSvc RetrievalServiceIface
|
||||
docDAO DocumentDAOIface
|
||||
docEngine engine.DocEngine
|
||||
}
|
||||
|
||||
// NewDifyRetrievalHandler creates a new DifyRetrievalHandler.
|
||||
// The KG pipeline is created inline when use_kg=true to avoid injecting
|
||||
// a pipeline that depends on per-request model configuration.
|
||||
func NewDifyRetrievalHandler(
|
||||
kbSvc KBServiceIface,
|
||||
modelSvc ModelServiceIface,
|
||||
metadataSvc MetadataServiceIface,
|
||||
retrievalSvc RetrievalServiceIface,
|
||||
docDAO DocumentDAOIface,
|
||||
docEngine engine.DocEngine,
|
||||
) *DifyRetrievalHandler {
|
||||
return &DifyRetrievalHandler{
|
||||
kbSvc: kbSvc,
|
||||
modelSvc: modelSvc,
|
||||
metadataSvc: metadataSvc,
|
||||
retrievalSvc: retrievalSvc,
|
||||
docDAO: docDAO,
|
||||
docEngine: docEngine,
|
||||
}
|
||||
}
|
||||
|
||||
// Retrieval handles POST/GET /api/v1/dify/retrieval.
|
||||
// Matches Python: api/apps/restful_apis/dify_retrieval_api.py::retrieval()
|
||||
func (h *DifyRetrievalHandler) Retrieval(c *gin.Context) {
|
||||
user, errCode, errMsg := GetUser(c)
|
||||
if errCode != common.CodeSuccess {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"code": errCode, "message": errMsg})
|
||||
return
|
||||
}
|
||||
|
||||
var req difyRetrievalRequest
|
||||
if c.Request.Method == http.MethodGet {
|
||||
if err := c.ShouldBindQuery(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"code": common.CodeArgumentError, "message": "invalid query parameters"})
|
||||
return
|
||||
}
|
||||
// Manually extract top_k and score_threshold from query (flat params, not nested)
|
||||
if v := c.Query("top_k"); v != "" {
|
||||
if parsed, err := strconv.Atoi(v); err == nil {
|
||||
if req.RetrievalSetting == nil {
|
||||
req.RetrievalSetting = &difyRetrievalSetting{}
|
||||
}
|
||||
req.RetrievalSetting.TopK = &parsed
|
||||
}
|
||||
}
|
||||
if v := c.Query("score_threshold"); v != "" {
|
||||
if parsed, err := strconv.ParseFloat(v, 64); err == nil {
|
||||
if req.RetrievalSetting == nil {
|
||||
req.RetrievalSetting = &difyRetrievalSetting{}
|
||||
}
|
||||
req.RetrievalSetting.ScoreThreshold = &parsed
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"code": common.CodeArgumentError, "message": "invalid request body"})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if req.KnowledgeID == "" || req.Query == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"code": common.CodeArgumentError, "message": "knowledge_id and query are required"})
|
||||
return
|
||||
}
|
||||
|
||||
kb, err := h.kbSvc.GetByID(req.KnowledgeID)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
c.JSON(http.StatusNotFound, gin.H{"code": common.CodeNotFound, "message": "Knowledgebase not found!"})
|
||||
} else {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"code": common.CodeServerError, "message": "failed to query knowledgebase"})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if !h.kbSvc.Accessible(req.KnowledgeID, user.ID) {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"code": common.CodeAuthenticationError, "message": "No authorization."})
|
||||
return
|
||||
}
|
||||
|
||||
// Parse retrieval options (nil means service uses defaults)
|
||||
var topK *int
|
||||
if req.RetrievalSetting != nil && req.RetrievalSetting.TopK != nil {
|
||||
topK = req.RetrievalSetting.TopK
|
||||
}
|
||||
var scoreThreshold *float64
|
||||
if req.RetrievalSetting != nil && req.RetrievalSetting.ScoreThreshold != nil {
|
||||
scoreThreshold = req.RetrievalSetting.ScoreThreshold
|
||||
}
|
||||
pageSize := 1024
|
||||
if topK != nil {
|
||||
pageSize = *topK
|
||||
}
|
||||
|
||||
// Get embedding model
|
||||
embModel, err := h.modelSvc.GetEmbeddingModel(kb.TenantID, kb.EmbdID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"code": common.CodeServerError, "message": fmt.Sprintf("failed to get embedding model: %v", err)})
|
||||
return
|
||||
}
|
||||
|
||||
// Metadata filter
|
||||
metas, metaErr := h.metadataSvc.GetFlattedMetaByKBs([]string{req.KnowledgeID})
|
||||
docIDs := make([]string, 0)
|
||||
if metaErr == nil && req.MetadataCondition != nil {
|
||||
logic := req.MetadataCondition.Logic
|
||||
if logic == "" {
|
||||
logic = "and"
|
||||
}
|
||||
filteredIDs := service.ApplyMetaFilter(metas, req.MetadataCondition.toMetaFilterConditions(), logic)
|
||||
docIDs = append(docIDs, filteredIDs...)
|
||||
}
|
||||
if len(docIDs) == 0 && req.MetadataCondition != nil {
|
||||
docIDs = []string{"-999"}
|
||||
}
|
||||
|
||||
// Label question for rank features
|
||||
kbs := []*entity.Knowledgebase{kb}
|
||||
rankFeature := h.metadataSvc.LabelQuestion(req.Query, kbs)
|
||||
|
||||
// Chunk retrieval
|
||||
sr := &nlp.RetrievalRequest{
|
||||
Question: req.Query,
|
||||
TenantIDs: []string{kb.TenantID},
|
||||
KbIDs: []string{req.KnowledgeID},
|
||||
DocIDs: docIDs,
|
||||
Page: 1,
|
||||
PageSize: pageSize,
|
||||
Top: topK,
|
||||
SimilarityThreshold: scoreThreshold,
|
||||
EmbeddingModel: embModel,
|
||||
}
|
||||
if rankFeature != nil {
|
||||
sr.RankFeature = &rankFeature
|
||||
}
|
||||
|
||||
result, err := h.retrievalSvc.Retrieval(c.Request.Context(), sr)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "not_found") {
|
||||
c.JSON(http.StatusNotFound, gin.H{"code": common.CodeNotFound, "message": "No chunk found! Check the chunk status please!"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"code": common.CodeServerError, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
// Enrich with child chunks
|
||||
chunks := nlp.RetrievalByChildren(result.Chunks, []string{kb.TenantID}, h.docEngine, c.Request.Context())
|
||||
|
||||
// KG retrieval (optional)
|
||||
if req.UseKG {
|
||||
chatModel, kgErr := h.modelSvc.GetChatModel(kb.TenantID, "")
|
||||
if kgErr != nil {
|
||||
common.Warn("KG retrieval: failed to get chat model", zap.String("kbID", req.KnowledgeID), zap.Error(kgErr))
|
||||
} else if chatModel != nil {
|
||||
kgPipeline := kg.NewPipeline(
|
||||
h.docEngine,
|
||||
[]string{req.KnowledgeID},
|
||||
[]string{kb.TenantID},
|
||||
req.Query,
|
||||
)
|
||||
kgPipeline.SetChatModel(chatModel)
|
||||
kgPipeline.SetEmbModel(embModel)
|
||||
if kgResult, kgErr := kgPipeline.Retrieval(c.Request.Context()); kgErr == nil {
|
||||
if content, ok := kgResult["content_with_weight"].(string); ok && content != "" {
|
||||
chunks = append([]map[string]interface{}{kgResult}, chunks...)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Collect doc IDs and fetch documents
|
||||
docIDSet := make(map[string]struct{})
|
||||
for _, ch := range chunks {
|
||||
if docID, ok := ch["doc_id"].(string); ok && docID != "" {
|
||||
docIDSet[docID] = struct{}{}
|
||||
}
|
||||
}
|
||||
allDocIDs := make([]string, 0, len(docIDSet))
|
||||
for id := range docIDSet {
|
||||
allDocIDs = append(allDocIDs, id)
|
||||
}
|
||||
|
||||
docMap := make(map[string]*entity.Document)
|
||||
if len(allDocIDs) > 0 {
|
||||
docs, err := h.docDAO.GetByIDs(allDocIDs)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"code": common.CodeServerError, "message": fmt.Sprintf("failed to load documents: %v", err)})
|
||||
return
|
||||
}
|
||||
for _, d := range docs {
|
||||
docMap[d.ID] = d
|
||||
}
|
||||
}
|
||||
|
||||
// Build response
|
||||
records := make([]difyRecord, 0, len(chunks))
|
||||
for _, ch := range chunks {
|
||||
docID, _ := ch["doc_id"].(string)
|
||||
doc := docMap[docID]
|
||||
if doc == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Remove vector to reduce response size
|
||||
delete(ch, "vector")
|
||||
|
||||
meta := make(map[string]interface{})
|
||||
if doc.MetaFields != nil {
|
||||
for k, v := range *doc.MetaFields {
|
||||
meta[k] = v
|
||||
}
|
||||
}
|
||||
meta["doc_id"] = docID
|
||||
meta["document_id"] = docID
|
||||
|
||||
score, _ := ch["similarity"].(float64)
|
||||
title, _ := ch["docnm_kwd"].(string)
|
||||
content, _ := ch["content_with_weight"].(string)
|
||||
|
||||
records = append(records, difyRecord{
|
||||
Content: content,
|
||||
Score: score,
|
||||
Title: title,
|
||||
Metadata: meta,
|
||||
})
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"records": records})
|
||||
}
|
||||
|
||||
// HealthCheck returns a simple health check response.
|
||||
func (h *DifyRetrievalHandler) HealthCheck(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"code": 0, "data": true})
|
||||
}
|
||||
401
internal/handler/dify_retrieval_handler_test.go
Normal file
401
internal/handler/dify_retrieval_handler_test.go
Normal file
@@ -0,0 +1,401 @@
|
||||
//
|
||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"gorm.io/gorm"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"ragflow/internal/common"
|
||||
"ragflow/internal/engine"
|
||||
"ragflow/internal/entity"
|
||||
modelModule "ragflow/internal/entity/models"
|
||||
"ragflow/internal/service/nlp"
|
||||
"ragflow/internal/engine/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// --- Mock implementations ---
|
||||
|
||||
type mockKBService struct {
|
||||
KBServiceIface
|
||||
getByIDFn func(kbID string) (*entity.Knowledgebase, error)
|
||||
accessibleFn func(kbID, userID string) bool
|
||||
}
|
||||
|
||||
func (m *mockKBService) GetByID(kbID string) (*entity.Knowledgebase, error) {
|
||||
if m.getByIDFn != nil {
|
||||
return m.getByIDFn(kbID)
|
||||
}
|
||||
return &entity.Knowledgebase{
|
||||
ID: kbID, TenantID: "tenant1", EmbdID: "text-embedding",
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *mockKBService) Accessible(kbID, userID string) bool {
|
||||
if m.accessibleFn != nil {
|
||||
return m.accessibleFn(kbID, userID)
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
type mockModelService struct {
|
||||
ModelServiceIface
|
||||
getEmbeddingFn func(tenantID, embdID string) (*modelModule.EmbeddingModel, error)
|
||||
getChatModelFn func(tenantID, llmID string) (*modelModule.ChatModel, error)
|
||||
}
|
||||
|
||||
func (m *mockModelService) GetEmbeddingModel(tenantID, embdID string) (*modelModule.EmbeddingModel, error) {
|
||||
if m.getEmbeddingFn != nil {
|
||||
return m.getEmbeddingFn(tenantID, embdID)
|
||||
}
|
||||
return &modelModule.EmbeddingModel{}, nil
|
||||
}
|
||||
|
||||
func (m *mockModelService) GetChatModel(tenantID, llmID string) (*modelModule.ChatModel, error) {
|
||||
if m.getChatModelFn != nil {
|
||||
return m.getChatModelFn(tenantID, llmID)
|
||||
}
|
||||
return &modelModule.ChatModel{}, nil
|
||||
}
|
||||
|
||||
type mockMetadataService struct {
|
||||
MetadataServiceIface
|
||||
getFlattedMetaFn func(kbIDs []string) (common.MetaData, error)
|
||||
labelQuestionFn func(question string, kbs []*entity.Knowledgebase) map[string]float64
|
||||
}
|
||||
|
||||
func (m *mockMetadataService) GetFlattedMetaByKBs(kbIDs []string) (common.MetaData, error) {
|
||||
if m.getFlattedMetaFn != nil {
|
||||
return m.getFlattedMetaFn(kbIDs)
|
||||
}
|
||||
return common.MetaData{}, nil
|
||||
}
|
||||
|
||||
func (m *mockMetadataService) LabelQuestion(question string, kbs []*entity.Knowledgebase) map[string]float64 {
|
||||
if m.labelQuestionFn != nil {
|
||||
return m.labelQuestionFn(question, kbs)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type mockRetrievalService struct {
|
||||
RetrievalServiceIface
|
||||
retrievalFn func(ctx context.Context, req *nlp.RetrievalRequest) (*nlp.RetrievalResult, error)
|
||||
}
|
||||
|
||||
func (m *mockRetrievalService) Retrieval(ctx context.Context, req *nlp.RetrievalRequest) (*nlp.RetrievalResult, error) {
|
||||
if m.retrievalFn != nil {
|
||||
return m.retrievalFn(ctx, req)
|
||||
}
|
||||
return &nlp.RetrievalResult{
|
||||
Chunks: []map[string]interface{}{
|
||||
{"doc_id": "doc1", "docnm_kwd": "Test Doc", "content_with_weight": "test content", "similarity": 0.85},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
type mockDocDAO struct {
|
||||
DocumentDAOIface
|
||||
getByIDsFn func(ids []string) ([]*entity.Document, error)
|
||||
}
|
||||
|
||||
func (m *mockDocDAO) GetByIDs(ids []string) ([]*entity.Document, error) {
|
||||
if m.getByIDsFn != nil {
|
||||
return m.getByIDsFn(ids)
|
||||
}
|
||||
return []*entity.Document{
|
||||
{ID: "doc1", Name: strPtr("Test Doc"), MetaFields: &entity.JSONMap{"author": "Zhang San"}},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// mockDocEngine stubs the DocEngine interface (embed = panic on unimplemented).
|
||||
type mockDocEngine struct {
|
||||
engine.DocEngine
|
||||
}
|
||||
|
||||
func (m *mockDocEngine) Close() error { return nil }
|
||||
func (m *mockDocEngine) Ping(ctx context.Context) error { return nil }
|
||||
func (m *mockDocEngine) GetType() string { return "mock" }
|
||||
func (m *mockDocEngine) Search(ctx context.Context, req *types.SearchRequest) (*types.SearchResult, error) {
|
||||
return &types.SearchResult{}, nil
|
||||
}
|
||||
func (m *mockDocEngine) GetChunk(ctx context.Context, _, _ string, _ []string) (interface{}, error) {
|
||||
return map[string]interface{}{}, nil
|
||||
}
|
||||
|
||||
// --- Helper ---
|
||||
|
||||
func setupDifyTest(userID string) (*DifyRetrievalHandler, *gin.Engine) {
|
||||
h := &DifyRetrievalHandler{
|
||||
kbSvc: &mockKBService{},
|
||||
modelSvc: &mockModelService{},
|
||||
metadataSvc: &mockMetadataService{},
|
||||
retrievalSvc: &mockRetrievalService{},
|
||||
docDAO: &mockDocDAO{},
|
||||
docEngine: &mockDocEngine{},
|
||||
}
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
r.Use(func(c *gin.Context) {
|
||||
c.Set("user", &entity.User{ID: userID})
|
||||
})
|
||||
r.POST("/api/v1/dify/retrieval", h.Retrieval)
|
||||
r.GET("/api/v1/dify/retrieval", h.Retrieval)
|
||||
r.GET("/api/v1/dify/retrieval/health", h.HealthCheck)
|
||||
return h, r
|
||||
}
|
||||
|
||||
func setupDifyTestNoAuth() (*DifyRetrievalHandler, *gin.Engine) {
|
||||
h := &DifyRetrievalHandler{
|
||||
kbSvc: &mockKBService{},
|
||||
modelSvc: &mockModelService{},
|
||||
metadataSvc: &mockMetadataService{},
|
||||
retrievalSvc: &mockRetrievalService{},
|
||||
docDAO: &mockDocDAO{},
|
||||
docEngine: &mockDocEngine{},
|
||||
}
|
||||
gin.SetMode(gin.TestMode)
|
||||
r := gin.New()
|
||||
r.POST("/api/v1/dify/retrieval", h.Retrieval)
|
||||
return h, r
|
||||
}
|
||||
|
||||
// --- Tests ---
|
||||
|
||||
func TestDifyRetrieval_HealthCheck(t *testing.T) {
|
||||
_, r := setupDifyTest("user1")
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("GET", "/api/v1/dify/retrieval/health", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected 200, got %d", w.Code)
|
||||
}
|
||||
var resp map[string]interface{}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if resp["data"] != true {
|
||||
t.Errorf("expected data=true, got %v", resp["data"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestDifyRetrieval_Basic(t *testing.T) {
|
||||
_, r := setupDifyTest("user1")
|
||||
body := `{"knowledge_id": "kb1", "query": "test question"}`
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("POST", "/api/v1/dify/retrieval", strings.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
var resp map[string]interface{}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
records, ok := resp["records"].([]interface{})
|
||||
if !ok || len(records) == 0 {
|
||||
t.Errorf("expected non-empty records, got %v", resp["records"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestDifyRetrieval_GET(t *testing.T) {
|
||||
_, r := setupDifyTest("user1")
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("GET", "/api/v1/dify/retrieval?knowledge_id=kb1&query=test", nil)
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestDifyRetrieval_MissingArgs(t *testing.T) {
|
||||
_, r := setupDifyTest("user1")
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
}{
|
||||
{"no knowledge_id", `{"query": "test"}`},
|
||||
{"no query", `{"knowledge_id": "kb1"}`},
|
||||
{"empty body", `{}`},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("POST", "/api/v1/dify/retrieval", strings.NewReader(tc.body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400, got %d", w.Code)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDifyRetrieval_KBNotFound(t *testing.T) {
|
||||
h, r := setupDifyTest("user1")
|
||||
h.kbSvc = &mockKBService{
|
||||
getByIDFn: func(kbID string) (*entity.Knowledgebase, error) {
|
||||
return nil, gorm.ErrRecordNotFound
|
||||
},
|
||||
}
|
||||
w := httptest.NewRecorder()
|
||||
body := `{"knowledge_id": "nonexistent", "query": "test"}`
|
||||
req, _ := http.NewRequest("POST", "/api/v1/dify/retrieval", strings.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("expected 404, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDifyRetrieval_NoAuth(t *testing.T) {
|
||||
_, r := setupDifyTestNoAuth()
|
||||
w := httptest.NewRecorder()
|
||||
body := `{"knowledge_id": "kb1", "query": "test"}`
|
||||
req, _ := http.NewRequest("POST", "/api/v1/dify/retrieval", strings.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Errorf("expected 401, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDifyRetrieval_Unauthorized(t *testing.T) {
|
||||
h, r := setupDifyTest("user1")
|
||||
h.kbSvc = &mockKBService{
|
||||
accessibleFn: func(kbID, userID string) bool { return false },
|
||||
}
|
||||
w := httptest.NewRecorder()
|
||||
body := `{"knowledge_id": "kb1", "query": "test"}`
|
||||
req, _ := http.NewRequest("POST", "/api/v1/dify/retrieval", strings.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusUnauthorized {
|
||||
t.Errorf("expected 401, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDifyRetrieval_WithMetadataFilter(t *testing.T) {
|
||||
h, r := setupDifyTest("user1")
|
||||
h.metadataSvc = &mockMetadataService{
|
||||
getFlattedMetaFn: func(kbIDs []string) (common.MetaData, error) {
|
||||
return common.MetaData{}, nil
|
||||
},
|
||||
}
|
||||
body := `{"knowledge_id":"kb1","query":"test","metadata_condition":{"conditions":[{"name":"author","comparison_operator":"eq","value":"Zhang San"}],"logic":"and"}}`
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("POST", "/api/v1/dify/retrieval", strings.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestDifyRetrieval_InvalidJSON(t *testing.T) {
|
||||
_, r := setupDifyTest("user1")
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("POST", "/api/v1/dify/retrieval", strings.NewReader("{invalid json"))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected 400, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDifyRetrieval_UseKG(t *testing.T) {
|
||||
h, r := setupDifyTest("user1")
|
||||
h.metadataSvc = &mockMetadataService{
|
||||
labelQuestionFn: func(question string, kbs []*entity.Knowledgebase) map[string]float64 {
|
||||
return map[string]float64{"tag_1": 0.8}
|
||||
},
|
||||
}
|
||||
body := `{"knowledge_id":"kb1","query":"test","use_kg":true}`
|
||||
w := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("POST", "/api/v1/dify/retrieval", strings.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected 200, got %d: %s", w.Code, w.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func strPtr(s string) *string { return &s }
|
||||
|
||||
func TestDifyRetrieval_KBDBError(t *testing.T) {
|
||||
h, r := setupDifyTest("user1")
|
||||
h.kbSvc = &mockKBService{
|
||||
getByIDFn: func(kbID string) (*entity.Knowledgebase, error) {
|
||||
return nil, errors.New("connection refused")
|
||||
},
|
||||
}
|
||||
w := httptest.NewRecorder()
|
||||
body := `{"knowledge_id": "kb1", "query": "test"}`
|
||||
req, _ := http.NewRequest("POST", "/api/v1/dify/retrieval", strings.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Errorf("expected 500 for DB error, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDifyRetrieval_DocLoadError(t *testing.T) {
|
||||
h, r := setupDifyTest("user1")
|
||||
h.docDAO = &mockDocDAO{
|
||||
getByIDsFn: func(ids []string) ([]*entity.Document, error) {
|
||||
return nil, errors.New("db unavailable")
|
||||
},
|
||||
}
|
||||
w := httptest.NewRecorder()
|
||||
body := `{"knowledge_id": "kb1", "query": "test"}`
|
||||
req, _ := http.NewRequest("POST", "/api/v1/dify/retrieval", strings.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusInternalServerError {
|
||||
t.Errorf("expected 500 for doc load error, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDifyRetrieval_RetrievalNotFound(t *testing.T) {
|
||||
h, r := setupDifyTest("user1")
|
||||
h.retrievalSvc = &mockRetrievalService{
|
||||
retrievalFn: func(ctx context.Context, req *nlp.RetrievalRequest) (*nlp.RetrievalResult, error) {
|
||||
return nil, errors.New("no chunk found: not_found")
|
||||
},
|
||||
}
|
||||
w := httptest.NewRecorder()
|
||||
body := `{"knowledge_id": "kb1", "query": "test"}`
|
||||
req, _ := http.NewRequest("POST", "/api/v1/dify/retrieval", strings.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusNotFound {
|
||||
t.Errorf("expected 404 for not_found, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
@@ -43,6 +43,7 @@ type Router struct {
|
||||
providerHandler *handler.ProviderHandler
|
||||
agentHandler *handler.AgentHandler
|
||||
relatedQuestionsHandler *handler.SearchbotHandler
|
||||
difyRetrievalHandler *handler.DifyRetrievalHandler
|
||||
}
|
||||
|
||||
// NewRouter create router
|
||||
@@ -67,6 +68,7 @@ func NewRouter(
|
||||
providerHandler *handler.ProviderHandler,
|
||||
agentHandler *handler.AgentHandler,
|
||||
relatedQuestionsHandler *handler.SearchbotHandler,
|
||||
difyRetrievalHandler *handler.DifyRetrievalHandler,
|
||||
) *Router {
|
||||
return &Router{
|
||||
authHandler: authHandler,
|
||||
@@ -89,6 +91,7 @@ func NewRouter(
|
||||
providerHandler: providerHandler,
|
||||
agentHandler: agentHandler,
|
||||
relatedQuestionsHandler: relatedQuestionsHandler,
|
||||
difyRetrievalHandler: difyRetrievalHandler,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -517,6 +520,14 @@ func (r *Router) Setup(engine *gin.Engine) {
|
||||
|
||||
}
|
||||
|
||||
// Dify retrieval routes
|
||||
dify := authorized.Group("/api/v1/dify")
|
||||
{
|
||||
dify.POST("/retrieval", r.difyRetrievalHandler.Retrieval)
|
||||
dify.GET("/retrieval", r.difyRetrievalHandler.Retrieval)
|
||||
}
|
||||
apiNoAuth.GET("/dify/retrieval/health", r.difyRetrievalHandler.HealthCheck)
|
||||
|
||||
// Handle undefined routes
|
||||
engine.NoRoute(handler.HandleNoRoute)
|
||||
}
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package service
|
||||
package kg
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -30,9 +30,9 @@ import (
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// KGSearchPipeline encapsulates the knowledge graph retrieval pipeline.
|
||||
// Pipeline encapsulates the knowledge graph retrieval pipeline.
|
||||
// Matches Python: rag/graphrag/search.py::KGSearch
|
||||
type KGSearchPipeline struct {
|
||||
type Pipeline struct {
|
||||
docEngine engine.DocEngine
|
||||
chatModel *modelModule.ChatModel
|
||||
embModel *modelModule.EmbeddingModel
|
||||
@@ -50,51 +50,51 @@ type KGSearchPipeline struct {
|
||||
maxToken int
|
||||
}
|
||||
|
||||
// KGSearchOption configures a KGSearchPipeline.
|
||||
type KGSearchOption func(*KGSearchPipeline)
|
||||
// Option configures a Pipeline.
|
||||
type Option func(*Pipeline)
|
||||
|
||||
// WithKGSimThreshold sets the similarity threshold for entity and relation search.
|
||||
// WithSimThreshold sets the similarity threshold for entity and relation search.
|
||||
// Default: 0.3 (matches Python ent_sim_threshold, rel_sim_threshold).
|
||||
func WithKGSimThreshold(v float64) KGSearchOption {
|
||||
return func(p *KGSearchPipeline) { p.entSimThreshold = v; p.relSimThreshold = v }
|
||||
func WithSimThreshold(v float64) Option {
|
||||
return func(p *Pipeline) { p.entSimThreshold = v; p.relSimThreshold = v }
|
||||
}
|
||||
|
||||
// WithKGDenseTopK sets the TopK for dense vector search.
|
||||
// WithDenseTopK sets the TopK for dense vector search.
|
||||
// Default: 1024 (matches Python get_vector topk).
|
||||
func WithKGDenseTopK(v int) KGSearchOption {
|
||||
return func(p *KGSearchPipeline) { p.denseTopK = v }
|
||||
func WithDenseTopK(v int) Option {
|
||||
return func(p *Pipeline) { p.denseTopK = v }
|
||||
}
|
||||
|
||||
// NewKGSearchPipeline creates a KG search pipeline with the given dependencies.
|
||||
// NewPipeline creates a KG search pipeline with the given dependencies.
|
||||
//
|
||||
// docEngine: search engine backend
|
||||
// kbIDs: knowledge base IDs to search
|
||||
// tenantIDs: tenant IDs (converted to index names internally)
|
||||
// question: user query string
|
||||
// opts: optional configuration (WithKGSimThreshold, WithKGDenseTopK)
|
||||
// opts: optional configuration (WithSimThreshold, WithDenseTopK)
|
||||
//
|
||||
// chatModel and embModel should be set via WithChatModel/WithEmbModel setters
|
||||
// or passed directly after construction.
|
||||
func NewKGSearchPipeline(
|
||||
func NewPipeline(
|
||||
docEngine engine.DocEngine,
|
||||
kbIDs []string,
|
||||
tenantIDs []string,
|
||||
question string,
|
||||
opts ...KGSearchOption,
|
||||
) *KGSearchPipeline {
|
||||
opts ...Option,
|
||||
) *Pipeline {
|
||||
idxnms := make([]string, len(tenantIDs))
|
||||
for i, tid := range tenantIDs {
|
||||
idxnms[i] = indexName(tid)
|
||||
}
|
||||
p := &KGSearchPipeline{
|
||||
p := &Pipeline{
|
||||
docEngine: docEngine,
|
||||
kbIDs: kbIDs,
|
||||
idxnms: idxnms,
|
||||
question: question,
|
||||
|
||||
entSimThreshold: defaultKGSimThreshold,
|
||||
relSimThreshold: defaultKGSimThreshold,
|
||||
denseTopK: defaultKGDenseTopK,
|
||||
entSimThreshold: defaultSimThreshold,
|
||||
relSimThreshold: defaultSimThreshold,
|
||||
denseTopK: defaultDenseTopK,
|
||||
entTopN: 6,
|
||||
relTopN: 6,
|
||||
commTopN: 1,
|
||||
@@ -106,12 +106,22 @@ func NewKGSearchPipeline(
|
||||
return p
|
||||
}
|
||||
|
||||
// SetChatModel sets the chat model for LLM-based query rewrite.
|
||||
func (p *Pipeline) SetChatModel(chatModel *modelModule.ChatModel) {
|
||||
p.chatModel = chatModel
|
||||
}
|
||||
|
||||
// SetEmbModel sets the embedding model for dense/hybrid search.
|
||||
func (p *Pipeline) SetEmbModel(embModel *modelModule.EmbeddingModel) {
|
||||
p.embModel = embModel
|
||||
}
|
||||
|
||||
// Retrieval runs the full KG retrieval pipeline and returns a synthetic chunk.
|
||||
func (p *KGSearchPipeline) Retrieval(ctx context.Context) (map[string]interface{}, error) {
|
||||
func (p *Pipeline) Retrieval(ctx context.Context) (map[string]interface{}, error) {
|
||||
// 1. Query rewrite via LLM, or fall back to raw question
|
||||
ty2entsJSON := ""
|
||||
if p.chatModel != nil {
|
||||
typeSamples, err := searchKGTypeSamples(ctx, p.docEngine, p.idxnms, p.kbIDs)
|
||||
typeSamples, err := searchTypeSamples(ctx, p.docEngine, p.idxnms, p.kbIDs)
|
||||
if err != nil {
|
||||
common.Warn("KG type samples search failed", zap.String("kbIDs", fmt.Sprint(p.kbIDs)))
|
||||
}
|
||||
@@ -157,11 +167,11 @@ func (p *KGSearchPipeline) Retrieval(ctx context.Context) (map[string]interface{
|
||||
scoredRels := SortAndTrimRelations(relsFromText, p.relTopN)
|
||||
|
||||
// 7. Build KG content with token budget
|
||||
entsRelsContent := BuildKGContent(scoredEnts, scoredRels, p.maxToken)
|
||||
entsRelsContent := BuildContent(scoredEnts, scoredRels, p.maxToken)
|
||||
used := NumTokensFromString(entsRelsContent)
|
||||
remaining := p.maxToken - used
|
||||
// 8. Search community reports with remaining token budget
|
||||
communityContent := searchKGCommunityContent(ctx, p.docEngine, p.idxnms, p.kbIDs, scoredEnts, p.commTopN, &remaining)
|
||||
communityContent := searchCommunityContent(ctx, p.docEngine, p.idxnms, p.kbIDs, scoredEnts, p.commTopN, &remaining)
|
||||
|
||||
// 9. Build synthetic chunk
|
||||
return map[string]interface{}{
|
||||
@@ -182,7 +192,7 @@ func (p *KGSearchPipeline) Retrieval(ctx context.Context) (map[string]interface{
|
||||
}
|
||||
|
||||
// searchEntities searches KG entities by keyword text and optional dense vector.
|
||||
func (p *KGSearchPipeline) searchEntities(ctx context.Context, entities []string) (map[string]*KGEntity, error) {
|
||||
func (p *Pipeline) searchEntities(ctx context.Context, entities []string) (map[string]*KGEntity, error) {
|
||||
entsReq := &types.SearchRequest{
|
||||
IndexNames: p.idxnms,
|
||||
KbIDs: p.kbIDs,
|
||||
@@ -207,14 +217,14 @@ func (p *KGSearchPipeline) searchEntities(ctx context.Context, entities []string
|
||||
if name == "" {
|
||||
continue
|
||||
}
|
||||
e := kgEntityFromChunk(name, chunk)
|
||||
e := entityFromChunk(name, chunk)
|
||||
result[name] = &e
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// searchEntityTypes searches KG entities by type keywords.
|
||||
func (p *KGSearchPipeline) searchEntityTypes(ctx context.Context, typeKeywords []string) map[string]struct{} {
|
||||
func (p *Pipeline) searchEntityTypes(ctx context.Context, typeKeywords []string) map[string]struct{} {
|
||||
typesReq := &types.SearchRequest{
|
||||
IndexNames: p.idxnms,
|
||||
KbIDs: p.kbIDs,
|
||||
@@ -244,7 +254,7 @@ func (p *KGSearchPipeline) searchEntityTypes(ctx context.Context, typeKeywords [
|
||||
}
|
||||
|
||||
// searchRelations searches KG relations by entity text and optional dense vector.
|
||||
func (p *KGSearchPipeline) searchRelations(ctx context.Context, entities []string) map[Edge]*KGRelation {
|
||||
func (p *Pipeline) searchRelations(ctx context.Context, entities []string) map[Edge]*KGRelation {
|
||||
relsReq := &types.SearchRequest{
|
||||
IndexNames: p.idxnms,
|
||||
KbIDs: p.kbIDs,
|
||||
@@ -265,7 +275,7 @@ func (p *KGSearchPipeline) searchRelations(ctx context.Context, entities []strin
|
||||
common.Warn("KG relations search failed", zap.String("kbIDs", fmt.Sprint(p.kbIDs)))
|
||||
} else {
|
||||
for _, chunk := range FilterChunksByScore(relsResult.Chunks, p.relSimThreshold) {
|
||||
edge, rel := kgRelationFromChunk(chunk)
|
||||
edge, rel := relationFromChunk(chunk)
|
||||
if edge.From == "" || edge.To == "" {
|
||||
continue
|
||||
}
|
||||
@@ -1,20 +1,4 @@
|
||||
//
|
||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package service
|
||||
package kg
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -27,67 +11,9 @@ import (
|
||||
modelModule "ragflow/internal/entity/models"
|
||||
)
|
||||
|
||||
// indexName builds the search index name from a tenant ID.
|
||||
// Matches Python: rag/nlp/search.py::index_name()
|
||||
func indexName(tenantID string) string {
|
||||
return "ragflow_" + tenantID
|
||||
}
|
||||
|
||||
// Python alignment defaults — match rag/graphrag/search.py retrieval() params
|
||||
const (
|
||||
defaultKGSimThreshold = 0.3 // Python: ent_sim_threshold, rel_sim_threshold
|
||||
defaultKGDenseTopK = 1024 // Python: get_vector() topk
|
||||
)
|
||||
|
||||
// kgEntityFromChunk parses a single entity chunk into a KGEntity.
|
||||
func kgEntityFromChunk(name string, chunk map[string]interface{}) KGEntity {
|
||||
e := KGEntity{}
|
||||
if v, ok := chunk["_score"].(float64); ok {
|
||||
e.Similarity = v
|
||||
} else if v, ok := chunk["score"].(float64); ok {
|
||||
e.Similarity = v
|
||||
}
|
||||
if v, ok := chunk["rank_flt"].(float64); ok {
|
||||
e.PageRank = v
|
||||
}
|
||||
e.Description, _ = chunk["content_with_weight"].(string)
|
||||
if raw, ok := chunk["n_hop_with_weight"].(string); ok && raw != "" {
|
||||
var nhopData []struct {
|
||||
Path []string `json:"path"`
|
||||
Weights []float64 `json:"weights"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(raw), &nhopData); err == nil {
|
||||
for _, item := range nhopData {
|
||||
e.NhopEnts = append(e.NhopEnts, NhopEntity{
|
||||
Path: item.Path,
|
||||
Weights: item.Weights,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
return e
|
||||
}
|
||||
|
||||
// kgRelationFromChunk parses a single relation chunk into a KGRelation.
|
||||
func kgRelationFromChunk(chunk map[string]interface{}) (Edge, KGRelation) {
|
||||
r := KGRelation{}
|
||||
r.Description, _ = chunk["content_with_weight"].(string)
|
||||
if v, ok := chunk["weight_int"].(float64); ok {
|
||||
r.PageRank = float64(v)
|
||||
} else if v, ok := chunk["weight_int"].(int); ok {
|
||||
r.PageRank = float64(v)
|
||||
}
|
||||
from, _ := chunk["from_entity_kwd"].(string)
|
||||
to, _ := chunk["to_entity_kwd"].(string)
|
||||
return Edge{From: from, To: to}, r
|
||||
}
|
||||
|
||||
// KGSearchRetrieval performs a full knowledge graph retrieval and returns
|
||||
// a synthetic chunk to be inserted into search results.
|
||||
// Corresponds to Python: rag/graphrag/search.py::KGSearch.retrieval()
|
||||
//
|
||||
// This is a convenience wrapper around KGSearchPipeline.
|
||||
func KGSearchRetrieval(
|
||||
// Retrieval performs a full knowledge graph retrieval and returns
|
||||
// a synthetic chunk. Convenience wrapper around Pipeline.
|
||||
func Retrieval(
|
||||
ctx context.Context,
|
||||
docEngine engine.DocEngine,
|
||||
chatModel *modelModule.ChatModel,
|
||||
@@ -96,16 +22,16 @@ func KGSearchRetrieval(
|
||||
tenantIDs []string,
|
||||
question string,
|
||||
) (map[string]interface{}, error) {
|
||||
p := &KGSearchPipeline{
|
||||
p := &Pipeline{
|
||||
docEngine: docEngine,
|
||||
chatModel: chatModel,
|
||||
embModel: embModel,
|
||||
kbIDs: kbIDs,
|
||||
idxnms: makeIndexNames(tenantIDs),
|
||||
question: question,
|
||||
entSimThreshold: defaultKGSimThreshold,
|
||||
relSimThreshold: defaultKGSimThreshold,
|
||||
denseTopK: defaultKGDenseTopK,
|
||||
entSimThreshold: defaultSimThreshold,
|
||||
relSimThreshold: defaultSimThreshold,
|
||||
denseTopK: defaultDenseTopK,
|
||||
entTopN: 6,
|
||||
relTopN: 6,
|
||||
commTopN: 1,
|
||||
@@ -123,8 +49,13 @@ func makeIndexNames(tenantIDs []string) []string {
|
||||
return idxnms
|
||||
}
|
||||
|
||||
// searchKGTypeSamples searches for ty2ents data.
|
||||
func searchKGTypeSamples(ctx context.Context, docEngine engine.DocEngine, idxnms []string, kbIDs []string) (map[string][]string, error) {
|
||||
// indexName builds the search index name from a tenant ID.
|
||||
func indexName(tenantID string) string {
|
||||
return "ragflow_" + tenantID
|
||||
}
|
||||
|
||||
// searchTypeSamples searches for ty2ents data.
|
||||
func searchTypeSamples(ctx context.Context, docEngine engine.DocEngine, idxnms []string, kbIDs []string) (map[string][]string, error) {
|
||||
req := &types.SearchRequest{
|
||||
IndexNames: idxnms,
|
||||
KbIDs: kbIDs,
|
||||
@@ -153,8 +84,8 @@ func searchKGTypeSamples(ctx context.Context, docEngine engine.DocEngine, idxnms
|
||||
return typeMap, nil
|
||||
}
|
||||
|
||||
// searchKGCommunityContent searches for community reports and formats them.
|
||||
func searchKGCommunityContent(ctx context.Context, docEngine engine.DocEngine, idxnms []string, kbIDs []string, scoredEnts []ScoredEntity, topN int, maxToken *int) string {
|
||||
// searchCommunityContent searches for community reports and formats them.
|
||||
func searchCommunityContent(ctx context.Context, docEngine engine.DocEngine, idxnms []string, kbIDs []string, scoredEnts []ScoredEntity, topN int, maxToken *int) string {
|
||||
if maxToken == nil || len(scoredEnts) == 0 || *maxToken <= 0 {
|
||||
return ""
|
||||
}
|
||||
@@ -189,7 +120,6 @@ func searchKGCommunityContent(ctx context.Context, docEngine engine.DocEngine, i
|
||||
if title == "" && raw == "" {
|
||||
continue
|
||||
}
|
||||
// Parse JSON for nested report/evidences fields (Python: json.loads)
|
||||
report := raw
|
||||
evidence := ""
|
||||
var parsed map[string]interface{}
|
||||
@@ -212,30 +142,52 @@ func searchKGCommunityContent(ctx context.Context, docEngine engine.DocEngine, i
|
||||
return bld
|
||||
}
|
||||
|
||||
// buildMatchDenseExpr constructs a MatchDenseExpr from an embedding vector.
|
||||
// This is a pure function — no I/O, no external dependencies.
|
||||
func buildMatchDenseExpr(vector []float64, topN int, similarity float64) *types.MatchDenseExpr {
|
||||
vectorColumnName := fmt.Sprintf("q_%d_vec", len(vector))
|
||||
return &types.MatchDenseExpr{
|
||||
VectorColumnName: vectorColumnName,
|
||||
EmbeddingData: vector,
|
||||
EmbeddingDataType: "float",
|
||||
DistanceType: "cosine",
|
||||
TopN: topN,
|
||||
ExtraOptions: map[string]interface{}{"similarity": similarity},
|
||||
// entityFromChunk parses a single entity chunk into a KGEntity.
|
||||
func entityFromChunk(name string, chunk map[string]interface{}) KGEntity {
|
||||
e := KGEntity{}
|
||||
if v, ok := chunk["_score"].(float64); ok {
|
||||
e.Similarity = v
|
||||
} else if v, ok := chunk["score"].(float64); ok {
|
||||
e.Similarity = v
|
||||
}
|
||||
if v, ok := chunk["rank_flt"].(float64); ok {
|
||||
e.PageRank = v
|
||||
}
|
||||
e.Description, _ = chunk["content_with_weight"].(string)
|
||||
if raw, ok := chunk["n_hop_with_weight"].(string); ok && raw != "" {
|
||||
var nhopData []struct {
|
||||
Path []string `json:"path"`
|
||||
Weights []float64 `json:"weights"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(raw), &nhopData); err == nil {
|
||||
for _, item := range nhopData {
|
||||
e.NhopEnts = append(e.NhopEnts, NhopEntity{
|
||||
Path: item.Path,
|
||||
Weights: item.Weights,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
return e
|
||||
}
|
||||
|
||||
// buildFusionExpr constructs a FusionExpr for weighted-sum hybrid search.
|
||||
// This is a pure function — no I/O, no external dependencies.
|
||||
func buildFusionExpr(textWeight, vectorWeight float64, topN int) *types.FusionExpr {
|
||||
return &types.FusionExpr{
|
||||
Method: "weighted_sum",
|
||||
TopN: topN,
|
||||
FusionParams: map[string]interface{}{
|
||||
"weights": fmt.Sprintf("%.2f,%.2f", textWeight, vectorWeight),
|
||||
},
|
||||
// relationFromChunk parses a single relation chunk into a KGRelation.
|
||||
func relationFromChunk(chunk map[string]interface{}) (Edge, KGRelation) {
|
||||
r := KGRelation{}
|
||||
r.Description, _ = chunk["content_with_weight"].(string)
|
||||
if v, ok := chunk["_score"].(float64); ok {
|
||||
r.Sim = v
|
||||
} else if v, ok := chunk["score"].(float64); ok {
|
||||
r.Sim = v
|
||||
}
|
||||
if v, ok := chunk["weight_int"].(float64); ok {
|
||||
r.PageRank = float64(v)
|
||||
} else if v, ok := chunk["weight_int"].(int); ok {
|
||||
r.PageRank = float64(v)
|
||||
}
|
||||
from, _ := chunk["from_entity_kwd"].(string)
|
||||
to, _ := chunk["to_entity_kwd"].(string)
|
||||
return Edge{From: from, To: to}, r
|
||||
}
|
||||
|
||||
// buildSearchExprs constructs MatchExprs for KG entity/relation search.
|
||||
@@ -252,12 +204,35 @@ func buildSearchExprs(embModel *modelModule.EmbeddingModel, matchText *types.Mat
|
||||
return []interface{}{matchText}
|
||||
}
|
||||
denseExpr := buildMatchDenseExpr(embeddings[0].Embedding, denseTopK, simThreshold)
|
||||
fusionExpr := buildFusionExpr(0.5, 0.5, matchText.TopN)
|
||||
fusionExpr := buildFusionExpr(defaultTextWeight, defaultVectorWeight, matchText.TopN)
|
||||
return []interface{}{matchText, denseExpr, fusionExpr}
|
||||
}
|
||||
|
||||
// buildMatchDenseExpr constructs a MatchDenseExpr from an embedding vector.
|
||||
func buildMatchDenseExpr(vector []float64, topN int, similarity float64) *types.MatchDenseExpr {
|
||||
vectorColumnName := fmt.Sprintf("q_%d_vec", len(vector))
|
||||
return &types.MatchDenseExpr{
|
||||
VectorColumnName: vectorColumnName,
|
||||
EmbeddingData: vector,
|
||||
EmbeddingDataType: "float",
|
||||
DistanceType: "cosine",
|
||||
TopN: topN,
|
||||
ExtraOptions: map[string]interface{}{"similarity": similarity},
|
||||
}
|
||||
}
|
||||
|
||||
// buildFusionExpr constructs a FusionExpr for weighted-sum hybrid search.
|
||||
func buildFusionExpr(textWeight, vectorWeight float64, topN int) *types.FusionExpr {
|
||||
return &types.FusionExpr{
|
||||
Method: "weighted_sum",
|
||||
TopN: topN,
|
||||
FusionParams: map[string]interface{}{
|
||||
"weights": fmt.Sprintf("%.2f,%.2f", textWeight, vectorWeight),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// queryRewrite attempts LLM-based query rewrite, falling back to raw question.
|
||||
// ty2entsJSON is the JSON-encoded type→entities mapping for prompt context.
|
||||
func queryRewrite(chatModel *modelModule.ChatModel, question string, ty2entsJSON string) (typeKeywords, entities []string) {
|
||||
if question == "" {
|
||||
return nil, nil
|
||||
@@ -276,6 +251,14 @@ func queryRewrite(chatModel *modelModule.ChatModel, question string, ty2entsJSON
|
||||
}
|
||||
}
|
||||
}
|
||||
// Fallback: use raw question as single entity
|
||||
return nil, []string{question}
|
||||
}
|
||||
|
||||
// Python alignment defaults
|
||||
const (
|
||||
defaultSimThreshold = 0.3
|
||||
defaultDenseTopK = 1024
|
||||
// defaultTextWeight / defaultVectorWeight are fusion weights for hybrid search (equal by default).
|
||||
defaultTextWeight = 0.5
|
||||
defaultVectorWeight = 0.5
|
||||
)
|
||||
@@ -14,7 +14,7 @@
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package service
|
||||
package kg
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -56,16 +56,16 @@ func (m *mockRetrievalEngine) Search(ctx context.Context, req *types.SearchReque
|
||||
return &types.SearchResult{}, nil
|
||||
}
|
||||
|
||||
// --- kgEntityFromChunk ---
|
||||
// --- entityFromChunk ---
|
||||
|
||||
func TestKgEntityFromChunk_Basic(t *testing.T) {
|
||||
func TestEntityFromChunk_Basic(t *testing.T) {
|
||||
chunk := map[string]interface{}{
|
||||
"_score": 0.85,
|
||||
"rank_flt": 0.9,
|
||||
"content_with_weight": "Founder of SpaceX",
|
||||
"n_hop_with_weight": `[{"path":["A","B"],"weights":[0.8]}]`,
|
||||
}
|
||||
e := kgEntityFromChunk("Elon Musk", chunk)
|
||||
e := entityFromChunk("Elon Musk", chunk)
|
||||
if e.Similarity != 0.85 {
|
||||
t.Errorf("expected Sim=0.85, got %f", e.Similarity)
|
||||
}
|
||||
@@ -80,32 +80,32 @@ func TestKgEntityFromChunk_Basic(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestKgEntityFromChunk_ScoreFallback(t *testing.T) {
|
||||
func TestEntityFromChunk_ScoreFallback(t *testing.T) {
|
||||
chunk := map[string]interface{}{"score": 0.75}
|
||||
e := kgEntityFromChunk("Test", chunk)
|
||||
e := entityFromChunk("Test", chunk)
|
||||
if e.Similarity != 0.75 {
|
||||
t.Errorf("expected Sim=0.75 from score field, got %f", e.Similarity)
|
||||
}
|
||||
}
|
||||
|
||||
func TestKgEntityFromChunk_MissingFields(t *testing.T) {
|
||||
func TestEntityFromChunk_MissingFields(t *testing.T) {
|
||||
chunk := map[string]interface{}{}
|
||||
e := kgEntityFromChunk("Empty", chunk)
|
||||
e := entityFromChunk("Empty", chunk)
|
||||
if e.Similarity != 0 || e.PageRank != 0 || len(e.NhopEnts) != 0 {
|
||||
t.Errorf("expected zero defaults, got %+v", e)
|
||||
}
|
||||
}
|
||||
|
||||
// --- kgRelationFromChunk ---
|
||||
// --- relationFromChunk ---
|
||||
|
||||
func TestKgRelationFromChunk_Basic(t *testing.T) {
|
||||
func TestRelationFromChunk_Basic(t *testing.T) {
|
||||
chunk := map[string]interface{}{
|
||||
"from_entity_kwd": "Elon Musk",
|
||||
"to_entity_kwd": "SpaceX",
|
||||
"weight_int": float64(5),
|
||||
"content_with_weight": "Founder",
|
||||
}
|
||||
edge, rel := kgRelationFromChunk(chunk)
|
||||
edge, rel := relationFromChunk(chunk)
|
||||
if edge.From != "Elon Musk" || edge.To != "SpaceX" {
|
||||
t.Errorf("expected Elon Musk→SpaceX, got %v", edge)
|
||||
}
|
||||
@@ -114,17 +114,17 @@ func TestKgRelationFromChunk_Basic(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestKgRelationFromChunk_MissingFrom(t *testing.T) {
|
||||
func TestRelationFromChunk_MissingFrom(t *testing.T) {
|
||||
chunk := map[string]interface{}{"to_entity_kwd": "B"}
|
||||
edge, _ := kgRelationFromChunk(chunk)
|
||||
edge, _ := relationFromChunk(chunk)
|
||||
if edge.From != "" {
|
||||
t.Error("expected empty from")
|
||||
}
|
||||
}
|
||||
|
||||
// --- searchKGTypeSamples ---
|
||||
// --- searchTypeSamples ---
|
||||
|
||||
func TestSearchKGTypeSamples_Success(t *testing.T) {
|
||||
func TestSearchTypeSamples_Success(t *testing.T) {
|
||||
data, _ := json.Marshal(map[string][]string{"PERSON": {"Elon Musk"}})
|
||||
mock := &mockRetrievalEngine{
|
||||
results: map[string]*types.SearchResult{
|
||||
@@ -133,7 +133,7 @@ func TestSearchKGTypeSamples_Success(t *testing.T) {
|
||||
}},
|
||||
},
|
||||
}
|
||||
result, err := searchKGTypeSamples(context.Background(), mock, []string{"ragflow_tenant1"}, []string{"kb1"})
|
||||
result, err := searchTypeSamples(context.Background(), mock, []string{"ragflow_tenant1"}, []string{"kb1"})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
@@ -142,9 +142,9 @@ func TestSearchKGTypeSamples_Success(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearchKGTypeSamples_Empty(t *testing.T) {
|
||||
func TestSearchTypeSamples_Empty(t *testing.T) {
|
||||
mock := &mockRetrievalEngine{}
|
||||
result, err := searchKGTypeSamples(context.Background(), mock, []string{"ragflow_tenant1"}, []string{"kb1"})
|
||||
result, err := searchTypeSamples(context.Background(), mock, []string{"ragflow_tenant1"}, []string{"kb1"})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
@@ -153,9 +153,9 @@ func TestSearchKGTypeSamples_Empty(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// --- KGSearchRetrieval ---
|
||||
// --- Retrieval ---
|
||||
|
||||
func TestKGSearchRetrieval_Basic(t *testing.T) {
|
||||
func TestRetrieval_Basic(t *testing.T) {
|
||||
mock := &mockRetrievalEngine{
|
||||
results: map[string]*types.SearchResult{
|
||||
"entity": {Chunks: []map[string]interface{}{
|
||||
@@ -172,9 +172,9 @@ func TestKGSearchRetrieval_Basic(t *testing.T) {
|
||||
}},
|
||||
},
|
||||
}
|
||||
result, err := KGSearchRetrieval(context.Background(), mock, nil, nil, []string{"kb1"}, []string{"tenant1"}, "Elon Musk")
|
||||
result, err := Retrieval(context.Background(), mock, nil, nil, []string{"kb1"}, []string{"tenant1"}, "Elon Musk")
|
||||
if err != nil {
|
||||
t.Fatalf("KGSearchRetrieval failed: %v", err)
|
||||
t.Fatalf("Retrieval failed: %v", err)
|
||||
}
|
||||
if result == nil {
|
||||
t.Fatal("expected non-nil result")
|
||||
@@ -194,11 +194,11 @@ func TestKGSearchRetrieval_Basic(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestKGSearchRetrieval_NoEntities(t *testing.T) {
|
||||
func TestRetrieval_NoEntities(t *testing.T) {
|
||||
mock := &mockRetrievalEngine{}
|
||||
result, err := KGSearchRetrieval(context.Background(), mock, nil, nil, []string{"kb1"}, []string{"tenant1"}, "test")
|
||||
result, err := Retrieval(context.Background(), mock, nil, nil, []string{"kb1"}, []string{"tenant1"}, "test")
|
||||
if err != nil {
|
||||
t.Fatalf("KGSearchRetrieval failed: %v", err)
|
||||
t.Fatalf("Retrieval failed: %v", err)
|
||||
}
|
||||
if result == nil {
|
||||
t.Fatal("expected non-nil result")
|
||||
@@ -211,7 +211,7 @@ func TestKGSearchRetrieval_NoEntities(t *testing.T) {
|
||||
|
||||
// TestEntitySearch_MultiEntities verifies that all entities are used in search query.
|
||||
|
||||
func TestKGSearchRetrieval_WithChatModel(t *testing.T) {
|
||||
func TestRetrieval_WithChatModel(t *testing.T) {
|
||||
mock := &mockRetrievalEngine{
|
||||
results: map[string]*types.SearchResult{
|
||||
"entity": {Chunks: []map[string]interface{}{
|
||||
@@ -225,9 +225,9 @@ func TestKGSearchRetrieval_WithChatModel(t *testing.T) {
|
||||
// chatModel with nil ModelName so queryRewrite falls back to raw question,
|
||||
// but the ty2entsJSON construction path is still exercised.
|
||||
chatModel := &modelModule.ChatModel{ModelName: nil, APIConfig: nil}
|
||||
result, err := KGSearchRetrieval(context.Background(), mock, chatModel, nil, []string{"kb1"}, []string{"tenant1"}, "Elon Musk")
|
||||
result, err := Retrieval(context.Background(), mock, chatModel, nil, []string{"kb1"}, []string{"tenant1"}, "Elon Musk")
|
||||
if err != nil {
|
||||
t.Fatalf("KGSearchRetrieval failed: %v", err)
|
||||
t.Fatalf("Retrieval failed: %v", err)
|
||||
}
|
||||
if result == nil {
|
||||
t.Fatal("expected non-nil result")
|
||||
@@ -415,7 +415,7 @@ func TestBuildSearchExprs_WithEmbModel(t *testing.T) {
|
||||
MatchingText: "Elon Musk SpaceX",
|
||||
TopN: 50,
|
||||
}
|
||||
exprs := buildSearchExprs(embModel, matchText, defaultKGSimThreshold, defaultKGDenseTopK)
|
||||
exprs := buildSearchExprs(embModel, matchText, defaultSimThreshold, defaultDenseTopK)
|
||||
// Verify Embed was called with matchText.MatchingText, not raw question
|
||||
if len(driver.capturedTexts) != 1 || driver.capturedTexts[0] != "Elon Musk SpaceX" {
|
||||
t.Errorf("expected Embed to receive %q, got %v", "Elon Musk SpaceX", driver.capturedTexts)
|
||||
@@ -439,11 +439,11 @@ func TestBuildSearchExprs_WithEmbModel(t *testing.T) {
|
||||
if md.VectorColumnName != "q_3_vec" {
|
||||
t.Errorf("expected q_3_vec, got %q", md.VectorColumnName)
|
||||
}
|
||||
if md.TopN != defaultKGDenseTopK {
|
||||
t.Errorf("expected TopN=%d (Python alignment), got %d", defaultKGDenseTopK, md.TopN)
|
||||
if md.TopN != defaultDenseTopK {
|
||||
t.Errorf("expected TopN=%d (Python alignment), got %d", defaultDenseTopK, md.TopN)
|
||||
}
|
||||
if md.ExtraOptions["similarity"] != defaultKGSimThreshold {
|
||||
t.Errorf("expected similarity=%v (Python alignment), got %v", defaultKGSimThreshold, md.ExtraOptions["similarity"])
|
||||
if md.ExtraOptions["similarity"] != defaultSimThreshold {
|
||||
t.Errorf("expected similarity=%v (Python alignment), got %v", defaultSimThreshold, md.ExtraOptions["similarity"])
|
||||
}
|
||||
// Index 2: FusionExpr
|
||||
fu, ok := exprs[2].(*types.FusionExpr)
|
||||
@@ -463,7 +463,7 @@ func TestBuildSearchExprs_EmbModelFallback(t *testing.T) {
|
||||
MatchingText: "fallback test",
|
||||
TopN: 10,
|
||||
}
|
||||
exprs := buildSearchExprs(embModel, matchText, defaultKGSimThreshold, defaultKGDenseTopK)
|
||||
exprs := buildSearchExprs(embModel, matchText, defaultSimThreshold, defaultDenseTopK)
|
||||
// Should fall back to text-only when Embed fails
|
||||
if len(exprs) != 1 {
|
||||
t.Fatalf("expected 1 expr (text-only fallback), got %d", len(exprs))
|
||||
@@ -476,11 +476,11 @@ func TestBuildSearchExprs_EmbModelFallback(t *testing.T) {
|
||||
// --- Python alignment defaults ---
|
||||
|
||||
func TestDefaultValuesMatchPython(t *testing.T) {
|
||||
if defaultKGSimThreshold != 0.3 {
|
||||
t.Errorf("expected 0.3 (Python ent_sim_threshold), got %f", defaultKGSimThreshold)
|
||||
if defaultSimThreshold != 0.3 {
|
||||
t.Errorf("expected 0.3 (Python ent_sim_threshold), got %f", defaultSimThreshold)
|
||||
}
|
||||
if defaultKGDenseTopK != 1024 {
|
||||
t.Errorf("expected 1024 (Python get_vector topk), got %d", defaultKGDenseTopK)
|
||||
if defaultDenseTopK != 1024 {
|
||||
t.Errorf("expected 1024 (Python get_vector topk), got %d", defaultDenseTopK)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -506,11 +506,11 @@ func TestIndexName_Empty(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// --- searchKGCommunityContent ---
|
||||
// --- searchCommunityContent ---
|
||||
|
||||
func TestSearchKGCommunityContent_EmptyEntities(t *testing.T) {
|
||||
mock := &mockRetrievalEngine{}
|
||||
result := searchKGCommunityContent(context.Background(), mock, []string{"ragflow_t1"}, []string{"kb1"}, nil, 1, intPtr(100))
|
||||
result := searchCommunityContent(context.Background(), mock, []string{"ragflow_t1"}, []string{"kb1"}, nil, 1, intPtr(100))
|
||||
if result != "" {
|
||||
t.Errorf("expected empty, got %q", result)
|
||||
}
|
||||
@@ -527,7 +527,7 @@ func TestSearchKGCommunityContent_WithContent(t *testing.T) {
|
||||
}},
|
||||
},
|
||||
}
|
||||
result := searchKGCommunityContent(context.Background(), mock, []string{"ragflow_t1"}, []string{"kb1"}, []ScoredEntity{{Entity: "E1"}}, 1, intPtr(500))
|
||||
result := searchCommunityContent(context.Background(), mock, []string{"ragflow_t1"}, []string{"kb1"}, []ScoredEntity{{Entity: "E1"}}, 1, intPtr(500))
|
||||
if result == "" {
|
||||
t.Fatal("expected non-empty result")
|
||||
}
|
||||
@@ -547,7 +547,7 @@ func TestSearchKGCommunityContent_WithContent(t *testing.T) {
|
||||
|
||||
func TestSearchKGCommunityContent_NilMaxToken(t *testing.T) {
|
||||
mock := &mockRetrievalEngine{}
|
||||
result := searchKGCommunityContent(context.Background(), mock, []string{"ragflow_t1"}, []string{"kb1"}, []ScoredEntity{{Entity: "E1"}}, 1, nil)
|
||||
result := searchCommunityContent(context.Background(), mock, []string{"ragflow_t1"}, []string{"kb1"}, []ScoredEntity{{Entity: "E1"}}, 1, nil)
|
||||
if result != "" {
|
||||
t.Errorf("expected empty when maxToken is nil, got %q", result)
|
||||
}
|
||||
@@ -555,7 +555,7 @@ func TestSearchKGCommunityContent_NilMaxToken(t *testing.T) {
|
||||
|
||||
func TestSearchKGCommunityContent_ZeroMaxToken(t *testing.T) {
|
||||
mock := &mockRetrievalEngine{}
|
||||
result := searchKGCommunityContent(context.Background(), mock, []string{"ragflow_t1"}, []string{"kb1"}, []ScoredEntity{{Entity: "E1"}}, 1, intPtr(0))
|
||||
result := searchCommunityContent(context.Background(), mock, []string{"ragflow_t1"}, []string{"kb1"}, []ScoredEntity{{Entity: "E1"}}, 1, intPtr(0))
|
||||
if result != "" {
|
||||
t.Errorf("expected empty when maxToken=0, got %q", result)
|
||||
}
|
||||
@@ -14,7 +14,7 @@
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package service
|
||||
package kg
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
@@ -227,9 +227,9 @@ func FormatRelationsToCSV(relations []ScoredRelation, maxToken int) (csv string,
|
||||
return b.String(), maxToken
|
||||
}
|
||||
|
||||
// BuildKGContent assembles the final knowledge graph content string.
|
||||
// BuildContent assembles the final knowledge graph content string.
|
||||
// Python equivalent: lines 267-291
|
||||
func BuildKGContent(
|
||||
func BuildContent(
|
||||
entities []ScoredEntity,
|
||||
relations []ScoredRelation,
|
||||
maxToken int,
|
||||
306
internal/service/kg/search.go
Normal file
306
internal/service/kg/search.go
Normal file
@@ -0,0 +1,306 @@
|
||||
package kg
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"ragflow/internal/engine"
|
||||
"encoding/json"
|
||||
"ragflow/internal/engine/types"
|
||||
modelModule "ragflow/internal/entity/models"
|
||||
)
|
||||
|
||||
// NhopEntityNames extracts unique entity names from an n_hop_with_weight JSON string.
|
||||
func NhopEntityNames(nHopJSON string) []string {
|
||||
if nHopJSON == "" {
|
||||
return nil
|
||||
}
|
||||
var nhopData []struct {
|
||||
Path []string `json:"path"`
|
||||
Weights []float64 `json:"weights"`
|
||||
}
|
||||
if err := json.Unmarshal([]byte(nHopJSON), &nhopData); err != nil {
|
||||
return nil
|
||||
}
|
||||
seen := make(map[string]struct{})
|
||||
for _, item := range nhopData {
|
||||
for _, name := range item.Path {
|
||||
seen[name] = struct{}{}
|
||||
}
|
||||
}
|
||||
result := make([]string, 0, len(seen))
|
||||
for name := range seen {
|
||||
result = append(result, name)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// SearchEntities searches for KG entities matching a question.
|
||||
func SearchEntities(ctx context.Context, docEngine engine.DocEngine, kbIDs []string, question string, embModel *modelModule.EmbeddingModel, topN int) ([]KGEntity, error) {
|
||||
dense, err := buildDenseExpr(embModel, question, topN)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
searchReq := buildEntitySearchRequest(kbIDs, question, dense, topN)
|
||||
result, err := docEngine.Search(ctx, searchReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("KG entity search failed: %w", err)
|
||||
}
|
||||
return ParseEntityChunks(result.Chunks), nil
|
||||
}
|
||||
|
||||
// SearchEntitiesByTypes searches for KG entities by type keywords.
|
||||
func SearchEntitiesByTypes(ctx context.Context, docEngine engine.DocEngine, kbIDs []string, typeKeywords []string, topN int) ([]KGEntity, error) {
|
||||
searchReq := buildEntityTypeSearchRequest(kbIDs, typeKeywords, topN)
|
||||
result, err := docEngine.Search(ctx, searchReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("KG entity type search failed: %w", err)
|
||||
}
|
||||
return ParseEntityChunks(result.Chunks), nil
|
||||
}
|
||||
|
||||
// SearchRelations searches for KG relations matching a question.
|
||||
func SearchRelations(ctx context.Context, docEngine engine.DocEngine, kbIDs []string, question string, embModel *modelModule.EmbeddingModel, topN int) ([]KGRelation, error) {
|
||||
dense, err := buildDenseExpr(embModel, question, topN)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
searchReq := buildRelationSearchRequest(kbIDs, question, dense, topN)
|
||||
result, err := docEngine.Search(ctx, searchReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("KG relation search failed: %w", err)
|
||||
}
|
||||
return ParseRelationChunks(result.Chunks), nil
|
||||
}
|
||||
|
||||
// SearchCommunityReports searches for community reports related to given entities.
|
||||
func SearchCommunityReports(ctx context.Context, docEngine engine.DocEngine, kbIDs []string, entityNames []string, topN int) ([]KGCommunityReport, error) {
|
||||
searchReq := buildCommunitySearchRequest(kbIDs, entityNames, topN)
|
||||
result, err := docEngine.Search(ctx, searchReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("KG community search failed: %w", err)
|
||||
}
|
||||
return ParseCommunityReportChunks(result.Chunks), nil
|
||||
}
|
||||
|
||||
// SearchTypeSamples retrieves the typeu2192entities mapping from ES.
|
||||
func SearchTypeSamples(ctx context.Context, docEngine engine.DocEngine, kbIDs []string) (map[string][]string, error) {
|
||||
searchReq := buildTypeSamplesSearchRequest(kbIDs)
|
||||
result, err := docEngine.Search(ctx, searchReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ParseTypeSamplesChunks(result.Chunks), nil
|
||||
}
|
||||
|
||||
// buildDenseExpr computes the query vector and returns a MatchDenseExpr.
|
||||
func buildDenseExpr(embModel *modelModule.EmbeddingModel, question string, topN int) (*types.MatchDenseExpr, error) {
|
||||
if embModel == nil || question == "" {
|
||||
return nil, nil
|
||||
}
|
||||
embCfg := &modelModule.EmbeddingConfig{Dimension: 0}
|
||||
embeddings, err := embModel.ModelDriver.Embed(embModel.ModelName, []string{question}, embModel.APIConfig, embCfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("KG entity embed failed: %w", err)
|
||||
}
|
||||
if len(embeddings) == 0 || len(embeddings[0].Embedding) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
vector := embeddings[0].Embedding
|
||||
return &types.MatchDenseExpr{
|
||||
VectorColumnName: fmt.Sprintf("q_%d_vec", len(vector)),
|
||||
EmbeddingData: vector,
|
||||
EmbeddingDataType: "float",
|
||||
DistanceType: "cosine",
|
||||
TopN: topN,
|
||||
ExtraOptions: map[string]interface{}{"similarity": 0.3},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// buildHybridExpr returns MatchExprs for hybrid search (dense + text + fusion).
|
||||
func buildHybridExpr(dense *types.MatchDenseExpr, text *types.MatchTextExpr, topN int) []interface{} {
|
||||
if dense == nil {
|
||||
return []interface{}{text}
|
||||
}
|
||||
fusion := buildFusionExpr(defaultTextWeight, defaultVectorWeight, topN)
|
||||
return []interface{}{dense, text, fusion}
|
||||
}
|
||||
|
||||
// buildEntitySearchRequest constructs a SearchRequest for KG entities.
|
||||
func buildEntitySearchRequest(kbIDs []string, question string, dense *types.MatchDenseExpr, topN int) *types.SearchRequest {
|
||||
req := &types.SearchRequest{
|
||||
KbIDs: kbIDs,
|
||||
SelectFields: []string{"entity_kwd", "entity_type_kwd", "rank_flt", "content_with_weight", "n_hop_with_weight", "_score"},
|
||||
Limit: topN,
|
||||
Filter: map[string]interface{}{"knowledge_graph_kwd": "entity"},
|
||||
}
|
||||
if question != "" {
|
||||
textExpr := &types.MatchTextExpr{
|
||||
Fields: []string{"entity_kwd^10", "content_ltks^2"},
|
||||
MatchingText: question,
|
||||
TopN: topN,
|
||||
}
|
||||
req.MatchExprs = buildHybridExpr(dense, textExpr, topN)
|
||||
}
|
||||
return req
|
||||
}
|
||||
|
||||
// buildEntityTypeSearchRequest constructs a SearchRequest for KG entities by type.
|
||||
func buildEntityTypeSearchRequest(kbIDs []string, typeKeywords []string, topN int) *types.SearchRequest {
|
||||
req := &types.SearchRequest{
|
||||
KbIDs: kbIDs,
|
||||
SelectFields: []string{"entity_kwd", "entity_type_kwd"},
|
||||
Limit: topN,
|
||||
Filter: map[string]interface{}{"knowledge_graph_kwd": "entity"},
|
||||
}
|
||||
if len(typeKeywords) > 0 {
|
||||
filters := make([]interface{}, len(typeKeywords))
|
||||
for i, t := range typeKeywords {
|
||||
filters[i] = t
|
||||
}
|
||||
req.Filter["entity_type_kwd"] = filters
|
||||
}
|
||||
return req
|
||||
}
|
||||
|
||||
// buildRelationSearchRequest constructs a SearchRequest for KG relations.
|
||||
func buildRelationSearchRequest(kbIDs []string, question string, dense *types.MatchDenseExpr, topN int) *types.SearchRequest {
|
||||
req := &types.SearchRequest{
|
||||
KbIDs: kbIDs,
|
||||
SelectFields: []string{"from_entity_kwd", "to_entity_kwd", "weight_int", "content_with_weight", "_score"},
|
||||
Limit: topN,
|
||||
Filter: map[string]interface{}{"knowledge_graph_kwd": "relation"},
|
||||
}
|
||||
if question != "" {
|
||||
textExpr := &types.MatchTextExpr{
|
||||
Fields: []string{"content_ltks", "from_entity_kwd", "to_entity_kwd"},
|
||||
MatchingText: question,
|
||||
TopN: topN,
|
||||
}
|
||||
req.MatchExprs = buildHybridExpr(dense, textExpr, topN)
|
||||
}
|
||||
return req
|
||||
}
|
||||
|
||||
// buildCommunitySearchRequest constructs a SearchRequest for KG community reports.
|
||||
func buildCommunitySearchRequest(kbIDs []string, entityNames []string, topN int) *types.SearchRequest {
|
||||
req := &types.SearchRequest{
|
||||
KbIDs: kbIDs,
|
||||
SelectFields: []string{"docnm_kwd", "content_with_weight", "weight_flt", "entities_kwd"},
|
||||
Limit: topN,
|
||||
Filter: map[string]interface{}{"knowledge_graph_kwd": "community_report"},
|
||||
OrderBy: (&types.OrderByExpr{}).Desc("weight_flt"),
|
||||
}
|
||||
if len(entityNames) > 0 {
|
||||
filters := make([]interface{}, len(entityNames))
|
||||
for i, name := range entityNames {
|
||||
filters[i] = name
|
||||
}
|
||||
req.Filter["entities_kwd"] = filters
|
||||
}
|
||||
return req
|
||||
}
|
||||
|
||||
// buildTypeSamplesSearchRequest constructs a SearchRequest for type samples.
|
||||
func buildTypeSamplesSearchRequest(kbIDs []string) *types.SearchRequest {
|
||||
return &types.SearchRequest{
|
||||
KbIDs: kbIDs,
|
||||
SelectFields: []string{"content_with_weight"},
|
||||
Limit: 10000,
|
||||
Filter: map[string]interface{}{"knowledge_graph_kwd": "ty2ents"},
|
||||
}
|
||||
}
|
||||
|
||||
// ParseEntityChunks converts raw search result chunks into KGEntity slices.
|
||||
func ParseEntityChunks(chunks []map[string]interface{}) []KGEntity {
|
||||
var entities []KGEntity
|
||||
for _, chunk := range chunks {
|
||||
name, _ := chunk["entity_kwd"].(string)
|
||||
if name == "" {
|
||||
// Try extracting from list
|
||||
if list, ok := chunk["entity_kwd"].([]interface{}); ok && len(list) > 0 {
|
||||
name, _ = list[0].(string)
|
||||
}
|
||||
}
|
||||
if name == "" {
|
||||
continue
|
||||
}
|
||||
typ, _ := chunk["entity_type_kwd"].(string)
|
||||
e := KGEntity{Name: name, Type: typ}
|
||||
if v, ok := chunk["rank_flt"].(float64); ok {
|
||||
e.PageRank = v
|
||||
}
|
||||
if v, ok := chunk["_score"].(float64); ok {
|
||||
e.Similarity = v
|
||||
} else if v, ok := chunk["score"].(float64); ok {
|
||||
e.Similarity = v
|
||||
}
|
||||
e.Description, _ = chunk["content_with_weight"].(string)
|
||||
entities = append(entities, e)
|
||||
}
|
||||
return entities
|
||||
}
|
||||
|
||||
// ParseRelationChunks converts raw search result chunks into KGRelation slices.
|
||||
func ParseRelationChunks(chunks []map[string]interface{}) []KGRelation {
|
||||
var relations []KGRelation
|
||||
for _, chunk := range chunks {
|
||||
from, _ := chunk["from_entity_kwd"].(string)
|
||||
to, _ := chunk["to_entity_kwd"].(string)
|
||||
if from == "" || to == "" {
|
||||
continue
|
||||
}
|
||||
r := KGRelation{From: from, To: to}
|
||||
if v, ok := chunk["_score"].(float64); ok {
|
||||
r.Sim = v
|
||||
} else if v, ok := chunk["score"].(float64); ok {
|
||||
r.Sim = v
|
||||
}
|
||||
if v, ok := chunk["weight_int"].(float64); ok {
|
||||
r.PageRank = v
|
||||
} else if v, ok := chunk["weight_int"].(int); ok {
|
||||
r.PageRank = float64(v)
|
||||
}
|
||||
r.Description, _ = chunk["content_with_weight"].(string)
|
||||
relations = append(relations, r)
|
||||
}
|
||||
return relations
|
||||
}
|
||||
|
||||
// ParseCommunityReportChunks converts raw search result chunks into KGCommunityReport slices.
|
||||
func ParseCommunityReportChunks(chunks []map[string]interface{}) []KGCommunityReport {
|
||||
var reports []KGCommunityReport
|
||||
for _, chunk := range chunks {
|
||||
title, _ := chunk["docnm_kwd"].(string)
|
||||
content, _ := chunk["content_with_weight"].(string)
|
||||
if title == "" && content == "" {
|
||||
continue
|
||||
}
|
||||
r := KGCommunityReport{Title: title, Content: content}
|
||||
if v, ok := chunk["weight_flt"].(float64); ok {
|
||||
r.Weight = v
|
||||
}
|
||||
r.Entities, _ = chunk["entities_kwd"].(string)
|
||||
reports = append(reports, r)
|
||||
}
|
||||
return reports
|
||||
}
|
||||
|
||||
// ParseTypeSamplesChunks converts raw search result chunks into a typeu2192entities map.
|
||||
func ParseTypeSamplesChunks(chunks []map[string]interface{}) map[string][]string {
|
||||
typeMap := make(map[string][]string)
|
||||
for _, chunk := range chunks {
|
||||
content, ok := chunk["content_with_weight"].(string)
|
||||
if !ok || content == "" {
|
||||
continue
|
||||
}
|
||||
var parsed map[string][]string
|
||||
if err := json.Unmarshal([]byte(content), &parsed); err != nil {
|
||||
continue
|
||||
}
|
||||
for typ, entities := range parsed {
|
||||
typeMap[typ] = append(typeMap[typ], entities...)
|
||||
}
|
||||
}
|
||||
return typeMap
|
||||
}
|
||||
@@ -14,7 +14,7 @@
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package service
|
||||
package kg
|
||||
|
||||
import (
|
||||
"context"
|
||||
@@ -159,13 +159,13 @@ func TestBuildTypeSamplesSearchRequest(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// --- ParseKGEntityChunks ---
|
||||
// --- ParseEntityChunks ---
|
||||
|
||||
func TestParseKGEntityChunks_Basic(t *testing.T) {
|
||||
func TestParseEntityChunks_Basic(t *testing.T) {
|
||||
chunks := []map[string]interface{}{
|
||||
{"entity_kwd": "Elon Musk", "entity_type_kwd": "PERSON", "rank_flt": 0.9, "_score": 0.85, "content_with_weight": "Founder of SpaceX"},
|
||||
}
|
||||
entities := ParseKGEntityChunks(chunks)
|
||||
entities := ParseEntityChunks(chunks)
|
||||
if len(entities) != 1 {
|
||||
t.Fatalf("expected 1, got %d", len(entities))
|
||||
}
|
||||
@@ -174,98 +174,98 @@ func TestParseKGEntityChunks_Basic(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseKGEntityChunks_List(t *testing.T) {
|
||||
func TestParseEntityChunks_List(t *testing.T) {
|
||||
chunks := []map[string]interface{}{
|
||||
{"entity_kwd": []interface{}{"Elon Musk", "elon_musk"}},
|
||||
}
|
||||
entities := ParseKGEntityChunks(chunks)
|
||||
entities := ParseEntityChunks(chunks)
|
||||
if len(entities) != 1 || entities[0].Name != "Elon Musk" {
|
||||
t.Errorf("expected first list element, got %q", entities[0].Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseKGEntityChunks_EmptyName(t *testing.T) {
|
||||
func TestParseEntityChunks_EmptyName(t *testing.T) {
|
||||
chunks := []map[string]interface{}{{"entity_type_kwd": "PERSON"}}
|
||||
if len(ParseKGEntityChunks(chunks)) != 0 {
|
||||
if len(ParseEntityChunks(chunks)) != 0 {
|
||||
t.Error("expected 0 for missing name")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseKGEntityChunks_ScoreFallback(t *testing.T) {
|
||||
func TestParseEntityChunks_ScoreFallback(t *testing.T) {
|
||||
chunks := []map[string]interface{}{{"entity_kwd": "Test", "score": 0.75}}
|
||||
if ParseKGEntityChunks(chunks)[0].Similarity != 0.75 {
|
||||
if ParseEntityChunks(chunks)[0].Similarity != 0.75 {
|
||||
t.Error("expected 0.75 from score field")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseKGEntityChunks_NilInput(t *testing.T) {
|
||||
if len(ParseKGEntityChunks(nil)) != 0 {
|
||||
func TestParseEntityChunks_NilInput(t *testing.T) {
|
||||
if len(ParseEntityChunks(nil)) != 0 {
|
||||
t.Error("expected 0 for nil input")
|
||||
}
|
||||
}
|
||||
|
||||
// --- ParseKGRelationChunks ---
|
||||
// --- ParseRelationChunks ---
|
||||
|
||||
func TestParseKGRelationChunks_Basic(t *testing.T) {
|
||||
func TestParseRelationChunks_Basic(t *testing.T) {
|
||||
chunks := []map[string]interface{}{
|
||||
{"from_entity_kwd": "Elon Musk", "to_entity_kwd": "SpaceX", "weight_int": float64(5), "content_with_weight": "Founder"},
|
||||
}
|
||||
relations := ParseKGRelationChunks(chunks)
|
||||
if len(relations) != 1 || relations[0].From != "Elon Musk" || relations[0].Weight != 5 {
|
||||
relations := ParseRelationChunks(chunks)
|
||||
if len(relations) != 1 || relations[0].From != "Elon Musk" || relations[0].PageRank != 5 {
|
||||
t.Errorf("unexpected: %+v", relations[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseKGRelationChunks_IntWeight(t *testing.T) {
|
||||
func TestParseRelationChunks_IntWeight(t *testing.T) {
|
||||
chunks := []map[string]interface{}{{"from_entity_kwd": "A", "to_entity_kwd": "B", "weight_int": 3}}
|
||||
if ParseKGRelationChunks(chunks)[0].Weight != 3 {
|
||||
if ParseRelationChunks(chunks)[0].PageRank != 3 {
|
||||
t.Error("expected weight 3")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseKGRelationChunks_EmptyFrom(t *testing.T) {
|
||||
if len(ParseKGRelationChunks([]map[string]interface{}{{"to_entity_kwd": "B"}})) != 0 {
|
||||
func TestParseRelationChunks_EmptyFrom(t *testing.T) {
|
||||
if len(ParseRelationChunks([]map[string]interface{}{{"to_entity_kwd": "B"}})) != 0 {
|
||||
t.Error("expected 0 for missing from")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseKGRelationChunks_NilInput(t *testing.T) {
|
||||
if len(ParseKGRelationChunks(nil)) != 0 {
|
||||
func TestParseRelationChunks_NilInput(t *testing.T) {
|
||||
if len(ParseRelationChunks(nil)) != 0 {
|
||||
t.Error("expected 0 for nil")
|
||||
}
|
||||
}
|
||||
|
||||
// --- ParseKGCommunityReportChunks ---
|
||||
// --- ParseCommunityReportChunks ---
|
||||
|
||||
func TestParseKGCommunityReportChunks_Basic(t *testing.T) {
|
||||
func TestParseCommunityReportChunks_Basic(t *testing.T) {
|
||||
chunks := []map[string]interface{}{
|
||||
{"docnm_kwd": "Report 1", "content_with_weight": "content", "weight_flt": 0.95, "entities_kwd": "A, B"},
|
||||
}
|
||||
reports := ParseKGCommunityReportChunks(chunks)
|
||||
reports := ParseCommunityReportChunks(chunks)
|
||||
if len(reports) != 1 || reports[0].Title != "Report 1" || reports[0].Weight != 0.95 {
|
||||
t.Errorf("unexpected: %+v", reports[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseKGCommunityReportChunks_EmptyTitle(t *testing.T) {
|
||||
if len(ParseKGCommunityReportChunks([]map[string]interface{}{{"weight_flt": 0.5}})) != 0 {
|
||||
func TestParseCommunityReportChunks_EmptyTitle(t *testing.T) {
|
||||
if len(ParseCommunityReportChunks([]map[string]interface{}{{"weight_flt": 0.5}})) != 0 {
|
||||
t.Error("expected 0 for empty title and content")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseKGCommunityReportChunks_NilInput(t *testing.T) {
|
||||
if len(ParseKGCommunityReportChunks(nil)) != 0 {
|
||||
func TestParseCommunityReportChunks_NilInput(t *testing.T) {
|
||||
if len(ParseCommunityReportChunks(nil)) != 0 {
|
||||
t.Error("expected 0 for nil")
|
||||
}
|
||||
}
|
||||
|
||||
// --- ParseKGTypeSamplesChunks ---
|
||||
// --- ParseTypeSamplesChunks ---
|
||||
|
||||
func TestParseKGTypeSamplesChunks_ValidJSON(t *testing.T) {
|
||||
func TestParseTypeSamplesChunks_ValidJSON(t *testing.T) {
|
||||
chunks := []map[string]interface{}{
|
||||
{"content_with_weight": `{"PERSON": ["Elon Musk", "Einstein"], "ORGANIZATION": ["SpaceX"]}`},
|
||||
}
|
||||
result := ParseKGTypeSamplesChunks(chunks)
|
||||
result := ParseTypeSamplesChunks(chunks)
|
||||
if len(result) != 2 {
|
||||
t.Fatalf("expected 2 types, got %d: %v", len(result), result)
|
||||
}
|
||||
@@ -277,18 +277,18 @@ func TestParseKGTypeSamplesChunks_ValidJSON(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseKGTypeSamplesChunks_InvalidJSON(t *testing.T) {
|
||||
func TestParseTypeSamplesChunks_InvalidJSON(t *testing.T) {
|
||||
chunks := []map[string]interface{}{
|
||||
{"content_with_weight": "not json"},
|
||||
}
|
||||
result := ParseKGTypeSamplesChunks(chunks)
|
||||
result := ParseTypeSamplesChunks(chunks)
|
||||
if len(result) != 0 {
|
||||
t.Error("expected empty for invalid JSON")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseKGTypeSamplesChunks_Empty(t *testing.T) {
|
||||
result := ParseKGTypeSamplesChunks(nil)
|
||||
func TestParseTypeSamplesChunks_Empty(t *testing.T) {
|
||||
result := ParseTypeSamplesChunks(nil)
|
||||
if len(result) != 0 {
|
||||
t.Error("expected empty for nil")
|
||||
}
|
||||
@@ -356,9 +356,9 @@ func TestBuildKGDenseExpr_WithModel(t *testing.T) {
|
||||
},
|
||||
APIConfig: &modelModule.APIConfig{},
|
||||
}
|
||||
dense, err := buildKGDenseExpr(embModel, "test question", 10)
|
||||
dense, err := buildDenseExpr(embModel, "test question", 10)
|
||||
if err != nil {
|
||||
t.Fatalf("buildKGDenseExpr failed: %v", err)
|
||||
t.Fatalf("buildDenseExpr failed: %v", err)
|
||||
}
|
||||
if dense == nil {
|
||||
t.Fatal("expected non-nil MatchDenseExpr")
|
||||
@@ -372,14 +372,14 @@ func TestBuildKGDenseExpr_WithModel(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBuildKGDenseExpr_NilModel(t *testing.T) {
|
||||
dense, err := buildKGDenseExpr(nil, "test", 10)
|
||||
dense, err := buildDenseExpr(nil, "test", 10)
|
||||
if dense != nil || err != nil {
|
||||
t.Errorf("expected nil,nil for nil model, got dense=%v err=%v", dense, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildKGDenseExpr_EmptyQuestion(t *testing.T) {
|
||||
dense, err := buildKGDenseExpr(&modelModule.EmbeddingModel{}, "", 10)
|
||||
dense, err := buildDenseExpr(&modelModule.EmbeddingModel{}, "", 10)
|
||||
if dense != nil || err != nil {
|
||||
t.Errorf("expected nil,nil for empty question, got dense=%v err=%v", dense, err)
|
||||
}
|
||||
@@ -387,7 +387,7 @@ func TestBuildKGDenseExpr_EmptyQuestion(t *testing.T) {
|
||||
|
||||
// --- Search integration with mock ---
|
||||
|
||||
func TestSearchKGEntities_WithMock(t *testing.T) {
|
||||
func TestSearchEntities_WithMock(t *testing.T) {
|
||||
mock := &mockKGEngine{
|
||||
searchFunc: func(ctx context.Context, req *types.SearchRequest) (*types.SearchResult, error) {
|
||||
if req.Filter["knowledge_graph_kwd"] != "entity" {
|
||||
@@ -400,16 +400,16 @@ func TestSearchKGEntities_WithMock(t *testing.T) {
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
entities, err := SearchKGEntities(context.Background(), mock, []string{"kb1"}, "Elon", nil, 10)
|
||||
entities, err := SearchEntities(context.Background(), mock, []string{"kb1"}, "Elon", nil, 10)
|
||||
if err != nil {
|
||||
t.Fatalf("SearchKGEntities failed: %v", err)
|
||||
t.Fatalf("SearchEntities failed: %v", err)
|
||||
}
|
||||
if len(entities) != 1 || entities[0].Name != "Elon Musk" {
|
||||
t.Errorf("expected [Elon Musk], got %v", entities)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearchKGEntitiesByTypes_WithMock(t *testing.T) {
|
||||
func TestSearchEntitiesByTypes_WithMock(t *testing.T) {
|
||||
mock := &mockKGEngine{
|
||||
searchFunc: func(ctx context.Context, req *types.SearchRequest) (*types.SearchResult, error) {
|
||||
return &types.SearchResult{
|
||||
@@ -419,20 +419,20 @@ func TestSearchKGEntitiesByTypes_WithMock(t *testing.T) {
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
entities, err := SearchKGEntitiesByTypes(context.Background(), mock, []string{"kb1"}, []string{"ORGANIZATION"}, 10)
|
||||
entities, err := SearchEntitiesByTypes(context.Background(), mock, []string{"kb1"}, []string{"ORGANIZATION"}, 10)
|
||||
if err != nil {
|
||||
t.Fatalf("SearchKGEntitiesByTypes failed: %v", err)
|
||||
t.Fatalf("SearchEntitiesByTypes failed: %v", err)
|
||||
}
|
||||
if len(entities) != 1 || entities[0].Type != "ORGANIZATION" {
|
||||
t.Errorf("expected ORGANIZATION, got %v", entities)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearchKGTypeSamples_WithMock(t *testing.T) {
|
||||
func TestSearchTypeSamples_WithMock(t *testing.T) {
|
||||
mock := &mockKGEngine{}
|
||||
samples, err := SearchKGTypeSamples(context.Background(), mock, []string{"kb1"})
|
||||
samples, err := SearchTypeSamples(context.Background(), mock, []string{"kb1"})
|
||||
if err != nil {
|
||||
t.Fatalf("SearchKGTypeSamples failed: %v", err)
|
||||
t.Fatalf("SearchTypeSamples failed: %v", err)
|
||||
}
|
||||
if samples == nil {
|
||||
samples = map[string][]string{}
|
||||
3
internal/service/kg/testutil_test.go
Normal file
3
internal/service/kg/testutil_test.go
Normal file
@@ -0,0 +1,3 @@
|
||||
package kg
|
||||
|
||||
func strPtr(s string) *string { return &s }
|
||||
60
internal/service/kg/types.go
Normal file
60
internal/service/kg/types.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package kg
|
||||
|
||||
// KGEntity represents a knowledge graph entity.
|
||||
type KGEntity struct {
|
||||
Name string // entity_kwd
|
||||
Type string // entity_type_kwd
|
||||
PageRank float64 // rank_flt
|
||||
Similarity float64 // _score
|
||||
Description string // content_with_weight
|
||||
NhopEnts []NhopEntity // n_hop_with_weight (parsed JSON)
|
||||
}
|
||||
|
||||
// NhopEntity represents an N-hop neighbor path.
|
||||
type NhopEntity struct {
|
||||
Path []string // entity names along the path
|
||||
Weights []float64 // pagerank weights per hop
|
||||
}
|
||||
|
||||
// KGRelation represents a relation between two entities.
|
||||
type KGRelation struct {
|
||||
From string // from_entity_kwd
|
||||
To string // to_entity_kwd
|
||||
Description string // content_with_weight
|
||||
Sim float64 // score accumulated during pipeline scoring
|
||||
PageRank float64 // rank_flt or weight_int as float64
|
||||
}
|
||||
|
||||
// Edge represents a directed (from_entity, to_entity) pair.
|
||||
type Edge struct {
|
||||
From, To string
|
||||
}
|
||||
|
||||
// EdgeScore represents the accumulated score for an edge from N-hop analysis.
|
||||
type EdgeScore struct {
|
||||
Sim float64
|
||||
PageRank float64
|
||||
}
|
||||
|
||||
// ScoredEntity is a scored entity ready for output.
|
||||
type ScoredEntity struct {
|
||||
Entity string
|
||||
Score float64
|
||||
Description string
|
||||
}
|
||||
|
||||
// ScoredRelation is a scored relation ready for output.
|
||||
type ScoredRelation struct {
|
||||
From string
|
||||
To string
|
||||
Score float64
|
||||
Description string
|
||||
}
|
||||
|
||||
// KGCommunityReport represents a community report.
|
||||
type KGCommunityReport struct {
|
||||
Title string // docnm_kwd
|
||||
Content string // content_with_weight
|
||||
Weight float64 // weight_flt
|
||||
Entities string // entities_kwd
|
||||
}
|
||||
@@ -1,368 +0,0 @@
|
||||
//
|
||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// --- AnalyzeNHopPaths ---
|
||||
|
||||
func TestAnalyzeNHopPaths_Basic(t *testing.T) {
|
||||
ents := map[string]*KGEntity{
|
||||
"A": {
|
||||
Similarity: 0.9,
|
||||
NhopEnts: []NhopEntity{
|
||||
{Path: []string{"A", "B", "C"}, Weights: []float64{0.8, 0.5}},
|
||||
},
|
||||
},
|
||||
}
|
||||
result := AnalyzeNHopPaths(ents)
|
||||
// A→B: 0.9 / (2+0) = 0.45
|
||||
// B→C: 0.9 / (2+1) = 0.3
|
||||
if len(result) != 2 {
|
||||
t.Fatalf("expected 2 edges, got %d", len(result))
|
||||
}
|
||||
if result[Edge{"A", "B"}].Sim != 0.45 {
|
||||
t.Errorf("expected A→B sim=0.45, got %f", result[Edge{"A", "B"}].Sim)
|
||||
}
|
||||
if result[Edge{"B", "C"}].Sim != 0.3 {
|
||||
t.Errorf("expected B→C sim=0.3, got %f", result[Edge{"B", "C"}].Sim)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAnalyzeNHopPaths_MultipleContributors(t *testing.T) {
|
||||
ents := map[string]*KGEntity{
|
||||
"A": {
|
||||
Similarity: 0.8,
|
||||
NhopEnts: []NhopEntity{
|
||||
{Path: []string{"A", "B"}, Weights: []float64{0.7}},
|
||||
},
|
||||
},
|
||||
"X": {
|
||||
Similarity: 0.6,
|
||||
NhopEnts: []NhopEntity{
|
||||
{Path: []string{"X", "B"}, Weights: []float64{0.5}},
|
||||
},
|
||||
},
|
||||
}
|
||||
result := AnalyzeNHopPaths(ents)
|
||||
// A→B: 0.8 / 2 = 0.4
|
||||
// X→B: 0.6 / 2 = 0.3
|
||||
if result[Edge{"A", "B"}].Sim != 0.4 {
|
||||
t.Errorf("expected A→B sim=0.4, got %f", result[Edge{"A", "B"}].Sim)
|
||||
}
|
||||
if result[Edge{"X", "B"}].Sim != 0.3 {
|
||||
t.Errorf("expected X→B sim=0.3, got %f", result[Edge{"X", "B"}].Sim)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAnalyzeNHopPaths_Empty(t *testing.T) {
|
||||
result := AnalyzeNHopPaths(nil)
|
||||
if len(result) != 0 {
|
||||
t.Errorf("expected empty, got %d", len(result))
|
||||
}
|
||||
}
|
||||
|
||||
// --- DoubleHitBoost ---
|
||||
|
||||
func TestDoubleHitBoost(t *testing.T) {
|
||||
ents := map[string]*KGEntity{
|
||||
"A": {Similarity: 0.5},
|
||||
"B": {Similarity: 0.3},
|
||||
}
|
||||
types := map[string]struct{}{"A": {}}
|
||||
DoubleHitBoost(ents, types)
|
||||
if ents["A"].Similarity != 1.0 {
|
||||
t.Errorf("expected A sim=1.0 after boost, got %f", ents["A"].Similarity)
|
||||
}
|
||||
if ents["B"].Similarity != 0.3 {
|
||||
t.Errorf("expected B sim unchanged at 0.3, got %f", ents["B"].Similarity)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDoubleHitBoost_Empty(t *testing.T) {
|
||||
ents := map[string]*KGEntity{"A": {Similarity: 0.5}}
|
||||
DoubleHitBoost(ents, map[string]struct{}{})
|
||||
if ents["A"].Similarity != 0.5 {
|
||||
t.Errorf("expected unchanged, got %f", ents["A"].Similarity)
|
||||
}
|
||||
}
|
||||
|
||||
// --- FuseRelationScores ---
|
||||
|
||||
func TestFuseRelationScores_NhopContribution(t *testing.T) {
|
||||
rels := map[Edge]*KGRelation{
|
||||
{"A", "B"}: {Sim: 0.5, PageRank: 0.8},
|
||||
}
|
||||
types := map[string]struct{}{}
|
||||
nhop := map[Edge]EdgeScore{
|
||||
{"A", "B"}: {Sim: 0.3},
|
||||
}
|
||||
FuseRelationScores(rels, types, nhop)
|
||||
// sim = 0.5 * (0.3 + 1) = 0.65
|
||||
if rels[Edge{"A", "B"}].Sim != 0.65 {
|
||||
t.Errorf("expected 0.65, got %f", rels[Edge{"A", "B"}].Sim)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFuseRelationScores_TypeBoost(t *testing.T) {
|
||||
rels := map[Edge]*KGRelation{
|
||||
{"A", "B"}: {Sim: 0.5},
|
||||
}
|
||||
types := map[string]struct{}{"A": {}, "B": {}}
|
||||
nhop := map[Edge]EdgeScore{}
|
||||
FuseRelationScores(rels, types, nhop)
|
||||
// Both endpoints in types: s=2, sim = 0.5 * (2+1) = 1.5
|
||||
if rels[Edge{"A", "B"}].Sim != 1.5 {
|
||||
t.Errorf("expected 1.5, got %f", rels[Edge{"A", "B"}].Sim)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFuseRelationScores_NhopNewEdge(t *testing.T) {
|
||||
rels := map[Edge]*KGRelation{}
|
||||
types := map[string]struct{}{}
|
||||
nhop := map[Edge]EdgeScore{
|
||||
{"A", "B"}: {Sim: 0.4, PageRank: 0.7},
|
||||
}
|
||||
FuseRelationScores(rels, types, nhop)
|
||||
if _, ok := rels[Edge{"A", "B"}]; !ok {
|
||||
t.Fatal("expected new edge from N-hop")
|
||||
}
|
||||
if rels[Edge{"A", "B"}].Sim != 0.4 {
|
||||
t.Errorf("expected sim=0.4, got %f", rels[Edge{"A", "B"}].Sim)
|
||||
}
|
||||
}
|
||||
|
||||
// --- SortAndTrim ---
|
||||
|
||||
func TestSortAndTrimEntities(t *testing.T) {
|
||||
ents := map[string]*KGEntity{
|
||||
"A": {Similarity: 0.5, PageRank: 0.9},
|
||||
"B": {Similarity: 0.8, PageRank: 0.3},
|
||||
"C": {Similarity: 0.9, PageRank: 0.1},
|
||||
}
|
||||
result := SortAndTrimEntities(ents, 2)
|
||||
if len(result) != 2 {
|
||||
t.Fatalf("expected 2, got %d", len(result))
|
||||
}
|
||||
// A: 0.45, B: 0.24, C: 0.09 → top 2 should be A, B
|
||||
if result[0].Entity != "A" {
|
||||
t.Errorf("expected A first (0.45), got %s (%f)", result[0].Entity, result[0].Score)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSortAndTrimEntities_DefaultTopN(t *testing.T) {
|
||||
ents := map[string]*KGEntity{
|
||||
"A": {Similarity: 0.5, PageRank: 0.9},
|
||||
"B": {Similarity: 0.8, PageRank: 0.3},
|
||||
}
|
||||
result := SortAndTrimEntities(ents, 0)
|
||||
if len(result) != 2 {
|
||||
t.Errorf("expected default topN to include all, got %d", len(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSortAndTrimRelations(t *testing.T) {
|
||||
rels := map[Edge]*KGRelation{
|
||||
{"A", "B"}: {Sim: 0.9, PageRank: 0.1},
|
||||
{"C", "D"}: {Sim: 0.3, PageRank: 0.8},
|
||||
}
|
||||
result := SortAndTrimRelations(rels, 1)
|
||||
if len(result) != 1 {
|
||||
t.Fatalf("expected 1, got %d", len(result))
|
||||
}
|
||||
// A→B: 0.09, C→D: 0.24 → C→D should be first
|
||||
if result[0].From != "C" {
|
||||
t.Errorf("expected C first (0.24), got %s (%f)", result[0].From, result[0].Score)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Format and Build ---
|
||||
|
||||
func TestBuildKGContent_Basic(t *testing.T) {
|
||||
entities := []ScoredEntity{
|
||||
{Entity: "A", Score: 0.45, Description: `{"description": "Entity A desc"}`},
|
||||
}
|
||||
relations := []ScoredRelation{
|
||||
{From: "A", To: "B", Score: 0.3, Description: `{"description": "rel A-B"}`},
|
||||
}
|
||||
result := BuildKGContent(entities, relations, 10000)
|
||||
if !strings.Contains(result, "Entity A desc") {
|
||||
t.Errorf("expected entity description in output, got: %s", result)
|
||||
}
|
||||
if !strings.Contains(result, "rel A-B") {
|
||||
t.Errorf("expected relation description in output, got: %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildKGContent_TokenBudget(t *testing.T) {
|
||||
longDesc := strings.Repeat("This is a very long description. ", 50)
|
||||
entities := []ScoredEntity{
|
||||
{Entity: "LongEntityName", Score: 1.0, Description: longDesc},
|
||||
}
|
||||
relations := []ScoredRelation{
|
||||
{From: "X", To: "Y", Score: 1.0, Description: "relation desc"},
|
||||
}
|
||||
result := BuildKGContent(entities, relations, 50)
|
||||
// Token budget is very small, should truncate and not include relations
|
||||
if strings.Contains(result, "relation desc") {
|
||||
t.Log("Note: relations included despite small budget (depending on token count)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatEntitiesToCSV_HeaderExceedsBudget(t *testing.T) {
|
||||
entities := []ScoredEntity{
|
||||
{Entity: "A", Score: 1.0, Description: "d"},
|
||||
}
|
||||
result, remaining := FormatEntitiesToCSV(entities, 3)
|
||||
tokens := NumTokensFromString(result)
|
||||
// Header lines (---- Entities ----\n, Entity,Score,Description\n) are written
|
||||
// before the token budget check. They consume ~11 tokens but are not deducted
|
||||
// from maxToken. This is a known limitation shared with Python.
|
||||
if tokens > 3 {
|
||||
t.Logf("output %d tokens exceeds budget of %d (header not counted, remaining=%d)", tokens, 3, remaining)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilterChunksByScore_AllPass(t *testing.T) {
|
||||
chunks := []map[string]interface{}{
|
||||
{"entity_kwd": "A", "_score": 0.5},
|
||||
{"entity_kwd": "B", "_score": 0.8},
|
||||
}
|
||||
result := FilterChunksByScore(chunks, 0.3)
|
||||
if len(result) != 2 {
|
||||
t.Errorf("expected all 2 chunks to pass, got %d", len(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilterChunksByScore_SomeFiltered(t *testing.T) {
|
||||
chunks := []map[string]interface{}{
|
||||
{"entity_kwd": "A", "_score": 0.2},
|
||||
{"entity_kwd": "B", "_score": 0.9},
|
||||
}
|
||||
result := FilterChunksByScore(chunks, 0.3)
|
||||
if len(result) != 1 || result[0]["entity_kwd"] != "B" {
|
||||
t.Errorf("expected only B to pass, got %v", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilterChunksByScore_MissingScore(t *testing.T) {
|
||||
chunks := []map[string]interface{}{
|
||||
{"entity_kwd": "A"}, // no _score → treated as 0
|
||||
{"entity_kwd": "B", "score": 0.5},
|
||||
}
|
||||
result := FilterChunksByScore(chunks, 0.3)
|
||||
if len(result) != 1 || result[0]["entity_kwd"] != "B" {
|
||||
t.Errorf("expected only B (using 'score' field), got %v", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilterChunksByScore_NilInput(t *testing.T) {
|
||||
result := FilterChunksByScore(nil, 0.3)
|
||||
if result != nil {
|
||||
t.Errorf("expected nil, got %v", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilterChunksByScore_ZeroThreshold(t *testing.T) {
|
||||
chunks := []map[string]interface{}{
|
||||
{"entity_kwd": "A", "_score": 0.0},
|
||||
}
|
||||
result := FilterChunksByScore(chunks, 0)
|
||||
if len(result) != 1 {
|
||||
t.Errorf("expected all pass when threshold=0, got %d", len(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractDescription_JSON(t *testing.T) {
|
||||
result := extractDescription(`{"description": "Entity A description", "other": "value"}`)
|
||||
if result != "Entity A description" {
|
||||
t.Errorf("expected 'Entity A description', got %q", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractDescription_Plain(t *testing.T) {
|
||||
result := extractDescription("plain description")
|
||||
if result != "plain description" {
|
||||
t.Errorf("expected 'plain description', got %q", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractDescription_EscapedQuote(t *testing.T) {
|
||||
result := extractDescription(`{"description": "has \"quote\" inside"}`)
|
||||
if result != `has "quote" inside` {
|
||||
t.Errorf("expected full description with quote, got %q", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractDescription_NonStringValue(t *testing.T) {
|
||||
result := extractDescription(`{"description": null, "other": "val"}`)
|
||||
if result != `{"description": null, "other": "val"}` {
|
||||
t.Errorf("expected raw JSON when description is null, got %q", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractDescription_EmptyString(t *testing.T) {
|
||||
result := extractDescription("")
|
||||
if result != "" {
|
||||
t.Errorf("expected empty, got %q", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatCSVLine_Normal(t *testing.T) {
|
||||
result := formatCSVLine("Elon Musk", "0.85", "CEO of SpaceX")
|
||||
// Normal values should not be quoted
|
||||
if result != "Elon Musk,0.85,CEO of SpaceX\n" {
|
||||
t.Errorf("expected unquoted CSV, got %q", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatCSVLine_CommaInField(t *testing.T) {
|
||||
result := formatCSVLine("Musk, Elon", "0.85", "CEO, SpaceX")
|
||||
// Values with commas should be quoted
|
||||
expected := `"Musk, Elon",0.85,"CEO, SpaceX"` + "\n"
|
||||
if result != expected {
|
||||
t.Errorf("expected %q, got %q", expected, result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatCSVLine_QuoteInField(t *testing.T) {
|
||||
result := formatCSVLine("Elon Musk", "0.85", `CEO of "SpaceX"`)
|
||||
// Values with quotes should have quotes escaped
|
||||
expected := `Elon Musk,0.85,"CEO of ""SpaceX"""` + "\n"
|
||||
if result != expected {
|
||||
t.Errorf("expected %q, got %q", expected, result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatCSVLine_EmptyField(t *testing.T) {
|
||||
result := formatCSVLine("", "", "")
|
||||
if result != ",,\n" {
|
||||
t.Errorf("expected empty fields, got %q", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNumTokensFromString(t *testing.T) {
|
||||
s := "This is a test string with multiple words"
|
||||
tokens := NumTokensFromString(s)
|
||||
if tokens <= 0 {
|
||||
t.Errorf("expected positive token count, got %d", tokens)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,397 +0,0 @@
|
||||
//
|
||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"ragflow/internal/engine"
|
||||
"ragflow/internal/engine/types"
|
||||
modelModule "ragflow/internal/entity/models"
|
||||
)
|
||||
|
||||
// NhopEntity represents an N-hop neighbor path.
|
||||
type NhopEntity struct {
|
||||
Path []string // entity names along the path
|
||||
Weights []float64 // pagerank weights per hop
|
||||
}
|
||||
|
||||
// KGEntity represents a knowledge graph entity.
|
||||
type KGEntity struct {
|
||||
Name string // entity_kwd
|
||||
Type string // entity_type_kwd
|
||||
PageRank float64 // rank_flt
|
||||
Similarity float64 // _score
|
||||
Description string // content_with_weight
|
||||
NhopEnts []NhopEntity // n_hop_with_weight (parsed JSON)
|
||||
}
|
||||
|
||||
// Edge represents a directed (from_entity, to_entity) pair.
|
||||
type Edge struct {
|
||||
From, To string
|
||||
}
|
||||
|
||||
// EdgeScore represents the accumulated score for an edge from N-hop analysis.
|
||||
type EdgeScore struct {
|
||||
Sim float64
|
||||
PageRank float64
|
||||
}
|
||||
|
||||
// ScoredEntity is a scored entity ready for output.
|
||||
type ScoredEntity struct {
|
||||
Entity string
|
||||
Score float64
|
||||
Description string
|
||||
}
|
||||
|
||||
// ScoredRelation is a scored relation ready for output.
|
||||
type ScoredRelation struct {
|
||||
From string
|
||||
To string
|
||||
Score float64
|
||||
Description string
|
||||
}
|
||||
|
||||
// KGRelation represents a relation between two entities.
|
||||
type KGRelation struct {
|
||||
From string // from_entity_kwd
|
||||
To string // to_entity_kwd
|
||||
Weight int // weight_int
|
||||
Description string // content_with_weight
|
||||
Sim float64 // score accumulated during pipeline scoring
|
||||
PageRank float64 // rank_flt or weight_int as float64
|
||||
}
|
||||
|
||||
// KGCommunityReport represents a community report.
|
||||
type KGCommunityReport struct {
|
||||
Title string // docnm_kwd
|
||||
Content string // content_with_weight
|
||||
Weight float64 // weight_flt
|
||||
Entities string // entities_kwd
|
||||
}
|
||||
|
||||
// buildKGDenseExpr computes the query vector and returns a MatchDenseExpr
|
||||
// for KG hybrid search. Returns nil if embModel or question is empty.
|
||||
func buildKGDenseExpr(embModel *modelModule.EmbeddingModel, question string, topN int) (*types.MatchDenseExpr, error) {
|
||||
if embModel == nil || question == "" {
|
||||
return nil, nil
|
||||
}
|
||||
embCfg := &modelModule.EmbeddingConfig{Dimension: 0}
|
||||
embeddings, err := embModel.ModelDriver.Embed(embModel.ModelName, []string{question}, embModel.APIConfig, embCfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("KG entity embed failed: %w", err)
|
||||
}
|
||||
if len(embeddings) == 0 || len(embeddings[0].Embedding) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
vector := embeddings[0].Embedding
|
||||
return &types.MatchDenseExpr{
|
||||
VectorColumnName: fmt.Sprintf("q_%d_vec", len(vector)),
|
||||
EmbeddingData: vector,
|
||||
EmbeddingDataType: "float",
|
||||
DistanceType: "cosine",
|
||||
TopN: topN,
|
||||
ExtraOptions: map[string]interface{}{"similarity": 0.3},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// buildHybridExpr returns MatchExprs for hybrid search (dense + text + fusion).
|
||||
func buildHybridExpr(dense *types.MatchDenseExpr, text *types.MatchTextExpr, topN int) []interface{} {
|
||||
return []interface{}{
|
||||
dense,
|
||||
text,
|
||||
&types.FusionExpr{
|
||||
Method: "weighted_sum",
|
||||
TopN: topN,
|
||||
FusionParams: map[string]interface{}{"weights": "0.05,0.95"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// buildEntitySearchRequest constructs a SearchRequest for KG entities.
|
||||
// dense may be nil for text-only search.
|
||||
func buildEntitySearchRequest(kbIDs []string, question string, dense *types.MatchDenseExpr, topN int) *types.SearchRequest {
|
||||
req := &types.SearchRequest{
|
||||
KbIDs: kbIDs,
|
||||
SelectFields: []string{"entity_kwd", "entity_type_kwd", "rank_flt", "content_with_weight"},
|
||||
Limit: topN,
|
||||
Filter: map[string]interface{}{"knowledge_graph_kwd": "entity"},
|
||||
}
|
||||
if question == "" {
|
||||
return req
|
||||
}
|
||||
textExpr := &types.MatchTextExpr{
|
||||
Fields: []string{"entity_kwd^10", "content_ltks^2"},
|
||||
MatchingText: question,
|
||||
TopN: topN,
|
||||
}
|
||||
if dense != nil {
|
||||
req.MatchExprs = buildHybridExpr(dense, textExpr, topN)
|
||||
req.RankFeature = map[string]float64{"pagerank_fea": 10.0}
|
||||
} else {
|
||||
req.MatchExprs = []interface{}{textExpr}
|
||||
}
|
||||
return req
|
||||
}
|
||||
|
||||
// buildEntityTypeSearchRequest constructs a SearchRequest for KG entities by type.
|
||||
func buildEntityTypeSearchRequest(kbIDs []string, typeKeywords []string, topN int) *types.SearchRequest {
|
||||
req := &types.SearchRequest{
|
||||
KbIDs: kbIDs,
|
||||
SelectFields: []string{"entity_kwd", "entity_type_kwd", "rank_flt", "content_with_weight"},
|
||||
Limit: topN,
|
||||
Filter: map[string]interface{}{
|
||||
"knowledge_graph_kwd": "entity",
|
||||
},
|
||||
}
|
||||
if len(typeKeywords) > 0 {
|
||||
filters := make([]interface{}, len(typeKeywords))
|
||||
for i, t := range typeKeywords {
|
||||
filters[i] = t
|
||||
}
|
||||
req.Filter["entity_type_kwd"] = filters
|
||||
}
|
||||
return req
|
||||
}
|
||||
|
||||
// buildRelationSearchRequest constructs a SearchRequest for KG relations.
|
||||
// dense may be nil for text-only search.
|
||||
func buildRelationSearchRequest(kbIDs []string, question string, dense *types.MatchDenseExpr, topN int) *types.SearchRequest {
|
||||
req := &types.SearchRequest{
|
||||
KbIDs: kbIDs,
|
||||
SelectFields: []string{"from_entity_kwd", "to_entity_kwd", "weight_int", "content_with_weight"},
|
||||
Limit: topN,
|
||||
Filter: map[string]interface{}{"knowledge_graph_kwd": "relation"},
|
||||
}
|
||||
if question != "" {
|
||||
textExpr := &types.MatchTextExpr{
|
||||
Fields: []string{"content_ltks"},
|
||||
MatchingText: question,
|
||||
TopN: topN,
|
||||
}
|
||||
if dense != nil {
|
||||
req.MatchExprs = buildHybridExpr(dense, textExpr, topN)
|
||||
} else {
|
||||
req.MatchExprs = []interface{}{textExpr}
|
||||
}
|
||||
}
|
||||
return req
|
||||
}
|
||||
|
||||
// buildCommunitySearchRequest constructs a SearchRequest for KG community reports.
|
||||
// Matches community reports whose entities_kwd contains any of the given entity names.
|
||||
func buildCommunitySearchRequest(kbIDs []string, entityNames []string, topN int) *types.SearchRequest {
|
||||
req := &types.SearchRequest{
|
||||
KbIDs: kbIDs,
|
||||
SelectFields: []string{"docnm_kwd", "content_with_weight", "weight_flt", "entities_kwd"},
|
||||
Limit: topN,
|
||||
Filter: map[string]interface{}{
|
||||
"knowledge_graph_kwd": "community_report",
|
||||
},
|
||||
OrderBy: (&types.OrderByExpr{}).Desc("weight_flt"),
|
||||
}
|
||||
if len(entityNames) > 0 {
|
||||
filters := make([]interface{}, len(entityNames))
|
||||
for i, name := range entityNames {
|
||||
filters[i] = name
|
||||
}
|
||||
req.Filter["entities_kwd"] = filters
|
||||
}
|
||||
return req
|
||||
}
|
||||
|
||||
// buildTypeSamplesSearchRequest constructs a SearchRequest for ty2ents data.
|
||||
func buildTypeSamplesSearchRequest(kbIDs []string) *types.SearchRequest {
|
||||
return &types.SearchRequest{
|
||||
KbIDs: kbIDs,
|
||||
SelectFields: []string{"content_with_weight"},
|
||||
Limit: 10000,
|
||||
Filter: map[string]interface{}{"knowledge_graph_kwd": "ty2ents"},
|
||||
}
|
||||
}
|
||||
|
||||
// ParseKGEntityChunks converts raw search result chunks into KGEntity slices.
|
||||
func ParseKGEntityChunks(chunks []map[string]interface{}) []KGEntity {
|
||||
var entities []KGEntity
|
||||
for _, chunk := range chunks {
|
||||
e := KGEntity{}
|
||||
if v, ok := chunk["entity_kwd"].(string); ok {
|
||||
e.Name = v
|
||||
} else if list, ok := chunk["entity_kwd"].([]interface{}); ok && len(list) > 0 {
|
||||
e.Name, _ = list[0].(string)
|
||||
}
|
||||
if e.Name == "" {
|
||||
continue
|
||||
}
|
||||
e.Type, _ = chunk["entity_type_kwd"].(string)
|
||||
e.Description, _ = chunk["content_with_weight"].(string)
|
||||
if v, ok := chunk["rank_flt"].(float64); ok {
|
||||
e.PageRank = v
|
||||
}
|
||||
if v, ok := chunk["_score"].(float64); ok {
|
||||
e.Similarity = v
|
||||
} else if v, ok := chunk["score"].(float64); ok {
|
||||
e.Similarity = v
|
||||
}
|
||||
entities = append(entities, e)
|
||||
}
|
||||
return entities
|
||||
}
|
||||
|
||||
// ParseKGRelationChunks converts raw search result chunks into KGRelation slices.
|
||||
func ParseKGRelationChunks(chunks []map[string]interface{}) []KGRelation {
|
||||
var relations []KGRelation
|
||||
for _, chunk := range chunks {
|
||||
r := KGRelation{}
|
||||
r.From, _ = chunk["from_entity_kwd"].(string)
|
||||
r.To, _ = chunk["to_entity_kwd"].(string)
|
||||
r.Description, _ = chunk["content_with_weight"].(string)
|
||||
if v, ok := chunk["weight_int"].(float64); ok {
|
||||
r.Weight = int(v)
|
||||
} else if v, ok := chunk["weight_int"].(int); ok {
|
||||
r.Weight = v
|
||||
}
|
||||
if r.From == "" || r.To == "" {
|
||||
continue
|
||||
}
|
||||
relations = append(relations, r)
|
||||
}
|
||||
return relations
|
||||
}
|
||||
|
||||
// ParseKGCommunityReportChunks converts raw search result chunks into KGCommunityReport slices.
|
||||
func ParseKGCommunityReportChunks(chunks []map[string]interface{}) []KGCommunityReport {
|
||||
var reports []KGCommunityReport
|
||||
for _, chunk := range chunks {
|
||||
r := KGCommunityReport{}
|
||||
r.Title, _ = chunk["docnm_kwd"].(string)
|
||||
r.Content, _ = chunk["content_with_weight"].(string)
|
||||
r.Entities, _ = chunk["entities_kwd"].(string)
|
||||
if v, ok := chunk["weight_flt"].(float64); ok {
|
||||
r.Weight = v
|
||||
}
|
||||
if r.Title == "" && r.Content == "" {
|
||||
continue
|
||||
}
|
||||
reports = append(reports, r)
|
||||
}
|
||||
return reports
|
||||
}
|
||||
|
||||
// ParseKGTypeSamplesChunks converts raw search result chunks into a type→entities map.
|
||||
func ParseKGTypeSamplesChunks(chunks []map[string]interface{}) map[string][]string {
|
||||
result := make(map[string][]string)
|
||||
for _, chunk := range chunks {
|
||||
content, ok := chunk["content_with_weight"].(string)
|
||||
if !ok || content == "" {
|
||||
continue
|
||||
}
|
||||
var typeMap map[string][]string
|
||||
if err := json.Unmarshal([]byte(content), &typeMap); err != nil {
|
||||
continue
|
||||
}
|
||||
for typ, entities := range typeMap {
|
||||
result[typ] = append(result[typ], entities...)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// NhopEntityNames extracts unique entity names from n_hop_with_weight JSON string.
|
||||
// The JSON format is: [{"path": ["A", "B", "C"], "weights": [0.8, 0.5]}, ...]
|
||||
// Returns entity names in order of first appearance, with duplicates removed.
|
||||
func NhopEntityNames(nHopJSON string) []string {
|
||||
type nhopItem struct {
|
||||
Path []string `json:"path"`
|
||||
Weights []float64 `json:"weights"`
|
||||
}
|
||||
var data []nhopItem
|
||||
if err := json.Unmarshal([]byte(nHopJSON), &data); err != nil {
|
||||
return nil
|
||||
}
|
||||
seen := make(map[string]struct{})
|
||||
var names []string
|
||||
for _, item := range data {
|
||||
for _, name := range item.Path {
|
||||
if _, ok := seen[name]; !ok {
|
||||
seen[name] = struct{}{}
|
||||
names = append(names, name)
|
||||
}
|
||||
}
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
||||
// SearchKGEntities searches for KG entities matching a question.
|
||||
func SearchKGEntities(ctx context.Context, docEngine engine.DocEngine, kbIDs []string, question string, embModel *modelModule.EmbeddingModel, topN int) ([]KGEntity, error) {
|
||||
dense, err := buildKGDenseExpr(embModel, question, topN)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req := buildEntitySearchRequest(kbIDs, question, dense, topN)
|
||||
result, err := docEngine.Search(ctx, req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("KG entity search failed: %w", err)
|
||||
}
|
||||
return ParseKGEntityChunks(result.Chunks), nil
|
||||
}
|
||||
|
||||
// SearchKGEntitiesByTypes searches for KG entities by type keywords.
|
||||
func SearchKGEntitiesByTypes(ctx context.Context, docEngine engine.DocEngine, kbIDs []string, typeKeywords []string, topN int) ([]KGEntity, error) {
|
||||
req := buildEntityTypeSearchRequest(kbIDs, typeKeywords, topN)
|
||||
result, err := docEngine.Search(ctx, req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("KG entity type search failed: %w", err)
|
||||
}
|
||||
return ParseKGEntityChunks(result.Chunks), nil
|
||||
}
|
||||
|
||||
// SearchKGRelations searches for KG relations matching a question.
|
||||
func SearchKGRelations(ctx context.Context, docEngine engine.DocEngine, kbIDs []string, question string, embModel *modelModule.EmbeddingModel, topN int) ([]KGRelation, error) {
|
||||
dense, err := buildKGDenseExpr(embModel, question, topN)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req := buildRelationSearchRequest(kbIDs, question, dense, topN)
|
||||
result, err := docEngine.Search(ctx, req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("KG relation search failed: %w", err)
|
||||
}
|
||||
return ParseKGRelationChunks(result.Chunks), nil
|
||||
}
|
||||
|
||||
// SearchKGCommunityReports searches for community reports related to given entities.
|
||||
func SearchKGCommunityReports(ctx context.Context, docEngine engine.DocEngine, kbIDs []string, entityNames []string, topN int) ([]KGCommunityReport, error) {
|
||||
req := buildCommunitySearchRequest(kbIDs, entityNames, topN)
|
||||
result, err := docEngine.Search(ctx, req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("KG community search failed: %w", err)
|
||||
}
|
||||
return ParseKGCommunityReportChunks(result.Chunks), nil
|
||||
}
|
||||
|
||||
// SearchKGTypeSamples retrieves the type→entities mapping from ES.
|
||||
func SearchKGTypeSamples(ctx context.Context, docEngine engine.DocEngine, kbIDs []string) (map[string][]string, error) {
|
||||
req := buildTypeSamplesSearchRequest(kbIDs)
|
||||
result, err := docEngine.Search(ctx, req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("KG type samples search failed: %w", err)
|
||||
}
|
||||
return ParseKGTypeSamplesChunks(result.Chunks), nil
|
||||
}
|
||||
Reference in New Issue
Block a user