Refactor: enhance graphrag - part 2 (#14972)

### What problem does this PR solve?
1. expose batch_chunk_token_size for configuration
2. retrieve chunks when build subgraph for the doc, not retreive all
docs chunks at the begining
3. get all chunks for a document, used to be hard coded 10000
4. delete not used method run_graphrag

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
- [x] Refactoring

Follow on: #14617
This commit is contained in:
Wang Qi
2026-05-18 16:10:21 +08:00
committed by GitHub
parent b12eaee38b
commit 13b422037f
15 changed files with 82 additions and 118 deletions

View File

@@ -54,6 +54,22 @@ from common import settings
from common.doc_store.doc_store_base import OrderByExpr
DEFAULT_GRAPHRAG_BATCH_CHUNK_TOKEN_SIZE = 4096
def _positive_int_config(config: dict, key: str, default: int) -> int:
value = config.get(key, default)
try:
value = int(value)
except (TypeError, ValueError):
logging.warning("Invalid GraphRAG config %s=%r, using default %s", key, value, default)
return default
if value < 512 or value > 8196:
logging.warning("Invalid GraphRAG config %s=%r, using default %s", key, value, default)
return default
return value
def _select_extractor(graphrag_config: dict):
"""Return the extractor class matching ``graphrag_config["method"]``.
@@ -121,100 +137,6 @@ async def load_subgraph_from_store(tenant_id: str, kb_id: str, doc_id: str):
return None
async def run_graphrag(
row: dict,
language,
with_resolution: bool,
with_community: bool,
chat_model,
embedding_model,
callback,
):
enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
start = asyncio.get_running_loop().time()
tenant_id, kb_id, doc_id = row["tenant_id"], str(row["kb_id"]), row["doc_id"]
chunks = []
for d in settings.retriever.chunk_list(doc_id, tenant_id, [kb_id], max_count=10000, fields=["content_with_weight", "doc_id"], sort_by_position=True):
chunks.append(d["content_with_weight"])
timeout_sec = max(120, len(chunks) * 60 * 10) if enable_timeout_assertion else 10000000000
try:
subgraph = await asyncio.wait_for(
generate_subgraph(
_select_extractor(row["kb_parser_config"].get("graphrag", {})),
tenant_id,
kb_id,
doc_id,
chunks,
language,
row["kb_parser_config"]["graphrag"].get("entity_types", []),
chat_model,
embedding_model,
callback,
),
timeout=timeout_sec,
)
except asyncio.TimeoutError:
logging.error("generate_subgraph timeout")
raise
if not subgraph:
return
graphrag_task_lock = RedisDistributedLock(f"graphrag_task_{kb_id}", lock_value=doc_id, timeout=1200)
await graphrag_task_lock.spin_acquire()
callback(msg=f"run_graphrag {doc_id} graphrag_task_lock acquired")
try:
subgraph_nodes = set(subgraph.nodes())
new_graph = await merge_subgraph(
tenant_id,
kb_id,
doc_id,
subgraph,
embedding_model,
callback,
)
assert new_graph is not None
if not with_resolution and not with_community:
return
if with_resolution:
await graphrag_task_lock.spin_acquire()
callback(msg=f"run_graphrag {doc_id} graphrag_task_lock acquired")
await resolve_entities(
new_graph,
subgraph_nodes,
tenant_id,
kb_id,
doc_id,
chat_model,
embedding_model,
callback,
task_id=row["id"],
)
if with_community:
await graphrag_task_lock.spin_acquire()
callback(msg=f"run_graphrag {doc_id} graphrag_task_lock acquired")
await extract_community(
new_graph,
tenant_id,
kb_id,
doc_id,
chat_model,
embedding_model,
callback,
task_id=row["id"],
)
finally:
graphrag_task_lock.release()
now = asyncio.get_running_loop().time()
callback(msg=f"GraphRAG for doc {doc_id} done in {now - start:.2f} seconds.")
return
async def run_graphrag_for_kb(
row: dict,
doc_ids: list[str],
@@ -232,6 +154,8 @@ async def run_graphrag_for_kb(
enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
start = asyncio.get_running_loop().time()
fields_for_chunks = ["content_with_weight", "doc_id"]
graphrag_config = kb_parser_config.get("graphrag", {})
batch_chunk_token_size = _positive_int_config(graphrag_config, "batch_chunk_token_size", DEFAULT_GRAPHRAG_BATCH_CHUNK_TOKEN_SIZE)
if not doc_ids:
logging.info(f"Fetching all docs for {kb_id}")
@@ -259,21 +183,20 @@ async def run_graphrag_for_kb(
chunks = []
current_chunk = ""
# DEBUG: Obtener todos los chunks primero
raw_chunks = list(settings.retriever.chunk_list(
doc_id,
tenant_id,
[kb_id],
max_count=10000, # FIX: Aumentar límite para procesar todos los chunks
fields=fields_for_chunks,
sort_by_position=True,
retrieve_all=True
))
callback(msg=f"[DEBUG] chunk_list() returned {len(raw_chunks)} raw chunks for doc {doc_id}")
callback(msg=f"[GraphRAG] chunk_list returned {len(raw_chunks)} raw chunks for doc:{doc_id}")
for d in raw_chunks:
content = d["content_with_weight"]
if num_tokens_from_string(current_chunk + content) < 4096:
if num_tokens_from_string(current_chunk + content) < batch_chunk_token_size:
current_chunk += content
else:
if current_chunk:
@@ -285,16 +208,7 @@ async def run_graphrag_for_kb(
return chunks
all_doc_chunks: dict[str, list[str]] = {}
total_chunks = 0
for doc_id in doc_ids:
chunks = load_doc_chunks(doc_id)
all_doc_chunks[doc_id] = chunks
total_chunks += len(chunks)
if total_chunks == 0:
callback(msg=f"[GraphRAG] kb:{kb_id} has no available chunks in all documents, skip.")
return {"ok_docs": [], "failed_docs": doc_ids, "total_docs": len(doc_ids), "total_chunks": 0, "seconds": 0.0}
semaphore = asyncio.Semaphore(max_parallel_docs)
@@ -302,18 +216,13 @@ async def run_graphrag_for_kb(
failed_docs: list[tuple[str, str]] = [] # (doc_id, error)
async def build_one(doc_id: str):
nonlocal total_chunks
if has_canceled(row["id"]):
callback(msg=f"Task {row['id']} cancelled, stopping execution.")
raise TaskCanceledException(f"Task {row['id']} was cancelled")
chunks = all_doc_chunks.get(doc_id, [])
if not chunks:
callback(msg=f"[GraphRAG] doc:{doc_id} has no available chunks, skip generation.")
return
kg_extractor = _select_extractor(kb_parser_config.get("graphrag", {}))
deadline = max(120, len(chunks) * 60 * 10) if enable_timeout_assertion else 10000000000
kg_extractor = _select_extractor(graphrag_config)
async with semaphore:
# CHECKPOINT: bounded by semaphore so doc-store lookups respect max_parallel_docs
@@ -323,6 +232,13 @@ async def run_graphrag_for_kb(
callback(msg=f"[GraphRAG] doc:{doc_id} subgraph found in store, skipping LLM extraction.")
return
try:
chunks = load_doc_chunks(doc_id)
total_chunks += len(chunks)
if not chunks:
callback(msg=f"[GraphRAG] doc:{doc_id} has no available chunks, skip generation.")
return
deadline = max(120, len(chunks) * 60 * 10) if enable_timeout_assertion else 10000000000
msg = f"[GraphRAG] build_subgraph doc:{doc_id}"
callback(msg=f"{msg} start (chunks={len(chunks)}, timeout={deadline}s)")
@@ -373,6 +289,10 @@ async def run_graphrag_for_kb(
await asyncio.gather(*tasks, return_exceptions=True)
raise
if total_chunks == 0 and not subgraphs:
callback(msg=f"[GraphRAG] kb:{kb_id} has no available chunks in all documents, skip.")
return {"ok_docs": [], "failed_docs": [(doc_id, "no available chunks") for doc_id in doc_ids], "total_docs": len(doc_ids), "total_chunks": 0, "seconds": 0.0}
if has_canceled(row["id"]):
callback(msg=f"Task {row['id']} cancelled after document processing.")
raise TaskCanceledException(f"Task {row['id']} was cancelled")