feat: add KG entity/relation/community search functions (#15689)

## Summary

Knowledge Graph search functions for entity, relation, community report,
and type-samples retrieval. Uses DocEngine.SelectFields (PR #15684) for
KG-specific fields.

### Functions

| Function | Description |
|----------|-------------|
| `SearchKGEntities` | Hybrid search over KG entities (dense + text +
fusion) |
| `SearchKGEntitiesByTypes` | Entity search filtered by
`entity_type_kwd` |
| `SearchKGRelations` | Hybrid search over KG relations |
| `SearchKGCommunityReports` | Community report search by entity names |
| `SearchKGTypeSamples` | Type→entities mapping for query_rewrite |

### Internal helpers

| Helper | Description |
|--------|-------------|
| `buildHybridExpr` | Shared dense+text+fusion expression construction |
| `buildKGDenseExpr` | Wraps `Embed()` call for vector search |
| `Parse*` | Convert raw chunks to typed structs |

### Testing

35 tests (pure function + mock integration)

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>

---------

Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
Jack
2026-06-05 13:23:04 +08:00
committed by GitHub
parent 4b2af1347c
commit e629c0203b
2 changed files with 806 additions and 0 deletions

View File

@@ -0,0 +1,362 @@
//
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
package service
import (
"context"
"encoding/json"
"fmt"
"ragflow/internal/engine"
"ragflow/internal/engine/types"
modelModule "ragflow/internal/entity/models"
)
// KGEntity represents a knowledge graph entity.
type KGEntity struct {
Name string // entity_kwd
Type string // entity_type_kwd
PageRank float64 // rank_flt
Similarity float64 // _score
Description string // content_with_weight
}
// KGRelation represents a relation between two entities.
type KGRelation struct {
From string // from_entity_kwd
To string // to_entity_kwd
Weight int // weight_int
Description string // content_with_weight
}
// KGCommunityReport represents a community report.
type KGCommunityReport struct {
Title string // docnm_kwd
Content string // content_with_weight
Weight float64 // weight_flt
Entities string // entities_kwd
}
// buildKGDenseExpr computes the query vector and returns a MatchDenseExpr
// for KG hybrid search. Returns nil if embModel or question is empty.
func buildKGDenseExpr(embModel *modelModule.EmbeddingModel, question string, topN int) (*types.MatchDenseExpr, error) {
if embModel == nil || question == "" {
return nil, nil
}
embCfg := &modelModule.EmbeddingConfig{Dimension: 0}
embeddings, err := embModel.ModelDriver.Embed(embModel.ModelName, []string{question}, embModel.APIConfig, embCfg)
if err != nil {
return nil, fmt.Errorf("KG entity embed failed: %w", err)
}
if len(embeddings) == 0 || len(embeddings[0].Embedding) == 0 {
return nil, nil
}
vector := embeddings[0].Embedding
return &types.MatchDenseExpr{
VectorColumnName: fmt.Sprintf("q_%d_vec", len(vector)),
EmbeddingData: vector,
EmbeddingDataType: "float",
DistanceType: "cosine",
TopN: topN,
ExtraOptions: map[string]interface{}{"similarity": 0.3},
}, nil
}
// buildHybridExpr returns MatchExprs for hybrid search (dense + text + fusion).
func buildHybridExpr(dense *types.MatchDenseExpr, text *types.MatchTextExpr, topN int) []interface{} {
return []interface{}{
dense,
text,
&types.FusionExpr{
Method: "weighted_sum",
TopN: topN,
FusionParams: map[string]interface{}{"weights": "0.05,0.95"},
},
}
}
// buildEntitySearchRequest constructs a SearchRequest for KG entities.
// dense may be nil for text-only search.
func buildEntitySearchRequest(kbIDs []string, question string, dense *types.MatchDenseExpr, topN int) *types.SearchRequest {
req := &types.SearchRequest{
KbIDs: kbIDs,
SelectFields: []string{"entity_kwd", "entity_type_kwd", "rank_flt", "content_with_weight"},
Limit: topN,
Filter: map[string]interface{}{"knowledge_graph_kwd": "entity"},
}
if question == "" {
return req
}
textExpr := &types.MatchTextExpr{
Fields: []string{"entity_kwd^10", "content_ltks^2"},
MatchingText: question,
TopN: topN,
}
if dense != nil {
req.MatchExprs = buildHybridExpr(dense, textExpr, topN)
req.RankFeature = map[string]float64{"pagerank_fea": 10.0}
} else {
req.MatchExprs = []interface{}{textExpr}
}
return req
}
// buildEntityTypeSearchRequest constructs a SearchRequest for KG entities by type.
func buildEntityTypeSearchRequest(kbIDs []string, typeKeywords []string, topN int) *types.SearchRequest {
req := &types.SearchRequest{
KbIDs: kbIDs,
SelectFields: []string{"entity_kwd", "entity_type_kwd", "rank_flt", "content_with_weight"},
Limit: topN,
Filter: map[string]interface{}{
"knowledge_graph_kwd": "entity",
},
}
if len(typeKeywords) > 0 {
filters := make([]interface{}, len(typeKeywords))
for i, t := range typeKeywords {
filters[i] = t
}
req.Filter["entity_type_kwd"] = filters
}
return req
}
// buildRelationSearchRequest constructs a SearchRequest for KG relations.
// dense may be nil for text-only search.
func buildRelationSearchRequest(kbIDs []string, question string, dense *types.MatchDenseExpr, topN int) *types.SearchRequest {
req := &types.SearchRequest{
KbIDs: kbIDs,
SelectFields: []string{"from_entity_kwd", "to_entity_kwd", "weight_int", "content_with_weight"},
Limit: topN,
Filter: map[string]interface{}{"knowledge_graph_kwd": "relation"},
}
if question != "" {
textExpr := &types.MatchTextExpr{
Fields: []string{"content_ltks"},
MatchingText: question,
TopN: topN,
}
if dense != nil {
req.MatchExprs = buildHybridExpr(dense, textExpr, topN)
} else {
req.MatchExprs = []interface{}{textExpr}
}
}
return req
}
// buildCommunitySearchRequest constructs a SearchRequest for KG community reports.
// Matches community reports whose entities_kwd contains any of the given entity names.
func buildCommunitySearchRequest(kbIDs []string, entityNames []string, topN int) *types.SearchRequest {
req := &types.SearchRequest{
KbIDs: kbIDs,
SelectFields: []string{"docnm_kwd", "content_with_weight", "weight_flt", "entities_kwd"},
Limit: topN,
Filter: map[string]interface{}{
"knowledge_graph_kwd": "community_report",
},
OrderBy: (&types.OrderByExpr{}).Desc("weight_flt"),
}
if len(entityNames) > 0 {
filters := make([]interface{}, len(entityNames))
for i, name := range entityNames {
filters[i] = name
}
req.Filter["entities_kwd"] = filters
}
return req
}
// buildTypeSamplesSearchRequest constructs a SearchRequest for ty2ents data.
func buildTypeSamplesSearchRequest(kbIDs []string) *types.SearchRequest {
return &types.SearchRequest{
KbIDs: kbIDs,
SelectFields: []string{"content_with_weight"},
Limit: 10000,
Filter: map[string]interface{}{"knowledge_graph_kwd": "ty2ents"},
}
}
// ParseKGEntityChunks converts raw search result chunks into KGEntity slices.
func ParseKGEntityChunks(chunks []map[string]interface{}) []KGEntity {
var entities []KGEntity
for _, chunk := range chunks {
e := KGEntity{}
if v, ok := chunk["entity_kwd"].(string); ok {
e.Name = v
} else if list, ok := chunk["entity_kwd"].([]interface{}); ok && len(list) > 0 {
e.Name, _ = list[0].(string)
}
if e.Name == "" {
continue
}
e.Type, _ = chunk["entity_type_kwd"].(string)
e.Description, _ = chunk["content_with_weight"].(string)
if v, ok := chunk["rank_flt"].(float64); ok {
e.PageRank = v
}
if v, ok := chunk["_score"].(float64); ok {
e.Similarity = v
} else if v, ok := chunk["score"].(float64); ok {
e.Similarity = v
}
entities = append(entities, e)
}
return entities
}
// ParseKGRelationChunks converts raw search result chunks into KGRelation slices.
func ParseKGRelationChunks(chunks []map[string]interface{}) []KGRelation {
var relations []KGRelation
for _, chunk := range chunks {
r := KGRelation{}
r.From, _ = chunk["from_entity_kwd"].(string)
r.To, _ = chunk["to_entity_kwd"].(string)
r.Description, _ = chunk["content_with_weight"].(string)
if v, ok := chunk["weight_int"].(float64); ok {
r.Weight = int(v)
} else if v, ok := chunk["weight_int"].(int); ok {
r.Weight = v
}
if r.From == "" || r.To == "" {
continue
}
relations = append(relations, r)
}
return relations
}
// ParseKGCommunityReportChunks converts raw search result chunks into KGCommunityReport slices.
func ParseKGCommunityReportChunks(chunks []map[string]interface{}) []KGCommunityReport {
var reports []KGCommunityReport
for _, chunk := range chunks {
r := KGCommunityReport{}
r.Title, _ = chunk["docnm_kwd"].(string)
r.Content, _ = chunk["content_with_weight"].(string)
r.Entities, _ = chunk["entities_kwd"].(string)
if v, ok := chunk["weight_flt"].(float64); ok {
r.Weight = v
}
if r.Title == "" && r.Content == "" {
continue
}
reports = append(reports, r)
}
return reports
}
// ParseKGTypeSamplesChunks converts raw search result chunks into a type→entities map.
func ParseKGTypeSamplesChunks(chunks []map[string]interface{}) map[string][]string {
result := make(map[string][]string)
for _, chunk := range chunks {
content, ok := chunk["content_with_weight"].(string)
if !ok || content == "" {
continue
}
var typeMap map[string][]string
if err := json.Unmarshal([]byte(content), &typeMap); err != nil {
continue
}
for typ, entities := range typeMap {
result[typ] = append(result[typ], entities...)
}
}
return result
}
// NhopEntityNames extracts unique entity names from n_hop_with_weight JSON string.
// The JSON format is: [{"path": ["A", "B", "C"], "weights": [0.8, 0.5]}, ...]
// Returns entity names in order of first appearance, with duplicates removed.
func NhopEntityNames(nHopJSON string) []string {
type nhopItem struct {
Path []string `json:"path"`
Weights []float64 `json:"weights"`
}
var data []nhopItem
if err := json.Unmarshal([]byte(nHopJSON), &data); err != nil {
return nil
}
seen := make(map[string]struct{})
var names []string
for _, item := range data {
for _, name := range item.Path {
if _, ok := seen[name]; !ok {
seen[name] = struct{}{}
names = append(names, name)
}
}
}
return names
}
// SearchKGEntities searches for KG entities matching a question.
func SearchKGEntities(ctx context.Context, docEngine engine.DocEngine, kbIDs []string, question string, embModel *modelModule.EmbeddingModel, topN int) ([]KGEntity, error) {
dense, err := buildKGDenseExpr(embModel, question, topN)
if err != nil {
return nil, err
}
req := buildEntitySearchRequest(kbIDs, question, dense, topN)
result, err := docEngine.Search(ctx, req)
if err != nil {
return nil, fmt.Errorf("KG entity search failed: %w", err)
}
return ParseKGEntityChunks(result.Chunks), nil
}
// SearchKGEntitiesByTypes searches for KG entities by type keywords.
func SearchKGEntitiesByTypes(ctx context.Context, docEngine engine.DocEngine, kbIDs []string, typeKeywords []string, topN int) ([]KGEntity, error) {
req := buildEntityTypeSearchRequest(kbIDs, typeKeywords, topN)
result, err := docEngine.Search(ctx, req)
if err != nil {
return nil, fmt.Errorf("KG entity type search failed: %w", err)
}
return ParseKGEntityChunks(result.Chunks), nil
}
// SearchKGRelations searches for KG relations matching a question.
func SearchKGRelations(ctx context.Context, docEngine engine.DocEngine, kbIDs []string, question string, embModel *modelModule.EmbeddingModel, topN int) ([]KGRelation, error) {
dense, err := buildKGDenseExpr(embModel, question, topN)
if err != nil {
return nil, err
}
req := buildRelationSearchRequest(kbIDs, question, dense, topN)
result, err := docEngine.Search(ctx, req)
if err != nil {
return nil, fmt.Errorf("KG relation search failed: %w", err)
}
return ParseKGRelationChunks(result.Chunks), nil
}
// SearchKGCommunityReports searches for community reports related to given entities.
func SearchKGCommunityReports(ctx context.Context, docEngine engine.DocEngine, kbIDs []string, entityNames []string, topN int) ([]KGCommunityReport, error) {
req := buildCommunitySearchRequest(kbIDs, entityNames, topN)
result, err := docEngine.Search(ctx, req)
if err != nil {
return nil, fmt.Errorf("KG community search failed: %w", err)
}
return ParseKGCommunityReportChunks(result.Chunks), nil
}
// SearchKGTypeSamples retrieves the type→entities mapping from ES.
func SearchKGTypeSamples(ctx context.Context, docEngine engine.DocEngine, kbIDs []string) (map[string][]string, error) {
req := buildTypeSamplesSearchRequest(kbIDs)
result, err := docEngine.Search(ctx, req)
if err != nil {
return nil, fmt.Errorf("KG type samples search failed: %w", err)
}
return ParseKGTypeSamplesChunks(result.Chunks), nil
}

View File

@@ -0,0 +1,444 @@
//
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
package service
import (
"context"
"testing"
"ragflow/internal/engine"
"ragflow/internal/engine/types"
modelModule "ragflow/internal/entity/models"
)
// --- buildEntitySearchRequest ---
func TestBuildEntitySearchRequest_TextOnly(t *testing.T) {
req := buildEntitySearchRequest([]string{"kb1"}, "Elon Musk", nil, 10)
if req == nil {
t.Fatal("expected non-nil request")
}
if req.Filter["knowledge_graph_kwd"] != "entity" {
t.Errorf("expected 'entity' filter, got %v", req.Filter["knowledge_graph_kwd"])
}
if req.Limit != 10 {
t.Errorf("expected limit 10, got %d", req.Limit)
}
if len(req.MatchExprs) != 1 {
t.Fatalf("expected 1 MatchExpr (text only), got %d", len(req.MatchExprs))
}
if _, ok := req.MatchExprs[0].(*types.MatchTextExpr); !ok {
t.Error("expected MatchTextExpr")
}
}
func TestBuildEntitySearchRequest_EmptyQuestion(t *testing.T) {
req := buildEntitySearchRequest([]string{"kb1"}, "", nil, 10)
if len(req.MatchExprs) != 0 {
t.Error("expected no MatchExprs for empty question")
}
}
func TestBuildEntitySearchRequest_Hybrid(t *testing.T) {
dense := &types.MatchDenseExpr{
VectorColumnName: "q_768_vec",
EmbeddingData: []float64{0.1, 0.2, 0.3},
EmbeddingDataType: "float",
DistanceType: "cosine",
TopN: 10,
}
req := buildEntitySearchRequest([]string{"kb1"}, "test question", dense, 10)
if len(req.MatchExprs) != 3 {
t.Fatalf("expected 3 MatchExprs (dense + text + fusion), got %d", len(req.MatchExprs))
}
if _, ok := req.MatchExprs[0].(*types.MatchDenseExpr); !ok {
t.Error("expected MatchDenseExpr at [0]")
}
if _, ok := req.MatchExprs[1].(*types.MatchTextExpr); !ok {
t.Error("expected MatchTextExpr at [1]")
}
fusion, ok := req.MatchExprs[2].(*types.FusionExpr)
if !ok {
t.Fatal("expected FusionExpr at [2]")
}
if fusion.Method != "weighted_sum" {
t.Errorf("expected 'weighted_sum', got %q", fusion.Method)
}
}
// --- buildEntityTypeSearchRequest ---
func TestBuildEntityTypeSearchRequest_Basic(t *testing.T) {
req := buildEntityTypeSearchRequest([]string{"kb1"}, []string{"PERSON", "ORGANIZATION"}, 10)
if req.Filter["knowledge_graph_kwd"] != "entity" {
t.Errorf("expected 'entity' filter, got %v", req.Filter["knowledge_graph_kwd"])
}
filter, ok := req.Filter["entity_type_kwd"].([]interface{})
if !ok || len(filter) != 2 {
t.Errorf("expected 2 entity_type filters, got %v", filter)
}
}
func TestBuildEntityTypeSearchRequest_EmptyTypes(t *testing.T) {
req := buildEntityTypeSearchRequest([]string{"kb1"}, nil, 10)
if _, ok := req.Filter["entity_type_kwd"]; ok {
t.Error("expected no entity_type_kwd filter for empty types")
}
}
// --- buildRelationSearchRequest ---
func TestBuildRelationSearchRequest_Basic(t *testing.T) {
req := buildRelationSearchRequest([]string{"kb1"}, "test", nil, 10)
if req.Filter["knowledge_graph_kwd"] != "relation" {
t.Errorf("expected 'relation' filter, got %v", req.Filter["knowledge_graph_kwd"])
}
}
func TestBuildRelationSearchRequest_EmptyQuestion(t *testing.T) {
req := buildRelationSearchRequest([]string{"kb1"}, "", nil, 10)
if len(req.MatchExprs) != 0 {
t.Error("expected no MatchExprs for empty question")
}
}
func TestBuildRelationSearchRequest_Hybrid(t *testing.T) {
dense := &types.MatchDenseExpr{
VectorColumnName: "q_768_vec",
EmbeddingData: []float64{0.1, 0.2},
TopN: 5,
}
req := buildRelationSearchRequest([]string{"kb1"}, "test", dense, 5)
if len(req.MatchExprs) != 3 {
t.Fatalf("expected 3 MatchExprs (dense + text + fusion), got %d", len(req.MatchExprs))
}
}
// --- buildCommunitySearchRequest ---
func TestBuildCommunitySearchRequest_Basic(t *testing.T) {
req := buildCommunitySearchRequest([]string{"kb1"}, []string{"Elon Musk"}, 5)
if req.Filter["knowledge_graph_kwd"] != "community_report" {
t.Errorf("expected 'community_report' filter, got %v", req.Filter["knowledge_graph_kwd"])
}
if req.OrderBy == nil {
t.Error("expected OrderBy")
}
}
func TestBuildCommunitySearchRequest_EmptyNames(t *testing.T) {
req := buildCommunitySearchRequest([]string{"kb1"}, nil, 5)
if _, ok := req.Filter["entities_kwd"]; ok {
t.Error("expected no entities_kwd filter for empty names")
}
}
// --- buildTypeSamplesSearchRequest ---
func TestBuildTypeSamplesSearchRequest(t *testing.T) {
req := buildTypeSamplesSearchRequest([]string{"kb1"})
if req.Filter["knowledge_graph_kwd"] != "ty2ents" {
t.Errorf("expected 'ty2ents' filter, got %v", req.Filter["knowledge_graph_kwd"])
}
if req.Limit != 10000 {
t.Errorf("expected 10000, got %d", req.Limit)
}
}
// --- ParseKGEntityChunks ---
func TestParseKGEntityChunks_Basic(t *testing.T) {
chunks := []map[string]interface{}{
{"entity_kwd": "Elon Musk", "entity_type_kwd": "PERSON", "rank_flt": 0.9, "_score": 0.85, "content_with_weight": "Founder of SpaceX"},
}
entities := ParseKGEntityChunks(chunks)
if len(entities) != 1 {
t.Fatalf("expected 1, got %d", len(entities))
}
if entities[0].Name != "Elon Musk" || entities[0].Type != "PERSON" || entities[0].PageRank != 0.9 || entities[0].Similarity != 0.85 {
t.Errorf("unexpected entity fields: %+v", entities[0])
}
}
func TestParseKGEntityChunks_List(t *testing.T) {
chunks := []map[string]interface{}{
{"entity_kwd": []interface{}{"Elon Musk", "elon_musk"}},
}
entities := ParseKGEntityChunks(chunks)
if len(entities) != 1 || entities[0].Name != "Elon Musk" {
t.Errorf("expected first list element, got %q", entities[0].Name)
}
}
func TestParseKGEntityChunks_EmptyName(t *testing.T) {
chunks := []map[string]interface{}{{"entity_type_kwd": "PERSON"}}
if len(ParseKGEntityChunks(chunks)) != 0 {
t.Error("expected 0 for missing name")
}
}
func TestParseKGEntityChunks_ScoreFallback(t *testing.T) {
chunks := []map[string]interface{}{{"entity_kwd": "Test", "score": 0.75}}
if ParseKGEntityChunks(chunks)[0].Similarity != 0.75 {
t.Error("expected 0.75 from score field")
}
}
func TestParseKGEntityChunks_NilInput(t *testing.T) {
if len(ParseKGEntityChunks(nil)) != 0 {
t.Error("expected 0 for nil input")
}
}
// --- ParseKGRelationChunks ---
func TestParseKGRelationChunks_Basic(t *testing.T) {
chunks := []map[string]interface{}{
{"from_entity_kwd": "Elon Musk", "to_entity_kwd": "SpaceX", "weight_int": float64(5), "content_with_weight": "Founder"},
}
relations := ParseKGRelationChunks(chunks)
if len(relations) != 1 || relations[0].From != "Elon Musk" || relations[0].Weight != 5 {
t.Errorf("unexpected: %+v", relations[0])
}
}
func TestParseKGRelationChunks_IntWeight(t *testing.T) {
chunks := []map[string]interface{}{{"from_entity_kwd": "A", "to_entity_kwd": "B", "weight_int": 3}}
if ParseKGRelationChunks(chunks)[0].Weight != 3 {
t.Error("expected weight 3")
}
}
func TestParseKGRelationChunks_EmptyFrom(t *testing.T) {
if len(ParseKGRelationChunks([]map[string]interface{}{{"to_entity_kwd": "B"}})) != 0 {
t.Error("expected 0 for missing from")
}
}
func TestParseKGRelationChunks_NilInput(t *testing.T) {
if len(ParseKGRelationChunks(nil)) != 0 {
t.Error("expected 0 for nil")
}
}
// --- ParseKGCommunityReportChunks ---
func TestParseKGCommunityReportChunks_Basic(t *testing.T) {
chunks := []map[string]interface{}{
{"docnm_kwd": "Report 1", "content_with_weight": "content", "weight_flt": 0.95, "entities_kwd": "A, B"},
}
reports := ParseKGCommunityReportChunks(chunks)
if len(reports) != 1 || reports[0].Title != "Report 1" || reports[0].Weight != 0.95 {
t.Errorf("unexpected: %+v", reports[0])
}
}
func TestParseKGCommunityReportChunks_EmptyTitle(t *testing.T) {
if len(ParseKGCommunityReportChunks([]map[string]interface{}{{"weight_flt": 0.5}})) != 0 {
t.Error("expected 0 for empty title and content")
}
}
func TestParseKGCommunityReportChunks_NilInput(t *testing.T) {
if len(ParseKGCommunityReportChunks(nil)) != 0 {
t.Error("expected 0 for nil")
}
}
// --- ParseKGTypeSamplesChunks ---
func TestParseKGTypeSamplesChunks_ValidJSON(t *testing.T) {
chunks := []map[string]interface{}{
{"content_with_weight": `{"PERSON": ["Elon Musk", "Einstein"], "ORGANIZATION": ["SpaceX"]}`},
}
result := ParseKGTypeSamplesChunks(chunks)
if len(result) != 2 {
t.Fatalf("expected 2 types, got %d: %v", len(result), result)
}
if len(result["PERSON"]) != 2 || result["PERSON"][0] != "Elon Musk" {
t.Errorf("expected PERSON entities, got %v", result["PERSON"])
}
if len(result["ORGANIZATION"]) != 1 || result["ORGANIZATION"][0] != "SpaceX" {
t.Errorf("expected ORGANIZATION entities, got %v", result["ORGANIZATION"])
}
}
func TestParseKGTypeSamplesChunks_InvalidJSON(t *testing.T) {
chunks := []map[string]interface{}{
{"content_with_weight": "not json"},
}
result := ParseKGTypeSamplesChunks(chunks)
if len(result) != 0 {
t.Error("expected empty for invalid JSON")
}
}
func TestParseKGTypeSamplesChunks_Empty(t *testing.T) {
result := ParseKGTypeSamplesChunks(nil)
if len(result) != 0 {
t.Error("expected empty for nil")
}
}
// --- NhopEntityNames ---
func TestNhopEntityNames_ValidJSON(t *testing.T) {
input := `[{"path": ["A", "B", "C"], "weights": [0.8, 0.5]}, {"path": ["C", "D"], "weights": [0.3]}]`
names := NhopEntityNames(input)
if len(names) != 4 {
t.Fatalf("expected 4 unique names, got %d: %v", len(names), names)
}
}
func TestNhopEntityNames_Dedup(t *testing.T) {
input := `[{"path": ["A", "B"], "weights": [0.9]}, {"path": ["A", "C"], "weights": [0.8]}]`
names := NhopEntityNames(input)
if len(names) != 3 {
t.Errorf("expected 3 unique names (A,B,C), got %d: %v", len(names), names)
}
}
func TestNhopEntityNames_InvalidJSON(t *testing.T) {
result := NhopEntityNames("not json")
if result != nil {
t.Error("expected nil for invalid JSON")
}
}
// --- Mock engine ---
type mockKGEngine struct {
engine.DocEngine
searchFunc func(ctx context.Context, req *types.SearchRequest) (*types.SearchResult, error)
}
func (m *mockKGEngine) Search(ctx context.Context, req *types.SearchRequest) (*types.SearchResult, error) {
if m.searchFunc != nil {
return m.searchFunc(ctx, req)
}
return &types.SearchResult{}, nil
}
// --- Mock model driver for Embed tests ---
type fakeEmbedDriver struct {
modelModule.ModelDriver
name string
vector []float64
}
func (f *fakeEmbedDriver) Name() string { return f.name }
func (f *fakeEmbedDriver) Embed(modelName *string, texts []string, apiConfig *modelModule.APIConfig, config *modelModule.EmbeddingConfig) ([]modelModule.EmbeddingData, error) {
return []modelModule.EmbeddingData{{Embedding: f.vector}}, nil
}
func TestBuildKGDenseExpr_WithModel(t *testing.T) {
embModel := &modelModule.EmbeddingModel{
ModelName: strPtr("test-model"),
ModelDriver: &fakeEmbedDriver{
name: "test",
vector: []float64{0.1, 0.2, 0.3},
},
APIConfig: &modelModule.APIConfig{},
}
dense, err := buildKGDenseExpr(embModel, "test question", 10)
if err != nil {
t.Fatalf("buildKGDenseExpr failed: %v", err)
}
if dense == nil {
t.Fatal("expected non-nil MatchDenseExpr")
}
if dense.VectorColumnName != "q_3_vec" {
t.Errorf("expected 'q_3_vec', got %q", dense.VectorColumnName)
}
if dense.TopN != 10 {
t.Errorf("expected TopN 10, got %d", dense.TopN)
}
}
func TestBuildKGDenseExpr_NilModel(t *testing.T) {
dense, err := buildKGDenseExpr(nil, "test", 10)
if dense != nil || err != nil {
t.Errorf("expected nil,nil for nil model, got dense=%v err=%v", dense, err)
}
}
func TestBuildKGDenseExpr_EmptyQuestion(t *testing.T) {
dense, err := buildKGDenseExpr(&modelModule.EmbeddingModel{}, "", 10)
if dense != nil || err != nil {
t.Errorf("expected nil,nil for empty question, got dense=%v err=%v", dense, err)
}
}
// --- Search integration with mock ---
func TestSearchKGEntities_WithMock(t *testing.T) {
mock := &mockKGEngine{
searchFunc: func(ctx context.Context, req *types.SearchRequest) (*types.SearchResult, error) {
if req.Filter["knowledge_graph_kwd"] != "entity" {
t.Error("expected entity filter")
}
return &types.SearchResult{
Chunks: []map[string]interface{}{
{"entity_kwd": "Elon Musk", "entity_type_kwd": "PERSON"},
},
}, nil
},
}
entities, err := SearchKGEntities(context.Background(), mock, []string{"kb1"}, "Elon", nil, 10)
if err != nil {
t.Fatalf("SearchKGEntities failed: %v", err)
}
if len(entities) != 1 || entities[0].Name != "Elon Musk" {
t.Errorf("expected [Elon Musk], got %v", entities)
}
}
func TestSearchKGEntitiesByTypes_WithMock(t *testing.T) {
mock := &mockKGEngine{
searchFunc: func(ctx context.Context, req *types.SearchRequest) (*types.SearchResult, error) {
return &types.SearchResult{
Chunks: []map[string]interface{}{
{"entity_kwd": "SpaceX", "entity_type_kwd": "ORGANIZATION"},
},
}, nil
},
}
entities, err := SearchKGEntitiesByTypes(context.Background(), mock, []string{"kb1"}, []string{"ORGANIZATION"}, 10)
if err != nil {
t.Fatalf("SearchKGEntitiesByTypes failed: %v", err)
}
if len(entities) != 1 || entities[0].Type != "ORGANIZATION" {
t.Errorf("expected ORGANIZATION, got %v", entities)
}
}
func TestSearchKGTypeSamples_WithMock(t *testing.T) {
mock := &mockKGEngine{}
samples, err := SearchKGTypeSamples(context.Background(), mock, []string{"kb1"})
if err != nil {
t.Fatalf("SearchKGTypeSamples failed: %v", err)
}
if samples == nil {
samples = map[string][]string{}
}
if len(samples) != 0 {
t.Errorf("expected empty, got %d", len(samples))
}
}