# Copyright (c) 2024 Microsoft Corporation. # Licensed under the MIT License from common.misc_utils import thread_pool_exec """ Reference: - [graphrag](https://github.com/microsoft/graphrag) - [LightRag](https://github.com/HKUDS/LightRAG) """ import asyncio import dataclasses import html import json import logging import os import re import time from collections import defaultdict from copy import deepcopy from hashlib import md5 from typing import Any, Callable, Set, Tuple import networkx as nx import numpy as np import xxhash from networkx.readwrite import json_graph from common.misc_utils import get_uuid from common.connection_utils import timeout from common.asyncio_utils import LoopLocalSemaphore from rag.nlp import rag_tokenizer, search from rag.utils.redis_conn import REDIS_CONN from common import settings from common.doc_store.doc_store_base import OrderByExpr GRAPH_FIELD_SEP = "" ErrorHandlerFn = Callable[[BaseException | None, str | None, dict | None], None] chat_limiter = LoopLocalSemaphore(int(os.environ.get("MAX_CONCURRENT_CHATS", 10))) # Doc-store insert batching for GraphRAG subgraph/node/edge/community_report # chunks. Defaults (64 docs per batch, up to 4 batches in flight) mirror the # regular ingest pipeline in document_service.py while still keeping the total # number of simultaneous requests to ES/Infinity bounded. Override with # GRAPHRAG_INSERT_BULK_SIZE and GRAPHRAG_INSERT_CONCURRENCY. _INSERT_BULK_SIZE = max(1, int(os.environ.get("GRAPHRAG_INSERT_BULK_SIZE", 64))) _INSERT_CONCURRENCY = max(1, int(os.environ.get("GRAPHRAG_INSERT_CONCURRENCY", 4))) async def insert_chunks_bounded(chunks, tenant_id, kb_id, *, callback=None, label="Insert chunks"): """Insert ``chunks`` into the doc store in batches with bounded concurrency and retries. Batch size is controlled by ``GRAPHRAG_INSERT_BULK_SIZE`` (default 64) and the number of batches in flight by ``GRAPHRAG_INSERT_CONCURRENCY`` (default 4). Each batch has the same retry / timeout behaviour as the previous hand-rolled loop (3 attempts, exponential backoff). Raises the first unrecoverable error; other in-flight batches are then cancelled by ``asyncio.gather``. """ if not chunks: return enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION") sem = asyncio.Semaphore(_INSERT_CONCURRENCY) total = len(chunks) progress = {"done": 0, "next_report": 100} progress_lock = asyncio.Lock() async def _one(offset: int) -> None: batch = chunks[offset : offset + _INSERT_BULK_SIZE] timeout_s = 3 if enable_timeout_assertion else 30000000 max_retries = 3 async with sem: for attempt in range(max_retries): try: result = await asyncio.wait_for( thread_pool_exec( settings.docStoreConn.insert, batch, search.index_name(tenant_id), kb_id, ), timeout=timeout_s, ) if result: raise Exception(f"Insert chunk error: {result}, please check log file and Elasticsearch/Infinity status!") break except asyncio.TimeoutError: if attempt < max_retries - 1: wait = 2 ** attempt logging.warning(f"Insert batch at offset {offset}/{total} attempt {attempt + 1} timed out, retrying in {wait}s") await asyncio.sleep(wait) else: raise except asyncio.CancelledError: raise except Exception as e: if attempt < max_retries - 1: wait = 2 ** attempt logging.warning(f"Insert batch at offset {offset}/{total} attempt {attempt + 1} failed: {e}, retrying in {wait}s") await asyncio.sleep(wait) else: raise if callback: async with progress_lock: progress["done"] += len(batch) if progress["done"] >= progress["next_report"] or progress["done"] == total: callback(msg=f"{label}: {progress['done']}/{total}") progress["next_report"] = progress["done"] + 100 await asyncio.gather(*(asyncio.create_task(_one(o)) for o in range(0, total, _INSERT_BULK_SIZE))) @dataclasses.dataclass class GraphChange: removed_nodes: Set[str] = dataclasses.field(default_factory=set) added_updated_nodes: Set[str] = dataclasses.field(default_factory=set) removed_edges: Set[Tuple[str, str]] = dataclasses.field(default_factory=set) added_updated_edges: Set[Tuple[str, str]] = dataclasses.field(default_factory=set) def perform_variable_replacements(input: str, history: list[dict] | None = None, variables: dict | None = None) -> str: """Perform variable replacements on the input string and in a chat log.""" if history is None: history = [] if variables is None: variables = {} result = input def replace_all(input: str) -> str: result = input for k, v in variables.items(): result = result.replace(f"{{{k}}}", str(v)) return result result = replace_all(result) for i, entry in enumerate(history): if entry.get("role") == "system": entry["content"] = replace_all(entry.get("content") or "") return result def clean_str(input: Any) -> str: """Clean an input string by removing HTML escapes, control characters, and other unwanted characters.""" # If we get non-string input, just give it back if not isinstance(input, str): return input result = html.unescape(input.strip()) # https://stackoverflow.com/questions/4324790/removing-control-characters-from-a-string-in-python return re.sub(r"[\"\x00-\x1f\x7f-\x9f]", "", result) def dict_has_keys_with_types(data: dict, expected_fields: list[tuple[str, type]]) -> bool: """Return True if the given dictionary has the given keys with the given types.""" for field, field_type in expected_fields: if field not in data: return False value = data[field] if not isinstance(value, field_type): return False return True def get_llm_cache(llmnm, txt, history, genconf): """Return a cached LLM completion for the given model/text/history/config, or None on miss.""" hasher = xxhash.xxh64() hasher.update((str(llmnm)+str(txt)+str(history)+str(genconf)).encode("utf-8")) k = hasher.hexdigest() bin = REDIS_CONN.get(k) if not bin: return None return bin def set_llm_cache(llmnm, txt, v, history, genconf): """Store an LLM completion *v* in Redis keyed by a hash of model/text/history/config.""" hasher = xxhash.xxh64() hasher.update((str(llmnm)+str(txt)+str(history)+str(genconf)).encode("utf-8")) k = hasher.hexdigest() REDIS_CONN.set(k, v.encode("utf-8"), 24 * 3600) def get_embed_cache(llmnm, txt): """Return a cached embedding vector (numpy array) for *llmnm*/*txt*, or None on miss.""" hasher = xxhash.xxh64() hasher.update(str(llmnm).encode("utf-8")) hasher.update(str(txt).encode("utf-8")) k = hasher.hexdigest() bin = REDIS_CONN.get(k) if not bin: return return np.array(json.loads(bin)) def set_embed_cache(llmnm, txt, arr): """Store embedding *arr* in Redis for the given model name and input text.""" hasher = xxhash.xxh64() hasher.update(str(llmnm).encode("utf-8")) hasher.update(str(txt).encode("utf-8")) k = hasher.hexdigest() arr = json.dumps(arr.tolist() if isinstance(arr, np.ndarray) else arr) REDIS_CONN.set(k, arr.encode("utf-8"), 24 * 3600) def _batch_embed_cache_misses(llmnm: str, keys: list) -> "list[bool]": """Return a boolean miss-mask for *keys* using a single MGET round-trip. Avoids per-item REDIS_CONN.get() calls (which would block the event loop when called from an async context) by issuing one batched MGET instead. """ if not keys: return [] hashes = [] for key in keys: h = xxhash.xxh64() h.update(str(llmnm).encode("utf-8")) h.update(str(key).encode("utf-8")) hashes.append(h.hexdigest()) return [v is None for v in REDIS_CONN.mget(hashes)] def _write_embed_cache_batch(llmnm: str, keys: list, embeddings) -> None: """Write a batch of embeddings to the Redis embed cache synchronously. Intended for use with thread_pool_exec so that the synchronous Redis SET calls do not block the event loop. """ for key, ebd in zip(keys, embeddings): set_embed_cache(llmnm, key, ebd) def get_tags_from_cache(kb_ids): """Return cached tag data for the given kb_ids from Redis, or None on miss.""" hasher = xxhash.xxh64() hasher.update(str(kb_ids).encode("utf-8")) k = hasher.hexdigest() bin = REDIS_CONN.get(k) if not bin: return return bin def set_tags_to_cache(kb_ids, tags): """Persist tag data for *kb_ids* in Redis.""" hasher = xxhash.xxh64() hasher.update(str(kb_ids).encode("utf-8")) k = hasher.hexdigest() REDIS_CONN.set(k, json.dumps(tags).encode("utf-8"), 600) def tidy_graph(graph: nx.Graph, callback, check_attribute: bool = True): """ Ensure all nodes and edges in the graph have some essential attribute. """ def is_valid_item(node_attrs: dict) -> bool: valid_node = True for attr in ["description", "source_id"]: if attr not in node_attrs: valid_node = False break return valid_node if check_attribute: purged_nodes = [] for node, node_attrs in graph.nodes(data=True): if not is_valid_item(node_attrs): purged_nodes.append(node) for node in purged_nodes: graph.remove_node(node) if purged_nodes and callback: callback(msg=f"Purged {len(purged_nodes)} nodes from graph due to missing essential attributes.") purged_edges = [] for source, target, attr in graph.edges(data=True): if check_attribute: if not is_valid_item(attr): purged_edges.append((source, target)) if "keywords" not in attr: attr["keywords"] = [] for source, target in purged_edges: graph.remove_edge(source, target) if purged_edges and callback: callback(msg=f"Purged {len(purged_edges)} edges from graph due to missing essential attributes.") def get_from_to(node1, node2): """Return a canonical (lesser, greater) node pair for consistent undirected edge keying.""" if node1 < node2: return (node1, node2) else: return (node2, node1) def graph_merge(g1: nx.Graph, g2: nx.Graph, change: GraphChange): """Merge graph g2 into g1 in place.""" for node_name, attr in g2.nodes(data=True): change.added_updated_nodes.add(node_name) if not g1.has_node(node_name): g1.add_node(node_name, **attr) continue node = g1.nodes[node_name] node["description"] += GRAPH_FIELD_SEP + attr["description"] # A node's source_id indicates which chunks it came from. node["source_id"] += attr["source_id"] for source, target, attr in g2.edges(data=True): change.added_updated_edges.add(get_from_to(source, target)) edge = g1.get_edge_data(source, target) if edge is None: g1.add_edge(source, target, **attr) continue edge["weight"] += attr.get("weight", 0) edge["description"] += GRAPH_FIELD_SEP + attr["description"] edge["keywords"] += attr["keywords"] # A edge's source_id indicates which chunks it came from. edge["source_id"] += attr["source_id"] for node_degree in g1.degree: g1.nodes[str(node_degree[0])]["rank"] = int(node_degree[1]) # A graph's source_id indicates which documents it came from. if "source_id" not in g1.graph: g1.graph["source_id"] = [] g1.graph["source_id"] += g2.graph.get("source_id", []) return g1 def compute_args_hash(*args): """Return a hex MD5 digest of the string representation of *args* (used as a cache key).""" return md5(str(args).encode()).hexdigest() def handle_single_entity_extraction( record_attributes: list[str], chunk_key: str, ): """Parse one entity record from LLM output and return a node-attribute dict, or None.""" if len(record_attributes) < 4 or record_attributes[0] != '"entity"': return None # add this record as a node in the G entity_name = clean_str(record_attributes[1].upper()) if not entity_name.strip(): return None entity_type = clean_str(record_attributes[2].upper()) entity_description = clean_str(record_attributes[3]) entity_source_id = chunk_key return dict( entity_name=entity_name.upper(), entity_type=entity_type.upper(), description=entity_description, source_id=entity_source_id, ) def handle_single_relationship_extraction(record_attributes: list[str], chunk_key: str): """Parse one relationship record from LLM output and return an edge-attribute dict, or None.""" if len(record_attributes) < 5 or record_attributes[0] != '"relationship"': return None # add this record as edge source = clean_str(record_attributes[1].upper()) target = clean_str(record_attributes[2].upper()) edge_description = clean_str(record_attributes[3]) edge_keywords = clean_str(record_attributes[4]) edge_source_id = chunk_key weight = float(record_attributes[-1]) if is_float_regex(record_attributes[-1]) else 1.0 pair = sorted([source.upper(), target.upper()]) return dict( src_id=pair[0], tgt_id=pair[1], weight=weight, description=edge_description, keywords=edge_keywords, source_id=edge_source_id, metadata={"created_at": time.time()}, ) def pack_user_ass_to_openai_messages(*args: str): """Interleave *args* as alternating user/assistant messages in OpenAI chat format.""" roles = ["user", "assistant"] return [{"role": roles[i % 2], "content": content} for i, content in enumerate(args)] def split_string_by_multi_markers(content: str, markers: list[str]) -> list[str]: """Split a string by multiple markers""" if not markers: return [content] results = re.split("|".join(re.escape(marker) for marker in markers), content) return [r.strip() for r in results if r.strip()] def is_float_regex(value): """Return True if *value* is a string representation of a float or integer.""" return bool(re.match(r"^[-+]?[0-9]*\.?[0-9]+$", value)) def chunk_id(chunk): """Return a deterministic hex ID for *chunk* derived from its content and kb_id.""" return xxhash.xxh64((chunk["content_with_weight"] + chunk["kb_id"]).encode("utf-8")).hexdigest() async def graph_node_to_chunk(kb_id, embd_mdl, ent_name, meta, chunks, nhop_neighbors=None): """Convert a graph node (entity) to an embeddable chunk and append it to *chunks*.""" global chat_limiter enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION") chunk = { "id": get_uuid(), "important_kwd": [ent_name], "title_tks": rag_tokenizer.tokenize(ent_name), "entity_kwd": ent_name, "knowledge_graph_kwd": "entity", "entity_type_kwd": meta["entity_type"], "content_with_weight": json.dumps(meta, ensure_ascii=False), "content_ltks": rag_tokenizer.tokenize(meta["description"]), "source_id": meta["source_id"], # pagerank drives the P(E|Q) = pagerank * sim ranking in KGSearch; the # n-hop neighbour paths feed its relation-enrichment step. Both are read # back as `rank_flt` / `n_hop_with_weight` in rag/graphrag/search.py. "rank_flt": float(meta.get("pagerank", 0) or 0), "n_hop_with_weight": json.dumps(nhop_neighbors or [], ensure_ascii=False), "kb_id": kb_id, "available_int": 0, } chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"]) ebd = get_embed_cache(embd_mdl.llm_name, ent_name) if ebd is None: async with chat_limiter: timeout = 3 if enable_timeout_assertion else 30000000 ebd, _ = await asyncio.wait_for( thread_pool_exec(embd_mdl.encode, [ent_name]), timeout=timeout ) ebd = ebd[0] set_embed_cache(embd_mdl.llm_name, ent_name, ebd) assert ebd is not None chunk["q_%d_vec" % len(ebd)] = ebd chunks.append(chunk) @timeout(3, 3) async def get_relation(tenant_id, kb_id, from_ent_name, to_ent_name, size=1): """Retrieve edge metadata between entity names from the document store.""" ents = from_ent_name if isinstance(ents, str): ents = [from_ent_name] if isinstance(to_ent_name, str): to_ent_name = [to_ent_name] ents.extend(to_ent_name) ents = list(set(ents)) conds = {"fields": ["content_with_weight"], "size": size, "from_entity_kwd": ents, "to_entity_kwd": ents, "knowledge_graph_kwd": ["relation"]} res = [] es_res = await settings.retriever.search(conds, search.index_name(tenant_id), [kb_id] if isinstance(kb_id, str) else kb_id) for id in es_res.ids: try: if size == 1: return json.loads(es_res.field[id]["content_with_weight"]) res.append(json.loads(es_res.field[id]["content_with_weight"])) except Exception: continue return res async def graph_edge_to_chunk(kb_id, embd_mdl, from_ent_name, to_ent_name, meta, chunks): """Convert a graph edge (relation) to an embeddable chunk and append it to *chunks*.""" enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION") chunk = { "id": get_uuid(), "from_entity_kwd": from_ent_name, "to_entity_kwd": to_ent_name, "knowledge_graph_kwd": "relation", "content_with_weight": json.dumps(meta, ensure_ascii=False), "content_ltks": rag_tokenizer.tokenize(meta["description"]), "important_kwd": meta["keywords"], "source_id": meta["source_id"], "weight_int": int(meta["weight"]), "kb_id": kb_id, "available_int": 0, } chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"]) txt = f"{from_ent_name}->{to_ent_name}" ebd = get_embed_cache(embd_mdl.llm_name, txt) if ebd is None: async with chat_limiter: timeout = 3 if enable_timeout_assertion else 300000000 ebd, _ = await asyncio.wait_for( thread_pool_exec( embd_mdl.encode, [txt + f": {meta['description']}"] ), timeout=timeout ) ebd = ebd[0] set_embed_cache(embd_mdl.llm_name, txt, ebd) assert ebd is not None chunk["q_%d_vec" % len(ebd)] = ebd chunks.append(chunk) async def does_graph_contains(tenant_id, kb_id, doc_id): """Return True if *doc_id* is recorded as a source document in the stored graph for *kb_id*.""" # Get doc_ids of graph fields = ["source_id"] condition = { "knowledge_graph_kwd": ["graph"], "removed_kwd": "N", } res = await thread_pool_exec( settings.docStoreConn.search, fields, [], condition, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [kb_id] ) fields2 = settings.docStoreConn.get_fields(res, fields) graph_doc_ids = set() for chunk_id in fields2.keys(): graph_doc_ids = set(fields2[chunk_id]["source_id"]) return doc_id in graph_doc_ids async def get_graph_doc_ids(tenant_id, kb_id) -> list[str]: """Return the list of document IDs referenced by the stored graph for *kb_id*.""" conds = {"fields": ["source_id"], "removed_kwd": "N", "size": 1, "knowledge_graph_kwd": ["graph"]} res = await settings.retriever.search(conds, search.index_name(tenant_id), [kb_id]) doc_ids = [] if res.total == 0: return doc_ids for id in res.ids: doc_ids = res.field[id]["source_id"] return doc_ids async def get_graph(tenant_id, kb_id, exclude_rebuild=None): """Load the knowledge-graph for *kb_id* from the document store, rebuilding if marked removed.""" conds = {"fields": ["content_with_weight", "removed_kwd", "source_id"], "size": 1, "knowledge_graph_kwd": ["graph"]} res = await settings.retriever.search(conds, search.index_name(tenant_id), [kb_id]) if not res.total == 0: for id in res.ids: try: if res.field[id]["removed_kwd"] == "N": g = json_graph.node_link_graph(json.loads(res.field[id]["content_with_weight"]), edges="edges") if "source_id" not in g.graph: g.graph["source_id"] = res.field[id]["source_id"] else: g = await rebuild_graph(tenant_id, kb_id, exclude_rebuild) return g except Exception: continue result = None return result async def set_graph(tenant_id: str, kb_id: str, embd_mdl, graph: nx.Graph, change: GraphChange, callback): """Persist a knowledge-graph snapshot to the document store. Converts *graph* nodes and edges to embedding chunks, pre-warms the Redis embed cache for all cache-miss entities/relations in bulk before spawning per-item tasks, then atomically replaces the old graph chunks in the store. """ global chat_limiter start = asyncio.get_running_loop().time() # Build all new chunks first (graph, subgraphs, node/edge embeddings) before # deleting anything. This ensures that if embedding generation or any other # step crashes, the old graph and per-doc subgraph checkpoints remain intact # so the pipeline can resume without re-running earlier phases. chunks = [ { "id": get_uuid(), "content_with_weight": json.dumps(nx.node_link_data(graph, edges="edges"), ensure_ascii=False), "knowledge_graph_kwd": "graph", "kb_id": kb_id, "source_id": graph.graph.get("source_id", []), "available_int": 0, "removed_kwd": "N", } ] # generate updated subgraphs for source in graph.graph["source_id"]: subgraph = graph.subgraph([n for n in graph.nodes if source in graph.nodes[n]["source_id"]]).copy() subgraph.graph["source_id"] = [source] for n in subgraph.nodes: subgraph.nodes[n]["source_id"] = [source] chunks.append( { "id": get_uuid(), "content_with_weight": json.dumps(nx.node_link_data(subgraph, edges="edges"), ensure_ascii=False), "knowledge_graph_kwd": "subgraph", "kb_id": kb_id, "source_id": [source], "available_int": 0, "removed_kwd": "N", } ) # ── batch pre-warm entity embeddings ─────────────────────────────────────── # Without this, set_graph spawns one asyncio task per entity, each calling # embd_mdl.encode([single_name]). For 17 k+ nodes that is 17 k round-trips. # Pre-warming the cache here collapses N calls to ceil(N/_INSERT_BULK_SIZE). _node_list = list(change.added_updated_nodes) _node_misses = await thread_pool_exec( _batch_embed_cache_misses, embd_mdl.llm_name, _node_list ) _uncached_node_names = [n for n, miss in zip(_node_list, _node_misses) if miss] logging.debug( "set_graph node pre-warm: %d nodes, %d cache misses", len(_node_list), len(_uncached_node_names), ) if _uncached_node_names: _enable_ta = os.environ.get("ENABLE_TIMEOUT_ASSERTION") _timeout = 3 if _enable_ta else 30000000 for _i in range(0, len(_uncached_node_names), _INSERT_BULK_SIZE): _batch = _uncached_node_names[_i : _i + _INSERT_BULK_SIZE] async with chat_limiter: _ebds, _ = await asyncio.wait_for( thread_pool_exec(embd_mdl.encode, _batch), timeout=_timeout, ) await thread_pool_exec(_write_embed_cache_batch, embd_mdl.llm_name, _batch, _ebds) logging.debug( "set_graph node pre-warm: wrote batch %d/%d (%d items)", _i // _INSERT_BULK_SIZE + 1, (len(_uncached_node_names) + _INSERT_BULK_SIZE - 1) // _INSERT_BULK_SIZE, len(_batch), ) if callback: callback(msg=f"Batch-embedded {len(_uncached_node_names)} entity names " f"({(len(_uncached_node_names) + _INSERT_BULK_SIZE - 1) // _INSERT_BULK_SIZE} " f"batches of {_INSERT_BULK_SIZE}).") # ── end batch pre-warm ────────────────────────────────────────────────────── tasks = [] for ii, node in enumerate(change.added_updated_nodes): node_attrs = graph.nodes[node] nhop_neighbors = n_neighbor(graph, node) tasks.append(asyncio.create_task( graph_node_to_chunk(kb_id, embd_mdl, node, node_attrs, chunks, nhop_neighbors) )) if ii % 100 == 9 and callback: callback(msg=f"Get embedding of nodes: {ii}/{len(change.added_updated_nodes)}") try: await asyncio.gather(*tasks, return_exceptions=False) except Exception as e: logging.error(f"Error in get_embedding_of_nodes: {e}") for t in tasks: t.cancel() await asyncio.gather(*tasks, return_exceptions=True) raise # ── batch pre-warm edge embeddings ───────────────────────────────────────── # Mirror of the node pre-warm above for relation chunks. # Cache key = "A->B" (matches graph_edge_to_chunk lookup key) # Encoded text = "A->B: " (matches graph_edge_to_chunk encode text) _all_edge_data = [ (_fn, _tn, graph.get_edge_data(_fn, _tn)) for _fn, _tn in change.added_updated_edges ] _all_edge_data = [(f, t, a) for f, t, a in _all_edge_data if a] _edge_lookup_keys = [f"{f}->{t}" for f, t, _ in _all_edge_data] _edge_misses = await thread_pool_exec( _batch_embed_cache_misses, embd_mdl.llm_name, _edge_lookup_keys ) if _all_edge_data else [] _uncached_edge_items = [item for item, miss in zip(_all_edge_data, _edge_misses) if miss] logging.debug( "set_graph edge pre-warm: %d edges, %d cache misses", len(_all_edge_data), len(_uncached_edge_items), ) if _uncached_edge_items: _edge_keys = [f"{f}->{t}" for f, t, _ in _uncached_edge_items] _edge_texts = [f"{f}->{t}: {a['description']}" for f, t, a in _uncached_edge_items] _enable_ta = os.environ.get("ENABLE_TIMEOUT_ASSERTION") _timeout = 3 if _enable_ta else 30000000 for _i in range(0, len(_edge_texts), _INSERT_BULK_SIZE): _btexts = _edge_texts[_i : _i + _INSERT_BULK_SIZE] _bkeys = _edge_keys [_i : _i + _INSERT_BULK_SIZE] async with chat_limiter: _ebds, _ = await asyncio.wait_for( thread_pool_exec(embd_mdl.encode, _btexts), timeout=_timeout, ) await thread_pool_exec(_write_embed_cache_batch, embd_mdl.llm_name, _bkeys, _ebds) logging.debug( "set_graph edge pre-warm: wrote batch %d/%d (%d items)", _i // _INSERT_BULK_SIZE + 1, (len(_uncached_edge_items) + _INSERT_BULK_SIZE - 1) // _INSERT_BULK_SIZE, len(_btexts), ) if callback: callback(msg=f"Batch-embedded {len(_uncached_edge_items)} edge descriptions " f"({(len(_uncached_edge_items) + _INSERT_BULK_SIZE - 1) // _INSERT_BULK_SIZE} " f"batches of {_INSERT_BULK_SIZE}).") # ── end batch pre-warm ────────────────────────────────────────────────────── tasks = [] for ii, (from_node, to_node) in enumerate(change.added_updated_edges): edge_attrs = graph.get_edge_data(from_node, to_node) if not edge_attrs: continue tasks.append(asyncio.create_task( graph_edge_to_chunk(kb_id, embd_mdl, from_node, to_node, edge_attrs, chunks) )) if ii % 100 == 9 and callback: callback(msg=f"Get embedding of edges: {ii}/{len(change.added_updated_edges)}") try: await asyncio.gather(*tasks, return_exceptions=False) except Exception as e: logging.error(f"Error in get_embedding_of_edges: {e}") for t in tasks: t.cancel() await asyncio.gather(*tasks, return_exceptions=True) raise now = asyncio.get_running_loop().time() if callback: callback(msg=f"set_graph converted graph change to {len(chunks)} chunks in {now - start:.2f}s.") start = now # All new chunks are ready. Now delete old data and insert the new data. # Deleting only after chunks are built ensures that a crash during embedding # generation above does not destroy the old graph/subgraph checkpoints. await thread_pool_exec( settings.docStoreConn.delete, {"knowledge_graph_kwd": ["graph", "subgraph"]}, search.index_name(tenant_id), kb_id ) if change.removed_nodes: BATCH_SIZE = 100 sorted_nodes = sorted(change.removed_nodes) for i in range(0, len(sorted_nodes), BATCH_SIZE): batch = sorted_nodes[i:i + BATCH_SIZE] await thread_pool_exec( settings.docStoreConn.delete, {"knowledge_graph_kwd": ["entity"], "entity_kwd": batch}, search.index_name(tenant_id), kb_id ) if change.removed_edges: async def del_edges(from_node, to_node): max_retries = 3 for attempt in range(max_retries): try: async with chat_limiter: await thread_pool_exec( settings.docStoreConn.delete, {"knowledge_graph_kwd": ["relation"], "from_entity_kwd": from_node, "to_entity_kwd": to_node}, search.index_name(tenant_id), kb_id ) return except Exception as e: if attempt < max_retries - 1: wait = 2 ** attempt logging.warning(f"del_edges({from_node}, {to_node}) attempt {attempt + 1} failed: {e}, retrying in {wait}s") await asyncio.sleep(wait) else: raise tasks = [] for from_node, to_node in change.removed_edges: tasks.append(asyncio.create_task(del_edges(from_node, to_node))) try: await asyncio.gather(*tasks, return_exceptions=False) except Exception as e: logging.error(f"Error while deleting edges: {e}") for t in tasks: t.cancel() await asyncio.gather(*tasks, return_exceptions=True) raise del_now = asyncio.get_running_loop().time() if callback: callback(msg=f"set_graph removed {len(change.removed_nodes)} nodes and {len(change.removed_edges)} edges from index in {del_now - start:.2f}s.") start = del_now await insert_chunks_bounded(chunks, tenant_id, kb_id, callback=callback, label="Insert chunks") now = asyncio.get_running_loop().time() if callback: callback(msg=f"set_graph added/updated {len(change.added_updated_nodes)} nodes and {len(change.added_updated_edges)} edges from index in {now - start:.2f}s.") def is_continuous_subsequence(subseq, seq): """Return True if *subseq* appears as a contiguous sub-path within tuple *seq*.""" def find_all_indexes(tup, value): indexes = [] start = 0 while True: try: index = tup.index(value, start) indexes.append(index) start = index + 1 except ValueError: break return indexes index_list = find_all_indexes(seq, subseq[0]) for idx in index_list: if idx != len(seq) - 1: if seq[idx + 1] == subseq[-1]: return True return False def merge_tuples(list1, list2): """Extend each path tuple in *list1* by matching continuations found in *list2*.""" result = [] for tup in list1: last_element = tup[-1] if last_element in tup[:-1]: result.append(tup) else: matching_tuples = [t for t in list2 if t[0] == last_element] already_match_flag = 0 for match in matching_tuples: matchh = (match[1], match[0]) if is_continuous_subsequence(match, tup) or is_continuous_subsequence(matchh, tup): continue already_match_flag = 1 merged_tuple = tup + match[1:] result.append(merged_tuple) if not already_match_flag: result.append(tup) return result def n_neighbor(graph: nx.Graph, node, n_hop: int = 2): """Enumerate paths of up to ``n_hop`` edges starting at ``node`` together with the edge weight along each step. Returns a list of ``{"path": (n0, n1, ...), "weights": [w0, w1, ...]}`` dicts (``len(weights) == len(path) - 1``). This is the structure consumed by :class:`rag.graphrag.search.KGSearch` for n-hop relation enrichment and is stored per entity chunk as ``n_hop_with_weight``. """ source_edge = list(graph.edges(node)) if not source_edge: return [] count = 1 while count < n_hop: count += 1 sc_edge = deepcopy(source_edge) source_edge = [] for pair in sc_edge: append_edge = list(graph.edges(pair[-1])) for tuples in merge_tuples([pair], append_edge): source_edge.append(tuples) wts = nx.get_edge_attributes(graph, "weight") nbrs = [] for path in source_edge: nbr = {"path": path, "weights": []} for i in range(len(path) - 1): f, t = path[i], path[i + 1] w = wts.get((f, t)) if w is None: w = wts.get((t, f), 0) nbr["weights"].append(w) nbrs.append(nbr) return nbrs async def get_entity_type2samples(idxnms, kb_ids: list): """Return a mapping of entity type → sample entity names fetched from the document store.""" es_res = await settings.retriever.search({"knowledge_graph_kwd": "ty2ents", "kb_id": kb_ids, "size": 10000, "fields": ["content_with_weight"]},idxnms,kb_ids) res = defaultdict(list) for id in es_res.ids: smp = es_res.field[id].get("content_with_weight") if not smp: continue try: smp = json.loads(smp) except Exception as e: logging.exception(e) for ty, ents in smp.items(): res[ty].extend(ents) return res def flat_uniq_list(arr, key): """Flatten and deduplicate the values at *key* across a list of dicts.""" res = [] for a in arr: a = a[key] if isinstance(a, list): res.extend(a) else: res.append(a) return list(set(res)) async def rebuild_graph(tenant_id, kb_id, exclude_rebuild=None): """Reconstruct the full knowledge-graph for *kb_id* from its stored subgraph chunks.""" graph = nx.Graph() flds = ["knowledge_graph_kwd", "content_with_weight", "source_id"] bs = 256 for i in range(0, 1024 * bs, bs): es_res = await thread_pool_exec( settings.docStoreConn.search, flds, [], {"kb_id": kb_id, "knowledge_graph_kwd": ["subgraph"]}, [], OrderByExpr(), i, bs, search.index_name(tenant_id), [kb_id] ) # tot = settings.docStoreConn.get_total(es_res) es_res = settings.docStoreConn.get_fields(es_res, flds) if len(es_res) == 0: break for id, d in es_res.items(): assert d["knowledge_graph_kwd"] == "subgraph" if isinstance(exclude_rebuild, list): if sum([n in d["source_id"] for n in exclude_rebuild]): continue elif exclude_rebuild in d["source_id"]: continue next_graph = json_graph.node_link_graph(json.loads(d["content_with_weight"]), edges="edges") merged_graph = nx.compose(graph, next_graph) merged_source = {n: graph.nodes[n]["source_id"] + next_graph.nodes[n]["source_id"] for n in graph.nodes & next_graph.nodes} nx.set_node_attributes(merged_graph, merged_source, "source_id") if "source_id" in graph.graph: merged_graph.graph["source_id"] = graph.graph["source_id"] + next_graph.graph["source_id"] else: merged_graph.graph["source_id"] = next_graph.graph["source_id"] graph = merged_graph if len(graph.nodes) == 0: return None graph.graph["source_id"] = sorted(graph.graph["source_id"]) return graph