mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 15:31:05 +08:00
fix(rag/raptor): handle max_cluster edge case in GMM cluster selection (#16199)
### What problem does this PR solve? `_get_optimal_clusters` in `rag/raptor.py` had two edge-case issues in GMM cluster-count selection: 1. It used `np.arange(1, max_clusters)`, which never evaluates the upper-bound candidate (`max_clusters`). 2. When effective `max_clusters` becomes `1`, the candidate list was empty and `argmin` crashed. This PR makes candidate evaluation inclusive (`1..max_clusters`) and guards the single-cluster case by returning `1` directly. ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) ### Validation - `pytest test/unit_test/rag/test_raptor_psi_tree_builder.py --config-file pyproject.toml -q` - `ruff check rag/raptor.py test/unit_test/rag/test_raptor_psi_tree_builder.py` ### Tests added - Regression test for `max_cluster == 1` path (no crash, returns 1) - Regression test verifying upper-bound candidate is evaluated and can be selected _AI-assistance disclosure: parts of this change (bug triage and test scaffolding) were drafted with AI assistance and fully reviewed and verified by me._ --------- Co-authored-by: Harsh Kashyap <harshkashyap@Harshs-MacBook-Pro.local> Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -235,7 +235,14 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
|
||||
def _get_optimal_clusters(self, embeddings: np.ndarray, random_state: int, task_id: str = ""):
|
||||
"""Choose the GMM cluster count with the lowest BIC score."""
|
||||
max_clusters = min(self._max_cluster, len(embeddings))
|
||||
n_clusters = np.arange(1, max_clusters)
|
||||
if max_clusters <= 1:
|
||||
logging.info(
|
||||
"RAPTOR GMM: _get_optimal_clusters returning 1 (max_clusters=%s, embeddings=%d)",
|
||||
max_clusters,
|
||||
len(embeddings),
|
||||
)
|
||||
return 1
|
||||
n_clusters = np.arange(1, max_clusters + 1)
|
||||
bics = []
|
||||
for n in n_clusters:
|
||||
self._check_task_canceled(task_id, "get optimal clusters")
|
||||
@@ -244,7 +251,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
|
||||
gm.fit(embeddings)
|
||||
bics.append(gm.bic(embeddings))
|
||||
optimal_clusters = n_clusters[np.argmin(bics)]
|
||||
return optimal_clusters
|
||||
return int(optimal_clusters)
|
||||
|
||||
def _get_clusters_ahc(self, embeddings: np.ndarray, task_id: str = "") -> np.ndarray:
|
||||
"""Cluster embeddings with Ward-linkage AHC and a dendrogram gap heuristic."""
|
||||
|
||||
Reference in New Issue
Block a user