mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 15:31:05 +08:00
Refactor: enhance graphrag - part 2 (#14972)
### What problem does this PR solve? 1. expose batch_chunk_token_size for configuration 2. retrieve chunks when build subgraph for the doc, not retreive all docs chunks at the begining 3. get all chunks for a document, used to be hard coded 10000 4. delete not used method run_graphrag ### Type of change - [x] New Feature (non-breaking change which adds functionality) - [x] Refactoring Follow on: #14617
This commit is contained in:
@@ -54,6 +54,22 @@ from common import settings
|
||||
from common.doc_store.doc_store_base import OrderByExpr
|
||||
|
||||
|
||||
DEFAULT_GRAPHRAG_BATCH_CHUNK_TOKEN_SIZE = 4096
|
||||
|
||||
|
||||
def _positive_int_config(config: dict, key: str, default: int) -> int:
|
||||
value = config.get(key, default)
|
||||
try:
|
||||
value = int(value)
|
||||
except (TypeError, ValueError):
|
||||
logging.warning("Invalid GraphRAG config %s=%r, using default %s", key, value, default)
|
||||
return default
|
||||
if value < 512 or value > 8196:
|
||||
logging.warning("Invalid GraphRAG config %s=%r, using default %s", key, value, default)
|
||||
return default
|
||||
return value
|
||||
|
||||
|
||||
def _select_extractor(graphrag_config: dict):
|
||||
"""Return the extractor class matching ``graphrag_config["method"]``.
|
||||
|
||||
@@ -121,100 +137,6 @@ async def load_subgraph_from_store(tenant_id: str, kb_id: str, doc_id: str):
|
||||
return None
|
||||
|
||||
|
||||
async def run_graphrag(
|
||||
row: dict,
|
||||
language,
|
||||
with_resolution: bool,
|
||||
with_community: bool,
|
||||
chat_model,
|
||||
embedding_model,
|
||||
callback,
|
||||
):
|
||||
enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
|
||||
start = asyncio.get_running_loop().time()
|
||||
tenant_id, kb_id, doc_id = row["tenant_id"], str(row["kb_id"]), row["doc_id"]
|
||||
chunks = []
|
||||
for d in settings.retriever.chunk_list(doc_id, tenant_id, [kb_id], max_count=10000, fields=["content_with_weight", "doc_id"], sort_by_position=True):
|
||||
chunks.append(d["content_with_weight"])
|
||||
|
||||
timeout_sec = max(120, len(chunks) * 60 * 10) if enable_timeout_assertion else 10000000000
|
||||
|
||||
try:
|
||||
subgraph = await asyncio.wait_for(
|
||||
generate_subgraph(
|
||||
_select_extractor(row["kb_parser_config"].get("graphrag", {})),
|
||||
tenant_id,
|
||||
kb_id,
|
||||
doc_id,
|
||||
chunks,
|
||||
language,
|
||||
row["kb_parser_config"]["graphrag"].get("entity_types", []),
|
||||
chat_model,
|
||||
embedding_model,
|
||||
callback,
|
||||
),
|
||||
timeout=timeout_sec,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logging.error("generate_subgraph timeout")
|
||||
raise
|
||||
|
||||
if not subgraph:
|
||||
return
|
||||
|
||||
graphrag_task_lock = RedisDistributedLock(f"graphrag_task_{kb_id}", lock_value=doc_id, timeout=1200)
|
||||
await graphrag_task_lock.spin_acquire()
|
||||
callback(msg=f"run_graphrag {doc_id} graphrag_task_lock acquired")
|
||||
|
||||
try:
|
||||
subgraph_nodes = set(subgraph.nodes())
|
||||
new_graph = await merge_subgraph(
|
||||
tenant_id,
|
||||
kb_id,
|
||||
doc_id,
|
||||
subgraph,
|
||||
embedding_model,
|
||||
callback,
|
||||
)
|
||||
assert new_graph is not None
|
||||
|
||||
if not with_resolution and not with_community:
|
||||
return
|
||||
|
||||
if with_resolution:
|
||||
await graphrag_task_lock.spin_acquire()
|
||||
callback(msg=f"run_graphrag {doc_id} graphrag_task_lock acquired")
|
||||
await resolve_entities(
|
||||
new_graph,
|
||||
subgraph_nodes,
|
||||
tenant_id,
|
||||
kb_id,
|
||||
doc_id,
|
||||
chat_model,
|
||||
embedding_model,
|
||||
callback,
|
||||
task_id=row["id"],
|
||||
)
|
||||
if with_community:
|
||||
await graphrag_task_lock.spin_acquire()
|
||||
callback(msg=f"run_graphrag {doc_id} graphrag_task_lock acquired")
|
||||
await extract_community(
|
||||
new_graph,
|
||||
tenant_id,
|
||||
kb_id,
|
||||
doc_id,
|
||||
chat_model,
|
||||
embedding_model,
|
||||
callback,
|
||||
task_id=row["id"],
|
||||
)
|
||||
finally:
|
||||
graphrag_task_lock.release()
|
||||
now = asyncio.get_running_loop().time()
|
||||
callback(msg=f"GraphRAG for doc {doc_id} done in {now - start:.2f} seconds.")
|
||||
return
|
||||
|
||||
|
||||
async def run_graphrag_for_kb(
|
||||
row: dict,
|
||||
doc_ids: list[str],
|
||||
@@ -232,6 +154,8 @@ async def run_graphrag_for_kb(
|
||||
enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
|
||||
start = asyncio.get_running_loop().time()
|
||||
fields_for_chunks = ["content_with_weight", "doc_id"]
|
||||
graphrag_config = kb_parser_config.get("graphrag", {})
|
||||
batch_chunk_token_size = _positive_int_config(graphrag_config, "batch_chunk_token_size", DEFAULT_GRAPHRAG_BATCH_CHUNK_TOKEN_SIZE)
|
||||
|
||||
if not doc_ids:
|
||||
logging.info(f"Fetching all docs for {kb_id}")
|
||||
@@ -259,21 +183,20 @@ async def run_graphrag_for_kb(
|
||||
chunks = []
|
||||
current_chunk = ""
|
||||
|
||||
# DEBUG: Obtener todos los chunks primero
|
||||
raw_chunks = list(settings.retriever.chunk_list(
|
||||
doc_id,
|
||||
tenant_id,
|
||||
[kb_id],
|
||||
max_count=10000, # FIX: Aumentar límite para procesar todos los chunks
|
||||
fields=fields_for_chunks,
|
||||
sort_by_position=True,
|
||||
retrieve_all=True
|
||||
))
|
||||
|
||||
callback(msg=f"[DEBUG] chunk_list() returned {len(raw_chunks)} raw chunks for doc {doc_id}")
|
||||
callback(msg=f"[GraphRAG] chunk_list returned {len(raw_chunks)} raw chunks for doc:{doc_id}")
|
||||
|
||||
for d in raw_chunks:
|
||||
content = d["content_with_weight"]
|
||||
if num_tokens_from_string(current_chunk + content) < 4096:
|
||||
if num_tokens_from_string(current_chunk + content) < batch_chunk_token_size:
|
||||
current_chunk += content
|
||||
else:
|
||||
if current_chunk:
|
||||
@@ -285,16 +208,7 @@ async def run_graphrag_for_kb(
|
||||
|
||||
return chunks
|
||||
|
||||
all_doc_chunks: dict[str, list[str]] = {}
|
||||
total_chunks = 0
|
||||
for doc_id in doc_ids:
|
||||
chunks = load_doc_chunks(doc_id)
|
||||
all_doc_chunks[doc_id] = chunks
|
||||
total_chunks += len(chunks)
|
||||
|
||||
if total_chunks == 0:
|
||||
callback(msg=f"[GraphRAG] kb:{kb_id} has no available chunks in all documents, skip.")
|
||||
return {"ok_docs": [], "failed_docs": doc_ids, "total_docs": len(doc_ids), "total_chunks": 0, "seconds": 0.0}
|
||||
|
||||
semaphore = asyncio.Semaphore(max_parallel_docs)
|
||||
|
||||
@@ -302,18 +216,13 @@ async def run_graphrag_for_kb(
|
||||
failed_docs: list[tuple[str, str]] = [] # (doc_id, error)
|
||||
|
||||
async def build_one(doc_id: str):
|
||||
nonlocal total_chunks
|
||||
|
||||
if has_canceled(row["id"]):
|
||||
callback(msg=f"Task {row['id']} cancelled, stopping execution.")
|
||||
raise TaskCanceledException(f"Task {row['id']} was cancelled")
|
||||
|
||||
chunks = all_doc_chunks.get(doc_id, [])
|
||||
if not chunks:
|
||||
callback(msg=f"[GraphRAG] doc:{doc_id} has no available chunks, skip generation.")
|
||||
return
|
||||
|
||||
kg_extractor = _select_extractor(kb_parser_config.get("graphrag", {}))
|
||||
|
||||
deadline = max(120, len(chunks) * 60 * 10) if enable_timeout_assertion else 10000000000
|
||||
kg_extractor = _select_extractor(graphrag_config)
|
||||
|
||||
async with semaphore:
|
||||
# CHECKPOINT: bounded by semaphore so doc-store lookups respect max_parallel_docs
|
||||
@@ -323,6 +232,13 @@ async def run_graphrag_for_kb(
|
||||
callback(msg=f"[GraphRAG] doc:{doc_id} subgraph found in store, skipping LLM extraction.")
|
||||
return
|
||||
try:
|
||||
chunks = load_doc_chunks(doc_id)
|
||||
total_chunks += len(chunks)
|
||||
if not chunks:
|
||||
callback(msg=f"[GraphRAG] doc:{doc_id} has no available chunks, skip generation.")
|
||||
return
|
||||
|
||||
deadline = max(120, len(chunks) * 60 * 10) if enable_timeout_assertion else 10000000000
|
||||
msg = f"[GraphRAG] build_subgraph doc:{doc_id}"
|
||||
callback(msg=f"{msg} start (chunks={len(chunks)}, timeout={deadline}s)")
|
||||
|
||||
@@ -373,6 +289,10 @@ async def run_graphrag_for_kb(
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
raise
|
||||
|
||||
if total_chunks == 0 and not subgraphs:
|
||||
callback(msg=f"[GraphRAG] kb:{kb_id} has no available chunks in all documents, skip.")
|
||||
return {"ok_docs": [], "failed_docs": [(doc_id, "no available chunks") for doc_id in doc_ids], "total_docs": len(doc_ids), "total_chunks": 0, "seconds": 0.0}
|
||||
|
||||
if has_canceled(row["id"]):
|
||||
callback(msg=f"Task {row['id']} cancelled after document processing.")
|
||||
raise TaskCanceledException(f"Task {row['id']} was cancelled")
|
||||
|
||||
Reference in New Issue
Block a user