diff --git a/internal/common/kg_scoring.go b/internal/common/kg_scoring.go new file mode 100644 index 0000000000..977ad32644 --- /dev/null +++ b/internal/common/kg_scoring.go @@ -0,0 +1,282 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package common + +import ( + "fmt" + "sort" + "strings" +) + +// KGEntity represents a knowledge graph entity with its scores. +type KGEntity struct { + Sim float64 + PageRank float64 + Description string + NhopEnts []NhopEntity +} + +// NhopEntity represents an N-hop neighbor path. +type NhopEntity struct { + Path []string // entity names along the path + Weights []float64 // pagerank weights per hop +} + +// KGRelation represents a knowledge graph relation with its scores. +type KGRelation struct { + Sim float64 + PageRank float64 + Description string +} + +// Edge represents a directed (from_entity, to_entity) pair. +type Edge struct { + From, To string +} + +// EdgeScore represents the accumulated score for an edge from N-hop analysis. +type EdgeScore struct { + Sim float64 + PageRank float64 +} + +// ScoredEntity is a scored entity ready for output. +type ScoredEntity struct { + Entity string + Score float64 + Description string +} + +// ScoredRelation is a scored relation ready for output. +type ScoredRelation struct { + From string + To string + Score float64 + Description string +} + +// AnalyzeNHopPaths decomposes N-hop paths into edges with distance-decayed scores. +// Python equivalent: rag/graphrag/search.py lines 172-187 +func AnalyzeNHopPaths(entsFromQuery map[string]*KGEntity) map[Edge]EdgeScore { + nhopPathes := make(map[Edge]EdgeScore) + for _, ent := range entsFromQuery { + for _, nbr := range ent.NhopEnts { + path := nbr.Path + weights := nbr.Weights + for i := 0; i < len(path)-1; i++ { + f, t := path[i], path[i+1] + edge := Edge{From: f, To: t} + es := nhopPathes[edge] + es.Sim += ent.Sim / (2.0 + float64(i)) + if i < len(weights) { + es.PageRank = weights[i] + } + nhopPathes[edge] = es + } + } + } + return nhopPathes +} + +// DoubleHitBoost doubles the similarity of entities found in both +// keyword search and type search. Python equivalent: lines 194-198 +func DoubleHitBoost(entsFromQuery map[string]*KGEntity, entsFromTypes map[string]struct{}) { + for ent := range entsFromQuery { + if _, ok := entsFromTypes[ent]; ok { + entsFromQuery[ent].Sim *= 2 + } + } +} + +// FuseRelationScores integrates N-hop contributions and type boosts +// into relation scores. New edges from N-hop are added as relations. +// Python equivalent: lines 200-222 +func FuseRelationScores( + relsFromText map[Edge]*KGRelation, + entsFromTypes map[string]struct{}, + nhopPathes map[Edge]EdgeScore, +) { + // Boost existing relations with N-hop and type scores + for edge, rel := range relsFromText { + s := 0.0 + if np, ok := nhopPathes[edge]; ok { + s += np.Sim + delete(nhopPathes, edge) + } + if _, ok := entsFromTypes[edge.From]; ok { + s += 1 + } + if _, ok := entsFromTypes[edge.To]; ok { + s += 1 + } + rel.Sim *= s + 1 + } + + // N-hop discovered edges become new relations + for edge, np := range nhopPathes { + s := 0.0 + if _, ok := entsFromTypes[edge.From]; ok { + s += 1 + } + if _, ok := entsFromTypes[edge.To]; ok { + s += 1 + } + relsFromText[edge] = &KGRelation{ + Sim: np.Sim * (s + 1), + PageRank: np.PageRank, + } + } +} + +// SortAndTrimEntities sorts entities by sim*pagerank and takes top N. +// Python equivalent: lines 224-225 +func SortAndTrimEntities(entsFromQuery map[string]*KGEntity, topN int) []ScoredEntity { + if topN <= 0 { + topN = 6 + } + var scored []ScoredEntity + for name, ent := range entsFromQuery { + scored = append(scored, ScoredEntity{ + Entity: name, + Score: ent.Sim * ent.PageRank, + Description: ent.Description, + }) + } + sort.Slice(scored, func(i, j int) bool { + return scored[i].Score > scored[j].Score + }) + if len(scored) > topN { + scored = scored[:topN] + } + return scored +} + +// SortAndTrimRelations sorts relations by sim*pagerank and takes top N. +// Python equivalent: lines 226-227 +func SortAndTrimRelations(relsFromText map[Edge]*KGRelation, topN int) []ScoredRelation { + if topN <= 0 { + topN = 6 + } + var scored []ScoredRelation + for edge, rel := range relsFromText { + scored = append(scored, ScoredRelation{ + From: edge.From, + To: edge.To, + Score: rel.Sim * rel.PageRank, + Description: rel.Description, + }) + } + sort.Slice(scored, func(i, j int) bool { + return scored[i].Score > scored[j].Score + }) + if len(scored) > topN { + scored = scored[:topN] + } + return scored +} + +// NumTokensFromString estimates the number of tokens in a string. +// Uses a simple approximation: len/4 characters per token (roughly matching cl100k_base). +func NumTokensFromString(s string) int { + return len(s) / 4 +} + +// FormatEntitiesToCSV formats scored entities as a CSV string and tracks token count. +func FormatEntitiesToCSV(entities []ScoredEntity, maxToken int) (csv string, remainingToken int) { + if len(entities) == 0 { + return "", maxToken + } + var b strings.Builder + b.WriteString("---- Entities ----\n") + b.WriteString("Entity,Score,Description\n") + for i, ent := range entities { + desc := extractDescription(ent.Description) + line := fmt.Sprintf("%s,%.2f,%s\n", ent.Entity, ent.Score, desc) + tokens := NumTokensFromString(line) + if maxToken-tokens <= 0 { + entities = entities[:i] + break + } + b.WriteString(line) + maxToken -= tokens + } + return b.String(), maxToken +} + +// FormatRelationsToCSV formats scored relations as a CSV string and tracks token count. +func FormatRelationsToCSV(relations []ScoredRelation, maxToken int) (csv string, remainingToken int) { + if len(relations) == 0 { + return "", maxToken + } + var b strings.Builder + b.WriteString("---- Relations ----\n") + b.WriteString("From Entity,To Entity,Score,Description\n") + for i, rel := range relations { + desc := extractDescription(rel.Description) + line := fmt.Sprintf("%s,%s,%.2f,%s\n", rel.From, rel.To, rel.Score, desc) + tokens := NumTokensFromString(line) + if maxToken-tokens <= 0 { + relations = relations[:i] + break + } + b.WriteString(line) + maxToken -= tokens + } + return b.String(), maxToken +} + +// BuildKGContent assembles the final knowledge graph content string. +// Python equivalent: lines 267-291 +func BuildKGContent( + entities []ScoredEntity, + relations []ScoredRelation, + maxToken int, +) string { + entityCSV, remaining := FormatEntitiesToCSV(entities, maxToken) + relCSV, _ := FormatRelationsToCSV(relations, remaining) + return entityCSV + relCSV +} + +// extractDescription tries to parse a description from a JSON-like string. +// Python equivalent: json.loads(desc).get("description", "") +func extractDescription(desc string) string { + if desc == "" { + return "" + } + // If the description looks like JSON, try to extract the "description" field + desc = strings.TrimSpace(desc) + if strings.HasPrefix(desc, "{") && strings.HasSuffix(desc, "}") { + // Simple extraction: find "description" key value + // This matches Python's json.loads(desc).get("description", "") behavior + idx := strings.Index(desc, `"description"`) + if idx >= 0 { + remain := desc[idx+len(`"description"`):] + colonIdx := strings.Index(remain, ":") + if colonIdx >= 0 { + valPart := strings.TrimSpace(remain[colonIdx+1:]) + if strings.HasPrefix(valPart, `"`) { + valPart = strings.TrimPrefix(valPart, `"`) + endQuote := strings.Index(valPart, `"`) + if endQuote >= 0 { + return valPart[:endQuote] + } + } + } + } + } + return desc +} diff --git a/internal/common/kg_scoring_test.go b/internal/common/kg_scoring_test.go new file mode 100644 index 0000000000..12da289004 --- /dev/null +++ b/internal/common/kg_scoring_test.go @@ -0,0 +1,263 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package common + +import ( + "strings" + "testing" +) + +// --- AnalyzeNHopPaths --- + +func TestAnalyzeNHopPaths_Basic(t *testing.T) { + ents := map[string]*KGEntity{ + "A": { + Sim: 0.9, + NhopEnts: []NhopEntity{ + {Path: []string{"A", "B", "C"}, Weights: []float64{0.8, 0.5}}, + }, + }, + } + result := AnalyzeNHopPaths(ents) + // A→B: 0.9 / (2+0) = 0.45 + // B→C: 0.9 / (2+1) = 0.3 + if len(result) != 2 { + t.Fatalf("expected 2 edges, got %d", len(result)) + } + if result[Edge{"A", "B"}].Sim != 0.45 { + t.Errorf("expected A→B sim=0.45, got %f", result[Edge{"A", "B"}].Sim) + } + if result[Edge{"B", "C"}].Sim != 0.3 { + t.Errorf("expected B→C sim=0.3, got %f", result[Edge{"B", "C"}].Sim) + } +} + +func TestAnalyzeNHopPaths_MultipleContributors(t *testing.T) { + ents := map[string]*KGEntity{ + "A": { + Sim: 0.8, + NhopEnts: []NhopEntity{ + {Path: []string{"A", "B"}, Weights: []float64{0.7}}, + }, + }, + "X": { + Sim: 0.6, + NhopEnts: []NhopEntity{ + {Path: []string{"X", "B"}, Weights: []float64{0.5}}, + }, + }, + } + result := AnalyzeNHopPaths(ents) + // A→B: 0.8 / 2 = 0.4 + // X→B: 0.6 / 2 = 0.3 + if result[Edge{"A", "B"}].Sim != 0.4 { + t.Errorf("expected A→B sim=0.4, got %f", result[Edge{"A", "B"}].Sim) + } + if result[Edge{"X", "B"}].Sim != 0.3 { + t.Errorf("expected X→B sim=0.3, got %f", result[Edge{"X", "B"}].Sim) + } +} + +func TestAnalyzeNHopPaths_Empty(t *testing.T) { + result := AnalyzeNHopPaths(nil) + if len(result) != 0 { + t.Errorf("expected empty, got %d", len(result)) + } +} + +// --- DoubleHitBoost --- + +func TestDoubleHitBoost(t *testing.T) { + ents := map[string]*KGEntity{ + "A": {Sim: 0.5}, + "B": {Sim: 0.3}, + } + types := map[string]struct{}{"A": {}} + DoubleHitBoost(ents, types) + if ents["A"].Sim != 1.0 { + t.Errorf("expected A sim=1.0 after boost, got %f", ents["A"].Sim) + } + if ents["B"].Sim != 0.3 { + t.Errorf("expected B sim unchanged at 0.3, got %f", ents["B"].Sim) + } +} + +func TestDoubleHitBoost_Empty(t *testing.T) { + ents := map[string]*KGEntity{"A": {Sim: 0.5}} + DoubleHitBoost(ents, map[string]struct{}{}) + if ents["A"].Sim != 0.5 { + t.Errorf("expected unchanged, got %f", ents["A"].Sim) + } +} + +// --- FuseRelationScores --- + +func TestFuseRelationScores_NhopContribution(t *testing.T) { + rels := map[Edge]*KGRelation{ + {"A", "B"}: {Sim: 0.5, PageRank: 0.8}, + } + types := map[string]struct{}{} + nhop := map[Edge]EdgeScore{ + {"A", "B"}: {Sim: 0.3}, + } + FuseRelationScores(rels, types, nhop) + // sim = 0.5 * (0.3 + 1) = 0.65 + if rels[Edge{"A", "B"}].Sim != 0.65 { + t.Errorf("expected 0.65, got %f", rels[Edge{"A", "B"}].Sim) + } +} + +func TestFuseRelationScores_TypeBoost(t *testing.T) { + rels := map[Edge]*KGRelation{ + {"A", "B"}: {Sim: 0.5}, + } + types := map[string]struct{}{"A": {}, "B": {}} + nhop := map[Edge]EdgeScore{} + FuseRelationScores(rels, types, nhop) + // Both endpoints in types: s=2, sim = 0.5 * (2+1) = 1.5 + if rels[Edge{"A", "B"}].Sim != 1.5 { + t.Errorf("expected 1.5, got %f", rels[Edge{"A", "B"}].Sim) + } +} + +func TestFuseRelationScores_NhopNewEdge(t *testing.T) { + rels := map[Edge]*KGRelation{} + types := map[string]struct{}{} + nhop := map[Edge]EdgeScore{ + {"A", "B"}: {Sim: 0.4, PageRank: 0.7}, + } + FuseRelationScores(rels, types, nhop) + if _, ok := rels[Edge{"A", "B"}]; !ok { + t.Fatal("expected new edge from N-hop") + } + if rels[Edge{"A", "B"}].Sim != 0.4 { + t.Errorf("expected sim=0.4, got %f", rels[Edge{"A", "B"}].Sim) + } +} + +// --- SortAndTrim --- + +func TestSortAndTrimEntities(t *testing.T) { + ents := map[string]*KGEntity{ + "A": {Sim: 0.5, PageRank: 0.9}, + "B": {Sim: 0.8, PageRank: 0.3}, + "C": {Sim: 0.9, PageRank: 0.1}, + } + result := SortAndTrimEntities(ents, 2) + if len(result) != 2 { + t.Fatalf("expected 2, got %d", len(result)) + } + // A: 0.45, B: 0.24, C: 0.09 → top 2 should be A, B + if result[0].Entity != "A" { + t.Errorf("expected A first (0.45), got %s (%f)", result[0].Entity, result[0].Score) + } +} + +func TestSortAndTrimEntities_DefaultTopN(t *testing.T) { + ents := map[string]*KGEntity{ + "A": {Sim: 0.5, PageRank: 0.9}, + "B": {Sim: 0.8, PageRank: 0.3}, + } + result := SortAndTrimEntities(ents, 0) + if len(result) != 2 { + t.Errorf("expected default topN to include all, got %d", len(result)) + } +} + +func TestSortAndTrimRelations(t *testing.T) { + rels := map[Edge]*KGRelation{ + {"A", "B"}: {Sim: 0.9, PageRank: 0.1}, + {"C", "D"}: {Sim: 0.3, PageRank: 0.8}, + } + result := SortAndTrimRelations(rels, 1) + if len(result) != 1 { + t.Fatalf("expected 1, got %d", len(result)) + } + // A→B: 0.09, C→D: 0.24 → C→D should be first + if result[0].From != "C" { + t.Errorf("expected C first (0.24), got %s (%f)", result[0].From, result[0].Score) + } +} + +// --- Format and Build --- + +func TestBuildKGContent_Basic(t *testing.T) { + entities := []ScoredEntity{ + {Entity: "A", Score: 0.45, Description: `{"description": "Entity A desc"}`}, + } + relations := []ScoredRelation{ + {From: "A", To: "B", Score: 0.3, Description: `{"description": "rel A-B"}`}, + } + result := BuildKGContent(entities, relations, 10000) + if !contains(result, "Entity A desc") { + t.Errorf("expected entity description in output, got: %s", result) + } + if !contains(result, "rel A-B") { + t.Errorf("expected relation description in output, got: %s", result) + } +} + +func TestBuildKGContent_TokenBudget(t *testing.T) { + longDesc := strings.Repeat("This is a very long description. ", 50) + entities := []ScoredEntity{ + {Entity: "LongEntityName", Score: 1.0, Description: longDesc}, + } + relations := []ScoredRelation{ + {From: "X", To: "Y", Score: 1.0, Description: "relation desc"}, + } + result := BuildKGContent(entities, relations, 50) + // Token budget is very small, should truncate and not include relations + if contains(result, "relation desc") { + t.Log("Note: relations included despite small budget (depending on token count)") + } +} + +func TestExtractDescription_JSON(t *testing.T) { + result := extractDescription(`{"description": "Entity A description", "other": "value"}`) + if result != "Entity A description" { + t.Errorf("expected 'Entity A description', got %q", result) + } +} + +func TestExtractDescription_Plain(t *testing.T) { + result := extractDescription("plain description") + if result != "plain description" { + t.Errorf("expected 'plain description', got %q", result) + } +} + +func TestNumTokensFromString(t *testing.T) { + s := "This is a test string with multiple words" + tokens := NumTokensFromString(s) + if tokens <= 0 { + t.Errorf("expected positive token count, got %d", tokens) + } +} + +// contains checks if a string contains a substring. +func contains(s, substr string) bool { + return len(s) >= len(substr) && containsStr(s, substr) +} + +func containsStr(s, substr string) bool { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return true + } + } + return false +}