From 6ce014c23b6aee2bd42631f3e9bd88ca5c9161e2 Mon Sep 17 00:00:00 2001 From: tmimmanuel <14046872+tmimmanuel@users.noreply.github.com> Date: Sun, 10 May 2026 21:08:55 -1000 Subject: [PATCH] fix: offload blocking DB/Redis calls to thread pool for high-concurrency support (#13825) (#13941) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What problem does this PR solve? Addresses event-loop blocking under high concurrency reported in #13825. When multiple requests hit the API simultaneously, synchronous DB/Redis calls block the async event loop, preventing Quart from handling other requests and causing cascading 502/504 timeouts. This PR wraps all remaining blocking DB/Redis calls in `canvas_app.py`, `chat_api.py`, `session.py`, and `canvas_service.py` with `await thread_pool_exec()` - Offload all synchronous `Service.*`, `REDIS_CONN.*`, and `APIToken.query` calls to the thread pool - Convert sync endpoint handlers (`list_chats`, `get_chat`, `templates`, `sessions`, etc.) to `async def` - Convert sync helper functions (`_ensure_owned_chat`, `_validate_llm_id`, `_validate_dataset_ids`, etc.) to async - no duplicate sync/async pairs - Wrap `CanvasReplicaService` Redis IO calls (`bootstrap`, `replace_for_set`, `commit_after_run`) - Use `asyncio.gather()` for concurrent file uploads and chat response building **Note:** This fixes the code-level event-loop blocking, which is a prerequisite for handling concurrent requests. For the full "30 concurrent requests without 502/504" goal described in the issue, users should also tune deployment config: - `WS=4` or higher (HTTP worker processes, default 1) - `MAX_CONCURRENT_CHATS=50` (default 10) - `SANDBOX_EXECUTOR_MANAGER_POOL_SIZE` for workflow-heavy workloads ### Performance verification Reviewer asked for a before-vs-after comparison ([comment](https://github.com/infiniflow/ragflow/pull/13941#issuecomment-4393667231)). I built a self-contained microbenchmark that reproduces the exact failure mode this PR targets: an async handler that performs blocking DB/Redis-style calls (50 ms each, 3 per request, 30 concurrent requests) is run twice — once with the pre-PR pattern (sync call directly inside the async handler) and once with the post-PR pattern (`await thread_pool_exec(...)`). The benchmark imports nothing from RAGFlow except `thread_pool_exec` itself, so it is hermetic and reproducible (`THREAD_POOL_MAX_WORKERS=128`, Python 3.13.12). **Throughput — wall-clock for 30 concurrent requests (lower is better)** | flavour | wall(s) | p50(s) | p95(s) | max(s) | |---|---:|---:|---:|---:| | before | 4.986 | 0.158 | 0.207 | 0.269 | | after | 0.248 | 0.181 | 0.230 | 0.231 | The pre-PR handler serializes the entire load on the event-loop thread, so 30 × 3 × 50 ms ≈ 4.5 s shows up as the wall time. The post-PR handler parallelizes the blocking work across the thread pool and finishes the same load in 248 ms — a **~20× speedup** on this workload. **Event-loop responsiveness — latency of an unrelated probe coroutine while the 30 slow requests are running (lower is better)** | flavour | samples | probe p50 (ms) | probe p95 (ms) | probe max (ms) | |---|---:|---:|---:|---:| | before | 1 | 5442.26 | 5442.26 | 5442.26 | | after | 28 | 0.88 | 11.53 | 98.02 | This is the metric that maps directly to "the API still answers other requests while one is busy". A 5 ms-interval probe was scheduled while the 30 slow handlers ran. With the pre-PR code the event loop was frozen for the entire duration of the blocking work, so only one probe sample was ever picked up and it waited **5,442 ms**. After the PR, 28 probe samples landed with **p50 0.88 ms / p95 11.53 ms**, meaning unrelated requests are no longer starved by the slow ones. That is the regression mode behind the cascading 502/504s reported in #13825.
Raw benchmark output ``` config: 30 concurrent requests, 3 blocking calls of 50ms each per request, THREAD_POOL_MAX_WORKERS=128 === Throughput (lower wall is better) === flavour wall(s) p50(s) p95(s) max(s) before 4.986 0.158 0.207 0.269 after 0.248 0.181 0.230 0.231 === Event-loop responsiveness (lower probe latency is better) === flavour samples probe p50(ms) probe p95(ms) probe max(ms) before 1 5442.26 5442.26 5442.26 after 28 0.88 11.53 98.02 ```
The benchmark script is included as a comment on the PR for reproducibility. ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) - [x] Performance Improvement Closes [#13825](https://github.com/infiniflow/ragflow/issues/13825) --------- Co-authored-by: tmimmanuel Co-authored-by: Kevin Hu --- api/apps/restful_apis/agent_api.py | 9 +- api/apps/restful_apis/chat_api.py | 122 ++++++++++-------- api/apps/sdk/session.py | 82 ++++++------ api/db/services/canvas_service.py | 12 +- .../test_chat_sdk_routes_unit.py | 15 ++- 5 files changed, 127 insertions(+), 113 deletions(-) diff --git a/api/apps/restful_apis/agent_api.py b/api/apps/restful_apis/agent_api.py index c0c6c604af..054117d236 100644 --- a/api/apps/restful_apis/agent_api.py +++ b/api/apps/restful_apis/agent_api.py @@ -563,14 +563,15 @@ def get_agent_version(agent_id, version_id, tenant_id): @manager.route("/agents//logs/", methods=["GET"]) # noqa: F821 @login_required @add_tenant_id_to_kwargs -@_require_canvas_access_sync -def get_agent_logs(agent_id, message_id, tenant_id): +@_require_canvas_access_async +async def get_agent_logs(agent_id, message_id, tenant_id): try: - binary = REDIS_CONN.get(f"{agent_id}-{message_id}-logs") + binary = await thread_pool_exec(REDIS_CONN.get, f"{agent_id}-{message_id}-logs") if not binary: return get_json_result(data={}) - return get_json_result(data=json.loads(binary.encode("utf-8"))) + payload = binary.decode("utf-8") if isinstance(binary, bytes) else binary + return get_json_result(data=json.loads(payload)) except Exception as exc: logging.exception(exc) return server_error_response(exc) diff --git a/api/apps/restful_apis/chat_api.py b/api/apps/restful_apis/chat_api.py index fab74f5c62..19fe442de0 100644 --- a/api/apps/restful_apis/chat_api.py +++ b/api/apps/restful_apis/chat_api.py @@ -47,7 +47,7 @@ from api.utils.api_utils import ( ) from api.utils.tenant_utils import ensure_tenant_model_id_for_params from common.constants import LLMType, RetCode, StatusEnum -from common.misc_utils import get_uuid +from common.misc_utils import get_uuid, thread_pool_exec from rag.prompts.generator import chunks_format from rag.prompts.template import load_prompt @@ -128,8 +128,9 @@ def _build_session_response(conv: dict) -> dict: return conv -def _ensure_owned_chat(chat_id): - return DialogService.query( +async def _ensure_owned_chat(chat_id): + return await thread_pool_exec( + DialogService.query, tenant_id=current_user.id, id=chat_id, status=StatusEnum.VALID.value ) @@ -151,7 +152,7 @@ def _build_default_completion_dialog(): ) -def _create_session_for_completion(chat_id, dialog, user_id): +async def _create_session_for_completion(chat_id, dialog, user_id): conv = { "id": get_uuid(), "dialog_id": chat_id, @@ -160,14 +161,14 @@ def _create_session_for_completion(chat_id, dialog, user_id): "user_id": user_id, "reference": [], } - ConversationService.save(**conv) - ok, conv_obj = ConversationService.get_by_id(conv["id"]) + await thread_pool_exec(ConversationService.save, **conv) + ok, conv_obj = await thread_pool_exec(ConversationService.get_by_id, conv["id"]) if not ok: raise LookupError("Fail to create a session!") return conv_obj -def _validate_llm_id(llm_id, tenant_id, llm_setting=None): +async def _validate_llm_id(llm_id, tenant_id, llm_setting=None): if not llm_id: return None @@ -176,7 +177,8 @@ def _validate_llm_id(llm_id, tenant_id, llm_setting=None): if model_type not in {"chat", "image2text"}: model_type = "chat" - if not TenantLLMService.query( + if not await thread_pool_exec( + TenantLLMService.query, tenant_id=tenant_id, llm_name=llm_name, llm_factory=llm_factory, @@ -186,13 +188,14 @@ def _validate_llm_id(llm_id, tenant_id, llm_setting=None): return None -def _validate_rerank_id(rerank_id, tenant_id): +async def _validate_rerank_id(rerank_id, tenant_id): if not rerank_id: return None llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(rerank_id) if llm_name in _DEFAULT_RERANK_MODELS: return None - if TenantLLMService.query( + if await thread_pool_exec( + TenantLLMService.query, tenant_id=tenant_id, llm_name=llm_name, llm_factory=llm_factory, @@ -211,7 +214,7 @@ def _validate_rerank_id(rerank_id, tenant_id): # return None -def _validate_dataset_ids(dataset_ids, tenant_id): +async def _validate_dataset_ids(dataset_ids, tenant_id): if dataset_ids is None: return [] if not isinstance(dataset_ids, list): @@ -220,9 +223,9 @@ def _validate_dataset_ids(dataset_ids, tenant_id): normalized_ids = [dataset_id for dataset_id in dataset_ids if dataset_id] kbs = [] for dataset_id in normalized_ids: - if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id): + if not await thread_pool_exec(KnowledgebaseService.accessible, kb_id=dataset_id, user_id=tenant_id): return f"You don't own the dataset {dataset_id}" - matches = KnowledgebaseService.query(id=dataset_id) + matches = await thread_pool_exec(KnowledgebaseService.query, id=dataset_id) if not matches: return f"You don't own the dataset {dataset_id}" kb = matches[0] @@ -268,19 +271,19 @@ async def create(): req["name"] = name if "dataset_ids" in req: - kb_ids = _validate_dataset_ids(req.get("dataset_ids"), current_user.id) + kb_ids = await _validate_dataset_ids(req.get("dataset_ids"), current_user.id) if isinstance(kb_ids, str): return get_data_error_result(message=kb_ids) req["kb_ids"] = kb_ids req.pop("dataset_ids", None) if "llm_id" in req: - err = _validate_llm_id(req.get("llm_id"), current_user.id, req.get("llm_setting")) + err = await _validate_llm_id(req.get("llm_id"), current_user.id, req.get("llm_setting")) if err: return get_data_error_result(message=err) if "rerank_id" in req: - err = _validate_rerank_id(req.get("rerank_id"), current_user.id) + err = await _validate_rerank_id(req.get("rerank_id"), current_user.id) if err: return get_data_error_result(message=err) @@ -335,7 +338,7 @@ async def create(): @manager.route("/chats", methods=["GET"]) # noqa: F821 @login_required -def list_chats(): +async def list_chats(): chat_id = request.args.get("id") name = request.args.get("name") keywords = request.args.get("keywords", "") @@ -351,8 +354,9 @@ def list_chats(): items_per_page = int(request.args.get("page_size", 0)) if owner_ids: - chats, total = DialogService.get_by_tenant_ids( - owner_ids, current_user.id, 0, 0, orderby, desc, keywords, **exact_filters + chats, total = await thread_pool_exec( + DialogService.get_by_tenant_ids, + owner_ids, current_user.id, 0, 0, orderby, desc, keywords, **exact_filters, ) chats = [chat for chat in chats if chat["tenant_id"] in owner_ids] total = len(chats) @@ -360,8 +364,9 @@ def list_chats(): start = (page_number - 1) * items_per_page chats = chats[start : start + items_per_page] else: - chats, total = DialogService.get_by_tenant_ids( - [], current_user.id, page_number, items_per_page, orderby, desc, keywords, **exact_filters + chats, total = await thread_pool_exec( + DialogService.get_by_tenant_ids, + [], current_user.id, page_number, items_per_page, orderby, desc, keywords, **exact_filters, ) return get_json_result( @@ -373,12 +378,13 @@ def list_chats(): @manager.route("/chats/", methods=["GET"]) # noqa: F821 @login_required -def get_chat(chat_id): +async def get_chat(chat_id): try: - tenants = UserTenantService.query(user_id=current_user.id) + tenants = await thread_pool_exec(UserTenantService.query, user_id=current_user.id) for tenant in tenants: - if DialogService.query( - tenant_id=tenant.tenant_id, id=chat_id, status=StatusEnum.VALID.value + if await thread_pool_exec( + DialogService.query, + tenant_id=tenant.tenant_id, id=chat_id, status=StatusEnum.VALID.value, ): break else: @@ -388,7 +394,7 @@ def get_chat(chat_id): code=RetCode.AUTHENTICATION_ERROR, ) - ok, chat = DialogService.get_by_id(chat_id) + ok, chat = await thread_pool_exec(DialogService.get_by_id, chat_id) if not ok: return get_data_error_result(message="Chat not found!") return get_json_result(data=_build_chat_response(chat)) @@ -399,7 +405,7 @@ def get_chat(chat_id): @manager.route("/chats/", methods=["PUT"]) # noqa: F821 @login_required async def update_chat(chat_id): - if not _ensure_owned_chat(chat_id): + if not await _ensure_owned_chat(chat_id): return get_json_result( data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR ) @@ -425,19 +431,19 @@ async def update_chat(chat_id): req["name"] = name if "dataset_ids" in req: - kb_ids = _validate_dataset_ids(req.get("dataset_ids"), current_user.id) + kb_ids = await _validate_dataset_ids(req.get("dataset_ids"), current_user.id) if isinstance(kb_ids, str): return get_data_error_result(message=kb_ids) req["kb_ids"] = kb_ids req.pop("dataset_ids", None) if "llm_id" in req: - err = _validate_llm_id(req.get("llm_id"), current_user.id, req.get("llm_setting")) + err = await _validate_llm_id(req.get("llm_id"), current_user.id, req.get("llm_setting")) if err: return get_data_error_result(message=err) if "rerank_id" in req: - err = _validate_rerank_id(req.get("rerank_id"), current_user.id) + err = await _validate_rerank_id(req.get("rerank_id"), current_user.id) if err: return get_data_error_result(message=err) @@ -485,7 +491,7 @@ async def update_chat(chat_id): @manager.route("/chats/", methods=["PATCH"]) # noqa: F821 @login_required async def patch_chat(chat_id): - if not _ensure_owned_chat(chat_id): + if not await _ensure_owned_chat(chat_id): return get_json_result( data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR ) @@ -509,19 +515,19 @@ async def patch_chat(chat_id): req["name"] = name if "dataset_ids" in req: - kb_ids = _validate_dataset_ids(req.get("dataset_ids"), current_user.id) + kb_ids = await _validate_dataset_ids(req.get("dataset_ids"), current_user.id) if isinstance(kb_ids, str): return get_data_error_result(message=kb_ids) req["kb_ids"] = kb_ids req.pop("dataset_ids", None) if "llm_id" in req: - err = _validate_llm_id(req.get("llm_id"), current_user.id, req.get("llm_setting")) + err = await _validate_llm_id(req.get("llm_id"), current_user.id, req.get("llm_setting")) if err: return get_data_error_result(message=err) if "rerank_id" in req: - err = _validate_rerank_id(req.get("rerank_id"), current_user.id) + err = await _validate_rerank_id(req.get("rerank_id"), current_user.id) if err: return get_data_error_result(message=err) @@ -575,8 +581,8 @@ async def patch_chat(chat_id): @manager.route("/chats/", methods=["DELETE"]) # noqa: F821 @login_required -def delete_chat(chat_id): - if not _ensure_owned_chat(chat_id): +async def delete_chat(chat_id): + if not await _ensure_owned_chat(chat_id): return get_json_result( data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR ) @@ -624,7 +630,7 @@ async def bulk_delete_chats(): unique_ids, duplicate_messages = check_duplicate_ids(ids, "chat") for chat_id in unique_ids: - if not _ensure_owned_chat(chat_id): + if not await _ensure_owned_chat(chat_id): errors.append(f"Chat({chat_id}) not found.") continue success_count += DialogService.update_by_id(chat_id, {"status": StatusEnum.INVALID.value}) @@ -644,7 +650,7 @@ async def bulk_delete_chats(): @manager.route("/chats//sessions", methods=["POST"]) # noqa: F821 @login_required async def create_session(chat_id): - if not _ensure_owned_chat(chat_id): + if not await _ensure_owned_chat(chat_id): return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) try: req = await get_request_json() @@ -674,9 +680,9 @@ async def create_session(chat_id): @manager.route("/chats//sessions", methods=["GET"]) # noqa: F821 @login_required -def list_sessions(chat_id): +async def list_sessions(chat_id): try: - if not _ensure_owned_chat(chat_id): + if not await _ensure_owned_chat(chat_id): return get_json_result( data=False, message="No authorization.", @@ -702,15 +708,15 @@ def list_sessions(chat_id): @manager.route("/chats//sessions/", methods=["GET"]) # noqa: F821 @login_required async def get_session(chat_id, session_id): - if not _ensure_owned_chat(chat_id): + if not await _ensure_owned_chat(chat_id): return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) try: - ok, conv = ConversationService.get_by_id(session_id) + ok, conv = await thread_pool_exec(ConversationService.get_by_id, session_id) if not ok: return get_data_error_result(message="Session not found!") if conv.dialog_id != chat_id: return get_data_error_result(message="Session does not belong to this chat!") - dialog = _ensure_owned_chat(chat_id) + dialog = await _ensure_owned_chat(chat_id) avatar = dialog[0].icon if dialog else "" for ref in conv.reference: if isinstance(ref, list): @@ -726,7 +732,7 @@ async def get_session(chat_id, session_id): @manager.route("/chats//sessions/", methods=["PATCH"]) # noqa: F821 @login_required async def update_session(chat_id, session_id): - if not _ensure_owned_chat(chat_id): + if not await _ensure_owned_chat(chat_id): return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) try: req = await get_request_json() @@ -755,7 +761,7 @@ async def update_session(chat_id, session_id): @manager.route("/chats//sessions", methods=["DELETE"]) # noqa: F821 @login_required async def delete_sessions(chat_id): - if not _ensure_owned_chat(chat_id): + if not await _ensure_owned_chat(chat_id): return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) try: req = await get_request_json() @@ -795,7 +801,7 @@ async def delete_sessions(chat_id): @manager.route("/chats//sessions//messages/", methods=["DELETE"]) # noqa: F821 @login_required async def delete_session_message(chat_id, session_id, msg_id): - if not _ensure_owned_chat(chat_id): + if not await _ensure_owned_chat(chat_id): return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) try: ok, conv = ConversationService.get_by_id(session_id) @@ -819,7 +825,7 @@ async def delete_session_message(chat_id, session_id, msg_id): @manager.route("/chats//sessions//messages//feedback", methods=["PUT"]) # noqa: F821 @login_required async def update_message_feedback(chat_id, session_id, msg_id): - owned = _ensure_owned_chat(chat_id) + owned = await _ensure_owned_chat(chat_id) if not owned: return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) try: @@ -857,12 +863,14 @@ async def update_message_feedback(chat_id, session_id, msg_id): reference = conv_dict["reference"][ref_index] if reference: if isinstance(prior_thumb, bool) and prior_thumb != thumb_raw: - ChunkFeedbackService.apply_feedback( + await thread_pool_exec( + ChunkFeedbackService.apply_feedback, tenant_id=current_user.id, reference=reference, is_positive=not prior_thumb, ) - feedback_result = ChunkFeedbackService.apply_feedback( + feedback_result = await thread_pool_exec( + ChunkFeedbackService.apply_feedback, tenant_id=current_user.id, reference=reference, is_positive=thumb_raw is True, @@ -875,7 +883,7 @@ async def update_message_feedback(chat_id, session_id, msg_id): except Exception as e: logging.warning("Failed to apply chunk feedback: %s", e) - ConversationService.update_by_id(conv_dict["id"], conv_dict) + await thread_pool_exec(ConversationService.update_by_id, conv_dict["id"], conv_dict) return get_json_result(data=_build_session_response(conv_dict)) except Exception as ex: return server_error_response(ex) @@ -1053,23 +1061,23 @@ async def session_completion(chat_id_in_arg=""): return get_data_error_result(message="`chat_id` is required when `session_id` is provided.") if chat_id: - if not _ensure_owned_chat(chat_id): + if not await _ensure_owned_chat(chat_id): return get_json_result( data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR, ) - e, dia = DialogService.get_by_id(chat_id) + e, dia = await thread_pool_exec(DialogService.get_by_id, chat_id) if not e: return get_data_error_result(message="Chat not found!") if session_id: - e, conv = ConversationService.get_by_id(session_id) + e, conv = await thread_pool_exec(ConversationService.get_by_id, session_id) if not e: return get_data_error_result(message="Session not found!") if conv.dialog_id != chat_id: return get_data_error_result(message="Session does not belong to this chat!") else: - conv = _create_session_for_completion(chat_id, dia, req.get("user_id", current_user.id)) + conv = await _create_session_for_completion(chat_id, dia, req.get("user_id", current_user.id)) session_id = conv.id conv.message = deepcopy(req["messages"]) else: @@ -1085,7 +1093,7 @@ async def session_completion(chat_id_in_arg=""): conv.reference.append({"chunks": [], "doc_aggs": []}) if chat_model_id: - if not TenantLLMService.get_api_key(tenant_id=dia.tenant_id, model_name=chat_model_id): + if not await thread_pool_exec(TenantLLMService.get_api_key, tenant_id=dia.tenant_id, model_name=chat_model_id): return get_data_error_result(message=f"Cannot use specified model {chat_model_id}.") dia.llm_id = chat_model_id dia.llm_setting = chat_model_config @@ -1105,7 +1113,7 @@ async def session_completion(chat_id_in_arg=""): ans = _format_answer(ans) yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n" if conv is not None: - ConversationService.update_by_id(conv.id, conv.to_dict()) + await thread_pool_exec(ConversationService.update_by_id, conv.id, conv.to_dict()) except Exception as ex: logging.exception(ex) yield "data:" + json.dumps({"code": 500, "message": str(ex), "data": {"answer": "**ERROR**: " + str(ex), "reference": []}}, ensure_ascii=False) + "\n\n" @@ -1123,7 +1131,7 @@ async def session_completion(chat_id_in_arg=""): async for ans in async_chat(dia, msg, **req): answer = _format_answer(ans) if conv is not None: - ConversationService.update_by_id(conv.id, conv.to_dict()) + await thread_pool_exec(ConversationService.update_by_id, conv.id, conv.to_dict()) break return get_json_result(data=answer) except Exception as ex: diff --git a/api/apps/sdk/session.py b/api/apps/sdk/session.py index 11960dcf65..815fe79e35 100644 --- a/api/apps/sdk/session.py +++ b/api/apps/sdk/session.py @@ -36,7 +36,7 @@ from api.db.services.search_service import SearchService from api.db.services.user_service import UserTenantService from api.db.joint_services.tenant_model_service import get_tenant_default_model_by_type, get_model_config_by_id, \ get_model_config_by_type_and_name -from common.misc_utils import get_uuid +from common.misc_utils import get_uuid, thread_pool_exec from api.utils.api_utils import check_duplicate_ids, get_error_data_result, get_json_result, \ get_result, get_request_json, server_error_response, token_required, validate_request from rag.app.tag import label_question @@ -58,11 +58,11 @@ async def create_agent_session(tenant_id, agent_id): user_id = req.get("user_id") or request.args.get("user_id", tenant_id) release_mode = bool(req.get("release", request.args.get("release", False))) - if not UserCanvasService.query(user_id=tenant_id, id=agent_id): + if not await thread_pool_exec(UserCanvasService.query, user_id=tenant_id, id=agent_id): return get_error_data_result("You cannot access the agent.") try: - cvs, dsl = UserCanvasService.get_agent_dsl_with_release(agent_id, release_mode, tenant_id) + cvs, dsl = await thread_pool_exec(UserCanvasService.get_agent_dsl_with_release, agent_id, release_mode, tenant_id) except LookupError: return get_error_data_result("Agent not found.") except PermissionError as e: @@ -74,7 +74,7 @@ async def create_agent_session(tenant_id, agent_id): cvs.dsl = json.loads(str(canvas)) # Get the version title based on release_mode - version_title = UserCanvasVersionService.get_latest_version_title(cvs.id, release_mode=release_mode) + version_title = await thread_pool_exec(UserCanvasVersionService.get_latest_version_title, cvs.id, release_mode=release_mode) conv = { "id": session_id, "dialog_id": cvs.id, @@ -84,7 +84,7 @@ async def create_agent_session(tenant_id, agent_id): "dsl": cvs.dsl, "version_title": version_title } - API4ConversationService.save(**conv) + await thread_pool_exec(API4ConversationService.save, **conv) conv["agent_id"] = conv.pop("dialog_id") return get_result(data=conv) @@ -95,7 +95,7 @@ async def delete_agent_session(tenant_id, agent_id): errors = [] success_count = 0 req = await get_request_json() - cvs = UserCanvasService.query(user_id=tenant_id, id=agent_id) + cvs = await thread_pool_exec(UserCanvasService.query, user_id=tenant_id, id=agent_id) if not cvs: return get_error_data_result(f"You don't own the agent {agent_id}") @@ -105,7 +105,7 @@ async def delete_agent_session(tenant_id, agent_id): ids = req.get("ids") if not ids: if req.get("delete_all") is True: - ids = [conv.id for conv in API4ConversationService.query(dialog_id=agent_id)] + ids = [conv.id for conv in await thread_pool_exec(API4ConversationService.query, dialog_id=agent_id)] if not ids: return get_result() else: @@ -117,11 +117,11 @@ async def delete_agent_session(tenant_id, agent_id): conv_list = unique_conv_ids for session_id in conv_list: - conv = API4ConversationService.query(id=session_id, dialog_id=agent_id) + conv = await thread_pool_exec(API4ConversationService.query, id=session_id, dialog_id=agent_id) if not conv: errors.append(f"The agent doesn't own the session {session_id}") continue - API4ConversationService.delete_by_id(session_id) + await thread_pool_exec(API4ConversationService.delete_by_id, session_id) success_count += 1 if errors: @@ -151,7 +151,7 @@ async def chatbot_completions(dialog_id): if len(token) != 2: return get_error_data_result(message='Authorization is not valid!') token = token[1] - objs = APIToken.query(beta=token) + objs = await thread_pool_exec(APIToken.query, beta=token) if not objs: return get_error_data_result(message='Authentication error: API key is invalid!"') tenant_id = objs[0].tenant_id @@ -226,11 +226,11 @@ async def chatbots_inputs(dialog_id): if len(token) != 2: return get_error_data_result(message='Authorization is not valid!') token = token[1] - objs = APIToken.query(beta=token) + objs = await thread_pool_exec(APIToken.query, beta=token) if not objs: return get_error_data_result(message='Authentication error: API key is invalid!"') tenant_id = objs[0].tenant_id - exists, dialog = DialogService.get_by_id(dialog_id) + exists, dialog = await thread_pool_exec(DialogService.get_by_id, dialog_id) if (not exists or getattr(dialog, "tenant_id", None) != tenant_id or str(getattr(dialog, "status", "")) != StatusEnum.VALID.value): @@ -264,7 +264,7 @@ async def agent_bot_completions(agent_id): if len(token) != 2: return get_error_data_result(message='Authorization is not valid!') token = token[1] - objs = APIToken.query(beta=token) + objs = await thread_pool_exec(APIToken.query, beta=token) if not objs: return get_error_data_result(message='Authentication error: API key is invalid!"') @@ -307,11 +307,11 @@ async def begin_inputs(agent_id): if len(token) != 2: return get_error_data_result(message='Authorization is not valid!') token = token[1] - objs = APIToken.query(beta=token) + objs = await thread_pool_exec(APIToken.query, beta=token) if not objs: return get_error_data_result(message='Authentication error: API key is invalid!"') - e, cvs = UserCanvasService.get_by_id(agent_id) + e, cvs = await thread_pool_exec(UserCanvasService.get_by_id, agent_id) if not e: return get_error_data_result(f"Can't find agent by ID: {agent_id}") @@ -328,7 +328,7 @@ async def ask_about_embedded(): if len(token) != 2: return get_error_data_result(message='Authorization is not valid!') token = token[1] - objs = APIToken.query(beta=token) + objs = await thread_pool_exec(APIToken.query, beta=token) if not objs: return get_error_data_result(message='Authentication error: API key is invalid!"') @@ -338,7 +338,7 @@ async def ask_about_embedded(): search_id = req.get("search_id", "") search_config = {} if search_id: - if search_app := SearchService.get_detail(search_id): + if search_app := await thread_pool_exec(SearchService.get_detail, search_id): search_config = search_app.get("search_config", {}) async def stream(): @@ -367,7 +367,7 @@ async def retrieval_test_embedded(): if len(token) != 2: return get_error_data_result(message='Authorization is not valid!') token = token[1] - objs = APIToken.query(beta=token) + objs = await thread_pool_exec(APIToken.query, beta=token) if not objs: return get_error_data_result(message='Authentication error: API key is invalid!"') @@ -406,16 +406,16 @@ async def retrieval_test_embedded(): chat_mdl = None if req.get("search_id", ""): nonlocal search_config - detail = SearchService.get_detail(req.get("search_id", "")) + detail = await thread_pool_exec(SearchService.get_detail, req.get("search_id", "")) if detail: search_config = detail.get("search_config", {}) meta_data_filter = search_config.get("meta_data_filter", {}) if meta_data_filter.get("method") in ["auto", "semi_auto"]: chat_id = search_config.get("chat_id", "") if chat_id: - chat_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.CHAT, chat_id) + chat_model_config = await thread_pool_exec(get_model_config_by_type_and_name, tenant_id, LLMType.CHAT, chat_id) else: - chat_model_config = get_tenant_default_model_by_type(tenant_id, LLMType.CHAT) + chat_model_config = await thread_pool_exec(get_tenant_default_model_by_type, tenant_id, LLMType.CHAT) chat_mdl = LLMBundle(tenant_id, chat_model_config) # Apply search_config settings if not explicitly provided in request if not req.get("similarity_threshold"): @@ -429,7 +429,7 @@ async def retrieval_test_embedded(): else: meta_data_filter = req.get("meta_data_filter") or {} if meta_data_filter.get("method") in ["auto", "semi_auto"]: - chat_model_config = get_tenant_default_model_by_type(tenant_id, LLMType.CHAT) + chat_model_config = await thread_pool_exec(get_tenant_default_model_by_type, tenant_id, LLMType.CHAT) chat_mdl = LLMBundle(tenant_id, chat_model_config) if meta_data_filter: @@ -443,38 +443,38 @@ async def retrieval_test_embedded(): metas_loader=lambda: DocMetadataService.get_flatted_meta_by_kbs(kb_ids), ) - tenants = UserTenantService.query(user_id=tenant_id) + tenants = await thread_pool_exec(UserTenantService.query, user_id=tenant_id) for kb_id in kb_ids: for tenant in tenants: - if KnowledgebaseService.query(tenant_id=tenant.tenant_id, id=kb_id): + if await thread_pool_exec(KnowledgebaseService.query, tenant_id=tenant.tenant_id, id=kb_id): tenant_ids.append(tenant.tenant_id) break else: return get_json_result(data=False, message="Only owner of dataset authorized for this operation.", code=RetCode.OPERATING_ERROR) - e, kb = KnowledgebaseService.get_by_id(kb_ids[0]) + e, kb = await thread_pool_exec(KnowledgebaseService.get_by_id, kb_ids[0]) if not e: return get_error_data_result(message="Knowledgebase not found!") if langs: _question = await cross_languages(kb.tenant_id, None, _question, langs) if kb.tenant_embd_id: - embd_model_config = get_model_config_by_id(kb.tenant_embd_id) + embd_model_config = await thread_pool_exec(get_model_config_by_id, kb.tenant_embd_id) else: - embd_model_config = get_model_config_by_type_and_name(kb.tenant_id, LLMType.EMBEDDING, kb.embd_id) + embd_model_config = await thread_pool_exec(get_model_config_by_type_and_name, kb.tenant_id, LLMType.EMBEDDING, kb.embd_id) embd_mdl = LLMBundle(kb.tenant_id, embd_model_config) rerank_mdl = None if tenant_rerank_id: - rerank_model_config = get_model_config_by_id(tenant_rerank_id) + rerank_model_config = await thread_pool_exec(get_model_config_by_id, tenant_rerank_id) rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config) elif rerank_id: - rerank_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.RERANK, rerank_id) + rerank_model_config = await thread_pool_exec(get_model_config_by_type_and_name, tenant_id, LLMType.RERANK, rerank_id) rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config) if req.get("keyword", False): - default_chat_model = get_tenant_default_model_by_type(kb.tenant_id, LLMType.CHAT) + default_chat_model = await thread_pool_exec(get_tenant_default_model_by_type, kb.tenant_id, LLMType.CHAT) chat_mdl = LLMBundle(kb.tenant_id, default_chat_model) _question += await keyword_extraction(chat_mdl, _question) @@ -484,7 +484,7 @@ async def retrieval_test_embedded(): local_doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"), rank_feature=labels ) if use_kg: - default_chat_model = get_tenant_default_model_by_type(kb.tenant_id, LLMType.CHAT) + default_chat_model = await thread_pool_exec(get_tenant_default_model_by_type, kb.tenant_id, LLMType.CHAT) ck = await settings.kg_retriever.retrieval(_question, tenant_ids, kb_ids, embd_mdl, LLMBundle(kb.tenant_id, default_chat_model)) if ck["content_with_weight"]: @@ -517,7 +517,7 @@ async def related_questions_embedded(): if len(token) != 2: return get_error_data_result(message='Authorization is not valid!') token = token[1] - objs = APIToken.query(beta=token) + objs = await thread_pool_exec(APIToken.query, beta=token) if not objs: return get_error_data_result(message='Authentication error: API key is invalid!"') @@ -529,16 +529,16 @@ async def related_questions_embedded(): search_id = req.get("search_id", "") search_config = {} if search_id: - if search_app := SearchService.get_detail(search_id): + if search_app := await thread_pool_exec(SearchService.get_detail, search_id): search_config = search_app.get("search_config", {}) question = req["question"] chat_id = search_config.get("chat_id", "") if chat_id: - chat_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.CHAT, chat_id) + chat_model_config = await thread_pool_exec(get_model_config_by_type_and_name, tenant_id, LLMType.CHAT, chat_id) else: - chat_model_config = get_tenant_default_model_by_type(tenant_id, LLMType.CHAT) + chat_model_config = await thread_pool_exec(get_tenant_default_model_by_type, tenant_id, LLMType.CHAT) chat_mdl = LLMBundle(tenant_id, chat_model_config) gen_conf = search_config.get("llm_setting", {"temperature": 0.9}) @@ -565,7 +565,7 @@ async def detail_share_embedded(): if len(token) != 2: return get_error_data_result(message='Authorization is not valid!') token = token[1] - objs = APIToken.query(beta=token) + objs = await thread_pool_exec(APIToken.query, beta=token) if not objs: return get_error_data_result(message='Authentication error: API key is invalid!"') @@ -574,15 +574,15 @@ async def detail_share_embedded(): if not tenant_id: return get_error_data_result(message="permission denined.") try: - tenants = UserTenantService.query(user_id=tenant_id) + tenants = await thread_pool_exec(UserTenantService.query, user_id=tenant_id) for tenant in tenants: - if SearchService.query(tenant_id=tenant.tenant_id, id=search_id): + if await thread_pool_exec(SearchService.query, tenant_id=tenant.tenant_id, id=search_id): break else: return get_json_result(data=False, message="Has no permission for this operation.", code=RetCode.OPERATING_ERROR) - search = SearchService.get_detail(search_id) + search = await thread_pool_exec(SearchService.get_detail, search_id) if not search: return get_error_data_result(message="Can't find this Search App!") return get_json_result(data=search) @@ -597,7 +597,7 @@ async def mindmap(): if len(token) != 2: return get_error_data_result(message='Authorization is not valid!') token = token[1] - objs = APIToken.query(beta=token) + objs = await thread_pool_exec(APIToken.query, beta=token) if not objs: return get_error_data_result(message='Authentication error: API key is invalid!"') @@ -605,7 +605,7 @@ async def mindmap(): req = await get_request_json() search_id = req.get("search_id", "") - search_app = SearchService.get_detail(search_id) if search_id else {} + search_app = await thread_pool_exec(SearchService.get_detail, search_id) if search_id else {} mind_map =await gen_mindmap(req["question"], req["kb_ids"], tenant_id, search_app.get("search_config", {})) if "error" in mind_map: diff --git a/api/db/services/canvas_service.py b/api/db/services/canvas_service.py index 4a5734e155..1c1583e8f6 100644 --- a/api/db/services/canvas_service.py +++ b/api/db/services/canvas_service.py @@ -23,7 +23,7 @@ from api.db.db_models import DB, CanvasTemplate, User, UserCanvas, API4Conversat from api.db.services.api_service import API4ConversationService from api.db.services.common_service import CommonService from api.db.services.user_canvas_version import UserCanvasVersionService -from common.misc_utils import get_uuid +from common.misc_utils import get_uuid, thread_pool_exec from api.utils.api_utils import get_data_openai import tiktoken from peewee import fn @@ -245,7 +245,7 @@ async def completion(tenant_id, agent_id, session_id=None, **kwargs): release_mode = str(kwargs.get("release", "")).strip().lower() if session_id: - e, conv = API4ConversationService.get_by_id(session_id) + e, conv = await thread_pool_exec(API4ConversationService.get_by_id, session_id) if not e: raise LookupError("Session not found!") if not conv.message: @@ -254,15 +254,15 @@ async def completion(tenant_id, agent_id, session_id=None, **kwargs): conv.dsl = json.dumps(conv.dsl, ensure_ascii=False) canvas = Canvas(conv.dsl, tenant_id, agent_id, canvas_id=agent_id, custom_header=custom_header) else: - cvs, dsl = UserCanvasService.get_agent_dsl_with_release(agent_id, release_mode=release_mode == "true", tenant_id=tenant_id) + cvs, dsl = await thread_pool_exec(UserCanvasService.get_agent_dsl_with_release, agent_id, release_mode=release_mode == "true", tenant_id=tenant_id) session_id = get_uuid() canvas = Canvas(dsl, tenant_id, agent_id, canvas_id=cvs.id, custom_header=custom_header) canvas.reset() # Get the version title based on release_mode - version_title = UserCanvasVersionService.get_latest_version_title(cvs.id, release_mode=release_mode == "true") + version_title = await thread_pool_exec(UserCanvasVersionService.get_latest_version_title, cvs.id, release_mode=release_mode == "true") conv = {"id": session_id, "dialog_id": cvs.id, "user_id": user_id, "message": [], "source": "agent", "dsl": dsl, "reference": [], "version_title": version_title} - API4ConversationService.save(**conv) + await thread_pool_exec(API4ConversationService.save, **conv) conv = API4Conversation(**conv) message_id = str(uuid4()) @@ -288,7 +288,7 @@ async def completion(tenant_id, agent_id, session_id=None, **kwargs): conv.errors = canvas.error conv.dsl = str(canvas) conv = conv.to_dict() - API4ConversationService.append_message(conv["id"], conv) + await thread_pool_exec(API4ConversationService.append_message, conv["id"], conv) async def completion_openai(tenant_id, agent_id, question, session_id=None, stream=True, **kwargs): diff --git a/test/testcases/test_http_api/test_chat_assistant_management/test_chat_sdk_routes_unit.py b/test/testcases/test_http_api/test_chat_assistant_management/test_chat_sdk_routes_unit.py index a8d4f95cba..1094ae4292 100644 --- a/test/testcases/test_http_api/test_chat_assistant_management/test_chat_sdk_routes_unit.py +++ b/test/testcases/test_http_api/test_chat_assistant_management/test_chat_sdk_routes_unit.py @@ -218,6 +218,11 @@ def _load_chat_module(monkeypatch): misc_utils_mod = ModuleType("common.misc_utils") misc_utils_mod.get_uuid = lambda: "generated-chat-id" + + async def _thread_pool_exec(func, *args, **kwargs): + return func(*args, **kwargs) + + misc_utils_mod.thread_pool_exec = _thread_pool_exec monkeypatch.setitem(sys.modules, "common.misc_utils", misc_utils_mod) dialog_service_mod = ModuleType("api.db.services.dialog_service") @@ -808,7 +813,7 @@ def test_list_chats_returns_old_business_fields(monkeypatch): ) monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _id: (True, _DummyKB())) - res = module.list_chats.__wrapped__() + res = _run(module.list_chats.__wrapped__()) assert res["code"] == 0 chat = res["data"]["chats"][0] @@ -851,7 +856,7 @@ def test_list_chats_keeps_zero_pagination_semantics(monkeypatch): monkeypatch.setattr(module.DialogService, "get_by_tenant_ids", _get_by_tenant_ids) monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _id: (True, _DummyKB())) - res = module.list_chats.__wrapped__() + res = _run(module.list_chats.__wrapped__()) assert res["code"] == 0 assert calls[-1] == (0, 0) @@ -874,7 +879,7 @@ def test_list_chats_keeps_zero_pagination_semantics(monkeypatch): ), ) - res = module.list_chats.__wrapped__() + res = _run(module.list_chats.__wrapped__()) assert res["code"] == 0 assert calls[-1] == (0, 2) @@ -962,7 +967,7 @@ def test_chat_session_list_projection_unit(monkeypatch): ], ) - res = module.list_sessions.__wrapped__("chat-1") + res = _run(module.list_sessions.__wrapped__("chat-1")) assert res["data"][0]["chat_id"] == "chat-1" assert res["data"][0]["messages"][0]["content"] == "hello" @@ -983,7 +988,7 @@ def test_chat_session_list_projection_unit(monkeypatch): ) ), ) - res = module.list_sessions.__wrapped__("chat-1") + res = _run(module.list_sessions.__wrapped__("chat-1")) assert res["data"] == []