mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-07-03 01:01:56 +08:00
## Problem When building or updating a knowledge graph with a large number of entities and edges, `set_graph()` in `rag/graphrag/utils.py` creates one `asyncio` task per entity and one per edge, each calling `embd_mdl.encode([single_name])` — a single-item HTTP request to the embedding server. For a graph with 17,000+ nodes and edges (real case reported in #16205), this generates **34,000+ individual embedding API round-trips** instead of ~266 batched calls at the default `_INSERT_BULK_SIZE=128`. The asyncio gather over thousands of tasks makes the embedding server the bottleneck; under load, a single slow/failed call aborts all remaining tasks, causing the pipeline to stall and never complete. Closes #16205. Related: #15921. ## Root Cause ```python # Before (in set_graph, node loop): tasks = [asyncio.create_task(graph_node_to_chunk(n, ...)) for n in nodes] # Each task calls embd_mdl.encode([single_name]) — 1 HTTP call per node ``` `graph_node_to_chunk` checks the embed cache first, but the cache is cold on first build, so every task makes a live API call. ## Fix Pre-warm the embedding cache with batched calls before spawning tasks. Each batch pre-warm calls `embd_mdl.encode(batch_of_128)` once, populating the cache. Then every individual task hits the cache and makes zero embedding API calls. - Only encodes names not already in cache (no-op on warm cache / small incremental updates) - Uses existing project idioms: `thread_pool_exec`, `chat_limiter`, `_INSERT_BULK_SIZE`, `get_embed_cache`, `set_embed_cache` - Mirrors the `ENABLE_TIMEOUT_ASSERTION` timeout pattern from `graph_node_to_chunk` - Zero behavior change: per-task encode logic remains as a correct fallback ## Result | Graph size | Before | After | |---|---|---| | 17,576 edges | ~17,576 embedding calls → stall | ~138 batched calls | | 17,509 nodes | ~17,509 embedding calls → stall | ~137 batched calls | | **Total** | **~35,000 calls** | **~275 calls** | --------- Co-authored-by: Oti_B <oti@mac.speedport.ip>
946 lines
37 KiB
Python
946 lines
37 KiB
Python
# Copyright (c) 2024 Microsoft Corporation.
|
|
# Licensed under the MIT License
|
|
|
|
from common.misc_utils import thread_pool_exec
|
|
|
|
"""
|
|
Reference:
|
|
- [graphrag](https://github.com/microsoft/graphrag)
|
|
- [LightRag](https://github.com/HKUDS/LightRAG)
|
|
"""
|
|
|
|
import asyncio
|
|
import dataclasses
|
|
import html
|
|
import json
|
|
import logging
|
|
import os
|
|
import re
|
|
import time
|
|
from collections import defaultdict
|
|
from copy import deepcopy
|
|
from hashlib import md5
|
|
from typing import Any, Callable, Set, Tuple
|
|
|
|
import networkx as nx
|
|
import numpy as np
|
|
import xxhash
|
|
from networkx.readwrite import json_graph
|
|
|
|
from common.misc_utils import get_uuid
|
|
from common.connection_utils import timeout
|
|
from common.asyncio_utils import LoopLocalSemaphore
|
|
from rag.nlp import rag_tokenizer, search
|
|
from rag.utils.redis_conn import REDIS_CONN
|
|
from common import settings
|
|
from common.doc_store.doc_store_base import OrderByExpr
|
|
|
|
GRAPH_FIELD_SEP = "<SEP>"
|
|
|
|
ErrorHandlerFn = Callable[[BaseException | None, str | None, dict | None], None]
|
|
|
|
chat_limiter = LoopLocalSemaphore(int(os.environ.get("MAX_CONCURRENT_CHATS", 10)))
|
|
|
|
# Doc-store insert batching for GraphRAG subgraph/node/edge/community_report
|
|
# chunks. Defaults (64 docs per batch, up to 4 batches in flight) mirror the
|
|
# regular ingest pipeline in document_service.py while still keeping the total
|
|
# number of simultaneous requests to ES/Infinity bounded. Override with
|
|
# GRAPHRAG_INSERT_BULK_SIZE and GRAPHRAG_INSERT_CONCURRENCY.
|
|
_INSERT_BULK_SIZE = max(1, int(os.environ.get("GRAPHRAG_INSERT_BULK_SIZE", 64)))
|
|
_INSERT_CONCURRENCY = max(1, int(os.environ.get("GRAPHRAG_INSERT_CONCURRENCY", 4)))
|
|
|
|
|
|
async def insert_chunks_bounded(chunks, tenant_id, kb_id, *, callback=None, label="Insert chunks"):
|
|
"""Insert ``chunks`` into the doc store in batches with bounded concurrency and retries.
|
|
|
|
Batch size is controlled by ``GRAPHRAG_INSERT_BULK_SIZE`` (default 64) and
|
|
the number of batches in flight by ``GRAPHRAG_INSERT_CONCURRENCY``
|
|
(default 4). Each batch has the same retry / timeout behaviour as the
|
|
previous hand-rolled loop (3 attempts, exponential backoff).
|
|
|
|
Raises the first unrecoverable error; other in-flight batches are then
|
|
cancelled by ``asyncio.gather``.
|
|
"""
|
|
if not chunks:
|
|
return
|
|
enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
|
|
sem = asyncio.Semaphore(_INSERT_CONCURRENCY)
|
|
total = len(chunks)
|
|
progress = {"done": 0, "next_report": 100}
|
|
progress_lock = asyncio.Lock()
|
|
|
|
async def _one(offset: int) -> None:
|
|
batch = chunks[offset : offset + _INSERT_BULK_SIZE]
|
|
timeout_s = 3 if enable_timeout_assertion else 30000000
|
|
max_retries = 3
|
|
async with sem:
|
|
for attempt in range(max_retries):
|
|
try:
|
|
result = await asyncio.wait_for(
|
|
thread_pool_exec(
|
|
settings.docStoreConn.insert,
|
|
batch,
|
|
search.index_name(tenant_id),
|
|
kb_id,
|
|
),
|
|
timeout=timeout_s,
|
|
)
|
|
if result:
|
|
raise Exception(f"Insert chunk error: {result}, please check log file and Elasticsearch/Infinity status!")
|
|
break
|
|
except asyncio.TimeoutError:
|
|
if attempt < max_retries - 1:
|
|
wait = 2 ** attempt
|
|
logging.warning(f"Insert batch at offset {offset}/{total} attempt {attempt + 1} timed out, retrying in {wait}s")
|
|
await asyncio.sleep(wait)
|
|
else:
|
|
raise
|
|
except asyncio.CancelledError:
|
|
raise
|
|
except Exception as e:
|
|
if attempt < max_retries - 1:
|
|
wait = 2 ** attempt
|
|
logging.warning(f"Insert batch at offset {offset}/{total} attempt {attempt + 1} failed: {e}, retrying in {wait}s")
|
|
await asyncio.sleep(wait)
|
|
else:
|
|
raise
|
|
if callback:
|
|
async with progress_lock:
|
|
progress["done"] += len(batch)
|
|
if progress["done"] >= progress["next_report"] or progress["done"] == total:
|
|
callback(msg=f"{label}: {progress['done']}/{total}")
|
|
progress["next_report"] = progress["done"] + 100
|
|
|
|
await asyncio.gather(*(asyncio.create_task(_one(o)) for o in range(0, total, _INSERT_BULK_SIZE)))
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class GraphChange:
|
|
removed_nodes: Set[str] = dataclasses.field(default_factory=set)
|
|
added_updated_nodes: Set[str] = dataclasses.field(default_factory=set)
|
|
removed_edges: Set[Tuple[str, str]] = dataclasses.field(default_factory=set)
|
|
added_updated_edges: Set[Tuple[str, str]] = dataclasses.field(default_factory=set)
|
|
|
|
|
|
def perform_variable_replacements(input: str, history: list[dict] | None = None, variables: dict | None = None) -> str:
|
|
"""Perform variable replacements on the input string and in a chat log."""
|
|
if history is None:
|
|
history = []
|
|
if variables is None:
|
|
variables = {}
|
|
result = input
|
|
|
|
def replace_all(input: str) -> str:
|
|
result = input
|
|
for k, v in variables.items():
|
|
result = result.replace(f"{{{k}}}", str(v))
|
|
return result
|
|
|
|
result = replace_all(result)
|
|
for i, entry in enumerate(history):
|
|
if entry.get("role") == "system":
|
|
entry["content"] = replace_all(entry.get("content") or "")
|
|
|
|
return result
|
|
|
|
|
|
def clean_str(input: Any) -> str:
|
|
"""Clean an input string by removing HTML escapes, control characters, and other unwanted characters."""
|
|
# If we get non-string input, just give it back
|
|
if not isinstance(input, str):
|
|
return input
|
|
|
|
result = html.unescape(input.strip())
|
|
# https://stackoverflow.com/questions/4324790/removing-control-characters-from-a-string-in-python
|
|
return re.sub(r"[\"\x00-\x1f\x7f-\x9f]", "", result)
|
|
|
|
|
|
def dict_has_keys_with_types(data: dict, expected_fields: list[tuple[str, type]]) -> bool:
|
|
"""Return True if the given dictionary has the given keys with the given types."""
|
|
for field, field_type in expected_fields:
|
|
if field not in data:
|
|
return False
|
|
|
|
value = data[field]
|
|
if not isinstance(value, field_type):
|
|
return False
|
|
return True
|
|
|
|
|
|
def get_llm_cache(llmnm, txt, history, genconf):
|
|
"""Return a cached LLM completion for the given model/text/history/config, or None on miss."""
|
|
hasher = xxhash.xxh64()
|
|
hasher.update((str(llmnm)+str(txt)+str(history)+str(genconf)).encode("utf-8"))
|
|
|
|
k = hasher.hexdigest()
|
|
bin = REDIS_CONN.get(k)
|
|
if not bin:
|
|
return None
|
|
return bin
|
|
|
|
|
|
def set_llm_cache(llmnm, txt, v, history, genconf):
|
|
"""Store an LLM completion *v* in Redis keyed by a hash of model/text/history/config."""
|
|
hasher = xxhash.xxh64()
|
|
hasher.update((str(llmnm)+str(txt)+str(history)+str(genconf)).encode("utf-8"))
|
|
k = hasher.hexdigest()
|
|
REDIS_CONN.set(k, v.encode("utf-8"), 24 * 3600)
|
|
|
|
|
|
def get_embed_cache(llmnm, txt):
|
|
"""Return a cached embedding vector (numpy array) for *llmnm*/*txt*, or None on miss."""
|
|
hasher = xxhash.xxh64()
|
|
hasher.update(str(llmnm).encode("utf-8"))
|
|
hasher.update(str(txt).encode("utf-8"))
|
|
|
|
k = hasher.hexdigest()
|
|
bin = REDIS_CONN.get(k)
|
|
if not bin:
|
|
return
|
|
return np.array(json.loads(bin))
|
|
|
|
|
|
def set_embed_cache(llmnm, txt, arr):
|
|
"""Store embedding *arr* in Redis for the given model name and input text."""
|
|
hasher = xxhash.xxh64()
|
|
hasher.update(str(llmnm).encode("utf-8"))
|
|
hasher.update(str(txt).encode("utf-8"))
|
|
|
|
k = hasher.hexdigest()
|
|
arr = json.dumps(arr.tolist() if isinstance(arr, np.ndarray) else arr)
|
|
REDIS_CONN.set(k, arr.encode("utf-8"), 24 * 3600)
|
|
|
|
|
|
def _batch_embed_cache_misses(llmnm: str, keys: list) -> "list[bool]":
|
|
"""Return a boolean miss-mask for *keys* using a single MGET round-trip.
|
|
|
|
Avoids per-item REDIS_CONN.get() calls (which would block the event loop
|
|
when called from an async context) by issuing one batched MGET instead.
|
|
"""
|
|
if not keys:
|
|
return []
|
|
hashes = []
|
|
for key in keys:
|
|
h = xxhash.xxh64()
|
|
h.update(str(llmnm).encode("utf-8"))
|
|
h.update(str(key).encode("utf-8"))
|
|
hashes.append(h.hexdigest())
|
|
return [v is None for v in REDIS_CONN.mget(hashes)]
|
|
|
|
|
|
def _write_embed_cache_batch(llmnm: str, keys: list, embeddings) -> None:
|
|
"""Write a batch of embeddings to the Redis embed cache synchronously.
|
|
|
|
Intended for use with thread_pool_exec so that the synchronous Redis SET
|
|
calls do not block the event loop.
|
|
"""
|
|
for key, ebd in zip(keys, embeddings):
|
|
set_embed_cache(llmnm, key, ebd)
|
|
|
|
|
|
def get_tags_from_cache(kb_ids):
|
|
"""Return cached tag data for the given kb_ids from Redis, or None on miss."""
|
|
hasher = xxhash.xxh64()
|
|
hasher.update(str(kb_ids).encode("utf-8"))
|
|
|
|
k = hasher.hexdigest()
|
|
bin = REDIS_CONN.get(k)
|
|
if not bin:
|
|
return
|
|
return bin
|
|
|
|
|
|
def set_tags_to_cache(kb_ids, tags):
|
|
"""Persist tag data for *kb_ids* in Redis."""
|
|
hasher = xxhash.xxh64()
|
|
hasher.update(str(kb_ids).encode("utf-8"))
|
|
|
|
k = hasher.hexdigest()
|
|
REDIS_CONN.set(k, json.dumps(tags).encode("utf-8"), 600)
|
|
|
|
|
|
def tidy_graph(graph: nx.Graph, callback, check_attribute: bool = True):
|
|
"""
|
|
Ensure all nodes and edges in the graph have some essential attribute.
|
|
"""
|
|
|
|
def is_valid_item(node_attrs: dict) -> bool:
|
|
valid_node = True
|
|
for attr in ["description", "source_id"]:
|
|
if attr not in node_attrs:
|
|
valid_node = False
|
|
break
|
|
return valid_node
|
|
|
|
if check_attribute:
|
|
purged_nodes = []
|
|
for node, node_attrs in graph.nodes(data=True):
|
|
if not is_valid_item(node_attrs):
|
|
purged_nodes.append(node)
|
|
for node in purged_nodes:
|
|
graph.remove_node(node)
|
|
if purged_nodes and callback:
|
|
callback(msg=f"Purged {len(purged_nodes)} nodes from graph due to missing essential attributes.")
|
|
|
|
purged_edges = []
|
|
for source, target, attr in graph.edges(data=True):
|
|
if check_attribute:
|
|
if not is_valid_item(attr):
|
|
purged_edges.append((source, target))
|
|
if "keywords" not in attr:
|
|
attr["keywords"] = []
|
|
for source, target in purged_edges:
|
|
graph.remove_edge(source, target)
|
|
if purged_edges and callback:
|
|
callback(msg=f"Purged {len(purged_edges)} edges from graph due to missing essential attributes.")
|
|
|
|
|
|
def get_from_to(node1, node2):
|
|
"""Return a canonical (lesser, greater) node pair for consistent undirected edge keying."""
|
|
if node1 < node2:
|
|
return (node1, node2)
|
|
else:
|
|
return (node2, node1)
|
|
|
|
|
|
def graph_merge(g1: nx.Graph, g2: nx.Graph, change: GraphChange):
|
|
"""Merge graph g2 into g1 in place."""
|
|
for node_name, attr in g2.nodes(data=True):
|
|
change.added_updated_nodes.add(node_name)
|
|
if not g1.has_node(node_name):
|
|
g1.add_node(node_name, **attr)
|
|
continue
|
|
node = g1.nodes[node_name]
|
|
node["description"] += GRAPH_FIELD_SEP + attr["description"]
|
|
# A node's source_id indicates which chunks it came from.
|
|
node["source_id"] += attr["source_id"]
|
|
|
|
for source, target, attr in g2.edges(data=True):
|
|
change.added_updated_edges.add(get_from_to(source, target))
|
|
edge = g1.get_edge_data(source, target)
|
|
if edge is None:
|
|
g1.add_edge(source, target, **attr)
|
|
continue
|
|
edge["weight"] += attr.get("weight", 0)
|
|
edge["description"] += GRAPH_FIELD_SEP + attr["description"]
|
|
edge["keywords"] += attr["keywords"]
|
|
# A edge's source_id indicates which chunks it came from.
|
|
edge["source_id"] += attr["source_id"]
|
|
|
|
for node_degree in g1.degree:
|
|
g1.nodes[str(node_degree[0])]["rank"] = int(node_degree[1])
|
|
# A graph's source_id indicates which documents it came from.
|
|
if "source_id" not in g1.graph:
|
|
g1.graph["source_id"] = []
|
|
g1.graph["source_id"] += g2.graph.get("source_id", [])
|
|
return g1
|
|
|
|
|
|
def compute_args_hash(*args):
|
|
"""Return a hex MD5 digest of the string representation of *args* (used as a cache key)."""
|
|
return md5(str(args).encode()).hexdigest()
|
|
|
|
|
|
def handle_single_entity_extraction(
|
|
record_attributes: list[str],
|
|
chunk_key: str,
|
|
):
|
|
"""Parse one entity record from LLM output and return a node-attribute dict, or None."""
|
|
if len(record_attributes) < 4 or record_attributes[0] != '"entity"':
|
|
return None
|
|
# add this record as a node in the G
|
|
entity_name = clean_str(record_attributes[1].upper())
|
|
if not entity_name.strip():
|
|
return None
|
|
entity_type = clean_str(record_attributes[2].upper())
|
|
entity_description = clean_str(record_attributes[3])
|
|
entity_source_id = chunk_key
|
|
return dict(
|
|
entity_name=entity_name.upper(),
|
|
entity_type=entity_type.upper(),
|
|
description=entity_description,
|
|
source_id=entity_source_id,
|
|
)
|
|
|
|
|
|
def handle_single_relationship_extraction(record_attributes: list[str], chunk_key: str):
|
|
"""Parse one relationship record from LLM output and return an edge-attribute dict, or None."""
|
|
if len(record_attributes) < 5 or record_attributes[0] != '"relationship"':
|
|
return None
|
|
# add this record as edge
|
|
source = clean_str(record_attributes[1].upper())
|
|
target = clean_str(record_attributes[2].upper())
|
|
edge_description = clean_str(record_attributes[3])
|
|
|
|
edge_keywords = clean_str(record_attributes[4])
|
|
edge_source_id = chunk_key
|
|
weight = float(record_attributes[-1]) if is_float_regex(record_attributes[-1]) else 1.0
|
|
pair = sorted([source.upper(), target.upper()])
|
|
return dict(
|
|
src_id=pair[0],
|
|
tgt_id=pair[1],
|
|
weight=weight,
|
|
description=edge_description,
|
|
keywords=edge_keywords,
|
|
source_id=edge_source_id,
|
|
metadata={"created_at": time.time()},
|
|
)
|
|
|
|
|
|
def pack_user_ass_to_openai_messages(*args: str):
|
|
"""Interleave *args* as alternating user/assistant messages in OpenAI chat format."""
|
|
roles = ["user", "assistant"]
|
|
return [{"role": roles[i % 2], "content": content} for i, content in enumerate(args)]
|
|
|
|
|
|
def split_string_by_multi_markers(content: str, markers: list[str]) -> list[str]:
|
|
"""Split a string by multiple markers"""
|
|
if not markers:
|
|
return [content]
|
|
results = re.split("|".join(re.escape(marker) for marker in markers), content)
|
|
return [r.strip() for r in results if r.strip()]
|
|
|
|
|
|
def is_float_regex(value):
|
|
"""Return True if *value* is a string representation of a float or integer."""
|
|
return bool(re.match(r"^[-+]?[0-9]*\.?[0-9]+$", value))
|
|
|
|
|
|
def chunk_id(chunk):
|
|
"""Return a deterministic hex ID for *chunk* derived from its content and kb_id."""
|
|
return xxhash.xxh64((chunk["content_with_weight"] + chunk["kb_id"]).encode("utf-8")).hexdigest()
|
|
|
|
|
|
async def graph_node_to_chunk(kb_id, embd_mdl, ent_name, meta, chunks, nhop_neighbors=None):
|
|
"""Convert a graph node (entity) to an embeddable chunk and append it to *chunks*."""
|
|
global chat_limiter
|
|
enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
|
|
chunk = {
|
|
"id": get_uuid(),
|
|
"important_kwd": [ent_name],
|
|
"title_tks": rag_tokenizer.tokenize(ent_name),
|
|
"entity_kwd": ent_name,
|
|
"knowledge_graph_kwd": "entity",
|
|
"entity_type_kwd": meta["entity_type"],
|
|
"content_with_weight": json.dumps(meta, ensure_ascii=False),
|
|
"content_ltks": rag_tokenizer.tokenize(meta["description"]),
|
|
"source_id": meta["source_id"],
|
|
# pagerank drives the P(E|Q) = pagerank * sim ranking in KGSearch; the
|
|
# n-hop neighbour paths feed its relation-enrichment step. Both are read
|
|
# back as `rank_flt` / `n_hop_with_weight` in rag/graphrag/search.py.
|
|
"rank_flt": float(meta.get("pagerank", 0) or 0),
|
|
"n_hop_with_weight": json.dumps(nhop_neighbors or [], ensure_ascii=False),
|
|
"kb_id": kb_id,
|
|
"available_int": 0,
|
|
}
|
|
chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"])
|
|
ebd = get_embed_cache(embd_mdl.llm_name, ent_name)
|
|
if ebd is None:
|
|
async with chat_limiter:
|
|
timeout = 3 if enable_timeout_assertion else 30000000
|
|
ebd, _ = await asyncio.wait_for(
|
|
thread_pool_exec(embd_mdl.encode, [ent_name]),
|
|
timeout=timeout
|
|
)
|
|
ebd = ebd[0]
|
|
set_embed_cache(embd_mdl.llm_name, ent_name, ebd)
|
|
assert ebd is not None
|
|
chunk["q_%d_vec" % len(ebd)] = ebd
|
|
chunks.append(chunk)
|
|
|
|
|
|
@timeout(3, 3)
|
|
async def get_relation(tenant_id, kb_id, from_ent_name, to_ent_name, size=1):
|
|
"""Retrieve edge metadata between entity names from the document store."""
|
|
ents = from_ent_name
|
|
if isinstance(ents, str):
|
|
ents = [from_ent_name]
|
|
if isinstance(to_ent_name, str):
|
|
to_ent_name = [to_ent_name]
|
|
ents.extend(to_ent_name)
|
|
ents = list(set(ents))
|
|
conds = {"fields": ["content_with_weight"], "size": size, "from_entity_kwd": ents, "to_entity_kwd": ents, "knowledge_graph_kwd": ["relation"]}
|
|
res = []
|
|
es_res = await settings.retriever.search(conds, search.index_name(tenant_id), [kb_id] if isinstance(kb_id, str) else kb_id)
|
|
for id in es_res.ids:
|
|
try:
|
|
if size == 1:
|
|
return json.loads(es_res.field[id]["content_with_weight"])
|
|
res.append(json.loads(es_res.field[id]["content_with_weight"]))
|
|
except Exception:
|
|
continue
|
|
return res
|
|
|
|
|
|
async def graph_edge_to_chunk(kb_id, embd_mdl, from_ent_name, to_ent_name, meta, chunks):
|
|
"""Convert a graph edge (relation) to an embeddable chunk and append it to *chunks*."""
|
|
enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
|
|
chunk = {
|
|
"id": get_uuid(),
|
|
"from_entity_kwd": from_ent_name,
|
|
"to_entity_kwd": to_ent_name,
|
|
"knowledge_graph_kwd": "relation",
|
|
"content_with_weight": json.dumps(meta, ensure_ascii=False),
|
|
"content_ltks": rag_tokenizer.tokenize(meta["description"]),
|
|
"important_kwd": meta["keywords"],
|
|
"source_id": meta["source_id"],
|
|
"weight_int": int(meta["weight"]),
|
|
"kb_id": kb_id,
|
|
"available_int": 0,
|
|
}
|
|
chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"])
|
|
txt = f"{from_ent_name}->{to_ent_name}"
|
|
ebd = get_embed_cache(embd_mdl.llm_name, txt)
|
|
if ebd is None:
|
|
async with chat_limiter:
|
|
timeout = 3 if enable_timeout_assertion else 300000000
|
|
ebd, _ = await asyncio.wait_for(
|
|
thread_pool_exec(
|
|
embd_mdl.encode,
|
|
[txt + f": {meta['description']}"]
|
|
),
|
|
timeout=timeout
|
|
)
|
|
ebd = ebd[0]
|
|
set_embed_cache(embd_mdl.llm_name, txt, ebd)
|
|
assert ebd is not None
|
|
chunk["q_%d_vec" % len(ebd)] = ebd
|
|
chunks.append(chunk)
|
|
|
|
|
|
async def does_graph_contains(tenant_id, kb_id, doc_id):
|
|
"""Return True if *doc_id* is recorded as a source document in the stored graph for *kb_id*."""
|
|
# Get doc_ids of graph
|
|
fields = ["source_id"]
|
|
condition = {
|
|
"knowledge_graph_kwd": ["graph"],
|
|
"removed_kwd": "N",
|
|
}
|
|
res = await thread_pool_exec(
|
|
settings.docStoreConn.search,
|
|
fields, [], condition, [], OrderByExpr(),
|
|
0, 1, search.index_name(tenant_id), [kb_id]
|
|
)
|
|
fields2 = settings.docStoreConn.get_fields(res, fields)
|
|
graph_doc_ids = set()
|
|
for chunk_id in fields2.keys():
|
|
graph_doc_ids = set(fields2[chunk_id]["source_id"])
|
|
return doc_id in graph_doc_ids
|
|
|
|
|
|
async def get_graph_doc_ids(tenant_id, kb_id) -> list[str]:
|
|
"""Return the list of document IDs referenced by the stored graph for *kb_id*."""
|
|
conds = {"fields": ["source_id"], "removed_kwd": "N", "size": 1, "knowledge_graph_kwd": ["graph"]}
|
|
res = await settings.retriever.search(conds, search.index_name(tenant_id), [kb_id])
|
|
doc_ids = []
|
|
if res.total == 0:
|
|
return doc_ids
|
|
for id in res.ids:
|
|
doc_ids = res.field[id]["source_id"]
|
|
return doc_ids
|
|
|
|
|
|
async def get_graph(tenant_id, kb_id, exclude_rebuild=None):
|
|
"""Load the knowledge-graph for *kb_id* from the document store, rebuilding if marked removed."""
|
|
conds = {"fields": ["content_with_weight", "removed_kwd", "source_id"], "size": 1, "knowledge_graph_kwd": ["graph"]}
|
|
res = await settings.retriever.search(conds, search.index_name(tenant_id), [kb_id])
|
|
if not res.total == 0:
|
|
for id in res.ids:
|
|
try:
|
|
if res.field[id]["removed_kwd"] == "N":
|
|
g = json_graph.node_link_graph(json.loads(res.field[id]["content_with_weight"]), edges="edges")
|
|
if "source_id" not in g.graph:
|
|
g.graph["source_id"] = res.field[id]["source_id"]
|
|
else:
|
|
g = await rebuild_graph(tenant_id, kb_id, exclude_rebuild)
|
|
return g
|
|
except Exception:
|
|
continue
|
|
result = None
|
|
return result
|
|
|
|
|
|
async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, change: GraphChange, callback):
|
|
"""Persist a knowledge-graph snapshot to the document store.
|
|
|
|
Converts *graph* nodes and edges to embedding chunks, pre-warms the Redis
|
|
embed cache for all cache-miss entities/relations in bulk before spawning
|
|
per-item tasks, then atomically replaces the old graph chunks in the store.
|
|
"""
|
|
global chat_limiter
|
|
start = asyncio.get_running_loop().time()
|
|
|
|
# Build all new chunks first (graph, subgraphs, node/edge embeddings) before
|
|
# deleting anything. This ensures that if embedding generation or any other
|
|
# step crashes, the old graph and per-doc subgraph checkpoints remain intact
|
|
# so the pipeline can resume without re-running earlier phases.
|
|
chunks = [
|
|
{
|
|
"id": get_uuid(),
|
|
"content_with_weight": json.dumps(nx.node_link_data(graph, edges="edges"), ensure_ascii=False),
|
|
"knowledge_graph_kwd": "graph",
|
|
"kb_id": kb_id,
|
|
"source_id": graph.graph.get("source_id", []),
|
|
"available_int": 0,
|
|
"removed_kwd": "N",
|
|
}
|
|
]
|
|
|
|
# generate updated subgraphs
|
|
for source in graph.graph["source_id"]:
|
|
subgraph = graph.subgraph([n for n in graph.nodes if source in graph.nodes[n]["source_id"]]).copy()
|
|
subgraph.graph["source_id"] = [source]
|
|
for n in subgraph.nodes:
|
|
subgraph.nodes[n]["source_id"] = [source]
|
|
chunks.append(
|
|
{
|
|
"id": get_uuid(),
|
|
"content_with_weight": json.dumps(nx.node_link_data(subgraph, edges="edges"), ensure_ascii=False),
|
|
"knowledge_graph_kwd": "subgraph",
|
|
"kb_id": kb_id,
|
|
"source_id": [source],
|
|
"available_int": 0,
|
|
"removed_kwd": "N",
|
|
}
|
|
)
|
|
|
|
# ── batch pre-warm entity embeddings ───────────────────────────────────────
|
|
# Without this, set_graph spawns one asyncio task per entity, each calling
|
|
# embd_mdl.encode([single_name]). For 17 k+ nodes that is 17 k round-trips.
|
|
# Pre-warming the cache here collapses N calls to ceil(N/_INSERT_BULK_SIZE).
|
|
_node_list = list(change.added_updated_nodes)
|
|
_node_misses = await thread_pool_exec(
|
|
_batch_embed_cache_misses, embd_mdl.llm_name, _node_list
|
|
)
|
|
_uncached_node_names = [n for n, miss in zip(_node_list, _node_misses) if miss]
|
|
logging.debug(
|
|
"set_graph node pre-warm: %d nodes, %d cache misses",
|
|
len(_node_list), len(_uncached_node_names),
|
|
)
|
|
if _uncached_node_names:
|
|
_enable_ta = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
|
|
_timeout = 3 if _enable_ta else 30000000
|
|
for _i in range(0, len(_uncached_node_names), _INSERT_BULK_SIZE):
|
|
_batch = _uncached_node_names[_i : _i + _INSERT_BULK_SIZE]
|
|
async with chat_limiter:
|
|
_ebds, _ = await asyncio.wait_for(
|
|
thread_pool_exec(embd_mdl.encode, _batch),
|
|
timeout=_timeout,
|
|
)
|
|
await thread_pool_exec(_write_embed_cache_batch, embd_mdl.llm_name, _batch, _ebds)
|
|
logging.debug(
|
|
"set_graph node pre-warm: wrote batch %d/%d (%d items)",
|
|
_i // _INSERT_BULK_SIZE + 1,
|
|
(len(_uncached_node_names) + _INSERT_BULK_SIZE - 1) // _INSERT_BULK_SIZE,
|
|
len(_batch),
|
|
)
|
|
if callback:
|
|
callback(msg=f"Batch-embedded {len(_uncached_node_names)} entity names "
|
|
f"({(len(_uncached_node_names) + _INSERT_BULK_SIZE - 1) // _INSERT_BULK_SIZE} "
|
|
f"batches of {_INSERT_BULK_SIZE}).")
|
|
# ── end batch pre-warm ──────────────────────────────────────────────────────
|
|
|
|
tasks = []
|
|
for ii, node in enumerate(change.added_updated_nodes):
|
|
node_attrs = graph.nodes[node]
|
|
nhop_neighbors = n_neighbor(graph, node)
|
|
tasks.append(asyncio.create_task(
|
|
graph_node_to_chunk(kb_id, embd_mdl, node, node_attrs, chunks, nhop_neighbors)
|
|
))
|
|
if ii % 100 == 9 and callback:
|
|
callback(msg=f"Get embedding of nodes: {ii}/{len(change.added_updated_nodes)}")
|
|
try:
|
|
await asyncio.gather(*tasks, return_exceptions=False)
|
|
except Exception as e:
|
|
logging.error(f"Error in get_embedding_of_nodes: {e}")
|
|
for t in tasks:
|
|
t.cancel()
|
|
await asyncio.gather(*tasks, return_exceptions=True)
|
|
raise
|
|
|
|
# ── batch pre-warm edge embeddings ─────────────────────────────────────────
|
|
# Mirror of the node pre-warm above for relation chunks.
|
|
# Cache key = "A->B" (matches graph_edge_to_chunk lookup key)
|
|
# Encoded text = "A->B: <description>" (matches graph_edge_to_chunk encode text)
|
|
_all_edge_data = [
|
|
(_fn, _tn, graph.get_edge_data(_fn, _tn))
|
|
for _fn, _tn in change.added_updated_edges
|
|
]
|
|
_all_edge_data = [(f, t, a) for f, t, a in _all_edge_data if a]
|
|
_edge_lookup_keys = [f"{f}->{t}" for f, t, _ in _all_edge_data]
|
|
_edge_misses = await thread_pool_exec(
|
|
_batch_embed_cache_misses, embd_mdl.llm_name, _edge_lookup_keys
|
|
) if _all_edge_data else []
|
|
_uncached_edge_items = [item for item, miss in zip(_all_edge_data, _edge_misses) if miss]
|
|
logging.debug(
|
|
"set_graph edge pre-warm: %d edges, %d cache misses",
|
|
len(_all_edge_data), len(_uncached_edge_items),
|
|
)
|
|
if _uncached_edge_items:
|
|
_edge_keys = [f"{f}->{t}" for f, t, _ in _uncached_edge_items]
|
|
_edge_texts = [f"{f}->{t}: {a['description']}" for f, t, a in _uncached_edge_items]
|
|
_enable_ta = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
|
|
_timeout = 3 if _enable_ta else 30000000
|
|
for _i in range(0, len(_edge_texts), _INSERT_BULK_SIZE):
|
|
_btexts = _edge_texts[_i : _i + _INSERT_BULK_SIZE]
|
|
_bkeys = _edge_keys [_i : _i + _INSERT_BULK_SIZE]
|
|
async with chat_limiter:
|
|
_ebds, _ = await asyncio.wait_for(
|
|
thread_pool_exec(embd_mdl.encode, _btexts),
|
|
timeout=_timeout,
|
|
)
|
|
await thread_pool_exec(_write_embed_cache_batch, embd_mdl.llm_name, _bkeys, _ebds)
|
|
logging.debug(
|
|
"set_graph edge pre-warm: wrote batch %d/%d (%d items)",
|
|
_i // _INSERT_BULK_SIZE + 1,
|
|
(len(_uncached_edge_items) + _INSERT_BULK_SIZE - 1) // _INSERT_BULK_SIZE,
|
|
len(_btexts),
|
|
)
|
|
if callback:
|
|
callback(msg=f"Batch-embedded {len(_uncached_edge_items)} edge descriptions "
|
|
f"({(len(_uncached_edge_items) + _INSERT_BULK_SIZE - 1) // _INSERT_BULK_SIZE} "
|
|
f"batches of {_INSERT_BULK_SIZE}).")
|
|
# ── end batch pre-warm ──────────────────────────────────────────────────────
|
|
|
|
tasks = []
|
|
for ii, (from_node, to_node) in enumerate(change.added_updated_edges):
|
|
edge_attrs = graph.get_edge_data(from_node, to_node)
|
|
if not edge_attrs:
|
|
continue
|
|
tasks.append(asyncio.create_task(
|
|
graph_edge_to_chunk(kb_id, embd_mdl, from_node, to_node, edge_attrs, chunks)
|
|
))
|
|
if ii % 100 == 9 and callback:
|
|
callback(msg=f"Get embedding of edges: {ii}/{len(change.added_updated_edges)}")
|
|
try:
|
|
await asyncio.gather(*tasks, return_exceptions=False)
|
|
except Exception as e:
|
|
logging.error(f"Error in get_embedding_of_edges: {e}")
|
|
for t in tasks:
|
|
t.cancel()
|
|
await asyncio.gather(*tasks, return_exceptions=True)
|
|
raise
|
|
|
|
now = asyncio.get_running_loop().time()
|
|
if callback:
|
|
callback(msg=f"set_graph converted graph change to {len(chunks)} chunks in {now - start:.2f}s.")
|
|
start = now
|
|
|
|
# All new chunks are ready. Now delete old data and insert the new data.
|
|
# Deleting only after chunks are built ensures that a crash during embedding
|
|
# generation above does not destroy the old graph/subgraph checkpoints.
|
|
await thread_pool_exec(
|
|
settings.docStoreConn.delete,
|
|
{"knowledge_graph_kwd": ["graph", "subgraph"]},
|
|
search.index_name(tenant_id),
|
|
kb_id
|
|
)
|
|
|
|
if change.removed_nodes:
|
|
BATCH_SIZE = 100
|
|
sorted_nodes = sorted(change.removed_nodes)
|
|
for i in range(0, len(sorted_nodes), BATCH_SIZE):
|
|
batch = sorted_nodes[i:i + BATCH_SIZE]
|
|
await thread_pool_exec(
|
|
settings.docStoreConn.delete,
|
|
{"knowledge_graph_kwd": ["entity"], "entity_kwd": batch},
|
|
search.index_name(tenant_id),
|
|
kb_id
|
|
)
|
|
|
|
if change.removed_edges:
|
|
|
|
async def del_edges(from_node, to_node):
|
|
max_retries = 3
|
|
for attempt in range(max_retries):
|
|
try:
|
|
async with chat_limiter:
|
|
await thread_pool_exec(
|
|
settings.docStoreConn.delete,
|
|
{"knowledge_graph_kwd": ["relation"], "from_entity_kwd": from_node, "to_entity_kwd": to_node},
|
|
search.index_name(tenant_id),
|
|
kb_id
|
|
)
|
|
return
|
|
except Exception as e:
|
|
if attempt < max_retries - 1:
|
|
wait = 2 ** attempt
|
|
logging.warning(f"del_edges({from_node}, {to_node}) attempt {attempt + 1} failed: {e}, retrying in {wait}s")
|
|
await asyncio.sleep(wait)
|
|
else:
|
|
raise
|
|
|
|
tasks = []
|
|
for from_node, to_node in change.removed_edges:
|
|
tasks.append(asyncio.create_task(del_edges(from_node, to_node)))
|
|
|
|
try:
|
|
await asyncio.gather(*tasks, return_exceptions=False)
|
|
except Exception as e:
|
|
logging.error(f"Error while deleting edges: {e}")
|
|
for t in tasks:
|
|
t.cancel()
|
|
await asyncio.gather(*tasks, return_exceptions=True)
|
|
raise
|
|
|
|
del_now = asyncio.get_running_loop().time()
|
|
if callback:
|
|
callback(msg=f"set_graph removed {len(change.removed_nodes)} nodes and {len(change.removed_edges)} edges from index in {del_now - start:.2f}s.")
|
|
start = del_now
|
|
|
|
await insert_chunks_bounded(chunks, tenant_id, kb_id, callback=callback, label="Insert chunks")
|
|
now = asyncio.get_running_loop().time()
|
|
if callback:
|
|
callback(msg=f"set_graph added/updated {len(change.added_updated_nodes)} nodes and {len(change.added_updated_edges)} edges from index in {now - start:.2f}s.")
|
|
|
|
|
|
def is_continuous_subsequence(subseq, seq):
|
|
"""Return True if *subseq* appears as a contiguous sub-path within tuple *seq*."""
|
|
def find_all_indexes(tup, value):
|
|
indexes = []
|
|
start = 0
|
|
while True:
|
|
try:
|
|
index = tup.index(value, start)
|
|
indexes.append(index)
|
|
start = index + 1
|
|
except ValueError:
|
|
break
|
|
return indexes
|
|
|
|
index_list = find_all_indexes(seq, subseq[0])
|
|
for idx in index_list:
|
|
if idx != len(seq) - 1:
|
|
if seq[idx + 1] == subseq[-1]:
|
|
return True
|
|
return False
|
|
|
|
|
|
def merge_tuples(list1, list2):
|
|
"""Extend each path tuple in *list1* by matching continuations found in *list2*."""
|
|
result = []
|
|
for tup in list1:
|
|
last_element = tup[-1]
|
|
if last_element in tup[:-1]:
|
|
result.append(tup)
|
|
else:
|
|
matching_tuples = [t for t in list2 if t[0] == last_element]
|
|
already_match_flag = 0
|
|
for match in matching_tuples:
|
|
matchh = (match[1], match[0])
|
|
if is_continuous_subsequence(match, tup) or is_continuous_subsequence(matchh, tup):
|
|
continue
|
|
already_match_flag = 1
|
|
merged_tuple = tup + match[1:]
|
|
result.append(merged_tuple)
|
|
if not already_match_flag:
|
|
result.append(tup)
|
|
return result
|
|
|
|
|
|
def n_neighbor(graph: nx.Graph, node, n_hop: int = 2):
|
|
"""Enumerate paths of up to ``n_hop`` edges starting at ``node`` together
|
|
with the edge weight along each step.
|
|
|
|
Returns a list of ``{"path": (n0, n1, ...), "weights": [w0, w1, ...]}``
|
|
dicts (``len(weights) == len(path) - 1``). This is the structure consumed
|
|
by :class:`rag.graphrag.search.KGSearch` for n-hop relation enrichment and
|
|
is stored per entity chunk as ``n_hop_with_weight``.
|
|
"""
|
|
source_edge = list(graph.edges(node))
|
|
if not source_edge:
|
|
return []
|
|
count = 1
|
|
while count < n_hop:
|
|
count += 1
|
|
sc_edge = deepcopy(source_edge)
|
|
source_edge = []
|
|
for pair in sc_edge:
|
|
append_edge = list(graph.edges(pair[-1]))
|
|
for tuples in merge_tuples([pair], append_edge):
|
|
source_edge.append(tuples)
|
|
wts = nx.get_edge_attributes(graph, "weight")
|
|
nbrs = []
|
|
for path in source_edge:
|
|
nbr = {"path": path, "weights": []}
|
|
for i in range(len(path) - 1):
|
|
f, t = path[i], path[i + 1]
|
|
w = wts.get((f, t))
|
|
if w is None:
|
|
w = wts.get((t, f), 0)
|
|
nbr["weights"].append(w)
|
|
nbrs.append(nbr)
|
|
return nbrs
|
|
|
|
|
|
async def get_entity_type2samples(idxnms, kb_ids: list):
|
|
"""Return a mapping of entity type → sample entity names fetched from the document store."""
|
|
es_res = await settings.retriever.search({"knowledge_graph_kwd": "ty2ents", "kb_id": kb_ids, "size": 10000, "fields": ["content_with_weight"]},idxnms,kb_ids)
|
|
|
|
res = defaultdict(list)
|
|
for id in es_res.ids:
|
|
smp = es_res.field[id].get("content_with_weight")
|
|
if not smp:
|
|
continue
|
|
try:
|
|
smp = json.loads(smp)
|
|
except Exception as e:
|
|
logging.exception(e)
|
|
|
|
for ty, ents in smp.items():
|
|
res[ty].extend(ents)
|
|
return res
|
|
|
|
|
|
def flat_uniq_list(arr, key):
|
|
"""Flatten and deduplicate the values at *key* across a list of dicts."""
|
|
res = []
|
|
for a in arr:
|
|
a = a[key]
|
|
if isinstance(a, list):
|
|
res.extend(a)
|
|
else:
|
|
res.append(a)
|
|
return list(set(res))
|
|
|
|
|
|
async def rebuild_graph(tenant_id, kb_id, exclude_rebuild=None):
|
|
"""Reconstruct the full knowledge-graph for *kb_id* from its stored subgraph chunks."""
|
|
graph = nx.Graph()
|
|
flds = ["knowledge_graph_kwd", "content_with_weight", "source_id"]
|
|
bs = 256
|
|
for i in range(0, 1024 * bs, bs):
|
|
es_res = await thread_pool_exec(
|
|
settings.docStoreConn.search,
|
|
flds, [], {"kb_id": kb_id, "knowledge_graph_kwd": ["subgraph"]},
|
|
[], OrderByExpr(), i, bs, search.index_name(tenant_id), [kb_id]
|
|
)
|
|
# tot = settings.docStoreConn.get_total(es_res)
|
|
es_res = settings.docStoreConn.get_fields(es_res, flds)
|
|
|
|
if len(es_res) == 0:
|
|
break
|
|
|
|
for id, d in es_res.items():
|
|
assert d["knowledge_graph_kwd"] == "subgraph"
|
|
if isinstance(exclude_rebuild, list):
|
|
if sum([n in d["source_id"] for n in exclude_rebuild]):
|
|
continue
|
|
elif exclude_rebuild in d["source_id"]:
|
|
continue
|
|
|
|
next_graph = json_graph.node_link_graph(json.loads(d["content_with_weight"]), edges="edges")
|
|
merged_graph = nx.compose(graph, next_graph)
|
|
merged_source = {n: graph.nodes[n]["source_id"] + next_graph.nodes[n]["source_id"] for n in graph.nodes & next_graph.nodes}
|
|
nx.set_node_attributes(merged_graph, merged_source, "source_id")
|
|
if "source_id" in graph.graph:
|
|
merged_graph.graph["source_id"] = graph.graph["source_id"] + next_graph.graph["source_id"]
|
|
else:
|
|
merged_graph.graph["source_id"] = next_graph.graph["source_id"]
|
|
graph = merged_graph
|
|
|
|
if len(graph.nodes) == 0:
|
|
return None
|
|
graph.graph["source_id"] = sorted(graph.graph["source_id"])
|
|
return graph
|