From b4a8a90c73d8553a7937f4bc9de50c0a6f480bff Mon Sep 17 00:00:00 2001 From: Harsh Kashyap Date: Tue, 23 Jun 2026 06:07:26 -0700 Subject: [PATCH] fix(rag/raptor): handle max_cluster edge case in GMM cluster selection (#16199) ### What problem does this PR solve? `_get_optimal_clusters` in `rag/raptor.py` had two edge-case issues in GMM cluster-count selection: 1. It used `np.arange(1, max_clusters)`, which never evaluates the upper-bound candidate (`max_clusters`). 2. When effective `max_clusters` becomes `1`, the candidate list was empty and `argmin` crashed. This PR makes candidate evaluation inclusive (`1..max_clusters`) and guards the single-cluster case by returning `1` directly. ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) ### Validation - `pytest test/unit_test/rag/test_raptor_psi_tree_builder.py --config-file pyproject.toml -q` - `ruff check rag/raptor.py test/unit_test/rag/test_raptor_psi_tree_builder.py` ### Tests added - Regression test for `max_cluster == 1` path (no crash, returns 1) - Regression test verifying upper-bound candidate is evaluated and can be selected _AI-assistance disclosure: parts of this change (bug triage and test scaffolding) were drafted with AI assistance and fully reviewed and verified by me._ --------- Co-authored-by: Harsh Kashyap Co-authored-by: Cursor --- rag/raptor.py | 11 ++++- .../rag/test_raptor_psi_tree_builder.py | 40 +++++++++++++++++++ 2 files changed, 49 insertions(+), 2 deletions(-) diff --git a/rag/raptor.py b/rag/raptor.py index a7f2c782d3..d39964f70f 100644 --- a/rag/raptor.py +++ b/rag/raptor.py @@ -235,7 +235,14 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: def _get_optimal_clusters(self, embeddings: np.ndarray, random_state: int, task_id: str = ""): """Choose the GMM cluster count with the lowest BIC score.""" max_clusters = min(self._max_cluster, len(embeddings)) - n_clusters = np.arange(1, max_clusters) + if max_clusters <= 1: + logging.info( + "RAPTOR GMM: _get_optimal_clusters returning 1 (max_clusters=%s, embeddings=%d)", + max_clusters, + len(embeddings), + ) + return 1 + n_clusters = np.arange(1, max_clusters + 1) bics = [] for n in n_clusters: self._check_task_canceled(task_id, "get optimal clusters") @@ -244,7 +251,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: gm.fit(embeddings) bics.append(gm.bic(embeddings)) optimal_clusters = n_clusters[np.argmin(bics)] - return optimal_clusters + return int(optimal_clusters) def _get_clusters_ahc(self, embeddings: np.ndarray, task_id: str = "") -> np.ndarray: """Cluster embeddings with Ward-linkage AHC and a dendrogram gap heuristic.""" diff --git a/test/unit_test/rag/test_raptor_psi_tree_builder.py b/test/unit_test/rag/test_raptor_psi_tree_builder.py index 1d0af20d96..5590d928f2 100644 --- a/test/unit_test/rag/test_raptor_psi_tree_builder.py +++ b/test/unit_test/rag/test_raptor_psi_tree_builder.py @@ -207,6 +207,46 @@ def test_unknown_clustering_method_is_rejected(raptor_module): _make_raptor(raptor_module, clustering_method="psi") +@pytest.mark.p2 +def test_get_optimal_clusters_handles_max_cluster_equal_one(raptor_module): + raptor = _make_raptor(raptor_module, max_cluster=1) + + optimal = raptor._get_optimal_clusters( + np.array([[0.0, 0.0], [0.1, 0.0], [1.0, 1.0], [1.1, 1.1]]), + random_state=0, + ) + + assert optimal == 1 + + +@pytest.mark.p2 +def test_get_optimal_clusters_evaluates_upper_bound_candidate(monkeypatch, raptor_module): + raptor = _make_raptor(raptor_module, max_cluster=3) + evaluated = [] + + class RecordingGaussianMixture: + def __init__(self, n_components, random_state=None): + self.n_components = n_components + evaluated.append(n_components) + + def fit(self, embeddings): + return self + + def bic(self, embeddings): + scores = {1: 30.0, 2: 20.0, 3: 10.0} + return scores[self.n_components] + + monkeypatch.setattr(raptor_module, "GaussianMixture", RecordingGaussianMixture) + + optimal = raptor._get_optimal_clusters( + np.array([[0.0, 0.0], [0.1, 0.0], [1.0, 1.0], [1.1, 1.1]]), + random_state=0, + ) + + assert optimal == 3 + assert evaluated == [1, 2, 3] + + def test_psi_tree_builder_ranks_all_leaf_pairs_by_original_cosine_similarity(raptor_module): raptor = _make_raptor(raptor_module) leaves = [