diff --git a/rag/raptor.py b/rag/raptor.py index a7f2c782d3..d39964f70f 100644 --- a/rag/raptor.py +++ b/rag/raptor.py @@ -235,7 +235,14 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: def _get_optimal_clusters(self, embeddings: np.ndarray, random_state: int, task_id: str = ""): """Choose the GMM cluster count with the lowest BIC score.""" max_clusters = min(self._max_cluster, len(embeddings)) - n_clusters = np.arange(1, max_clusters) + if max_clusters <= 1: + logging.info( + "RAPTOR GMM: _get_optimal_clusters returning 1 (max_clusters=%s, embeddings=%d)", + max_clusters, + len(embeddings), + ) + return 1 + n_clusters = np.arange(1, max_clusters + 1) bics = [] for n in n_clusters: self._check_task_canceled(task_id, "get optimal clusters") @@ -244,7 +251,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: gm.fit(embeddings) bics.append(gm.bic(embeddings)) optimal_clusters = n_clusters[np.argmin(bics)] - return optimal_clusters + return int(optimal_clusters) def _get_clusters_ahc(self, embeddings: np.ndarray, task_id: str = "") -> np.ndarray: """Cluster embeddings with Ward-linkage AHC and a dendrogram gap heuristic.""" diff --git a/test/unit_test/rag/test_raptor_psi_tree_builder.py b/test/unit_test/rag/test_raptor_psi_tree_builder.py index 1d0af20d96..5590d928f2 100644 --- a/test/unit_test/rag/test_raptor_psi_tree_builder.py +++ b/test/unit_test/rag/test_raptor_psi_tree_builder.py @@ -207,6 +207,46 @@ def test_unknown_clustering_method_is_rejected(raptor_module): _make_raptor(raptor_module, clustering_method="psi") +@pytest.mark.p2 +def test_get_optimal_clusters_handles_max_cluster_equal_one(raptor_module): + raptor = _make_raptor(raptor_module, max_cluster=1) + + optimal = raptor._get_optimal_clusters( + np.array([[0.0, 0.0], [0.1, 0.0], [1.0, 1.0], [1.1, 1.1]]), + random_state=0, + ) + + assert optimal == 1 + + +@pytest.mark.p2 +def test_get_optimal_clusters_evaluates_upper_bound_candidate(monkeypatch, raptor_module): + raptor = _make_raptor(raptor_module, max_cluster=3) + evaluated = [] + + class RecordingGaussianMixture: + def __init__(self, n_components, random_state=None): + self.n_components = n_components + evaluated.append(n_components) + + def fit(self, embeddings): + return self + + def bic(self, embeddings): + scores = {1: 30.0, 2: 20.0, 3: 10.0} + return scores[self.n_components] + + monkeypatch.setattr(raptor_module, "GaussianMixture", RecordingGaussianMixture) + + optimal = raptor._get_optimal_clusters( + np.array([[0.0, 0.0], [0.1, 0.0], [1.0, 1.0], [1.1, 1.1]]), + random_state=0, + ) + + assert optimal == 3 + assert evaluated == [1, 2, 3] + + def test_psi_tree_builder_ranks_all_leaf_pairs_by_original_cosine_similarity(raptor_module): raptor = _make_raptor(raptor_module) leaves = [