Refa: GraphRAG to use async chat methods instead of thread pool execution (#14002)

### What problem does this PR solve?

GraphRAG _async_chat.

### Type of change

- [x] Refactoring
- [x] Performance Improvement


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

* **Refactor**
* Unified chat calls to an async invocation across extractors, improving
timeout handling and ensuring task IDs propagate reliably.
* **Tests**
* Added and expanded unit tests and mocks to cover extractor behavior,
timeout scenarios, and safe test-package imports, reducing regression
risk.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
Yongteng Lei
2026-04-09 19:57:35 +08:00
committed by GitHub
parent c2ce49e037
commit b33d2fdea5
9 changed files with 150 additions and 36 deletions

View File

@@ -32,7 +32,6 @@ from rag.graphrag.utils import perform_variable_replacements, chat_limiter, Grap
from api.db.services.task_service import has_canceled
from common.exceptions import TaskCanceledException
from common.misc_utils import thread_pool_exec
DEFAULT_RECORD_DELIMITER = "##"
DEFAULT_ENTITY_INDEX_DELIMITER = "<|>"
@@ -213,21 +212,15 @@ class EntityResolution(Extractor):
timeout_seconds = 280 if os.environ.get("ENABLE_TIMEOUT_ASSERTION") else 1000000000
try:
response = await asyncio.wait_for(
thread_pool_exec(
self._chat,
text,
[{"role": "user", "content": "Output:"}],
{},
task_id
),
self._async_chat(text, [{"role": "user", "content": "Output:"}], {}, task_id),
timeout=timeout_seconds,
)
except asyncio.TimeoutError:
logging.warning("_resolve_candidate._chat timeout, skipping...")
logging.warning("_resolve_candidate._async_chat timeout, skipping...")
return
except Exception as e:
logging.error(f"_resolve_candidate._chat failed: {e}")
logging.error(f"_resolve_candidate._async_chat failed: {e}")
return
logging.debug(f"_resolve_candidate chat prompt: {text}\nchat response: {response}")

View File

@@ -1,7 +1,6 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
from common.misc_utils import thread_pool_exec
"""
Reference:
@@ -20,7 +19,6 @@ import pandas as pd
from api.db.services.task_service import has_canceled
from common.exceptions import TaskCanceledException
from common.connection_utils import timeout
from rag.graphrag.general import leiden
from rag.graphrag.general.community_report_prompt import COMMUNITY_REPORT_PROMPT
from rag.graphrag.general.extractor import Extractor
@@ -65,7 +63,6 @@ class CommunityReportsExtractor(Extractor):
res_str = []
res_dict = []
over, token_count = 0, 0
@timeout(120)
async def extract_community_report(community):
nonlocal res_str, res_dict, over, token_count
if task_id:
@@ -104,12 +101,12 @@ class CommunityReportsExtractor(Extractor):
async with chat_limiter:
try:
timeout = 180 if enable_timeout_assertion else 1000000000
response = await asyncio.wait_for(thread_pool_exec(self._chat,text,[{"role": "user", "content": "Output:"}],{},task_id),timeout=timeout)
response = await asyncio.wait_for(self._async_chat(text, [{"role": "user", "content": "Output:"}], {}, task_id), timeout=timeout)
except asyncio.TimeoutError:
logging.warning("extract_community_report._chat timeout, skipping...")
logging.warning("extract_community_report._async_chat timeout, skipping...")
return
except Exception as e:
logging.error(f"extract_community_report._chat failed: {e}")
logging.error(f"extract_community_report._async_chat failed: {e}")
return
token_count += num_tokens_from_string(text + response)
response = re.sub(r"^[^\{]*", "", response)

View File

@@ -24,7 +24,6 @@ from typing import Callable
import networkx as nx
from api.db.services.task_service import has_canceled
from common.connection_utils import timeout
from common.token_utils import truncate
from rag.graphrag.general.graph_prompt import SUMMARIZE_DESCRIPTIONS_PROMPT
from rag.graphrag.utils import (
@@ -74,11 +73,10 @@ class Extractor:
def _is_truncated_cache(response):
return len((response or "").strip()) <= 1
@timeout(60 * 20)
def _chat(self, system, history, gen_conf={}, task_id=""):
async def _async_chat(self, system, history, gen_conf={}, task_id=""):
hist = deepcopy(history)
conf = deepcopy(gen_conf)
response = get_llm_cache(self._llm.llm_name, system, hist, conf)
response = await thread_pool_exec(get_llm_cache, self._llm.llm_name, system, hist, conf)
response = self._normalize_response_text(response)
if self._is_truncated_cache(response):
response = ""
@@ -88,18 +86,24 @@ class Extractor:
response = ""
for attempt in range(3):
if task_id:
if has_canceled(task_id):
if await thread_pool_exec(has_canceled, task_id):
logging.info(f"Task {task_id} cancelled during entity resolution candidate processing.")
raise TaskCanceledException(f"Task {task_id} was cancelled")
try:
response = asyncio.run(self._llm.async_chat(system_msg[0]["content"], hist, conf))
response = await asyncio.wait_for(
self._llm.async_chat(system_msg[0]["content"], hist, conf),
timeout=60 * 20,
)
response = self._normalize_response_text(response)
response = re.sub(r"^.*</think>", "", response, flags=re.DOTALL)
if response.find("**ERROR**") >= 0:
raise Exception(response)
if not self._is_truncated_cache(response):
set_llm_cache(self._llm.llm_name, system, response, history, gen_conf)
await thread_pool_exec(set_llm_cache, self._llm.llm_name, system, response, history, gen_conf)
break
except asyncio.TimeoutError:
logging.warning("_async_chat timed out after 20 minutes")
raise # timeout is not a transient error; do not retry
except Exception as e:
logging.exception(e)
if attempt == 2:
@@ -357,5 +361,5 @@ class Extractor:
raise TaskCanceledException(f"Task {task_id} was cancelled during summary handling")
async with chat_limiter:
summary = await thread_pool_exec(self._chat, "", [{"role": "user", "content": use_prompt}], {}, task_id)
summary = await self._async_chat("", [{"role": "user", "content": use_prompt}], {}, task_id)
return summary

View File

@@ -1,8 +1,6 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
from common.misc_utils import thread_pool_exec
"""
Reference:
- [graphrag](https://github.com/microsoft/graphrag)
@@ -109,7 +107,7 @@ class GraphExtractor(Extractor):
}
hint_prompt = perform_variable_replacements(self._extraction_prompt, variables=variables)
async with chat_limiter:
response = await thread_pool_exec(self._chat,hint_prompt,[{"role": "user", "content": "Output:"}],{},task_id)
response = await self._async_chat(hint_prompt, [{"role": "user", "content": "Output:"}], {}, task_id)
token_count += num_tokens_from_string(hint_prompt + response)
results = response or ""
@@ -119,7 +117,7 @@ class GraphExtractor(Extractor):
for i in range(self._max_gleanings):
history.append({"role": "user", "content": CONTINUE_PROMPT})
async with chat_limiter:
response = await thread_pool_exec(self._chat, "", history, {})
response = await self._async_chat("", history, {}, task_id)
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + response)
results += response or ""
@@ -129,7 +127,7 @@ class GraphExtractor(Extractor):
history.append({"role": "assistant", "content": response})
history.append({"role": "user", "content": LOOP_PROMPT})
async with chat_limiter:
continuation = await thread_pool_exec(self._chat, "", history)
continuation = await self._async_chat("", history, {}, task_id)
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + response)
if continuation != "Y":
break

