Files
ragflow/rag/graphrag/checkpoints.py
Jonathan Chang c586292993 feat: Implement checkpoint/resume support for GraphRAG community extraction and entity resolution (#15523)
## Summary

This PR adds checkpoint/resume support for the GraphRAG
`extract_community` and `resolve_entities` stages.

The implementation stores successful intermediate results in the
document store so interrupted ingestion can resume without repeating
already-completed LLM work. Checkpoints are loaded before each stage,
reused when available, saved after successful batch/community
processing, and cleaned up after the stage completes successfully.
## Related Issue
Closes: #15518
## Change Type
- [x] Feature
- [x] Bug fix
- [x] Test
- [ ] Refactor
- [ ] Documentation
- [ ] Breaking change
## Real Behavior Proof

Validation commands run locally:

```bash
uv run python -m py_compile \
  rag/graphrag/checkpoints.py \
  rag/graphrag/general/community_reports_extractor.py \
  rag/graphrag/entity_resolution.py \
  rag/graphrag/general/index.py \
  test/unit_test/rag/graphrag/test_checkpoints.py
```
Result:

```text
Passed
```

```bash
uv run pytest test/unit_test/rag/graphrag/test_checkpoints.py
```
Result:

```text
4 passed
```

```bash
uv run pytest \
  test/unit_test/rag/graphrag/test_phase_markers.py \
  test/unit_test/rag/graphrag/test_graphrag_utils.py \
  test/unit_test/rag/graphrag/test_checkpoints.py
```
Result:

```text
95 passed
```

```bash
git diff --check
```
Result:

```text
Passed
```

## Checklist

- [x] Implemented checkpoint/resume support for `extract_community`.
- [x] Implemented checkpoint/resume support for `resolve_entities`.
- [x] Avoided touching unrelated API behavior.
- [x] Added unit tests for the new checkpoint helper logic.
- [x] Verified Python syntax compilation.
- [x] Ran related GraphRAG unit tests successfully.
- [x] Ran `git diff --check`.
- [ ] Ran full project test suite.

---------

Co-authored-by: Wang Qi <wangq8@outlook.com>
2026-06-09 15:34:47 +08:00

135 lines
5.9 KiB
Python

#
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from __future__ import annotations
import hashlib
import json
import logging
from typing import Any
from common.misc_utils import thread_pool_exec
from rag.utils.redis_conn import REDIS_CONN
COMMUNITY_CHECKPOINT = "graphrag_checkpoint_community"
RESOLUTION_CHECKPOINT = "graphrag_checkpoint_resolution"
CHECKPOINT_PAGE_SIZE = 1000
CHECKPOINT_TTL_SECONDS = 7 * 24 * 3600
def stable_checkpoint_key(*parts: Any) -> str:
payload = json.dumps(parts, ensure_ascii=False, sort_keys=True, separators=(",", ":"))
return hashlib.sha256(payload.encode("utf-8")).hexdigest()
def community_checkpoint_key(level: str, community_id: str, nodes: list[str]) -> str:
return stable_checkpoint_key("community", str(level), str(community_id), sorted(nodes))
def resolution_checkpoint_key(entity_type: str, pairs: list[tuple[str, str]]) -> str:
normalized_pairs = sorted([sorted([a, b]) for a, b in pairs])
return stable_checkpoint_key("resolution", entity_type, normalized_pairs)
def _checkpoint_index_key(tenant_id: str, kb_id: str, checkpoint_type: str) -> str:
return f"graphrag:checkpoint:{tenant_id}:{kb_id}:{checkpoint_type}:keys"
def _checkpoint_data_key(tenant_id: str, kb_id: str, checkpoint_type: str, checkpoint_key: str) -> str:
return f"graphrag:checkpoint:{tenant_id}:{kb_id}:{checkpoint_type}:{checkpoint_key}"
def _decode_redis_value(value: Any) -> Any:
if isinstance(value, bytes):
return value.decode("utf-8")
return value
def _checkpoint_page_size(page_size: int | None) -> int:
return page_size if page_size and page_size > 0 else CHECKPOINT_PAGE_SIZE
def _iter_checkpoint_keys(index_key: str, page_size: int | None):
redis_client = getattr(REDIS_CONN, "REDIS", None)
if redis_client is None or not hasattr(redis_client, "sscan_iter"):
raise RuntimeError("Redis SSCAN is unavailable for GraphRAG checkpoint index iteration")
return redis_client.sscan_iter(index_key, count=_checkpoint_page_size(page_size))
def _load_checkpoints_sync(tenant_id: str, kb_id: str, checkpoint_type: str, page_size: int | None) -> dict[str, Any]:
checkpoints: dict[str, Any] = {}
index_key = _checkpoint_index_key(tenant_id, kb_id, checkpoint_type)
try:
checkpoint_keys = _iter_checkpoint_keys(index_key, page_size)
except Exception:
logging.exception("Failed to load GraphRAG checkpoint index type=%s kb=%s", checkpoint_type, kb_id)
return checkpoints
for checkpoint_key in checkpoint_keys:
checkpoint_key = _decode_redis_value(checkpoint_key)
try:
value = REDIS_CONN.get(_checkpoint_data_key(tenant_id, kb_id, checkpoint_type, checkpoint_key))
value = _decode_redis_value(value)
if not value:
continue
checkpoints[checkpoint_key] = json.loads(value)
except Exception:
logging.exception("Failed to parse GraphRAG checkpoint type=%s kb=%s key=%s", checkpoint_type, kb_id, checkpoint_key)
logging.info("Loaded %d GraphRAG checkpoints type=%s kb=%s", len(checkpoints), checkpoint_type, kb_id)
return checkpoints
async def load_checkpoints(tenant_id: str, kb_id: str, checkpoint_type: str, *, page_size: int | None = None) -> dict[str, Any]:
return await thread_pool_exec(_load_checkpoints_sync, tenant_id, kb_id, checkpoint_type, page_size)
async def save_checkpoint(tenant_id: str, kb_id: str, checkpoint_type: str, checkpoint_key: str, payload: Any) -> bool:
index_key = _checkpoint_index_key(tenant_id, kb_id, checkpoint_type)
data_key = _checkpoint_data_key(tenant_id, kb_id, checkpoint_type, checkpoint_key)
try:
redis_client = getattr(REDIS_CONN, "REDIS", None)
if redis_client is None or not hasattr(redis_client, "pipeline"):
logging.warning("GraphRAG checkpoint Redis client unavailable type=%s kb=%s key=%s", checkpoint_type, kb_id, checkpoint_key)
return False
pipeline = redis_client.pipeline(transaction=True)
pipeline.set(data_key, json.dumps(payload, ensure_ascii=False), ex=CHECKPOINT_TTL_SECONDS)
pipeline.sadd(index_key, checkpoint_key)
pipeline.expire(index_key, CHECKPOINT_TTL_SECONDS)
pipeline.execute()
logging.info("Saved GraphRAG checkpoint type=%s kb=%s key=%s", checkpoint_type, kb_id, checkpoint_key)
return True
except Exception:
logging.exception("Failed to save GraphRAG checkpoint type=%s kb=%s key=%s", checkpoint_type, kb_id, checkpoint_key)
return False
async def cleanup_checkpoints(tenant_id: str, kb_id: str, checkpoint_type: str, *, page_size: int | None = None) -> bool:
index_key = _checkpoint_index_key(tenant_id, kb_id, checkpoint_type)
try:
cleaned_count = 0
checkpoint_keys = _iter_checkpoint_keys(index_key, page_size)
for checkpoint_key in checkpoint_keys:
checkpoint_key = _decode_redis_value(checkpoint_key)
REDIS_CONN.delete(_checkpoint_data_key(tenant_id, kb_id, checkpoint_type, checkpoint_key))
cleaned_count += 1
REDIS_CONN.delete(index_key)
logging.info("Cleaned up %d GraphRAG checkpoints type=%s kb=%s", cleaned_count, checkpoint_type, kb_id)
return True
except Exception:
logging.exception("Failed to cleanup GraphRAG checkpoints type=%s kb=%s", checkpoint_type, kb_id)
return False