mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 23:41:12 +08:00
feat(raptor): add Psi tree builder with original-space ranking and safe migration (#14679)
### What problem does this PR solve? Closes #14674. This PR improves RAPTOR configuration and tree construction while preserving the existing RAPTOR behavior as the default. RAPTOR currently builds summary layers with the original UMAP + GMM clustering path. This PR keeps that default path, and adds: - A hidden backend tree-builder option: - `tree_builder="raptor"`: default, existing RAPTOR behavior. - `tree_builder="psi"`: rank-aware Psi-style tree builder using original embedding-space cosine ranking. - A user-facing clustering method option for the default RAPTOR builder: - `clustering_method="gmm"`: existing default. - `clustering_method="ahc"`: agglomerative hierarchical clustering path. - A RAPTOR UI setting for `Clustering method` and `Max cluster`. ### What changed #### Backend - Added `tree_builder` support for RAPTOR/Psi. - Added `clustering_method` support for GMM/AHC. - Kept existing RAPTOR + GMM as the default. - Added Psi tree building from original-space cosine similarity. - Added bucketed Psi building controls for large inputs: - `raptor.ext.psi_exact_max_leaves` - `raptor.ext.psi_bucket_size` - Added method-aware RAPTOR summary metadata using existing `extra.raptor_method`. - Avoided adding a dedicated DB schema field for experimental method tracking. - Added cleanup/migration logic to avoid mixing stale RAPTOR summary trees. - Added defensive checks for Psi tree construction and summary failures. #### Frontend/UI - Added `Clustering method` in RAPTOR settings with `GMM` and `AHC`. - Added/kept `Max cluster` in RAPTOR settings. - Enlarged max cluster UI limit to `1024`, matching backend validation. - Kept AHC editable even when a RAPTOR task has already finished. - Fixed the UI save payload so `clustering_method` and `tree_builder` are serialized through `parser_config.raptor.ext`, avoiding backend validation errors for extra top-level RAPTOR fields. Example saved RAPTOR config: ```json { "raptor": { "max_cluster": 317, "ext": { "clustering_method": "ahc", "tree_builder": "raptor" } } } Co-authored-by: CaptainTimon <CaptainTimon@users.noreply.github.com>
This commit is contained in:
@@ -583,6 +583,10 @@ class TestDatasetUpdate:
|
||||
{"raptor": {"max_cluster": 512}},
|
||||
{"raptor": {"max_cluster": 1024}},
|
||||
{"raptor": {"random_seed": 0}},
|
||||
{"raptor": {"clustering_method": "gmm"}},
|
||||
{"raptor": {"clustering_method": "ahc"}},
|
||||
{"raptor": {"tree_builder": "raptor"}},
|
||||
{"raptor": {"tree_builder": "psi"}},
|
||||
],
|
||||
ids=[
|
||||
"auto_keywords_min",
|
||||
@@ -633,6 +637,10 @@ class TestDatasetUpdate:
|
||||
"raptor_max_cluster_mid",
|
||||
"raptor_max_cluster_max",
|
||||
"raptor_random_seed_min",
|
||||
"raptor_clustering_method_gmm",
|
||||
"raptor_clustering_method_ahc",
|
||||
"raptor_tree_builder_raptor",
|
||||
"raptor_tree_builder_psi",
|
||||
],
|
||||
)
|
||||
def test_parser_config(self, HttpApiAuth, add_dataset_func, parser_config):
|
||||
@@ -707,6 +715,10 @@ class TestDatasetUpdate:
|
||||
({"raptor": {"random_seed": -1}}, "Input should be greater than or equal to 0"),
|
||||
({"raptor": {"random_seed": 3.14}}, "Input should be a valid integer"),
|
||||
({"raptor": {"random_seed": "string"}}, "Input should be a valid integer"),
|
||||
({"raptor": {"clustering_method": "unknown"}}, "Input should be 'gmm' or 'ahc'"),
|
||||
({"raptor": {"clustering_method": None}}, "Input should be 'gmm' or 'ahc'"),
|
||||
({"raptor": {"tree_builder": "ahc"}}, "Input should be 'raptor' or 'psi'"),
|
||||
({"raptor": {"tree_builder": None}}, "Input should be 'raptor' or 'psi'"),
|
||||
({"delimiter": "a" * 65536}, "Parser config exceeds size limit (max 65,535 characters)"),
|
||||
],
|
||||
ids=[
|
||||
@@ -763,6 +775,10 @@ class TestDatasetUpdate:
|
||||
"raptor_random_seed_min_limit",
|
||||
"raptor_random_seed_float_not_allowed",
|
||||
"raptor_random_seed_type_invalid",
|
||||
"raptor_clustering_method_invalid",
|
||||
"raptor_clustering_method_none_invalid",
|
||||
"raptor_tree_builder_invalid",
|
||||
"raptor_tree_builder_none_invalid",
|
||||
"parser_config_type_invalid",
|
||||
],
|
||||
)
|
||||
|
||||
375
test/unit_test/rag/test_raptor_psi_tree_builder.py
Normal file
375
test/unit_test/rag/test_raptor_psi_tree_builder.py
Normal file
@@ -0,0 +1,375 @@
|
||||
#
|
||||
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import importlib
|
||||
import sys
|
||||
import types
|
||||
|
||||
import pytest
|
||||
|
||||
np = pytest.importorskip("numpy")
|
||||
|
||||
from api.utils.validation_utils import RaptorConfig
|
||||
from pydantic import ValidationError
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def raptor_module(monkeypatch):
|
||||
class TaskCanceledException(Exception):
|
||||
pass
|
||||
|
||||
class DummyLimiter:
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
class DummyGaussianMixture:
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def fit(self, embeddings):
|
||||
return self
|
||||
|
||||
def bic(self, embeddings):
|
||||
return 0
|
||||
|
||||
def predict_proba(self, embeddings):
|
||||
return np.ones((len(embeddings), 1))
|
||||
|
||||
class DummyAgglomerativeClustering:
|
||||
def __init__(self, n_clusters=None, distance_threshold=None, compute_distances=False, linkage="ward"):
|
||||
self.n_clusters = n_clusters
|
||||
self.distance_threshold = distance_threshold
|
||||
self.compute_distances = compute_distances
|
||||
self.linkage = linkage
|
||||
self.distances_ = np.array([0.1, 0.2, 1.0])
|
||||
|
||||
def fit(self, embeddings):
|
||||
self.labels_ = self.fit_predict(embeddings)
|
||||
return self
|
||||
|
||||
def fit_predict(self, embeddings):
|
||||
if self.n_clusters is None:
|
||||
return np.zeros(len(embeddings), dtype=int)
|
||||
return np.array([idx % self.n_clusters for idx in range(len(embeddings))])
|
||||
|
||||
class DummyUMAP:
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def fit_transform(self, embeddings):
|
||||
raise AssertionError("Psi tree builder must use original embeddings, not UMAP")
|
||||
|
||||
sklearn_module = types.ModuleType("sklearn")
|
||||
mixture_module = types.ModuleType("sklearn.mixture")
|
||||
mixture_module.GaussianMixture = DummyGaussianMixture
|
||||
cluster_module = types.ModuleType("sklearn.cluster")
|
||||
cluster_module.AgglomerativeClustering = DummyAgglomerativeClustering
|
||||
umap_module = types.ModuleType("umap")
|
||||
umap_module.UMAP = DummyUMAP
|
||||
task_service_module = types.ModuleType("api.db.services.task_service")
|
||||
task_service_module.has_canceled = lambda task_id: False
|
||||
connection_utils_module = types.ModuleType("common.connection_utils")
|
||||
connection_utils_module.timeout = lambda seconds: lambda fn: fn
|
||||
exceptions_module = types.ModuleType("common.exceptions")
|
||||
exceptions_module.TaskCanceledException = TaskCanceledException
|
||||
token_utils_module = types.ModuleType("common.token_utils")
|
||||
token_utils_module.truncate = lambda text, max_len: text[:max_len]
|
||||
graphrag_utils_module = types.ModuleType("rag.graphrag.utils")
|
||||
graphrag_utils_module.chat_limiter = DummyLimiter()
|
||||
graphrag_utils_module.get_embed_cache = lambda *args, **kwargs: None
|
||||
graphrag_utils_module.get_llm_cache = lambda *args, **kwargs: None
|
||||
graphrag_utils_module.set_embed_cache = lambda *args, **kwargs: None
|
||||
graphrag_utils_module.set_llm_cache = lambda *args, **kwargs: None
|
||||
|
||||
async def thread_pool_exec(fn, *args, **kwargs):
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
misc_utils_module = types.ModuleType("common.misc_utils")
|
||||
misc_utils_module.thread_pool_exec = thread_pool_exec
|
||||
|
||||
monkeypatch.setitem(sys.modules, "sklearn", sklearn_module)
|
||||
monkeypatch.setitem(sys.modules, "sklearn.mixture", mixture_module)
|
||||
monkeypatch.setitem(sys.modules, "sklearn.cluster", cluster_module)
|
||||
monkeypatch.setitem(sys.modules, "umap", umap_module)
|
||||
monkeypatch.setitem(sys.modules, "api.db.services.task_service", task_service_module)
|
||||
monkeypatch.setitem(sys.modules, "common.connection_utils", connection_utils_module)
|
||||
monkeypatch.setitem(sys.modules, "common.exceptions", exceptions_module)
|
||||
monkeypatch.setitem(sys.modules, "common.token_utils", token_utils_module)
|
||||
monkeypatch.setitem(sys.modules, "rag.graphrag.utils", graphrag_utils_module)
|
||||
monkeypatch.setitem(sys.modules, "common.misc_utils", misc_utils_module)
|
||||
monkeypatch.delitem(sys.modules, "rag.raptor", raising=False)
|
||||
module = importlib.import_module("rag.raptor")
|
||||
yield module
|
||||
monkeypatch.delitem(sys.modules, "rag.raptor", raising=False)
|
||||
|
||||
|
||||
class FakeChatModel:
|
||||
llm_name = "fake-chat"
|
||||
max_length = 4096
|
||||
|
||||
def __init__(self):
|
||||
self.calls = []
|
||||
|
||||
async def async_chat(self, system, history, gen_conf):
|
||||
self.calls.append(history[0]["content"])
|
||||
return f"summary-{len(self.calls)}"
|
||||
|
||||
|
||||
class FakeEmbeddingModel:
|
||||
llm_name = "fake-embedding"
|
||||
|
||||
def encode(self, texts):
|
||||
embeddings = []
|
||||
for text in texts:
|
||||
checksum = sum(ord(ch) for ch in text)
|
||||
embeddings.append(np.array([len(text), checksum % 17 + 1], dtype=float))
|
||||
return embeddings, len(texts)
|
||||
|
||||
|
||||
_DEFAULT_TREE_BUILDER = object()
|
||||
|
||||
|
||||
def _make_raptor(raptor_module, max_cluster=64, tree_builder=_DEFAULT_TREE_BUILDER, **kwargs):
|
||||
if tree_builder is _DEFAULT_TREE_BUILDER:
|
||||
kwargs["tree_builder"] = raptor_module.PSI_TREE_BUILDER
|
||||
else:
|
||||
kwargs["tree_builder"] = tree_builder
|
||||
return raptor_module.RecursiveAbstractiveProcessing4TreeOrganizedRetrieval(
|
||||
max_cluster,
|
||||
FakeChatModel(),
|
||||
FakeEmbeddingModel(),
|
||||
"{cluster_content}",
|
||||
max_token=32,
|
||||
threshold=0.1,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def _chunks():
|
||||
return [
|
||||
("alpha first", np.array([1.0, 0.0])),
|
||||
("alpha second", np.array([0.99, 0.01])),
|
||||
("alpha third", np.array([0.98, 0.02])),
|
||||
]
|
||||
|
||||
|
||||
def test_default_tree_builder_remains_original_raptor(raptor_module):
|
||||
raptor = _make_raptor(raptor_module, tree_builder=None)
|
||||
|
||||
assert raptor._tree_builder == raptor_module.RAPTOR_TREE_BUILDER
|
||||
|
||||
|
||||
def test_unknown_tree_builder_is_rejected(raptor_module):
|
||||
with pytest.raises(ValueError, match="Unsupported RAPTOR tree builder"):
|
||||
_make_raptor(raptor_module, tree_builder="ahc")
|
||||
|
||||
|
||||
def test_raptor_config_accepts_hidden_psi_tree_builder():
|
||||
assert RaptorConfig().tree_builder == "raptor"
|
||||
assert RaptorConfig().clustering_method == "gmm"
|
||||
assert RaptorConfig(clustering_method="ahc").clustering_method == "ahc"
|
||||
assert RaptorConfig(tree_builder="psi").tree_builder == "psi"
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
RaptorConfig(tree_builder="ahc")
|
||||
with pytest.raises(ValidationError):
|
||||
RaptorConfig(clustering_method="psi")
|
||||
|
||||
|
||||
def test_ahc_clustering_method_is_supported_in_original_tree_builder(raptor_module):
|
||||
raptor = _make_raptor(raptor_module, tree_builder=raptor_module.RAPTOR_TREE_BUILDER, clustering_method="ahc")
|
||||
|
||||
labels = raptor._get_clusters_ahc(np.array([[0.0, 0.0], [0.1, 0.0], [10.0, 10.0], [10.1, 10.0]]))
|
||||
|
||||
assert raptor._tree_builder == raptor_module.RAPTOR_TREE_BUILDER
|
||||
assert raptor._clustering_method == "ahc"
|
||||
assert len(labels) == 4
|
||||
|
||||
|
||||
def test_unknown_clustering_method_is_rejected(raptor_module):
|
||||
with pytest.raises(ValueError, match="Unsupported RAPTOR clustering method"):
|
||||
_make_raptor(raptor_module, clustering_method="psi")
|
||||
|
||||
|
||||
def test_psi_tree_builder_ranks_all_leaf_pairs_by_original_cosine_similarity(raptor_module):
|
||||
raptor = _make_raptor(raptor_module)
|
||||
leaves = [
|
||||
raptor_module._PsiTreeNode(index=0, embedding=np.array([1.0, 0.0])),
|
||||
raptor_module._PsiTreeNode(index=1, embedding=np.array([0.0, 1.0])),
|
||||
raptor_module._PsiTreeNode(index=2, embedding=np.array([0.99, 0.01])),
|
||||
raptor_module._PsiTreeNode(index=3, embedding=np.array([-1.0, 0.0])),
|
||||
]
|
||||
|
||||
ranked_pairs = raptor._rank_leaf_pairs(leaves)
|
||||
|
||||
assert len(ranked_pairs) == 6
|
||||
assert tuple(ranked_pairs[0]) == (2, 0)
|
||||
|
||||
|
||||
def test_psi_tree_builder_uses_cosine_similarity_not_vector_magnitude(raptor_module):
|
||||
raptor = _make_raptor(raptor_module)
|
||||
leaves = [
|
||||
raptor_module._PsiTreeNode(index=0, embedding=np.array([100.0, 0.0])),
|
||||
raptor_module._PsiTreeNode(index=1, embedding=np.array([1.0, 1.0])),
|
||||
raptor_module._PsiTreeNode(index=2, embedding=np.array([0.1, 0.0])),
|
||||
]
|
||||
|
||||
ranked_pairs = raptor._rank_leaf_pairs(leaves)
|
||||
|
||||
assert tuple(ranked_pairs[0]) == (2, 0)
|
||||
|
||||
|
||||
def test_psi_tree_builder_handles_zero_vectors_in_cosine_ranking(raptor_module):
|
||||
raptor = _make_raptor(raptor_module)
|
||||
leaves = [
|
||||
raptor_module._PsiTreeNode(index=0, embedding=np.array([0.0, 0.0])),
|
||||
raptor_module._PsiTreeNode(index=1, embedding=np.array([1.0, 0.0])),
|
||||
raptor_module._PsiTreeNode(index=2, embedding=np.array([0.9, 0.1])),
|
||||
]
|
||||
|
||||
ranked_pairs = raptor._rank_leaf_pairs(leaves)
|
||||
|
||||
assert tuple(ranked_pairs[0]) == (2, 1)
|
||||
|
||||
|
||||
def test_psi_tree_builder_collapses_leaf_into_ranked_pair_parent(raptor_module):
|
||||
raptor = _make_raptor(raptor_module, max_cluster=64)
|
||||
|
||||
root, leaves = raptor._build_psi_structure(_chunks())
|
||||
|
||||
assert len(root.children) == 3
|
||||
assert {child.index for child in root.children} == {0, 1, 2}
|
||||
assert all(leaf.parent is root for leaf in leaves)
|
||||
|
||||
|
||||
def test_psi_tree_builder_collapses_leaf_at_matching_rank(monkeypatch, raptor_module):
|
||||
raptor = _make_raptor(raptor_module, max_cluster=64)
|
||||
chunks = [
|
||||
("node 0", np.array([1.0, 0.0])),
|
||||
("node 1", np.array([0.9, 0.1])),
|
||||
("node 2", np.array([-1.0, 0.0])),
|
||||
("node 3", np.array([-0.9, -0.1])),
|
||||
("node 4", np.array([0.8, 0.2])),
|
||||
]
|
||||
monkeypatch.setattr(
|
||||
raptor,
|
||||
"_rank_leaf_pairs",
|
||||
lambda _leaves: np.array([[0, 1], [2, 3], [0, 2], [4, 0]]),
|
||||
)
|
||||
|
||||
root, leaves = raptor._build_psi_structure(chunks)
|
||||
|
||||
assert leaves[4].parent is leaves[0].parent
|
||||
assert leaves[4].parent is not root
|
||||
assert len(root.children) == 2
|
||||
|
||||
|
||||
def test_psi_union_find_clamps_out_of_bounds_parent_rank(caplog, raptor_module):
|
||||
union_find = raptor_module._PsiUnionFind(2)
|
||||
union_find._node_ids[1] = [1]
|
||||
union_find._rank[0] = 2
|
||||
|
||||
with caplog.at_level("WARNING"):
|
||||
union_find._build(0, 1, insert_point=1)
|
||||
|
||||
assert union_find.tree[0] == 1
|
||||
assert "rank index" in caplog.text
|
||||
|
||||
|
||||
def test_psi_tree_builder_rebalances_nodes_over_max_children(raptor_module):
|
||||
raptor = _make_raptor(raptor_module, max_cluster=2)
|
||||
|
||||
root, _ = raptor._build_psi_structure(_chunks())
|
||||
|
||||
assert all(len(node.children) <= 2 for node in raptor._iter_nodes(root))
|
||||
assert len(root.children) == 2
|
||||
assert any(child.children for child in root.children)
|
||||
|
||||
|
||||
def test_psi_tree_builder_uses_bucketed_structure_for_large_inputs(monkeypatch, raptor_module):
|
||||
chunks = [(f"node {idx}", np.array([float(idx), float(idx % 3 + 1)])) for idx in range(8)]
|
||||
raptor = _make_raptor(
|
||||
raptor_module,
|
||||
max_cluster=3,
|
||||
psi_exact_max_leaves=3,
|
||||
psi_bucket_size=2,
|
||||
)
|
||||
ranked_sizes = []
|
||||
original_rank = raptor._rank_leaf_pairs
|
||||
|
||||
def track_rank(nodes):
|
||||
ranked_sizes.append(len(nodes))
|
||||
return original_rank(nodes)
|
||||
|
||||
monkeypatch.setattr(raptor, "_rank_leaf_pairs", track_rank)
|
||||
|
||||
root, leaves = raptor._build_psi_structure(chunks)
|
||||
|
||||
assert len(leaves) == len(chunks)
|
||||
assert all(leaf.parent is not None for leaf in leaves)
|
||||
assert all(len(node.children) <= 3 for node in raptor._iter_nodes(root))
|
||||
assert max(ranked_sizes) <= 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_psi_tree_builder_materializes_rebalanced_summary_layers_without_umap(monkeypatch, raptor_module):
|
||||
def fail_umap(*args, **kwargs):
|
||||
raise AssertionError("Psi tree builder must use original embeddings, not UMAP")
|
||||
|
||||
monkeypatch.setattr(raptor_module.umap, "UMAP", fail_umap)
|
||||
raptor = _make_raptor(raptor_module, max_cluster=2)
|
||||
|
||||
chunks, layers = await raptor(_chunks(), random_state=0)
|
||||
|
||||
assert len(chunks) == 5
|
||||
assert layers == [(0, 3), (3, 4), (4, 5)]
|
||||
assert [chunk[0] for chunk in chunks[3:]] == ["summary-1", "summary-2"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_psi_tree_builder_skips_failed_node_summary(monkeypatch, raptor_module):
|
||||
raptor = _make_raptor(raptor_module, max_cluster=2)
|
||||
|
||||
async def fail_summary(*args, **kwargs):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(raptor, "_summarize_texts", fail_summary)
|
||||
|
||||
chunks, layers = await raptor(_chunks(), random_state=0)
|
||||
|
||||
assert len(chunks) == 3
|
||||
assert [chunk[0] for chunk in chunks] == [chunk[0] for chunk in _chunks()]
|
||||
assert layers == [(0, 3)]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_original_raptor_stops_when_transient_summary_fails(monkeypatch, raptor_module):
|
||||
raptor = _make_raptor(raptor_module, tree_builder=raptor_module.RAPTOR_TREE_BUILDER)
|
||||
|
||||
async def fail_summary(*args, **kwargs):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(raptor, "_summarize_texts", fail_summary)
|
||||
|
||||
input_chunks = _chunks()[:2]
|
||||
chunks, layers = await raptor(input_chunks, random_state=0)
|
||||
|
||||
assert len(chunks) == 2
|
||||
assert [chunk[0] for chunk in chunks] == [chunk[0] for chunk in input_chunks]
|
||||
assert layers == [(0, 2)]
|
||||
@@ -18,15 +18,22 @@
|
||||
Unit tests for Raptor utility functions.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
import pytest
|
||||
from rag.utils.raptor_utils import (
|
||||
CSV_EXTENSIONS,
|
||||
EXCEL_EXTENSIONS,
|
||||
STRUCTURED_EXTENSIONS,
|
||||
collect_raptor_chunk_ids,
|
||||
collect_raptor_methods,
|
||||
get_raptor_clustering_method,
|
||||
get_raptor_tree_builder,
|
||||
get_skip_reason,
|
||||
is_structured_file_type,
|
||||
is_tabular_pdf,
|
||||
make_raptor_summary_chunk_id,
|
||||
should_skip_raptor,
|
||||
get_skip_reason,
|
||||
EXCEL_EXTENSIONS,
|
||||
CSV_EXTENSIONS,
|
||||
STRUCTURED_EXTENSIONS
|
||||
)
|
||||
|
||||
|
||||
@@ -283,5 +290,117 @@ class TestIntegrationScenarios:
|
||||
assert should_skip_raptor(file_type, raptor_config=raptor_config) is False
|
||||
|
||||
|
||||
class TestRaptorTreeBuilderConfig:
|
||||
"""Test RAPTOR tree builder config resolution"""
|
||||
|
||||
def test_defaults_to_original_raptor_builder(self):
|
||||
assert get_raptor_tree_builder({}) == "raptor"
|
||||
assert get_raptor_tree_builder(None) == "raptor"
|
||||
|
||||
def test_reads_top_level_tree_builder(self):
|
||||
assert get_raptor_tree_builder({"tree_builder": "psi"}) == "psi"
|
||||
|
||||
def test_reads_legacy_ext_tree_builder(self):
|
||||
assert get_raptor_tree_builder({"ext": {"tree_builder": "psi"}}) == "psi"
|
||||
|
||||
def test_ext_tree_builder_overrides_stale_top_level_value(self):
|
||||
assert get_raptor_tree_builder({"tree_builder": "psi", "ext": {"tree_builder": "raptor"}}) == "raptor"
|
||||
|
||||
def test_rejects_unknown_tree_builder(self):
|
||||
with pytest.raises(ValueError, match="Unsupported RAPTOR tree builder"):
|
||||
get_raptor_tree_builder({"tree_builder": "ahc"})
|
||||
|
||||
|
||||
class TestRaptorClusteringMethodConfig:
|
||||
"""Test RAPTOR clustering method config resolution"""
|
||||
|
||||
def test_defaults_to_gmm(self):
|
||||
assert get_raptor_clustering_method({}) == "gmm"
|
||||
assert get_raptor_clustering_method(None) == "gmm"
|
||||
|
||||
def test_reads_top_level_clustering_method(self):
|
||||
assert get_raptor_clustering_method({"clustering_method": "gmm"}) == "gmm"
|
||||
assert get_raptor_clustering_method({"clustering_method": "ahc"}) == "ahc"
|
||||
|
||||
def test_reads_legacy_ext_clustering_method(self):
|
||||
assert get_raptor_clustering_method({"ext": {"clustering_method": "ahc"}}) == "ahc"
|
||||
|
||||
def test_ext_clustering_method_overrides_stale_top_level_value(self):
|
||||
assert get_raptor_clustering_method({"clustering_method": "gmm", "ext": {"clustering_method": "ahc"}}) == "ahc"
|
||||
|
||||
def test_rejects_unknown_clustering_method(self):
|
||||
with pytest.raises(ValueError, match="Unsupported RAPTOR clustering method"):
|
||||
get_raptor_clustering_method({"clustering_method": "unknown"})
|
||||
|
||||
|
||||
class TestRaptorMethodCollection:
|
||||
"""Test RAPTOR summary method extraction from doc-store fields"""
|
||||
|
||||
def test_legacy_summary_without_method_is_original_raptor(self):
|
||||
field_map = {"chunk_1": {"raptor_kwd": "raptor"}}
|
||||
|
||||
assert collect_raptor_methods(field_map) == {"raptor"}
|
||||
assert collect_raptor_chunk_ids(field_map) == {"chunk_1"}
|
||||
|
||||
def test_extra_method_is_preserved(self):
|
||||
field_map = {"chunk_1": {"raptor_kwd": "raptor", "extra": {"raptor_method": "psi"}}}
|
||||
|
||||
assert collect_raptor_methods(field_map) == {"psi"}
|
||||
assert collect_raptor_chunk_ids(field_map) == {"chunk_1"}
|
||||
|
||||
def test_extra_field_supports_oceanbase_legacy_rows(self):
|
||||
field_map = {
|
||||
"chunk_1": {
|
||||
"extra": {
|
||||
"raptor_kwd": "raptor",
|
||||
"raptor_method": "psi",
|
||||
}
|
||||
},
|
||||
"chunk_2": {
|
||||
"extra": "{\"raptor_kwd\": \"raptor\"}",
|
||||
},
|
||||
"chunk_3": {
|
||||
"extra": {"raptor_kwd": ""},
|
||||
},
|
||||
}
|
||||
|
||||
assert collect_raptor_methods(field_map) == {"psi", "raptor"}
|
||||
assert collect_raptor_chunk_ids(field_map) == {"chunk_1", "chunk_2"}
|
||||
|
||||
def test_non_raptor_rows_are_ignored(self):
|
||||
field_map = {
|
||||
"chunk_1": {"raptor_kwd": ""},
|
||||
"chunk_2": {"extra": {"raptor_kwd": "graph"}},
|
||||
"chunk_3": {},
|
||||
}
|
||||
|
||||
assert collect_raptor_methods(field_map) == set()
|
||||
assert collect_raptor_chunk_ids(field_map) == set()
|
||||
|
||||
def test_malformed_extra_payload_is_logged_and_ignored(self, caplog):
|
||||
field_map = {"chunk_1": {"extra": "{bad json"}}
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
assert collect_raptor_methods(field_map) == set()
|
||||
assert collect_raptor_chunk_ids(field_map) == set()
|
||||
|
||||
assert "Ignoring malformed RAPTOR extra payload" in caplog.text
|
||||
|
||||
def test_chunk_id_collection_can_preserve_current_method(self):
|
||||
field_map = {
|
||||
"legacy": {"raptor_kwd": "raptor"},
|
||||
"old": {"raptor_kwd": "raptor", "extra": {"raptor_method": "raptor"}},
|
||||
"current": {"raptor_kwd": "raptor", "extra": {"raptor_method": "psi"}},
|
||||
}
|
||||
|
||||
assert collect_raptor_chunk_ids(field_map, exclude_methods={"psi"}) == {"legacy", "old"}
|
||||
assert collect_raptor_chunk_ids(field_map, exclude_methods={"raptor"}) == {"current"}
|
||||
|
||||
def test_summary_chunk_ids_include_real_document_id(self):
|
||||
content = "same generated summary"
|
||||
|
||||
assert make_raptor_summary_chunk_id(content, "doc-a") != make_raptor_summary_chunk_id(content, "doc-b")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
|
||||
Reference in New Issue
Block a user