View File

@@ -29,7 +29,6 @@ import markdown_to_json
from functools import reduce
from common.token_utils import num_tokens_from_string
from common.misc_utils import thread_pool_exec
@dataclass
class MindMapResult:
@@ -186,7 +185,7 @@ class MindMapExtractor(Extractor):
}
text = perform_variable_replacements(self._mind_map_prompt, variables=variables)
async with chat_limiter:
response = await thread_pool_exec(self._chat,text,[{"role": "user", "content": "Output:"}],{})
response = await self._async_chat(text, [{"role": "user", "content": "Output:"}], {})
response = re.sub(r"```[^\n]*", "", response)
logging.debug(response)
logging.debug(self._todict(markdown_to_json.dictify(response)))

View File

@@ -1,8 +1,6 @@
# Copyright (c) 2024 Microsoft Corporation.
# Licensed under the MIT License
from common.misc_utils import thread_pool_exec
"""
Reference:
- [graphrag](https://github.com/microsoft/graphrag)
@@ -83,12 +81,12 @@ class GraphExtractor(Extractor):
if self.callback:
self.callback(msg=f"Start processing for {chunk_key}: {content[:25]}...")
async with chat_limiter:
final_result = await thread_pool_exec(self._chat,"",[{"role": "user", "content": hint_prompt}],gen_conf,task_id)
final_result = await self._async_chat("", [{"role": "user", "content": hint_prompt}], gen_conf, task_id)
token_count += num_tokens_from_string(hint_prompt + final_result)
history = pack_user_ass_to_openai_messages(hint_prompt, final_result, self._continue_prompt)
for now_glean_index in range(self._max_gleanings):
async with chat_limiter:
glean_result = await thread_pool_exec(self._chat,"",history,gen_conf,task_id)
glean_result = await self._async_chat("", history, gen_conf, task_id)
history.extend([{"role": "assistant", "content": glean_result}])
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + hint_prompt + self._continue_prompt)
final_result += glean_result
@@ -97,7 +95,7 @@ class GraphExtractor(Extractor):
history.extend([{"role": "user", "content": self._if_loop_prompt}])
async with chat_limiter:
if_loop_result = await thread_pool_exec(self._chat,"",history,gen_conf,task_id)
if_loop_result = await self._async_chat("", history, gen_conf, task_id)
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + if_loop_result + self._if_loop_prompt)
if_loop_result = if_loop_result.strip().strip('"').strip("'").lower()
if if_loop_result != "yes":