Refa: GraphRAG to use async chat methods instead of thread pool execution (#14002)

### What problem does this PR solve?

GraphRAG _async_chat.

### Type of change

- [x] Refactoring
- [x] Performance Improvement


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **Refactor**
* Unified chat calls to an async invocation across extractors, improving
timeout handling and ensuring task IDs propagate reliably.
* **Tests**
* Added and expanded unit tests and mocks to cover extractor behavior,
timeout scenarios, and safe test-package imports, reducing regression
risk.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
Yongteng Lei
2026-04-09 19:57:35 +08:00
committed by GitHub
parent c2ce49e037
commit b33d2fdea5
9 changed files with 150 additions and 36 deletions

View File

@@ -32,7 +32,6 @@ from rag.graphrag.utils import perform_variable_replacements, chat_limiter, Grap
from api.db.services.task_service import has_canceled
from common.exceptions import TaskCanceledException
from common.misc_utils import thread_pool_exec
DEFAULT_RECORD_DELIMITER = "##"
DEFAULT_ENTITY_INDEX_DELIMITER = "<|>"
@@ -213,21 +212,15 @@ class EntityResolution(Extractor):
timeout_seconds = 280 if os.environ.get("ENABLE_TIMEOUT_ASSERTION") else 1000000000
try:
response = await asyncio.wait_for(
thread_pool_exec(
self._chat,
text,
[{"role": "user", "content": "Output:"}],
{},
task_id
),
self._async_chat(text, [{"role": "user", "content": "Output:"}], {}, task_id),
timeout=timeout_seconds,
)
except asyncio.TimeoutError:
logging.warning("_resolve_candidate._chat timeout, skipping...")
logging.warning("_resolve_candidate._async_chat timeout, skipping...")
return
except Exception as e:
logging.error(f"_resolve_candidate._chat failed: {e}")
logging.error(f"_resolve_candidate._async_chat failed: {e}")
return
logging.debug(f"_resolve_candidate chat prompt: {text}\nchat response: {response}")

View File

@@ -1,7 +1,6 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
from common.misc_utils import thread_pool_exec
"""
Reference:
@@ -20,7 +19,6 @@ import pandas as pd
from api.db.services.task_service import has_canceled
from common.exceptions import TaskCanceledException
from common.connection_utils import timeout
from rag.graphrag.general import leiden
from rag.graphrag.general.community_report_prompt import COMMUNITY_REPORT_PROMPT
from rag.graphrag.general.extractor import Extractor
@@ -65,7 +63,6 @@ class CommunityReportsExtractor(Extractor):
res_str = []
res_dict = []
over, token_count = 0, 0
@timeout(120)
async def extract_community_report(community):
nonlocal res_str, res_dict, over, token_count
if task_id:
@@ -104,12 +101,12 @@ class CommunityReportsExtractor(Extractor):
async with chat_limiter:
try:
timeout = 180 if enable_timeout_assertion else 1000000000
response = await asyncio.wait_for(thread_pool_exec(self._chat,text,[{"role": "user", "content": "Output:"}],{},task_id),timeout=timeout)
response = await asyncio.wait_for(self._async_chat(text, [{"role": "user", "content": "Output:"}], {}, task_id), timeout=timeout)
except asyncio.TimeoutError:
logging.warning("extract_community_report._chat timeout, skipping...")
logging.warning("extract_community_report._async_chat timeout, skipping...")
return
except Exception as e:
logging.error(f"extract_community_report._chat failed: {e}")
logging.error(f"extract_community_report._async_chat failed: {e}")
return
token_count += num_tokens_from_string(text + response)
response = re.sub(r"^[^\{]*", "", response)

View File

@@ -24,7 +24,6 @@ from typing import Callable
import networkx as nx
from api.db.services.task_service import has_canceled
from common.connection_utils import timeout
from common.token_utils import truncate
from rag.graphrag.general.graph_prompt import SUMMARIZE_DESCRIPTIONS_PROMPT
from rag.graphrag.utils import (
@@ -74,11 +73,10 @@ class Extractor:
def _is_truncated_cache(response):
return len((response or "").strip()) <= 1
@timeout(60 * 20)
def _chat(self, system, history, gen_conf={}, task_id=""):
async def _async_chat(self, system, history, gen_conf={}, task_id=""):
hist = deepcopy(history)
conf = deepcopy(gen_conf)
response = get_llm_cache(self._llm.llm_name, system, hist, conf)
response = await thread_pool_exec(get_llm_cache, self._llm.llm_name, system, hist, conf)
response = self._normalize_response_text(response)
if self._is_truncated_cache(response):
response = ""
@@ -88,18 +86,24 @@ class Extractor:
response = ""
for attempt in range(3):
if task_id:
if has_canceled(task_id):
if await thread_pool_exec(has_canceled, task_id):
logging.info(f"Task {task_id} cancelled during entity resolution candidate processing.")
raise TaskCanceledException(f"Task {task_id} was cancelled")
try:
response = asyncio.run(self._llm.async_chat(system_msg[0]["content"], hist, conf))
response = await asyncio.wait_for(
self._llm.async_chat(system_msg[0]["content"], hist, conf),
timeout=60 * 20,
)
response = self._normalize_response_text(response)
response = re.sub(r"^.*</think>", "", response, flags=re.DOTALL)
if response.find("**ERROR**") >= 0:
raise Exception(response)
if not self._is_truncated_cache(response):
set_llm_cache(self._llm.llm_name, system, response, history, gen_conf)
await thread_pool_exec(set_llm_cache, self._llm.llm_name, system, response, history, gen_conf)
break
except asyncio.TimeoutError:
logging.warning("_async_chat timed out after 20 minutes")
raise # timeout is not a transient error; do not retry
except Exception as e:
logging.exception(e)
if attempt == 2:
@@ -357,5 +361,5 @@ class Extractor:
raise TaskCanceledException(f"Task {task_id} was cancelled during summary handling")
async with chat_limiter:
summary = await thread_pool_exec(self._chat, "", [{"role": "user", "content": use_prompt}], {}, task_id)
summary = await self._async_chat("", [{"role": "user", "content": use_prompt}], {}, task_id)
return summary

View File

@@ -1,8 +1,6 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
from common.misc_utils import thread_pool_exec
"""
Reference:
- [graphrag](https://github.com/microsoft/graphrag)
@@ -109,7 +107,7 @@ class GraphExtractor(Extractor):
}
hint_prompt = perform_variable_replacements(self._extraction_prompt, variables=variables)
async with chat_limiter:
response = await thread_pool_exec(self._chat,hint_prompt,[{"role": "user", "content": "Output:"}],{},task_id)
response = await self._async_chat(hint_prompt, [{"role": "user", "content": "Output:"}], {}, task_id)
token_count += num_tokens_from_string(hint_prompt + response)
results = response or ""
@@ -119,7 +117,7 @@ class GraphExtractor(Extractor):
for i in range(self._max_gleanings):
history.append({"role": "user", "content": CONTINUE_PROMPT})
async with chat_limiter:
response = await thread_pool_exec(self._chat, "", history, {})
response = await self._async_chat("", history, {}, task_id)
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + response)
results += response or ""
@@ -129,7 +127,7 @@ class GraphExtractor(Extractor):
history.append({"role": "assistant", "content": response})
history.append({"role": "user", "content": LOOP_PROMPT})
async with chat_limiter:
continuation = await thread_pool_exec(self._chat, "", history)
continuation = await self._async_chat("", history, {}, task_id)
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + response)
if continuation != "Y":
break

