fix(rag/raptor): handle max_cluster edge case in GMM cluster selection (#16199)

### What problem does this PR solve?
`_get_optimal_clusters` in `rag/raptor.py` had two edge-case issues in
GMM cluster-count selection:
1. It used `np.arange(1, max_clusters)`, which never evaluates the
upper-bound candidate (`max_clusters`).
2. When effective `max_clusters` becomes `1`, the candidate list was
empty and `argmin` crashed.

This PR makes candidate evaluation inclusive (`1..max_clusters`) and
guards the single-cluster case by returning `1` directly.

### Type of change
- [x] Bug Fix (non-breaking change which fixes an issue)

### Validation
- `pytest test/unit_test/rag/test_raptor_psi_tree_builder.py
--config-file pyproject.toml -q`
- `ruff check rag/raptor.py
test/unit_test/rag/test_raptor_psi_tree_builder.py`

### Tests added
- Regression test for `max_cluster == 1` path (no crash, returns 1)
- Regression test verifying upper-bound candidate is evaluated and can
be selected

_AI-assistance disclosure: parts of this change (bug triage and test
scaffolding) were drafted with AI assistance and fully reviewed and
verified by me._

---------

Co-authored-by: Harsh Kashyap <harshkashyap@Harshs-MacBook-Pro.local>
Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
Harsh Kashyap
2026-06-23 06:07:26 -07:00
committed by GitHub
parent 706e0d2d06
commit b4a8a90c73
2 changed files with 49 additions and 2 deletions

View File

@@ -235,7 +235,14 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
def _get_optimal_clusters(self, embeddings: np.ndarray, random_state: int, task_id: str = ""):
"""Choose the GMM cluster count with the lowest BIC score."""
max_clusters = min(self._max_cluster, len(embeddings))
n_clusters = np.arange(1, max_clusters)
if max_clusters <= 1:
logging.info(
"RAPTOR GMM: _get_optimal_clusters returning 1 (max_clusters=%s, embeddings=%d)",
max_clusters,
len(embeddings),
)
return 1
n_clusters = np.arange(1, max_clusters + 1)
bics = []
for n in n_clusters:
self._check_task_canceled(task_id, "get optimal clusters")
@@ -244,7 +251,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
gm.fit(embeddings)
bics.append(gm.bic(embeddings))
optimal_clusters = n_clusters[np.argmin(bics)]
return optimal_clusters
return int(optimal_clusters)
def _get_clusters_ahc(self, embeddings: np.ndarray, task_id: str = "") -> np.ndarray:
"""Cluster embeddings with Ward-linkage AHC and a dendrogram gap heuristic."""