mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 23:41:12 +08:00
feat: Implement checkpoint/resume support for GraphRAG community extraction and entity resolution (#15523)
## Summary This PR adds checkpoint/resume support for the GraphRAG `extract_community` and `resolve_entities` stages. The implementation stores successful intermediate results in the document store so interrupted ingestion can resume without repeating already-completed LLM work. Checkpoints are loaded before each stage, reused when available, saved after successful batch/community processing, and cleaned up after the stage completes successfully. ## Related Issue Closes: #15518 ## Change Type - [x] Feature - [x] Bug fix - [x] Test - [ ] Refactor - [ ] Documentation - [ ] Breaking change ## Real Behavior Proof Validation commands run locally: ```bash uv run python -m py_compile \ rag/graphrag/checkpoints.py \ rag/graphrag/general/community_reports_extractor.py \ rag/graphrag/entity_resolution.py \ rag/graphrag/general/index.py \ test/unit_test/rag/graphrag/test_checkpoints.py ``` Result: ```text Passed ``` ```bash uv run pytest test/unit_test/rag/graphrag/test_checkpoints.py ``` Result: ```text 4 passed ``` ```bash uv run pytest \ test/unit_test/rag/graphrag/test_phase_markers.py \ test/unit_test/rag/graphrag/test_graphrag_utils.py \ test/unit_test/rag/graphrag/test_checkpoints.py ``` Result: ```text 95 passed ``` ```bash git diff --check ``` Result: ```text Passed ``` ## Checklist - [x] Implemented checkpoint/resume support for `extract_community`. - [x] Implemented checkpoint/resume support for `resolve_entities`. - [x] Avoided touching unrelated API behavior. - [x] Added unit tests for the new checkpoint helper logic. - [x] Verified Python syntax compilation. - [x] Ran related GraphRAG unit tests successfully. - [x] Ran `git diff --check`. - [ ] Ran full project test suite. --------- Co-authored-by: Wang Qi <wangq8@outlook.com>
This commit is contained in:
134
rag/graphrag/checkpoints.py
Normal file
134
rag/graphrag/checkpoints.py
Normal file
@@ -0,0 +1,134 @@
|
||||
#
|
||||
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from common.misc_utils import thread_pool_exec
|
||||
from rag.utils.redis_conn import REDIS_CONN
|
||||
|
||||
|
||||
COMMUNITY_CHECKPOINT = "graphrag_checkpoint_community"
|
||||
RESOLUTION_CHECKPOINT = "graphrag_checkpoint_resolution"
|
||||
CHECKPOINT_PAGE_SIZE = 1000
|
||||
CHECKPOINT_TTL_SECONDS = 7 * 24 * 3600
|
||||
|
||||
|
||||
def stable_checkpoint_key(*parts: Any) -> str:
|
||||
payload = json.dumps(parts, ensure_ascii=False, sort_keys=True, separators=(",", ":"))
|
||||
return hashlib.sha256(payload.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
def community_checkpoint_key(level: str, community_id: str, nodes: list[str]) -> str:
|
||||
return stable_checkpoint_key("community", str(level), str(community_id), sorted(nodes))
|
||||
|
||||
|
||||
def resolution_checkpoint_key(entity_type: str, pairs: list[tuple[str, str]]) -> str:
|
||||
normalized_pairs = sorted([sorted([a, b]) for a, b in pairs])
|
||||
return stable_checkpoint_key("resolution", entity_type, normalized_pairs)
|
||||
|
||||
|
||||
def _checkpoint_index_key(tenant_id: str, kb_id: str, checkpoint_type: str) -> str:
|
||||
return f"graphrag:checkpoint:{tenant_id}:{kb_id}:{checkpoint_type}:keys"
|
||||
|
||||
|
||||
def _checkpoint_data_key(tenant_id: str, kb_id: str, checkpoint_type: str, checkpoint_key: str) -> str:
|
||||
return f"graphrag:checkpoint:{tenant_id}:{kb_id}:{checkpoint_type}:{checkpoint_key}"
|
||||
|
||||
|
||||
def _decode_redis_value(value: Any) -> Any:
|
||||
if isinstance(value, bytes):
|
||||
return value.decode("utf-8")
|
||||
return value
|
||||
|
||||
|
||||
def _checkpoint_page_size(page_size: int | None) -> int:
|
||||
return page_size if page_size and page_size > 0 else CHECKPOINT_PAGE_SIZE
|
||||
|
||||
|
||||
def _iter_checkpoint_keys(index_key: str, page_size: int | None):
|
||||
redis_client = getattr(REDIS_CONN, "REDIS", None)
|
||||
if redis_client is None or not hasattr(redis_client, "sscan_iter"):
|
||||
raise RuntimeError("Redis SSCAN is unavailable for GraphRAG checkpoint index iteration")
|
||||
return redis_client.sscan_iter(index_key, count=_checkpoint_page_size(page_size))
|
||||
|
||||
|
||||
def _load_checkpoints_sync(tenant_id: str, kb_id: str, checkpoint_type: str, page_size: int | None) -> dict[str, Any]:
|
||||
checkpoints: dict[str, Any] = {}
|
||||
index_key = _checkpoint_index_key(tenant_id, kb_id, checkpoint_type)
|
||||
try:
|
||||
checkpoint_keys = _iter_checkpoint_keys(index_key, page_size)
|
||||
except Exception:
|
||||
logging.exception("Failed to load GraphRAG checkpoint index type=%s kb=%s", checkpoint_type, kb_id)
|
||||
return checkpoints
|
||||
|
||||
for checkpoint_key in checkpoint_keys:
|
||||
checkpoint_key = _decode_redis_value(checkpoint_key)
|
||||
try:
|
||||
value = REDIS_CONN.get(_checkpoint_data_key(tenant_id, kb_id, checkpoint_type, checkpoint_key))
|
||||
value = _decode_redis_value(value)
|
||||
if not value:
|
||||
continue
|
||||
checkpoints[checkpoint_key] = json.loads(value)
|
||||
except Exception:
|
||||
logging.exception("Failed to parse GraphRAG checkpoint type=%s kb=%s key=%s", checkpoint_type, kb_id, checkpoint_key)
|
||||
logging.info("Loaded %d GraphRAG checkpoints type=%s kb=%s", len(checkpoints), checkpoint_type, kb_id)
|
||||
return checkpoints
|
||||
|
||||
|
||||
async def load_checkpoints(tenant_id: str, kb_id: str, checkpoint_type: str, *, page_size: int | None = None) -> dict[str, Any]:
|
||||
return await thread_pool_exec(_load_checkpoints_sync, tenant_id, kb_id, checkpoint_type, page_size)
|
||||
|
||||
|
||||
async def save_checkpoint(tenant_id: str, kb_id: str, checkpoint_type: str, checkpoint_key: str, payload: Any) -> bool:
|
||||
index_key = _checkpoint_index_key(tenant_id, kb_id, checkpoint_type)
|
||||
data_key = _checkpoint_data_key(tenant_id, kb_id, checkpoint_type, checkpoint_key)
|
||||
try:
|
||||
redis_client = getattr(REDIS_CONN, "REDIS", None)
|
||||
if redis_client is None or not hasattr(redis_client, "pipeline"):
|
||||
logging.warning("GraphRAG checkpoint Redis client unavailable type=%s kb=%s key=%s", checkpoint_type, kb_id, checkpoint_key)
|
||||
return False
|
||||
pipeline = redis_client.pipeline(transaction=True)
|
||||
pipeline.set(data_key, json.dumps(payload, ensure_ascii=False), ex=CHECKPOINT_TTL_SECONDS)
|
||||
pipeline.sadd(index_key, checkpoint_key)
|
||||
pipeline.expire(index_key, CHECKPOINT_TTL_SECONDS)
|
||||
pipeline.execute()
|
||||
logging.info("Saved GraphRAG checkpoint type=%s kb=%s key=%s", checkpoint_type, kb_id, checkpoint_key)
|
||||
return True
|
||||
except Exception:
|
||||
logging.exception("Failed to save GraphRAG checkpoint type=%s kb=%s key=%s", checkpoint_type, kb_id, checkpoint_key)
|
||||
return False
|
||||
|
||||
|
||||
async def cleanup_checkpoints(tenant_id: str, kb_id: str, checkpoint_type: str, *, page_size: int | None = None) -> bool:
|
||||
index_key = _checkpoint_index_key(tenant_id, kb_id, checkpoint_type)
|
||||
try:
|
||||
cleaned_count = 0
|
||||
checkpoint_keys = _iter_checkpoint_keys(index_key, page_size)
|
||||
for checkpoint_key in checkpoint_keys:
|
||||
checkpoint_key = _decode_redis_value(checkpoint_key)
|
||||
REDIS_CONN.delete(_checkpoint_data_key(tenant_id, kb_id, checkpoint_type, checkpoint_key))
|
||||
cleaned_count += 1
|
||||
REDIS_CONN.delete(index_key)
|
||||
logging.info("Cleaned up %d GraphRAG checkpoints type=%s kb=%s", cleaned_count, checkpoint_type, kb_id)
|
||||
return True
|
||||
except Exception:
|
||||
logging.exception("Failed to cleanup GraphRAG checkpoints type=%s kb=%s", checkpoint_type, kb_id)
|
||||
return False
|
||||
@@ -19,7 +19,7 @@ import itertools
|
||||
import os
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable
|
||||
from typing import Any, Awaitable, Callable
|
||||
|
||||
import networkx as nx
|
||||
|
||||
@@ -27,6 +27,7 @@ from rag.graphrag.general.extractor import Extractor
|
||||
from rag.nlp import is_english
|
||||
import editdistance
|
||||
from rag.graphrag.entity_resolution_prompt import ENTITY_RESOLUTION_PROMPT
|
||||
from rag.graphrag.checkpoints import resolution_checkpoint_key
|
||||
from rag.llm.chat_model import Base as CompletionLLM
|
||||
from rag.graphrag.utils import perform_variable_replacements, chat_limiter, GraphChange
|
||||
from api.db.services.task_service import has_canceled
|
||||
@@ -71,7 +72,9 @@ class EntityResolution(Extractor):
|
||||
subgraph_nodes: set[str],
|
||||
prompt_variables: dict[str, Any] | None = None,
|
||||
callback: Callable | None = None,
|
||||
task_id: str = "") -> EntityResolutionResult:
|
||||
task_id: str = "",
|
||||
checkpoints: dict[str, Any] | None = None,
|
||||
save_checkpoint: Callable[[str, Any], Awaitable[bool]] | None = None) -> EntityResolutionResult:
|
||||
"""Call method definition."""
|
||||
if prompt_variables is None:
|
||||
prompt_variables = {}
|
||||
@@ -106,19 +109,35 @@ class EntityResolution(Extractor):
|
||||
resolution_batch_size = 100
|
||||
max_concurrent_tasks = 5
|
||||
semaphore = asyncio.Semaphore(max_concurrent_tasks)
|
||||
checkpoints = checkpoints or {}
|
||||
|
||||
async def limited_resolve_candidate(candidate_batch, result_set, result_lock):
|
||||
nonlocal remain_candidates_to_resolve, callback
|
||||
async with semaphore:
|
||||
try:
|
||||
checkpoint_key = resolution_checkpoint_key(candidate_batch[0], candidate_batch[1])
|
||||
checkpoint = checkpoints.get(checkpoint_key)
|
||||
if isinstance(checkpoint, list):
|
||||
async with result_lock:
|
||||
for pair in checkpoint:
|
||||
if isinstance(pair, (list, tuple)) and len(pair) == 2:
|
||||
result_set.add((pair[0], pair[1]))
|
||||
remain_candidates_to_resolve -= len(candidate_batch[1])
|
||||
callback(
|
||||
msg=f"Replayed {len(candidate_batch[1])} resolved pairs from checkpoint, "
|
||||
f"{remain_candidates_to_resolve} remain."
|
||||
)
|
||||
return
|
||||
enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
|
||||
timeout_sec = 280 if enable_timeout_assertion else 1_000_000_000
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
selected_pairs = await asyncio.wait_for(
|
||||
self._resolve_candidate(candidate_batch, result_set, result_lock, task_id),
|
||||
timeout=timeout_sec
|
||||
)
|
||||
if selected_pairs is not None and save_checkpoint:
|
||||
await save_checkpoint(checkpoint_key, [list(pair) for pair in selected_pairs])
|
||||
remain_candidates_to_resolve -= len(candidate_batch[1])
|
||||
callback(
|
||||
msg=f"Resolved {len(candidate_batch[1])} pairs, "
|
||||
@@ -219,10 +238,10 @@ class EntityResolution(Extractor):
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logging.warning("_resolve_candidate._async_chat timeout, skipping...")
|
||||
return
|
||||
return None
|
||||
except Exception as e:
|
||||
logging.error(f"_resolve_candidate._async_chat failed: {e}")
|
||||
return
|
||||
return None
|
||||
|
||||
logging.debug(f"_resolve_candidate chat prompt: {text}\nchat response: {response}")
|
||||
result = self._process_results(len(candidate_resolution_i[1]), response,
|
||||
@@ -232,9 +251,11 @@ class EntityResolution(Extractor):
|
||||
DEFAULT_ENTITY_INDEX_DELIMITER),
|
||||
self.prompt_variables.get(self._resolution_result_delimiter_key,
|
||||
DEFAULT_RESOLUTION_RESULT_DELIMITER))
|
||||
selected_pairs = [candidate_resolution_i[1][result_i[0] - 1] for result_i in result]
|
||||
async with resolution_result_lock:
|
||||
for result_i in result:
|
||||
resolution_result.add(candidate_resolution_i[1][result_i[0] - 1])
|
||||
for pair in selected_pairs:
|
||||
resolution_result.add(pair)
|
||||
return selected_pairs
|
||||
|
||||
def _process_results(
|
||||
self,
|
||||
@@ -288,4 +309,3 @@ class EntityResolution(Extractor):
|
||||
return len(a & b) > 1
|
||||
|
||||
return len(a & b)*1./max_l >= 0.8
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ import logging
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from typing import Callable
|
||||
from typing import Any, Awaitable, Callable
|
||||
from dataclasses import dataclass
|
||||
import networkx as nx
|
||||
import pandas as pd
|
||||
@@ -23,6 +23,7 @@ from rag.graphrag.general import leiden
|
||||
from rag.graphrag.general.community_report_prompt import COMMUNITY_REPORT_PROMPT
|
||||
from rag.graphrag.general.extractor import Extractor
|
||||
from rag.graphrag.general.leiden import add_community_info2graph
|
||||
from rag.graphrag.checkpoints import community_checkpoint_key
|
||||
from rag.llm.chat_model import Base as CompletionLLM
|
||||
from rag.graphrag.utils import perform_variable_replacements, dict_has_keys_with_types, chat_limiter
|
||||
from common.token_utils import num_tokens_from_string
|
||||
@@ -53,7 +54,14 @@ class CommunityReportsExtractor(Extractor):
|
||||
self._extraction_prompt = COMMUNITY_REPORT_PROMPT
|
||||
self._max_report_length = max_report_length or 1500
|
||||
|
||||
async def __call__(self, graph: nx.Graph, callback: Callable | None = None, task_id: str = ""):
|
||||
async def __call__(
|
||||
self,
|
||||
graph: nx.Graph,
|
||||
callback: Callable | None = None,
|
||||
task_id: str = "",
|
||||
checkpoints: dict[str, Any] | None = None,
|
||||
save_checkpoint: Callable[[str, Any], Awaitable[bool]] | None = None,
|
||||
):
|
||||
enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION")
|
||||
for node_degree in graph.degree:
|
||||
graph.nodes[str(node_degree[0])]["rank"] = int(node_degree[1])
|
||||
@@ -63,7 +71,9 @@ class CommunityReportsExtractor(Extractor):
|
||||
res_str = []
|
||||
res_dict = []
|
||||
over, token_count = 0, 0
|
||||
async def extract_community_report(community):
|
||||
checkpoints = checkpoints or {}
|
||||
|
||||
async def extract_community_report(level, community):
|
||||
nonlocal res_str, res_dict, over, token_count
|
||||
if task_id:
|
||||
if has_canceled(task_id):
|
||||
@@ -75,6 +85,19 @@ class CommunityReportsExtractor(Extractor):
|
||||
ents = cm["nodes"]
|
||||
if len(ents) < 2:
|
||||
return
|
||||
checkpoint_key = community_checkpoint_key(str(level), str(cm_id), list(ents))
|
||||
checkpoint = checkpoints.get(checkpoint_key)
|
||||
if isinstance(checkpoint, dict):
|
||||
response = checkpoint.get("structured_output")
|
||||
output = checkpoint.get("output")
|
||||
if isinstance(response, dict) and isinstance(output, str):
|
||||
add_community_info2graph(graph, response.get("entities", ents), response.get("title", ""))
|
||||
res_str.append(output)
|
||||
res_dict.append(response)
|
||||
over += 1
|
||||
if callback:
|
||||
callback(msg=f"Communities: {over}/{total}, used tokens: {token_count}")
|
||||
return
|
||||
ent_list = [{"entity": ent, "description": graph.nodes[ent]["description"]} for ent in ents]
|
||||
ent_df = pd.DataFrame(ent_list)
|
||||
|
||||
@@ -131,7 +154,10 @@ class CommunityReportsExtractor(Extractor):
|
||||
response["weight"] = weight
|
||||
response["entities"] = ents
|
||||
add_community_info2graph(graph, ents, response["title"])
|
||||
res_str.append(self._get_text_output(response))
|
||||
output = self._get_text_output(response)
|
||||
if save_checkpoint:
|
||||
await save_checkpoint(checkpoint_key, {"structured_output": response, "output": output})
|
||||
res_str.append(output)
|
||||
res_dict.append(response)
|
||||
over += 1
|
||||
if callback:
|
||||
@@ -145,7 +171,7 @@ class CommunityReportsExtractor(Extractor):
|
||||
if task_id and has_canceled(task_id):
|
||||
logging.info(f"Task {task_id} cancelled before community processing.")
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
||||
tasks.append(asyncio.create_task(extract_community_report(community)))
|
||||
tasks.append(asyncio.create_task(extract_community_report(level, community)))
|
||||
try:
|
||||
await asyncio.gather(*tasks, return_exceptions=False)
|
||||
except Exception as e:
|
||||
|
||||
@@ -24,6 +24,13 @@ from api.db.services.task_service import has_canceled
|
||||
from common.exceptions import TaskCanceledException
|
||||
from common.connection_utils import timeout
|
||||
from rag.graphrag.entity_resolution import EntityResolution
|
||||
from rag.graphrag.checkpoints import (
|
||||
COMMUNITY_CHECKPOINT,
|
||||
RESOLUTION_CHECKPOINT,
|
||||
cleanup_checkpoints,
|
||||
load_checkpoints,
|
||||
save_checkpoint,
|
||||
)
|
||||
from rag.graphrag.general.community_reports_extractor import CommunityReportsExtractor
|
||||
from rag.graphrag.general.extractor import Extractor
|
||||
from rag.graphrag.general.graph_extractor import GraphExtractor as GeneralKGExt
|
||||
@@ -763,10 +770,22 @@ async def resolve_entities(
|
||||
_has_cancel_and_exit(task_id, f"Task {task_id} cancelled during entity resolution.", callback)
|
||||
|
||||
start = asyncio.get_running_loop().time()
|
||||
checkpoints = await load_checkpoints(tenant_id, kb_id, RESOLUTION_CHECKPOINT)
|
||||
|
||||
async def save_resolution_checkpoint(checkpoint_key: str, payload):
|
||||
return await save_checkpoint(tenant_id, kb_id, RESOLUTION_CHECKPOINT, checkpoint_key, payload)
|
||||
|
||||
er = EntityResolution(
|
||||
llm_bdl,
|
||||
)
|
||||
reso = await er(graph, subgraph_nodes, callback=callback, task_id=task_id)
|
||||
reso = await er(
|
||||
graph,
|
||||
subgraph_nodes,
|
||||
callback=callback,
|
||||
task_id=task_id,
|
||||
checkpoints=checkpoints,
|
||||
save_checkpoint=save_resolution_checkpoint,
|
||||
)
|
||||
graph = reso.graph
|
||||
change = reso.change
|
||||
callback(msg=f"Graph resolution removed {len(change.removed_nodes)} nodes and {len(change.removed_edges)} edges.")
|
||||
@@ -776,6 +795,7 @@ async def resolve_entities(
|
||||
|
||||
_has_cancel_and_exit(task_id, f"Task {task_id} cancelled before saving resolved graph.", callback)
|
||||
await set_graph(tenant_id, kb_id, embed_bdl, graph, change, callback)
|
||||
await cleanup_checkpoints(tenant_id, kb_id, RESOLUTION_CHECKPOINT)
|
||||
now = asyncio.get_running_loop().time()
|
||||
callback(msg=f"Graph resolution done in {now - start:.2f}s.")
|
||||
|
||||
@@ -794,10 +814,21 @@ async def extract_community(
|
||||
_has_cancel_and_exit(task_id, f"Task {task_id} cancelled before community extraction.", callback)
|
||||
|
||||
start = asyncio.get_running_loop().time()
|
||||
checkpoints = await load_checkpoints(tenant_id, kb_id, COMMUNITY_CHECKPOINT)
|
||||
|
||||
async def save_community_checkpoint(checkpoint_key: str, payload):
|
||||
return await save_checkpoint(tenant_id, kb_id, COMMUNITY_CHECKPOINT, checkpoint_key, payload)
|
||||
|
||||
ext = CommunityReportsExtractor(
|
||||
llm_bdl,
|
||||
)
|
||||
cr = await ext(graph, callback=callback, task_id=task_id)
|
||||
cr = await ext(
|
||||
graph,
|
||||
callback=callback,
|
||||
task_id=task_id,
|
||||
checkpoints=checkpoints,
|
||||
save_checkpoint=save_community_checkpoint,
|
||||
)
|
||||
|
||||
_has_cancel_and_exit(task_id, f"Task {task_id} cancelled during community extraction.", callback)
|
||||
|
||||
@@ -881,6 +912,7 @@ async def extract_community(
|
||||
logging.exception("Failed to prune %d stale community reports for kb %s", len(stale_ids), kb_id)
|
||||
|
||||
_has_cancel_and_exit(task_id, f"Task {task_id} cancelled after community indexing.", callback)
|
||||
await cleanup_checkpoints(tenant_id, kb_id, COMMUNITY_CHECKPOINT)
|
||||
|
||||
now = asyncio.get_running_loop().time()
|
||||
callback(msg=f"Graph indexed {len(cr.structured_output)} communities in {now - start:.2f}s.")
|
||||
|
||||
184
test/unit_test/rag/graphrag/test_checkpoints.py
Normal file
184
test/unit_test/rag/graphrag/test_checkpoints.py
Normal file
@@ -0,0 +1,184 @@
|
||||
#
|
||||
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import pytest
|
||||
|
||||
from rag.graphrag import checkpoints
|
||||
|
||||
|
||||
class _FakeRedisClient:
|
||||
def __init__(self, conn):
|
||||
self.conn = conn
|
||||
self.expirations = {}
|
||||
self.scan_counts = []
|
||||
|
||||
def expire(self, key, ttl):
|
||||
self.expirations[key] = ttl
|
||||
return True
|
||||
|
||||
def pipeline(self, transaction=True):
|
||||
assert transaction is True
|
||||
return _FakeRedisPipeline(self.conn)
|
||||
|
||||
def sscan_iter(self, key, count=None):
|
||||
self.scan_counts.append((key, count))
|
||||
yield from self.conn.sets.get(key, set())
|
||||
|
||||
|
||||
class _FakeRedisPipeline:
|
||||
def __init__(self, conn):
|
||||
self.conn = conn
|
||||
self.commands = []
|
||||
|
||||
def set(self, key, value, ex=None):
|
||||
self.commands.append(("set", key, value, ex))
|
||||
return self
|
||||
|
||||
def sadd(self, key, member):
|
||||
self.commands.append(("sadd", key, member))
|
||||
return self
|
||||
|
||||
def expire(self, key, ttl):
|
||||
self.commands.append(("expire", key, ttl))
|
||||
return self
|
||||
|
||||
def execute(self):
|
||||
if self.conn.fail_pipeline:
|
||||
raise RuntimeError("redis transaction failed")
|
||||
for command in self.commands:
|
||||
match command:
|
||||
case ("set", key, value, ttl):
|
||||
self.conn.values[key] = value
|
||||
if ttl is not None:
|
||||
self.conn.REDIS.expire(key, ttl)
|
||||
case ("sadd", key, member):
|
||||
self.conn.sets.setdefault(key, set()).add(member)
|
||||
case ("expire", key, ttl):
|
||||
self.conn.REDIS.expire(key, ttl)
|
||||
return [True] * len(self.commands)
|
||||
|
||||
|
||||
class _FakeRedisConn:
|
||||
def __init__(self):
|
||||
self.values = {}
|
||||
self.sets = {}
|
||||
self.REDIS = _FakeRedisClient(self)
|
||||
self.fail_set = False
|
||||
self.fail_pipeline = False
|
||||
|
||||
def get(self, key):
|
||||
return self.values.get(key)
|
||||
|
||||
def set(self, key, value, exp=3600):
|
||||
if self.fail_set:
|
||||
return False
|
||||
self.values[key] = value
|
||||
self.REDIS.expire(key, exp)
|
||||
return True
|
||||
|
||||
def sadd(self, key, member):
|
||||
self.sets.setdefault(key, set()).add(member)
|
||||
return True
|
||||
|
||||
def smembers(self, key):
|
||||
raise AssertionError("checkpoint code must use sscan_iter instead of smembers")
|
||||
|
||||
def delete(self, key):
|
||||
self.values.pop(key, None)
|
||||
self.sets.pop(key, None)
|
||||
return True
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_redis(monkeypatch):
|
||||
fake = _FakeRedisConn()
|
||||
monkeypatch.setattr(checkpoints, "REDIS_CONN", fake)
|
||||
return fake
|
||||
|
||||
|
||||
@pytest.mark.p1
|
||||
def test_checkpoint_keys_are_stable():
|
||||
first = checkpoints.community_checkpoint_key("1", "2", ["B", "A"])
|
||||
second = checkpoints.community_checkpoint_key("1", "2", ["A", "B"])
|
||||
assert first == second
|
||||
|
||||
pairs = [("alpha", "alfa"), ("beta", "bata")]
|
||||
reversed_pairs = list(reversed(pairs))
|
||||
assert checkpoints.resolution_checkpoint_key("entity", pairs) == checkpoints.resolution_checkpoint_key("entity", reversed_pairs)
|
||||
|
||||
internally_reversed_pairs = [("alfa", "alpha"), ("bata", "beta")]
|
||||
assert checkpoints.resolution_checkpoint_key("entity", pairs) == checkpoints.resolution_checkpoint_key("entity", internally_reversed_pairs)
|
||||
|
||||
|
||||
@pytest.mark.p1
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_checkpoints_reads_redis_index(fake_redis, monkeypatch):
|
||||
await checkpoints.save_checkpoint("tenant-1", "kb-1", checkpoints.COMMUNITY_CHECKPOINT, "k1", {"value": 1})
|
||||
await checkpoints.save_checkpoint("tenant-1", "kb-1", checkpoints.COMMUNITY_CHECKPOINT, "k2", {"value": 2})
|
||||
await checkpoints.save_checkpoint("tenant-1", "kb-2", checkpoints.COMMUNITY_CHECKPOINT, "k3", {"value": 3})
|
||||
|
||||
thread_pool_calls = []
|
||||
|
||||
async def _fake_thread_pool_exec(func, *args, **kwargs):
|
||||
thread_pool_calls.append((func, args, kwargs))
|
||||
return func(*args, **kwargs)
|
||||
|
||||
monkeypatch.setattr(checkpoints, "thread_pool_exec", _fake_thread_pool_exec)
|
||||
|
||||
loaded = await checkpoints.load_checkpoints("tenant-1", "kb-1", checkpoints.COMMUNITY_CHECKPOINT, page_size=1)
|
||||
|
||||
assert loaded == {"k1": {"value": 1}, "k2": {"value": 2}}
|
||||
assert thread_pool_calls == [
|
||||
(
|
||||
checkpoints._load_checkpoints_sync,
|
||||
("tenant-1", "kb-1", checkpoints.COMMUNITY_CHECKPOINT, 1),
|
||||
{},
|
||||
)
|
||||
]
|
||||
assert fake_redis.REDIS.scan_counts[-1] == (
|
||||
"graphrag:checkpoint:tenant-1:kb-1:graphrag_checkpoint_community:keys",
|
||||
1,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_checkpoint_degrades_on_redis_failure(fake_redis):
|
||||
fake_redis.fail_pipeline = True
|
||||
|
||||
saved = await checkpoints.save_checkpoint("tenant-1", "kb-1", checkpoints.RESOLUTION_CHECKPOINT, "key-1", {"ok": True})
|
||||
|
||||
assert saved is False
|
||||
assert fake_redis.values == {}
|
||||
assert fake_redis.sets == {}
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_checkpoints_deletes_redis_stage_keys(fake_redis):
|
||||
await checkpoints.save_checkpoint("tenant-1", "kb-1", checkpoints.RESOLUTION_CHECKPOINT, "k1", {"value": 1})
|
||||
await checkpoints.save_checkpoint("tenant-1", "kb-1", checkpoints.RESOLUTION_CHECKPOINT, "k2", {"value": 2})
|
||||
await checkpoints.save_checkpoint("tenant-1", "kb-1", checkpoints.COMMUNITY_CHECKPOINT, "k3", {"value": 3})
|
||||
|
||||
cleaned = await checkpoints.cleanup_checkpoints("tenant-1", "kb-1", checkpoints.RESOLUTION_CHECKPOINT, page_size=1)
|
||||
|
||||
assert cleaned is True
|
||||
assert await checkpoints.load_checkpoints("tenant-1", "kb-1", checkpoints.RESOLUTION_CHECKPOINT) == {}
|
||||
assert await checkpoints.load_checkpoints("tenant-1", "kb-1", checkpoints.COMMUNITY_CHECKPOINT) == {"k3": {"value": 3}}
|
||||
assert (
|
||||
"graphrag:checkpoint:tenant-1:kb-1:graphrag_checkpoint_resolution:keys",
|
||||
1,
|
||||
) in fake_redis.REDIS.scan_counts
|
||||
Reference in New Issue
Block a user