mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 23:41:12 +08:00
### Summary Closes #15795 Knowledge-graph queries rank entities by `pagerank * sim` in `KGSearch`, but the entity chunks written at index time stopped carrying the values that ranking depends on. `graph_node_to_chunk` only stored `entity_type`, `description`, and `source_id`, dropping the node `pagerank` and the n-hop neighbour paths, while `search.py` still read them back as `rank_flt` and `n_hop_with_weight`. The producer of these fields, `update_nodes_pagerank_nhop_neighbour`, was removed in #6513, but the read side in `KGSearch` was never updated. The result is that on every knowledge-graph query: - `pagerank` resolves to `0`, so the `pagerank * sim` sort key is `0` for every entity and selection falls back to arbitrary order. - Every displayed entity score is `0.00`. - The n-hop relation-enrichment block is dead code because `n_hop_ents` is always empty, leaving `merge_tuples` and `is_continuous_subsequence` orphaned. This PR restores the missing index-time fields so the documented `P(E|Q) = pagerank * sim` ranking and the n-hop enrichment work again. What changed: - `graph_node_to_chunk` now writes `rank_flt` from the node pagerank and `n_hop_with_weight` from the recomputed n-hop neighbour paths. - Reintroduced the n-hop path computation (`n_neighbor`) in `rag/graphrag/utils.py`, reusing the previously orphaned `merge_tuples` / `is_continuous_subsequence` helpers, with a direction-agnostic edge-weight lookup for undirected graphs. `set_graph` computes the paths per added or updated node and passes them through. - `KGSearch` now selects `n_hop_with_weight` in the entity keyword search so Infinity and OceanBase return it (Elasticsearch and OpenSearch already read it from `_source`), and the read is hardened against missing keys or empty strings before `json.loads`. - Added the `n_hop_with_weight` column to OceanBase, including the `EXTRA_COLUMNS` migration entry so existing tables get it. The other engines already map both fields via dynamic templates or the Infinity mapping. Scope note: pagerank and n-hop are re-indexed for the added or updated nodes in each pass, consistent with the existing incremental indexing design. ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) ### Testing Added unit tests in `test/unit_test/rag/graphrag/test_graphrag_utils.py`: - `n_neighbor`: path and weight shape, one-hop vs two-hop, isolated nodes, missing weights, and direction-agnostic lookup. - `graph_node_to_chunk`: `rank_flt` populated from pagerank and defaulting to `0`, `n_hop_with_weight` serialized and defaulting to an empty list. ``` uv run pytest test/unit_test/rag/graphrag/ # 106 passed uv run ruff check rag/graphrag/ rag/utils/ob_conn.py ```
641 lines
22 KiB
Python
641 lines
22 KiB
Python
#
|
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
|
|
import json
|
|
from types import SimpleNamespace
|
|
|
|
import networkx as nx
|
|
import numpy as np
|
|
import pytest
|
|
|
|
import rag.graphrag.utils as graphrag_utils
|
|
from rag.graphrag.utils import (
|
|
GRAPH_FIELD_SEP,
|
|
GraphChange,
|
|
clean_str,
|
|
compute_args_hash,
|
|
dict_has_keys_with_types,
|
|
flat_uniq_list,
|
|
get_from_to,
|
|
graph_merge,
|
|
handle_single_entity_extraction,
|
|
handle_single_relationship_extraction,
|
|
is_continuous_subsequence,
|
|
is_float_regex,
|
|
merge_tuples,
|
|
n_neighbor,
|
|
pack_user_ass_to_openai_messages,
|
|
perform_variable_replacements,
|
|
split_string_by_multi_markers,
|
|
tidy_graph,
|
|
)
|
|
|
|
|
|
class TestCleanStr:
|
|
"""Tests for clean_str function."""
|
|
|
|
def test_basic_string(self):
|
|
assert clean_str("hello world") == "hello world"
|
|
|
|
def test_strips_whitespace(self):
|
|
assert clean_str(" hello ") == "hello"
|
|
|
|
def test_removes_html_escapes(self):
|
|
assert clean_str("& < >") == "& < >"
|
|
|
|
def test_removes_control_characters(self):
|
|
assert clean_str("hello\x00world") == "helloworld"
|
|
assert clean_str("test\x1f") == "test"
|
|
assert clean_str("\x7fdata") == "data"
|
|
|
|
def test_removes_double_quotes(self):
|
|
assert clean_str('"quoted"') == "quoted"
|
|
|
|
def test_non_string_passthrough(self):
|
|
assert clean_str(123) == 123
|
|
assert clean_str(None) is None
|
|
assert clean_str([1, 2]) == [1, 2]
|
|
|
|
def test_empty_string(self):
|
|
assert clean_str("") == ""
|
|
|
|
def test_combined_html_and_control(self):
|
|
assert clean_str(" &\x00test\x1f ") == "&test"
|
|
|
|
|
|
class TestDictHasKeysWithTypes:
|
|
"""Tests for dict_has_keys_with_types function."""
|
|
|
|
def test_matching_keys_and_types(self):
|
|
data = {"name": "Alice", "age": 30}
|
|
assert dict_has_keys_with_types(data, [("name", str), ("age", int)]) is True
|
|
|
|
def test_missing_key(self):
|
|
data = {"name": "Alice"}
|
|
assert dict_has_keys_with_types(data, [("name", str), ("age", int)]) is False
|
|
|
|
def test_wrong_type(self):
|
|
data = {"name": "Alice", "age": "thirty"}
|
|
assert dict_has_keys_with_types(data, [("name", str), ("age", int)]) is False
|
|
|
|
def test_empty_expected_fields(self):
|
|
assert dict_has_keys_with_types({"a": 1}, []) is True
|
|
|
|
def test_empty_data(self):
|
|
assert dict_has_keys_with_types({}, [("key", str)]) is False
|
|
|
|
def test_subclass_type_match(self):
|
|
assert dict_has_keys_with_types({"val": True}, [("val", int)]) is True
|
|
|
|
|
|
class TestPerformVariableReplacements:
|
|
"""Tests for perform_variable_replacements function."""
|
|
|
|
def test_simple_replacement(self):
|
|
result = perform_variable_replacements("Hello {name}!", variables={"name": "World"})
|
|
assert result == "Hello World!"
|
|
|
|
def test_multiple_replacements(self):
|
|
result = perform_variable_replacements(
|
|
"{greeting} {name}!",
|
|
variables={"greeting": "Hi", "name": "Alice"},
|
|
)
|
|
assert result == "Hi Alice!"
|
|
|
|
def test_no_variables(self):
|
|
result = perform_variable_replacements("No vars here")
|
|
assert result == "No vars here"
|
|
|
|
def test_empty_variables_dict(self):
|
|
result = perform_variable_replacements("{keep}", variables={})
|
|
assert result == "{keep}"
|
|
|
|
def test_history_system_message_replacement(self):
|
|
history = [
|
|
{"role": "system", "content": "You are {role}"},
|
|
{"role": "user", "content": "Hello {role}"},
|
|
]
|
|
perform_variable_replacements("input", history=history, variables={"role": "assistant"})
|
|
assert history[0]["content"] == "You are assistant"
|
|
assert history[1]["content"] == "Hello {role}"
|
|
|
|
def test_none_defaults(self):
|
|
result = perform_variable_replacements("text")
|
|
assert result == "text"
|
|
|
|
def test_non_string_variable_value(self):
|
|
result = perform_variable_replacements("count: {n}", variables={"n": 42})
|
|
assert result == "count: 42"
|
|
|
|
|
|
class TestGetFromTo:
|
|
"""Tests for get_from_to function."""
|
|
|
|
def test_ordered_pair(self):
|
|
assert get_from_to("A", "B") == ("A", "B")
|
|
|
|
def test_reversed_pair(self):
|
|
assert get_from_to("B", "A") == ("A", "B")
|
|
|
|
def test_equal_values(self):
|
|
assert get_from_to("X", "X") == ("X", "X")
|
|
|
|
def test_numeric_strings(self):
|
|
assert get_from_to("2", "1") == ("1", "2")
|
|
|
|
|
|
class TestComputeArgsHash:
|
|
"""Tests for compute_args_hash function."""
|
|
|
|
def test_deterministic(self):
|
|
h1 = compute_args_hash("a", "b", "c")
|
|
h2 = compute_args_hash("a", "b", "c")
|
|
assert h1 == h2
|
|
|
|
def test_different_args_different_hash(self):
|
|
h1 = compute_args_hash("a", "b")
|
|
h2 = compute_args_hash("a", "c")
|
|
assert h1 != h2
|
|
|
|
def test_returns_hex_string(self):
|
|
result = compute_args_hash("test")
|
|
assert isinstance(result, str)
|
|
assert len(result) == 32
|
|
int(result, 16)
|
|
|
|
def test_empty_args(self):
|
|
result = compute_args_hash()
|
|
assert isinstance(result, str)
|
|
|
|
|
|
class TestIsFloatRegex:
|
|
"""Tests for is_float_regex function."""
|
|
|
|
@pytest.mark.parametrize(
|
|
"value",
|
|
["1.0", "0.5", "100", "-3.14", "+2.7", ".5", "0"],
|
|
)
|
|
def test_valid_floats(self, value):
|
|
assert is_float_regex(value)
|
|
|
|
@pytest.mark.parametrize(
|
|
"value",
|
|
["abc", "", "1.2.3", "1e10", "inf", "NaN", " 1.0", "1.0 "],
|
|
)
|
|
def test_invalid_floats(self, value):
|
|
assert not is_float_regex(value)
|
|
|
|
|
|
class TestGraphChange:
|
|
"""Tests for GraphChange dataclass."""
|
|
|
|
def test_default_empty_sets(self):
|
|
change = GraphChange()
|
|
assert change.removed_nodes == set()
|
|
assert change.added_updated_nodes == set()
|
|
assert change.removed_edges == set()
|
|
assert change.added_updated_edges == set()
|
|
|
|
def test_mutable_default_independence(self):
|
|
c1 = GraphChange()
|
|
c2 = GraphChange()
|
|
c1.removed_nodes.add("A")
|
|
assert "A" not in c2.removed_nodes
|
|
|
|
|
|
class TestHandleSingleEntityExtraction:
|
|
"""Tests for handle_single_entity_extraction function."""
|
|
|
|
def test_valid_entity(self):
|
|
attrs = ['"entity"', "Alice", "Person", "A character"]
|
|
result = handle_single_entity_extraction(attrs, "chunk1")
|
|
assert result is not None
|
|
assert result["entity_name"] == "ALICE"
|
|
assert result["entity_type"] == "PERSON"
|
|
assert result["description"] == "A character"
|
|
assert result["source_id"] == "chunk1"
|
|
|
|
def test_not_entity_type(self):
|
|
attrs = ['"relationship"', "A", "B", "desc"]
|
|
assert handle_single_entity_extraction(attrs, "c1") is None
|
|
|
|
def test_too_few_attributes(self):
|
|
attrs = ['"entity"', "name", "type"]
|
|
assert handle_single_entity_extraction(attrs, "c1") is None
|
|
|
|
def test_empty_entity_name(self):
|
|
attrs = ['"entity"', '""', "Type", "Desc"]
|
|
assert handle_single_entity_extraction(attrs, "c1") is None
|
|
|
|
def test_entity_name_uppercased(self):
|
|
attrs = ['"entity"', "alice", "person", "desc"]
|
|
result = handle_single_entity_extraction(attrs, "c1")
|
|
assert result["entity_name"] == "ALICE"
|
|
assert result["entity_type"] == "PERSON"
|
|
|
|
|
|
class TestHandleSingleRelationshipExtraction:
|
|
"""Tests for handle_single_relationship_extraction function."""
|
|
|
|
def test_valid_relationship(self):
|
|
attrs = ['"relationship"', "Alice", "Bob", "friends with", "friendship", "2.0"]
|
|
result = handle_single_relationship_extraction(attrs, "chunk1")
|
|
assert result is not None
|
|
assert result["src_id"] == "ALICE"
|
|
assert result["tgt_id"] == "BOB"
|
|
assert result["weight"] == 2.0
|
|
assert result["description"] == "friends with"
|
|
assert result["keywords"] == "friendship"
|
|
assert result["source_id"] == "chunk1"
|
|
assert "created_at" in result["metadata"]
|
|
|
|
def test_not_relationship_type(self):
|
|
attrs = ['"entity"', "A", "B", "desc", "kw"]
|
|
assert handle_single_relationship_extraction(attrs, "c1") is None
|
|
|
|
def test_too_few_attributes(self):
|
|
attrs = ['"relationship"', "A", "B", "desc"]
|
|
assert handle_single_relationship_extraction(attrs, "c1") is None
|
|
|
|
def test_non_float_weight_defaults_to_one(self):
|
|
attrs = ['"relationship"', "A", "B", "desc", "kw", "not_a_number"]
|
|
result = handle_single_relationship_extraction(attrs, "c1")
|
|
assert result["weight"] == 1.0
|
|
|
|
def test_source_target_sorted(self):
|
|
attrs = ['"relationship"', "Zebra", "Apple", "desc", "kw", "1.0"]
|
|
result = handle_single_relationship_extraction(attrs, "c1")
|
|
assert result["src_id"] == "APPLE"
|
|
assert result["tgt_id"] == "ZEBRA"
|
|
|
|
|
|
class TestPackUserAssToOpenaiMessages:
|
|
"""Tests for pack_user_ass_to_openai_messages function."""
|
|
|
|
def test_single_message(self):
|
|
result = pack_user_ass_to_openai_messages("hello")
|
|
assert result == [{"role": "user", "content": "hello"}]
|
|
|
|
def test_alternating_roles(self):
|
|
result = pack_user_ass_to_openai_messages("q1", "a1", "q2")
|
|
assert result == [
|
|
{"role": "user", "content": "q1"},
|
|
{"role": "assistant", "content": "a1"},
|
|
{"role": "user", "content": "q2"},
|
|
]
|
|
|
|
def test_empty(self):
|
|
result = pack_user_ass_to_openai_messages()
|
|
assert result == []
|
|
|
|
|
|
class TestSplitStringByMultiMarkers:
|
|
"""Tests for split_string_by_multi_markers function."""
|
|
|
|
def test_single_marker(self):
|
|
result = split_string_by_multi_markers("a|b|c", ["|"])
|
|
assert result == ["a", "b", "c"]
|
|
|
|
def test_multiple_markers(self):
|
|
result = split_string_by_multi_markers("a|b;c", ["|", ";"])
|
|
assert result == ["a", "b", "c"]
|
|
|
|
def test_no_markers(self):
|
|
result = split_string_by_multi_markers("abc", [])
|
|
assert result == ["abc"]
|
|
|
|
def test_strips_whitespace(self):
|
|
result = split_string_by_multi_markers("a | b | c", ["|"])
|
|
assert result == ["a", "b", "c"]
|
|
|
|
def test_empty_segments_removed(self):
|
|
result = split_string_by_multi_markers("a||b", ["|"])
|
|
assert result == ["a", "b"]
|
|
|
|
def test_regex_special_chars_escaped(self):
|
|
result = split_string_by_multi_markers("a.b.c", ["."])
|
|
assert result == ["a", "b", "c"]
|
|
|
|
|
|
class TestGraphMerge:
|
|
"""Tests for graph_merge function."""
|
|
|
|
def _make_node(self, description="desc", source_id=None):
|
|
return {"description": description, "source_id": source_id or ["s1"]}
|
|
|
|
def _make_edge(self, weight=1.0, description="edge", keywords=None, source_id=None):
|
|
return {
|
|
"weight": weight,
|
|
"description": description,
|
|
"keywords": keywords or [],
|
|
"source_id": source_id or ["s1"],
|
|
}
|
|
|
|
def test_merge_disjoint_graphs(self):
|
|
g1 = nx.Graph()
|
|
g1.add_node("A", **self._make_node("A desc"))
|
|
g1.graph["source_id"] = ["doc1"]
|
|
|
|
g2 = nx.Graph()
|
|
g2.add_node("B", **self._make_node("B desc"))
|
|
g2.graph["source_id"] = ["doc2"]
|
|
|
|
change = GraphChange()
|
|
result = graph_merge(g1, g2, change)
|
|
|
|
assert result.has_node("A")
|
|
assert result.has_node("B")
|
|
assert "B" in change.added_updated_nodes
|
|
assert result.graph["source_id"] == ["doc1", "doc2"]
|
|
|
|
def test_merge_overlapping_nodes(self):
|
|
g1 = nx.Graph()
|
|
g1.add_node("A", description="first", source_id=["s1"])
|
|
g1.graph["source_id"] = ["doc1"]
|
|
|
|
g2 = nx.Graph()
|
|
g2.add_node("A", description="second", source_id=["s2"])
|
|
g2.graph["source_id"] = ["doc2"]
|
|
|
|
change = GraphChange()
|
|
graph_merge(g1, g2, change)
|
|
|
|
assert f"first{GRAPH_FIELD_SEP}second" == g1.nodes["A"]["description"]
|
|
assert g1.nodes["A"]["source_id"] == ["s1", "s2"]
|
|
|
|
def test_merge_overlapping_edges(self):
|
|
g1 = nx.Graph()
|
|
g1.add_node("A", **self._make_node())
|
|
g1.add_node("B", **self._make_node())
|
|
g1.add_edge("A", "B", **self._make_edge(weight=1.0, description="e1", keywords=["k1"], source_id=["s1"]))
|
|
g1.graph["source_id"] = ["doc1"]
|
|
|
|
g2 = nx.Graph()
|
|
g2.add_node("A", **self._make_node())
|
|
g2.add_node("B", **self._make_node())
|
|
g2.add_edge("A", "B", **self._make_edge(weight=2.0, description="e2", keywords=["k2"], source_id=["s2"]))
|
|
g2.graph["source_id"] = ["doc2"]
|
|
|
|
change = GraphChange()
|
|
graph_merge(g1, g2, change)
|
|
|
|
edge = g1.get_edge_data("A", "B")
|
|
assert edge["weight"] == 3.0
|
|
assert f"e1{GRAPH_FIELD_SEP}e2" == edge["description"]
|
|
assert edge["keywords"] == ["k1", "k2"]
|
|
assert edge["source_id"] == ["s1", "s2"]
|
|
|
|
def test_merge_tracks_changes(self):
|
|
g1 = nx.Graph()
|
|
g1.graph["source_id"] = []
|
|
|
|
g2 = nx.Graph()
|
|
g2.add_node("X", **self._make_node())
|
|
g2.add_node("Y", **self._make_node())
|
|
g2.add_edge("X", "Y", **self._make_edge())
|
|
g2.graph["source_id"] = ["doc1"]
|
|
|
|
change = GraphChange()
|
|
graph_merge(g1, g2, change)
|
|
|
|
assert {"X", "Y"} == change.added_updated_nodes
|
|
assert {("X", "Y")} == change.added_updated_edges
|
|
|
|
def test_merge_sets_rank(self):
|
|
g1 = nx.Graph()
|
|
g1.graph["source_id"] = []
|
|
|
|
g2 = nx.Graph()
|
|
g2.add_node("A", **self._make_node())
|
|
g2.add_node("B", **self._make_node())
|
|
g2.add_edge("A", "B", **self._make_edge())
|
|
g2.graph["source_id"] = ["doc1"]
|
|
|
|
change = GraphChange()
|
|
graph_merge(g1, g2, change)
|
|
|
|
assert g1.nodes["A"]["rank"] == 1
|
|
assert g1.nodes["B"]["rank"] == 1
|
|
|
|
|
|
class TestTidyGraph:
|
|
"""Tests for tidy_graph function."""
|
|
|
|
def test_removes_nodes_missing_attributes(self):
|
|
g = nx.Graph()
|
|
g.add_node("good", description="d", source_id="s")
|
|
g.add_node("bad")
|
|
messages = []
|
|
tidy_graph(g, lambda msg: messages.append(msg))
|
|
assert g.has_node("good")
|
|
assert not g.has_node("bad")
|
|
assert len(messages) == 1
|
|
|
|
def test_removes_edges_missing_attributes(self):
|
|
g = nx.Graph()
|
|
g.add_node("A", description="d", source_id="s")
|
|
g.add_node("B", description="d", source_id="s")
|
|
g.add_edge("A", "B")
|
|
messages = []
|
|
tidy_graph(g, lambda msg: messages.append(msg))
|
|
assert not g.has_edge("A", "B")
|
|
|
|
def test_adds_keywords_to_edges_without_it(self):
|
|
g = nx.Graph()
|
|
g.add_node("A", description="d", source_id="s")
|
|
g.add_node("B", description="d", source_id="s")
|
|
g.add_edge("A", "B", description="d", source_id="s")
|
|
tidy_graph(g, None)
|
|
assert g.edges["A", "B"]["keywords"] == []
|
|
|
|
def test_skip_attribute_check(self):
|
|
g = nx.Graph()
|
|
g.add_node("no_attrs")
|
|
g.add_edge("no_attrs", "no_attrs")
|
|
tidy_graph(g, None, check_attribute=False)
|
|
assert g.has_node("no_attrs")
|
|
|
|
def test_none_callback_no_error(self):
|
|
g = nx.Graph()
|
|
g.add_node("bad")
|
|
tidy_graph(g, None)
|
|
assert not g.has_node("bad")
|
|
|
|
|
|
class TestIsContinuousSubsequence:
|
|
"""Tests for is_continuous_subsequence function."""
|
|
|
|
def test_basic_match(self):
|
|
assert is_continuous_subsequence(("A", "B"), ("A", "B", "C")) is True
|
|
|
|
def test_no_match(self):
|
|
assert is_continuous_subsequence(("A", "C"), ("A", "B", "C")) is False
|
|
|
|
def test_at_end(self):
|
|
assert is_continuous_subsequence(("B", "C"), ("A", "B", "C")) is True
|
|
|
|
def test_single_element_sequence(self):
|
|
assert is_continuous_subsequence(("A", "B"), ("A",)) is False
|
|
|
|
|
|
class TestMergeTuples:
|
|
"""Tests for merge_tuples function."""
|
|
|
|
def test_basic_merge(self):
|
|
list1 = [("A", "B")]
|
|
list2 = [("B", "C")]
|
|
result = merge_tuples(list1, list2)
|
|
assert ("A", "B", "C") in result
|
|
|
|
def test_no_merge_possible(self):
|
|
list1 = [("A", "B")]
|
|
list2 = [("C", "D")]
|
|
result = merge_tuples(list1, list2)
|
|
assert ("A", "B") in result
|
|
|
|
def test_self_loop_kept(self):
|
|
list1 = [("A", "B", "A")]
|
|
list2 = []
|
|
result = merge_tuples(list1, list2)
|
|
assert ("A", "B", "A") in result
|
|
|
|
def test_empty_lists(self):
|
|
assert merge_tuples([], []) == []
|
|
|
|
|
|
@pytest.mark.p1
|
|
class TestNNeighbor:
|
|
"""Tests for n_neighbor function (n-hop neighbour path enumeration).
|
|
|
|
Regression coverage for the GraphRAG entity-ranking pipeline: the result
|
|
is serialized into each entity chunk as ``n_hop_with_weight`` and consumed
|
|
by KGSearch for n-hop relation enrichment.
|
|
"""
|
|
|
|
def _line_graph(self):
|
|
# A -1- B -2- C -3- D
|
|
g = nx.Graph()
|
|
g.add_edge("A", "B", weight=1)
|
|
g.add_edge("B", "C", weight=2)
|
|
g.add_edge("C", "D", weight=3)
|
|
return g
|
|
|
|
def test_isolated_node_returns_empty(self):
|
|
g = nx.Graph()
|
|
g.add_node("A")
|
|
assert n_neighbor(g, "A") == []
|
|
|
|
def test_result_shape(self):
|
|
nbrs = n_neighbor(self._line_graph(), "A")
|
|
assert isinstance(nbrs, list)
|
|
for nbr in nbrs:
|
|
assert set(nbr.keys()) == {"path", "weights"}
|
|
assert len(nbr["weights"]) == len(nbr["path"]) - 1
|
|
|
|
def test_two_hop_paths_and_weights(self):
|
|
# From A, 2-hop reaches the path A -> B -> C with weights [1, 2].
|
|
nbrs = n_neighbor(self._line_graph(), "A", n_hop=2)
|
|
paths = {tuple(n["path"]): n["weights"] for n in nbrs}
|
|
assert ("A", "B", "C") in paths
|
|
assert paths[("A", "B", "C")] == [1, 2]
|
|
|
|
def test_one_hop_only(self):
|
|
nbrs = n_neighbor(self._line_graph(), "A", n_hop=1)
|
|
paths = {tuple(n["path"]) for n in nbrs}
|
|
assert paths == {("A", "B")}
|
|
|
|
def test_missing_weight_defaults_to_zero(self):
|
|
g = nx.Graph()
|
|
g.add_edge("A", "B") # no weight attribute
|
|
nbrs = n_neighbor(g, "A", n_hop=1)
|
|
assert nbrs[0]["weights"] == [0]
|
|
|
|
def test_weight_lookup_is_direction_agnostic(self):
|
|
# Undirected graph: edge attributes may be keyed either way; the
|
|
# weight must still be recovered regardless of traversal direction.
|
|
nbrs = n_neighbor(self._line_graph(), "D", n_hop=1)
|
|
assert nbrs[0]["path"][0] == "D"
|
|
assert nbrs[0]["weights"] == [3]
|
|
|
|
|
|
@pytest.mark.p1
|
|
class TestGraphNodeToChunk:
|
|
"""Tests for graph_node_to_chunk field population.
|
|
|
|
Regression coverage for the dropped ranking fields: the entity chunk must
|
|
carry ``rank_flt`` (pagerank) and ``n_hop_with_weight`` so KGSearch's
|
|
``pagerank * sim`` ranking and n-hop enrichment are not permanently dead.
|
|
"""
|
|
|
|
@pytest.fixture
|
|
def fake_embd(self, monkeypatch):
|
|
# Skip the real encode/Redis path by returning a cached embedding.
|
|
monkeypatch.setattr(graphrag_utils, "get_embed_cache", lambda *_a, **_k: np.array([0.1, 0.2, 0.3]))
|
|
return graphrag_utils
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_writes_rank_flt_from_pagerank(self, fake_embd):
|
|
chunks = []
|
|
meta = {"entity_type": "PERSON", "description": "desc", "source_id": ["s1"], "pagerank": 0.42}
|
|
await fake_embd.graph_node_to_chunk("kb1", SimpleNamespace(llm_name="m"), "ALICE", meta, chunks)
|
|
assert len(chunks) == 1
|
|
assert chunks[0]["rank_flt"] == pytest.approx(0.42)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_rank_flt_defaults_to_zero_without_pagerank(self, fake_embd):
|
|
chunks = []
|
|
meta = {"entity_type": "PERSON", "description": "desc", "source_id": ["s1"]}
|
|
await fake_embd.graph_node_to_chunk("kb1", SimpleNamespace(llm_name="m"), "ALICE", meta, chunks)
|
|
assert chunks[0]["rank_flt"] == 0.0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_writes_n_hop_with_weight(self, fake_embd):
|
|
chunks = []
|
|
meta = {"entity_type": "PERSON", "description": "desc", "source_id": ["s1"], "pagerank": 0.1}
|
|
nhop = [{"path": ("ALICE", "BOB"), "weights": [3]}]
|
|
await fake_embd.graph_node_to_chunk("kb1", SimpleNamespace(llm_name="m"), "ALICE", meta, chunks, nhop)
|
|
stored = json.loads(chunks[0]["n_hop_with_weight"])
|
|
assert stored == [{"path": ["ALICE", "BOB"], "weights": [3]}]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_n_hop_defaults_to_empty_list(self, fake_embd):
|
|
chunks = []
|
|
meta = {"entity_type": "PERSON", "description": "desc", "source_id": ["s1"]}
|
|
await fake_embd.graph_node_to_chunk("kb1", SimpleNamespace(llm_name="m"), "ALICE", meta, chunks)
|
|
assert json.loads(chunks[0]["n_hop_with_weight"]) == []
|
|
|
|
|
|
class TestFlatUniqList:
|
|
"""Tests for flat_uniq_list function."""
|
|
|
|
def test_flat_lists(self):
|
|
arr = [{"k": [1, 2]}, {"k": [2, 3]}]
|
|
result = flat_uniq_list(arr, "k")
|
|
assert set(result) == {1, 2, 3}
|
|
|
|
def test_scalar_values(self):
|
|
arr = [{"k": "a"}, {"k": "b"}, {"k": "a"}]
|
|
result = flat_uniq_list(arr, "k")
|
|
assert set(result) == {"a", "b"}
|
|
|
|
def test_empty_list(self):
|
|
assert flat_uniq_list([], "k") == []
|
|
|
|
def test_mixed_list_and_scalar(self):
|
|
arr = [{"k": [1, 2]}, {"k": 3}]
|
|
result = flat_uniq_list(arr, "k")
|
|
assert set(result) == {1, 2, 3}
|