mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 23:41:12 +08:00
feat(raptor): add Psi tree builder with original-space ranking and safe migration (#14679)
### What problem does this PR solve? Closes #14674. This PR improves RAPTOR configuration and tree construction while preserving the existing RAPTOR behavior as the default. RAPTOR currently builds summary layers with the original UMAP + GMM clustering path. This PR keeps that default path, and adds: - A hidden backend tree-builder option: - `tree_builder="raptor"`: default, existing RAPTOR behavior. - `tree_builder="psi"`: rank-aware Psi-style tree builder using original embedding-space cosine ranking. - A user-facing clustering method option for the default RAPTOR builder: - `clustering_method="gmm"`: existing default. - `clustering_method="ahc"`: agglomerative hierarchical clustering path. - A RAPTOR UI setting for `Clustering method` and `Max cluster`. ### What changed #### Backend - Added `tree_builder` support for RAPTOR/Psi. - Added `clustering_method` support for GMM/AHC. - Kept existing RAPTOR + GMM as the default. - Added Psi tree building from original-space cosine similarity. - Added bucketed Psi building controls for large inputs: - `raptor.ext.psi_exact_max_leaves` - `raptor.ext.psi_bucket_size` - Added method-aware RAPTOR summary metadata using existing `extra.raptor_method`. - Avoided adding a dedicated DB schema field for experimental method tracking. - Added cleanup/migration logic to avoid mixing stale RAPTOR summary trees. - Added defensive checks for Psi tree construction and summary failures. #### Frontend/UI - Added `Clustering method` in RAPTOR settings with `GMM` and `AHC`. - Added/kept `Max cluster` in RAPTOR settings. - Enlarged max cluster UI limit to `1024`, matching backend validation. - Kept AHC editable even when a RAPTOR task has already finished. - Fixed the UI save payload so `clustering_method` and `tree_builder` are serialized through `parser_config.raptor.ext`, avoiding backend validation errors for extra top-level RAPTOR fields. Example saved RAPTOR config: ```json { "raptor": { "max_cluster": 317, "ext": { "clustering_method": "ahc", "tree_builder": "raptor" } } } Co-authored-by: CaptainTimon <CaptainTimon@users.noreply.github.com>
This commit is contained in:
637
rag/raptor.py
637
rag/raptor.py
@@ -14,11 +14,13 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
import asyncio
|
||||
from dataclasses import dataclass, field
|
||||
import logging
|
||||
import re
|
||||
|
||||
import numpy as np
|
||||
import umap
|
||||
from sklearn.cluster import AgglomerativeClustering
|
||||
from sklearn.mixture import GaussianMixture
|
||||
|
||||
from api.db.services.task_service import has_canceled
|
||||
@@ -33,9 +35,127 @@ from rag.graphrag.utils import (
|
||||
set_llm_cache,
|
||||
)
|
||||
from common.misc_utils import thread_pool_exec
|
||||
from rag.utils.raptor_utils import (
|
||||
AHC_CLUSTERING_METHOD,
|
||||
GMM_CLUSTERING_METHOD,
|
||||
PSI_TREE_BUILDER,
|
||||
RAPTOR_TREE_BUILDER,
|
||||
SUPPORTED_CLUSTERING_METHODS,
|
||||
SUPPORTED_TREE_BUILDERS,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _PsiTreeNode:
|
||||
"""Node used to represent the in-memory Psi merge tree."""
|
||||
|
||||
index: int
|
||||
text: str = ""
|
||||
embedding: np.ndarray | None = None
|
||||
children: list["_PsiTreeNode"] = field(default_factory=list)
|
||||
parent: "_PsiTreeNode | None" = None
|
||||
|
||||
|
||||
class _PsiUnionFind:
|
||||
"""Build parent links for the Psi merge tree from ranked leaf pairs."""
|
||||
|
||||
def __init__(self, n: int):
|
||||
"""Initialize the union-find state for n leaf nodes."""
|
||||
self._rank = [0 for _ in range(n)]
|
||||
self._parent_chains = [[] for _ in range(n)]
|
||||
self._node_ids = [[i] for i in range(n)]
|
||||
self._tree = [-1 for _ in range(max(1, 2 * n - 1))]
|
||||
self._next_id = n
|
||||
|
||||
@staticmethod
|
||||
def _ordered_extend(target: list[int], values: list[int]):
|
||||
"""Append unseen values while preserving their original order."""
|
||||
for value in values:
|
||||
if value not in target:
|
||||
target.append(value)
|
||||
|
||||
def _find(self, i: int) -> list[int]:
|
||||
"""Return the parent chain for a leaf, extending it lazily."""
|
||||
chain = self._parent_chains[i]
|
||||
if not chain or (len(chain) == 1 and chain[0] == i):
|
||||
return [i]
|
||||
if chain[0] == i:
|
||||
self._ordered_extend(chain, self._find(chain[1]))
|
||||
else:
|
||||
self._ordered_extend(chain, self._find(chain[0]))
|
||||
return chain
|
||||
|
||||
def _rank_bisect_right(self, chain: list[int], rank: int) -> int:
|
||||
"""Return the first chain index whose rank is greater than rank."""
|
||||
idx = 0
|
||||
while idx < len(chain) and self._rank[chain[idx]] <= rank:
|
||||
idx += 1
|
||||
return idx
|
||||
|
||||
def _build(self, i: int, j: int, insert_point: int | None = None):
|
||||
"""Record a merge edge in the compact parent array."""
|
||||
if insert_point is not None:
|
||||
parent_ids = self._node_ids[insert_point]
|
||||
parent_rank_idx = self._rank[i] + 1
|
||||
if parent_rank_idx >= len(parent_ids):
|
||||
logging.warning(
|
||||
"RAPTOR Psi union fallback: rank index %d is out of bounds for node %d with %d parent ids",
|
||||
parent_rank_idx,
|
||||
insert_point,
|
||||
len(parent_ids),
|
||||
)
|
||||
parent_rank_idx = len(parent_ids) - 1
|
||||
self._tree[self._node_ids[i][-1]] = parent_ids[parent_rank_idx]
|
||||
return
|
||||
self._tree[self._node_ids[i][-1]] = self._next_id
|
||||
self._tree[self._node_ids[j][-1]] = self._next_id
|
||||
self._node_ids[i].append(self._next_id)
|
||||
self._next_id += 1
|
||||
|
||||
def union(self, i: int, j: int) -> bool:
|
||||
"""Merge two ranked leaves and return whether a new edge was added."""
|
||||
root_i = self._find(i)[-1]
|
||||
root_j = self._find(j)[-1]
|
||||
if root_i == root_j:
|
||||
return False
|
||||
|
||||
if self._rank[root_i] < self._rank[root_j]:
|
||||
if not self._parent_chains[root_j]:
|
||||
self._parent_chains[root_j].append(root_j)
|
||||
chain = self._parent_chains[j]
|
||||
higher_rank_idx = self._rank_bisect_right(chain, self._rank[root_i])
|
||||
if higher_rank_idx >= len(chain):
|
||||
higher_rank_idx = len(chain) - 1
|
||||
insert_point = chain[higher_rank_idx]
|
||||
self._ordered_extend(self._parent_chains[root_i], chain[higher_rank_idx:])
|
||||
self._build(root_i, root_j, insert_point=insert_point)
|
||||
elif self._rank[root_i] > self._rank[root_j]:
|
||||
if not self._parent_chains[root_i]:
|
||||
self._parent_chains[root_i].append(root_i)
|
||||
chain = self._parent_chains[i]
|
||||
higher_rank_idx = self._rank_bisect_right(chain, self._rank[root_j])
|
||||
if higher_rank_idx >= len(chain):
|
||||
higher_rank_idx = len(chain) - 1
|
||||
insert_point = chain[higher_rank_idx]
|
||||
self._ordered_extend(self._parent_chains[root_j], chain[higher_rank_idx:])
|
||||
self._build(root_j, root_i, insert_point=insert_point)
|
||||
else:
|
||||
if not self._parent_chains[root_i]:
|
||||
self._parent_chains[root_i].append(root_i)
|
||||
self._ordered_extend(self._parent_chains[root_j], self._parent_chains[i][-1:])
|
||||
self._rank[root_i] += 1
|
||||
self._build(root_i, root_j)
|
||||
return True
|
||||
|
||||
@property
|
||||
def tree(self) -> list[int]:
|
||||
"""Return the compact child-to-parent array for constructed nodes."""
|
||||
return self._tree[:self._next_id]
|
||||
|
||||
|
||||
class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
|
||||
"""Build RAPTOR summary layers with the classic or Psi tree strategy."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_cluster,
|
||||
@@ -45,7 +165,12 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
|
||||
max_token=512,
|
||||
threshold=0.1,
|
||||
max_errors=3,
|
||||
tree_builder=RAPTOR_TREE_BUILDER,
|
||||
clustering_method=GMM_CLUSTERING_METHOD,
|
||||
psi_exact_max_leaves=4096,
|
||||
psi_bucket_size=1024,
|
||||
):
|
||||
"""Configure RAPTOR summarization, clustering, and Psi limits."""
|
||||
self._max_cluster = max_cluster
|
||||
self._llm_model = llm_model
|
||||
self._embd_model = embd_model
|
||||
@@ -54,8 +179,17 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
|
||||
self._max_token = max_token
|
||||
self._max_errors = max(1, max_errors)
|
||||
self._error_count = 0
|
||||
|
||||
self._tree_builder = tree_builder or RAPTOR_TREE_BUILDER
|
||||
if self._tree_builder not in SUPPORTED_TREE_BUILDERS:
|
||||
raise ValueError(f"Unsupported RAPTOR tree builder: {self._tree_builder}")
|
||||
self._clustering_method = clustering_method or GMM_CLUSTERING_METHOD
|
||||
if self._clustering_method not in SUPPORTED_CLUSTERING_METHODS:
|
||||
raise ValueError(f"Unsupported RAPTOR clustering method: {self._clustering_method}")
|
||||
self._psi_exact_max_leaves = max(2, int(psi_exact_max_leaves or 4096))
|
||||
self._psi_bucket_size = min(max(2, int(psi_bucket_size or 1024)), self._psi_exact_max_leaves)
|
||||
|
||||
def _check_task_canceled(self, task_id: str, message: str = ""):
|
||||
"""Raise if the current document task was canceled."""
|
||||
if task_id and has_canceled(task_id):
|
||||
log_msg = f"Task {task_id} cancelled during RAPTOR {message}."
|
||||
logging.info(log_msg)
|
||||
@@ -63,6 +197,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
|
||||
|
||||
@timeout(60 * 20)
|
||||
async def _chat(self, system, history, gen_conf):
|
||||
"""Call the configured LLM with caching and short retries."""
|
||||
cached = await thread_pool_exec(get_llm_cache, self._llm_model.llm_name, system, history, gen_conf)
|
||||
if cached:
|
||||
return cached
|
||||
@@ -86,6 +221,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
|
||||
|
||||
@timeout(20)
|
||||
async def _embedding_encode(self, txt):
|
||||
"""Encode text with the configured embedding model and cache result."""
|
||||
response = await thread_pool_exec(get_embed_cache, self._embd_model.llm_name, txt)
|
||||
if response is not None:
|
||||
return response
|
||||
@@ -97,6 +233,7 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
|
||||
return embds
|
||||
|
||||
def _get_optimal_clusters(self, embeddings: np.ndarray, random_state: int, task_id: str = ""):
|
||||
"""Choose the GMM cluster count with the lowest BIC score."""
|
||||
max_clusters = min(self._max_cluster, len(embeddings))
|
||||
n_clusters = np.arange(1, max_clusters)
|
||||
bics = []
|
||||
@@ -109,57 +246,422 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
|
||||
optimal_clusters = n_clusters[np.argmin(bics)]
|
||||
return optimal_clusters
|
||||
|
||||
def _get_clusters_ahc(self, embeddings: np.ndarray, task_id: str = "") -> np.ndarray:
|
||||
"""Cluster embeddings with Ward-linkage AHC and a dendrogram gap heuristic."""
|
||||
n = len(embeddings)
|
||||
if n <= 1:
|
||||
return np.zeros(n, dtype=int)
|
||||
if n == 2:
|
||||
return np.arange(n)
|
||||
|
||||
self._check_task_canceled(task_id, "_get_clusters_ahc dendrogram")
|
||||
full_clust = AgglomerativeClustering(
|
||||
n_clusters=None,
|
||||
distance_threshold=0,
|
||||
compute_distances=True,
|
||||
linkage="ward",
|
||||
)
|
||||
full_clust.fit(embeddings)
|
||||
|
||||
distances = full_clust.distances_
|
||||
if len(distances) > 1:
|
||||
gaps = np.diff(distances)
|
||||
max_gap_idx = int(np.argmax(gaps))
|
||||
n_clusters = max(1, min(n - max_gap_idx - 1, self._max_cluster))
|
||||
else:
|
||||
n_clusters = max(1, min(n, self._max_cluster))
|
||||
if n_clusters <= 1:
|
||||
logging.info("RAPTOR AHC: _get_clusters_ahc selected one cluster for %d embeddings", n)
|
||||
return np.zeros(n, dtype=int)
|
||||
|
||||
logging.info("RAPTOR AHC: _get_clusters_ahc selected n_clusters=%d for %d embeddings", n_clusters, n)
|
||||
self._check_task_canceled(task_id, "_get_clusters_ahc fit")
|
||||
clustering = AgglomerativeClustering(n_clusters=n_clusters, linkage="ward")
|
||||
return clustering.fit_predict(embeddings)
|
||||
|
||||
def _adjust_tree_nodes(self, embeddings: np.ndarray, labels: np.ndarray, max_iter: int = 5) -> np.ndarray:
|
||||
"""Refine AHC assignments by reassigning nodes to nearest centroids."""
|
||||
labels = labels.copy()
|
||||
for _ in range(max_iter):
|
||||
unique_labels = np.unique(labels)
|
||||
if len(unique_labels) <= 1:
|
||||
return labels
|
||||
centroids = np.stack([embeddings[labels == lbl].mean(axis=0) for lbl in unique_labels])
|
||||
diffs = embeddings[:, np.newaxis, :] - centroids[np.newaxis, :, :]
|
||||
sq_dists = (diffs**2).sum(axis=2)
|
||||
new_label_indices = np.argmin(sq_dists, axis=1)
|
||||
new_labels = unique_labels[new_label_indices]
|
||||
if np.array_equal(new_labels, labels):
|
||||
break
|
||||
unique_new = np.unique(new_labels)
|
||||
remap = {old: new for new, old in enumerate(unique_new)}
|
||||
labels = np.array([remap[int(lbl)] for lbl in new_labels])
|
||||
return labels
|
||||
|
||||
@timeout(60 * 20)
|
||||
async def _summarize_texts(self, texts: list[str], callback=None, task_id: str = ""):
|
||||
"""Summarize a cluster and return text plus embedding when successful."""
|
||||
self._check_task_canceled(task_id, "summarization")
|
||||
|
||||
len_per_chunk = int((self._llm_model.max_length - self._max_token) / len(texts))
|
||||
cluster_content = "\n".join([truncate(t, max(1, len_per_chunk)) for t in texts])
|
||||
try:
|
||||
async with chat_limiter:
|
||||
self._check_task_canceled(task_id, "before LLM call")
|
||||
|
||||
cnt = await self._chat(
|
||||
"You're a helpful assistant.",
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": self._prompt.format(cluster_content=cluster_content),
|
||||
}
|
||||
],
|
||||
{"max_tokens": max(self._max_token, 512)}, # fix issue: #10235
|
||||
)
|
||||
cnt = re.sub(
|
||||
"(······\n由于长度的原因,回答被截断了,要继续吗?|For the content length reason, it stopped, continue?)",
|
||||
"",
|
||||
cnt,
|
||||
)
|
||||
logging.debug(f"SUM: {cnt}")
|
||||
|
||||
self._check_task_canceled(task_id, "before embedding")
|
||||
|
||||
embds = await self._embedding_encode(cnt)
|
||||
return cnt, embds
|
||||
except TaskCanceledException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
self._error_count += 1
|
||||
warn_msg = f"[RAPTOR] Skip cluster ({len(texts)} chunks) due to error: {exc}"
|
||||
logging.warning(warn_msg)
|
||||
if callback:
|
||||
callback(msg=warn_msg)
|
||||
if self._error_count >= self._max_errors:
|
||||
raise RuntimeError(f"RAPTOR aborted after {self._error_count} errors. Last error: {exc}") from exc
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _root(node: _PsiTreeNode) -> _PsiTreeNode:
|
||||
"""Return the current root for a Psi tree node."""
|
||||
while node.parent is not None:
|
||||
node = node.parent
|
||||
return node
|
||||
|
||||
def _rank_leaf_pairs(self, leaves: list[_PsiTreeNode]) -> np.ndarray:
|
||||
"""Rank all leaf pairs by original embedding-space cosine similarity."""
|
||||
node_embeddings = np.asarray([leaf.embedding for leaf in leaves], dtype=np.float64)
|
||||
node_embeddings = self._normalize_embeddings(node_embeddings)
|
||||
similarities = node_embeddings @ node_embeddings.T
|
||||
lower = np.tril_indices(len(leaves), -1)
|
||||
ordered = np.argsort(similarities[lower], axis=0)[::-1]
|
||||
return np.stack([lower[0][ordered], lower[1][ordered]], axis=-1)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_embeddings(node_embeddings: np.ndarray) -> np.ndarray:
|
||||
"""Normalize embeddings for cosine operations while tolerating zero vectors."""
|
||||
node_embeddings = np.asarray(node_embeddings, dtype=np.float64)
|
||||
norms = np.linalg.norm(node_embeddings, axis=1, keepdims=True)
|
||||
return node_embeddings / np.maximum(norms, 1e-12)
|
||||
|
||||
def _split_psi_buckets(self, nodes: list[_PsiTreeNode]) -> list[list[_PsiTreeNode]]:
|
||||
"""Split large Psi inputs so exact pair ranking is bounded per bucket."""
|
||||
if len(nodes) <= self._psi_bucket_size:
|
||||
return [nodes]
|
||||
|
||||
node_embeddings = self._normalize_embeddings(np.asarray([node.embedding for node in nodes], dtype=np.float64))
|
||||
groups = [np.arange(len(nodes), dtype=int)]
|
||||
buckets = []
|
||||
|
||||
while groups:
|
||||
group = np.asarray(groups.pop(), dtype=int)
|
||||
if len(group) <= self._psi_bucket_size:
|
||||
buckets.append(group.tolist())
|
||||
continue
|
||||
|
||||
fanout = min(max(2, int(np.ceil(len(group) / self._psi_bucket_size))), len(group), 32)
|
||||
group_embeddings = node_embeddings[group]
|
||||
center_idx = np.linspace(0, len(group_embeddings) - 1, num=fanout, dtype=int)
|
||||
centers = group_embeddings[center_idx].copy()
|
||||
|
||||
for _ in range(5):
|
||||
labels = np.argmax(group_embeddings @ centers.T, axis=1)
|
||||
for center_id in range(fanout):
|
||||
mask = labels == center_id
|
||||
if not np.any(mask):
|
||||
continue
|
||||
center = group_embeddings[mask].mean(axis=0)
|
||||
norm = np.linalg.norm(center)
|
||||
centers[center_id] = center / norm if norm > 0 else center
|
||||
|
||||
labels = np.argmax(group_embeddings @ centers.T, axis=1)
|
||||
split_groups = [group[labels == center_id].tolist() for center_id in range(fanout)]
|
||||
split_groups = [bucket for bucket in split_groups if bucket]
|
||||
if len(split_groups) <= 1:
|
||||
split_groups = [
|
||||
group[start:start + self._psi_bucket_size].tolist()
|
||||
for start in range(0, len(group), self._psi_bucket_size)
|
||||
]
|
||||
groups.extend(split_groups)
|
||||
|
||||
buckets = [bucket for bucket in buckets if bucket]
|
||||
buckets.sort(key=lambda bucket: (len(bucket), bucket[0]))
|
||||
return [[nodes[idx] for idx in bucket] for bucket in buckets]
|
||||
|
||||
def _assign_prototype_embeddings(self, node: _PsiTreeNode) -> np.ndarray:
|
||||
"""Assign mean child embeddings to internal Psi nodes for bucket-level ranking."""
|
||||
if not node.children:
|
||||
return np.asarray(node.embedding, dtype=np.float64)
|
||||
embeddings = np.asarray([self._assign_prototype_embeddings(child) for child in node.children], dtype=np.float64)
|
||||
node.embedding = embeddings.mean(axis=0)
|
||||
return node.embedding
|
||||
|
||||
@staticmethod
|
||||
def _iter_nodes(root: _PsiTreeNode):
|
||||
"""Yield nodes in a Psi tree using a stack traversal."""
|
||||
stack = [root]
|
||||
while stack:
|
||||
node = stack.pop()
|
||||
yield node
|
||||
stack.extend(node.children)
|
||||
|
||||
def _create_psi_parent(self, index: int, children: list[_PsiTreeNode]) -> _PsiTreeNode:
|
||||
"""Create a parent node and attach the provided children to it."""
|
||||
parent = _PsiTreeNode(index=index, children=children)
|
||||
for child in children:
|
||||
child.parent = parent
|
||||
return parent
|
||||
|
||||
def _rebalance_psi_tree(self, root: _PsiTreeNode, next_index: int) -> tuple[_PsiTreeNode, int]:
|
||||
"""Group oversized Psi tree nodes so fanout stays within max_cluster."""
|
||||
max_children = max(2, int(self._max_cluster or 2))
|
||||
|
||||
def rebalance(node: _PsiTreeNode):
|
||||
"""Recursively group children when a Psi node exceeds fanout."""
|
||||
nonlocal next_index
|
||||
|
||||
for child in list(node.children):
|
||||
rebalance(child)
|
||||
|
||||
while len(node.children) > max_children:
|
||||
original_children = len(node.children)
|
||||
grouped_children = []
|
||||
for start in range(0, len(node.children), max_children):
|
||||
batch = node.children[start:start + max_children]
|
||||
if len(batch) == 1:
|
||||
grouped_children.append(batch[0])
|
||||
batch[0].parent = node
|
||||
else:
|
||||
grouped_children.append(self._create_psi_parent(next_index, batch))
|
||||
grouped_children[-1].parent = node
|
||||
next_index += 1
|
||||
node.children = grouped_children
|
||||
logging.info(
|
||||
"RAPTOR Psi rebalance: node=%s children=%d grouped_to=%d max_cluster=%d",
|
||||
node.index,
|
||||
original_children,
|
||||
len(grouped_children),
|
||||
max_children,
|
||||
)
|
||||
|
||||
rebalance(root)
|
||||
return self._root(root), next_index
|
||||
|
||||
def _build_exact_psi_structure(
|
||||
self,
|
||||
nodes: list[_PsiTreeNode],
|
||||
next_index: int,
|
||||
task_id: str = "",
|
||||
) -> tuple[_PsiTreeNode, int, int]:
|
||||
"""Build an exact Psi subtree for a bounded node set."""
|
||||
if len(nodes) == 1:
|
||||
return nodes[0], next_index, 0
|
||||
|
||||
ranked_pairs = self._rank_leaf_pairs(nodes)
|
||||
union_find = _PsiUnionFind(len(nodes))
|
||||
merges = 0
|
||||
for left_idx, right_idx in ranked_pairs:
|
||||
self._check_task_canceled(task_id, "Psi tree construction")
|
||||
if union_find.union(int(left_idx), int(right_idx)):
|
||||
merges += 1
|
||||
if merges == len(nodes) - 1:
|
||||
break
|
||||
|
||||
local_nodes = {idx: node for idx, node in enumerate(nodes)}
|
||||
tree = union_find.tree
|
||||
children_by_parent = {}
|
||||
for child_idx, parent_idx in enumerate(tree):
|
||||
if child_idx not in local_nodes:
|
||||
local_nodes[child_idx] = _PsiTreeNode(index=next_index)
|
||||
next_index += 1
|
||||
if parent_idx == -1:
|
||||
continue
|
||||
children_by_parent.setdefault(parent_idx, []).append(child_idx)
|
||||
if parent_idx not in local_nodes:
|
||||
local_nodes[parent_idx] = _PsiTreeNode(index=next_index)
|
||||
next_index += 1
|
||||
|
||||
for parent_idx, child_indices in children_by_parent.items():
|
||||
parent = local_nodes[parent_idx]
|
||||
parent.children = [local_nodes[child_idx] for child_idx in child_indices]
|
||||
for child in parent.children:
|
||||
child.parent = parent
|
||||
|
||||
roots = [local_nodes[idx] for idx, parent_idx in enumerate(tree) if parent_idx == -1 and idx in local_nodes]
|
||||
root = max(roots, key=lambda node: node.index)
|
||||
return root, next_index, merges
|
||||
|
||||
def _build_bucketed_psi_structure(
|
||||
self,
|
||||
nodes: list[_PsiTreeNode],
|
||||
next_index: int,
|
||||
task_id: str = "",
|
||||
) -> tuple[_PsiTreeNode, int, int]:
|
||||
"""Build large Psi trees by exact-ranking bounded buckets, then bucket roots."""
|
||||
buckets = self._split_psi_buckets(nodes)
|
||||
logging.info(
|
||||
"RAPTOR Psi bucketed build: nodes=%d buckets=%d bucket_size=%d exact_max_leaves=%d",
|
||||
len(nodes),
|
||||
len(buckets),
|
||||
self._psi_bucket_size,
|
||||
self._psi_exact_max_leaves,
|
||||
)
|
||||
|
||||
bucket_roots = []
|
||||
merges = 0
|
||||
for bucket in buckets:
|
||||
bucket_root, next_index, bucket_merges = self._build_psi_structure_from_nodes(bucket, next_index, task_id)
|
||||
self._assign_prototype_embeddings(bucket_root)
|
||||
bucket_roots.append(bucket_root)
|
||||
merges += bucket_merges
|
||||
|
||||
if len(bucket_roots) == 1:
|
||||
return bucket_roots[0], next_index, merges
|
||||
|
||||
root, next_index, root_merges = self._build_psi_structure_from_nodes(bucket_roots, next_index, task_id)
|
||||
return root, next_index, merges + root_merges
|
||||
|
||||
def _build_psi_structure_from_nodes(
|
||||
self,
|
||||
nodes: list[_PsiTreeNode],
|
||||
next_index: int,
|
||||
task_id: str = "",
|
||||
) -> tuple[_PsiTreeNode, int, int]:
|
||||
"""Build Psi structure exactly for small sets and bucket large sets."""
|
||||
if len(nodes) <= self._psi_exact_max_leaves:
|
||||
return self._build_exact_psi_structure(nodes, next_index, task_id)
|
||||
return self._build_bucketed_psi_structure(nodes, next_index, task_id)
|
||||
|
||||
def _build_psi_structure(self, chunks, task_id: str = "") -> tuple[_PsiTreeNode, list[_PsiTreeNode]]:
|
||||
"""Build the Psi merge tree from original chunk embeddings."""
|
||||
leaves = [
|
||||
_PsiTreeNode(index=i, text=text, embedding=np.asarray(embd))
|
||||
for i, (text, embd) in enumerate(chunks)
|
||||
]
|
||||
if len(leaves) == 1:
|
||||
return leaves[0], leaves
|
||||
|
||||
root, next_index, merges = self._build_psi_structure_from_nodes(leaves, len(leaves), task_id)
|
||||
root, _ = self._rebalance_psi_tree(root, next_index)
|
||||
logging.info(
|
||||
"RAPTOR Psi tree built: leaves=%d merges=%d root_fanout=%d",
|
||||
len(leaves),
|
||||
merges,
|
||||
len(root.children),
|
||||
)
|
||||
return root, leaves
|
||||
|
||||
@staticmethod
|
||||
def _psi_layers(root: _PsiTreeNode) -> dict[int, list[_PsiTreeNode]]:
|
||||
"""Collect non-leaf Psi nodes by height for bottom-up summarization."""
|
||||
layers = {}
|
||||
|
||||
def height(node: _PsiTreeNode) -> int:
|
||||
"""Return node height while collecting internal nodes by layer."""
|
||||
if not node.children:
|
||||
return 0
|
||||
node_height = max(height(child) for child in node.children) + 1
|
||||
layers.setdefault(node_height, []).append(node)
|
||||
return node_height
|
||||
|
||||
height(root)
|
||||
return layers
|
||||
|
||||
async def _build_psi_layers(self, chunks, callback=None, task_id: str = ""):
|
||||
"""Materialize Psi tree layers as summary chunks."""
|
||||
layers = [(0, len(chunks))]
|
||||
root, _ = self._build_psi_structure(chunks, task_id=task_id)
|
||||
|
||||
for layer_idx, (_, nodes) in enumerate(sorted(self._psi_layers(root).items()), start=1):
|
||||
layer_start = len(chunks)
|
||||
|
||||
async def summarize_node(node: _PsiTreeNode):
|
||||
"""Summarize one Psi internal node if its children have text."""
|
||||
texts = [child.text for child in node.children if child.text]
|
||||
if not texts:
|
||||
logging.warning("RAPTOR Psi node %s skipped because it has no child text to summarize", node.index)
|
||||
return None
|
||||
result = await self._summarize_texts(texts, callback, task_id)
|
||||
if result is None:
|
||||
logging.warning("RAPTOR Psi node %s skipped because summarization failed", node.index)
|
||||
return None
|
||||
node.text, node.embedding = result
|
||||
return node
|
||||
|
||||
tasks = [asyncio.create_task(summarize_node(node)) for node in nodes]
|
||||
try:
|
||||
summarized_nodes = await asyncio.gather(*tasks, return_exceptions=False)
|
||||
except Exception as e:
|
||||
logging.error(f"Error in RAPTOR Psi tree processing: {e}")
|
||||
for task in tasks:
|
||||
task.cancel()
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
raise
|
||||
|
||||
summarized_nodes = [node for node in summarized_nodes if node is not None]
|
||||
for node in summarized_nodes:
|
||||
chunks.append((node.text, node.embedding))
|
||||
|
||||
if len(chunks) > layer_start:
|
||||
layers.append((layer_start, len(chunks)))
|
||||
logging.info(
|
||||
"RAPTOR Psi layer materialized: layer=%d nodes=%d summaries=%d",
|
||||
layer_idx,
|
||||
len(nodes),
|
||||
len(chunks) - layer_start,
|
||||
)
|
||||
if callback:
|
||||
callback(msg="Build one Psi-RAG layer: {} -> {}".format(len(nodes), len(chunks) - layer_start))
|
||||
else:
|
||||
logging.warning("RAPTOR Psi layer %d produced no summaries; stopping materialization", layer_idx)
|
||||
break
|
||||
|
||||
return chunks, layers
|
||||
|
||||
async def __call__(self, chunks, random_state, callback=None, task_id: str = ""):
|
||||
"""Build summary chunks and layer boundaries for RAPTOR retrieval."""
|
||||
if len(chunks) <= 1:
|
||||
return [], []
|
||||
chunks = [(s, a) for s, a in chunks if s and a is not None and len(a) > 0]
|
||||
if len(chunks) <= 1:
|
||||
return chunks, [(0, len(chunks))]
|
||||
if self._tree_builder == PSI_TREE_BUILDER:
|
||||
logging.info("RAPTOR: using %s tree builder for %d chunks", self._tree_builder, len(chunks))
|
||||
return await self._build_psi_layers(chunks, callback, task_id)
|
||||
|
||||
layers = [(0, len(chunks))]
|
||||
start, end = 0, len(chunks)
|
||||
|
||||
@timeout(60 * 20)
|
||||
async def summarize(ck_idx: list[int]):
|
||||
"""Summarize one classic RAPTOR cluster into the chunk list."""
|
||||
nonlocal chunks
|
||||
|
||||
self._check_task_canceled(task_id, "summarization")
|
||||
|
||||
texts = [chunks[i][0] for i in ck_idx]
|
||||
len_per_chunk = int((self._llm_model.max_length - self._max_token) / len(texts))
|
||||
cluster_content = "\n".join([truncate(t, max(1, len_per_chunk)) for t in texts])
|
||||
try:
|
||||
async with chat_limiter:
|
||||
self._check_task_canceled(task_id, "before LLM call")
|
||||
|
||||
cnt = await self._chat(
|
||||
"You're a helpful assistant.",
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": self._prompt.format(cluster_content=cluster_content),
|
||||
}
|
||||
],
|
||||
{"max_tokens": max(self._max_token, 512)}, # fix issue: #10235
|
||||
)
|
||||
cnt = re.sub(
|
||||
"(······\n由于长度的原因,回答被截断了,要继续吗?|For the content length reason, it stopped, continue?)",
|
||||
"",
|
||||
cnt,
|
||||
)
|
||||
logging.debug(f"SUM: {cnt}")
|
||||
|
||||
self._check_task_canceled(task_id, "before embedding")
|
||||
|
||||
embds = await self._embedding_encode(cnt)
|
||||
chunks.append((cnt, embds))
|
||||
except TaskCanceledException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
self._error_count += 1
|
||||
warn_msg = f"[RAPTOR] Skip cluster ({len(ck_idx)} chunks) due to error: {exc}"
|
||||
logging.warning(warn_msg)
|
||||
if callback:
|
||||
callback(msg=warn_msg)
|
||||
if self._error_count >= self._max_errors:
|
||||
raise RuntimeError(f"RAPTOR aborted after {self._error_count} errors. Last error: {exc}") from exc
|
||||
result = await self._summarize_texts(texts, callback, task_id)
|
||||
if result is not None:
|
||||
chunks.append(result)
|
||||
|
||||
while end - start > 1:
|
||||
self._check_task_canceled(task_id, "layer processing")
|
||||
@@ -167,8 +669,12 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
|
||||
embeddings = [embd for _, embd in chunks[start:end]]
|
||||
if len(embeddings) == 2:
|
||||
await summarize([start, start + 1])
|
||||
produced = len(chunks) - end
|
||||
if produced == 0:
|
||||
logging.warning("RAPTOR layer produced no summaries; stopping materialization")
|
||||
break
|
||||
if callback:
|
||||
callback(msg="Cluster one layer: {} -> {}".format(end - start, len(chunks) - end))
|
||||
callback(msg="Cluster one layer: {} -> {}".format(end - start, produced))
|
||||
layers.append((end, len(chunks)))
|
||||
start = end
|
||||
end = len(chunks)
|
||||
@@ -180,15 +686,37 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
|
||||
n_components=min(12, len(embeddings) - 2),
|
||||
metric="cosine",
|
||||
).fit_transform(embeddings)
|
||||
n_clusters = self._get_optimal_clusters(reduced_embeddings, random_state, task_id=task_id)
|
||||
if self._clustering_method == AHC_CLUSTERING_METHOD:
|
||||
logging.info("RAPTOR: using clustering_method=%s before _get_clusters_ahc", self._clustering_method)
|
||||
raw_labels = self._get_clusters_ahc(reduced_embeddings, task_id=task_id)
|
||||
raw_cluster_count = np.unique(raw_labels).size
|
||||
logging.info("RAPTOR AHC: _get_clusters_ahc produced n_clusters=%d", raw_cluster_count)
|
||||
if raw_cluster_count > 1:
|
||||
adjusted = self._adjust_tree_nodes(reduced_embeddings, raw_labels)
|
||||
adjusted_cluster_count = np.unique(adjusted).size
|
||||
logging.info("RAPTOR AHC: _adjust_tree_nodes adjusted n_clusters=%d", adjusted_cluster_count)
|
||||
else:
|
||||
adjusted = raw_labels
|
||||
logging.warning("RAPTOR AHC: _adjust_tree_nodes skipped because _get_clusters_ahc returned one cluster")
|
||||
unique_labels = np.unique(adjusted)
|
||||
label_map = {old: idx for idx, old in enumerate(unique_labels)}
|
||||
lbls = [label_map[int(lbl)] for lbl in adjusted]
|
||||
n_clusters = len(unique_labels)
|
||||
else:
|
||||
n_clusters = self._get_optimal_clusters(reduced_embeddings, random_state, task_id=task_id)
|
||||
if n_clusters == 1:
|
||||
lbls = [0 for _ in range(len(reduced_embeddings))]
|
||||
else:
|
||||
gm = GaussianMixture(n_components=n_clusters, random_state=random_state)
|
||||
gm.fit(reduced_embeddings)
|
||||
probs = gm.predict_proba(reduced_embeddings)
|
||||
lbls = [np.where(prob > self._threshold)[0] for prob in probs]
|
||||
lbls = [lbl[0] if isinstance(lbl, np.ndarray) else lbl for lbl in lbls]
|
||||
|
||||
if n_clusters == 1:
|
||||
lbls = [0 for _ in range(len(reduced_embeddings))]
|
||||
else:
|
||||
gm = GaussianMixture(n_components=n_clusters, random_state=random_state)
|
||||
gm.fit(reduced_embeddings)
|
||||
probs = gm.predict_proba(reduced_embeddings)
|
||||
lbls = [np.where(prob > self._threshold)[0] for prob in probs]
|
||||
lbls = [lbl[0] if isinstance(lbl, np.ndarray) else lbl for lbl in lbls]
|
||||
lbls = [int(lbl[0]) if isinstance(lbl, np.ndarray) else int(lbl) for lbl in lbls]
|
||||
|
||||
tasks = []
|
||||
for c in range(n_clusters):
|
||||
@@ -205,10 +733,21 @@ class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval:
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
raise
|
||||
|
||||
assert len(chunks) - end == n_clusters, "{} vs. {}".format(len(chunks) - end, n_clusters)
|
||||
produced = len(chunks) - end
|
||||
assert produced <= n_clusters, "{} vs. {}".format(produced, n_clusters)
|
||||
if produced < n_clusters:
|
||||
logging.warning(
|
||||
"RAPTOR layer produced %d/%d cluster summaries; skipped %d cluster(s) due to errors",
|
||||
produced,
|
||||
n_clusters,
|
||||
n_clusters - produced,
|
||||
)
|
||||
if produced == 0:
|
||||
logging.warning("RAPTOR layer produced no summaries; stopping materialization")
|
||||
break
|
||||
layers.append((end, len(chunks)))
|
||||
if callback:
|
||||
callback(msg="Cluster one layer: {} -> {}".format(end - start, len(chunks) - end))
|
||||
callback(msg="Cluster one layer: {} -> {}".format(end - start, produced))
|
||||
start = end
|
||||
end = len(chunks)
|
||||
|
||||
|
||||
@@ -36,7 +36,15 @@ from api.db.joint_services.memory_message_service import handle_save_to_memory_t
|
||||
from common.connection_utils import timeout
|
||||
from common.metadata_utils import turn2jsonschema, update_metadata_to
|
||||
from rag.utils.base64_image import image2id
|
||||
from rag.utils.raptor_utils import should_skip_raptor, get_skip_reason
|
||||
from rag.utils.raptor_utils import (
|
||||
collect_raptor_chunk_ids,
|
||||
collect_raptor_methods,
|
||||
get_raptor_clustering_method,
|
||||
get_raptor_tree_builder,
|
||||
get_skip_reason,
|
||||
make_raptor_summary_chunk_id,
|
||||
should_skip_raptor,
|
||||
)
|
||||
from common.log_utils import init_root_logger
|
||||
from common.config_utils import show_configs
|
||||
from rag.graphrag.general.index import run_graphrag_for_kb
|
||||
@@ -70,7 +78,10 @@ from api.db.db_models import close_connection
|
||||
from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, \
|
||||
email, tag
|
||||
from rag.nlp import search, rag_tokenizer, add_positions
|
||||
from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor
|
||||
from rag.raptor import (
|
||||
RAPTOR_TREE_BUILDER,
|
||||
RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor,
|
||||
)
|
||||
from common.token_utils import num_tokens_from_string, truncate
|
||||
from rag.utils.redis_conn import REDIS_CONN, RedisDistributedLock
|
||||
from rag.graphrag.utils import chat_limiter
|
||||
@@ -817,61 +828,160 @@ async def run_dataflow(task: dict):
|
||||
dsl=str(pipeline))
|
||||
|
||||
|
||||
async def has_raptor_chunks(doc_id: str, tenant_id: str, kb_id: str) -> bool:
|
||||
"""Return True if RAPTOR chunks already exist for doc_id in the doc store.
|
||||
RAPTOR_METHOD_SEARCH_LIMIT = 10000
|
||||
|
||||
Queries directly for raptor_kwd="raptor" rows so a non-RAPTOR leading
|
||||
chunk cannot produce a false-negative result. Uses thread_pool_exec so
|
||||
the blocking doc-store call does not stall the event loop.
|
||||
"""
|
||||
|
||||
async def get_raptor_chunk_field_map(doc_id: str, tenant_id: str, kb_id: str) -> dict:
|
||||
"""Return stored RAPTOR marker fields for a document."""
|
||||
from common.doc_store.doc_store_base import OrderByExpr
|
||||
from rag.nlp import search as nlp_search
|
||||
try:
|
||||
condition = {"doc_id": doc_id, "raptor_kwd": ["raptor"]}
|
||||
|
||||
async def search_fields(fields: list[str], condition: dict, order_by=None):
|
||||
"""Search chunk fields in the current knowledge base."""
|
||||
res = await thread_pool_exec(
|
||||
settings.docStoreConn.search,
|
||||
["raptor_kwd"], [], condition, [], OrderByExpr(),
|
||||
0, 1, nlp_search.index_name(tenant_id), [kb_id]
|
||||
fields, [], condition, [], order_by or OrderByExpr(),
|
||||
0, RAPTOR_METHOD_SEARCH_LIMIT, nlp_search.index_name(tenant_id), [kb_id]
|
||||
)
|
||||
field_map = settings.docStoreConn.get_fields(res, ["raptor_kwd"])
|
||||
found = bool(field_map)
|
||||
if found:
|
||||
return settings.docStoreConn.get_fields(res, fields)
|
||||
|
||||
primary = await search_fields(["raptor_kwd", "extra"], {"doc_id": doc_id, "raptor_kwd": ["raptor"]})
|
||||
if collect_raptor_chunk_ids(primary):
|
||||
return primary
|
||||
|
||||
try:
|
||||
return await search_fields(
|
||||
["raptor_kwd", "extra"],
|
||||
{"doc_id": doc_id},
|
||||
OrderByExpr().desc("create_timestamp_flt"),
|
||||
)
|
||||
except Exception:
|
||||
logging.debug("RAPTOR fallback method lookup with extra field failed for doc %s", doc_id, exc_info=True)
|
||||
return primary
|
||||
|
||||
|
||||
async def get_raptor_chunk_methods(doc_id: str, tenant_id: str, kb_id: str) -> set[str]:
|
||||
"""Return the RAPTOR tree builders already stored for doc_id.
|
||||
|
||||
Queries directly for raptor_kwd="raptor" rows so a non-RAPTOR leading
|
||||
chunk cannot produce a false-negative result. Legacy summary chunks that
|
||||
do not have method metadata are treated as the original RAPTOR builder.
|
||||
"""
|
||||
try:
|
||||
field_map = await get_raptor_chunk_field_map(doc_id, tenant_id, kb_id)
|
||||
methods = collect_raptor_methods(field_map)
|
||||
if methods:
|
||||
logging.info(
|
||||
"Checkpoint hit: RAPTOR chunks for doc %s (tenant=%s kb=%s) already exist",
|
||||
doc_id, tenant_id, kb_id,
|
||||
"Checkpoint hit: RAPTOR chunks for doc %s (tenant=%s kb=%s methods=%s) already exist",
|
||||
doc_id, tenant_id, kb_id, sorted(methods),
|
||||
)
|
||||
else:
|
||||
logging.info(
|
||||
"Checkpoint miss: no RAPTOR chunks for doc %s (tenant=%s kb=%s)",
|
||||
doc_id, tenant_id, kb_id,
|
||||
)
|
||||
return found
|
||||
return methods
|
||||
except Exception:
|
||||
logging.exception("Failed to check RAPTOR chunks for doc %s", doc_id)
|
||||
return False
|
||||
raise
|
||||
|
||||
|
||||
async def has_raptor_chunks(doc_id: str, tenant_id: str, kb_id: str, tree_builder: str = RAPTOR_TREE_BUILDER) -> bool:
|
||||
"""Return whether doc_id already has summaries for tree_builder."""
|
||||
methods = await get_raptor_chunk_methods(doc_id, tenant_id, kb_id)
|
||||
return tree_builder in methods
|
||||
|
||||
|
||||
async def delete_raptor_chunks(doc_id: str, tenant_id: str, kb_id: str, keep_method: str | None = None):
|
||||
"""Delete RAPTOR summaries for doc_id, optionally preserving one method."""
|
||||
from rag.nlp import search as nlp_search
|
||||
|
||||
if keep_method is None:
|
||||
logging.info(
|
||||
"delete_raptor_chunks: removing all RAPTOR summaries (doc=%s tenant=%s kb=%s)",
|
||||
doc_id, tenant_id, kb_id,
|
||||
)
|
||||
await thread_pool_exec(
|
||||
settings.docStoreConn.delete,
|
||||
{"doc_id": doc_id, "raptor_kwd": ["raptor"]},
|
||||
nlp_search.index_name(tenant_id),
|
||||
kb_id,
|
||||
)
|
||||
return 0
|
||||
|
||||
field_map = await get_raptor_chunk_field_map(doc_id, tenant_id, kb_id)
|
||||
chunk_ids = collect_raptor_chunk_ids(field_map, exclude_methods={keep_method})
|
||||
if not chunk_ids:
|
||||
logging.debug(
|
||||
"delete_raptor_chunks: no stale RAPTOR chunks to remove (doc=%s tenant=%s kb=%s keep=%s)",
|
||||
doc_id, tenant_id, kb_id, keep_method,
|
||||
)
|
||||
return 0
|
||||
|
||||
logging.info(
|
||||
"delete_raptor_chunks: removing %d stale RAPTOR chunks (doc=%s tenant=%s kb=%s keep=%s)",
|
||||
len(chunk_ids), doc_id, tenant_id, kb_id, keep_method,
|
||||
)
|
||||
await thread_pool_exec(
|
||||
settings.docStoreConn.delete,
|
||||
{"id": list(chunk_ids)},
|
||||
nlp_search.index_name(tenant_id),
|
||||
kb_id,
|
||||
)
|
||||
return len(chunk_ids)
|
||||
|
||||
|
||||
@timeout(3600)
|
||||
async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_size, callback=None, doc_ids=[]):
|
||||
"""Generate RAPTOR summaries for selected documents in a knowledge base."""
|
||||
fake_doc_id = GRAPH_RAPTOR_FAKE_DOC_ID
|
||||
|
||||
raptor_config = kb_parser_config.get("raptor", {})
|
||||
raptor_ext_config = raptor_config.get("ext") or {}
|
||||
tree_builder = get_raptor_tree_builder(raptor_config)
|
||||
clustering_method = get_raptor_clustering_method(raptor_config)
|
||||
vctr_nm = "q_%d_vec" % vector_size
|
||||
|
||||
res = []
|
||||
tk_count = 0
|
||||
cleanup_raptor_chunks = []
|
||||
max_errors = int(os.environ.get("RAPTOR_MAX_ERRORS", 3))
|
||||
doc_name_by_id = {}
|
||||
doc_info_by_id = {}
|
||||
for doc_id in set(doc_ids):
|
||||
ok, source_doc = DocumentService.get_by_id(doc_id)
|
||||
if not ok or not source_doc:
|
||||
continue
|
||||
source_name = getattr(source_doc, "name", "")
|
||||
if source_name:
|
||||
doc_name_by_id[doc_id] = source_name
|
||||
doc_info_by_id[doc_id] = {
|
||||
"name": getattr(source_doc, "name", ""),
|
||||
"type": getattr(source_doc, "type", ""),
|
||||
"parser_id": getattr(source_doc, "parser_id", ""),
|
||||
"parser_config": getattr(source_doc, "parser_config", {}) or {},
|
||||
}
|
||||
|
||||
def schedule_raptor_cleanup(doc_id: str, keep_method: str | None = None):
|
||||
"""Queue stale RAPTOR summaries for deletion after successful insert."""
|
||||
cleanup_plan = (doc_id, keep_method)
|
||||
if cleanup_plan not in cleanup_raptor_chunks:
|
||||
cleanup_raptor_chunks.append(cleanup_plan)
|
||||
|
||||
def skip_raptor_doc(doc_id: str) -> bool:
|
||||
"""Return whether RAPTOR should be skipped for this source document."""
|
||||
doc_info = doc_info_by_id.get(doc_id, {})
|
||||
file_type = doc_info.get("type") or row.get("type", "")
|
||||
parser_id = doc_info.get("parser_id") or row.get("parser_id", "")
|
||||
parser_config = doc_info.get("parser_config") or row.get("parser_config", {})
|
||||
if should_skip_raptor(file_type, parser_id, parser_config, raptor_config):
|
||||
skip_reason = get_skip_reason(file_type, parser_id, parser_config)
|
||||
doc_name = doc_info.get("name") or doc_id
|
||||
logging.info("Skipping Raptor for document %s: %s", doc_name, skip_reason)
|
||||
callback(msg=f"[RAPTOR] doc:{doc_id} skipped: {skip_reason}")
|
||||
return True
|
||||
return False
|
||||
|
||||
async def generate(chunks, did):
|
||||
"""Run RAPTOR and append generated summary chunks for one doc id."""
|
||||
nonlocal tk_count, res
|
||||
logging.info("RAPTOR: using tree_builder=%s clustering_method=%s for doc %s", tree_builder, clustering_method, did)
|
||||
raptor = Raptor(
|
||||
raptor_config.get("max_cluster", 64),
|
||||
chat_mdl,
|
||||
@@ -880,16 +990,21 @@ async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_si
|
||||
raptor_config["max_token"],
|
||||
raptor_config["threshold"],
|
||||
max_errors=max_errors,
|
||||
tree_builder=tree_builder,
|
||||
clustering_method=clustering_method,
|
||||
psi_exact_max_leaves=raptor_ext_config.get("psi_exact_max_leaves", 4096),
|
||||
psi_bucket_size=raptor_ext_config.get("psi_bucket_size", 1024),
|
||||
)
|
||||
original_length = len(chunks)
|
||||
chunks, layers = await raptor(chunks, kb_parser_config["raptor"]["random_seed"], callback, row["id"])
|
||||
effective_doc_name = row["name"] if did == fake_doc_id else doc_name_by_id.get(did, row["name"])
|
||||
effective_doc_name = row["name"] if did == fake_doc_id else doc_info_by_id.get(did, {}).get("name") or row["name"]
|
||||
doc = {
|
||||
"doc_id": did,
|
||||
"kb_id": [str(row["kb_id"])],
|
||||
"docnm_kwd": effective_doc_name,
|
||||
"title_tks": rag_tokenizer.tokenize(effective_doc_name),
|
||||
"raptor_kwd": "raptor"
|
||||
"raptor_kwd": "raptor",
|
||||
"extra": {"raptor_method": tree_builder},
|
||||
}
|
||||
if row["pagerank"]:
|
||||
doc[PAGERANK_FLD] = int(row["pagerank"])
|
||||
@@ -906,7 +1021,7 @@ async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_si
|
||||
|
||||
for idx, (content, vctr) in enumerate(chunks[original_length:], start=original_length):
|
||||
d = copy.deepcopy(doc)
|
||||
d["id"] = xxhash.xxh64((content + str(fake_doc_id)).encode("utf-8")).hexdigest()
|
||||
d["id"] = make_raptor_summary_chunk_id(content, did)
|
||||
d["create_time"] = str(datetime.now()).replace("T", " ")[:19]
|
||||
d["create_timestamp_flt"] = datetime.now().timestamp()
|
||||
d[vctr_nm] = vctr.tolist()
|
||||
@@ -918,12 +1033,28 @@ async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_si
|
||||
tk_count += num_tokens_from_string(content)
|
||||
|
||||
if raptor_config.get("scope", "file") == "file":
|
||||
dataset_methods = await get_raptor_chunk_methods(fake_doc_id, row["tenant_id"], row["kb_id"])
|
||||
remove_dataset_summaries = bool(dataset_methods)
|
||||
has_file_level_target = False
|
||||
if dataset_methods:
|
||||
callback(msg="[RAPTOR] will remove dataset-level summaries after file-level summaries are available.")
|
||||
|
||||
for x, doc_id in enumerate(doc_ids):
|
||||
# CHECKPOINT: skip docs that already have RAPTOR chunks in the doc store
|
||||
if await has_raptor_chunks(doc_id, row["tenant_id"], row["kb_id"]):
|
||||
callback(msg=f"[RAPTOR] doc:{doc_id} already has RAPTOR chunks, skipping.")
|
||||
if skip_raptor_doc(doc_id):
|
||||
callback(prog=(x + 1.) / len(doc_ids))
|
||||
continue
|
||||
# CHECKPOINT: skip docs that already have RAPTOR chunks in the doc store
|
||||
existing_methods = await get_raptor_chunk_methods(doc_id, row["tenant_id"], row["kb_id"])
|
||||
if tree_builder in existing_methods:
|
||||
has_file_level_target = True
|
||||
if existing_methods != {tree_builder}:
|
||||
schedule_raptor_cleanup(doc_id, tree_builder)
|
||||
callback(msg=f"[RAPTOR] doc:{doc_id} will remove old RAPTOR summaries after insert.")
|
||||
callback(msg=f"[RAPTOR] doc:{doc_id} already has {tree_builder} RAPTOR chunks, skipping.")
|
||||
callback(prog=(x + 1.) / len(doc_ids))
|
||||
continue
|
||||
if existing_methods:
|
||||
callback(msg=f"[RAPTOR] doc:{doc_id} will migrate RAPTOR summaries to {tree_builder} after insert.")
|
||||
|
||||
chunks = []
|
||||
skipped_chunks = 0
|
||||
@@ -945,12 +1076,52 @@ async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_si
|
||||
callback(msg=f"[WARN] No valid chunks with vectors found for doc {doc_id}, skipping")
|
||||
continue
|
||||
|
||||
before_generate = len(res)
|
||||
await generate(chunks, doc_id)
|
||||
if len(res) > before_generate:
|
||||
has_file_level_target = True
|
||||
if existing_methods:
|
||||
schedule_raptor_cleanup(doc_id, tree_builder)
|
||||
callback(prog=(x + 1.) / len(doc_ids))
|
||||
|
||||
if remove_dataset_summaries:
|
||||
if has_file_level_target:
|
||||
schedule_raptor_cleanup(fake_doc_id)
|
||||
else:
|
||||
callback(msg="[RAPTOR] kept dataset-level summaries because no file-level summaries were built.")
|
||||
else:
|
||||
migrated_file_docs = 0
|
||||
file_cleanup_doc_ids = []
|
||||
skipped_doc_ids = set()
|
||||
for doc_id in set(doc_ids):
|
||||
if skip_raptor_doc(doc_id):
|
||||
skipped_doc_ids.add(doc_id)
|
||||
continue
|
||||
existing_methods = await get_raptor_chunk_methods(doc_id, row["tenant_id"], row["kb_id"])
|
||||
if existing_methods:
|
||||
file_cleanup_doc_ids.append(doc_id)
|
||||
migrated_file_docs += 1
|
||||
if migrated_file_docs:
|
||||
callback(msg=f"[RAPTOR] will remove file-level summaries for {migrated_file_docs} docs after dataset-level build succeeds.")
|
||||
|
||||
existing_methods = await get_raptor_chunk_methods(fake_doc_id, row["tenant_id"], row["kb_id"])
|
||||
if tree_builder in existing_methods:
|
||||
if existing_methods != {tree_builder}:
|
||||
schedule_raptor_cleanup(fake_doc_id, tree_builder)
|
||||
callback(msg="[RAPTOR] will remove old dataset-level RAPTOR summaries after insert.")
|
||||
for doc_id in file_cleanup_doc_ids:
|
||||
schedule_raptor_cleanup(doc_id)
|
||||
callback(msg=f"[RAPTOR] dataset-level {tree_builder} summaries already exist, skipping.")
|
||||
return res, tk_count, cleanup_raptor_chunks
|
||||
migrate_dataset_summaries = bool(existing_methods)
|
||||
if migrate_dataset_summaries:
|
||||
callback(msg=f"[RAPTOR] will migrate dataset-level RAPTOR summaries to {tree_builder} after insert.")
|
||||
|
||||
chunks = []
|
||||
skipped_chunks = 0
|
||||
for doc_id in doc_ids:
|
||||
if doc_id in skipped_doc_ids:
|
||||
continue
|
||||
for d in settings.retriever.chunk_list(doc_id, row["tenant_id"], [str(row["kb_id"])],
|
||||
fields=["content_with_weight", vctr_nm],
|
||||
sort_by_position=True):
|
||||
@@ -965,13 +1136,22 @@ async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_si
|
||||
callback(msg=f"[WARN] Skipped {skipped_chunks} chunks without vector field '{vctr_nm}'. Consider re-parsing documents with the current embedding model.")
|
||||
|
||||
if not chunks:
|
||||
if skipped_doc_ids and len(skipped_doc_ids) == len(set(doc_ids)):
|
||||
callback(msg="[RAPTOR] all documents were skipped by RAPTOR auto-disable rules.")
|
||||
return res, tk_count, cleanup_raptor_chunks
|
||||
logging.error(f"RAPTOR: No valid chunks with vectors found in any document for kb {row['kb_id']}")
|
||||
callback(msg=f"[ERROR] No valid chunks with vectors found. Please ensure documents are parsed with the current embedding model (vector size: {vector_size}).")
|
||||
return res, tk_count
|
||||
return res, tk_count, cleanup_raptor_chunks
|
||||
|
||||
before_generate = len(res)
|
||||
await generate(chunks, fake_doc_id)
|
||||
if len(res) > before_generate:
|
||||
for doc_id in file_cleanup_doc_ids:
|
||||
schedule_raptor_cleanup(doc_id)
|
||||
if migrate_dataset_summaries:
|
||||
schedule_raptor_cleanup(fake_doc_id, tree_builder)
|
||||
|
||||
return res, tk_count
|
||||
return res, tk_count, cleanup_raptor_chunks
|
||||
|
||||
|
||||
async def delete_image(kb_id, chunk_id):
|
||||
@@ -1029,6 +1209,29 @@ async def insert_chunks(task_id, task_tenant_id, task_dataset_id, chunks, progre
|
||||
search.index_name(task_tenant_id), task_dataset_id, )
|
||||
task_canceled = has_canceled(task_id)
|
||||
if task_canceled:
|
||||
# Roll back partial RAPTOR summary inserts so the next run is not
|
||||
# mistaken for a completed checkpoint by get_raptor_chunk_methods.
|
||||
raptor_ids_to_rollback = [
|
||||
c["id"] for c in chunks[:b + settings.DOC_BULK_SIZE]
|
||||
if c.get("raptor_kwd") == "raptor"
|
||||
]
|
||||
if raptor_ids_to_rollback:
|
||||
try:
|
||||
await thread_pool_exec(
|
||||
settings.docStoreConn.delete,
|
||||
{"id": raptor_ids_to_rollback},
|
||||
search.index_name(task_tenant_id),
|
||||
task_dataset_id,
|
||||
)
|
||||
logging.info(
|
||||
"insert_chunks: rolled back %d partial RAPTOR chunks after cancellation (task=%s)",
|
||||
len(raptor_ids_to_rollback), task_id,
|
||||
)
|
||||
except Exception:
|
||||
logging.exception(
|
||||
"insert_chunks: failed to roll back partial RAPTOR chunks after cancellation (task=%s)",
|
||||
task_id,
|
||||
)
|
||||
progress_callback(-1, msg="Task has been canceled.")
|
||||
return False
|
||||
if b % 128 == 0:
|
||||
@@ -1088,6 +1291,7 @@ async def do_handle_task(task):
|
||||
task_parser_config = task["parser_config"]
|
||||
task_start_ts = timer()
|
||||
toc_thread = None
|
||||
raptor_cleanup_chunks = []
|
||||
|
||||
# prepare the progress callback function
|
||||
progress_callback = partial(set_progress, task_id, task_from_page, task_to_page)
|
||||
@@ -1135,7 +1339,9 @@ async def do_handle_task(task):
|
||||
"threshold": 0.1,
|
||||
"max_cluster": 64,
|
||||
"random_seed": 0,
|
||||
"scope": "file"
|
||||
"scope": "file",
|
||||
"clustering_method": "gmm",
|
||||
"tree_builder": "raptor",
|
||||
},
|
||||
}
|
||||
)
|
||||
@@ -1143,23 +1349,12 @@ async def do_handle_task(task):
|
||||
progress_callback(prog=-1.0, msg="Internal error: Invalid RAPTOR configuration")
|
||||
return
|
||||
|
||||
# Check if Raptor should be skipped for structured data
|
||||
file_type = task.get("type", "")
|
||||
parser_id = task.get("parser_id", "")
|
||||
raptor_config = kb_parser_config.get("raptor", {})
|
||||
|
||||
if should_skip_raptor(file_type, parser_id, task_parser_config, raptor_config):
|
||||
skip_reason = get_skip_reason(file_type, parser_id, task_parser_config)
|
||||
logging.info(f"Skipping Raptor for document {task_document_name}: {skip_reason}")
|
||||
progress_callback(prog=1.0, msg=f"Raptor skipped: {skip_reason}")
|
||||
return
|
||||
|
||||
# bind LLM for raptor
|
||||
chat_model_config = get_model_config_by_type_and_name(task_tenant_id, LLMType.CHAT, kb_task_llm_id)
|
||||
chat_model = LLMBundle(task_tenant_id, chat_model_config, lang=task_language)
|
||||
# run RAPTOR
|
||||
async with kg_limiter:
|
||||
chunks, token_count = await run_raptor_for_kb(
|
||||
chunks, token_count, raptor_cleanup_chunks = await run_raptor_for_kb(
|
||||
row=task,
|
||||
kb_parser_config=kb_parser_config,
|
||||
chat_mdl=chat_model,
|
||||
@@ -1268,6 +1463,18 @@ async def do_handle_task(task):
|
||||
progress_callback(-1, msg="Task has been canceled.")
|
||||
return
|
||||
|
||||
if raptor_cleanup_chunks:
|
||||
cleaned_chunks = 0
|
||||
for cleanup_doc_id, keep_method in raptor_cleanup_chunks:
|
||||
cleaned_chunks += await delete_raptor_chunks(
|
||||
cleanup_doc_id,
|
||||
task_tenant_id,
|
||||
task_dataset_id,
|
||||
keep_method=keep_method,
|
||||
)
|
||||
if cleaned_chunks:
|
||||
progress_callback(msg=f"Cleaned up {cleaned_chunks} stale RAPTOR chunks.")
|
||||
|
||||
logging.info(
|
||||
"Indexing doc({}), page({}-{}), chunks({}), elapsed: {:.2f}".format(
|
||||
task_document_name, task_from_page, task_to_page, len(chunks), timer() - start_ts
|
||||
|
||||
@@ -46,6 +46,8 @@ column_order_id = Column("_order_id", Integer, nullable=True, comment="chunk ord
|
||||
column_group_id = Column("group_id", String(256), nullable=True, comment="group id for external retrieval")
|
||||
column_mom_id = Column("mom_id", String(256), nullable=True, comment="parent chunk id")
|
||||
column_chunk_data = Column("chunk_data", JSON, nullable=True, comment="table parser row data")
|
||||
column_raptor_kwd = Column("raptor_kwd", String(256), nullable=True, comment="RAPTOR summary marker")
|
||||
column_raptor_layer_int = Column("raptor_layer_int", Integer, nullable=True, comment="RAPTOR summary layer")
|
||||
|
||||
column_definitions: list[Column] = [
|
||||
Column("id", String(256), primary_key=True, comment="chunk id"),
|
||||
@@ -86,6 +88,8 @@ column_definitions: list[Column] = [
|
||||
Column("rank_flt", Double, nullable=True, comment="rank of this entity"),
|
||||
Column("removed_kwd", String(256), nullable=True, index=True, server_default="'N'",
|
||||
comment="whether it has been deleted"),
|
||||
column_raptor_kwd,
|
||||
column_raptor_layer_int,
|
||||
column_chunk_data,
|
||||
Column("metadata", JSON, nullable=True, comment="metadata for this chunk"),
|
||||
Column("extra", JSON, nullable=True, comment="extra information of non-general chunk"),
|
||||
@@ -127,7 +131,14 @@ FTS_COLUMNS_TKS: list[str] = [
|
||||
]
|
||||
|
||||
# Extra columns to add after table creation (for migration)
|
||||
EXTRA_COLUMNS: list[Column] = [column_order_id, column_group_id, column_mom_id, column_chunk_data]
|
||||
EXTRA_COLUMNS: list[Column] = [
|
||||
column_order_id,
|
||||
column_group_id,
|
||||
column_mom_id,
|
||||
column_chunk_data,
|
||||
column_raptor_kwd,
|
||||
column_raptor_layer_int,
|
||||
]
|
||||
|
||||
|
||||
class SearchResult(BaseModel):
|
||||
|
||||
@@ -18,15 +18,111 @@
|
||||
Utility functions for Raptor processing decisions.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import xxhash
|
||||
|
||||
RAPTOR_TREE_BUILDER = "raptor"
|
||||
PSI_TREE_BUILDER = "psi"
|
||||
SUPPORTED_TREE_BUILDERS = {RAPTOR_TREE_BUILDER, PSI_TREE_BUILDER}
|
||||
GMM_CLUSTERING_METHOD = "gmm"
|
||||
AHC_CLUSTERING_METHOD = "ahc"
|
||||
SUPPORTED_CLUSTERING_METHODS = {GMM_CLUSTERING_METHOD, AHC_CLUSTERING_METHOD}
|
||||
|
||||
# File extensions for structured data types
|
||||
EXCEL_EXTENSIONS = {".xls", ".xlsx", ".xlsm", ".xlsb"}
|
||||
CSV_EXTENSIONS = {".csv", ".tsv"}
|
||||
STRUCTURED_EXTENSIONS = EXCEL_EXTENSIONS | CSV_EXTENSIONS
|
||||
|
||||
|
||||
def get_raptor_tree_builder(raptor_config: dict | None) -> str:
|
||||
"""Return the configured RAPTOR tree builder with legacy ext fallback."""
|
||||
raptor_config = raptor_config or {}
|
||||
ext = raptor_config.get("ext") or {}
|
||||
tree_builder = ext.get("tree_builder") or raptor_config.get("tree_builder") or RAPTOR_TREE_BUILDER
|
||||
if tree_builder not in SUPPORTED_TREE_BUILDERS:
|
||||
raise ValueError(f"Unsupported RAPTOR tree builder: {tree_builder}")
|
||||
return tree_builder
|
||||
|
||||
|
||||
def get_raptor_clustering_method(raptor_config: dict | None) -> str:
|
||||
"""Return the configured RAPTOR clustering method with legacy ext fallback."""
|
||||
raptor_config = raptor_config or {}
|
||||
ext = raptor_config.get("ext") or {}
|
||||
clustering_method = ext.get("clustering_method") or raptor_config.get("clustering_method") or GMM_CLUSTERING_METHOD
|
||||
if clustering_method not in SUPPORTED_CLUSTERING_METHODS:
|
||||
raise ValueError(f"Unsupported RAPTOR clustering method: {clustering_method}")
|
||||
return clustering_method
|
||||
|
||||
|
||||
def _as_extra_dict(extra) -> dict:
|
||||
"""Normalize a chunk extra payload into a dictionary."""
|
||||
if isinstance(extra, dict):
|
||||
return extra
|
||||
if isinstance(extra, str) and extra:
|
||||
try:
|
||||
parsed = json.loads(extra)
|
||||
except json.JSONDecodeError:
|
||||
logging.warning(
|
||||
"Ignoring malformed RAPTOR extra payload while collecting chunk metadata: %s",
|
||||
extra[:200],
|
||||
exc_info=True,
|
||||
)
|
||||
return {}
|
||||
return parsed if isinstance(parsed, dict) else {}
|
||||
return {}
|
||||
|
||||
|
||||
def _has_raptor_marker(marker) -> bool:
|
||||
"""Return whether a chunk marker identifies a RAPTOR summary chunk."""
|
||||
if isinstance(marker, list):
|
||||
return any(str(item) == RAPTOR_TREE_BUILDER for item in marker)
|
||||
return str(marker) == RAPTOR_TREE_BUILDER
|
||||
|
||||
|
||||
def _raptor_methods_from_fields(fields: dict, extra: dict | None = None) -> set[str]:
|
||||
"""Read RAPTOR builder methods from stored chunk fields."""
|
||||
extra = extra if extra is not None else _as_extra_dict(fields.get("extra"))
|
||||
method = extra.get("raptor_method") or RAPTOR_TREE_BUILDER
|
||||
if isinstance(method, list):
|
||||
return {str(item) for item in method if item}
|
||||
return {str(method)} if method else set()
|
||||
|
||||
|
||||
def collect_raptor_methods(field_map: dict) -> set[str]:
|
||||
"""Collect tree-builder methods from RAPTOR summary chunk fields."""
|
||||
methods = set()
|
||||
for fields in field_map.values():
|
||||
extra = _as_extra_dict(fields.get("extra"))
|
||||
marker = fields.get("raptor_kwd") or extra.get("raptor_kwd")
|
||||
if not _has_raptor_marker(marker):
|
||||
continue
|
||||
|
||||
methods.update(_raptor_methods_from_fields(fields, extra))
|
||||
return methods
|
||||
|
||||
|
||||
def collect_raptor_chunk_ids(field_map: dict, exclude_methods: set[str] | None = None) -> set[str]:
|
||||
"""Collect RAPTOR summary chunk IDs, optionally excluding some methods."""
|
||||
chunk_ids = set()
|
||||
exclude_methods = exclude_methods or set()
|
||||
for chunk_id, fields in field_map.items():
|
||||
extra = _as_extra_dict(fields.get("extra"))
|
||||
marker = fields.get("raptor_kwd") or extra.get("raptor_kwd")
|
||||
if _has_raptor_marker(marker):
|
||||
if _raptor_methods_from_fields(fields, extra).issubset(exclude_methods):
|
||||
continue
|
||||
chunk_ids.add(chunk_id)
|
||||
return chunk_ids
|
||||
|
||||
|
||||
def make_raptor_summary_chunk_id(content: str, doc_id: str) -> str:
|
||||
"""Build the stable ID used for generated RAPTOR summary chunks."""
|
||||
return xxhash.xxh64((content + str(doc_id)).encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
def is_structured_file_type(file_type: Optional[str]) -> bool:
|
||||
"""
|
||||
Check if a file type is structured data (Excel, CSV, etc.)
|
||||
|
||||
Reference in New Issue
Block a user