mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-07-03 09:11:59 +08:00
## Summary This PR adds checkpoint/resume support for the GraphRAG `extract_community` and `resolve_entities` stages. The implementation stores successful intermediate results in the document store so interrupted ingestion can resume without repeating already-completed LLM work. Checkpoints are loaded before each stage, reused when available, saved after successful batch/community processing, and cleaned up after the stage completes successfully. ## Related Issue Closes: #15518 ## Change Type - [x] Feature - [x] Bug fix - [x] Test - [ ] Refactor - [ ] Documentation - [ ] Breaking change ## Real Behavior Proof Validation commands run locally: ```bash uv run python -m py_compile \ rag/graphrag/checkpoints.py \ rag/graphrag/general/community_reports_extractor.py \ rag/graphrag/entity_resolution.py \ rag/graphrag/general/index.py \ test/unit_test/rag/graphrag/test_checkpoints.py ``` Result: ```text Passed ``` ```bash uv run pytest test/unit_test/rag/graphrag/test_checkpoints.py ``` Result: ```text 4 passed ``` ```bash uv run pytest \ test/unit_test/rag/graphrag/test_phase_markers.py \ test/unit_test/rag/graphrag/test_graphrag_utils.py \ test/unit_test/rag/graphrag/test_checkpoints.py ``` Result: ```text 95 passed ``` ```bash git diff --check ``` Result: ```text Passed ``` ## Checklist - [x] Implemented checkpoint/resume support for `extract_community`. - [x] Implemented checkpoint/resume support for `resolve_entities`. - [x] Avoided touching unrelated API behavior. - [x] Added unit tests for the new checkpoint helper logic. - [x] Verified Python syntax compilation. - [x] Ran related GraphRAG unit tests successfully. - [x] Ran `git diff --check`. - [ ] Ran full project test suite. --------- Co-authored-by: Wang Qi <wangq8@outlook.com>
210 lines
8.4 KiB
Python
210 lines
8.4 KiB
Python
# Copyright (c) 2024 Microsoft Corporation.
|
|
# Licensed under the MIT License
|
|
|
|
|
|
"""
|
|
Reference:
|
|
- [graphrag](https://github.com/microsoft/graphrag)
|
|
"""
|
|
|
|
import asyncio
|
|
import logging
|
|
import json
|
|
import os
|
|
import re
|
|
from typing import Any, Awaitable, Callable
|
|
from dataclasses import dataclass
|
|
import networkx as nx
|
|
import pandas as pd
|
|
|
|
from api.db.services.task_service import has_canceled
|
|
from common.exceptions import TaskCanceledException
|
|
from rag.graphrag.general import leiden
|
|
from rag.graphrag.general.community_report_prompt import COMMUNITY_REPORT_PROMPT
|
|
from rag.graphrag.general.extractor import Extractor
|
|
from rag.graphrag.general.leiden import add_community_info2graph
|
|
from rag.graphrag.checkpoints import community_checkpoint_key
|
|
from rag.llm.chat_model import Base as CompletionLLM
|
|
from rag.graphrag.utils import perform_variable_replacements, dict_has_keys_with_types, chat_limiter
|
|
from common.token_utils import num_tokens_from_string
|
|
|
|
@dataclass
|
|
class CommunityReportsResult:
|
|
"""Community reports result class definition."""
|
|
|
|
output: list[str]
|
|
structured_output: list[dict]
|
|
|
|
|
|
class CommunityReportsExtractor(Extractor):
|
|
"""Community reports extractor class definition."""
|
|
|
|
_extraction_prompt: str
|
|
_output_formatter_prompt: str
|
|
_max_report_length: int
|
|
|
|
def __init__(
|
|
self,
|
|
llm_invoker: CompletionLLM,
|
|
max_report_length: int | None = None,
|
|
):
|
|
super().__init__(llm_invoker)
|
|
"""Init method definition."""
|
|
self._llm = llm_invoker
|
|
self._extraction_prompt = COMMUNITY_REPORT_PROMPT
|
|
self._max_report_length = max_report_length or 1500
|
|
|
|
async def __call__(
|
|
self,
|
|
graph: nx.Graph,
|
|
callback: Callable | None = None,
|
|
task_id: str = "",
|
|
checkpoints: dict[str, Any] | None = None,
|
|
save_checkpoint: Callable[[str, Any], Awaitable[bool]] | None = None,
|
|
):
|
|
enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
|
|
for node_degree in graph.degree:
|
|
graph.nodes[str(node_degree[0])]["rank"] = int(node_degree[1])
|
|
|
|
communities: dict[str, dict[str, list]] = leiden.run(graph, {})
|
|
total = sum([len(comm.items()) for _, comm in communities.items()])
|
|
res_str = []
|
|
res_dict = []
|
|
over, token_count = 0, 0
|
|
checkpoints = checkpoints or {}
|
|
|
|
async def extract_community_report(level, community):
|
|
nonlocal res_str, res_dict, over, token_count
|
|
if task_id:
|
|
if has_canceled(task_id):
|
|
logging.info(f"Task {task_id} cancelled during community report extraction.")
|
|
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
|
|
|
cm_id, cm = community
|
|
weight = cm["weight"]
|
|
ents = cm["nodes"]
|
|
if len(ents) < 2:
|
|
return
|
|
checkpoint_key = community_checkpoint_key(str(level), str(cm_id), list(ents))
|
|
checkpoint = checkpoints.get(checkpoint_key)
|
|
if isinstance(checkpoint, dict):
|
|
response = checkpoint.get("structured_output")
|
|
output = checkpoint.get("output")
|
|
if isinstance(response, dict) and isinstance(output, str):
|
|
add_community_info2graph(graph, response.get("entities", ents), response.get("title", ""))
|
|
res_str.append(output)
|
|
res_dict.append(response)
|
|
over += 1
|
|
if callback:
|
|
callback(msg=f"Communities: {over}/{total}, used tokens: {token_count}")
|
|
return
|
|
ent_list = [{"entity": ent, "description": graph.nodes[ent]["description"]} for ent in ents]
|
|
ent_df = pd.DataFrame(ent_list)
|
|
|
|
rela_list = []
|
|
k = 0
|
|
for i in range(0, len(ents)):
|
|
if k >= 10000:
|
|
break
|
|
for j in range(i + 1, len(ents)):
|
|
if k >= 10000:
|
|
break
|
|
edge = graph.get_edge_data(ents[i], ents[j])
|
|
if edge is None:
|
|
continue
|
|
rela_list.append({"source": ents[i], "target": ents[j], "description": edge["description"]})
|
|
k += 1
|
|
rela_df = pd.DataFrame(rela_list)
|
|
|
|
prompt_variables = {
|
|
"entity_df": ent_df.to_csv(index_label="id"),
|
|
"relation_df": rela_df.to_csv(index_label="id")
|
|
}
|
|
text = perform_variable_replacements(self._extraction_prompt, variables=prompt_variables)
|
|
async with chat_limiter:
|
|
try:
|
|
timeout = 180 if enable_timeout_assertion else 1000000000
|
|
response = await asyncio.wait_for(self._async_chat(text, [{"role": "user", "content": "Output:"}], {}, task_id), timeout=timeout)
|
|
except asyncio.TimeoutError:
|
|
logging.warning("extract_community_report._async_chat timeout, skipping...")
|
|
return
|
|
except Exception as e:
|
|
logging.error(f"extract_community_report._async_chat failed: {e}")
|
|
return
|
|
token_count += num_tokens_from_string(text + response)
|
|
response = re.sub(r"^[^\{]*", "", response)
|
|
response = re.sub(r"[^\}]*$", "", response)
|
|
response = re.sub(r"\{\{", "{", response)
|
|
response = re.sub(r"\}\}", "}", response)
|
|
logging.debug(response)
|
|
try:
|
|
response = json.loads(response)
|
|
except json.JSONDecodeError as e:
|
|
logging.error(f"Failed to parse JSON response: {e}")
|
|
logging.error(f"Response content: {response}")
|
|
return
|
|
if not dict_has_keys_with_types(response, [
|
|
("title", str),
|
|
("summary", str),
|
|
("findings", list),
|
|
("rating", float),
|
|
("rating_explanation", str),
|
|
]):
|
|
return
|
|
response["weight"] = weight
|
|
response["entities"] = ents
|
|
add_community_info2graph(graph, ents, response["title"])
|
|
output = self._get_text_output(response)
|
|
if save_checkpoint:
|
|
await save_checkpoint(checkpoint_key, {"structured_output": response, "output": output})
|
|
res_str.append(output)
|
|
res_dict.append(response)
|
|
over += 1
|
|
if callback:
|
|
callback(msg=f"Communities: {over}/{total}, used tokens: {token_count}")
|
|
|
|
st = asyncio.get_running_loop().time()
|
|
tasks = []
|
|
for level, comm in communities.items():
|
|
logging.info(f"Level {level}: Community: {len(comm.keys())}")
|
|
for community in comm.items():
|
|
if task_id and has_canceled(task_id):
|
|
logging.info(f"Task {task_id} cancelled before community processing.")
|
|
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
|
tasks.append(asyncio.create_task(extract_community_report(level, community)))
|
|
try:
|
|
await asyncio.gather(*tasks, return_exceptions=False)
|
|
except Exception as e:
|
|
logging.error(f"Error in community processing: {e}")
|
|
for t in tasks:
|
|
t.cancel()
|
|
await asyncio.gather(*tasks, return_exceptions=True)
|
|
raise
|
|
if callback:
|
|
callback(msg=f"Community reports done in {asyncio.get_running_loop().time() - st:.2f}s, used tokens: {token_count}")
|
|
|
|
return CommunityReportsResult(
|
|
structured_output=res_dict,
|
|
output=res_str,
|
|
)
|
|
|
|
def _get_text_output(self, parsed_output: dict) -> str:
|
|
title = parsed_output.get("title", "Report")
|
|
summary = parsed_output.get("summary", "")
|
|
findings = parsed_output.get("findings", [])
|
|
|
|
def finding_summary(finding: dict):
|
|
if isinstance(finding, str):
|
|
return finding
|
|
return finding.get("summary")
|
|
|
|
def finding_explanation(finding: dict):
|
|
if isinstance(finding, str):
|
|
return ""
|
|
return finding.get("explanation")
|
|
|
|
report_sections = "\n\n".join(
|
|
f"## {finding_summary(f)}\n\n{finding_explanation(f)}" for f in findings
|
|
)
|
|
return f"# {title}\n\n{summary}\n\n{report_sections}"
|