mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 15:31:05 +08:00
Refa: GraphRAG to use async chat methods instead of thread pool execution (#14002)
### What problem does this PR solve? GraphRAG _async_chat. ### Type of change - [x] Refactoring - [x] Performance Improvement <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Refactor** * Unified chat calls to an async invocation across extractors, improving timeout handling and ensuring task IDs propagate reliably. * **Tests** * Added and expanded unit tests and mocks to cover extractor behavior, timeout scenarios, and safe test-package imports, reducing regression risk. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
This commit is contained in:
@@ -32,7 +32,6 @@ from rag.graphrag.utils import perform_variable_replacements, chat_limiter, Grap
|
||||
from api.db.services.task_service import has_canceled
|
||||
from common.exceptions import TaskCanceledException
|
||||
|
||||
from common.misc_utils import thread_pool_exec
|
||||
|
||||
DEFAULT_RECORD_DELIMITER = "##"
|
||||
DEFAULT_ENTITY_INDEX_DELIMITER = "<|>"
|
||||
@@ -213,21 +212,15 @@ class EntityResolution(Extractor):
|
||||
timeout_seconds = 280 if os.environ.get("ENABLE_TIMEOUT_ASSERTION") else 1000000000
|
||||
try:
|
||||
response = await asyncio.wait_for(
|
||||
thread_pool_exec(
|
||||
self._chat,
|
||||
text,
|
||||
[{"role": "user", "content": "Output:"}],
|
||||
{},
|
||||
task_id
|
||||
),
|
||||
self._async_chat(text, [{"role": "user", "content": "Output:"}], {}, task_id),
|
||||
timeout=timeout_seconds,
|
||||
)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logging.warning("_resolve_candidate._chat timeout, skipping...")
|
||||
logging.warning("_resolve_candidate._async_chat timeout, skipping...")
|
||||
return
|
||||
except Exception as e:
|
||||
logging.error(f"_resolve_candidate._chat failed: {e}")
|
||||
logging.error(f"_resolve_candidate._async_chat failed: {e}")
|
||||
return
|
||||
|
||||
logging.debug(f"_resolve_candidate chat prompt: {text}\nchat response: {response}")
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
from common.misc_utils import thread_pool_exec
|
||||
|
||||
"""
|
||||
Reference:
|
||||
@@ -20,7 +19,6 @@ import pandas as pd
|
||||
|
||||
from api.db.services.task_service import has_canceled
|
||||
from common.exceptions import TaskCanceledException
|
||||
from common.connection_utils import timeout
|
||||
from rag.graphrag.general import leiden
|
||||
from rag.graphrag.general.community_report_prompt import COMMUNITY_REPORT_PROMPT
|
||||
from rag.graphrag.general.extractor import Extractor
|
||||
@@ -65,7 +63,6 @@ class CommunityReportsExtractor(Extractor):
|
||||
res_str = []
|
||||
res_dict = []
|
||||
over, token_count = 0, 0
|
||||
@timeout(120)
|
||||
async def extract_community_report(community):
|
||||
nonlocal res_str, res_dict, over, token_count
|
||||
if task_id:
|
||||
@@ -104,12 +101,12 @@ class CommunityReportsExtractor(Extractor):
|
||||
async with chat_limiter:
|
||||
try:
|
||||
timeout = 180 if enable_timeout_assertion else 1000000000
|
||||
response = await asyncio.wait_for(thread_pool_exec(self._chat,text,[{"role": "user", "content": "Output:"}],{},task_id),timeout=timeout)
|
||||
response = await asyncio.wait_for(self._async_chat(text, [{"role": "user", "content": "Output:"}], {}, task_id), timeout=timeout)
|
||||
except asyncio.TimeoutError:
|
||||
logging.warning("extract_community_report._chat timeout, skipping...")
|
||||
logging.warning("extract_community_report._async_chat timeout, skipping...")
|
||||
return
|
||||
except Exception as e:
|
||||
logging.error(f"extract_community_report._chat failed: {e}")
|
||||
logging.error(f"extract_community_report._async_chat failed: {e}")
|
||||
return
|
||||
token_count += num_tokens_from_string(text + response)
|
||||
response = re.sub(r"^[^\{]*", "", response)
|
||||
|
||||
@@ -24,7 +24,6 @@ from typing import Callable
|
||||
import networkx as nx
|
||||
|
||||
from api.db.services.task_service import has_canceled
|
||||
from common.connection_utils import timeout
|
||||
from common.token_utils import truncate
|
||||
from rag.graphrag.general.graph_prompt import SUMMARIZE_DESCRIPTIONS_PROMPT
|
||||
from rag.graphrag.utils import (
|
||||
@@ -74,11 +73,10 @@ class Extractor:
|
||||
def _is_truncated_cache(response):
|
||||
return len((response or "").strip()) <= 1
|
||||
|
||||
@timeout(60 * 20)
|
||||
def _chat(self, system, history, gen_conf={}, task_id=""):
|
||||
async def _async_chat(self, system, history, gen_conf={}, task_id=""):
|
||||
hist = deepcopy(history)
|
||||
conf = deepcopy(gen_conf)
|
||||
response = get_llm_cache(self._llm.llm_name, system, hist, conf)
|
||||
response = await thread_pool_exec(get_llm_cache, self._llm.llm_name, system, hist, conf)
|
||||
response = self._normalize_response_text(response)
|
||||
if self._is_truncated_cache(response):
|
||||
response = ""
|
||||
@@ -88,18 +86,24 @@ class Extractor:
|
||||
response = ""
|
||||
for attempt in range(3):
|
||||
if task_id:
|
||||
if has_canceled(task_id):
|
||||
if await thread_pool_exec(has_canceled, task_id):
|
||||
logging.info(f"Task {task_id} cancelled during entity resolution candidate processing.")
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled")
|
||||
try:
|
||||
response = asyncio.run(self._llm.async_chat(system_msg[0]["content"], hist, conf))
|
||||
response = await asyncio.wait_for(
|
||||
self._llm.async_chat(system_msg[0]["content"], hist, conf),
|
||||
timeout=60 * 20,
|
||||
)
|
||||
response = self._normalize_response_text(response)
|
||||
response = re.sub(r"^.*</think>", "", response, flags=re.DOTALL)
|
||||
if response.find("**ERROR**") >= 0:
|
||||
raise Exception(response)
|
||||
if not self._is_truncated_cache(response):
|
||||
set_llm_cache(self._llm.llm_name, system, response, history, gen_conf)
|
||||
await thread_pool_exec(set_llm_cache, self._llm.llm_name, system, response, history, gen_conf)
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
logging.warning("_async_chat timed out after 20 minutes")
|
||||
raise # timeout is not a transient error; do not retry
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
if attempt == 2:
|
||||
@@ -357,5 +361,5 @@ class Extractor:
|
||||
raise TaskCanceledException(f"Task {task_id} was cancelled during summary handling")
|
||||
|
||||
async with chat_limiter:
|
||||
summary = await thread_pool_exec(self._chat, "", [{"role": "user", "content": use_prompt}], {}, task_id)
|
||||
summary = await self._async_chat("", [{"role": "user", "content": use_prompt}], {}, task_id)
|
||||
return summary
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
from common.misc_utils import thread_pool_exec
|
||||
|
||||
"""
|
||||
Reference:
|
||||
- [graphrag](https://github.com/microsoft/graphrag)
|
||||
@@ -109,7 +107,7 @@ class GraphExtractor(Extractor):
|
||||
}
|
||||
hint_prompt = perform_variable_replacements(self._extraction_prompt, variables=variables)
|
||||
async with chat_limiter:
|
||||
response = await thread_pool_exec(self._chat,hint_prompt,[{"role": "user", "content": "Output:"}],{},task_id)
|
||||
response = await self._async_chat(hint_prompt, [{"role": "user", "content": "Output:"}], {}, task_id)
|
||||
token_count += num_tokens_from_string(hint_prompt + response)
|
||||
|
||||
results = response or ""
|
||||
@@ -119,7 +117,7 @@ class GraphExtractor(Extractor):
|
||||
for i in range(self._max_gleanings):
|
||||
history.append({"role": "user", "content": CONTINUE_PROMPT})
|
||||
async with chat_limiter:
|
||||
response = await thread_pool_exec(self._chat, "", history, {})
|
||||
response = await self._async_chat("", history, {}, task_id)
|
||||
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + response)
|
||||
results += response or ""
|
||||
|
||||
@@ -129,7 +127,7 @@ class GraphExtractor(Extractor):
|
||||
history.append({"role": "assistant", "content": response})
|
||||
history.append({"role": "user", "content": LOOP_PROMPT})
|
||||
async with chat_limiter:
|
||||
continuation = await thread_pool_exec(self._chat, "", history)
|
||||
continuation = await self._async_chat("", history, {}, task_id)
|
||||
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + response)
|
||||
if continuation != "Y":
|
||||
break
|
||||
|
||||
@@ -29,7 +29,6 @@ import markdown_to_json
|
||||
from functools import reduce
|
||||
from common.token_utils import num_tokens_from_string
|
||||
|
||||
from common.misc_utils import thread_pool_exec
|
||||
|
||||
@dataclass
|
||||
class MindMapResult:
|
||||
@@ -186,7 +185,7 @@ class MindMapExtractor(Extractor):
|
||||
}
|
||||
text = perform_variable_replacements(self._mind_map_prompt, variables=variables)
|
||||
async with chat_limiter:
|
||||
response = await thread_pool_exec(self._chat,text,[{"role": "user", "content": "Output:"}],{})
|
||||
response = await self._async_chat(text, [{"role": "user", "content": "Output:"}], {})
|
||||
response = re.sub(r"```[^\n]*", "", response)
|
||||
logging.debug(response)
|
||||
logging.debug(self._todict(markdown_to_json.dictify(response)))
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
|
||||
from common.misc_utils import thread_pool_exec
|
||||
|
||||
"""
|
||||
Reference:
|
||||
- [graphrag](https://github.com/microsoft/graphrag)
|
||||
@@ -83,12 +81,12 @@ class GraphExtractor(Extractor):
|
||||
if self.callback:
|
||||
self.callback(msg=f"Start processing for {chunk_key}: {content[:25]}...")
|
||||
async with chat_limiter:
|
||||
final_result = await thread_pool_exec(self._chat,"",[{"role": "user", "content": hint_prompt}],gen_conf,task_id)
|
||||
final_result = await self._async_chat("", [{"role": "user", "content": hint_prompt}], gen_conf, task_id)
|
||||
token_count += num_tokens_from_string(hint_prompt + final_result)
|
||||
history = pack_user_ass_to_openai_messages(hint_prompt, final_result, self._continue_prompt)
|
||||
for now_glean_index in range(self._max_gleanings):
|
||||
async with chat_limiter:
|
||||
glean_result = await thread_pool_exec(self._chat,"",history,gen_conf,task_id)
|
||||
glean_result = await self._async_chat("", history, gen_conf, task_id)
|
||||
history.extend([{"role": "assistant", "content": glean_result}])
|
||||
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + hint_prompt + self._continue_prompt)
|
||||
final_result += glean_result
|
||||
@@ -97,7 +95,7 @@ class GraphExtractor(Extractor):
|
||||
|
||||
history.extend([{"role": "user", "content": self._if_loop_prompt}])
|
||||
async with chat_limiter:
|
||||
if_loop_result = await thread_pool_exec(self._chat,"",history,gen_conf,task_id)
|
||||
if_loop_result = await self._async_chat("", history, gen_conf, task_id)
|
||||
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + if_loop_result + self._if_loop_prompt)
|
||||
if_loop_result = if_loop_result.strip().strip('"').strip("'").lower()
|
||||
if if_loop_result != "yes":
|
||||
|
||||
Reference in New Issue
Block a user