feat(raptor): add Psi tree builder with original-space ranking and safe migration (#14679)

### What problem does this PR solve?

Closes #14674.

This PR improves RAPTOR configuration and tree construction while
preserving the existing RAPTOR behavior as the default.

RAPTOR currently builds summary layers with the original UMAP + GMM
clustering path. This PR keeps that default path, and adds:

- A hidden backend tree-builder option:
  - `tree_builder="raptor"`: default, existing RAPTOR behavior.
- `tree_builder="psi"`: rank-aware Psi-style tree builder using original
embedding-space cosine ranking.
- A user-facing clustering method option for the default RAPTOR builder:
  - `clustering_method="gmm"`: existing default.
- `clustering_method="ahc"`: agglomerative hierarchical clustering path.
- A RAPTOR UI setting for `Clustering method` and `Max cluster`.

### What changed

#### Backend

- Added `tree_builder` support for RAPTOR/Psi.
- Added `clustering_method` support for GMM/AHC.
- Kept existing RAPTOR + GMM as the default.
- Added Psi tree building from original-space cosine similarity.
- Added bucketed Psi building controls for large inputs:
  - `raptor.ext.psi_exact_max_leaves`
  - `raptor.ext.psi_bucket_size`
- Added method-aware RAPTOR summary metadata using existing
`extra.raptor_method`.
- Avoided adding a dedicated DB schema field for experimental method
tracking.
- Added cleanup/migration logic to avoid mixing stale RAPTOR summary
trees.
- Added defensive checks for Psi tree construction and summary failures.

#### Frontend/UI

- Added `Clustering method` in RAPTOR settings with `GMM` and `AHC`.
- Added/kept `Max cluster` in RAPTOR settings.
- Enlarged max cluster UI limit to `1024`, matching backend validation.
- Kept AHC editable even when a RAPTOR task has already finished.
- Fixed the UI save payload so `clustering_method` and `tree_builder`
are serialized through `parser_config.raptor.ext`, avoiding backend
validation errors for extra top-level RAPTOR fields.

Example saved RAPTOR config:

```json
{
  "raptor": {
    "max_cluster": 317,
    "ext": {
      "clustering_method": "ahc",
      "tree_builder": "raptor"
    }
  }
}

Co-authored-by: CaptainTimon <CaptainTimon@users.noreply.github.com>
This commit is contained in:
CaptainTimon
2026-05-11 15:42:31 -10:00
committed by GitHub
parent 415169d497
commit 2717ee283f
21 changed files with 1722 additions and 140 deletions

View File

@@ -14,11 +14,13 @@
# limitations under the License.
#
import asyncio
from dataclasses import dataclass, field
import logging
import re
import numpy as np
import umap
from sklearn.cluster import AgglomerativeClustering
from sklearn.mixture import GaussianMixture
from api.db.services.task_service import has_canceled
@@ -33,9 +35,127 @@ from rag.graphrag.utils import (
set_llm_cache,
)
from common.misc_utils import thread_pool_exec
from rag.utils.raptor_utils import (
AHC_CLUSTERING_METHOD,
GMM_CLUSTERING_METHOD,
PSI_TREE_BUILDER,
RAPTOR_TREE_BUILDER,
SUPPORTED_CLUSTERING_METHODS,
SUPPORTED_TREE_BUILDERS,
)
@dataclass
class _PsiTreeNode:
"""Node used to represent the in-memory Psi merge tree."""
index: int
text: str = ""
embedding: np.ndarray | None = None
children: list["_PsiTreeNode"] = field(default_factory=list)
parent: "_PsiTreeNode | None" = None
class _PsiUnionFind:
"""Build parent links for the Psi merge tree from ranked leaf pairs."""
def __init__(self, n: int):
"""Initialize the union-find state for n leaf nodes."""
self._rank = [0 for _ in range(n)]
self._parent_chains = [[] for _ in range(n)]
self._node_ids = [[i] for i in range(n)]
self._tree = [-1 for _ in range(max(1, 2 * n - 1))]
self._next_id = n
@staticmethod
def _ordered_extend(target: list[int], values: list[int]):
"""Append unseen values while preserving their original order."""
for value in values:
if value not in target:
target.append(value)
def _find(self, i: int) -> list[int]:
"""Return the parent chain for a leaf, extending it lazily."""
chain = self._parent_chains[i]
if not chain or (len(chain) == 1 and chain[0] == i):
return [i]
if chain[0] == i:
self._ordered_extend(chain, self._find(chain[1]))
else:
self._ordered_extend(chain, self._find(chain[0]))
return chain
def _rank_bisect_right(self, chain: list[int], rank: int) -> int:
"""Return the first chain index whose rank is greater than rank."""
idx = 0
while idx < len(chain) and self._rank[chain[idx]] <= rank:
idx += 1
return idx
def _build(self, i: int, j: int, insert_point: int | None = None):
"""Record a merge edge in the compact parent array."""
if insert_point is not None:
parent_ids = self._node_ids[insert_point]
parent_rank_idx = self._rank[i] + 1
if parent_rank_idx >= len(parent_ids):
logging.warning(
"RAPTOR Psi union fallback: rank index %d is out of bounds for node %d with %d parent ids",
parent_rank_idx,
insert_point,
len(parent_ids),
)
parent_rank_idx = len(parent_ids) - 1
self._tree[self._node_ids[i][-1]] = parent_ids[parent_rank_idx]
return
self._tree[self._node_ids[i][-1]] = self._next_id
self._tree[self._node_ids[j][-1]] = self._next_id
self._node_ids[i].append(self._next_id)
self._next_id += 1
def union(self, i: int, j: int) -> bool:
"""Merge two ranked leaves and return whether a new edge was added."""
root_i = self._find(i)[-1]
root_j = self._find(j)[-1]
if root_i == root_j:
return False
if self._rank[root_i] < self._rank[root_j]:
if not self._parent_chains[root_j]:
self._parent_chains[root_j].append(root_j)
chain = self._parent_chains[j]
higher_rank_idx = self._rank_bisect_right(chain, self._rank[root_i])
if higher_rank_idx >= len(chain):
higher_rank_idx = len(chain) - 1
insert_point = chain[higher_rank_idx]
self._ordered_extend(self._parent_chains[root_i], chain[higher_rank_idx:])
self._build(root_i, root_j, insert_point=insert_point)
elif self._rank[root_i] > self._rank[root_j]:
if not self._parent_chains[root_i]:
self._parent_chains[root_i].append(root_i)
chain = self._parent_chains[i]
higher_rank_idx = self._rank_bisect_right(chain, self._rank[root_j])
if higher_rank_idx >= len(chain):
higher_rank_idx = len(chain) - 1
insert_point = chain[higher_rank_idx]
self._ordered_extend(self._parent_chains[root_j], chain[higher_rank_idx:])
self._build(root_j, root_i, insert_point=insert_point)
else:
if not self._parent_chains[root_i]:
self._parent_chains[root_i].append(root_i)
self._ordered_extend(self._parent_chains[root_j], self._parent_chains[i][-1:])
self._rank[root_i] += 1
self._build(root_i, root_j)
return True
@property
def tree(self) -> list[int]:
"""Return the compact child-to-parent array for constructed nodes."""
return self._tree[:self._next_id]
class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
"""Build RAPTOR summary layers with the classic or Psi tree strategy."""
def __init__(
self,
max_cluster,
@@ -45,7 +165,12 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
max_token=512,
threshold=0.1,
max_errors=3,
tree_builder=RAPTOR_TREE_BUILDER,
clustering_method=GMM_CLUSTERING_METHOD,
psi_exact_max_leaves=4096,
psi_bucket_size=1024,
):
"""Configure RAPTOR summarization, clustering, and Psi limits."""
self._max_cluster = max_cluster
self._llm_model = llm_model
self._embd_model = embd_model
@@ -54,8 +179,17 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
self._max_token = max_token
self._max_errors = max(1, max_errors)
self._error_count = 0
self._tree_builder = tree_builder or RAPTOR_TREE_BUILDER
if self._tree_builder not in SUPPORTED_TREE_BUILDERS:
raise ValueError(f"Unsupported RAPTOR tree builder: {self._tree_builder}")
self._clustering_method = clustering_method or GMM_CLUSTERING_METHOD
if self._clustering_method not in SUPPORTED_CLUSTERING_METHODS:
raise ValueError(f"Unsupported RAPTOR clustering method: {self._clustering_method}")
self._psi_exact_max_leaves = max(2, int(psi_exact_max_leaves or 4096))
self._psi_bucket_size = min(max(2, int(psi_bucket_size or 1024)), self._psi_exact_max_leaves)
def _check_task_canceled(self, task_id: str, message: str = ""):
"""Raise if the current document task was canceled."""
if task_id and has_canceled(task_id):
log_msg = f"Task {task_id} cancelled during RAPTOR {message}."
logging.info(log_msg)
@@ -63,6 +197,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
@timeout(60 * 20)
async def _chat(self, system, history, gen_conf):
"""Call the configured LLM with caching and short retries."""
cached = await thread_pool_exec(get_llm_cache, self._llm_model.llm_name, system, history, gen_conf)
if cached:
return cached
@@ -86,6 +221,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
@timeout(20)
async def _embedding_encode(self, txt):
"""Encode text with the configured embedding model and cache result."""
response = await thread_pool_exec(get_embed_cache, self._embd_model.llm_name, txt)
if response is not None:
return response
@@ -97,6 +233,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
return embds
def _get_optimal_clusters(self, embeddings: np.ndarray, random_state: int, task_id: str = ""):
"""Choose the GMM cluster count with the lowest BIC score."""
max_clusters = min(self._max_cluster, len(embeddings))
n_clusters = np.arange(1, max_clusters)
bics = []
@@ -109,57 +246,422 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
optimal_clusters = n_clusters[np.argmin(bics)]
return optimal_clusters
def _get_clusters_ahc(self, embeddings: np.ndarray, task_id: str = "") -> np.ndarray:
"""Cluster embeddings with Ward-linkage AHC and a dendrogram gap heuristic."""
n = len(embeddings)
if n <= 1:
return np.zeros(n, dtype=int)
if n == 2:
return np.arange(n)
self._check_task_canceled(task_id, "_get_clusters_ahc dendrogram")
full_clust = AgglomerativeClustering(
n_clusters=None,
distance_threshold=0,
compute_distances=True,
linkage="ward",
)
full_clust.fit(embeddings)
distances = full_clust.distances_
if len(distances) > 1:
gaps = np.diff(distances)
max_gap_idx = int(np.argmax(gaps))
n_clusters = max(1, min(n - max_gap_idx - 1, self._max_cluster))
else:
n_clusters = max(1, min(n, self._max_cluster))
if n_clusters <= 1:
logging.info("RAPTOR AHC: _get_clusters_ahc selected one cluster for %d embeddings", n)
return np.zeros(n, dtype=int)
logging.info("RAPTOR AHC: _get_clusters_ahc selected n_clusters=%d for %d embeddings", n_clusters, n)
self._check_task_canceled(task_id, "_get_clusters_ahc fit")
clustering = AgglomerativeClustering(n_clusters=n_clusters, linkage="ward")
return clustering.fit_predict(embeddings)
def _adjust_tree_nodes(self, embeddings: np.ndarray, labels: np.ndarray, max_iter: int = 5) -> np.ndarray:
"""Refine AHC assignments by reassigning nodes to nearest centroids."""
labels = labels.copy()
for _ in range(max_iter):
unique_labels = np.unique(labels)
if len(unique_labels) <= 1:
return labels
centroids = np.stack([embeddings[labels == lbl].mean(axis=0) for lbl in unique_labels])
diffs = embeddings[:, np.newaxis, :] - centroids[np.newaxis, :, :]
sq_dists = (diffs**2).sum(axis=2)
new_label_indices = np.argmin(sq_dists, axis=1)
new_labels = unique_labels[new_label_indices]
if np.array_equal(new_labels, labels):
break
unique_new = np.unique(new_labels)
remap = {old: new for new, old in enumerate(unique_new)}
labels = np.array([remap[int(lbl)] for lbl in new_labels])
return labels
@timeout(60 * 20)
async def _summarize_texts(self, texts: list[str], callback=None, task_id: str = ""):
"""Summarize a cluster and return text plus embedding when successful."""
self._check_task_canceled(task_id, "summarization")
len_per_chunk = int((self._llm_model.max_length - self._max_token) / len(texts))
cluster_content = "\n".join([truncate(t, max(1, len_per_chunk)) for t in texts])
try:
async with chat_limiter:
self._check_task_canceled(task_id, "before LLM call")
cnt = await self._chat(
"You're a helpful assistant.",
[
{
"role": "user",
"content": self._prompt.format(cluster_content=cluster_content),
}
],
{"max_tokens": max(self._max_token, 512)}, # fix issue: #10235
)
cnt = re.sub(
"(······\n由于长度的原因,回答被截断了,要继续吗?|For the content length reason, it stopped, continue?)",
"",
cnt,
)
logging.debug(f"SUM: {cnt}")
self._check_task_canceled(task_id, "before embedding")
embds = await self._embedding_encode(cnt)
return cnt, embds
except TaskCanceledException:
raise
except Exception as exc:
self._error_count += 1
warn_msg = f"[RAPTOR] Skip cluster ({len(texts)} chunks) due to error: {exc}"
logging.warning(warn_msg)
if callback:
callback(msg=warn_msg)
if self._error_count >= self._max_errors:
raise RuntimeError(f"RAPTOR aborted after {self._error_count} errors. Last error: {exc}") from exc
return None
@staticmethod
def _root(node: _PsiTreeNode) -> _PsiTreeNode:
"""Return the current root for a Psi tree node."""
while node.parent is not None:
node = node.parent
return node
def _rank_leaf_pairs(self, leaves: list[_PsiTreeNode]) -> np.ndarray:
"""Rank all leaf pairs by original embedding-space cosine similarity."""
node_embeddings = np.asarray([leaf.embedding for leaf in leaves], dtype=np.float64)
node_embeddings = self._normalize_embeddings(node_embeddings)
similarities = node_embeddings @ node_embeddings.T
lower = np.tril_indices(len(leaves), -1)
ordered = np.argsort(similarities[lower], axis=0)[::-1]
return np.stack([lower[0][ordered], lower[1][ordered]], axis=-1)
@staticmethod
def _normalize_embeddings(node_embeddings: np.ndarray) -> np.ndarray:
"""Normalize embeddings for cosine operations while tolerating zero vectors."""
node_embeddings = np.asarray(node_embeddings, dtype=np.float64)
norms = np.linalg.norm(node_embeddings, axis=1, keepdims=True)
return node_embeddings / np.maximum(norms, 1e-12)
def _split_psi_buckets(self, nodes: list[_PsiTreeNode]) -> list[list[_PsiTreeNode]]:
"""Split large Psi inputs so exact pair ranking is bounded per bucket."""
if len(nodes) <= self._psi_bucket_size:
return [nodes]
node_embeddings = self._normalize_embeddings(np.asarray([node.embedding for node in nodes], dtype=np.float64))
groups = [np.arange(len(nodes), dtype=int)]
buckets = []
while groups:
group = np.asarray(groups.pop(), dtype=int)
if len(group) <= self._psi_bucket_size:
buckets.append(group.tolist())
continue
fanout = min(max(2, int(np.ceil(len(group) / self._psi_bucket_size))), len(group), 32)
group_embeddings = node_embeddings[group]
center_idx = np.linspace(0, len(group_embeddings) - 1, num=fanout, dtype=int)
centers = group_embeddings[center_idx].copy()
for _ in range(5):
labels = np.argmax(group_embeddings @ centers.T, axis=1)
for center_id in range(fanout):
mask = labels == center_id
if not np.any(mask):
continue
center = group_embeddings[mask].mean(axis=0)
norm = np.linalg.norm(center)
centers[center_id] = center / norm if norm > 0 else center
labels = np.argmax(group_embeddings @ centers.T, axis=1)
split_groups = [group[labels == center_id].tolist() for center_id in range(fanout)]
split_groups = [bucket for bucket in split_groups if bucket]
if len(split_groups) <= 1:
split_groups = [
group[start:start + self._psi_bucket_size].tolist()
for start in range(0, len(group), self._psi_bucket_size)
]
groups.extend(split_groups)
buckets = [bucket for bucket in buckets if bucket]
buckets.sort(key=lambda bucket: (len(bucket), bucket[0]))
return [[nodes[idx] for idx in bucket] for bucket in buckets]
def _assign_prototype_embeddings(self, node: _PsiTreeNode) -> np.ndarray:
"""Assign mean child embeddings to internal Psi nodes for bucket-level ranking."""
if not node.children:
return np.asarray(node.embedding, dtype=np.float64)
embeddings = np.asarray([self._assign_prototype_embeddings(child) for child in node.children], dtype=np.float64)
node.embedding = embeddings.mean(axis=0)
return node.embedding
@staticmethod
def _iter_nodes(root: _PsiTreeNode):
"""Yield nodes in a Psi tree using a stack traversal."""
stack = [root]
while stack:
node = stack.pop()
yield node
stack.extend(node.children)
def _create_psi_parent(self, index: int, children: list[_PsiTreeNode]) -> _PsiTreeNode:
"""Create a parent node and attach the provided children to it."""
parent = _PsiTreeNode(index=index, children=children)
for child in children:
child.parent = parent
return parent
def _rebalance_psi_tree(self, root: _PsiTreeNode, next_index: int) -> tuple[_PsiTreeNode, int]:
"""Group oversized Psi tree nodes so fanout stays within max_cluster."""
max_children = max(2, int(self._max_cluster or 2))
def rebalance(node: _PsiTreeNode):
"""Recursively group children when a Psi node exceeds fanout."""
nonlocal next_index
for child in list(node.children):
rebalance(child)
while len(node.children) > max_children:
original_children = len(node.children)
grouped_children = []
for start in range(0, len(node.children), max_children):
batch = node.children[start:start + max_children]
if len(batch) == 1:
grouped_children.append(batch[0])
batch[0].parent = node
else:
grouped_children.append(self._create_psi_parent(next_index, batch))
grouped_children[-1].parent = node
next_index += 1
node.children = grouped_children
logging.info(
"RAPTOR Psi rebalance: node=%s children=%d grouped_to=%d max_cluster=%d",
node.index,
original_children,
len(grouped_children),
max_children,
)
rebalance(root)
return self._root(root), next_index
def _build_exact_psi_structure(
self,
nodes: list[_PsiTreeNode],
next_index: int,
task_id: str = "",
) -> tuple[_PsiTreeNode, int, int]:
"""Build an exact Psi subtree for a bounded node set."""
if len(nodes) == 1:
return nodes[0], next_index, 0
ranked_pairs = self._rank_leaf_pairs(nodes)
union_find = _PsiUnionFind(len(nodes))
merges = 0
for left_idx, right_idx in ranked_pairs:
self._check_task_canceled(task_id, "Psi tree construction")
if union_find.union(int(left_idx), int(right_idx)):
merges += 1
if merges == len(nodes) - 1:
break
local_nodes = {idx: node for idx, node in enumerate(nodes)}
tree = union_find.tree
children_by_parent = {}
for child_idx, parent_idx in enumerate(tree):
if child_idx not in local_nodes:
local_nodes[child_idx] = _PsiTreeNode(index=next_index)
next_index += 1
if parent_idx == -1:
continue
children_by_parent.setdefault(parent_idx, []).append(child_idx)
if parent_idx not in local_nodes:
local_nodes[parent_idx] = _PsiTreeNode(index=next_index)
next_index += 1
for parent_idx, child_indices in children_by_parent.items():
parent = local_nodes[parent_idx]
parent.children = [local_nodes[child_idx] for child_idx in child_indices]
for child in parent.children:
child.parent = parent
roots = [local_nodes[idx] for idx, parent_idx in enumerate(tree) if parent_idx == -1 and idx in local_nodes]
root = max(roots, key=lambda node: node.index)
return root, next_index, merges
def _build_bucketed_psi_structure(
self,
nodes: list[_PsiTreeNode],
next_index: int,
task_id: str = "",
) -> tuple[_PsiTreeNode, int, int]:
"""Build large Psi trees by exact-ranking bounded buckets, then bucket roots."""
buckets = self._split_psi_buckets(nodes)
logging.info(
"RAPTOR Psi bucketed build: nodes=%d buckets=%d bucket_size=%d exact_max_leaves=%d",
len(nodes),
len(buckets),
self._psi_bucket_size,
self._psi_exact_max_leaves,
)
bucket_roots = []
merges = 0
for bucket in buckets:
bucket_root, next_index, bucket_merges = self._build_psi_structure_from_nodes(bucket, next_index, task_id)
self._assign_prototype_embeddings(bucket_root)
bucket_roots.append(bucket_root)
merges += bucket_merges
if len(bucket_roots) == 1:
return bucket_roots[0], next_index, merges
root, next_index, root_merges = self._build_psi_structure_from_nodes(bucket_roots, next_index, task_id)
return root, next_index, merges + root_merges
def _build_psi_structure_from_nodes(
self,
nodes: list[_PsiTreeNode],
next_index: int,
task_id: str = "",
) -> tuple[_PsiTreeNode, int, int]:
"""Build Psi structure exactly for small sets and bucket large sets."""
if len(nodes) <= self._psi_exact_max_leaves:
return self._build_exact_psi_structure(nodes, next_index, task_id)
return self._build_bucketed_psi_structure(nodes, next_index, task_id)
def _build_psi_structure(self, chunks, task_id: str = "") -> tuple[_PsiTreeNode, list[_PsiTreeNode]]:
"""Build the Psi merge tree from original chunk embeddings."""
leaves = [
_PsiTreeNode(index=i, text=text, embedding=np.asarray(embd))
for i, (text, embd) in enumerate(chunks)
]
if len(leaves) == 1:
return leaves[0], leaves
root, next_index, merges = self._build_psi_structure_from_nodes(leaves, len(leaves), task_id)
root, _ = self._rebalance_psi_tree(root, next_index)
logging.info(
"RAPTOR Psi tree built: leaves=%d merges=%d root_fanout=%d",
len(leaves),
merges,
len(root.children),
)
return root, leaves
@staticmethod
def _psi_layers(root: _PsiTreeNode) -> dict[int, list[_PsiTreeNode]]:
"""Collect non-leaf Psi nodes by height for bottom-up summarization."""
layers = {}
def height(node: _PsiTreeNode) -> int:
"""Return node height while collecting internal nodes by layer."""
if not node.children:
return 0
node_height = max(height(child) for child in node.children) + 1
layers.setdefault(node_height, []).append(node)
return node_height
height(root)
return layers
async def _build_psi_layers(self, chunks, callback=None, task_id: str = ""):
"""Materialize Psi tree layers as summary chunks."""
layers = [(0, len(chunks))]
root, _ = self._build_psi_structure(chunks, task_id=task_id)
for layer_idx, (_, nodes) in enumerate(sorted(self._psi_layers(root).items()), start=1):
layer_start = len(chunks)
async def summarize_node(node: _PsiTreeNode):
"""Summarize one Psi internal node if its children have text."""
texts = [child.text for child in node.children if child.text]
if not texts:
logging.warning("RAPTOR Psi node %s skipped because it has no child text to summarize", node.index)
return None
result = await self._summarize_texts(texts, callback, task_id)
if result is None:
logging.warning("RAPTOR Psi node %s skipped because summarization failed", node.index)
return None
node.text, node.embedding = result
return node
tasks = [asyncio.create_task(summarize_node(node)) for node in nodes]
try:
summarized_nodes = await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:
logging.error(f"Error in RAPTOR Psi tree processing: {e}")
for task in tasks:
task.cancel()
await asyncio.gather(*tasks, return_exceptions=True)
raise
summarized_nodes = [node for node in summarized_nodes if node is not None]
for node in summarized_nodes:
chunks.append((node.text, node.embedding))
if len(chunks) > layer_start:
layers.append((layer_start, len(chunks)))
logging.info(
"RAPTOR Psi layer materialized: layer=%d nodes=%d summaries=%d",
layer_idx,
len(nodes),
len(chunks) - layer_start,
)
if callback:
callback(msg="Build one Psi-RAG layer: {} -> {}".format(len(nodes), len(chunks) - layer_start))
else:
logging.warning("RAPTOR Psi layer %d produced no summaries; stopping materialization", layer_idx)
break
return chunks, layers
async def __call__(self, chunks, random_state, callback=None, task_id: str = ""):
"""Build summary chunks and layer boundaries for RAPTOR retrieval."""
if len(chunks) <= 1:
return [], []
chunks = [(s, a) for s, a in chunks if s and a is not None and len(a) > 0]
if len(chunks) <= 1:
return chunks, [(0, len(chunks))]
if self._tree_builder == PSI_TREE_BUILDER:
logging.info("RAPTOR: using %s tree builder for %d chunks", self._tree_builder, len(chunks))
return await self._build_psi_layers(chunks, callback, task_id)
layers = [(0, len(chunks))]
start, end = 0, len(chunks)
@timeout(60 * 20)
async def summarize(ck_idx: list[int]):
"""Summarize one classic RAPTOR cluster into the chunk list."""
nonlocal chunks
self._check_task_canceled(task_id, "summarization")
texts = [chunks[i][0] for i in ck_idx]
len_per_chunk = int((self._llm_model.max_length - self._max_token) / len(texts))
cluster_content = "\n".join([truncate(t, max(1, len_per_chunk)) for t in texts])
try:
async with chat_limiter:
self._check_task_canceled(task_id, "before LLM call")
cnt = await self._chat(
"You're a helpful assistant.",
[
{
"role": "user",
"content": self._prompt.format(cluster_content=cluster_content),
}
],
{"max_tokens": max(self._max_token, 512)}, # fix issue: #10235
)
cnt = re.sub(
"(······\n由于长度的原因,回答被截断了,要继续吗?|For the content length reason, it stopped, continue?)",
"",
cnt,
)
logging.debug(f"SUM: {cnt}")
self._check_task_canceled(task_id, "before embedding")
embds = await self._embedding_encode(cnt)
chunks.append((cnt, embds))
except TaskCanceledException:
raise
except Exception as exc:
self._error_count += 1
warn_msg = f"[RAPTOR] Skip cluster ({len(ck_idx)} chunks) due to error: {exc}"
logging.warning(warn_msg)
if callback:
callback(msg=warn_msg)
if self._error_count >= self._max_errors:
raise RuntimeError(f"RAPTOR aborted after {self._error_count} errors. Last error: {exc}") from exc
result = await self._summarize_texts(texts, callback, task_id)
if result is not None:
chunks.append(result)
while end - start > 1:
self._check_task_canceled(task_id, "layer processing")
@@ -167,8 +669,12 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
embeddings = [embd for _, embd in chunks[start:end]]
if len(embeddings) == 2:
await summarize([start, start + 1])
produced = len(chunks) - end
if produced == 0:
logging.warning("RAPTOR layer produced no summaries; stopping materialization")
break
if callback:
callback(msg="Cluster one layer: {} -> {}".format(end - start, len(chunks) - end))
callback(msg="Cluster one layer: {} -> {}".format(end - start, produced))
layers.append((end, len(chunks)))
start = end
end = len(chunks)
@@ -180,15 +686,37 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
n_components=min(12, len(embeddings) - 2),
metric="cosine",
).fit_transform(embeddings)
n_clusters = self._get_optimal_clusters(reduced_embeddings, random_state, task_id=task_id)
if self._clustering_method == AHC_CLUSTERING_METHOD:
logging.info("RAPTOR: using clustering_method=%s before _get_clusters_ahc", self._clustering_method)
raw_labels = self._get_clusters_ahc(reduced_embeddings, task_id=task_id)
raw_cluster_count = np.unique(raw_labels).size
logging.info("RAPTOR AHC: _get_clusters_ahc produced n_clusters=%d", raw_cluster_count)
if raw_cluster_count > 1:
adjusted = self._adjust_tree_nodes(reduced_embeddings, raw_labels)
adjusted_cluster_count = np.unique(adjusted).size
logging.info("RAPTOR AHC: _adjust_tree_nodes adjusted n_clusters=%d", adjusted_cluster_count)
else:
adjusted = raw_labels
logging.warning("RAPTOR AHC: _adjust_tree_nodes skipped because _get_clusters_ahc returned one cluster")
unique_labels = np.unique(adjusted)
label_map = {old: idx for idx, old in enumerate(unique_labels)}
lbls = [label_map[int(lbl)] for lbl in adjusted]
n_clusters = len(unique_labels)
else:
n_clusters = self._get_optimal_clusters(reduced_embeddings, random_state, task_id=task_id)
if n_clusters == 1:
lbls = [0 for _ in range(len(reduced_embeddings))]
else:
gm = GaussianMixture(n_components=n_clusters, random_state=random_state)
gm.fit(reduced_embeddings)
probs = gm.predict_proba(reduced_embeddings)
lbls = [np.where(prob > self._threshold)[0] for prob in probs]
lbls = [lbl[0] if isinstance(lbl, np.ndarray) else lbl for lbl in lbls]
if n_clusters == 1:
lbls = [0 for _ in range(len(reduced_embeddings))]
else:
gm = GaussianMixture(n_components=n_clusters, random_state=random_state)
gm.fit(reduced_embeddings)
probs = gm.predict_proba(reduced_embeddings)
lbls = [np.where(prob > self._threshold)[0] for prob in probs]
lbls = [lbl[0] if isinstance(lbl, np.ndarray) else lbl for lbl in lbls]
lbls = [int(lbl[0]) if isinstance(lbl, np.ndarray) else int(lbl) for lbl in lbls]
tasks = []
for c in range(n_clusters):
@@ -205,10 +733,21 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
await asyncio.gather(*tasks, return_exceptions=True)
raise
assert len(chunks) - end == n_clusters, "{} vs. {}".format(len(chunks) - end, n_clusters)
produced = len(chunks) - end
assert produced <= n_clusters, "{} vs. {}".format(produced, n_clusters)
if produced < n_clusters:
logging.warning(
"RAPTOR layer produced %d/%d cluster summaries; skipped %d cluster(s) due to errors",
produced,
n_clusters,
n_clusters - produced,
)
if produced == 0:
logging.warning("RAPTOR layer produced no summaries; stopping materialization")
break
layers.append((end, len(chunks)))
if callback:
callback(msg="Cluster one layer: {} -> {}".format(end - start, len(chunks) - end))
callback(msg="Cluster one layer: {} -> {}".format(end - start, produced))
start = end
end = len(chunks)

View File

@@ -36,7 +36,15 @@ from api.db.joint_services.memory_message_service import handle_save_to_memory_t
from common.connection_utils import timeout
from common.metadata_utils import turn2jsonschema, update_metadata_to
from rag.utils.base64_image import image2id
from rag.utils.raptor_utils import should_skip_raptor, get_skip_reason
from rag.utils.raptor_utils import (
collect_raptor_chunk_ids,
collect_raptor_methods,
get_raptor_clustering_method,
get_raptor_tree_builder,
get_skip_reason,
make_raptor_summary_chunk_id,
should_skip_raptor,
)
from common.log_utils import init_root_logger
from common.config_utils import show_configs
from rag.graphrag.general.index import run_graphrag_for_kb
@@ -70,7 +78,10 @@ from api.db.db_models import close_connection
from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, \
email, tag
from rag.nlp import search, rag_tokenizer, add_positions
from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor
from rag.raptor import (
RAPTOR_TREE_BUILDER,
RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor,
)
from common.token_utils import num_tokens_from_string, truncate
from rag.utils.redis_conn import REDIS_CONN, RedisDistributedLock
from rag.graphrag.utils import chat_limiter
@@ -817,61 +828,160 @@ async def run_dataflow(task: dict):
dsl=str(pipeline))
async def has_raptor_chunks(doc_id: str, tenant_id: str, kb_id: str) -> bool:
"""Return True if RAPTOR chunks already exist for doc_id in the doc store.
RAPTOR_METHOD_SEARCH_LIMIT = 10000
Queries directly for raptor_kwd="raptor" rows so a non-RAPTOR leading
chunk cannot produce a false-negative result. Uses thread_pool_exec so
the blocking doc-store call does not stall the event loop.
"""
async def get_raptor_chunk_field_map(doc_id: str, tenant_id: str, kb_id: str) -> dict:
"""Return stored RAPTOR marker fields for a document."""
from common.doc_store.doc_store_base import OrderByExpr
from rag.nlp import search as nlp_search
try:
condition = {"doc_id": doc_id, "raptor_kwd": ["raptor"]}
async def search_fields(fields: list[str], condition: dict, order_by=None):
"""Search chunk fields in the current knowledge base."""
res = await thread_pool_exec(
settings.docStoreConn.search,
["raptor_kwd"], [], condition, [], OrderByExpr(),
0, 1, nlp_search.index_name(tenant_id), [kb_id]
fields, [], condition, [], order_by or OrderByExpr(),
0, RAPTOR_METHOD_SEARCH_LIMIT, nlp_search.index_name(tenant_id), [kb_id]
)
field_map = settings.docStoreConn.get_fields(res, ["raptor_kwd"])
found = bool(field_map)
if found:
return settings.docStoreConn.get_fields(res, fields)
primary = await search_fields(["raptor_kwd", "extra"], {"doc_id": doc_id, "raptor_kwd": ["raptor"]})
if collect_raptor_chunk_ids(primary):
return primary
try:
return await search_fields(
["raptor_kwd", "extra"],
{"doc_id": doc_id},
OrderByExpr().desc("create_timestamp_flt"),
)
except Exception:
logging.debug("RAPTOR fallback method lookup with extra field failed for doc %s", doc_id, exc_info=True)
return primary
async def get_raptor_chunk_methods(doc_id: str, tenant_id: str, kb_id: str) -> set[str]:
"""Return the RAPTOR tree builders already stored for doc_id.
Queries directly for raptor_kwd="raptor" rows so a non-RAPTOR leading
chunk cannot produce a false-negative result. Legacy summary chunks that
do not have method metadata are treated as the original RAPTOR builder.
"""
try:
field_map = await get_raptor_chunk_field_map(doc_id, tenant_id, kb_id)
methods = collect_raptor_methods(field_map)
if methods:
logging.info(
"Checkpoint hit: RAPTOR chunks for doc %s (tenant=%s kb=%s) already exist",
doc_id, tenant_id, kb_id,
"Checkpoint hit: RAPTOR chunks for doc %s (tenant=%s kb=%s methods=%s) already exist",
doc_id, tenant_id, kb_id, sorted(methods),
)
else:
logging.info(
"Checkpoint miss: no RAPTOR chunks for doc %s (tenant=%s kb=%s)",
doc_id, tenant_id, kb_id,
)
return found
return methods
except Exception:
logging.exception("Failed to check RAPTOR chunks for doc %s", doc_id)
return False
raise
async def has_raptor_chunks(doc_id: str, tenant_id: str, kb_id: str, tree_builder: str = RAPTOR_TREE_BUILDER) -> bool:
"""Return whether doc_id already has summaries for tree_builder."""
methods = await get_raptor_chunk_methods(doc_id, tenant_id, kb_id)
return tree_builder in methods
async def delete_raptor_chunks(doc_id: str, tenant_id: str, kb_id: str, keep_method: str | None = None):
"""Delete RAPTOR summaries for doc_id, optionally preserving one method."""
from rag.nlp import search as nlp_search
if keep_method is None:
logging.info(
"delete_raptor_chunks: removing all RAPTOR summaries (doc=%s tenant=%s kb=%s)",
doc_id, tenant_id, kb_id,
)
await thread_pool_exec(
settings.docStoreConn.delete,
{"doc_id": doc_id, "raptor_kwd": ["raptor"]},
nlp_search.index_name(tenant_id),
kb_id,
)
return 0
field_map = await get_raptor_chunk_field_map(doc_id, tenant_id, kb_id)
chunk_ids = collect_raptor_chunk_ids(field_map, exclude_methods={keep_method})
if not chunk_ids:
logging.debug(
"delete_raptor_chunks: no stale RAPTOR chunks to remove (doc=%s tenant=%s kb=%s keep=%s)",
doc_id, tenant_id, kb_id, keep_method,
)
return 0
logging.info(
"delete_raptor_chunks: removing %d stale RAPTOR chunks (doc=%s tenant=%s kb=%s keep=%s)",
len(chunk_ids), doc_id, tenant_id, kb_id, keep_method,
)
await thread_pool_exec(
settings.docStoreConn.delete,
{"id": list(chunk_ids)},
nlp_search.index_name(tenant_id),
kb_id,
)
return len(chunk_ids)
@timeout(3600)
async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_size, callback=None, doc_ids=[]):
"""Generate RAPTOR summaries for selected documents in a knowledge base."""
fake_doc_id = GRAPH_RAPTOR_FAKE_DOC_ID
raptor_config = kb_parser_config.get("raptor", {})
raptor_ext_config = raptor_config.get("ext") or {}
tree_builder = get_raptor_tree_builder(raptor_config)
clustering_method = get_raptor_clustering_method(raptor_config)
vctr_nm = "q_%d_vec" % vector_size
res = []
tk_count = 0
cleanup_raptor_chunks = []
max_errors = int(os.environ.get("RAPTOR_MAX_ERRORS", 3))
doc_name_by_id = {}
doc_info_by_id = {}
for doc_id in set(doc_ids):
ok, source_doc = DocumentService.get_by_id(doc_id)
if not ok or not source_doc:
continue
source_name = getattr(source_doc, "name", "")
if source_name:
doc_name_by_id[doc_id] = source_name
doc_info_by_id[doc_id] = {
"name": getattr(source_doc, "name", ""),
"type": getattr(source_doc, "type", ""),
"parser_id": getattr(source_doc, "parser_id", ""),
"parser_config": getattr(source_doc, "parser_config", {}) or {},
}
def schedule_raptor_cleanup(doc_id: str, keep_method: str | None = None):
"""Queue stale RAPTOR summaries for deletion after successful insert."""
cleanup_plan = (doc_id, keep_method)
if cleanup_plan not in cleanup_raptor_chunks:
cleanup_raptor_chunks.append(cleanup_plan)
def skip_raptor_doc(doc_id: str) -> bool:
"""Return whether RAPTOR should be skipped for this source document."""
doc_info = doc_info_by_id.get(doc_id, {})
file_type = doc_info.get("type") or row.get("type", "")
parser_id = doc_info.get("parser_id") or row.get("parser_id", "")
parser_config = doc_info.get("parser_config") or row.get("parser_config", {})
if should_skip_raptor(file_type, parser_id, parser_config, raptor_config):
skip_reason = get_skip_reason(file_type, parser_id, parser_config)
doc_name = doc_info.get("name") or doc_id
logging.info("Skipping Raptor for document %s: %s", doc_name, skip_reason)
callback(msg=f"[RAPTOR] doc:{doc_id} skipped: {skip_reason}")
return True
return False
async def generate(chunks, did):
"""Run RAPTOR and append generated summary chunks for one doc id."""
nonlocal tk_count, res
logging.info("RAPTOR: using tree_builder=%s clustering_method=%s for doc %s", tree_builder, clustering_method, did)
raptor = Raptor(
raptor_config.get("max_cluster", 64),
chat_mdl,
@@ -880,16 +990,21 @@ async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_si
raptor_config["max_token"],
raptor_config["threshold"],
max_errors=max_errors,
tree_builder=tree_builder,
clustering_method=clustering_method,
psi_exact_max_leaves=raptor_ext_config.get("psi_exact_max_leaves", 4096),
psi_bucket_size=raptor_ext_config.get("psi_bucket_size", 1024),
)
original_length = len(chunks)
chunks, layers = await raptor(chunks, kb_parser_config["raptor"]["random_seed"], callback, row["id"])
effective_doc_name = row["name"] if did == fake_doc_id else doc_name_by_id.get(did, row["name"])
effective_doc_name = row["name"] if did == fake_doc_id else doc_info_by_id.get(did, {}).get("name") or row["name"]
doc = {
"doc_id": did,
"kb_id": [str(row["kb_id"])],
"docnm_kwd": effective_doc_name,
"title_tks": rag_tokenizer.tokenize(effective_doc_name),
"raptor_kwd": "raptor"
"raptor_kwd": "raptor",
"extra": {"raptor_method": tree_builder},
}
if row["pagerank"]:
doc[PAGERANK_FLD] = int(row["pagerank"])
@@ -906,7 +1021,7 @@ async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_si
for idx, (content, vctr) in enumerate(chunks[original_length:], start=original_length):
d = copy.deepcopy(doc)
d["id"] = xxhash.xxh64((content + str(fake_doc_id)).encode("utf-8")).hexdigest()
d["id"] = make_raptor_summary_chunk_id(content, did)
d["create_time"] = str(datetime.now()).replace("T", " ")[:19]
d["create_timestamp_flt"] = datetime.now().timestamp()
d[vctr_nm] = vctr.tolist()
@@ -918,12 +1033,28 @@ async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_si
tk_count += num_tokens_from_string(content)
if raptor_config.get("scope", "file") == "file":
dataset_methods = await get_raptor_chunk_methods(fake_doc_id, row["tenant_id"], row["kb_id"])
remove_dataset_summaries = bool(dataset_methods)
has_file_level_target = False
if dataset_methods:
callback(msg="[RAPTOR] will remove dataset-level summaries after file-level summaries are available.")
for x, doc_id in enumerate(doc_ids):
# CHECKPOINT: skip docs that already have RAPTOR chunks in the doc store
if await has_raptor_chunks(doc_id, row["tenant_id"], row["kb_id"]):
callback(msg=f"[RAPTOR] doc:{doc_id} already has RAPTOR chunks, skipping.")
if skip_raptor_doc(doc_id):
callback(prog=(x + 1.) / len(doc_ids))
continue
# CHECKPOINT: skip docs that already have RAPTOR chunks in the doc store
existing_methods = await get_raptor_chunk_methods(doc_id, row["tenant_id"], row["kb_id"])
if tree_builder in existing_methods:
has_file_level_target = True
if existing_methods != {tree_builder}:
schedule_raptor_cleanup(doc_id, tree_builder)
callback(msg=f"[RAPTOR] doc:{doc_id} will remove old RAPTOR summaries after insert.")
callback(msg=f"[RAPTOR] doc:{doc_id} already has {tree_builder} RAPTOR chunks, skipping.")
callback(prog=(x + 1.) / len(doc_ids))
continue
if existing_methods:
callback(msg=f"[RAPTOR] doc:{doc_id} will migrate RAPTOR summaries to {tree_builder} after insert.")
chunks = []
skipped_chunks = 0
@@ -945,12 +1076,52 @@ async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_si
callback(msg=f"[WARN] No valid chunks with vectors found for doc {doc_id}, skipping")
continue
before_generate = len(res)
await generate(chunks, doc_id)
if len(res) > before_generate:
has_file_level_target = True
if existing_methods:
schedule_raptor_cleanup(doc_id, tree_builder)
callback(prog=(x + 1.) / len(doc_ids))
if remove_dataset_summaries:
if has_file_level_target:
schedule_raptor_cleanup(fake_doc_id)
else:
callback(msg="[RAPTOR] kept dataset-level summaries because no file-level summaries were built.")
else:
migrated_file_docs = 0
file_cleanup_doc_ids = []
skipped_doc_ids = set()
for doc_id in set(doc_ids):
if skip_raptor_doc(doc_id):
skipped_doc_ids.add(doc_id)
continue
existing_methods = await get_raptor_chunk_methods(doc_id, row["tenant_id"], row["kb_id"])
if existing_methods:
file_cleanup_doc_ids.append(doc_id)
migrated_file_docs += 1
if migrated_file_docs:
callback(msg=f"[RAPTOR] will remove file-level summaries for {migrated_file_docs} docs after dataset-level build succeeds.")
existing_methods = await get_raptor_chunk_methods(fake_doc_id, row["tenant_id"], row["kb_id"])
if tree_builder in existing_methods:
if existing_methods != {tree_builder}:
schedule_raptor_cleanup(fake_doc_id, tree_builder)
callback(msg="[RAPTOR] will remove old dataset-level RAPTOR summaries after insert.")
for doc_id in file_cleanup_doc_ids:
schedule_raptor_cleanup(doc_id)
callback(msg=f"[RAPTOR] dataset-level {tree_builder} summaries already exist, skipping.")
return res, tk_count, cleanup_raptor_chunks
migrate_dataset_summaries = bool(existing_methods)
if migrate_dataset_summaries:
callback(msg=f"[RAPTOR] will migrate dataset-level RAPTOR summaries to {tree_builder} after insert.")
chunks = []
skipped_chunks = 0
for doc_id in doc_ids:
if doc_id in skipped_doc_ids:
continue
for d in settings.retriever.chunk_list(doc_id, row["tenant_id"], [str(row["kb_id"])],
fields=["content_with_weight", vctr_nm],
sort_by_position=True):
@@ -965,13 +1136,22 @@ async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_si
callback(msg=f"[WARN] Skipped {skipped_chunks} chunks without vector field '{vctr_nm}'. Consider re-parsing documents with the current embedding model.")
if not chunks:
if skipped_doc_ids and len(skipped_doc_ids) == len(set(doc_ids)):
callback(msg="[RAPTOR] all documents were skipped by RAPTOR auto-disable rules.")
return res, tk_count, cleanup_raptor_chunks
logging.error(f"RAPTOR: No valid chunks with vectors found in any document for kb {row['kb_id']}")
callback(msg=f"[ERROR] No valid chunks with vectors found. Please ensure documents are parsed with the current embedding model (vector size: {vector_size}).")
return res, tk_count
return res, tk_count, cleanup_raptor_chunks
before_generate = len(res)
await generate(chunks, fake_doc_id)
if len(res) > before_generate:
for doc_id in file_cleanup_doc_ids:
schedule_raptor_cleanup(doc_id)
if migrate_dataset_summaries:
schedule_raptor_cleanup(fake_doc_id, tree_builder)
return res, tk_count
return res, tk_count, cleanup_raptor_chunks
async def delete_image(kb_id, chunk_id):
@@ -1029,6 +1209,29 @@ async def insert_chunks(task_id, task_tenant_id, task_dataset_id, chunks, progre
search.index_name(task_tenant_id), task_dataset_id, )
task_canceled = has_canceled(task_id)
if task_canceled:
# Roll back partial RAPTOR summary inserts so the next run is not
# mistaken for a completed checkpoint by get_raptor_chunk_methods.
raptor_ids_to_rollback = [
c["id"] for c in chunks[:b + settings.DOC_BULK_SIZE]
if c.get("raptor_kwd") == "raptor"
]
if raptor_ids_to_rollback:
try:
await thread_pool_exec(
settings.docStoreConn.delete,
{"id": raptor_ids_to_rollback},
search.index_name(task_tenant_id),
task_dataset_id,
)
logging.info(
"insert_chunks: rolled back %d partial RAPTOR chunks after cancellation (task=%s)",
len(raptor_ids_to_rollback), task_id,
)
except Exception:
logging.exception(
"insert_chunks: failed to roll back partial RAPTOR chunks after cancellation (task=%s)",
task_id,
)
progress_callback(-1, msg="Task has been canceled.")
return False
if b % 128 == 0:
@@ -1088,6 +1291,7 @@ async def do_handle_task(task):
task_parser_config = task["parser_config"]
task_start_ts = timer()
toc_thread = None
raptor_cleanup_chunks = []
# prepare the progress callback function
progress_callback = partial(set_progress, task_id, task_from_page, task_to_page)
@@ -1135,7 +1339,9 @@ async def do_handle_task(task):
"threshold": 0.1,
"max_cluster": 64,
"random_seed": 0,
"scope": "file"
"scope": "file",
"clustering_method": "gmm",
"tree_builder": "raptor",
},
}
)
@@ -1143,23 +1349,12 @@ async def do_handle_task(task):
progress_callback(prog=-1.0, msg="Internal error: Invalid RAPTOR configuration")
return
# Check if Raptor should be skipped for structured data
file_type = task.get("type", "")
parser_id = task.get("parser_id", "")
raptor_config = kb_parser_config.get("raptor", {})
if should_skip_raptor(file_type, parser_id, task_parser_config, raptor_config):
skip_reason = get_skip_reason(file_type, parser_id, task_parser_config)
logging.info(f"Skipping Raptor for document {task_document_name}: {skip_reason}")
progress_callback(prog=1.0, msg=f"Raptor skipped: {skip_reason}")
return
# bind LLM for raptor
chat_model_config = get_model_config_by_type_and_name(task_tenant_id, LLMType.CHAT, kb_task_llm_id)
chat_model = LLMBundle(task_tenant_id, chat_model_config, lang=task_language)
# run RAPTOR
async with kg_limiter:
chunks, token_count = await run_raptor_for_kb(
chunks, token_count, raptor_cleanup_chunks = await run_raptor_for_kb(
row=task,
kb_parser_config=kb_parser_config,
chat_mdl=chat_model,
@@ -1268,6 +1463,18 @@ async def do_handle_task(task):
progress_callback(-1, msg="Task has been canceled.")
return
if raptor_cleanup_chunks:
cleaned_chunks = 0
for cleanup_doc_id, keep_method in raptor_cleanup_chunks:
cleaned_chunks += await delete_raptor_chunks(
cleanup_doc_id,
task_tenant_id,
task_dataset_id,
keep_method=keep_method,
)
if cleaned_chunks:
progress_callback(msg=f"Cleaned up {cleaned_chunks} stale RAPTOR chunks.")
logging.info(
"Indexing doc({}), page({}-{}), chunks({}), elapsed: {:.2f}".format(
task_document_name, task_from_page, task_to_page, len(chunks), timer() - start_ts

View File

@@ -46,6 +46,8 @@ column_order_id = Column("_order_id", Integer, nullable=True, comment="chunk ord
column_group_id = Column("group_id", String(256), nullable=True, comment="group id for external retrieval")
column_mom_id = Column("mom_id", String(256), nullable=True, comment="parent chunk id")
column_chunk_data = Column("chunk_data", JSON, nullable=True, comment="table parser row data")
column_raptor_kwd = Column("raptor_kwd", String(256), nullable=True, comment="RAPTOR summary marker")
column_raptor_layer_int = Column("raptor_layer_int", Integer, nullable=True, comment="RAPTOR summary layer")
column_definitions: list[Column] = [
Column("id", String(256), primary_key=True, comment="chunk id"),
@@ -86,6 +88,8 @@ column_definitions: list[Column] = [
Column("rank_flt", Double, nullable=True, comment="rank of this entity"),
Column("removed_kwd", String(256), nullable=True, index=True, server_default="'N'",
comment="whether it has been deleted"),
column_raptor_kwd,
column_raptor_layer_int,
column_chunk_data,
Column("metadata", JSON, nullable=True, comment="metadata for this chunk"),
Column("extra", JSON, nullable=True, comment="extra information of non-general chunk"),
@@ -127,7 +131,14 @@ FTS_COLUMNS_TKS: list[str] = [
]
# Extra columns to add after table creation (for migration)
EXTRA_COLUMNS: list[Column] = [column_order_id, column_group_id, column_mom_id, column_chunk_data]
EXTRA_COLUMNS: list[Column] = [
column_order_id,
column_group_id,
column_mom_id,
column_chunk_data,
column_raptor_kwd,
column_raptor_layer_int,
]
class SearchResult(BaseModel):

View File

@@ -18,15 +18,111 @@
Utility functions for Raptor processing decisions.
"""
import json
import logging
from typing import Optional
import xxhash
RAPTOR_TREE_BUILDER = "raptor"
PSI_TREE_BUILDER = "psi"
SUPPORTED_TREE_BUILDERS = {RAPTOR_TREE_BUILDER, PSI_TREE_BUILDER}
GMM_CLUSTERING_METHOD = "gmm"
AHC_CLUSTERING_METHOD = "ahc"
SUPPORTED_CLUSTERING_METHODS = {GMM_CLUSTERING_METHOD, AHC_CLUSTERING_METHOD}
# File extensions for structured data types
EXCEL_EXTENSIONS = {".xls", ".xlsx", ".xlsm", ".xlsb"}
CSV_EXTENSIONS = {".csv", ".tsv"}
STRUCTURED_EXTENSIONS = EXCEL_EXTENSIONS | CSV_EXTENSIONS
def get_raptor_tree_builder(raptor_config: dict | None) -> str:
"""Return the configured RAPTOR tree builder with legacy ext fallback."""
raptor_config = raptor_config or {}
ext = raptor_config.get("ext") or {}
tree_builder = ext.get("tree_builder") or raptor_config.get("tree_builder") or RAPTOR_TREE_BUILDER
if tree_builder not in SUPPORTED_TREE_BUILDERS:
raise ValueError(f"Unsupported RAPTOR tree builder: {tree_builder}")
return tree_builder
def get_raptor_clustering_method(raptor_config: dict | None) -> str:
"""Return the configured RAPTOR clustering method with legacy ext fallback."""
raptor_config = raptor_config or {}
ext = raptor_config.get("ext") or {}
clustering_method = ext.get("clustering_method") or raptor_config.get("clustering_method") or GMM_CLUSTERING_METHOD
if clustering_method not in SUPPORTED_CLUSTERING_METHODS:
raise ValueError(f"Unsupported RAPTOR clustering method: {clustering_method}")
return clustering_method
def _as_extra_dict(extra) -> dict:
"""Normalize a chunk extra payload into a dictionary."""
if isinstance(extra, dict):
return extra
if isinstance(extra, str) and extra:
try:
parsed = json.loads(extra)
except json.JSONDecodeError:
logging.warning(
"Ignoring malformed RAPTOR extra payload while collecting chunk metadata: %s",
extra[:200],
exc_info=True,
)
return {}
return parsed if isinstance(parsed, dict) else {}
return {}
def _has_raptor_marker(marker) -> bool:
"""Return whether a chunk marker identifies a RAPTOR summary chunk."""
if isinstance(marker, list):
return any(str(item) == RAPTOR_TREE_BUILDER for item in marker)
return str(marker) == RAPTOR_TREE_BUILDER
def _raptor_methods_from_fields(fields: dict, extra: dict | None = None) -> set[str]:
"""Read RAPTOR builder methods from stored chunk fields."""
extra = extra if extra is not None else _as_extra_dict(fields.get("extra"))
method = extra.get("raptor_method") or RAPTOR_TREE_BUILDER
if isinstance(method, list):
return {str(item) for item in method if item}
return {str(method)} if method else set()
def collect_raptor_methods(field_map: dict) -> set[str]:
"""Collect tree-builder methods from RAPTOR summary chunk fields."""
methods = set()
for fields in field_map.values():
extra = _as_extra_dict(fields.get("extra"))
marker = fields.get("raptor_kwd") or extra.get("raptor_kwd")
if not _has_raptor_marker(marker):
continue
methods.update(_raptor_methods_from_fields(fields, extra))
return methods
def collect_raptor_chunk_ids(field_map: dict, exclude_methods: set[str] | None = None) -> set[str]:
"""Collect RAPTOR summary chunk IDs, optionally excluding some methods."""
chunk_ids = set()
exclude_methods = exclude_methods or set()
for chunk_id, fields in field_map.items():
extra = _as_extra_dict(fields.get("extra"))
marker = fields.get("raptor_kwd") or extra.get("raptor_kwd")
if _has_raptor_marker(marker):
if _raptor_methods_from_fields(fields, extra).issubset(exclude_methods):
continue
chunk_ids.add(chunk_id)
return chunk_ids
def make_raptor_summary_chunk_id(content: str, doc_id: str) -> str:
"""Build the stable ID used for generated RAPTOR summary chunks."""
return xxhash.xxh64((content + str(doc_id)).encode("utf-8")).hexdigest()
def is_structured_file_type(file_type: Optional[str]) -> bool:
"""
Check if a file type is structured data (Excel, CSV, etc.)