mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 23:41:12 +08:00
feat: add KG entity/relation/community search functions (#15689)
## Summary Knowledge Graph search functions for entity, relation, community report, and type-samples retrieval. Uses DocEngine.SelectFields (PR #15684) for KG-specific fields. ### Functions | Function | Description | |----------|-------------| | `SearchKGEntities` | Hybrid search over KG entities (dense + text + fusion) | | `SearchKGEntitiesByTypes` | Entity search filtered by `entity_type_kwd` | | `SearchKGRelations` | Hybrid search over KG relations | | `SearchKGCommunityReports` | Community report search by entity names | | `SearchKGTypeSamples` | Type→entities mapping for query_rewrite | ### Internal helpers | Helper | Description | |--------|-------------| | `buildHybridExpr` | Shared dense+text+fusion expression construction | | `buildKGDenseExpr` | Wraps `Embed()` call for vector search | | `Parse*` | Convert raw chunks to typed structs | ### Testing 35 tests (pure function + mock integration) Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
362
internal/service/kg_search.go
Normal file
362
internal/service/kg_search.go
Normal file
@@ -0,0 +1,362 @@
|
||||
//
|
||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"ragflow/internal/engine"
|
||||
"ragflow/internal/engine/types"
|
||||
modelModule "ragflow/internal/entity/models"
|
||||
)
|
||||
|
||||
// KGEntity represents a knowledge graph entity.
|
||||
type KGEntity struct {
|
||||
Name string // entity_kwd
|
||||
Type string // entity_type_kwd
|
||||
PageRank float64 // rank_flt
|
||||
Similarity float64 // _score
|
||||
Description string // content_with_weight
|
||||
}
|
||||
|
||||
// KGRelation represents a relation between two entities.
|
||||
type KGRelation struct {
|
||||
From string // from_entity_kwd
|
||||
To string // to_entity_kwd
|
||||
Weight int // weight_int
|
||||
Description string // content_with_weight
|
||||
}
|
||||
|
||||
// KGCommunityReport represents a community report.
|
||||
type KGCommunityReport struct {
|
||||
Title string // docnm_kwd
|
||||
Content string // content_with_weight
|
||||
Weight float64 // weight_flt
|
||||
Entities string // entities_kwd
|
||||
}
|
||||
|
||||
// buildKGDenseExpr computes the query vector and returns a MatchDenseExpr
|
||||
// for KG hybrid search. Returns nil if embModel or question is empty.
|
||||
func buildKGDenseExpr(embModel *modelModule.EmbeddingModel, question string, topN int) (*types.MatchDenseExpr, error) {
|
||||
if embModel == nil || question == "" {
|
||||
return nil, nil
|
||||
}
|
||||
embCfg := &modelModule.EmbeddingConfig{Dimension: 0}
|
||||
embeddings, err := embModel.ModelDriver.Embed(embModel.ModelName, []string{question}, embModel.APIConfig, embCfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("KG entity embed failed: %w", err)
|
||||
}
|
||||
if len(embeddings) == 0 || len(embeddings[0].Embedding) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
vector := embeddings[0].Embedding
|
||||
return &types.MatchDenseExpr{
|
||||
VectorColumnName: fmt.Sprintf("q_%d_vec", len(vector)),
|
||||
EmbeddingData: vector,
|
||||
EmbeddingDataType: "float",
|
||||
DistanceType: "cosine",
|
||||
TopN: topN,
|
||||
ExtraOptions: map[string]interface{}{"similarity": 0.3},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// buildHybridExpr returns MatchExprs for hybrid search (dense + text + fusion).
|
||||
func buildHybridExpr(dense *types.MatchDenseExpr, text *types.MatchTextExpr, topN int) []interface{} {
|
||||
return []interface{}{
|
||||
dense,
|
||||
text,
|
||||
&types.FusionExpr{
|
||||
Method: "weighted_sum",
|
||||
TopN: topN,
|
||||
FusionParams: map[string]interface{}{"weights": "0.05,0.95"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// buildEntitySearchRequest constructs a SearchRequest for KG entities.
|
||||
// dense may be nil for text-only search.
|
||||
func buildEntitySearchRequest(kbIDs []string, question string, dense *types.MatchDenseExpr, topN int) *types.SearchRequest {
|
||||
req := &types.SearchRequest{
|
||||
KbIDs: kbIDs,
|
||||
SelectFields: []string{"entity_kwd", "entity_type_kwd", "rank_flt", "content_with_weight"},
|
||||
Limit: topN,
|
||||
Filter: map[string]interface{}{"knowledge_graph_kwd": "entity"},
|
||||
}
|
||||
if question == "" {
|
||||
return req
|
||||
}
|
||||
textExpr := &types.MatchTextExpr{
|
||||
Fields: []string{"entity_kwd^10", "content_ltks^2"},
|
||||
MatchingText: question,
|
||||
TopN: topN,
|
||||
}
|
||||
if dense != nil {
|
||||
req.MatchExprs = buildHybridExpr(dense, textExpr, topN)
|
||||
req.RankFeature = map[string]float64{"pagerank_fea": 10.0}
|
||||
} else {
|
||||
req.MatchExprs = []interface{}{textExpr}
|
||||
}
|
||||
return req
|
||||
}
|
||||
|
||||
// buildEntityTypeSearchRequest constructs a SearchRequest for KG entities by type.
|
||||
func buildEntityTypeSearchRequest(kbIDs []string, typeKeywords []string, topN int) *types.SearchRequest {
|
||||
req := &types.SearchRequest{
|
||||
KbIDs: kbIDs,
|
||||
SelectFields: []string{"entity_kwd", "entity_type_kwd", "rank_flt", "content_with_weight"},
|
||||
Limit: topN,
|
||||
Filter: map[string]interface{}{
|
||||
"knowledge_graph_kwd": "entity",
|
||||
},
|
||||
}
|
||||
if len(typeKeywords) > 0 {
|
||||
filters := make([]interface{}, len(typeKeywords))
|
||||
for i, t := range typeKeywords {
|
||||
filters[i] = t
|
||||
}
|
||||
req.Filter["entity_type_kwd"] = filters
|
||||
}
|
||||
return req
|
||||
}
|
||||
|
||||
// buildRelationSearchRequest constructs a SearchRequest for KG relations.
|
||||
// dense may be nil for text-only search.
|
||||
func buildRelationSearchRequest(kbIDs []string, question string, dense *types.MatchDenseExpr, topN int) *types.SearchRequest {
|
||||
req := &types.SearchRequest{
|
||||
KbIDs: kbIDs,
|
||||
SelectFields: []string{"from_entity_kwd", "to_entity_kwd", "weight_int", "content_with_weight"},
|
||||
Limit: topN,
|
||||
Filter: map[string]interface{}{"knowledge_graph_kwd": "relation"},
|
||||
}
|
||||
if question != "" {
|
||||
textExpr := &types.MatchTextExpr{
|
||||
Fields: []string{"content_ltks"},
|
||||
MatchingText: question,
|
||||
TopN: topN,
|
||||
}
|
||||
if dense != nil {
|
||||
req.MatchExprs = buildHybridExpr(dense, textExpr, topN)
|
||||
} else {
|
||||
req.MatchExprs = []interface{}{textExpr}
|
||||
}
|
||||
}
|
||||
return req
|
||||
}
|
||||
|
||||
// buildCommunitySearchRequest constructs a SearchRequest for KG community reports.
|
||||
// Matches community reports whose entities_kwd contains any of the given entity names.
|
||||
func buildCommunitySearchRequest(kbIDs []string, entityNames []string, topN int) *types.SearchRequest {
|
||||
req := &types.SearchRequest{
|
||||
KbIDs: kbIDs,
|
||||
SelectFields: []string{"docnm_kwd", "content_with_weight", "weight_flt", "entities_kwd"},
|
||||
Limit: topN,
|
||||
Filter: map[string]interface{}{
|
||||
"knowledge_graph_kwd": "community_report",
|
||||
},
|
||||
OrderBy: (&types.OrderByExpr{}).Desc("weight_flt"),
|
||||
}
|
||||
if len(entityNames) > 0 {
|
||||
filters := make([]interface{}, len(entityNames))
|
||||
for i, name := range entityNames {
|
||||
filters[i] = name
|
||||
}
|
||||
req.Filter["entities_kwd"] = filters
|
||||
}
|
||||
return req
|
||||
}
|
||||
|
||||
// buildTypeSamplesSearchRequest constructs a SearchRequest for ty2ents data.
|
||||
func buildTypeSamplesSearchRequest(kbIDs []string) *types.SearchRequest {
|
||||
return &types.SearchRequest{
|
||||
KbIDs: kbIDs,
|
||||
SelectFields: []string{"content_with_weight"},
|
||||
Limit: 10000,
|
||||
Filter: map[string]interface{}{"knowledge_graph_kwd": "ty2ents"},
|
||||
}
|
||||
}
|
||||
|
||||
// ParseKGEntityChunks converts raw search result chunks into KGEntity slices.
|
||||
func ParseKGEntityChunks(chunks []map[string]interface{}) []KGEntity {
|
||||
var entities []KGEntity
|
||||
for _, chunk := range chunks {
|
||||
e := KGEntity{}
|
||||
if v, ok := chunk["entity_kwd"].(string); ok {
|
||||
e.Name = v
|
||||
} else if list, ok := chunk["entity_kwd"].([]interface{}); ok && len(list) > 0 {
|
||||
e.Name, _ = list[0].(string)
|
||||
}
|
||||
if e.Name == "" {
|
||||
continue
|
||||
}
|
||||
e.Type, _ = chunk["entity_type_kwd"].(string)
|
||||
e.Description, _ = chunk["content_with_weight"].(string)
|
||||
if v, ok := chunk["rank_flt"].(float64); ok {
|
||||
e.PageRank = v
|
||||
}
|
||||
if v, ok := chunk["_score"].(float64); ok {
|
||||
e.Similarity = v
|
||||
} else if v, ok := chunk["score"].(float64); ok {
|
||||
e.Similarity = v
|
||||
}
|
||||
entities = append(entities, e)
|
||||
}
|
||||
return entities
|
||||
}
|
||||
|
||||
// ParseKGRelationChunks converts raw search result chunks into KGRelation slices.
|
||||
func ParseKGRelationChunks(chunks []map[string]interface{}) []KGRelation {
|
||||
var relations []KGRelation
|
||||
for _, chunk := range chunks {
|
||||
r := KGRelation{}
|
||||
r.From, _ = chunk["from_entity_kwd"].(string)
|
||||
r.To, _ = chunk["to_entity_kwd"].(string)
|
||||
r.Description, _ = chunk["content_with_weight"].(string)
|
||||
if v, ok := chunk["weight_int"].(float64); ok {
|
||||
r.Weight = int(v)
|
||||
} else if v, ok := chunk["weight_int"].(int); ok {
|
||||
r.Weight = v
|
||||
}
|
||||
if r.From == "" || r.To == "" {
|
||||
continue
|
||||
}
|
||||
relations = append(relations, r)
|
||||
}
|
||||
return relations
|
||||
}
|
||||
|
||||
// ParseKGCommunityReportChunks converts raw search result chunks into KGCommunityReport slices.
|
||||
func ParseKGCommunityReportChunks(chunks []map[string]interface{}) []KGCommunityReport {
|
||||
var reports []KGCommunityReport
|
||||
for _, chunk := range chunks {
|
||||
r := KGCommunityReport{}
|
||||
r.Title, _ = chunk["docnm_kwd"].(string)
|
||||
r.Content, _ = chunk["content_with_weight"].(string)
|
||||
r.Entities, _ = chunk["entities_kwd"].(string)
|
||||
if v, ok := chunk["weight_flt"].(float64); ok {
|
||||
r.Weight = v
|
||||
}
|
||||
if r.Title == "" && r.Content == "" {
|
||||
continue
|
||||
}
|
||||
reports = append(reports, r)
|
||||
}
|
||||
return reports
|
||||
}
|
||||
|
||||
// ParseKGTypeSamplesChunks converts raw search result chunks into a type→entities map.
|
||||
func ParseKGTypeSamplesChunks(chunks []map[string]interface{}) map[string][]string {
|
||||
result := make(map[string][]string)
|
||||
for _, chunk := range chunks {
|
||||
content, ok := chunk["content_with_weight"].(string)
|
||||
if !ok || content == "" {
|
||||
continue
|
||||
}
|
||||
var typeMap map[string][]string
|
||||
if err := json.Unmarshal([]byte(content), &typeMap); err != nil {
|
||||
continue
|
||||
}
|
||||
for typ, entities := range typeMap {
|
||||
result[typ] = append(result[typ], entities...)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// NhopEntityNames extracts unique entity names from n_hop_with_weight JSON string.
|
||||
// The JSON format is: [{"path": ["A", "B", "C"], "weights": [0.8, 0.5]}, ...]
|
||||
// Returns entity names in order of first appearance, with duplicates removed.
|
||||
func NhopEntityNames(nHopJSON string) []string {
|
||||
type nhopItem struct {
|
||||
Path []string `json:"path"`
|
||||
Weights []float64 `json:"weights"`
|
||||
}
|
||||
var data []nhopItem
|
||||
if err := json.Unmarshal([]byte(nHopJSON), &data); err != nil {
|
||||
return nil
|
||||
}
|
||||
seen := make(map[string]struct{})
|
||||
var names []string
|
||||
for _, item := range data {
|
||||
for _, name := range item.Path {
|
||||
if _, ok := seen[name]; !ok {
|
||||
seen[name] = struct{}{}
|
||||
names = append(names, name)
|
||||
}
|
||||
}
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
||||
// SearchKGEntities searches for KG entities matching a question.
|
||||
func SearchKGEntities(ctx context.Context, docEngine engine.DocEngine, kbIDs []string, question string, embModel *modelModule.EmbeddingModel, topN int) ([]KGEntity, error) {
|
||||
dense, err := buildKGDenseExpr(embModel, question, topN)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req := buildEntitySearchRequest(kbIDs, question, dense, topN)
|
||||
result, err := docEngine.Search(ctx, req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("KG entity search failed: %w", err)
|
||||
}
|
||||
return ParseKGEntityChunks(result.Chunks), nil
|
||||
}
|
||||
|
||||
// SearchKGEntitiesByTypes searches for KG entities by type keywords.
|
||||
func SearchKGEntitiesByTypes(ctx context.Context, docEngine engine.DocEngine, kbIDs []string, typeKeywords []string, topN int) ([]KGEntity, error) {
|
||||
req := buildEntityTypeSearchRequest(kbIDs, typeKeywords, topN)
|
||||
result, err := docEngine.Search(ctx, req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("KG entity type search failed: %w", err)
|
||||
}
|
||||
return ParseKGEntityChunks(result.Chunks), nil
|
||||
}
|
||||
|
||||
// SearchKGRelations searches for KG relations matching a question.
|
||||
func SearchKGRelations(ctx context.Context, docEngine engine.DocEngine, kbIDs []string, question string, embModel *modelModule.EmbeddingModel, topN int) ([]KGRelation, error) {
|
||||
dense, err := buildKGDenseExpr(embModel, question, topN)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req := buildRelationSearchRequest(kbIDs, question, dense, topN)
|
||||
result, err := docEngine.Search(ctx, req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("KG relation search failed: %w", err)
|
||||
}
|
||||
return ParseKGRelationChunks(result.Chunks), nil
|
||||
}
|
||||
|
||||
// SearchKGCommunityReports searches for community reports related to given entities.
|
||||
func SearchKGCommunityReports(ctx context.Context, docEngine engine.DocEngine, kbIDs []string, entityNames []string, topN int) ([]KGCommunityReport, error) {
|
||||
req := buildCommunitySearchRequest(kbIDs, entityNames, topN)
|
||||
result, err := docEngine.Search(ctx, req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("KG community search failed: %w", err)
|
||||
}
|
||||
return ParseKGCommunityReportChunks(result.Chunks), nil
|
||||
}
|
||||
|
||||
// SearchKGTypeSamples retrieves the type→entities mapping from ES.
|
||||
func SearchKGTypeSamples(ctx context.Context, docEngine engine.DocEngine, kbIDs []string) (map[string][]string, error) {
|
||||
req := buildTypeSamplesSearchRequest(kbIDs)
|
||||
result, err := docEngine.Search(ctx, req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("KG type samples search failed: %w", err)
|
||||
}
|
||||
return ParseKGTypeSamplesChunks(result.Chunks), nil
|
||||
}
|
||||
444
internal/service/kg_search_test.go
Normal file
444
internal/service/kg_search_test.go
Normal file
@@ -0,0 +1,444 @@
|
||||
//
|
||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"ragflow/internal/engine"
|
||||
"ragflow/internal/engine/types"
|
||||
modelModule "ragflow/internal/entity/models"
|
||||
)
|
||||
|
||||
// --- buildEntitySearchRequest ---
|
||||
|
||||
func TestBuildEntitySearchRequest_TextOnly(t *testing.T) {
|
||||
req := buildEntitySearchRequest([]string{"kb1"}, "Elon Musk", nil, 10)
|
||||
if req == nil {
|
||||
t.Fatal("expected non-nil request")
|
||||
}
|
||||
if req.Filter["knowledge_graph_kwd"] != "entity" {
|
||||
t.Errorf("expected 'entity' filter, got %v", req.Filter["knowledge_graph_kwd"])
|
||||
}
|
||||
if req.Limit != 10 {
|
||||
t.Errorf("expected limit 10, got %d", req.Limit)
|
||||
}
|
||||
if len(req.MatchExprs) != 1 {
|
||||
t.Fatalf("expected 1 MatchExpr (text only), got %d", len(req.MatchExprs))
|
||||
}
|
||||
if _, ok := req.MatchExprs[0].(*types.MatchTextExpr); !ok {
|
||||
t.Error("expected MatchTextExpr")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildEntitySearchRequest_EmptyQuestion(t *testing.T) {
|
||||
req := buildEntitySearchRequest([]string{"kb1"}, "", nil, 10)
|
||||
if len(req.MatchExprs) != 0 {
|
||||
t.Error("expected no MatchExprs for empty question")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildEntitySearchRequest_Hybrid(t *testing.T) {
|
||||
dense := &types.MatchDenseExpr{
|
||||
VectorColumnName: "q_768_vec",
|
||||
EmbeddingData: []float64{0.1, 0.2, 0.3},
|
||||
EmbeddingDataType: "float",
|
||||
DistanceType: "cosine",
|
||||
TopN: 10,
|
||||
}
|
||||
req := buildEntitySearchRequest([]string{"kb1"}, "test question", dense, 10)
|
||||
if len(req.MatchExprs) != 3 {
|
||||
t.Fatalf("expected 3 MatchExprs (dense + text + fusion), got %d", len(req.MatchExprs))
|
||||
}
|
||||
if _, ok := req.MatchExprs[0].(*types.MatchDenseExpr); !ok {
|
||||
t.Error("expected MatchDenseExpr at [0]")
|
||||
}
|
||||
if _, ok := req.MatchExprs[1].(*types.MatchTextExpr); !ok {
|
||||
t.Error("expected MatchTextExpr at [1]")
|
||||
}
|
||||
fusion, ok := req.MatchExprs[2].(*types.FusionExpr)
|
||||
if !ok {
|
||||
t.Fatal("expected FusionExpr at [2]")
|
||||
}
|
||||
if fusion.Method != "weighted_sum" {
|
||||
t.Errorf("expected 'weighted_sum', got %q", fusion.Method)
|
||||
}
|
||||
}
|
||||
|
||||
// --- buildEntityTypeSearchRequest ---
|
||||
|
||||
func TestBuildEntityTypeSearchRequest_Basic(t *testing.T) {
|
||||
req := buildEntityTypeSearchRequest([]string{"kb1"}, []string{"PERSON", "ORGANIZATION"}, 10)
|
||||
if req.Filter["knowledge_graph_kwd"] != "entity" {
|
||||
t.Errorf("expected 'entity' filter, got %v", req.Filter["knowledge_graph_kwd"])
|
||||
}
|
||||
filter, ok := req.Filter["entity_type_kwd"].([]interface{})
|
||||
if !ok || len(filter) != 2 {
|
||||
t.Errorf("expected 2 entity_type filters, got %v", filter)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildEntityTypeSearchRequest_EmptyTypes(t *testing.T) {
|
||||
req := buildEntityTypeSearchRequest([]string{"kb1"}, nil, 10)
|
||||
if _, ok := req.Filter["entity_type_kwd"]; ok {
|
||||
t.Error("expected no entity_type_kwd filter for empty types")
|
||||
}
|
||||
}
|
||||
|
||||
// --- buildRelationSearchRequest ---
|
||||
|
||||
func TestBuildRelationSearchRequest_Basic(t *testing.T) {
|
||||
req := buildRelationSearchRequest([]string{"kb1"}, "test", nil, 10)
|
||||
if req.Filter["knowledge_graph_kwd"] != "relation" {
|
||||
t.Errorf("expected 'relation' filter, got %v", req.Filter["knowledge_graph_kwd"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildRelationSearchRequest_EmptyQuestion(t *testing.T) {
|
||||
req := buildRelationSearchRequest([]string{"kb1"}, "", nil, 10)
|
||||
if len(req.MatchExprs) != 0 {
|
||||
t.Error("expected no MatchExprs for empty question")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildRelationSearchRequest_Hybrid(t *testing.T) {
|
||||
dense := &types.MatchDenseExpr{
|
||||
VectorColumnName: "q_768_vec",
|
||||
EmbeddingData: []float64{0.1, 0.2},
|
||||
TopN: 5,
|
||||
}
|
||||
req := buildRelationSearchRequest([]string{"kb1"}, "test", dense, 5)
|
||||
if len(req.MatchExprs) != 3 {
|
||||
t.Fatalf("expected 3 MatchExprs (dense + text + fusion), got %d", len(req.MatchExprs))
|
||||
}
|
||||
}
|
||||
|
||||
// --- buildCommunitySearchRequest ---
|
||||
|
||||
func TestBuildCommunitySearchRequest_Basic(t *testing.T) {
|
||||
req := buildCommunitySearchRequest([]string{"kb1"}, []string{"Elon Musk"}, 5)
|
||||
if req.Filter["knowledge_graph_kwd"] != "community_report" {
|
||||
t.Errorf("expected 'community_report' filter, got %v", req.Filter["knowledge_graph_kwd"])
|
||||
}
|
||||
if req.OrderBy == nil {
|
||||
t.Error("expected OrderBy")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildCommunitySearchRequest_EmptyNames(t *testing.T) {
|
||||
req := buildCommunitySearchRequest([]string{"kb1"}, nil, 5)
|
||||
if _, ok := req.Filter["entities_kwd"]; ok {
|
||||
t.Error("expected no entities_kwd filter for empty names")
|
||||
}
|
||||
}
|
||||
|
||||
// --- buildTypeSamplesSearchRequest ---
|
||||
|
||||
func TestBuildTypeSamplesSearchRequest(t *testing.T) {
|
||||
req := buildTypeSamplesSearchRequest([]string{"kb1"})
|
||||
if req.Filter["knowledge_graph_kwd"] != "ty2ents" {
|
||||
t.Errorf("expected 'ty2ents' filter, got %v", req.Filter["knowledge_graph_kwd"])
|
||||
}
|
||||
if req.Limit != 10000 {
|
||||
t.Errorf("expected 10000, got %d", req.Limit)
|
||||
}
|
||||
}
|
||||
|
||||
// --- ParseKGEntityChunks ---
|
||||
|
||||
func TestParseKGEntityChunks_Basic(t *testing.T) {
|
||||
chunks := []map[string]interface{}{
|
||||
{"entity_kwd": "Elon Musk", "entity_type_kwd": "PERSON", "rank_flt": 0.9, "_score": 0.85, "content_with_weight": "Founder of SpaceX"},
|
||||
}
|
||||
entities := ParseKGEntityChunks(chunks)
|
||||
if len(entities) != 1 {
|
||||
t.Fatalf("expected 1, got %d", len(entities))
|
||||
}
|
||||
if entities[0].Name != "Elon Musk" || entities[0].Type != "PERSON" || entities[0].PageRank != 0.9 || entities[0].Similarity != 0.85 {
|
||||
t.Errorf("unexpected entity fields: %+v", entities[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseKGEntityChunks_List(t *testing.T) {
|
||||
chunks := []map[string]interface{}{
|
||||
{"entity_kwd": []interface{}{"Elon Musk", "elon_musk"}},
|
||||
}
|
||||
entities := ParseKGEntityChunks(chunks)
|
||||
if len(entities) != 1 || entities[0].Name != "Elon Musk" {
|
||||
t.Errorf("expected first list element, got %q", entities[0].Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseKGEntityChunks_EmptyName(t *testing.T) {
|
||||
chunks := []map[string]interface{}{{"entity_type_kwd": "PERSON"}}
|
||||
if len(ParseKGEntityChunks(chunks)) != 0 {
|
||||
t.Error("expected 0 for missing name")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseKGEntityChunks_ScoreFallback(t *testing.T) {
|
||||
chunks := []map[string]interface{}{{"entity_kwd": "Test", "score": 0.75}}
|
||||
if ParseKGEntityChunks(chunks)[0].Similarity != 0.75 {
|
||||
t.Error("expected 0.75 from score field")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseKGEntityChunks_NilInput(t *testing.T) {
|
||||
if len(ParseKGEntityChunks(nil)) != 0 {
|
||||
t.Error("expected 0 for nil input")
|
||||
}
|
||||
}
|
||||
|
||||
// --- ParseKGRelationChunks ---
|
||||
|
||||
func TestParseKGRelationChunks_Basic(t *testing.T) {
|
||||
chunks := []map[string]interface{}{
|
||||
{"from_entity_kwd": "Elon Musk", "to_entity_kwd": "SpaceX", "weight_int": float64(5), "content_with_weight": "Founder"},
|
||||
}
|
||||
relations := ParseKGRelationChunks(chunks)
|
||||
if len(relations) != 1 || relations[0].From != "Elon Musk" || relations[0].Weight != 5 {
|
||||
t.Errorf("unexpected: %+v", relations[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseKGRelationChunks_IntWeight(t *testing.T) {
|
||||
chunks := []map[string]interface{}{{"from_entity_kwd": "A", "to_entity_kwd": "B", "weight_int": 3}}
|
||||
if ParseKGRelationChunks(chunks)[0].Weight != 3 {
|
||||
t.Error("expected weight 3")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseKGRelationChunks_EmptyFrom(t *testing.T) {
|
||||
if len(ParseKGRelationChunks([]map[string]interface{}{{"to_entity_kwd": "B"}})) != 0 {
|
||||
t.Error("expected 0 for missing from")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseKGRelationChunks_NilInput(t *testing.T) {
|
||||
if len(ParseKGRelationChunks(nil)) != 0 {
|
||||
t.Error("expected 0 for nil")
|
||||
}
|
||||
}
|
||||
|
||||
// --- ParseKGCommunityReportChunks ---
|
||||
|
||||
func TestParseKGCommunityReportChunks_Basic(t *testing.T) {
|
||||
chunks := []map[string]interface{}{
|
||||
{"docnm_kwd": "Report 1", "content_with_weight": "content", "weight_flt": 0.95, "entities_kwd": "A, B"},
|
||||
}
|
||||
reports := ParseKGCommunityReportChunks(chunks)
|
||||
if len(reports) != 1 || reports[0].Title != "Report 1" || reports[0].Weight != 0.95 {
|
||||
t.Errorf("unexpected: %+v", reports[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseKGCommunityReportChunks_EmptyTitle(t *testing.T) {
|
||||
if len(ParseKGCommunityReportChunks([]map[string]interface{}{{"weight_flt": 0.5}})) != 0 {
|
||||
t.Error("expected 0 for empty title and content")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseKGCommunityReportChunks_NilInput(t *testing.T) {
|
||||
if len(ParseKGCommunityReportChunks(nil)) != 0 {
|
||||
t.Error("expected 0 for nil")
|
||||
}
|
||||
}
|
||||
|
||||
// --- ParseKGTypeSamplesChunks ---
|
||||
|
||||
func TestParseKGTypeSamplesChunks_ValidJSON(t *testing.T) {
|
||||
chunks := []map[string]interface{}{
|
||||
{"content_with_weight": `{"PERSON": ["Elon Musk", "Einstein"], "ORGANIZATION": ["SpaceX"]}`},
|
||||
}
|
||||
result := ParseKGTypeSamplesChunks(chunks)
|
||||
if len(result) != 2 {
|
||||
t.Fatalf("expected 2 types, got %d: %v", len(result), result)
|
||||
}
|
||||
if len(result["PERSON"]) != 2 || result["PERSON"][0] != "Elon Musk" {
|
||||
t.Errorf("expected PERSON entities, got %v", result["PERSON"])
|
||||
}
|
||||
if len(result["ORGANIZATION"]) != 1 || result["ORGANIZATION"][0] != "SpaceX" {
|
||||
t.Errorf("expected ORGANIZATION entities, got %v", result["ORGANIZATION"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseKGTypeSamplesChunks_InvalidJSON(t *testing.T) {
|
||||
chunks := []map[string]interface{}{
|
||||
{"content_with_weight": "not json"},
|
||||
}
|
||||
result := ParseKGTypeSamplesChunks(chunks)
|
||||
if len(result) != 0 {
|
||||
t.Error("expected empty for invalid JSON")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseKGTypeSamplesChunks_Empty(t *testing.T) {
|
||||
result := ParseKGTypeSamplesChunks(nil)
|
||||
if len(result) != 0 {
|
||||
t.Error("expected empty for nil")
|
||||
}
|
||||
}
|
||||
|
||||
// --- NhopEntityNames ---
|
||||
|
||||
func TestNhopEntityNames_ValidJSON(t *testing.T) {
|
||||
input := `[{"path": ["A", "B", "C"], "weights": [0.8, 0.5]}, {"path": ["C", "D"], "weights": [0.3]}]`
|
||||
names := NhopEntityNames(input)
|
||||
if len(names) != 4 {
|
||||
t.Fatalf("expected 4 unique names, got %d: %v", len(names), names)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNhopEntityNames_Dedup(t *testing.T) {
|
||||
input := `[{"path": ["A", "B"], "weights": [0.9]}, {"path": ["A", "C"], "weights": [0.8]}]`
|
||||
names := NhopEntityNames(input)
|
||||
if len(names) != 3 {
|
||||
t.Errorf("expected 3 unique names (A,B,C), got %d: %v", len(names), names)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNhopEntityNames_InvalidJSON(t *testing.T) {
|
||||
result := NhopEntityNames("not json")
|
||||
if result != nil {
|
||||
t.Error("expected nil for invalid JSON")
|
||||
}
|
||||
}
|
||||
|
||||
// --- Mock engine ---
|
||||
|
||||
type mockKGEngine struct {
|
||||
engine.DocEngine
|
||||
searchFunc func(ctx context.Context, req *types.SearchRequest) (*types.SearchResult, error)
|
||||
}
|
||||
|
||||
func (m *mockKGEngine) Search(ctx context.Context, req *types.SearchRequest) (*types.SearchResult, error) {
|
||||
if m.searchFunc != nil {
|
||||
return m.searchFunc(ctx, req)
|
||||
}
|
||||
return &types.SearchResult{}, nil
|
||||
}
|
||||
|
||||
// --- Mock model driver for Embed tests ---
|
||||
|
||||
type fakeEmbedDriver struct {
|
||||
modelModule.ModelDriver
|
||||
name string
|
||||
vector []float64
|
||||
}
|
||||
|
||||
func (f *fakeEmbedDriver) Name() string { return f.name }
|
||||
|
||||
func (f *fakeEmbedDriver) Embed(modelName *string, texts []string, apiConfig *modelModule.APIConfig, config *modelModule.EmbeddingConfig) ([]modelModule.EmbeddingData, error) {
|
||||
return []modelModule.EmbeddingData{{Embedding: f.vector}}, nil
|
||||
}
|
||||
|
||||
func TestBuildKGDenseExpr_WithModel(t *testing.T) {
|
||||
embModel := &modelModule.EmbeddingModel{
|
||||
ModelName: strPtr("test-model"),
|
||||
ModelDriver: &fakeEmbedDriver{
|
||||
name: "test",
|
||||
vector: []float64{0.1, 0.2, 0.3},
|
||||
},
|
||||
APIConfig: &modelModule.APIConfig{},
|
||||
}
|
||||
dense, err := buildKGDenseExpr(embModel, "test question", 10)
|
||||
if err != nil {
|
||||
t.Fatalf("buildKGDenseExpr failed: %v", err)
|
||||
}
|
||||
if dense == nil {
|
||||
t.Fatal("expected non-nil MatchDenseExpr")
|
||||
}
|
||||
if dense.VectorColumnName != "q_3_vec" {
|
||||
t.Errorf("expected 'q_3_vec', got %q", dense.VectorColumnName)
|
||||
}
|
||||
if dense.TopN != 10 {
|
||||
t.Errorf("expected TopN 10, got %d", dense.TopN)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildKGDenseExpr_NilModel(t *testing.T) {
|
||||
dense, err := buildKGDenseExpr(nil, "test", 10)
|
||||
if dense != nil || err != nil {
|
||||
t.Errorf("expected nil,nil for nil model, got dense=%v err=%v", dense, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildKGDenseExpr_EmptyQuestion(t *testing.T) {
|
||||
dense, err := buildKGDenseExpr(&modelModule.EmbeddingModel{}, "", 10)
|
||||
if dense != nil || err != nil {
|
||||
t.Errorf("expected nil,nil for empty question, got dense=%v err=%v", dense, err)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Search integration with mock ---
|
||||
|
||||
func TestSearchKGEntities_WithMock(t *testing.T) {
|
||||
mock := &mockKGEngine{
|
||||
searchFunc: func(ctx context.Context, req *types.SearchRequest) (*types.SearchResult, error) {
|
||||
if req.Filter["knowledge_graph_kwd"] != "entity" {
|
||||
t.Error("expected entity filter")
|
||||
}
|
||||
return &types.SearchResult{
|
||||
Chunks: []map[string]interface{}{
|
||||
{"entity_kwd": "Elon Musk", "entity_type_kwd": "PERSON"},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
entities, err := SearchKGEntities(context.Background(), mock, []string{"kb1"}, "Elon", nil, 10)
|
||||
if err != nil {
|
||||
t.Fatalf("SearchKGEntities failed: %v", err)
|
||||
}
|
||||
if len(entities) != 1 || entities[0].Name != "Elon Musk" {
|
||||
t.Errorf("expected [Elon Musk], got %v", entities)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearchKGEntitiesByTypes_WithMock(t *testing.T) {
|
||||
mock := &mockKGEngine{
|
||||
searchFunc: func(ctx context.Context, req *types.SearchRequest) (*types.SearchResult, error) {
|
||||
return &types.SearchResult{
|
||||
Chunks: []map[string]interface{}{
|
||||
{"entity_kwd": "SpaceX", "entity_type_kwd": "ORGANIZATION"},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
entities, err := SearchKGEntitiesByTypes(context.Background(), mock, []string{"kb1"}, []string{"ORGANIZATION"}, 10)
|
||||
if err != nil {
|
||||
t.Fatalf("SearchKGEntitiesByTypes failed: %v", err)
|
||||
}
|
||||
if len(entities) != 1 || entities[0].Type != "ORGANIZATION" {
|
||||
t.Errorf("expected ORGANIZATION, got %v", entities)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearchKGTypeSamples_WithMock(t *testing.T) {
|
||||
mock := &mockKGEngine{}
|
||||
samples, err := SearchKGTypeSamples(context.Background(), mock, []string{"kb1"})
|
||||
if err != nil {
|
||||
t.Fatalf("SearchKGTypeSamples failed: %v", err)
|
||||
}
|
||||
if samples == nil {
|
||||
samples = map[string][]string{}
|
||||
}
|
||||
if len(samples) != 0 {
|
||||
t.Errorf("expected empty, got %d", len(samples))
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user