From e629c0203b3f63050c19002ee35bfb5dfc65c2cd Mon Sep 17 00:00:00 2001 From: Jack Date: Fri, 5 Jun 2026 13:23:04 +0800 Subject: [PATCH] feat: add KG entity/relation/community search functions (#15689) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary Knowledge Graph search functions for entity, relation, community report, and type-samples retrieval. Uses DocEngine.SelectFields (PR #15684) for KG-specific fields. ### Functions | Function | Description | |----------|-------------| | `SearchKGEntities` | Hybrid search over KG entities (dense + text + fusion) | | `SearchKGEntitiesByTypes` | Entity search filtered by `entity_type_kwd` | | `SearchKGRelations` | Hybrid search over KG relations | | `SearchKGCommunityReports` | Community report search by entity names | | `SearchKGTypeSamples` | Type→entities mapping for query_rewrite | ### Internal helpers | Helper | Description | |--------|-------------| | `buildHybridExpr` | Shared dense+text+fusion expression construction | | `buildKGDenseExpr` | Wraps `Embed()` call for vector search | | `Parse*` | Convert raw chunks to typed structs | ### Testing 35 tests (pure function + mock integration) Co-Authored-By: Claude Opus 4.8 --------- Co-authored-by: Claude Opus 4.8 --- internal/service/kg_search.go | 362 +++++++++++++++++++++++ internal/service/kg_search_test.go | 444 +++++++++++++++++++++++++++++ 2 files changed, 806 insertions(+) create mode 100644 internal/service/kg_search.go create mode 100644 internal/service/kg_search_test.go diff --git a/internal/service/kg_search.go b/internal/service/kg_search.go new file mode 100644 index 0000000000..a5373eb650 --- /dev/null +++ b/internal/service/kg_search.go @@ -0,0 +1,362 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package service + +import ( + "context" + "encoding/json" + "fmt" + + "ragflow/internal/engine" + "ragflow/internal/engine/types" + modelModule "ragflow/internal/entity/models" +) + +// KGEntity represents a knowledge graph entity. +type KGEntity struct { + Name string // entity_kwd + Type string // entity_type_kwd + PageRank float64 // rank_flt + Similarity float64 // _score + Description string // content_with_weight +} + +// KGRelation represents a relation between two entities. +type KGRelation struct { + From string // from_entity_kwd + To string // to_entity_kwd + Weight int // weight_int + Description string // content_with_weight +} + +// KGCommunityReport represents a community report. +type KGCommunityReport struct { + Title string // docnm_kwd + Content string // content_with_weight + Weight float64 // weight_flt + Entities string // entities_kwd +} + +// buildKGDenseExpr computes the query vector and returns a MatchDenseExpr +// for KG hybrid search. Returns nil if embModel or question is empty. +func buildKGDenseExpr(embModel *modelModule.EmbeddingModel, question string, topN int) (*types.MatchDenseExpr, error) { + if embModel == nil || question == "" { + return nil, nil + } + embCfg := &modelModule.EmbeddingConfig{Dimension: 0} + embeddings, err := embModel.ModelDriver.Embed(embModel.ModelName, []string{question}, embModel.APIConfig, embCfg) + if err != nil { + return nil, fmt.Errorf("KG entity embed failed: %w", err) + } + if len(embeddings) == 0 || len(embeddings[0].Embedding) == 0 { + return nil, nil + } + vector := embeddings[0].Embedding + return &types.MatchDenseExpr{ + VectorColumnName: fmt.Sprintf("q_%d_vec", len(vector)), + EmbeddingData: vector, + EmbeddingDataType: "float", + DistanceType: "cosine", + TopN: topN, + ExtraOptions: map[string]interface{}{"similarity": 0.3}, + }, nil +} + +// buildHybridExpr returns MatchExprs for hybrid search (dense + text + fusion). +func buildHybridExpr(dense *types.MatchDenseExpr, text *types.MatchTextExpr, topN int) []interface{} { + return []interface{}{ + dense, + text, + &types.FusionExpr{ + Method: "weighted_sum", + TopN: topN, + FusionParams: map[string]interface{}{"weights": "0.05,0.95"}, + }, + } +} + +// buildEntitySearchRequest constructs a SearchRequest for KG entities. +// dense may be nil for text-only search. +func buildEntitySearchRequest(kbIDs []string, question string, dense *types.MatchDenseExpr, topN int) *types.SearchRequest { + req := &types.SearchRequest{ + KbIDs: kbIDs, + SelectFields: []string{"entity_kwd", "entity_type_kwd", "rank_flt", "content_with_weight"}, + Limit: topN, + Filter: map[string]interface{}{"knowledge_graph_kwd": "entity"}, + } + if question == "" { + return req + } + textExpr := &types.MatchTextExpr{ + Fields: []string{"entity_kwd^10", "content_ltks^2"}, + MatchingText: question, + TopN: topN, + } + if dense != nil { + req.MatchExprs = buildHybridExpr(dense, textExpr, topN) + req.RankFeature = map[string]float64{"pagerank_fea": 10.0} + } else { + req.MatchExprs = []interface{}{textExpr} + } + return req +} + +// buildEntityTypeSearchRequest constructs a SearchRequest for KG entities by type. +func buildEntityTypeSearchRequest(kbIDs []string, typeKeywords []string, topN int) *types.SearchRequest { + req := &types.SearchRequest{ + KbIDs: kbIDs, + SelectFields: []string{"entity_kwd", "entity_type_kwd", "rank_flt", "content_with_weight"}, + Limit: topN, + Filter: map[string]interface{}{ + "knowledge_graph_kwd": "entity", + }, + } + if len(typeKeywords) > 0 { + filters := make([]interface{}, len(typeKeywords)) + for i, t := range typeKeywords { + filters[i] = t + } + req.Filter["entity_type_kwd"] = filters + } + return req +} + +// buildRelationSearchRequest constructs a SearchRequest for KG relations. +// dense may be nil for text-only search. +func buildRelationSearchRequest(kbIDs []string, question string, dense *types.MatchDenseExpr, topN int) *types.SearchRequest { + req := &types.SearchRequest{ + KbIDs: kbIDs, + SelectFields: []string{"from_entity_kwd", "to_entity_kwd", "weight_int", "content_with_weight"}, + Limit: topN, + Filter: map[string]interface{}{"knowledge_graph_kwd": "relation"}, + } + if question != "" { + textExpr := &types.MatchTextExpr{ + Fields: []string{"content_ltks"}, + MatchingText: question, + TopN: topN, + } + if dense != nil { + req.MatchExprs = buildHybridExpr(dense, textExpr, topN) + } else { + req.MatchExprs = []interface{}{textExpr} + } + } + return req +} + +// buildCommunitySearchRequest constructs a SearchRequest for KG community reports. +// Matches community reports whose entities_kwd contains any of the given entity names. +func buildCommunitySearchRequest(kbIDs []string, entityNames []string, topN int) *types.SearchRequest { + req := &types.SearchRequest{ + KbIDs: kbIDs, + SelectFields: []string{"docnm_kwd", "content_with_weight", "weight_flt", "entities_kwd"}, + Limit: topN, + Filter: map[string]interface{}{ + "knowledge_graph_kwd": "community_report", + }, + OrderBy: (&types.OrderByExpr{}).Desc("weight_flt"), + } + if len(entityNames) > 0 { + filters := make([]interface{}, len(entityNames)) + for i, name := range entityNames { + filters[i] = name + } + req.Filter["entities_kwd"] = filters + } + return req +} + +// buildTypeSamplesSearchRequest constructs a SearchRequest for ty2ents data. +func buildTypeSamplesSearchRequest(kbIDs []string) *types.SearchRequest { + return &types.SearchRequest{ + KbIDs: kbIDs, + SelectFields: []string{"content_with_weight"}, + Limit: 10000, + Filter: map[string]interface{}{"knowledge_graph_kwd": "ty2ents"}, + } +} + +// ParseKGEntityChunks converts raw search result chunks into KGEntity slices. +func ParseKGEntityChunks(chunks []map[string]interface{}) []KGEntity { + var entities []KGEntity + for _, chunk := range chunks { + e := KGEntity{} + if v, ok := chunk["entity_kwd"].(string); ok { + e.Name = v + } else if list, ok := chunk["entity_kwd"].([]interface{}); ok && len(list) > 0 { + e.Name, _ = list[0].(string) + } + if e.Name == "" { + continue + } + e.Type, _ = chunk["entity_type_kwd"].(string) + e.Description, _ = chunk["content_with_weight"].(string) + if v, ok := chunk["rank_flt"].(float64); ok { + e.PageRank = v + } + if v, ok := chunk["_score"].(float64); ok { + e.Similarity = v + } else if v, ok := chunk["score"].(float64); ok { + e.Similarity = v + } + entities = append(entities, e) + } + return entities +} + +// ParseKGRelationChunks converts raw search result chunks into KGRelation slices. +func ParseKGRelationChunks(chunks []map[string]interface{}) []KGRelation { + var relations []KGRelation + for _, chunk := range chunks { + r := KGRelation{} + r.From, _ = chunk["from_entity_kwd"].(string) + r.To, _ = chunk["to_entity_kwd"].(string) + r.Description, _ = chunk["content_with_weight"].(string) + if v, ok := chunk["weight_int"].(float64); ok { + r.Weight = int(v) + } else if v, ok := chunk["weight_int"].(int); ok { + r.Weight = v + } + if r.From == "" || r.To == "" { + continue + } + relations = append(relations, r) + } + return relations +} + +// ParseKGCommunityReportChunks converts raw search result chunks into KGCommunityReport slices. +func ParseKGCommunityReportChunks(chunks []map[string]interface{}) []KGCommunityReport { + var reports []KGCommunityReport + for _, chunk := range chunks { + r := KGCommunityReport{} + r.Title, _ = chunk["docnm_kwd"].(string) + r.Content, _ = chunk["content_with_weight"].(string) + r.Entities, _ = chunk["entities_kwd"].(string) + if v, ok := chunk["weight_flt"].(float64); ok { + r.Weight = v + } + if r.Title == "" && r.Content == "" { + continue + } + reports = append(reports, r) + } + return reports +} + +// ParseKGTypeSamplesChunks converts raw search result chunks into a type→entities map. +func ParseKGTypeSamplesChunks(chunks []map[string]interface{}) map[string][]string { + result := make(map[string][]string) + for _, chunk := range chunks { + content, ok := chunk["content_with_weight"].(string) + if !ok || content == "" { + continue + } + var typeMap map[string][]string + if err := json.Unmarshal([]byte(content), &typeMap); err != nil { + continue + } + for typ, entities := range typeMap { + result[typ] = append(result[typ], entities...) + } + } + return result +} + +// NhopEntityNames extracts unique entity names from n_hop_with_weight JSON string. +// The JSON format is: [{"path": ["A", "B", "C"], "weights": [0.8, 0.5]}, ...] +// Returns entity names in order of first appearance, with duplicates removed. +func NhopEntityNames(nHopJSON string) []string { + type nhopItem struct { + Path []string `json:"path"` + Weights []float64 `json:"weights"` + } + var data []nhopItem + if err := json.Unmarshal([]byte(nHopJSON), &data); err != nil { + return nil + } + seen := make(map[string]struct{}) + var names []string + for _, item := range data { + for _, name := range item.Path { + if _, ok := seen[name]; !ok { + seen[name] = struct{}{} + names = append(names, name) + } + } + } + return names +} + +// SearchKGEntities searches for KG entities matching a question. +func SearchKGEntities(ctx context.Context, docEngine engine.DocEngine, kbIDs []string, question string, embModel *modelModule.EmbeddingModel, topN int) ([]KGEntity, error) { + dense, err := buildKGDenseExpr(embModel, question, topN) + if err != nil { + return nil, err + } + req := buildEntitySearchRequest(kbIDs, question, dense, topN) + result, err := docEngine.Search(ctx, req) + if err != nil { + return nil, fmt.Errorf("KG entity search failed: %w", err) + } + return ParseKGEntityChunks(result.Chunks), nil +} + +// SearchKGEntitiesByTypes searches for KG entities by type keywords. +func SearchKGEntitiesByTypes(ctx context.Context, docEngine engine.DocEngine, kbIDs []string, typeKeywords []string, topN int) ([]KGEntity, error) { + req := buildEntityTypeSearchRequest(kbIDs, typeKeywords, topN) + result, err := docEngine.Search(ctx, req) + if err != nil { + return nil, fmt.Errorf("KG entity type search failed: %w", err) + } + return ParseKGEntityChunks(result.Chunks), nil +} + +// SearchKGRelations searches for KG relations matching a question. +func SearchKGRelations(ctx context.Context, docEngine engine.DocEngine, kbIDs []string, question string, embModel *modelModule.EmbeddingModel, topN int) ([]KGRelation, error) { + dense, err := buildKGDenseExpr(embModel, question, topN) + if err != nil { + return nil, err + } + req := buildRelationSearchRequest(kbIDs, question, dense, topN) + result, err := docEngine.Search(ctx, req) + if err != nil { + return nil, fmt.Errorf("KG relation search failed: %w", err) + } + return ParseKGRelationChunks(result.Chunks), nil +} + +// SearchKGCommunityReports searches for community reports related to given entities. +func SearchKGCommunityReports(ctx context.Context, docEngine engine.DocEngine, kbIDs []string, entityNames []string, topN int) ([]KGCommunityReport, error) { + req := buildCommunitySearchRequest(kbIDs, entityNames, topN) + result, err := docEngine.Search(ctx, req) + if err != nil { + return nil, fmt.Errorf("KG community search failed: %w", err) + } + return ParseKGCommunityReportChunks(result.Chunks), nil +} + +// SearchKGTypeSamples retrieves the type→entities mapping from ES. +func SearchKGTypeSamples(ctx context.Context, docEngine engine.DocEngine, kbIDs []string) (map[string][]string, error) { + req := buildTypeSamplesSearchRequest(kbIDs) + result, err := docEngine.Search(ctx, req) + if err != nil { + return nil, fmt.Errorf("KG type samples search failed: %w", err) + } + return ParseKGTypeSamplesChunks(result.Chunks), nil +} diff --git a/internal/service/kg_search_test.go b/internal/service/kg_search_test.go new file mode 100644 index 0000000000..79801b0df3 --- /dev/null +++ b/internal/service/kg_search_test.go @@ -0,0 +1,444 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package service + +import ( + "context" + "testing" + + "ragflow/internal/engine" + "ragflow/internal/engine/types" + modelModule "ragflow/internal/entity/models" +) + +// --- buildEntitySearchRequest --- + +func TestBuildEntitySearchRequest_TextOnly(t *testing.T) { + req := buildEntitySearchRequest([]string{"kb1"}, "Elon Musk", nil, 10) + if req == nil { + t.Fatal("expected non-nil request") + } + if req.Filter["knowledge_graph_kwd"] != "entity" { + t.Errorf("expected 'entity' filter, got %v", req.Filter["knowledge_graph_kwd"]) + } + if req.Limit != 10 { + t.Errorf("expected limit 10, got %d", req.Limit) + } + if len(req.MatchExprs) != 1 { + t.Fatalf("expected 1 MatchExpr (text only), got %d", len(req.MatchExprs)) + } + if _, ok := req.MatchExprs[0].(*types.MatchTextExpr); !ok { + t.Error("expected MatchTextExpr") + } +} + +func TestBuildEntitySearchRequest_EmptyQuestion(t *testing.T) { + req := buildEntitySearchRequest([]string{"kb1"}, "", nil, 10) + if len(req.MatchExprs) != 0 { + t.Error("expected no MatchExprs for empty question") + } +} + +func TestBuildEntitySearchRequest_Hybrid(t *testing.T) { + dense := &types.MatchDenseExpr{ + VectorColumnName: "q_768_vec", + EmbeddingData: []float64{0.1, 0.2, 0.3}, + EmbeddingDataType: "float", + DistanceType: "cosine", + TopN: 10, + } + req := buildEntitySearchRequest([]string{"kb1"}, "test question", dense, 10) + if len(req.MatchExprs) != 3 { + t.Fatalf("expected 3 MatchExprs (dense + text + fusion), got %d", len(req.MatchExprs)) + } + if _, ok := req.MatchExprs[0].(*types.MatchDenseExpr); !ok { + t.Error("expected MatchDenseExpr at [0]") + } + if _, ok := req.MatchExprs[1].(*types.MatchTextExpr); !ok { + t.Error("expected MatchTextExpr at [1]") + } + fusion, ok := req.MatchExprs[2].(*types.FusionExpr) + if !ok { + t.Fatal("expected FusionExpr at [2]") + } + if fusion.Method != "weighted_sum" { + t.Errorf("expected 'weighted_sum', got %q", fusion.Method) + } +} + +// --- buildEntityTypeSearchRequest --- + +func TestBuildEntityTypeSearchRequest_Basic(t *testing.T) { + req := buildEntityTypeSearchRequest([]string{"kb1"}, []string{"PERSON", "ORGANIZATION"}, 10) + if req.Filter["knowledge_graph_kwd"] != "entity" { + t.Errorf("expected 'entity' filter, got %v", req.Filter["knowledge_graph_kwd"]) + } + filter, ok := req.Filter["entity_type_kwd"].([]interface{}) + if !ok || len(filter) != 2 { + t.Errorf("expected 2 entity_type filters, got %v", filter) + } +} + +func TestBuildEntityTypeSearchRequest_EmptyTypes(t *testing.T) { + req := buildEntityTypeSearchRequest([]string{"kb1"}, nil, 10) + if _, ok := req.Filter["entity_type_kwd"]; ok { + t.Error("expected no entity_type_kwd filter for empty types") + } +} + +// --- buildRelationSearchRequest --- + +func TestBuildRelationSearchRequest_Basic(t *testing.T) { + req := buildRelationSearchRequest([]string{"kb1"}, "test", nil, 10) + if req.Filter["knowledge_graph_kwd"] != "relation" { + t.Errorf("expected 'relation' filter, got %v", req.Filter["knowledge_graph_kwd"]) + } +} + +func TestBuildRelationSearchRequest_EmptyQuestion(t *testing.T) { + req := buildRelationSearchRequest([]string{"kb1"}, "", nil, 10) + if len(req.MatchExprs) != 0 { + t.Error("expected no MatchExprs for empty question") + } +} + +func TestBuildRelationSearchRequest_Hybrid(t *testing.T) { + dense := &types.MatchDenseExpr{ + VectorColumnName: "q_768_vec", + EmbeddingData: []float64{0.1, 0.2}, + TopN: 5, + } + req := buildRelationSearchRequest([]string{"kb1"}, "test", dense, 5) + if len(req.MatchExprs) != 3 { + t.Fatalf("expected 3 MatchExprs (dense + text + fusion), got %d", len(req.MatchExprs)) + } +} + +// --- buildCommunitySearchRequest --- + +func TestBuildCommunitySearchRequest_Basic(t *testing.T) { + req := buildCommunitySearchRequest([]string{"kb1"}, []string{"Elon Musk"}, 5) + if req.Filter["knowledge_graph_kwd"] != "community_report" { + t.Errorf("expected 'community_report' filter, got %v", req.Filter["knowledge_graph_kwd"]) + } + if req.OrderBy == nil { + t.Error("expected OrderBy") + } +} + +func TestBuildCommunitySearchRequest_EmptyNames(t *testing.T) { + req := buildCommunitySearchRequest([]string{"kb1"}, nil, 5) + if _, ok := req.Filter["entities_kwd"]; ok { + t.Error("expected no entities_kwd filter for empty names") + } +} + +// --- buildTypeSamplesSearchRequest --- + +func TestBuildTypeSamplesSearchRequest(t *testing.T) { + req := buildTypeSamplesSearchRequest([]string{"kb1"}) + if req.Filter["knowledge_graph_kwd"] != "ty2ents" { + t.Errorf("expected 'ty2ents' filter, got %v", req.Filter["knowledge_graph_kwd"]) + } + if req.Limit != 10000 { + t.Errorf("expected 10000, got %d", req.Limit) + } +} + +// --- ParseKGEntityChunks --- + +func TestParseKGEntityChunks_Basic(t *testing.T) { + chunks := []map[string]interface{}{ + {"entity_kwd": "Elon Musk", "entity_type_kwd": "PERSON", "rank_flt": 0.9, "_score": 0.85, "content_with_weight": "Founder of SpaceX"}, + } + entities := ParseKGEntityChunks(chunks) + if len(entities) != 1 { + t.Fatalf("expected 1, got %d", len(entities)) + } + if entities[0].Name != "Elon Musk" || entities[0].Type != "PERSON" || entities[0].PageRank != 0.9 || entities[0].Similarity != 0.85 { + t.Errorf("unexpected entity fields: %+v", entities[0]) + } +} + +func TestParseKGEntityChunks_List(t *testing.T) { + chunks := []map[string]interface{}{ + {"entity_kwd": []interface{}{"Elon Musk", "elon_musk"}}, + } + entities := ParseKGEntityChunks(chunks) + if len(entities) != 1 || entities[0].Name != "Elon Musk" { + t.Errorf("expected first list element, got %q", entities[0].Name) + } +} + +func TestParseKGEntityChunks_EmptyName(t *testing.T) { + chunks := []map[string]interface{}{{"entity_type_kwd": "PERSON"}} + if len(ParseKGEntityChunks(chunks)) != 0 { + t.Error("expected 0 for missing name") + } +} + +func TestParseKGEntityChunks_ScoreFallback(t *testing.T) { + chunks := []map[string]interface{}{{"entity_kwd": "Test", "score": 0.75}} + if ParseKGEntityChunks(chunks)[0].Similarity != 0.75 { + t.Error("expected 0.75 from score field") + } +} + +func TestParseKGEntityChunks_NilInput(t *testing.T) { + if len(ParseKGEntityChunks(nil)) != 0 { + t.Error("expected 0 for nil input") + } +} + +// --- ParseKGRelationChunks --- + +func TestParseKGRelationChunks_Basic(t *testing.T) { + chunks := []map[string]interface{}{ + {"from_entity_kwd": "Elon Musk", "to_entity_kwd": "SpaceX", "weight_int": float64(5), "content_with_weight": "Founder"}, + } + relations := ParseKGRelationChunks(chunks) + if len(relations) != 1 || relations[0].From != "Elon Musk" || relations[0].Weight != 5 { + t.Errorf("unexpected: %+v", relations[0]) + } +} + +func TestParseKGRelationChunks_IntWeight(t *testing.T) { + chunks := []map[string]interface{}{{"from_entity_kwd": "A", "to_entity_kwd": "B", "weight_int": 3}} + if ParseKGRelationChunks(chunks)[0].Weight != 3 { + t.Error("expected weight 3") + } +} + +func TestParseKGRelationChunks_EmptyFrom(t *testing.T) { + if len(ParseKGRelationChunks([]map[string]interface{}{{"to_entity_kwd": "B"}})) != 0 { + t.Error("expected 0 for missing from") + } +} + +func TestParseKGRelationChunks_NilInput(t *testing.T) { + if len(ParseKGRelationChunks(nil)) != 0 { + t.Error("expected 0 for nil") + } +} + +// --- ParseKGCommunityReportChunks --- + +func TestParseKGCommunityReportChunks_Basic(t *testing.T) { + chunks := []map[string]interface{}{ + {"docnm_kwd": "Report 1", "content_with_weight": "content", "weight_flt": 0.95, "entities_kwd": "A, B"}, + } + reports := ParseKGCommunityReportChunks(chunks) + if len(reports) != 1 || reports[0].Title != "Report 1" || reports[0].Weight != 0.95 { + t.Errorf("unexpected: %+v", reports[0]) + } +} + +func TestParseKGCommunityReportChunks_EmptyTitle(t *testing.T) { + if len(ParseKGCommunityReportChunks([]map[string]interface{}{{"weight_flt": 0.5}})) != 0 { + t.Error("expected 0 for empty title and content") + } +} + +func TestParseKGCommunityReportChunks_NilInput(t *testing.T) { + if len(ParseKGCommunityReportChunks(nil)) != 0 { + t.Error("expected 0 for nil") + } +} + +// --- ParseKGTypeSamplesChunks --- + +func TestParseKGTypeSamplesChunks_ValidJSON(t *testing.T) { + chunks := []map[string]interface{}{ + {"content_with_weight": `{"PERSON": ["Elon Musk", "Einstein"], "ORGANIZATION": ["SpaceX"]}`}, + } + result := ParseKGTypeSamplesChunks(chunks) + if len(result) != 2 { + t.Fatalf("expected 2 types, got %d: %v", len(result), result) + } + if len(result["PERSON"]) != 2 || result["PERSON"][0] != "Elon Musk" { + t.Errorf("expected PERSON entities, got %v", result["PERSON"]) + } + if len(result["ORGANIZATION"]) != 1 || result["ORGANIZATION"][0] != "SpaceX" { + t.Errorf("expected ORGANIZATION entities, got %v", result["ORGANIZATION"]) + } +} + +func TestParseKGTypeSamplesChunks_InvalidJSON(t *testing.T) { + chunks := []map[string]interface{}{ + {"content_with_weight": "not json"}, + } + result := ParseKGTypeSamplesChunks(chunks) + if len(result) != 0 { + t.Error("expected empty for invalid JSON") + } +} + +func TestParseKGTypeSamplesChunks_Empty(t *testing.T) { + result := ParseKGTypeSamplesChunks(nil) + if len(result) != 0 { + t.Error("expected empty for nil") + } +} + +// --- NhopEntityNames --- + +func TestNhopEntityNames_ValidJSON(t *testing.T) { + input := `[{"path": ["A", "B", "C"], "weights": [0.8, 0.5]}, {"path": ["C", "D"], "weights": [0.3]}]` + names := NhopEntityNames(input) + if len(names) != 4 { + t.Fatalf("expected 4 unique names, got %d: %v", len(names), names) + } +} + +func TestNhopEntityNames_Dedup(t *testing.T) { + input := `[{"path": ["A", "B"], "weights": [0.9]}, {"path": ["A", "C"], "weights": [0.8]}]` + names := NhopEntityNames(input) + if len(names) != 3 { + t.Errorf("expected 3 unique names (A,B,C), got %d: %v", len(names), names) + } +} + +func TestNhopEntityNames_InvalidJSON(t *testing.T) { + result := NhopEntityNames("not json") + if result != nil { + t.Error("expected nil for invalid JSON") + } +} + +// --- Mock engine --- + +type mockKGEngine struct { + engine.DocEngine + searchFunc func(ctx context.Context, req *types.SearchRequest) (*types.SearchResult, error) +} + +func (m *mockKGEngine) Search(ctx context.Context, req *types.SearchRequest) (*types.SearchResult, error) { + if m.searchFunc != nil { + return m.searchFunc(ctx, req) + } + return &types.SearchResult{}, nil +} + +// --- Mock model driver for Embed tests --- + +type fakeEmbedDriver struct { + modelModule.ModelDriver + name string + vector []float64 +} + +func (f *fakeEmbedDriver) Name() string { return f.name } + +func (f *fakeEmbedDriver) Embed(modelName *string, texts []string, apiConfig *modelModule.APIConfig, config *modelModule.EmbeddingConfig) ([]modelModule.EmbeddingData, error) { + return []modelModule.EmbeddingData{{Embedding: f.vector}}, nil +} + +func TestBuildKGDenseExpr_WithModel(t *testing.T) { + embModel := &modelModule.EmbeddingModel{ + ModelName: strPtr("test-model"), + ModelDriver: &fakeEmbedDriver{ + name: "test", + vector: []float64{0.1, 0.2, 0.3}, + }, + APIConfig: &modelModule.APIConfig{}, + } + dense, err := buildKGDenseExpr(embModel, "test question", 10) + if err != nil { + t.Fatalf("buildKGDenseExpr failed: %v", err) + } + if dense == nil { + t.Fatal("expected non-nil MatchDenseExpr") + } + if dense.VectorColumnName != "q_3_vec" { + t.Errorf("expected 'q_3_vec', got %q", dense.VectorColumnName) + } + if dense.TopN != 10 { + t.Errorf("expected TopN 10, got %d", dense.TopN) + } +} + +func TestBuildKGDenseExpr_NilModel(t *testing.T) { + dense, err := buildKGDenseExpr(nil, "test", 10) + if dense != nil || err != nil { + t.Errorf("expected nil,nil for nil model, got dense=%v err=%v", dense, err) + } +} + +func TestBuildKGDenseExpr_EmptyQuestion(t *testing.T) { + dense, err := buildKGDenseExpr(&modelModule.EmbeddingModel{}, "", 10) + if dense != nil || err != nil { + t.Errorf("expected nil,nil for empty question, got dense=%v err=%v", dense, err) + } +} + +// --- Search integration with mock --- + +func TestSearchKGEntities_WithMock(t *testing.T) { + mock := &mockKGEngine{ + searchFunc: func(ctx context.Context, req *types.SearchRequest) (*types.SearchResult, error) { + if req.Filter["knowledge_graph_kwd"] != "entity" { + t.Error("expected entity filter") + } + return &types.SearchResult{ + Chunks: []map[string]interface{}{ + {"entity_kwd": "Elon Musk", "entity_type_kwd": "PERSON"}, + }, + }, nil + }, + } + entities, err := SearchKGEntities(context.Background(), mock, []string{"kb1"}, "Elon", nil, 10) + if err != nil { + t.Fatalf("SearchKGEntities failed: %v", err) + } + if len(entities) != 1 || entities[0].Name != "Elon Musk" { + t.Errorf("expected [Elon Musk], got %v", entities) + } +} + +func TestSearchKGEntitiesByTypes_WithMock(t *testing.T) { + mock := &mockKGEngine{ + searchFunc: func(ctx context.Context, req *types.SearchRequest) (*types.SearchResult, error) { + return &types.SearchResult{ + Chunks: []map[string]interface{}{ + {"entity_kwd": "SpaceX", "entity_type_kwd": "ORGANIZATION"}, + }, + }, nil + }, + } + entities, err := SearchKGEntitiesByTypes(context.Background(), mock, []string{"kb1"}, []string{"ORGANIZATION"}, 10) + if err != nil { + t.Fatalf("SearchKGEntitiesByTypes failed: %v", err) + } + if len(entities) != 1 || entities[0].Type != "ORGANIZATION" { + t.Errorf("expected ORGANIZATION, got %v", entities) + } +} + +func TestSearchKGTypeSamples_WithMock(t *testing.T) { + mock := &mockKGEngine{} + samples, err := SearchKGTypeSamples(context.Background(), mock, []string{"kb1"}) + if err != nil { + t.Fatalf("SearchKGTypeSamples failed: %v", err) + } + if samples == nil { + samples = map[string][]string{} + } + if len(samples) != 0 { + t.Errorf("expected empty, got %d", len(samples)) + } +} +