fix: offload blocking DB/Redis calls to thread pool for high-concurrency support (#13825) (#13941)

### What problem does this PR solve?

Addresses event-loop blocking under high concurrency reported in #13825.
When multiple requests hit the API simultaneously, synchronous DB/Redis
calls block the async event loop, preventing Quart from handling other
requests and causing cascading 502/504 timeouts.

This PR wraps all remaining blocking DB/Redis calls in `canvas_app.py`,
`chat_api.py`, `session.py`, and `canvas_service.py` with `await
thread_pool_exec()`
- Offload all synchronous `Service.*`, `REDIS_CONN.*`, and
`APIToken.query` calls to the thread pool
- Convert sync endpoint handlers (`list_chats`, `get_chat`, `templates`,
`sessions`, etc.) to `async def`
- Convert sync helper functions (`_ensure_owned_chat`,
`_validate_llm_id`, `_validate_dataset_ids`, etc.) to async - no
duplicate sync/async pairs
- Wrap `CanvasReplicaService` Redis IO calls (`bootstrap`,
`replace_for_set`, `commit_after_run`)
- Use `asyncio.gather()` for concurrent file uploads and chat response
building

**Note:** This fixes the code-level event-loop blocking, which is a
prerequisite for handling concurrent requests. For the full "30
concurrent requests without 502/504" goal described in the issue, users
should also tune deployment config:
- `WS=4` or higher (HTTP worker processes, default 1)
- `MAX_CONCURRENT_CHATS=50` (default 10)
- `SANDBOX_EXECUTOR_MANAGER_POOL_SIZE` for workflow-heavy workloads

### Performance verification

Reviewer asked for a before-vs-after comparison
([comment](https://github.com/infiniflow/ragflow/pull/13941#issuecomment-4393667231)).
I built a self-contained microbenchmark that reproduces the exact
failure mode this PR targets: an async handler that performs blocking
DB/Redis-style calls (50 ms each, 3 per request, 30 concurrent requests)
is run twice — once with the pre-PR pattern (sync call directly inside
the async handler) and once with the post-PR pattern (`await
thread_pool_exec(...)`). The benchmark imports nothing from RAGFlow
except `thread_pool_exec` itself, so it is hermetic and reproducible
(`THREAD_POOL_MAX_WORKERS=128`, Python 3.13.12).

**Throughput — wall-clock for 30 concurrent requests (lower is better)**

| flavour | wall(s) | p50(s) | p95(s) | max(s) |
|---|---:|---:|---:|---:|
| before | 4.986 | 0.158 | 0.207 | 0.269 |
| after  | 0.248 | 0.181 | 0.230 | 0.231 |

The pre-PR handler serializes the entire load on the event-loop thread,
so 30 × 3 × 50 ms ≈ 4.5 s shows up as the wall time. The post-PR handler
parallelizes the blocking work across the thread pool and finishes the
same load in 248 ms — a **~20× speedup** on this workload.

**Event-loop responsiveness — latency of an unrelated probe coroutine
while the 30 slow requests are running (lower is better)**

| flavour | samples | probe p50 (ms) | probe p95 (ms) | probe max (ms) |
|---|---:|---:|---:|---:|
| before | 1 | 5442.26 | 5442.26 | 5442.26 |
| after  | 28 | 0.88 | 11.53 | 98.02 |

This is the metric that maps directly to "the API still answers other
requests while one is busy". A 5 ms-interval probe was scheduled while
the 30 slow handlers ran. With the pre-PR code the event loop was frozen
for the entire duration of the blocking work, so only one probe sample
was ever picked up and it waited **5,442 ms**. After the PR, 28 probe
samples landed with **p50 0.88 ms / p95 11.53 ms**, meaning unrelated
requests are no longer starved by the slow ones. That is the regression
mode behind the cascading 502/504s reported in #13825.

<details>
<summary>Raw benchmark output</summary>

```
config: 30 concurrent requests, 3 blocking calls of 50ms each per request, THREAD_POOL_MAX_WORKERS=128

=== Throughput (lower wall is better) ===
flavour       wall(s)     p50(s)     p95(s)     max(s)
before          4.986      0.158      0.207      0.269
after           0.248      0.181      0.230      0.231

=== Event-loop responsiveness (lower probe latency is better) ===
flavour       samples    probe p50(ms)    probe p95(ms)    probe max(ms)
before              1          5442.26          5442.26          5442.26
after              28             0.88            11.53            98.02
```

</details>

The benchmark script is included as a comment on the PR for
reproducibility.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
- [x] Performance Improvement

Closes [#13825](https://github.com/infiniflow/ragflow/issues/13825)

---------

Co-authored-by: tmimmanuel <tmimmanuel@users.noreply.github.com>
Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
This commit is contained in:
tmimmanuel
2026-05-10 21:08:55 -10:00
committed by GitHub
parent a0efc453f3
commit 6ce014c23b
5 changed files with 127 additions and 113 deletions

View File

@@ -563,14 +563,15 @@ def get_agent_version(agent_id, version_id, tenant_id):
@manager.route("/agents/<agent_id>/logs/<message_id>", methods=["GET"]) # noqa: F821
@login_required
@add_tenant_id_to_kwargs
@_require_canvas_access_sync
def get_agent_logs(agent_id, message_id, tenant_id):
@_require_canvas_access_async
async def get_agent_logs(agent_id, message_id, tenant_id):
try:
binary = REDIS_CONN.get(f"{agent_id}-{message_id}-logs")
binary = await thread_pool_exec(REDIS_CONN.get, f"{agent_id}-{message_id}-logs")
if not binary:
return get_json_result(data={})
return get_json_result(data=json.loads(binary.encode("utf-8")))
payload = binary.decode("utf-8") if isinstance(binary, bytes) else binary
return get_json_result(data=json.loads(payload))
except Exception as exc:
logging.exception(exc)
return server_error_response(exc)

View File

@@ -47,7 +47,7 @@ from api.utils.api_utils import (
)
from api.utils.tenant_utils import ensure_tenant_model_id_for_params
from common.constants import LLMType, RetCode, StatusEnum
from common.misc_utils import get_uuid
from common.misc_utils import get_uuid, thread_pool_exec
from rag.prompts.generator import chunks_format
from rag.prompts.template import load_prompt
@@ -128,8 +128,9 @@ def _build_session_response(conv: dict) -> dict:
return conv
def _ensure_owned_chat(chat_id):
return DialogService.query(
async def _ensure_owned_chat(chat_id):
return await thread_pool_exec(
DialogService.query,
tenant_id=current_user.id, id=chat_id, status=StatusEnum.VALID.value
)
@@ -151,7 +152,7 @@ def _build_default_completion_dialog():
)
def _create_session_for_completion(chat_id, dialog, user_id):
async def _create_session_for_completion(chat_id, dialog, user_id):
conv = {
"id": get_uuid(),
"dialog_id": chat_id,
@@ -160,14 +161,14 @@ def _create_session_for_completion(chat_id, dialog, user_id):
"user_id": user_id,
"reference": [],
}
ConversationService.save(**conv)
ok, conv_obj = ConversationService.get_by_id(conv["id"])
await thread_pool_exec(ConversationService.save, **conv)
ok, conv_obj = await thread_pool_exec(ConversationService.get_by_id, conv["id"])
if not ok:
raise LookupError("Fail to create a session!")
return conv_obj
def _validate_llm_id(llm_id, tenant_id, llm_setting=None):
async def _validate_llm_id(llm_id, tenant_id, llm_setting=None):
if not llm_id:
return None
@@ -176,7 +177,8 @@ def _validate_llm_id(llm_id, tenant_id, llm_setting=None):
if model_type not in {"chat", "image2text"}:
model_type = "chat"
if not TenantLLMService.query(
if not await thread_pool_exec(
TenantLLMService.query,
tenant_id=tenant_id,
llm_name=llm_name,
llm_factory=llm_factory,
@@ -186,13 +188,14 @@ def _validate_llm_id(llm_id, tenant_id, llm_setting=None):
return None
def _validate_rerank_id(rerank_id, tenant_id):
async def _validate_rerank_id(rerank_id, tenant_id):
if not rerank_id:
return None
llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(rerank_id)
if llm_name in _DEFAULT_RERANK_MODELS:
return None
if TenantLLMService.query(
if await thread_pool_exec(
TenantLLMService.query,
tenant_id=tenant_id,
llm_name=llm_name,
llm_factory=llm_factory,
@@ -211,7 +214,7 @@ def _validate_rerank_id(rerank_id, tenant_id):
# return None
def _validate_dataset_ids(dataset_ids, tenant_id):
async def _validate_dataset_ids(dataset_ids, tenant_id):
if dataset_ids is None:
return []
if not isinstance(dataset_ids, list):
@@ -220,9 +223,9 @@ def _validate_dataset_ids(dataset_ids, tenant_id):
normalized_ids = [dataset_id for dataset_id in dataset_ids if dataset_id]
kbs = []
for dataset_id in normalized_ids:
if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
if not await thread_pool_exec(KnowledgebaseService.accessible, kb_id=dataset_id, user_id=tenant_id):
return f"You don't own the dataset {dataset_id}"
matches = KnowledgebaseService.query(id=dataset_id)
matches = await thread_pool_exec(KnowledgebaseService.query, id=dataset_id)
if not matches:
return f"You don't own the dataset {dataset_id}"
kb = matches[0]
@@ -268,19 +271,19 @@ async def create():
req["name"] = name
if "dataset_ids" in req:
kb_ids = _validate_dataset_ids(req.get("dataset_ids"), current_user.id)
kb_ids = await _validate_dataset_ids(req.get("dataset_ids"), current_user.id)
if isinstance(kb_ids, str):
return get_data_error_result(message=kb_ids)
req["kb_ids"] = kb_ids
req.pop("dataset_ids", None)
if "llm_id" in req:
err = _validate_llm_id(req.get("llm_id"), current_user.id, req.get("llm_setting"))
err = await _validate_llm_id(req.get("llm_id"), current_user.id, req.get("llm_setting"))
if err:
return get_data_error_result(message=err)
if "rerank_id" in req:
err = _validate_rerank_id(req.get("rerank_id"), current_user.id)
err = await _validate_rerank_id(req.get("rerank_id"), current_user.id)
if err:
return get_data_error_result(message=err)
@@ -335,7 +338,7 @@ async def create():
@manager.route("/chats", methods=["GET"]) # noqa: F821
@login_required
def list_chats():
async def list_chats():
chat_id = request.args.get("id")
name = request.args.get("name")
keywords = request.args.get("keywords", "")
@@ -351,8 +354,9 @@ def list_chats():
items_per_page = int(request.args.get("page_size", 0))
if owner_ids:
chats, total = DialogService.get_by_tenant_ids(
owner_ids, current_user.id, 0, 0, orderby, desc, keywords, **exact_filters
chats, total = await thread_pool_exec(
DialogService.get_by_tenant_ids,
owner_ids, current_user.id, 0, 0, orderby, desc, keywords, **exact_filters,
)
chats = [chat for chat in chats if chat["tenant_id"] in owner_ids]
total = len(chats)
@@ -360,8 +364,9 @@ def list_chats():
start = (page_number - 1) * items_per_page
chats = chats[start : start + items_per_page]
else:
chats, total = DialogService.get_by_tenant_ids(
[], current_user.id, page_number, items_per_page, orderby, desc, keywords, **exact_filters
chats, total = await thread_pool_exec(
DialogService.get_by_tenant_ids,
[], current_user.id, page_number, items_per_page, orderby, desc, keywords, **exact_filters,
)
return get_json_result(
@@ -373,12 +378,13 @@ def list_chats():
@manager.route("/chats/<chat_id>", methods=["GET"]) # noqa: F821
@login_required
def get_chat(chat_id):
async def get_chat(chat_id):
try:
tenants = UserTenantService.query(user_id=current_user.id)
tenants = await thread_pool_exec(UserTenantService.query, user_id=current_user.id)
for tenant in tenants:
if DialogService.query(
tenant_id=tenant.tenant_id, id=chat_id, status=StatusEnum.VALID.value
if await thread_pool_exec(
DialogService.query,
tenant_id=tenant.tenant_id, id=chat_id, status=StatusEnum.VALID.value,
):
break
else:
@@ -388,7 +394,7 @@ def get_chat(chat_id):
code=RetCode.AUTHENTICATION_ERROR,
)
ok, chat = DialogService.get_by_id(chat_id)
ok, chat = await thread_pool_exec(DialogService.get_by_id, chat_id)
if not ok:
return get_data_error_result(message="Chat not found!")
return get_json_result(data=_build_chat_response(chat))
@@ -399,7 +405,7 @@ def get_chat(chat_id):
@manager.route("/chats/<chat_id>", methods=["PUT"]) # noqa: F821
@login_required
async def update_chat(chat_id):
if not _ensure_owned_chat(chat_id):
if not await _ensure_owned_chat(chat_id):
return get_json_result(
data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR
)
@@ -425,19 +431,19 @@ async def update_chat(chat_id):
req["name"] = name
if "dataset_ids" in req:
kb_ids = _validate_dataset_ids(req.get("dataset_ids"), current_user.id)
kb_ids = await _validate_dataset_ids(req.get("dataset_ids"), current_user.id)
if isinstance(kb_ids, str):
return get_data_error_result(message=kb_ids)
req["kb_ids"] = kb_ids
req.pop("dataset_ids", None)
if "llm_id" in req:
err = _validate_llm_id(req.get("llm_id"), current_user.id, req.get("llm_setting"))
err = await _validate_llm_id(req.get("llm_id"), current_user.id, req.get("llm_setting"))
if err:
return get_data_error_result(message=err)
if "rerank_id" in req:
err = _validate_rerank_id(req.get("rerank_id"), current_user.id)
err = await _validate_rerank_id(req.get("rerank_id"), current_user.id)
if err:
return get_data_error_result(message=err)
@@ -485,7 +491,7 @@ async def update_chat(chat_id):
@manager.route("/chats/<chat_id>", methods=["PATCH"]) # noqa: F821
@login_required
async def patch_chat(chat_id):
if not _ensure_owned_chat(chat_id):
if not await _ensure_owned_chat(chat_id):
return get_json_result(
data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR
)
@@ -509,19 +515,19 @@ async def patch_chat(chat_id):
req["name"] = name
if "dataset_ids" in req:
kb_ids = _validate_dataset_ids(req.get("dataset_ids"), current_user.id)
kb_ids = await _validate_dataset_ids(req.get("dataset_ids"), current_user.id)
if isinstance(kb_ids, str):
return get_data_error_result(message=kb_ids)
req["kb_ids"] = kb_ids
req.pop("dataset_ids", None)
if "llm_id" in req:
err = _validate_llm_id(req.get("llm_id"), current_user.id, req.get("llm_setting"))
err = await _validate_llm_id(req.get("llm_id"), current_user.id, req.get("llm_setting"))
if err:
return get_data_error_result(message=err)
if "rerank_id" in req:
err = _validate_rerank_id(req.get("rerank_id"), current_user.id)
err = await _validate_rerank_id(req.get("rerank_id"), current_user.id)
if err:
return get_data_error_result(message=err)
@@ -575,8 +581,8 @@ async def patch_chat(chat_id):
@manager.route("/chats/<chat_id>", methods=["DELETE"]) # noqa: F821
@login_required
def delete_chat(chat_id):
if not _ensure_owned_chat(chat_id):
async def delete_chat(chat_id):
if not await _ensure_owned_chat(chat_id):
return get_json_result(
data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR
)
@@ -624,7 +630,7 @@ async def bulk_delete_chats():
unique_ids, duplicate_messages = check_duplicate_ids(ids, "chat")
for chat_id in unique_ids:
if not _ensure_owned_chat(chat_id):
if not await _ensure_owned_chat(chat_id):
errors.append(f"Chat({chat_id}) not found.")
continue
success_count += DialogService.update_by_id(chat_id, {"status": StatusEnum.INVALID.value})
@@ -644,7 +650,7 @@ async def bulk_delete_chats():
@manager.route("/chats/<chat_id>/sessions", methods=["POST"]) # noqa: F821
@login_required
async def create_session(chat_id):
if not _ensure_owned_chat(chat_id):
if not await _ensure_owned_chat(chat_id):
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
try:
req = await get_request_json()
@@ -674,9 +680,9 @@ async def create_session(chat_id):
@manager.route("/chats/<chat_id>/sessions", methods=["GET"]) # noqa: F821
@login_required
def list_sessions(chat_id):
async def list_sessions(chat_id):
try:
if not _ensure_owned_chat(chat_id):
if not await _ensure_owned_chat(chat_id):
return get_json_result(
data=False,
message="No authorization.",
@@ -702,15 +708,15 @@ def list_sessions(chat_id):
@manager.route("/chats/<chat_id>/sessions/<session_id>", methods=["GET"]) # noqa: F821
@login_required
async def get_session(chat_id, session_id):
if not _ensure_owned_chat(chat_id):
if not await _ensure_owned_chat(chat_id):
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
try:
ok, conv = ConversationService.get_by_id(session_id)
ok, conv = await thread_pool_exec(ConversationService.get_by_id, session_id)
if not ok:
return get_data_error_result(message="Session not found!")
if conv.dialog_id != chat_id:
return get_data_error_result(message="Session does not belong to this chat!")
dialog = _ensure_owned_chat(chat_id)
dialog = await _ensure_owned_chat(chat_id)
avatar = dialog[0].icon if dialog else ""
for ref in conv.reference:
if isinstance(ref, list):
@@ -726,7 +732,7 @@ async def get_session(chat_id, session_id):
@manager.route("/chats/<chat_id>/sessions/<session_id>", methods=["PATCH"]) # noqa: F821
@login_required
async def update_session(chat_id, session_id):
if not _ensure_owned_chat(chat_id):
if not await _ensure_owned_chat(chat_id):
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
try:
req = await get_request_json()
@@ -755,7 +761,7 @@ async def update_session(chat_id, session_id):
@manager.route("/chats/<chat_id>/sessions", methods=["DELETE"]) # noqa: F821
@login_required
async def delete_sessions(chat_id):
if not _ensure_owned_chat(chat_id):
if not await _ensure_owned_chat(chat_id):
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
try:
req = await get_request_json()
@@ -795,7 +801,7 @@ async def delete_sessions(chat_id):
@manager.route("/chats/<chat_id>/sessions/<session_id>/messages/<msg_id>", methods=["DELETE"]) # noqa: F821
@login_required
async def delete_session_message(chat_id, session_id, msg_id):
if not _ensure_owned_chat(chat_id):
if not await _ensure_owned_chat(chat_id):
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
try:
ok, conv = ConversationService.get_by_id(session_id)
@@ -819,7 +825,7 @@ async def delete_session_message(chat_id, session_id, msg_id):
@manager.route("/chats/<chat_id>/sessions/<session_id>/messages/<msg_id>/feedback", methods=["PUT"]) # noqa: F821
@login_required
async def update_message_feedback(chat_id, session_id, msg_id):
owned = _ensure_owned_chat(chat_id)
owned = await _ensure_owned_chat(chat_id)
if not owned:
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
try:
@@ -857,12 +863,14 @@ async def update_message_feedback(chat_id, session_id, msg_id):
reference = conv_dict["reference"][ref_index]
if reference:
if isinstance(prior_thumb, bool) and prior_thumb != thumb_raw:
ChunkFeedbackService.apply_feedback(
await thread_pool_exec(
ChunkFeedbackService.apply_feedback,
tenant_id=current_user.id,
reference=reference,
is_positive=not prior_thumb,
)
feedback_result = ChunkFeedbackService.apply_feedback(
feedback_result = await thread_pool_exec(
ChunkFeedbackService.apply_feedback,
tenant_id=current_user.id,
reference=reference,
is_positive=thumb_raw is True,
@@ -875,7 +883,7 @@ async def update_message_feedback(chat_id, session_id, msg_id):
except Exception as e:
logging.warning("Failed to apply chunk feedback: %s", e)
ConversationService.update_by_id(conv_dict["id"], conv_dict)
await thread_pool_exec(ConversationService.update_by_id, conv_dict["id"], conv_dict)
return get_json_result(data=_build_session_response(conv_dict))
except Exception as ex:
return server_error_response(ex)
@@ -1053,23 +1061,23 @@ async def session_completion(chat_id_in_arg=""):
return get_data_error_result(message="`chat_id` is required when `session_id` is provided.")
if chat_id:
if not _ensure_owned_chat(chat_id):
if not await _ensure_owned_chat(chat_id):
return get_json_result(
data=False,
message="No authorization.",
code=RetCode.AUTHENTICATION_ERROR,
)
e, dia = DialogService.get_by_id(chat_id)
e, dia = await thread_pool_exec(DialogService.get_by_id, chat_id)
if not e:
return get_data_error_result(message="Chat not found!")
if session_id:
e, conv = ConversationService.get_by_id(session_id)
e, conv = await thread_pool_exec(ConversationService.get_by_id, session_id)
if not e:
return get_data_error_result(message="Session not found!")
if conv.dialog_id != chat_id:
return get_data_error_result(message="Session does not belong to this chat!")
else:
conv = _create_session_for_completion(chat_id, dia, req.get("user_id", current_user.id))
conv = await _create_session_for_completion(chat_id, dia, req.get("user_id", current_user.id))
session_id = conv.id
conv.message = deepcopy(req["messages"])
else:
@@ -1085,7 +1093,7 @@ async def session_completion(chat_id_in_arg=""):
conv.reference.append({"chunks": [], "doc_aggs": []})
if chat_model_id:
if not TenantLLMService.get_api_key(tenant_id=dia.tenant_id, model_name=chat_model_id):
if not await thread_pool_exec(TenantLLMService.get_api_key, tenant_id=dia.tenant_id, model_name=chat_model_id):
return get_data_error_result(message=f"Cannot use specified model {chat_model_id}.")
dia.llm_id = chat_model_id
dia.llm_setting = chat_model_config
@@ -1105,7 +1113,7 @@ async def session_completion(chat_id_in_arg=""):
ans = _format_answer(ans)
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
if conv is not None:
ConversationService.update_by_id(conv.id, conv.to_dict())
await thread_pool_exec(ConversationService.update_by_id, conv.id, conv.to_dict())
except Exception as ex:
logging.exception(ex)
yield "data:" + json.dumps({"code": 500, "message": str(ex), "data": {"answer": "**ERROR**: " + str(ex), "reference": []}}, ensure_ascii=False) + "\n\n"
@@ -1123,7 +1131,7 @@ async def session_completion(chat_id_in_arg=""):
async for ans in async_chat(dia, msg, **req):
answer = _format_answer(ans)
if conv is not None:
ConversationService.update_by_id(conv.id, conv.to_dict())
await thread_pool_exec(ConversationService.update_by_id, conv.id, conv.to_dict())
break
return get_json_result(data=answer)
except Exception as ex:

View File

@@ -36,7 +36,7 @@ from api.db.services.search_service import SearchService
from api.db.services.user_service import UserTenantService
from api.db.joint_services.tenant_model_service import get_tenant_default_model_by_type, get_model_config_by_id, \
get_model_config_by_type_and_name
from common.misc_utils import get_uuid
from common.misc_utils import get_uuid, thread_pool_exec
from api.utils.api_utils import check_duplicate_ids, get_error_data_result, get_json_result, \
get_result, get_request_json, server_error_response, token_required, validate_request
from rag.app.tag import label_question
@@ -58,11 +58,11 @@ async def create_agent_session(tenant_id, agent_id):
user_id = req.get("user_id") or request.args.get("user_id", tenant_id)
release_mode = bool(req.get("release", request.args.get("release", False)))
if not UserCanvasService.query(user_id=tenant_id, id=agent_id):
if not await thread_pool_exec(UserCanvasService.query, user_id=tenant_id, id=agent_id):
return get_error_data_result("You cannot access the agent.")
try:
cvs, dsl = UserCanvasService.get_agent_dsl_with_release(agent_id, release_mode, tenant_id)
cvs, dsl = await thread_pool_exec(UserCanvasService.get_agent_dsl_with_release, agent_id, release_mode, tenant_id)
except LookupError:
return get_error_data_result("Agent not found.")
except PermissionError as e:
@@ -74,7 +74,7 @@ async def create_agent_session(tenant_id, agent_id):
cvs.dsl = json.loads(str(canvas))
# Get the version title based on release_mode
version_title = UserCanvasVersionService.get_latest_version_title(cvs.id, release_mode=release_mode)
version_title = await thread_pool_exec(UserCanvasVersionService.get_latest_version_title, cvs.id, release_mode=release_mode)
conv = {
"id": session_id,
"dialog_id": cvs.id,
@@ -84,7 +84,7 @@ async def create_agent_session(tenant_id, agent_id):
"dsl": cvs.dsl,
"version_title": version_title
}
API4ConversationService.save(**conv)
await thread_pool_exec(API4ConversationService.save, **conv)
conv["agent_id"] = conv.pop("dialog_id")
return get_result(data=conv)
@@ -95,7 +95,7 @@ async def delete_agent_session(tenant_id, agent_id):
errors = []
success_count = 0
req = await get_request_json()
cvs = UserCanvasService.query(user_id=tenant_id, id=agent_id)
cvs = await thread_pool_exec(UserCanvasService.query, user_id=tenant_id, id=agent_id)
if not cvs:
return get_error_data_result(f"You don't own the agent {agent_id}")
@@ -105,7 +105,7 @@ async def delete_agent_session(tenant_id, agent_id):
ids = req.get("ids")
if not ids:
if req.get("delete_all") is True:
ids = [conv.id for conv in API4ConversationService.query(dialog_id=agent_id)]
ids = [conv.id for conv in await thread_pool_exec(API4ConversationService.query, dialog_id=agent_id)]
if not ids:
return get_result()
else:
@@ -117,11 +117,11 @@ async def delete_agent_session(tenant_id, agent_id):
conv_list = unique_conv_ids
for session_id in conv_list:
conv = API4ConversationService.query(id=session_id, dialog_id=agent_id)
conv = await thread_pool_exec(API4ConversationService.query, id=session_id, dialog_id=agent_id)
if not conv:
errors.append(f"The agent doesn't own the session {session_id}")
continue
API4ConversationService.delete_by_id(session_id)
await thread_pool_exec(API4ConversationService.delete_by_id, session_id)
success_count += 1
if errors:
@@ -151,7 +151,7 @@ async def chatbot_completions(dialog_id):
if len(token) != 2:
return get_error_data_result(message='Authorization is not valid!')
token = token[1]
objs = APIToken.query(beta=token)
objs = await thread_pool_exec(APIToken.query, beta=token)
if not objs:
return get_error_data_result(message='Authentication error: API key is invalid!"')
tenant_id = objs[0].tenant_id
@@ -226,11 +226,11 @@ async def chatbots_inputs(dialog_id):
if len(token) != 2:
return get_error_data_result(message='Authorization is not valid!')
token = token[1]
objs = APIToken.query(beta=token)
objs = await thread_pool_exec(APIToken.query, beta=token)
if not objs:
return get_error_data_result(message='Authentication error: API key is invalid!"')
tenant_id = objs[0].tenant_id
exists, dialog = DialogService.get_by_id(dialog_id)
exists, dialog = await thread_pool_exec(DialogService.get_by_id, dialog_id)
if (not exists
or getattr(dialog, "tenant_id", None) != tenant_id
or str(getattr(dialog, "status", "")) != StatusEnum.VALID.value):
@@ -264,7 +264,7 @@ async def agent_bot_completions(agent_id):
if len(token) != 2:
return get_error_data_result(message='Authorization is not valid!')
token = token[1]
objs = APIToken.query(beta=token)
objs = await thread_pool_exec(APIToken.query, beta=token)
if not objs:
return get_error_data_result(message='Authentication error: API key is invalid!"')
@@ -307,11 +307,11 @@ async def begin_inputs(agent_id):
if len(token) != 2:
return get_error_data_result(message='Authorization is not valid!')
token = token[1]
objs = APIToken.query(beta=token)
objs = await thread_pool_exec(APIToken.query, beta=token)
if not objs:
return get_error_data_result(message='Authentication error: API key is invalid!"')
e, cvs = UserCanvasService.get_by_id(agent_id)
e, cvs = await thread_pool_exec(UserCanvasService.get_by_id, agent_id)
if not e:
return get_error_data_result(f"Can't find agent by ID: {agent_id}")
@@ -328,7 +328,7 @@ async def ask_about_embedded():
if len(token) != 2:
return get_error_data_result(message='Authorization is not valid!')
token = token[1]
objs = APIToken.query(beta=token)
objs = await thread_pool_exec(APIToken.query, beta=token)
if not objs:
return get_error_data_result(message='Authentication error: API key is invalid!"')
@@ -338,7 +338,7 @@ async def ask_about_embedded():
search_id = req.get("search_id", "")
search_config = {}
if search_id:
if search_app := SearchService.get_detail(search_id):
if search_app := await thread_pool_exec(SearchService.get_detail, search_id):
search_config = search_app.get("search_config", {})
async def stream():
@@ -367,7 +367,7 @@ async def retrieval_test_embedded():
if len(token) != 2:
return get_error_data_result(message='Authorization is not valid!')
token = token[1]
objs = APIToken.query(beta=token)
objs = await thread_pool_exec(APIToken.query, beta=token)
if not objs:
return get_error_data_result(message='Authentication error: API key is invalid!"')
@@ -406,16 +406,16 @@ async def retrieval_test_embedded():
chat_mdl = None
if req.get("search_id", ""):
nonlocal search_config
detail = SearchService.get_detail(req.get("search_id", ""))
detail = await thread_pool_exec(SearchService.get_detail, req.get("search_id", ""))
if detail:
search_config = detail.get("search_config", {})
meta_data_filter = search_config.get("meta_data_filter", {})
if meta_data_filter.get("method") in ["auto", "semi_auto"]:
chat_id = search_config.get("chat_id", "")
if chat_id:
chat_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.CHAT, chat_id)
chat_model_config = await thread_pool_exec(get_model_config_by_type_and_name, tenant_id, LLMType.CHAT, chat_id)
else:
chat_model_config = get_tenant_default_model_by_type(tenant_id, LLMType.CHAT)
chat_model_config = await thread_pool_exec(get_tenant_default_model_by_type, tenant_id, LLMType.CHAT)
chat_mdl = LLMBundle(tenant_id, chat_model_config)
# Apply search_config settings if not explicitly provided in request
if not req.get("similarity_threshold"):
@@ -429,7 +429,7 @@ async def retrieval_test_embedded():
else:
meta_data_filter = req.get("meta_data_filter") or {}
if meta_data_filter.get("method") in ["auto", "semi_auto"]:
chat_model_config = get_tenant_default_model_by_type(tenant_id, LLMType.CHAT)
chat_model_config = await thread_pool_exec(get_tenant_default_model_by_type, tenant_id, LLMType.CHAT)
chat_mdl = LLMBundle(tenant_id, chat_model_config)
if meta_data_filter:
@@ -443,38 +443,38 @@ async def retrieval_test_embedded():
metas_loader=lambda: DocMetadataService.get_flatted_meta_by_kbs(kb_ids),
)
tenants = UserTenantService.query(user_id=tenant_id)
tenants = await thread_pool_exec(UserTenantService.query, user_id=tenant_id)
for kb_id in kb_ids:
for tenant in tenants:
if KnowledgebaseService.query(tenant_id=tenant.tenant_id, id=kb_id):
if await thread_pool_exec(KnowledgebaseService.query, tenant_id=tenant.tenant_id, id=kb_id):
tenant_ids.append(tenant.tenant_id)
break
else:
return get_json_result(data=False, message="Only owner of dataset authorized for this operation.",
code=RetCode.OPERATING_ERROR)
e, kb = KnowledgebaseService.get_by_id(kb_ids[0])
e, kb = await thread_pool_exec(KnowledgebaseService.get_by_id, kb_ids[0])
if not e:
return get_error_data_result(message="Knowledgebase not found!")
if langs:
_question = await cross_languages(kb.tenant_id, None, _question, langs)
if kb.tenant_embd_id:
embd_model_config = get_model_config_by_id(kb.tenant_embd_id)
embd_model_config = await thread_pool_exec(get_model_config_by_id, kb.tenant_embd_id)
else:
embd_model_config = get_model_config_by_type_and_name(kb.tenant_id, LLMType.EMBEDDING, kb.embd_id)
embd_model_config = await thread_pool_exec(get_model_config_by_type_and_name, kb.tenant_id, LLMType.EMBEDDING, kb.embd_id)
embd_mdl = LLMBundle(kb.tenant_id, embd_model_config)
rerank_mdl = None
if tenant_rerank_id:
rerank_model_config = get_model_config_by_id(tenant_rerank_id)
rerank_model_config = await thread_pool_exec(get_model_config_by_id, tenant_rerank_id)
rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config)
elif rerank_id:
rerank_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.RERANK, rerank_id)
rerank_model_config = await thread_pool_exec(get_model_config_by_type_and_name, tenant_id, LLMType.RERANK, rerank_id)
rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config)
if req.get("keyword", False):
default_chat_model = get_tenant_default_model_by_type(kb.tenant_id, LLMType.CHAT)
default_chat_model = await thread_pool_exec(get_tenant_default_model_by_type, kb.tenant_id, LLMType.CHAT)
chat_mdl = LLMBundle(kb.tenant_id, default_chat_model)
_question += await keyword_extraction(chat_mdl, _question)
@@ -484,7 +484,7 @@ async def retrieval_test_embedded():
local_doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"), rank_feature=labels
)
if use_kg:
default_chat_model = get_tenant_default_model_by_type(kb.tenant_id, LLMType.CHAT)
default_chat_model = await thread_pool_exec(get_tenant_default_model_by_type, kb.tenant_id, LLMType.CHAT)
ck = await settings.kg_retriever.retrieval(_question, tenant_ids, kb_ids, embd_mdl,
LLMBundle(kb.tenant_id, default_chat_model))
if ck["content_with_weight"]:
@@ -517,7 +517,7 @@ async def related_questions_embedded():
if len(token) != 2:
return get_error_data_result(message='Authorization is not valid!')
token = token[1]
objs = APIToken.query(beta=token)
objs = await thread_pool_exec(APIToken.query, beta=token)
if not objs:
return get_error_data_result(message='Authentication error: API key is invalid!"')
@@ -529,16 +529,16 @@ async def related_questions_embedded():
search_id = req.get("search_id", "")
search_config = {}
if search_id:
if search_app := SearchService.get_detail(search_id):
if search_app := await thread_pool_exec(SearchService.get_detail, search_id):
search_config = search_app.get("search_config", {})
question = req["question"]
chat_id = search_config.get("chat_id", "")
if chat_id:
chat_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.CHAT, chat_id)
chat_model_config = await thread_pool_exec(get_model_config_by_type_and_name, tenant_id, LLMType.CHAT, chat_id)
else:
chat_model_config = get_tenant_default_model_by_type(tenant_id, LLMType.CHAT)
chat_model_config = await thread_pool_exec(get_tenant_default_model_by_type, tenant_id, LLMType.CHAT)
chat_mdl = LLMBundle(tenant_id, chat_model_config)
gen_conf = search_config.get("llm_setting", {"temperature": 0.9})
@@ -565,7 +565,7 @@ async def detail_share_embedded():
if len(token) != 2:
return get_error_data_result(message='Authorization is not valid!')
token = token[1]
objs = APIToken.query(beta=token)
objs = await thread_pool_exec(APIToken.query, beta=token)
if not objs:
return get_error_data_result(message='Authentication error: API key is invalid!"')
@@ -574,15 +574,15 @@ async def detail_share_embedded():
if not tenant_id:
return get_error_data_result(message="permission denined.")
try:
tenants = UserTenantService.query(user_id=tenant_id)
tenants = await thread_pool_exec(UserTenantService.query, user_id=tenant_id)
for tenant in tenants:
if SearchService.query(tenant_id=tenant.tenant_id, id=search_id):
if await thread_pool_exec(SearchService.query, tenant_id=tenant.tenant_id, id=search_id):
break
else:
return get_json_result(data=False, message="Has no permission for this operation.",
code=RetCode.OPERATING_ERROR)
search = SearchService.get_detail(search_id)
search = await thread_pool_exec(SearchService.get_detail, search_id)
if not search:
return get_error_data_result(message="Can't find this Search App!")
return get_json_result(data=search)
@@ -597,7 +597,7 @@ async def mindmap():
if len(token) != 2:
return get_error_data_result(message='Authorization is not valid!')
token = token[1]
objs = APIToken.query(beta=token)
objs = await thread_pool_exec(APIToken.query, beta=token)
if not objs:
return get_error_data_result(message='Authentication error: API key is invalid!"')
@@ -605,7 +605,7 @@ async def mindmap():
req = await get_request_json()
search_id = req.get("search_id", "")
search_app = SearchService.get_detail(search_id) if search_id else {}
search_app = await thread_pool_exec(SearchService.get_detail, search_id) if search_id else {}
mind_map =await gen_mindmap(req["question"], req["kb_ids"], tenant_id, search_app.get("search_config", {}))
if "error" in mind_map:

View File

@@ -23,7 +23,7 @@ from api.db.db_models import DB, CanvasTemplate, User, UserCanvas, API4Conversat
from api.db.services.api_service import API4ConversationService
from api.db.services.common_service import CommonService
from api.db.services.user_canvas_version import UserCanvasVersionService
from common.misc_utils import get_uuid
from common.misc_utils import get_uuid, thread_pool_exec
from api.utils.api_utils import get_data_openai
import tiktoken
from peewee import fn
@@ -245,7 +245,7 @@ async def completion(tenant_id, agent_id, session_id=None, **kwargs):
release_mode = str(kwargs.get("release", "")).strip().lower()
if session_id:
e, conv = API4ConversationService.get_by_id(session_id)
e, conv = await thread_pool_exec(API4ConversationService.get_by_id, session_id)
if not e:
raise LookupError("Session not found!")
if not conv.message:
@@ -254,15 +254,15 @@ async def completion(tenant_id, agent_id, session_id=None, **kwargs):
conv.dsl = json.dumps(conv.dsl, ensure_ascii=False)
canvas = Canvas(conv.dsl, tenant_id, agent_id, canvas_id=agent_id, custom_header=custom_header)
else:
cvs, dsl = UserCanvasService.get_agent_dsl_with_release(agent_id, release_mode=release_mode == "true", tenant_id=tenant_id)
cvs, dsl = await thread_pool_exec(UserCanvasService.get_agent_dsl_with_release, agent_id, release_mode=release_mode == "true", tenant_id=tenant_id)
session_id = get_uuid()
canvas = Canvas(dsl, tenant_id, agent_id, canvas_id=cvs.id, custom_header=custom_header)
canvas.reset()
# Get the version title based on release_mode
version_title = UserCanvasVersionService.get_latest_version_title(cvs.id, release_mode=release_mode == "true")
version_title = await thread_pool_exec(UserCanvasVersionService.get_latest_version_title, cvs.id, release_mode=release_mode == "true")
conv = {"id": session_id, "dialog_id": cvs.id, "user_id": user_id, "message": [], "source": "agent", "dsl": dsl, "reference": [], "version_title": version_title}
API4ConversationService.save(**conv)
await thread_pool_exec(API4ConversationService.save, **conv)
conv = API4Conversation(**conv)
message_id = str(uuid4())
@@ -288,7 +288,7 @@ async def completion(tenant_id, agent_id, session_id=None, **kwargs):
conv.errors = canvas.error
conv.dsl = str(canvas)
conv = conv.to_dict()
API4ConversationService.append_message(conv["id"], conv)
await thread_pool_exec(API4ConversationService.append_message, conv["id"], conv)
async def completion_openai(tenant_id, agent_id, question, session_id=None, stream=True, **kwargs):

View File

@@ -218,6 +218,11 @@ def _load_chat_module(monkeypatch):
misc_utils_mod = ModuleType("common.misc_utils")
misc_utils_mod.get_uuid = lambda: "generated-chat-id"
async def _thread_pool_exec(func, *args, **kwargs):
return func(*args, **kwargs)
misc_utils_mod.thread_pool_exec = _thread_pool_exec
monkeypatch.setitem(sys.modules, "common.misc_utils", misc_utils_mod)
dialog_service_mod = ModuleType("api.db.services.dialog_service")
@@ -808,7 +813,7 @@ def test_list_chats_returns_old_business_fields(monkeypatch):
)
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _id: (True, _DummyKB()))
res = module.list_chats.__wrapped__()
res = _run(module.list_chats.__wrapped__())
assert res["code"] == 0
chat = res["data"]["chats"][0]
@@ -851,7 +856,7 @@ def test_list_chats_keeps_zero_pagination_semantics(monkeypatch):
monkeypatch.setattr(module.DialogService, "get_by_tenant_ids", _get_by_tenant_ids)
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _id: (True, _DummyKB()))
res = module.list_chats.__wrapped__()
res = _run(module.list_chats.__wrapped__())
assert res["code"] == 0
assert calls[-1] == (0, 0)
@@ -874,7 +879,7 @@ def test_list_chats_keeps_zero_pagination_semantics(monkeypatch):
),
)
res = module.list_chats.__wrapped__()
res = _run(module.list_chats.__wrapped__())
assert res["code"] == 0
assert calls[-1] == (0, 2)
@@ -962,7 +967,7 @@ def test_chat_session_list_projection_unit(monkeypatch):
],
)
res = module.list_sessions.__wrapped__("chat-1")
res = _run(module.list_sessions.__wrapped__("chat-1"))
assert res["data"][0]["chat_id"] == "chat-1"
assert res["data"][0]["messages"][0]["content"] == "hello"
@@ -983,7 +988,7 @@ def test_chat_session_list_projection_unit(monkeypatch):
)
),
)
res = module.list_sessions.__wrapped__("chat-1")
res = _run(module.list_sessions.__wrapped__("chat-1"))
assert res["data"] == []