mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 23:41:12 +08:00
feat: add KG scoring utilities (#15666)
KG scoring utilities as pure functions. Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
282
internal/common/kg_scoring.go
Normal file
282
internal/common/kg_scoring.go
Normal file
@@ -0,0 +1,282 @@
|
||||
//
|
||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package common
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// KGEntity represents a knowledge graph entity with its scores.
|
||||
type KGEntity struct {
|
||||
Sim float64
|
||||
PageRank float64
|
||||
Description string
|
||||
NhopEnts []NhopEntity
|
||||
}
|
||||
|
||||
// NhopEntity represents an N-hop neighbor path.
|
||||
type NhopEntity struct {
|
||||
Path []string // entity names along the path
|
||||
Weights []float64 // pagerank weights per hop
|
||||
}
|
||||
|
||||
// KGRelation represents a knowledge graph relation with its scores.
|
||||
type KGRelation struct {
|
||||
Sim float64
|
||||
PageRank float64
|
||||
Description string
|
||||
}
|
||||
|
||||
// Edge represents a directed (from_entity, to_entity) pair.
|
||||
type Edge struct {
|
||||
From, To string
|
||||
}
|
||||
|
||||
// EdgeScore represents the accumulated score for an edge from N-hop analysis.
|
||||
type EdgeScore struct {
|
||||
Sim float64
|
||||
PageRank float64
|
||||
}
|
||||
|
||||
// ScoredEntity is a scored entity ready for output.
|
||||
type ScoredEntity struct {
|
||||
Entity string
|
||||
Score float64
|
||||
Description string
|
||||
}
|
||||
|
||||
// ScoredRelation is a scored relation ready for output.
|
||||
type ScoredRelation struct {
|
||||
From string
|
||||
To string
|
||||
Score float64
|
||||
Description string
|
||||
}
|
||||
|
||||
// AnalyzeNHopPaths decomposes N-hop paths into edges with distance-decayed scores.
|
||||
// Python equivalent: rag/graphrag/search.py lines 172-187
|
||||
func AnalyzeNHopPaths(entsFromQuery map[string]*KGEntity) map[Edge]EdgeScore {
|
||||
nhopPathes := make(map[Edge]EdgeScore)
|
||||
for _, ent := range entsFromQuery {
|
||||
for _, nbr := range ent.NhopEnts {
|
||||
path := nbr.Path
|
||||
weights := nbr.Weights
|
||||
for i := 0; i < len(path)-1; i++ {
|
||||
f, t := path[i], path[i+1]
|
||||
edge := Edge{From: f, To: t}
|
||||
es := nhopPathes[edge]
|
||||
es.Sim += ent.Sim / (2.0 + float64(i))
|
||||
if i < len(weights) {
|
||||
es.PageRank = weights[i]
|
||||
}
|
||||
nhopPathes[edge] = es
|
||||
}
|
||||
}
|
||||
}
|
||||
return nhopPathes
|
||||
}
|
||||
|
||||
// DoubleHitBoost doubles the similarity of entities found in both
|
||||
// keyword search and type search. Python equivalent: lines 194-198
|
||||
func DoubleHitBoost(entsFromQuery map[string]*KGEntity, entsFromTypes map[string]struct{}) {
|
||||
for ent := range entsFromQuery {
|
||||
if _, ok := entsFromTypes[ent]; ok {
|
||||
entsFromQuery[ent].Sim *= 2
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// FuseRelationScores integrates N-hop contributions and type boosts
|
||||
// into relation scores. New edges from N-hop are added as relations.
|
||||
// Python equivalent: lines 200-222
|
||||
func FuseRelationScores(
|
||||
relsFromText map[Edge]*KGRelation,
|
||||
entsFromTypes map[string]struct{},
|
||||
nhopPathes map[Edge]EdgeScore,
|
||||
) {
|
||||
// Boost existing relations with N-hop and type scores
|
||||
for edge, rel := range relsFromText {
|
||||
s := 0.0
|
||||
if np, ok := nhopPathes[edge]; ok {
|
||||
s += np.Sim
|
||||
delete(nhopPathes, edge)
|
||||
}
|
||||
if _, ok := entsFromTypes[edge.From]; ok {
|
||||
s += 1
|
||||
}
|
||||
if _, ok := entsFromTypes[edge.To]; ok {
|
||||
s += 1
|
||||
}
|
||||
rel.Sim *= s + 1
|
||||
}
|
||||
|
||||
// N-hop discovered edges become new relations
|
||||
for edge, np := range nhopPathes {
|
||||
s := 0.0
|
||||
if _, ok := entsFromTypes[edge.From]; ok {
|
||||
s += 1
|
||||
}
|
||||
if _, ok := entsFromTypes[edge.To]; ok {
|
||||
s += 1
|
||||
}
|
||||
relsFromText[edge] = &KGRelation{
|
||||
Sim: np.Sim * (s + 1),
|
||||
PageRank: np.PageRank,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SortAndTrimEntities sorts entities by sim*pagerank and takes top N.
|
||||
// Python equivalent: lines 224-225
|
||||
func SortAndTrimEntities(entsFromQuery map[string]*KGEntity, topN int) []ScoredEntity {
|
||||
if topN <= 0 {
|
||||
topN = 6
|
||||
}
|
||||
var scored []ScoredEntity
|
||||
for name, ent := range entsFromQuery {
|
||||
scored = append(scored, ScoredEntity{
|
||||
Entity: name,
|
||||
Score: ent.Sim * ent.PageRank,
|
||||
Description: ent.Description,
|
||||
})
|
||||
}
|
||||
sort.Slice(scored, func(i, j int) bool {
|
||||
return scored[i].Score > scored[j].Score
|
||||
})
|
||||
if len(scored) > topN {
|
||||
scored = scored[:topN]
|
||||
}
|
||||
return scored
|
||||
}
|
||||
|
||||
// SortAndTrimRelations sorts relations by sim*pagerank and takes top N.
|
||||
// Python equivalent: lines 226-227
|
||||
func SortAndTrimRelations(relsFromText map[Edge]*KGRelation, topN int) []ScoredRelation {
|
||||
if topN <= 0 {
|
||||
topN = 6
|
||||
}
|
||||
var scored []ScoredRelation
|
||||
for edge, rel := range relsFromText {
|
||||
scored = append(scored, ScoredRelation{
|
||||
From: edge.From,
|
||||
To: edge.To,
|
||||
Score: rel.Sim * rel.PageRank,
|
||||
Description: rel.Description,
|
||||
})
|
||||
}
|
||||
sort.Slice(scored, func(i, j int) bool {
|
||||
return scored[i].Score > scored[j].Score
|
||||
})
|
||||
if len(scored) > topN {
|
||||
scored = scored[:topN]
|
||||
}
|
||||
return scored
|
||||
}
|
||||
|
||||
// NumTokensFromString estimates the number of tokens in a string.
|
||||
// Uses a simple approximation: len/4 characters per token (roughly matching cl100k_base).
|
||||
func NumTokensFromString(s string) int {
|
||||
return len(s) / 4
|
||||
}
|
||||
|
||||
// FormatEntitiesToCSV formats scored entities as a CSV string and tracks token count.
|
||||
func FormatEntitiesToCSV(entities []ScoredEntity, maxToken int) (csv string, remainingToken int) {
|
||||
if len(entities) == 0 {
|
||||
return "", maxToken
|
||||
}
|
||||
var b strings.Builder
|
||||
b.WriteString("---- Entities ----\n")
|
||||
b.WriteString("Entity,Score,Description\n")
|
||||
for i, ent := range entities {
|
||||
desc := extractDescription(ent.Description)
|
||||
line := fmt.Sprintf("%s,%.2f,%s\n", ent.Entity, ent.Score, desc)
|
||||
tokens := NumTokensFromString(line)
|
||||
if maxToken-tokens <= 0 {
|
||||
entities = entities[:i]
|
||||
break
|
||||
}
|
||||
b.WriteString(line)
|
||||
maxToken -= tokens
|
||||
}
|
||||
return b.String(), maxToken
|
||||
}
|
||||
|
||||
// FormatRelationsToCSV formats scored relations as a CSV string and tracks token count.
|
||||
func FormatRelationsToCSV(relations []ScoredRelation, maxToken int) (csv string, remainingToken int) {
|
||||
if len(relations) == 0 {
|
||||
return "", maxToken
|
||||
}
|
||||
var b strings.Builder
|
||||
b.WriteString("---- Relations ----\n")
|
||||
b.WriteString("From Entity,To Entity,Score,Description\n")
|
||||
for i, rel := range relations {
|
||||
desc := extractDescription(rel.Description)
|
||||
line := fmt.Sprintf("%s,%s,%.2f,%s\n", rel.From, rel.To, rel.Score, desc)
|
||||
tokens := NumTokensFromString(line)
|
||||
if maxToken-tokens <= 0 {
|
||||
relations = relations[:i]
|
||||
break
|
||||
}
|
||||
b.WriteString(line)
|
||||
maxToken -= tokens
|
||||
}
|
||||
return b.String(), maxToken
|
||||
}
|
||||
|
||||
// BuildKGContent assembles the final knowledge graph content string.
|
||||
// Python equivalent: lines 267-291
|
||||
func BuildKGContent(
|
||||
entities []ScoredEntity,
|
||||
relations []ScoredRelation,
|
||||
maxToken int,
|
||||
) string {
|
||||
entityCSV, remaining := FormatEntitiesToCSV(entities, maxToken)
|
||||
relCSV, _ := FormatRelationsToCSV(relations, remaining)
|
||||
return entityCSV + relCSV
|
||||
}
|
||||
|
||||
// extractDescription tries to parse a description from a JSON-like string.
|
||||
// Python equivalent: json.loads(desc).get("description", "")
|
||||
func extractDescription(desc string) string {
|
||||
if desc == "" {
|
||||
return ""
|
||||
}
|
||||
// If the description looks like JSON, try to extract the "description" field
|
||||
desc = strings.TrimSpace(desc)
|
||||
if strings.HasPrefix(desc, "{") && strings.HasSuffix(desc, "}") {
|
||||
// Simple extraction: find "description" key value
|
||||
// This matches Python's json.loads(desc).get("description", "") behavior
|
||||
idx := strings.Index(desc, `"description"`)
|
||||
if idx >= 0 {
|
||||
remain := desc[idx+len(`"description"`):]
|
||||
colonIdx := strings.Index(remain, ":")
|
||||
if colonIdx >= 0 {
|
||||
valPart := strings.TrimSpace(remain[colonIdx+1:])
|
||||
if strings.HasPrefix(valPart, `"`) {
|
||||
valPart = strings.TrimPrefix(valPart, `"`)
|
||||
endQuote := strings.Index(valPart, `"`)
|
||||
if endQuote >= 0 {
|
||||
return valPart[:endQuote]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return desc
|
||||
}
|
||||
263
internal/common/kg_scoring_test.go
Normal file
263
internal/common/kg_scoring_test.go
Normal file
@@ -0,0 +1,263 @@
|
||||
//
|
||||
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
|
||||
package common
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// --- AnalyzeNHopPaths ---
|
||||
|
||||
func TestAnalyzeNHopPaths_Basic(t *testing.T) {
|
||||
ents := map[string]*KGEntity{
|
||||
"A": {
|
||||
Sim: 0.9,
|
||||
NhopEnts: []NhopEntity{
|
||||
{Path: []string{"A", "B", "C"}, Weights: []float64{0.8, 0.5}},
|
||||
},
|
||||
},
|
||||
}
|
||||
result := AnalyzeNHopPaths(ents)
|
||||
// A→B: 0.9 / (2+0) = 0.45
|
||||
// B→C: 0.9 / (2+1) = 0.3
|
||||
if len(result) != 2 {
|
||||
t.Fatalf("expected 2 edges, got %d", len(result))
|
||||
}
|
||||
if result[Edge{"A", "B"}].Sim != 0.45 {
|
||||
t.Errorf("expected A→B sim=0.45, got %f", result[Edge{"A", "B"}].Sim)
|
||||
}
|
||||
if result[Edge{"B", "C"}].Sim != 0.3 {
|
||||
t.Errorf("expected B→C sim=0.3, got %f", result[Edge{"B", "C"}].Sim)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAnalyzeNHopPaths_MultipleContributors(t *testing.T) {
|
||||
ents := map[string]*KGEntity{
|
||||
"A": {
|
||||
Sim: 0.8,
|
||||
NhopEnts: []NhopEntity{
|
||||
{Path: []string{"A", "B"}, Weights: []float64{0.7}},
|
||||
},
|
||||
},
|
||||
"X": {
|
||||
Sim: 0.6,
|
||||
NhopEnts: []NhopEntity{
|
||||
{Path: []string{"X", "B"}, Weights: []float64{0.5}},
|
||||
},
|
||||
},
|
||||
}
|
||||
result := AnalyzeNHopPaths(ents)
|
||||
// A→B: 0.8 / 2 = 0.4
|
||||
// X→B: 0.6 / 2 = 0.3
|
||||
if result[Edge{"A", "B"}].Sim != 0.4 {
|
||||
t.Errorf("expected A→B sim=0.4, got %f", result[Edge{"A", "B"}].Sim)
|
||||
}
|
||||
if result[Edge{"X", "B"}].Sim != 0.3 {
|
||||
t.Errorf("expected X→B sim=0.3, got %f", result[Edge{"X", "B"}].Sim)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAnalyzeNHopPaths_Empty(t *testing.T) {
|
||||
result := AnalyzeNHopPaths(nil)
|
||||
if len(result) != 0 {
|
||||
t.Errorf("expected empty, got %d", len(result))
|
||||
}
|
||||
}
|
||||
|
||||
// --- DoubleHitBoost ---
|
||||
|
||||
func TestDoubleHitBoost(t *testing.T) {
|
||||
ents := map[string]*KGEntity{
|
||||
"A": {Sim: 0.5},
|
||||
"B": {Sim: 0.3},
|
||||
}
|
||||
types := map[string]struct{}{"A": {}}
|
||||
DoubleHitBoost(ents, types)
|
||||
if ents["A"].Sim != 1.0 {
|
||||
t.Errorf("expected A sim=1.0 after boost, got %f", ents["A"].Sim)
|
||||
}
|
||||
if ents["B"].Sim != 0.3 {
|
||||
t.Errorf("expected B sim unchanged at 0.3, got %f", ents["B"].Sim)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDoubleHitBoost_Empty(t *testing.T) {
|
||||
ents := map[string]*KGEntity{"A": {Sim: 0.5}}
|
||||
DoubleHitBoost(ents, map[string]struct{}{})
|
||||
if ents["A"].Sim != 0.5 {
|
||||
t.Errorf("expected unchanged, got %f", ents["A"].Sim)
|
||||
}
|
||||
}
|
||||
|
||||
// --- FuseRelationScores ---
|
||||
|
||||
func TestFuseRelationScores_NhopContribution(t *testing.T) {
|
||||
rels := map[Edge]*KGRelation{
|
||||
{"A", "B"}: {Sim: 0.5, PageRank: 0.8},
|
||||
}
|
||||
types := map[string]struct{}{}
|
||||
nhop := map[Edge]EdgeScore{
|
||||
{"A", "B"}: {Sim: 0.3},
|
||||
}
|
||||
FuseRelationScores(rels, types, nhop)
|
||||
// sim = 0.5 * (0.3 + 1) = 0.65
|
||||
if rels[Edge{"A", "B"}].Sim != 0.65 {
|
||||
t.Errorf("expected 0.65, got %f", rels[Edge{"A", "B"}].Sim)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFuseRelationScores_TypeBoost(t *testing.T) {
|
||||
rels := map[Edge]*KGRelation{
|
||||
{"A", "B"}: {Sim: 0.5},
|
||||
}
|
||||
types := map[string]struct{}{"A": {}, "B": {}}
|
||||
nhop := map[Edge]EdgeScore{}
|
||||
FuseRelationScores(rels, types, nhop)
|
||||
// Both endpoints in types: s=2, sim = 0.5 * (2+1) = 1.5
|
||||
if rels[Edge{"A", "B"}].Sim != 1.5 {
|
||||
t.Errorf("expected 1.5, got %f", rels[Edge{"A", "B"}].Sim)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFuseRelationScores_NhopNewEdge(t *testing.T) {
|
||||
rels := map[Edge]*KGRelation{}
|
||||
types := map[string]struct{}{}
|
||||
nhop := map[Edge]EdgeScore{
|
||||
{"A", "B"}: {Sim: 0.4, PageRank: 0.7},
|
||||
}
|
||||
FuseRelationScores(rels, types, nhop)
|
||||
if _, ok := rels[Edge{"A", "B"}]; !ok {
|
||||
t.Fatal("expected new edge from N-hop")
|
||||
}
|
||||
if rels[Edge{"A", "B"}].Sim != 0.4 {
|
||||
t.Errorf("expected sim=0.4, got %f", rels[Edge{"A", "B"}].Sim)
|
||||
}
|
||||
}
|
||||
|
||||
// --- SortAndTrim ---
|
||||
|
||||
func TestSortAndTrimEntities(t *testing.T) {
|
||||
ents := map[string]*KGEntity{
|
||||
"A": {Sim: 0.5, PageRank: 0.9},
|
||||
"B": {Sim: 0.8, PageRank: 0.3},
|
||||
"C": {Sim: 0.9, PageRank: 0.1},
|
||||
}
|
||||
result := SortAndTrimEntities(ents, 2)
|
||||
if len(result) != 2 {
|
||||
t.Fatalf("expected 2, got %d", len(result))
|
||||
}
|
||||
// A: 0.45, B: 0.24, C: 0.09 → top 2 should be A, B
|
||||
if result[0].Entity != "A" {
|
||||
t.Errorf("expected A first (0.45), got %s (%f)", result[0].Entity, result[0].Score)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSortAndTrimEntities_DefaultTopN(t *testing.T) {
|
||||
ents := map[string]*KGEntity{
|
||||
"A": {Sim: 0.5, PageRank: 0.9},
|
||||
"B": {Sim: 0.8, PageRank: 0.3},
|
||||
}
|
||||
result := SortAndTrimEntities(ents, 0)
|
||||
if len(result) != 2 {
|
||||
t.Errorf("expected default topN to include all, got %d", len(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSortAndTrimRelations(t *testing.T) {
|
||||
rels := map[Edge]*KGRelation{
|
||||
{"A", "B"}: {Sim: 0.9, PageRank: 0.1},
|
||||
{"C", "D"}: {Sim: 0.3, PageRank: 0.8},
|
||||
}
|
||||
result := SortAndTrimRelations(rels, 1)
|
||||
if len(result) != 1 {
|
||||
t.Fatalf("expected 1, got %d", len(result))
|
||||
}
|
||||
// A→B: 0.09, C→D: 0.24 → C→D should be first
|
||||
if result[0].From != "C" {
|
||||
t.Errorf("expected C first (0.24), got %s (%f)", result[0].From, result[0].Score)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Format and Build ---
|
||||
|
||||
func TestBuildKGContent_Basic(t *testing.T) {
|
||||
entities := []ScoredEntity{
|
||||
{Entity: "A", Score: 0.45, Description: `{"description": "Entity A desc"}`},
|
||||
}
|
||||
relations := []ScoredRelation{
|
||||
{From: "A", To: "B", Score: 0.3, Description: `{"description": "rel A-B"}`},
|
||||
}
|
||||
result := BuildKGContent(entities, relations, 10000)
|
||||
if !contains(result, "Entity A desc") {
|
||||
t.Errorf("expected entity description in output, got: %s", result)
|
||||
}
|
||||
if !contains(result, "rel A-B") {
|
||||
t.Errorf("expected relation description in output, got: %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildKGContent_TokenBudget(t *testing.T) {
|
||||
longDesc := strings.Repeat("This is a very long description. ", 50)
|
||||
entities := []ScoredEntity{
|
||||
{Entity: "LongEntityName", Score: 1.0, Description: longDesc},
|
||||
}
|
||||
relations := []ScoredRelation{
|
||||
{From: "X", To: "Y", Score: 1.0, Description: "relation desc"},
|
||||
}
|
||||
result := BuildKGContent(entities, relations, 50)
|
||||
// Token budget is very small, should truncate and not include relations
|
||||
if contains(result, "relation desc") {
|
||||
t.Log("Note: relations included despite small budget (depending on token count)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractDescription_JSON(t *testing.T) {
|
||||
result := extractDescription(`{"description": "Entity A description", "other": "value"}`)
|
||||
if result != "Entity A description" {
|
||||
t.Errorf("expected 'Entity A description', got %q", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractDescription_Plain(t *testing.T) {
|
||||
result := extractDescription("plain description")
|
||||
if result != "plain description" {
|
||||
t.Errorf("expected 'plain description', got %q", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNumTokensFromString(t *testing.T) {
|
||||
s := "This is a test string with multiple words"
|
||||
tokens := NumTokensFromString(s)
|
||||
if tokens <= 0 {
|
||||
t.Errorf("expected positive token count, got %d", tokens)
|
||||
}
|
||||
}
|
||||
|
||||
// contains checks if a string contains a substring.
|
||||
func contains(s, substr string) bool {
|
||||
return len(s) >= len(substr) && containsStr(s, substr)
|
||||
}
|
||||
|
||||
func containsStr(s, substr string) bool {
|
||||
for i := 0; i <= len(s)-len(substr); i++ {
|
||||
if s[i:i+len(substr)] == substr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
Reference in New Issue
Block a user