mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 23:41:12 +08:00
## Summary This PR adds checkpoint/resume support for the GraphRAG `extract_community` and `resolve_entities` stages. The implementation stores successful intermediate results in the document store so interrupted ingestion can resume without repeating already-completed LLM work. Checkpoints are loaded before each stage, reused when available, saved after successful batch/community processing, and cleaned up after the stage completes successfully. ## Related Issue Closes: #15518 ## Change Type - [x] Feature - [x] Bug fix - [x] Test - [ ] Refactor - [ ] Documentation - [ ] Breaking change ## Real Behavior Proof Validation commands run locally: ```bash uv run python -m py_compile \ rag/graphrag/checkpoints.py \ rag/graphrag/general/community_reports_extractor.py \ rag/graphrag/entity_resolution.py \ rag/graphrag/general/index.py \ test/unit_test/rag/graphrag/test_checkpoints.py ``` Result: ```text Passed ``` ```bash uv run pytest test/unit_test/rag/graphrag/test_checkpoints.py ``` Result: ```text 4 passed ``` ```bash uv run pytest \ test/unit_test/rag/graphrag/test_phase_markers.py \ test/unit_test/rag/graphrag/test_graphrag_utils.py \ test/unit_test/rag/graphrag/test_checkpoints.py ``` Result: ```text 95 passed ``` ```bash git diff --check ``` Result: ```text Passed ``` ## Checklist - [x] Implemented checkpoint/resume support for `extract_community`. - [x] Implemented checkpoint/resume support for `resolve_entities`. - [x] Avoided touching unrelated API behavior. - [x] Added unit tests for the new checkpoint helper logic. - [x] Verified Python syntax compilation. - [x] Ran related GraphRAG unit tests successfully. - [x] Ran `git diff --check`. - [ ] Ran full project test suite. --------- Co-authored-by: Wang Qi <wangq8@outlook.com>
135 lines
5.9 KiB
Python
135 lines
5.9 KiB
Python
#
|
|
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
|
|
from __future__ import annotations
|
|
|
|
import hashlib
|
|
import json
|
|
import logging
|
|
from typing import Any
|
|
|
|
from common.misc_utils import thread_pool_exec
|
|
from rag.utils.redis_conn import REDIS_CONN
|
|
|
|
|
|
COMMUNITY_CHECKPOINT = "graphrag_checkpoint_community"
|
|
RESOLUTION_CHECKPOINT = "graphrag_checkpoint_resolution"
|
|
CHECKPOINT_PAGE_SIZE = 1000
|
|
CHECKPOINT_TTL_SECONDS = 7 * 24 * 3600
|
|
|
|
|
|
def stable_checkpoint_key(*parts: Any) -> str:
|
|
payload = json.dumps(parts, ensure_ascii=False, sort_keys=True, separators=(",", ":"))
|
|
return hashlib.sha256(payload.encode("utf-8")).hexdigest()
|
|
|
|
|
|
def community_checkpoint_key(level: str, community_id: str, nodes: list[str]) -> str:
|
|
return stable_checkpoint_key("community", str(level), str(community_id), sorted(nodes))
|
|
|
|
|
|
def resolution_checkpoint_key(entity_type: str, pairs: list[tuple[str, str]]) -> str:
|
|
normalized_pairs = sorted([sorted([a, b]) for a, b in pairs])
|
|
return stable_checkpoint_key("resolution", entity_type, normalized_pairs)
|
|
|
|
|
|
def _checkpoint_index_key(tenant_id: str, kb_id: str, checkpoint_type: str) -> str:
|
|
return f"graphrag:checkpoint:{tenant_id}:{kb_id}:{checkpoint_type}:keys"
|
|
|
|
|
|
def _checkpoint_data_key(tenant_id: str, kb_id: str, checkpoint_type: str, checkpoint_key: str) -> str:
|
|
return f"graphrag:checkpoint:{tenant_id}:{kb_id}:{checkpoint_type}:{checkpoint_key}"
|
|
|
|
|
|
def _decode_redis_value(value: Any) -> Any:
|
|
if isinstance(value, bytes):
|
|
return value.decode("utf-8")
|
|
return value
|
|
|
|
|
|
def _checkpoint_page_size(page_size: int | None) -> int:
|
|
return page_size if page_size and page_size > 0 else CHECKPOINT_PAGE_SIZE
|
|
|
|
|
|
def _iter_checkpoint_keys(index_key: str, page_size: int | None):
|
|
redis_client = getattr(REDIS_CONN, "REDIS", None)
|
|
if redis_client is None or not hasattr(redis_client, "sscan_iter"):
|
|
raise RuntimeError("Redis SSCAN is unavailable for GraphRAG checkpoint index iteration")
|
|
return redis_client.sscan_iter(index_key, count=_checkpoint_page_size(page_size))
|
|
|
|
|
|
def _load_checkpoints_sync(tenant_id: str, kb_id: str, checkpoint_type: str, page_size: int | None) -> dict[str, Any]:
|
|
checkpoints: dict[str, Any] = {}
|
|
index_key = _checkpoint_index_key(tenant_id, kb_id, checkpoint_type)
|
|
try:
|
|
checkpoint_keys = _iter_checkpoint_keys(index_key, page_size)
|
|
except Exception:
|
|
logging.exception("Failed to load GraphRAG checkpoint index type=%s kb=%s", checkpoint_type, kb_id)
|
|
return checkpoints
|
|
|
|
for checkpoint_key in checkpoint_keys:
|
|
checkpoint_key = _decode_redis_value(checkpoint_key)
|
|
try:
|
|
value = REDIS_CONN.get(_checkpoint_data_key(tenant_id, kb_id, checkpoint_type, checkpoint_key))
|
|
value = _decode_redis_value(value)
|
|
if not value:
|
|
continue
|
|
checkpoints[checkpoint_key] = json.loads(value)
|
|
except Exception:
|
|
logging.exception("Failed to parse GraphRAG checkpoint type=%s kb=%s key=%s", checkpoint_type, kb_id, checkpoint_key)
|
|
logging.info("Loaded %d GraphRAG checkpoints type=%s kb=%s", len(checkpoints), checkpoint_type, kb_id)
|
|
return checkpoints
|
|
|
|
|
|
async def load_checkpoints(tenant_id: str, kb_id: str, checkpoint_type: str, *, page_size: int | None = None) -> dict[str, Any]:
|
|
return await thread_pool_exec(_load_checkpoints_sync, tenant_id, kb_id, checkpoint_type, page_size)
|
|
|
|
|
|
async def save_checkpoint(tenant_id: str, kb_id: str, checkpoint_type: str, checkpoint_key: str, payload: Any) -> bool:
|
|
index_key = _checkpoint_index_key(tenant_id, kb_id, checkpoint_type)
|
|
data_key = _checkpoint_data_key(tenant_id, kb_id, checkpoint_type, checkpoint_key)
|
|
try:
|
|
redis_client = getattr(REDIS_CONN, "REDIS", None)
|
|
if redis_client is None or not hasattr(redis_client, "pipeline"):
|
|
logging.warning("GraphRAG checkpoint Redis client unavailable type=%s kb=%s key=%s", checkpoint_type, kb_id, checkpoint_key)
|
|
return False
|
|
pipeline = redis_client.pipeline(transaction=True)
|
|
pipeline.set(data_key, json.dumps(payload, ensure_ascii=False), ex=CHECKPOINT_TTL_SECONDS)
|
|
pipeline.sadd(index_key, checkpoint_key)
|
|
pipeline.expire(index_key, CHECKPOINT_TTL_SECONDS)
|
|
pipeline.execute()
|
|
logging.info("Saved GraphRAG checkpoint type=%s kb=%s key=%s", checkpoint_type, kb_id, checkpoint_key)
|
|
return True
|
|
except Exception:
|
|
logging.exception("Failed to save GraphRAG checkpoint type=%s kb=%s key=%s", checkpoint_type, kb_id, checkpoint_key)
|
|
return False
|
|
|
|
|
|
async def cleanup_checkpoints(tenant_id: str, kb_id: str, checkpoint_type: str, *, page_size: int | None = None) -> bool:
|
|
index_key = _checkpoint_index_key(tenant_id, kb_id, checkpoint_type)
|
|
try:
|
|
cleaned_count = 0
|
|
checkpoint_keys = _iter_checkpoint_keys(index_key, page_size)
|
|
for checkpoint_key in checkpoint_keys:
|
|
checkpoint_key = _decode_redis_value(checkpoint_key)
|
|
REDIS_CONN.delete(_checkpoint_data_key(tenant_id, kb_id, checkpoint_type, checkpoint_key))
|
|
cleaned_count += 1
|
|
REDIS_CONN.delete(index_key)
|
|
logging.info("Cleaned up %d GraphRAG checkpoints type=%s kb=%s", cleaned_count, checkpoint_type, kb_id)
|
|
return True
|
|
except Exception:
|
|
logging.exception("Failed to cleanup GraphRAG checkpoints type=%s kb=%s", checkpoint_type, kb_id)
|
|
return False
|