feat: Implement checkpoint/resume support for GraphRAG community extraction and entity resolution (#15523)

## Summary

This PR adds checkpoint/resume support for the GraphRAG
`extract_community` and `resolve_entities` stages.

The implementation stores successful intermediate results in the
document store so interrupted ingestion can resume without repeating
already-completed LLM work. Checkpoints are loaded before each stage,
reused when available, saved after successful batch/community
processing, and cleaned up after the stage completes successfully.
## Related Issue
Closes: #15518
## Change Type
- [x] Feature
- [x] Bug fix
- [x] Test
- [ ] Refactor
- [ ] Documentation
- [ ] Breaking change
## Real Behavior Proof

Validation commands run locally:

```bash
uv run python -m py_compile \
  rag/graphrag/checkpoints.py \
  rag/graphrag/general/community_reports_extractor.py \
  rag/graphrag/entity_resolution.py \
  rag/graphrag/general/index.py \
  test/unit_test/rag/graphrag/test_checkpoints.py
```
Result:

```text
Passed
```

```bash
uv run pytest test/unit_test/rag/graphrag/test_checkpoints.py
```
Result:

```text
4 passed
```

```bash
uv run pytest \
  test/unit_test/rag/graphrag/test_phase_markers.py \
  test/unit_test/rag/graphrag/test_graphrag_utils.py \
  test/unit_test/rag/graphrag/test_checkpoints.py
```
Result:

```text
95 passed
```

```bash
git diff --check
```
Result:

```text
Passed
```

## Checklist

- [x] Implemented checkpoint/resume support for `extract_community`.
- [x] Implemented checkpoint/resume support for `resolve_entities`.
- [x] Avoided touching unrelated API behavior.
- [x] Added unit tests for the new checkpoint helper logic.
- [x] Verified Python syntax compilation.
- [x] Ran related GraphRAG unit tests successfully.
- [x] Ran `git diff --check`.
- [ ] Ran full project test suite.

---------

Co-authored-by: Wang Qi <wangq8@outlook.com>
This commit is contained in:
Jonathan Chang
2026-06-09 14:34:47 +07:00
committed by GitHub
parent d02eb6b596
commit c586292993
5 changed files with 411 additions and 15 deletions

134
rag/graphrag/checkpoints.py Normal file
View File

@@ -0,0 +1,134 @@
#
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from __future__ import annotations
import hashlib
import json
import logging
from typing import Any
from common.misc_utils import thread_pool_exec
from rag.utils.redis_conn import REDIS_CONN
COMMUNITY_CHECKPOINT = "graphrag_checkpoint_community"
RESOLUTION_CHECKPOINT = "graphrag_checkpoint_resolution"
CHECKPOINT_PAGE_SIZE = 1000
CHECKPOINT_TTL_SECONDS = 7 * 24 * 3600
def stable_checkpoint_key(*parts: Any) -> str:
payload = json.dumps(parts, ensure_ascii=False, sort_keys=True, separators=(",", ":"))
return hashlib.sha256(payload.encode("utf-8")).hexdigest()
def community_checkpoint_key(level: str, community_id: str, nodes: list[str]) -> str:
return stable_checkpoint_key("community", str(level), str(community_id), sorted(nodes))
def resolution_checkpoint_key(entity_type: str, pairs: list[tuple[str, str]]) -> str:
normalized_pairs = sorted([sorted([a, b]) for a, b in pairs])
return stable_checkpoint_key("resolution", entity_type, normalized_pairs)
def _checkpoint_index_key(tenant_id: str, kb_id: str, checkpoint_type: str) -> str:
return f"graphrag:checkpoint:{tenant_id}:{kb_id}:{checkpoint_type}:keys"
def _checkpoint_data_key(tenant_id: str, kb_id: str, checkpoint_type: str, checkpoint_key: str) -> str:
return f"graphrag:checkpoint:{tenant_id}:{kb_id}:{checkpoint_type}:{checkpoint_key}"
def _decode_redis_value(value: Any) -> Any:
if isinstance(value, bytes):
return value.decode("utf-8")
return value
def _checkpoint_page_size(page_size: int | None) -> int:
return page_size if page_size and page_size > 0 else CHECKPOINT_PAGE_SIZE
def _iter_checkpoint_keys(index_key: str, page_size: int | None):
redis_client = getattr(REDIS_CONN, "REDIS", None)
if redis_client is None or not hasattr(redis_client, "sscan_iter"):
raise RuntimeError("Redis SSCAN is unavailable for GraphRAG checkpoint index iteration")
return redis_client.sscan_iter(index_key, count=_checkpoint_page_size(page_size))
def _load_checkpoints_sync(tenant_id: str, kb_id: str, checkpoint_type: str, page_size: int | None) -> dict[str, Any]:
checkpoints: dict[str, Any] = {}
index_key = _checkpoint_index_key(tenant_id, kb_id, checkpoint_type)
try:
checkpoint_keys = _iter_checkpoint_keys(index_key, page_size)
except Exception:
logging.exception("Failed to load GraphRAG checkpoint index type=%s kb=%s", checkpoint_type, kb_id)
return checkpoints
for checkpoint_key in checkpoint_keys:
checkpoint_key = _decode_redis_value(checkpoint_key)
try:
value = REDIS_CONN.get(_checkpoint_data_key(tenant_id, kb_id, checkpoint_type, checkpoint_key))
value = _decode_redis_value(value)
if not value:
continue
checkpoints[checkpoint_key] = json.loads(value)
except Exception:
logging.exception("Failed to parse GraphRAG checkpoint type=%s kb=%s key=%s", checkpoint_type, kb_id, checkpoint_key)
logging.info("Loaded %d GraphRAG checkpoints type=%s kb=%s", len(checkpoints), checkpoint_type, kb_id)
return checkpoints
async def load_checkpoints(tenant_id: str, kb_id: str, checkpoint_type: str, *, page_size: int | None = None) -> dict[str, Any]:
return await thread_pool_exec(_load_checkpoints_sync, tenant_id, kb_id, checkpoint_type, page_size)
async def save_checkpoint(tenant_id: str, kb_id: str, checkpoint_type: str, checkpoint_key: str, payload: Any) -> bool:
index_key = _checkpoint_index_key(tenant_id, kb_id, checkpoint_type)
data_key = _checkpoint_data_key(tenant_id, kb_id, checkpoint_type, checkpoint_key)
try:
redis_client = getattr(REDIS_CONN, "REDIS", None)
if redis_client is None or not hasattr(redis_client, "pipeline"):
logging.warning("GraphRAG checkpoint Redis client unavailable type=%s kb=%s key=%s", checkpoint_type, kb_id, checkpoint_key)
return False
pipeline = redis_client.pipeline(transaction=True)
pipeline.set(data_key, json.dumps(payload, ensure_ascii=False), ex=CHECKPOINT_TTL_SECONDS)
pipeline.sadd(index_key, checkpoint_key)
pipeline.expire(index_key, CHECKPOINT_TTL_SECONDS)
pipeline.execute()
logging.info("Saved GraphRAG checkpoint type=%s kb=%s key=%s", checkpoint_type, kb_id, checkpoint_key)
return True
except Exception:
logging.exception("Failed to save GraphRAG checkpoint type=%s kb=%s key=%s", checkpoint_type, kb_id, checkpoint_key)
return False
async def cleanup_checkpoints(tenant_id: str, kb_id: str, checkpoint_type: str, *, page_size: int | None = None) -> bool:
index_key = _checkpoint_index_key(tenant_id, kb_id, checkpoint_type)
try:
cleaned_count = 0
checkpoint_keys = _iter_checkpoint_keys(index_key, page_size)
for checkpoint_key in checkpoint_keys:
checkpoint_key = _decode_redis_value(checkpoint_key)
REDIS_CONN.delete(_checkpoint_data_key(tenant_id, kb_id, checkpoint_type, checkpoint_key))
cleaned_count += 1
REDIS_CONN.delete(index_key)
logging.info("Cleaned up %d GraphRAG checkpoints type=%s kb=%s", cleaned_count, checkpoint_type, kb_id)
return True
except Exception:
logging.exception("Failed to cleanup GraphRAG checkpoints type=%s kb=%s", checkpoint_type, kb_id)
return False

View File

@@ -19,7 +19,7 @@ import itertools
import os
import re
from dataclasses import dataclass
from typing import Any, Callable
from typing import Any, Awaitable, Callable
import networkx as nx
@@ -27,6 +27,7 @@ from rag.graphrag.general.extractor import Extractor
from rag.nlp import is_english
import editdistance
from rag.graphrag.entity_resolution_prompt import ENTITY_RESOLUTION_PROMPT
from rag.graphrag.checkpoints import resolution_checkpoint_key
from rag.llm.chat_model import Base as CompletionLLM
from rag.graphrag.utils import perform_variable_replacements, chat_limiter, GraphChange
from api.db.services.task_service import has_canceled
@@ -71,7 +72,9 @@ class EntityResolution(Extractor):
subgraph_nodes: set[str],
prompt_variables: dict[str, Any] | None = None,
callback: Callable | None = None,
task_id: str = "") -> EntityResolutionResult:
task_id: str = "",
checkpoints: dict[str, Any] | None = None,
save_checkpoint: Callable[[str, Any], Awaitable[bool]] | None = None) -> EntityResolutionResult:
"""Call method definition."""
if prompt_variables is None:
prompt_variables = {}
@@ -106,19 +109,35 @@ class EntityResolution(Extractor):
resolution_batch_size = 100
max_concurrent_tasks = 5
semaphore = asyncio.Semaphore(max_concurrent_tasks)
checkpoints = checkpoints or {}
async def limited_resolve_candidate(candidate_batch, result_set, result_lock):
nonlocal remain_candidates_to_resolve, callback
async with semaphore:
try:
checkpoint_key = resolution_checkpoint_key(candidate_batch[0], candidate_batch[1])
checkpoint = checkpoints.get(checkpoint_key)
if isinstance(checkpoint, list):
async with result_lock:
for pair in checkpoint:
if isinstance(pair, (list, tuple)) and len(pair) == 2:
result_set.add((pair[0], pair[1]))
remain_candidates_to_resolve -= len(candidate_batch[1])
callback(
msg=f"Replayed {len(candidate_batch[1])} resolved pairs from checkpoint, "
f"{remain_candidates_to_resolve} remain."
)
return
enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
timeout_sec = 280 if enable_timeout_assertion else 1_000_000_000
try:
await asyncio.wait_for(
selected_pairs = await asyncio.wait_for(
self._resolve_candidate(candidate_batch, result_set, result_lock, task_id),
timeout=timeout_sec
)
if selected_pairs is not None and save_checkpoint:
await save_checkpoint(checkpoint_key, [list(pair) for pair in selected_pairs])
remain_candidates_to_resolve -= len(candidate_batch[1])
callback(
msg=f"Resolved {len(candidate_batch[1])} pairs, "
@@ -219,10 +238,10 @@ class EntityResolution(Extractor):
except asyncio.TimeoutError:
logging.warning("_resolve_candidate._async_chat timeout, skipping...")
return
return None
except Exception as e:
logging.error(f"_resolve_candidate._async_chat failed: {e}")
return
return None
logging.debug(f"_resolve_candidate chat prompt: {text}\nchat response: {response}")
result = self._process_results(len(candidate_resolution_i[1]), response,
@@ -232,9 +251,11 @@ class EntityResolution(Extractor):
DEFAULT_ENTITY_INDEX_DELIMITER),
self.prompt_variables.get(self._resolution_result_delimiter_key,
DEFAULT_RESOLUTION_RESULT_DELIMITER))
selected_pairs = [candidate_resolution_i[1][result_i[0] - 1] for result_i in result]
async with resolution_result_lock:
for result_i in result:
resolution_result.add(candidate_resolution_i[1][result_i[0] - 1])
for pair in selected_pairs:
resolution_result.add(pair)
return selected_pairs
def _process_results(
self,
@@ -288,4 +309,3 @@ class EntityResolution(Extractor):
return len(a & b) > 1
return len(a & b)*1./max_l >= 0.8

View File

@@ -12,7 +12,7 @@ import logging
import json
import os
import re
from typing import Callable
from typing import Any, Awaitable, Callable
from dataclasses import dataclass
import networkx as nx
import pandas as pd
@@ -23,6 +23,7 @@ from rag.graphrag.general import leiden
from rag.graphrag.general.community_report_prompt import COMMUNITY_REPORT_PROMPT
from rag.graphrag.general.extractor import Extractor
from rag.graphrag.general.leiden import add_community_info2graph
from rag.graphrag.checkpoints import community_checkpoint_key
from rag.llm.chat_model import Base as CompletionLLM
from rag.graphrag.utils import perform_variable_replacements, dict_has_keys_with_types, chat_limiter
from common.token_utils import num_tokens_from_string
@@ -53,7 +54,14 @@ class CommunityReportsExtractor(Extractor):
self._extraction_prompt = COMMUNITY_REPORT_PROMPT
self._max_report_length = max_report_length or 1500
async def __call__(self, graph: nx.Graph, callback: Callable | None = None, task_id: str = ""):
async def __call__(
self,
graph: nx.Graph,
callback: Callable | None = None,
task_id: str = "",
checkpoints: dict[str, Any] | None = None,
save_checkpoint: Callable[[str, Any], Awaitable[bool]] | None = None,
):
enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
for node_degree in graph.degree:
graph.nodes[str(node_degree[0])]["rank"] = int(node_degree[1])
@@ -63,7 +71,9 @@ class CommunityReportsExtractor(Extractor):
res_str = []
res_dict = []
over, token_count = 0, 0
async def extract_community_report(community):
checkpoints = checkpoints or {}
async def extract_community_report(level, community):
nonlocal res_str, res_dict, over, token_count
if task_id:
if has_canceled(task_id):
@@ -75,6 +85,19 @@ class CommunityReportsExtractor(Extractor):
ents = cm["nodes"]
if len(ents) < 2:
return
checkpoint_key = community_checkpoint_key(str(level), str(cm_id), list(ents))
checkpoint = checkpoints.get(checkpoint_key)
if isinstance(checkpoint, dict):
response = checkpoint.get("structured_output")
output = checkpoint.get("output")
if isinstance(response, dict) and isinstance(output, str):
add_community_info2graph(graph, response.get("entities", ents), response.get("title", ""))
res_str.append(output)
res_dict.append(response)
over += 1
if callback:
callback(msg=f"Communities: {over}/{total}, used tokens: {token_count}")
return
ent_list = [{"entity": ent, "description": graph.nodes[ent]["description"]} for ent in ents]
ent_df = pd.DataFrame(ent_list)
@@ -131,7 +154,10 @@ class CommunityReportsExtractor(Extractor):
response["weight"] = weight
response["entities"] = ents
add_community_info2graph(graph, ents, response["title"])
res_str.append(self._get_text_output(response))
output = self._get_text_output(response)
if save_checkpoint:
await save_checkpoint(checkpoint_key, {"structured_output": response, "output": output})
res_str.append(output)
res_dict.append(response)
over += 1
if callback:
@@ -145,7 +171,7 @@ class CommunityReportsExtractor(Extractor):
if task_id and has_canceled(task_id):
logging.info(f"Task {task_id} cancelled before community processing.")
raise TaskCanceledException(f"Task {task_id} was cancelled")
tasks.append(asyncio.create_task(extract_community_report(community)))
tasks.append(asyncio.create_task(extract_community_report(level, community)))
try:
await asyncio.gather(*tasks, return_exceptions=False)
except Exception as e:

View File

@@ -24,6 +24,13 @@ from api.db.services.task_service import has_canceled
from common.exceptions import TaskCanceledException
from common.connection_utils import timeout
from rag.graphrag.entity_resolution import EntityResolution
from rag.graphrag.checkpoints import (
COMMUNITY_CHECKPOINT,
RESOLUTION_CHECKPOINT,
cleanup_checkpoints,
load_checkpoints,
save_checkpoint,
)
from rag.graphrag.general.community_reports_extractor import CommunityReportsExtractor
from rag.graphrag.general.extractor import Extractor
from rag.graphrag.general.graph_extractor import GraphExtractor as GeneralKGExt
@@ -763,10 +770,22 @@ async def resolve_entities(
_has_cancel_and_exit(task_id, f"Task {task_id} cancelled during entity resolution.", callback)
start = asyncio.get_running_loop().time()
checkpoints = await load_checkpoints(tenant_id, kb_id, RESOLUTION_CHECKPOINT)
async def save_resolution_checkpoint(checkpoint_key: str, payload):
return await save_checkpoint(tenant_id, kb_id, RESOLUTION_CHECKPOINT, checkpoint_key, payload)
er = EntityResolution(
llm_bdl,
)
reso = await er(graph, subgraph_nodes, callback=callback, task_id=task_id)
reso = await er(
graph,
subgraph_nodes,
callback=callback,
task_id=task_id,
checkpoints=checkpoints,
save_checkpoint=save_resolution_checkpoint,
)
graph = reso.graph
change = reso.change
callback(msg=f"Graph resolution removed {len(change.removed_nodes)} nodes and {len(change.removed_edges)} edges.")
@@ -776,6 +795,7 @@ async def resolve_entities(
_has_cancel_and_exit(task_id, f"Task {task_id} cancelled before saving resolved graph.", callback)
await set_graph(tenant_id, kb_id, embed_bdl, graph, change, callback)
await cleanup_checkpoints(tenant_id, kb_id, RESOLUTION_CHECKPOINT)
now = asyncio.get_running_loop().time()
callback(msg=f"Graph resolution done in {now - start:.2f}s.")
@@ -794,10 +814,21 @@ async def extract_community(
_has_cancel_and_exit(task_id, f"Task {task_id} cancelled before community extraction.", callback)
start = asyncio.get_running_loop().time()
checkpoints = await load_checkpoints(tenant_id, kb_id, COMMUNITY_CHECKPOINT)
async def save_community_checkpoint(checkpoint_key: str, payload):
return await save_checkpoint(tenant_id, kb_id, COMMUNITY_CHECKPOINT, checkpoint_key, payload)
ext = CommunityReportsExtractor(
llm_bdl,
)
cr = await ext(graph, callback=callback, task_id=task_id)
cr = await ext(
graph,
callback=callback,
task_id=task_id,
checkpoints=checkpoints,
save_checkpoint=save_community_checkpoint,
)
_has_cancel_and_exit(task_id, f"Task {task_id} cancelled during community extraction.", callback)
@@ -881,6 +912,7 @@ async def extract_community(
logging.exception("Failed to prune %d stale community reports for kb %s", len(stale_ids), kb_id)
_has_cancel_and_exit(task_id, f"Task {task_id} cancelled after community indexing.", callback)
await cleanup_checkpoints(tenant_id, kb_id, COMMUNITY_CHECKPOINT)
now = asyncio.get_running_loop().time()
callback(msg=f"Graph indexed {len(cr.structured_output)} communities in {now - start:.2f}s.")

View File

@@ -0,0 +1,184 @@
#
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import pytest
from rag.graphrag import checkpoints
class _FakeRedisClient:
def __init__(self, conn):
self.conn = conn
self.expirations = {}
self.scan_counts = []
def expire(self, key, ttl):
self.expirations[key] = ttl
return True
def pipeline(self, transaction=True):
assert transaction is True
return _FakeRedisPipeline(self.conn)
def sscan_iter(self, key, count=None):
self.scan_counts.append((key, count))
yield from self.conn.sets.get(key, set())
class _FakeRedisPipeline:
def __init__(self, conn):
self.conn = conn
self.commands = []
def set(self, key, value, ex=None):
self.commands.append(("set", key, value, ex))
return self
def sadd(self, key, member):
self.commands.append(("sadd", key, member))
return self
def expire(self, key, ttl):
self.commands.append(("expire", key, ttl))
return self
def execute(self):
if self.conn.fail_pipeline:
raise RuntimeError("redis transaction failed")
for command in self.commands:
match command:
case ("set", key, value, ttl):
self.conn.values[key] = value
if ttl is not None:
self.conn.REDIS.expire(key, ttl)
case ("sadd", key, member):
self.conn.sets.setdefault(key, set()).add(member)
case ("expire", key, ttl):
self.conn.REDIS.expire(key, ttl)
return [True] * len(self.commands)
class _FakeRedisConn:
def __init__(self):
self.values = {}
self.sets = {}
self.REDIS = _FakeRedisClient(self)
self.fail_set = False
self.fail_pipeline = False
def get(self, key):
return self.values.get(key)
def set(self, key, value, exp=3600):
if self.fail_set:
return False
self.values[key] = value
self.REDIS.expire(key, exp)
return True
def sadd(self, key, member):
self.sets.setdefault(key, set()).add(member)
return True
def smembers(self, key):
raise AssertionError("checkpoint code must use sscan_iter instead of smembers")
def delete(self, key):
self.values.pop(key, None)
self.sets.pop(key, None)
return True
@pytest.fixture
def fake_redis(monkeypatch):
fake = _FakeRedisConn()
monkeypatch.setattr(checkpoints, "REDIS_CONN", fake)
return fake
@pytest.mark.p1
def test_checkpoint_keys_are_stable():
first = checkpoints.community_checkpoint_key("1", "2", ["B", "A"])
second = checkpoints.community_checkpoint_key("1", "2", ["A", "B"])
assert first == second
pairs = [("alpha", "alfa"), ("beta", "bata")]
reversed_pairs = list(reversed(pairs))
assert checkpoints.resolution_checkpoint_key("entity", pairs) == checkpoints.resolution_checkpoint_key("entity", reversed_pairs)
internally_reversed_pairs = [("alfa", "alpha"), ("bata", "beta")]
assert checkpoints.resolution_checkpoint_key("entity", pairs) == checkpoints.resolution_checkpoint_key("entity", internally_reversed_pairs)
@pytest.mark.p1
@pytest.mark.asyncio
async def test_load_checkpoints_reads_redis_index(fake_redis, monkeypatch):
await checkpoints.save_checkpoint("tenant-1", "kb-1", checkpoints.COMMUNITY_CHECKPOINT, "k1", {"value": 1})
await checkpoints.save_checkpoint("tenant-1", "kb-1", checkpoints.COMMUNITY_CHECKPOINT, "k2", {"value": 2})
await checkpoints.save_checkpoint("tenant-1", "kb-2", checkpoints.COMMUNITY_CHECKPOINT, "k3", {"value": 3})
thread_pool_calls = []
async def _fake_thread_pool_exec(func, *args, **kwargs):
thread_pool_calls.append((func, args, kwargs))
return func(*args, **kwargs)
monkeypatch.setattr(checkpoints, "thread_pool_exec", _fake_thread_pool_exec)
loaded = await checkpoints.load_checkpoints("tenant-1", "kb-1", checkpoints.COMMUNITY_CHECKPOINT, page_size=1)
assert loaded == {"k1": {"value": 1}, "k2": {"value": 2}}
assert thread_pool_calls == [
(
checkpoints._load_checkpoints_sync,
("tenant-1", "kb-1", checkpoints.COMMUNITY_CHECKPOINT, 1),
{},
)
]
assert fake_redis.REDIS.scan_counts[-1] == (
"graphrag:checkpoint:tenant-1:kb-1:graphrag_checkpoint_community:keys",
1,
)
@pytest.mark.p2
@pytest.mark.asyncio
async def test_save_checkpoint_degrades_on_redis_failure(fake_redis):
fake_redis.fail_pipeline = True
saved = await checkpoints.save_checkpoint("tenant-1", "kb-1", checkpoints.RESOLUTION_CHECKPOINT, "key-1", {"ok": True})
assert saved is False
assert fake_redis.values == {}
assert fake_redis.sets == {}
@pytest.mark.p2
@pytest.mark.asyncio
async def test_cleanup_checkpoints_deletes_redis_stage_keys(fake_redis):
await checkpoints.save_checkpoint("tenant-1", "kb-1", checkpoints.RESOLUTION_CHECKPOINT, "k1", {"value": 1})
await checkpoints.save_checkpoint("tenant-1", "kb-1", checkpoints.RESOLUTION_CHECKPOINT, "k2", {"value": 2})
await checkpoints.save_checkpoint("tenant-1", "kb-1", checkpoints.COMMUNITY_CHECKPOINT, "k3", {"value": 3})
cleaned = await checkpoints.cleanup_checkpoints("tenant-1", "kb-1", checkpoints.RESOLUTION_CHECKPOINT, page_size=1)
assert cleaned is True
assert await checkpoints.load_checkpoints("tenant-1", "kb-1", checkpoints.RESOLUTION_CHECKPOINT) == {}
assert await checkpoints.load_checkpoints("tenant-1", "kb-1", checkpoints.COMMUNITY_CHECKPOINT) == {"k3": {"value": 3}}
assert (
"graphrag:checkpoint:tenant-1:kb-1:graphrag_checkpoint_resolution:keys",
1,
) in fake_redis.REDIS.scan_counts