View File

@@ -29,7 +29,6 @@ import markdown_to_json
from functools import reduce
from common.token_utils import num_tokens_from_string
from common.misc_utils import thread_pool_exec
@dataclass
class MindMapResult:
@@ -186,7 +185,7 @@ class MindMapExtractor(Extractor):
}
text = perform_variable_replacements(self._mind_map_prompt, variables=variables)
async with chat_limiter:
response = await thread_pool_exec(self._chat,text,[{"role": "user", "content": "Output:"}],{})
response = await self._async_chat(text, [{"role": "user", "content": "Output:"}], {})
response = re.sub(r"```[^\n]*", "", response)
logging.debug(response)
logging.debug(self._todict(markdown_to_json.dictify(response)))

View File

@@ -1,8 +1,6 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
from common.misc_utils import thread_pool_exec
"""
Reference:
- [graphrag](https://github.com/microsoft/graphrag)
@@ -83,12 +81,12 @@ class GraphExtractor(Extractor):
if self.callback:
self.callback(msg=f"Start processing for {chunk_key}: {content[:25]}...")
async with chat_limiter:
final_result = await thread_pool_exec(self._chat,"",[{"role": "user", "content": hint_prompt}],gen_conf,task_id)
final_result = await self._async_chat("", [{"role": "user", "content": hint_prompt}], gen_conf, task_id)
token_count += num_tokens_from_string(hint_prompt + final_result)
history = pack_user_ass_to_openai_messages(hint_prompt, final_result, self._continue_prompt)
for now_glean_index in range(self._max_gleanings):
async with chat_limiter:
glean_result = await thread_pool_exec(self._chat,"",history,gen_conf,task_id)
glean_result = await self._async_chat("", history, gen_conf, task_id)
history.extend([{"role": "assistant", "content": glean_result}])
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + hint_prompt + self._continue_prompt)
final_result += glean_result
@@ -97,7 +95,7 @@ class GraphExtractor(Extractor):
history.extend([{"role": "user", "content": self._if_loop_prompt}])
async with chat_limiter:
if_loop_result = await thread_pool_exec(self._chat,"",history,gen_conf,task_id)
if_loop_result = await self._async_chat("", history, gen_conf, task_id)
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + if_loop_result + self._if_loop_prompt)
if_loop_result = if_loop_result.strip().strip('"').strip("'").lower()
if if_loop_result != "yes":

View File

@@ -28,6 +28,10 @@ _modules_to_mock = [
"common.settings",
"common.doc_store",
"common.doc_store.doc_store_base",
"api.db.services",
"api.db.services.task_service",
"rag.graphrag.general.leiden",
"rag.llm.chat_model",
"rag.nlp",
"rag.nlp.search",
"rag.nlp.rag_tokenizer",
@@ -40,3 +44,7 @@ for mod_name in _modules_to_mock:
# Ensure `from common.connection_utils import timeout` returns a no-op decorator
sys.modules["common.connection_utils"].timeout = lambda *a, **kw: (lambda fn: fn)
sys.modules["api.db.services.task_service"].has_canceled = lambda *_a, **_kw: False
sys.modules["rag.graphrag.general.leiden"].run = lambda *_a, **_kw: {}
sys.modules["rag.graphrag.general.leiden"].add_community_info2graph = lambda *_a, **_kw: None
sys.modules["rag.llm.chat_model"].Base = object

View File

@@ -0,0 +1,96 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
from types import SimpleNamespace
import networkx as nx
import pytest
import rag.graphrag.general.community_reports_extractor as community_reports_module
from rag.graphrag.general.community_reports_extractor import CommunityReportsExtractor
from rag.graphrag.general.graph_extractor import GraphExtractor
def _build_llm_stub():
return SimpleNamespace(llm_name="test-llm", max_length=4096)
class TestGraphExtractor:
@pytest.mark.p2
@pytest.mark.asyncio
async def test_process_single_content_passes_task_id_to_gleaning_calls(self, monkeypatch):
extractor = GraphExtractor(_build_llm_stub(), entity_types=["person"])
extractor.callback = None
seen_task_ids = []
responses = iter(["seed-response", "glean-response", "N"])
async def fake_async_chat(_system, _history, _gen_conf=None, task_id=""):
seen_task_ids.append(task_id)
return next(responses)
monkeypatch.setattr(extractor, "_async_chat", fake_async_chat)
monkeypatch.setattr(extractor, "_entities_and_relations", lambda *_args, **_kwargs: ({}, {}))
out_results = []
await extractor._process_single_content(("chunk-1", "alpha beta"), 0, 1, out_results, task_id="task-123")
assert seen_task_ids == ["task-123", "task-123", "task-123"]
class TestCommunityReportsExtractor:
@pytest.mark.p2
@pytest.mark.asyncio
async def test_call_does_not_use_outer_timeout_shorter_than_llm_timeout(self, monkeypatch):
extractor = CommunityReportsExtractor(_build_llm_stub())
graph = nx.Graph()
graph.add_node("A", description="alpha")
graph.add_node("B", description="beta")
graph.add_edge("A", "B", description="related")
monkeypatch.setenv("ENABLE_TIMEOUT_ASSERTION", "1")
original_wait_for = asyncio.wait_for
def fake_timeout(_seconds, _attempts=2, **_kwargs):
def decorator(fn):
async def wrapper(*args, **kwargs):
return await original_wait_for(fn(*args, **kwargs), timeout=0.01)
return wrapper
return decorator
async def slow_async_chat(*_args, **_kwargs):
await asyncio.sleep(0.02)
return (
'{"title":"Community","summary":"Summary","findings":[],'
'"rating":1.0,"rating_explanation":"Clear"}'
)
monkeypatch.setattr(community_reports_module, "timeout", fake_timeout, raising=False)
monkeypatch.setattr(
community_reports_module.leiden,
"run",
lambda *_args, **_kwargs: {0: {"0": {"weight": 1.0, "nodes": ["A", "B"]}}},
)
monkeypatch.setattr(community_reports_module, "add_community_info2graph", lambda *_args, **_kwargs: None)
monkeypatch.setattr(extractor, "_async_chat", slow_async_chat)
result = await extractor(graph)
assert len(result.structured_output) == 1
assert result.structured_output[0]["title"] == "Community"

View File

@@ -0,0 +1,21 @@
import importlib
import sys
from types import ModuleType
import pytest
pytestmark = pytest.mark.p2
def test_chunk_feedback_package_import_is_safe_when_common_is_shadowed(monkeypatch):
shadow_common = ModuleType("common")
monkeypatch.setitem(sys.modules, "common", shadow_common)
monkeypatch.delitem(
sys.modules,
"test.testcases.test_web_api.test_chunk_feedback",
raising=False,
)
module = importlib.import_module("test.testcases.test_web_api.test_chunk_feedback")
assert module is not None