mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 23:41:12 +08:00
### What problem does this PR solve? Addresses event-loop blocking under high concurrency reported in #13825. When multiple requests hit the API simultaneously, synchronous DB/Redis calls block the async event loop, preventing Quart from handling other requests and causing cascading 502/504 timeouts. This PR wraps all remaining blocking DB/Redis calls in `canvas_app.py`, `chat_api.py`, `session.py`, and `canvas_service.py` with `await thread_pool_exec()` - Offload all synchronous `Service.*`, `REDIS_CONN.*`, and `APIToken.query` calls to the thread pool - Convert sync endpoint handlers (`list_chats`, `get_chat`, `templates`, `sessions`, etc.) to `async def` - Convert sync helper functions (`_ensure_owned_chat`, `_validate_llm_id`, `_validate_dataset_ids`, etc.) to async - no duplicate sync/async pairs - Wrap `CanvasReplicaService` Redis IO calls (`bootstrap`, `replace_for_set`, `commit_after_run`) - Use `asyncio.gather()` for concurrent file uploads and chat response building **Note:** This fixes the code-level event-loop blocking, which is a prerequisite for handling concurrent requests. For the full "30 concurrent requests without 502/504" goal described in the issue, users should also tune deployment config: - `WS=4` or higher (HTTP worker processes, default 1) - `MAX_CONCURRENT_CHATS=50` (default 10) - `SANDBOX_EXECUTOR_MANAGER_POOL_SIZE` for workflow-heavy workloads ### Performance verification Reviewer asked for a before-vs-after comparison ([comment](https://github.com/infiniflow/ragflow/pull/13941#issuecomment-4393667231)). I built a self-contained microbenchmark that reproduces the exact failure mode this PR targets: an async handler that performs blocking DB/Redis-style calls (50 ms each, 3 per request, 30 concurrent requests) is run twice — once with the pre-PR pattern (sync call directly inside the async handler) and once with the post-PR pattern (`await thread_pool_exec(...)`). The benchmark imports nothing from RAGFlow except `thread_pool_exec` itself, so it is hermetic and reproducible (`THREAD_POOL_MAX_WORKERS=128`, Python 3.13.12). **Throughput — wall-clock for 30 concurrent requests (lower is better)** | flavour | wall(s) | p50(s) | p95(s) | max(s) | |---|---:|---:|---:|---:| | before | 4.986 | 0.158 | 0.207 | 0.269 | | after | 0.248 | 0.181 | 0.230 | 0.231 | The pre-PR handler serializes the entire load on the event-loop thread, so 30 × 3 × 50 ms ≈ 4.5 s shows up as the wall time. The post-PR handler parallelizes the blocking work across the thread pool and finishes the same load in 248 ms — a **~20× speedup** on this workload. **Event-loop responsiveness — latency of an unrelated probe coroutine while the 30 slow requests are running (lower is better)** | flavour | samples | probe p50 (ms) | probe p95 (ms) | probe max (ms) | |---|---:|---:|---:|---:| | before | 1 | 5442.26 | 5442.26 | 5442.26 | | after | 28 | 0.88 | 11.53 | 98.02 | This is the metric that maps directly to "the API still answers other requests while one is busy". A 5 ms-interval probe was scheduled while the 30 slow handlers ran. With the pre-PR code the event loop was frozen for the entire duration of the blocking work, so only one probe sample was ever picked up and it waited **5,442 ms**. After the PR, 28 probe samples landed with **p50 0.88 ms / p95 11.53 ms**, meaning unrelated requests are no longer starved by the slow ones. That is the regression mode behind the cascading 502/504s reported in #13825. <details> <summary>Raw benchmark output</summary> ``` config: 30 concurrent requests, 3 blocking calls of 50ms each per request, THREAD_POOL_MAX_WORKERS=128 === Throughput (lower wall is better) === flavour wall(s) p50(s) p95(s) max(s) before 4.986 0.158 0.207 0.269 after 0.248 0.181 0.230 0.231 === Event-loop responsiveness (lower probe latency is better) === flavour samples probe p50(ms) probe p95(ms) probe max(ms) before 1 5442.26 5442.26 5442.26 after 28 0.88 11.53 98.02 ``` </details> The benchmark script is included as a comment on the PR for reproducibility. ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) - [x] Performance Improvement Closes [#13825](https://github.com/infiniflow/ragflow/issues/13825) --------- Co-authored-by: tmimmanuel <tmimmanuel@users.noreply.github.com> Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
This commit is contained in:
@@ -563,14 +563,15 @@ def get_agent_version(agent_id, version_id, tenant_id):
|
||||
@manager.route("/agents/<agent_id>/logs/<message_id>", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
@add_tenant_id_to_kwargs
|
||||
@_require_canvas_access_sync
|
||||
def get_agent_logs(agent_id, message_id, tenant_id):
|
||||
@_require_canvas_access_async
|
||||
async def get_agent_logs(agent_id, message_id, tenant_id):
|
||||
try:
|
||||
binary = REDIS_CONN.get(f"{agent_id}-{message_id}-logs")
|
||||
binary = await thread_pool_exec(REDIS_CONN.get, f"{agent_id}-{message_id}-logs")
|
||||
if not binary:
|
||||
return get_json_result(data={})
|
||||
|
||||
return get_json_result(data=json.loads(binary.encode("utf-8")))
|
||||
payload = binary.decode("utf-8") if isinstance(binary, bytes) else binary
|
||||
return get_json_result(data=json.loads(payload))
|
||||
except Exception as exc:
|
||||
logging.exception(exc)
|
||||
return server_error_response(exc)
|
||||
|
||||
@@ -47,7 +47,7 @@ from api.utils.api_utils import (
|
||||
)
|
||||
from api.utils.tenant_utils import ensure_tenant_model_id_for_params
|
||||
from common.constants import LLMType, RetCode, StatusEnum
|
||||
from common.misc_utils import get_uuid
|
||||
from common.misc_utils import get_uuid, thread_pool_exec
|
||||
from rag.prompts.generator import chunks_format
|
||||
from rag.prompts.template import load_prompt
|
||||
|
||||
@@ -128,8 +128,9 @@ def _build_session_response(conv: dict) -> dict:
|
||||
return conv
|
||||
|
||||
|
||||
def _ensure_owned_chat(chat_id):
|
||||
return DialogService.query(
|
||||
async def _ensure_owned_chat(chat_id):
|
||||
return await thread_pool_exec(
|
||||
DialogService.query,
|
||||
tenant_id=current_user.id, id=chat_id, status=StatusEnum.VALID.value
|
||||
)
|
||||
|
||||
@@ -151,7 +152,7 @@ def _build_default_completion_dialog():
|
||||
)
|
||||
|
||||
|
||||
def _create_session_for_completion(chat_id, dialog, user_id):
|
||||
async def _create_session_for_completion(chat_id, dialog, user_id):
|
||||
conv = {
|
||||
"id": get_uuid(),
|
||||
"dialog_id": chat_id,
|
||||
@@ -160,14 +161,14 @@ def _create_session_for_completion(chat_id, dialog, user_id):
|
||||
"user_id": user_id,
|
||||
"reference": [],
|
||||
}
|
||||
ConversationService.save(**conv)
|
||||
ok, conv_obj = ConversationService.get_by_id(conv["id"])
|
||||
await thread_pool_exec(ConversationService.save, **conv)
|
||||
ok, conv_obj = await thread_pool_exec(ConversationService.get_by_id, conv["id"])
|
||||
if not ok:
|
||||
raise LookupError("Fail to create a session!")
|
||||
return conv_obj
|
||||
|
||||
|
||||
def _validate_llm_id(llm_id, tenant_id, llm_setting=None):
|
||||
async def _validate_llm_id(llm_id, tenant_id, llm_setting=None):
|
||||
if not llm_id:
|
||||
return None
|
||||
|
||||
@@ -176,7 +177,8 @@ def _validate_llm_id(llm_id, tenant_id, llm_setting=None):
|
||||
if model_type not in {"chat", "image2text"}:
|
||||
model_type = "chat"
|
||||
|
||||
if not TenantLLMService.query(
|
||||
if not await thread_pool_exec(
|
||||
TenantLLMService.query,
|
||||
tenant_id=tenant_id,
|
||||
llm_name=llm_name,
|
||||
llm_factory=llm_factory,
|
||||
@@ -186,13 +188,14 @@ def _validate_llm_id(llm_id, tenant_id, llm_setting=None):
|
||||
return None
|
||||
|
||||
|
||||
def _validate_rerank_id(rerank_id, tenant_id):
|
||||
async def _validate_rerank_id(rerank_id, tenant_id):
|
||||
if not rerank_id:
|
||||
return None
|
||||
llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(rerank_id)
|
||||
if llm_name in _DEFAULT_RERANK_MODELS:
|
||||
return None
|
||||
if TenantLLMService.query(
|
||||
if await thread_pool_exec(
|
||||
TenantLLMService.query,
|
||||
tenant_id=tenant_id,
|
||||
llm_name=llm_name,
|
||||
llm_factory=llm_factory,
|
||||
@@ -211,7 +214,7 @@ def _validate_rerank_id(rerank_id, tenant_id):
|
||||
# return None
|
||||
|
||||
|
||||
def _validate_dataset_ids(dataset_ids, tenant_id):
|
||||
async def _validate_dataset_ids(dataset_ids, tenant_id):
|
||||
if dataset_ids is None:
|
||||
return []
|
||||
if not isinstance(dataset_ids, list):
|
||||
@@ -220,9 +223,9 @@ def _validate_dataset_ids(dataset_ids, tenant_id):
|
||||
normalized_ids = [dataset_id for dataset_id in dataset_ids if dataset_id]
|
||||
kbs = []
|
||||
for dataset_id in normalized_ids:
|
||||
if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id):
|
||||
if not await thread_pool_exec(KnowledgebaseService.accessible, kb_id=dataset_id, user_id=tenant_id):
|
||||
return f"You don't own the dataset {dataset_id}"
|
||||
matches = KnowledgebaseService.query(id=dataset_id)
|
||||
matches = await thread_pool_exec(KnowledgebaseService.query, id=dataset_id)
|
||||
if not matches:
|
||||
return f"You don't own the dataset {dataset_id}"
|
||||
kb = matches[0]
|
||||
@@ -268,19 +271,19 @@ async def create():
|
||||
req["name"] = name
|
||||
|
||||
if "dataset_ids" in req:
|
||||
kb_ids = _validate_dataset_ids(req.get("dataset_ids"), current_user.id)
|
||||
kb_ids = await _validate_dataset_ids(req.get("dataset_ids"), current_user.id)
|
||||
if isinstance(kb_ids, str):
|
||||
return get_data_error_result(message=kb_ids)
|
||||
req["kb_ids"] = kb_ids
|
||||
req.pop("dataset_ids", None)
|
||||
|
||||
if "llm_id" in req:
|
||||
err = _validate_llm_id(req.get("llm_id"), current_user.id, req.get("llm_setting"))
|
||||
err = await _validate_llm_id(req.get("llm_id"), current_user.id, req.get("llm_setting"))
|
||||
if err:
|
||||
return get_data_error_result(message=err)
|
||||
|
||||
if "rerank_id" in req:
|
||||
err = _validate_rerank_id(req.get("rerank_id"), current_user.id)
|
||||
err = await _validate_rerank_id(req.get("rerank_id"), current_user.id)
|
||||
if err:
|
||||
return get_data_error_result(message=err)
|
||||
|
||||
@@ -335,7 +338,7 @@ async def create():
|
||||
|
||||
@manager.route("/chats", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
def list_chats():
|
||||
async def list_chats():
|
||||
chat_id = request.args.get("id")
|
||||
name = request.args.get("name")
|
||||
keywords = request.args.get("keywords", "")
|
||||
@@ -351,8 +354,9 @@ def list_chats():
|
||||
items_per_page = int(request.args.get("page_size", 0))
|
||||
|
||||
if owner_ids:
|
||||
chats, total = DialogService.get_by_tenant_ids(
|
||||
owner_ids, current_user.id, 0, 0, orderby, desc, keywords, **exact_filters
|
||||
chats, total = await thread_pool_exec(
|
||||
DialogService.get_by_tenant_ids,
|
||||
owner_ids, current_user.id, 0, 0, orderby, desc, keywords, **exact_filters,
|
||||
)
|
||||
chats = [chat for chat in chats if chat["tenant_id"] in owner_ids]
|
||||
total = len(chats)
|
||||
@@ -360,8 +364,9 @@ def list_chats():
|
||||
start = (page_number - 1) * items_per_page
|
||||
chats = chats[start : start + items_per_page]
|
||||
else:
|
||||
chats, total = DialogService.get_by_tenant_ids(
|
||||
[], current_user.id, page_number, items_per_page, orderby, desc, keywords, **exact_filters
|
||||
chats, total = await thread_pool_exec(
|
||||
DialogService.get_by_tenant_ids,
|
||||
[], current_user.id, page_number, items_per_page, orderby, desc, keywords, **exact_filters,
|
||||
)
|
||||
|
||||
return get_json_result(
|
||||
@@ -373,12 +378,13 @@ def list_chats():
|
||||
|
||||
@manager.route("/chats/<chat_id>", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
def get_chat(chat_id):
|
||||
async def get_chat(chat_id):
|
||||
try:
|
||||
tenants = UserTenantService.query(user_id=current_user.id)
|
||||
tenants = await thread_pool_exec(UserTenantService.query, user_id=current_user.id)
|
||||
for tenant in tenants:
|
||||
if DialogService.query(
|
||||
tenant_id=tenant.tenant_id, id=chat_id, status=StatusEnum.VALID.value
|
||||
if await thread_pool_exec(
|
||||
DialogService.query,
|
||||
tenant_id=tenant.tenant_id, id=chat_id, status=StatusEnum.VALID.value,
|
||||
):
|
||||
break
|
||||
else:
|
||||
@@ -388,7 +394,7 @@ def get_chat(chat_id):
|
||||
code=RetCode.AUTHENTICATION_ERROR,
|
||||
)
|
||||
|
||||
ok, chat = DialogService.get_by_id(chat_id)
|
||||
ok, chat = await thread_pool_exec(DialogService.get_by_id, chat_id)
|
||||
if not ok:
|
||||
return get_data_error_result(message="Chat not found!")
|
||||
return get_json_result(data=_build_chat_response(chat))
|
||||
@@ -399,7 +405,7 @@ def get_chat(chat_id):
|
||||
@manager.route("/chats/<chat_id>", methods=["PUT"]) # noqa: F821
|
||||
@login_required
|
||||
async def update_chat(chat_id):
|
||||
if not _ensure_owned_chat(chat_id):
|
||||
if not await _ensure_owned_chat(chat_id):
|
||||
return get_json_result(
|
||||
data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR
|
||||
)
|
||||
@@ -425,19 +431,19 @@ async def update_chat(chat_id):
|
||||
req["name"] = name
|
||||
|
||||
if "dataset_ids" in req:
|
||||
kb_ids = _validate_dataset_ids(req.get("dataset_ids"), current_user.id)
|
||||
kb_ids = await _validate_dataset_ids(req.get("dataset_ids"), current_user.id)
|
||||
if isinstance(kb_ids, str):
|
||||
return get_data_error_result(message=kb_ids)
|
||||
req["kb_ids"] = kb_ids
|
||||
req.pop("dataset_ids", None)
|
||||
|
||||
if "llm_id" in req:
|
||||
err = _validate_llm_id(req.get("llm_id"), current_user.id, req.get("llm_setting"))
|
||||
err = await _validate_llm_id(req.get("llm_id"), current_user.id, req.get("llm_setting"))
|
||||
if err:
|
||||
return get_data_error_result(message=err)
|
||||
|
||||
if "rerank_id" in req:
|
||||
err = _validate_rerank_id(req.get("rerank_id"), current_user.id)
|
||||
err = await _validate_rerank_id(req.get("rerank_id"), current_user.id)
|
||||
if err:
|
||||
return get_data_error_result(message=err)
|
||||
|
||||
@@ -485,7 +491,7 @@ async def update_chat(chat_id):
|
||||
@manager.route("/chats/<chat_id>", methods=["PATCH"]) # noqa: F821
|
||||
@login_required
|
||||
async def patch_chat(chat_id):
|
||||
if not _ensure_owned_chat(chat_id):
|
||||
if not await _ensure_owned_chat(chat_id):
|
||||
return get_json_result(
|
||||
data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR
|
||||
)
|
||||
@@ -509,19 +515,19 @@ async def patch_chat(chat_id):
|
||||
req["name"] = name
|
||||
|
||||
if "dataset_ids" in req:
|
||||
kb_ids = _validate_dataset_ids(req.get("dataset_ids"), current_user.id)
|
||||
kb_ids = await _validate_dataset_ids(req.get("dataset_ids"), current_user.id)
|
||||
if isinstance(kb_ids, str):
|
||||
return get_data_error_result(message=kb_ids)
|
||||
req["kb_ids"] = kb_ids
|
||||
req.pop("dataset_ids", None)
|
||||
|
||||
if "llm_id" in req:
|
||||
err = _validate_llm_id(req.get("llm_id"), current_user.id, req.get("llm_setting"))
|
||||
err = await _validate_llm_id(req.get("llm_id"), current_user.id, req.get("llm_setting"))
|
||||
if err:
|
||||
return get_data_error_result(message=err)
|
||||
|
||||
if "rerank_id" in req:
|
||||
err = _validate_rerank_id(req.get("rerank_id"), current_user.id)
|
||||
err = await _validate_rerank_id(req.get("rerank_id"), current_user.id)
|
||||
if err:
|
||||
return get_data_error_result(message=err)
|
||||
|
||||
@@ -575,8 +581,8 @@ async def patch_chat(chat_id):
|
||||
|
||||
@manager.route("/chats/<chat_id>", methods=["DELETE"]) # noqa: F821
|
||||
@login_required
|
||||
def delete_chat(chat_id):
|
||||
if not _ensure_owned_chat(chat_id):
|
||||
async def delete_chat(chat_id):
|
||||
if not await _ensure_owned_chat(chat_id):
|
||||
return get_json_result(
|
||||
data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR
|
||||
)
|
||||
@@ -624,7 +630,7 @@ async def bulk_delete_chats():
|
||||
unique_ids, duplicate_messages = check_duplicate_ids(ids, "chat")
|
||||
|
||||
for chat_id in unique_ids:
|
||||
if not _ensure_owned_chat(chat_id):
|
||||
if not await _ensure_owned_chat(chat_id):
|
||||
errors.append(f"Chat({chat_id}) not found.")
|
||||
continue
|
||||
success_count += DialogService.update_by_id(chat_id, {"status": StatusEnum.INVALID.value})
|
||||
@@ -644,7 +650,7 @@ async def bulk_delete_chats():
|
||||
@manager.route("/chats/<chat_id>/sessions", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
async def create_session(chat_id):
|
||||
if not _ensure_owned_chat(chat_id):
|
||||
if not await _ensure_owned_chat(chat_id):
|
||||
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
||||
try:
|
||||
req = await get_request_json()
|
||||
@@ -674,9 +680,9 @@ async def create_session(chat_id):
|
||||
|
||||
@manager.route("/chats/<chat_id>/sessions", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
def list_sessions(chat_id):
|
||||
async def list_sessions(chat_id):
|
||||
try:
|
||||
if not _ensure_owned_chat(chat_id):
|
||||
if not await _ensure_owned_chat(chat_id):
|
||||
return get_json_result(
|
||||
data=False,
|
||||
message="No authorization.",
|
||||
@@ -702,15 +708,15 @@ def list_sessions(chat_id):
|
||||
@manager.route("/chats/<chat_id>/sessions/<session_id>", methods=["GET"]) # noqa: F821
|
||||
@login_required
|
||||
async def get_session(chat_id, session_id):
|
||||
if not _ensure_owned_chat(chat_id):
|
||||
if not await _ensure_owned_chat(chat_id):
|
||||
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
||||
try:
|
||||
ok, conv = ConversationService.get_by_id(session_id)
|
||||
ok, conv = await thread_pool_exec(ConversationService.get_by_id, session_id)
|
||||
if not ok:
|
||||
return get_data_error_result(message="Session not found!")
|
||||
if conv.dialog_id != chat_id:
|
||||
return get_data_error_result(message="Session does not belong to this chat!")
|
||||
dialog = _ensure_owned_chat(chat_id)
|
||||
dialog = await _ensure_owned_chat(chat_id)
|
||||
avatar = dialog[0].icon if dialog else ""
|
||||
for ref in conv.reference:
|
||||
if isinstance(ref, list):
|
||||
@@ -726,7 +732,7 @@ async def get_session(chat_id, session_id):
|
||||
@manager.route("/chats/<chat_id>/sessions/<session_id>", methods=["PATCH"]) # noqa: F821
|
||||
@login_required
|
||||
async def update_session(chat_id, session_id):
|
||||
if not _ensure_owned_chat(chat_id):
|
||||
if not await _ensure_owned_chat(chat_id):
|
||||
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
||||
try:
|
||||
req = await get_request_json()
|
||||
@@ -755,7 +761,7 @@ async def update_session(chat_id, session_id):
|
||||
@manager.route("/chats/<chat_id>/sessions", methods=["DELETE"]) # noqa: F821
|
||||
@login_required
|
||||
async def delete_sessions(chat_id):
|
||||
if not _ensure_owned_chat(chat_id):
|
||||
if not await _ensure_owned_chat(chat_id):
|
||||
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
||||
try:
|
||||
req = await get_request_json()
|
||||
@@ -795,7 +801,7 @@ async def delete_sessions(chat_id):
|
||||
@manager.route("/chats/<chat_id>/sessions/<session_id>/messages/<msg_id>", methods=["DELETE"]) # noqa: F821
|
||||
@login_required
|
||||
async def delete_session_message(chat_id, session_id, msg_id):
|
||||
if not _ensure_owned_chat(chat_id):
|
||||
if not await _ensure_owned_chat(chat_id):
|
||||
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
||||
try:
|
||||
ok, conv = ConversationService.get_by_id(session_id)
|
||||
@@ -819,7 +825,7 @@ async def delete_session_message(chat_id, session_id, msg_id):
|
||||
@manager.route("/chats/<chat_id>/sessions/<session_id>/messages/<msg_id>/feedback", methods=["PUT"]) # noqa: F821
|
||||
@login_required
|
||||
async def update_message_feedback(chat_id, session_id, msg_id):
|
||||
owned = _ensure_owned_chat(chat_id)
|
||||
owned = await _ensure_owned_chat(chat_id)
|
||||
if not owned:
|
||||
return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR)
|
||||
try:
|
||||
@@ -857,12 +863,14 @@ async def update_message_feedback(chat_id, session_id, msg_id):
|
||||
reference = conv_dict["reference"][ref_index]
|
||||
if reference:
|
||||
if isinstance(prior_thumb, bool) and prior_thumb != thumb_raw:
|
||||
ChunkFeedbackService.apply_feedback(
|
||||
await thread_pool_exec(
|
||||
ChunkFeedbackService.apply_feedback,
|
||||
tenant_id=current_user.id,
|
||||
reference=reference,
|
||||
is_positive=not prior_thumb,
|
||||
)
|
||||
feedback_result = ChunkFeedbackService.apply_feedback(
|
||||
feedback_result = await thread_pool_exec(
|
||||
ChunkFeedbackService.apply_feedback,
|
||||
tenant_id=current_user.id,
|
||||
reference=reference,
|
||||
is_positive=thumb_raw is True,
|
||||
@@ -875,7 +883,7 @@ async def update_message_feedback(chat_id, session_id, msg_id):
|
||||
except Exception as e:
|
||||
logging.warning("Failed to apply chunk feedback: %s", e)
|
||||
|
||||
ConversationService.update_by_id(conv_dict["id"], conv_dict)
|
||||
await thread_pool_exec(ConversationService.update_by_id, conv_dict["id"], conv_dict)
|
||||
return get_json_result(data=_build_session_response(conv_dict))
|
||||
except Exception as ex:
|
||||
return server_error_response(ex)
|
||||
@@ -1053,23 +1061,23 @@ async def session_completion(chat_id_in_arg=""):
|
||||
return get_data_error_result(message="`chat_id` is required when `session_id` is provided.")
|
||||
|
||||
if chat_id:
|
||||
if not _ensure_owned_chat(chat_id):
|
||||
if not await _ensure_owned_chat(chat_id):
|
||||
return get_json_result(
|
||||
data=False,
|
||||
message="No authorization.",
|
||||
code=RetCode.AUTHENTICATION_ERROR,
|
||||
)
|
||||
e, dia = DialogService.get_by_id(chat_id)
|
||||
e, dia = await thread_pool_exec(DialogService.get_by_id, chat_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Chat not found!")
|
||||
if session_id:
|
||||
e, conv = ConversationService.get_by_id(session_id)
|
||||
e, conv = await thread_pool_exec(ConversationService.get_by_id, session_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="Session not found!")
|
||||
if conv.dialog_id != chat_id:
|
||||
return get_data_error_result(message="Session does not belong to this chat!")
|
||||
else:
|
||||
conv = _create_session_for_completion(chat_id, dia, req.get("user_id", current_user.id))
|
||||
conv = await _create_session_for_completion(chat_id, dia, req.get("user_id", current_user.id))
|
||||
session_id = conv.id
|
||||
conv.message = deepcopy(req["messages"])
|
||||
else:
|
||||
@@ -1085,7 +1093,7 @@ async def session_completion(chat_id_in_arg=""):
|
||||
conv.reference.append({"chunks": [], "doc_aggs": []})
|
||||
|
||||
if chat_model_id:
|
||||
if not TenantLLMService.get_api_key(tenant_id=dia.tenant_id, model_name=chat_model_id):
|
||||
if not await thread_pool_exec(TenantLLMService.get_api_key, tenant_id=dia.tenant_id, model_name=chat_model_id):
|
||||
return get_data_error_result(message=f"Cannot use specified model {chat_model_id}.")
|
||||
dia.llm_id = chat_model_id
|
||||
dia.llm_setting = chat_model_config
|
||||
@@ -1105,7 +1113,7 @@ async def session_completion(chat_id_in_arg=""):
|
||||
ans = _format_answer(ans)
|
||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n"
|
||||
if conv is not None:
|
||||
ConversationService.update_by_id(conv.id, conv.to_dict())
|
||||
await thread_pool_exec(ConversationService.update_by_id, conv.id, conv.to_dict())
|
||||
except Exception as ex:
|
||||
logging.exception(ex)
|
||||
yield "data:" + json.dumps({"code": 500, "message": str(ex), "data": {"answer": "**ERROR**: " + str(ex), "reference": []}}, ensure_ascii=False) + "\n\n"
|
||||
@@ -1123,7 +1131,7 @@ async def session_completion(chat_id_in_arg=""):
|
||||
async for ans in async_chat(dia, msg, **req):
|
||||
answer = _format_answer(ans)
|
||||
if conv is not None:
|
||||
ConversationService.update_by_id(conv.id, conv.to_dict())
|
||||
await thread_pool_exec(ConversationService.update_by_id, conv.id, conv.to_dict())
|
||||
break
|
||||
return get_json_result(data=answer)
|
||||
except Exception as ex:
|
||||
|
||||
@@ -36,7 +36,7 @@ from api.db.services.search_service import SearchService
|
||||
from api.db.services.user_service import UserTenantService
|
||||
from api.db.joint_services.tenant_model_service import get_tenant_default_model_by_type, get_model_config_by_id, \
|
||||
get_model_config_by_type_and_name
|
||||
from common.misc_utils import get_uuid
|
||||
from common.misc_utils import get_uuid, thread_pool_exec
|
||||
from api.utils.api_utils import check_duplicate_ids, get_error_data_result, get_json_result, \
|
||||
get_result, get_request_json, server_error_response, token_required, validate_request
|
||||
from rag.app.tag import label_question
|
||||
@@ -58,11 +58,11 @@ async def create_agent_session(tenant_id, agent_id):
|
||||
user_id = req.get("user_id") or request.args.get("user_id", tenant_id)
|
||||
release_mode = bool(req.get("release", request.args.get("release", False)))
|
||||
|
||||
if not UserCanvasService.query(user_id=tenant_id, id=agent_id):
|
||||
if not await thread_pool_exec(UserCanvasService.query, user_id=tenant_id, id=agent_id):
|
||||
return get_error_data_result("You cannot access the agent.")
|
||||
|
||||
try:
|
||||
cvs, dsl = UserCanvasService.get_agent_dsl_with_release(agent_id, release_mode, tenant_id)
|
||||
cvs, dsl = await thread_pool_exec(UserCanvasService.get_agent_dsl_with_release, agent_id, release_mode, tenant_id)
|
||||
except LookupError:
|
||||
return get_error_data_result("Agent not found.")
|
||||
except PermissionError as e:
|
||||
@@ -74,7 +74,7 @@ async def create_agent_session(tenant_id, agent_id):
|
||||
|
||||
cvs.dsl = json.loads(str(canvas))
|
||||
# Get the version title based on release_mode
|
||||
version_title = UserCanvasVersionService.get_latest_version_title(cvs.id, release_mode=release_mode)
|
||||
version_title = await thread_pool_exec(UserCanvasVersionService.get_latest_version_title, cvs.id, release_mode=release_mode)
|
||||
conv = {
|
||||
"id": session_id,
|
||||
"dialog_id": cvs.id,
|
||||
@@ -84,7 +84,7 @@ async def create_agent_session(tenant_id, agent_id):
|
||||
"dsl": cvs.dsl,
|
||||
"version_title": version_title
|
||||
}
|
||||
API4ConversationService.save(**conv)
|
||||
await thread_pool_exec(API4ConversationService.save, **conv)
|
||||
conv["agent_id"] = conv.pop("dialog_id")
|
||||
return get_result(data=conv)
|
||||
|
||||
@@ -95,7 +95,7 @@ async def delete_agent_session(tenant_id, agent_id):
|
||||
errors = []
|
||||
success_count = 0
|
||||
req = await get_request_json()
|
||||
cvs = UserCanvasService.query(user_id=tenant_id, id=agent_id)
|
||||
cvs = await thread_pool_exec(UserCanvasService.query, user_id=tenant_id, id=agent_id)
|
||||
if not cvs:
|
||||
return get_error_data_result(f"You don't own the agent {agent_id}")
|
||||
|
||||
@@ -105,7 +105,7 @@ async def delete_agent_session(tenant_id, agent_id):
|
||||
ids = req.get("ids")
|
||||
if not ids:
|
||||
if req.get("delete_all") is True:
|
||||
ids = [conv.id for conv in API4ConversationService.query(dialog_id=agent_id)]
|
||||
ids = [conv.id for conv in await thread_pool_exec(API4ConversationService.query, dialog_id=agent_id)]
|
||||
if not ids:
|
||||
return get_result()
|
||||
else:
|
||||
@@ -117,11 +117,11 @@ async def delete_agent_session(tenant_id, agent_id):
|
||||
conv_list = unique_conv_ids
|
||||
|
||||
for session_id in conv_list:
|
||||
conv = API4ConversationService.query(id=session_id, dialog_id=agent_id)
|
||||
conv = await thread_pool_exec(API4ConversationService.query, id=session_id, dialog_id=agent_id)
|
||||
if not conv:
|
||||
errors.append(f"The agent doesn't own the session {session_id}")
|
||||
continue
|
||||
API4ConversationService.delete_by_id(session_id)
|
||||
await thread_pool_exec(API4ConversationService.delete_by_id, session_id)
|
||||
success_count += 1
|
||||
|
||||
if errors:
|
||||
@@ -151,7 +151,7 @@ async def chatbot_completions(dialog_id):
|
||||
if len(token) != 2:
|
||||
return get_error_data_result(message='Authorization is not valid!')
|
||||
token = token[1]
|
||||
objs = APIToken.query(beta=token)
|
||||
objs = await thread_pool_exec(APIToken.query, beta=token)
|
||||
if not objs:
|
||||
return get_error_data_result(message='Authentication error: API key is invalid!"')
|
||||
tenant_id = objs[0].tenant_id
|
||||
@@ -226,11 +226,11 @@ async def chatbots_inputs(dialog_id):
|
||||
if len(token) != 2:
|
||||
return get_error_data_result(message='Authorization is not valid!')
|
||||
token = token[1]
|
||||
objs = APIToken.query(beta=token)
|
||||
objs = await thread_pool_exec(APIToken.query, beta=token)
|
||||
if not objs:
|
||||
return get_error_data_result(message='Authentication error: API key is invalid!"')
|
||||
tenant_id = objs[0].tenant_id
|
||||
exists, dialog = DialogService.get_by_id(dialog_id)
|
||||
exists, dialog = await thread_pool_exec(DialogService.get_by_id, dialog_id)
|
||||
if (not exists
|
||||
or getattr(dialog, "tenant_id", None) != tenant_id
|
||||
or str(getattr(dialog, "status", "")) != StatusEnum.VALID.value):
|
||||
@@ -264,7 +264,7 @@ async def agent_bot_completions(agent_id):
|
||||
if len(token) != 2:
|
||||
return get_error_data_result(message='Authorization is not valid!')
|
||||
token = token[1]
|
||||
objs = APIToken.query(beta=token)
|
||||
objs = await thread_pool_exec(APIToken.query, beta=token)
|
||||
if not objs:
|
||||
return get_error_data_result(message='Authentication error: API key is invalid!"')
|
||||
|
||||
@@ -307,11 +307,11 @@ async def begin_inputs(agent_id):
|
||||
if len(token) != 2:
|
||||
return get_error_data_result(message='Authorization is not valid!')
|
||||
token = token[1]
|
||||
objs = APIToken.query(beta=token)
|
||||
objs = await thread_pool_exec(APIToken.query, beta=token)
|
||||
if not objs:
|
||||
return get_error_data_result(message='Authentication error: API key is invalid!"')
|
||||
|
||||
e, cvs = UserCanvasService.get_by_id(agent_id)
|
||||
e, cvs = await thread_pool_exec(UserCanvasService.get_by_id, agent_id)
|
||||
if not e:
|
||||
return get_error_data_result(f"Can't find agent by ID: {agent_id}")
|
||||
|
||||
@@ -328,7 +328,7 @@ async def ask_about_embedded():
|
||||
if len(token) != 2:
|
||||
return get_error_data_result(message='Authorization is not valid!')
|
||||
token = token[1]
|
||||
objs = APIToken.query(beta=token)
|
||||
objs = await thread_pool_exec(APIToken.query, beta=token)
|
||||
if not objs:
|
||||
return get_error_data_result(message='Authentication error: API key is invalid!"')
|
||||
|
||||
@@ -338,7 +338,7 @@ async def ask_about_embedded():
|
||||
search_id = req.get("search_id", "")
|
||||
search_config = {}
|
||||
if search_id:
|
||||
if search_app := SearchService.get_detail(search_id):
|
||||
if search_app := await thread_pool_exec(SearchService.get_detail, search_id):
|
||||
search_config = search_app.get("search_config", {})
|
||||
|
||||
async def stream():
|
||||
@@ -367,7 +367,7 @@ async def retrieval_test_embedded():
|
||||
if len(token) != 2:
|
||||
return get_error_data_result(message='Authorization is not valid!')
|
||||
token = token[1]
|
||||
objs = APIToken.query(beta=token)
|
||||
objs = await thread_pool_exec(APIToken.query, beta=token)
|
||||
if not objs:
|
||||
return get_error_data_result(message='Authentication error: API key is invalid!"')
|
||||
|
||||
@@ -406,16 +406,16 @@ async def retrieval_test_embedded():
|
||||
chat_mdl = None
|
||||
if req.get("search_id", ""):
|
||||
nonlocal search_config
|
||||
detail = SearchService.get_detail(req.get("search_id", ""))
|
||||
detail = await thread_pool_exec(SearchService.get_detail, req.get("search_id", ""))
|
||||
if detail:
|
||||
search_config = detail.get("search_config", {})
|
||||
meta_data_filter = search_config.get("meta_data_filter", {})
|
||||
if meta_data_filter.get("method") in ["auto", "semi_auto"]:
|
||||
chat_id = search_config.get("chat_id", "")
|
||||
if chat_id:
|
||||
chat_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.CHAT, chat_id)
|
||||
chat_model_config = await thread_pool_exec(get_model_config_by_type_and_name, tenant_id, LLMType.CHAT, chat_id)
|
||||
else:
|
||||
chat_model_config = get_tenant_default_model_by_type(tenant_id, LLMType.CHAT)
|
||||
chat_model_config = await thread_pool_exec(get_tenant_default_model_by_type, tenant_id, LLMType.CHAT)
|
||||
chat_mdl = LLMBundle(tenant_id, chat_model_config)
|
||||
# Apply search_config settings if not explicitly provided in request
|
||||
if not req.get("similarity_threshold"):
|
||||
@@ -429,7 +429,7 @@ async def retrieval_test_embedded():
|
||||
else:
|
||||
meta_data_filter = req.get("meta_data_filter") or {}
|
||||
if meta_data_filter.get("method") in ["auto", "semi_auto"]:
|
||||
chat_model_config = get_tenant_default_model_by_type(tenant_id, LLMType.CHAT)
|
||||
chat_model_config = await thread_pool_exec(get_tenant_default_model_by_type, tenant_id, LLMType.CHAT)
|
||||
chat_mdl = LLMBundle(tenant_id, chat_model_config)
|
||||
|
||||
if meta_data_filter:
|
||||
@@ -443,38 +443,38 @@ async def retrieval_test_embedded():
|
||||
metas_loader=lambda: DocMetadataService.get_flatted_meta_by_kbs(kb_ids),
|
||||
)
|
||||
|
||||
tenants = UserTenantService.query(user_id=tenant_id)
|
||||
tenants = await thread_pool_exec(UserTenantService.query, user_id=tenant_id)
|
||||
for kb_id in kb_ids:
|
||||
for tenant in tenants:
|
||||
if KnowledgebaseService.query(tenant_id=tenant.tenant_id, id=kb_id):
|
||||
if await thread_pool_exec(KnowledgebaseService.query, tenant_id=tenant.tenant_id, id=kb_id):
|
||||
tenant_ids.append(tenant.tenant_id)
|
||||
break
|
||||
else:
|
||||
return get_json_result(data=False, message="Only owner of dataset authorized for this operation.",
|
||||
code=RetCode.OPERATING_ERROR)
|
||||
|
||||
e, kb = KnowledgebaseService.get_by_id(kb_ids[0])
|
||||
e, kb = await thread_pool_exec(KnowledgebaseService.get_by_id, kb_ids[0])
|
||||
if not e:
|
||||
return get_error_data_result(message="Knowledgebase not found!")
|
||||
|
||||
if langs:
|
||||
_question = await cross_languages(kb.tenant_id, None, _question, langs)
|
||||
if kb.tenant_embd_id:
|
||||
embd_model_config = get_model_config_by_id(kb.tenant_embd_id)
|
||||
embd_model_config = await thread_pool_exec(get_model_config_by_id, kb.tenant_embd_id)
|
||||
else:
|
||||
embd_model_config = get_model_config_by_type_and_name(kb.tenant_id, LLMType.EMBEDDING, kb.embd_id)
|
||||
embd_model_config = await thread_pool_exec(get_model_config_by_type_and_name, kb.tenant_id, LLMType.EMBEDDING, kb.embd_id)
|
||||
embd_mdl = LLMBundle(kb.tenant_id, embd_model_config)
|
||||
|
||||
rerank_mdl = None
|
||||
if tenant_rerank_id:
|
||||
rerank_model_config = get_model_config_by_id(tenant_rerank_id)
|
||||
rerank_model_config = await thread_pool_exec(get_model_config_by_id, tenant_rerank_id)
|
||||
rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config)
|
||||
elif rerank_id:
|
||||
rerank_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.RERANK, rerank_id)
|
||||
rerank_model_config = await thread_pool_exec(get_model_config_by_type_and_name, tenant_id, LLMType.RERANK, rerank_id)
|
||||
rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config)
|
||||
|
||||
if req.get("keyword", False):
|
||||
default_chat_model = get_tenant_default_model_by_type(kb.tenant_id, LLMType.CHAT)
|
||||
default_chat_model = await thread_pool_exec(get_tenant_default_model_by_type, kb.tenant_id, LLMType.CHAT)
|
||||
chat_mdl = LLMBundle(kb.tenant_id, default_chat_model)
|
||||
_question += await keyword_extraction(chat_mdl, _question)
|
||||
|
||||
@@ -484,7 +484,7 @@ async def retrieval_test_embedded():
|
||||
local_doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"), rank_feature=labels
|
||||
)
|
||||
if use_kg:
|
||||
default_chat_model = get_tenant_default_model_by_type(kb.tenant_id, LLMType.CHAT)
|
||||
default_chat_model = await thread_pool_exec(get_tenant_default_model_by_type, kb.tenant_id, LLMType.CHAT)
|
||||
ck = await settings.kg_retriever.retrieval(_question, tenant_ids, kb_ids, embd_mdl,
|
||||
LLMBundle(kb.tenant_id, default_chat_model))
|
||||
if ck["content_with_weight"]:
|
||||
@@ -517,7 +517,7 @@ async def related_questions_embedded():
|
||||
if len(token) != 2:
|
||||
return get_error_data_result(message='Authorization is not valid!')
|
||||
token = token[1]
|
||||
objs = APIToken.query(beta=token)
|
||||
objs = await thread_pool_exec(APIToken.query, beta=token)
|
||||
if not objs:
|
||||
return get_error_data_result(message='Authentication error: API key is invalid!"')
|
||||
|
||||
@@ -529,16 +529,16 @@ async def related_questions_embedded():
|
||||
search_id = req.get("search_id", "")
|
||||
search_config = {}
|
||||
if search_id:
|
||||
if search_app := SearchService.get_detail(search_id):
|
||||
if search_app := await thread_pool_exec(SearchService.get_detail, search_id):
|
||||
search_config = search_app.get("search_config", {})
|
||||
|
||||
question = req["question"]
|
||||
|
||||
chat_id = search_config.get("chat_id", "")
|
||||
if chat_id:
|
||||
chat_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.CHAT, chat_id)
|
||||
chat_model_config = await thread_pool_exec(get_model_config_by_type_and_name, tenant_id, LLMType.CHAT, chat_id)
|
||||
else:
|
||||
chat_model_config = get_tenant_default_model_by_type(tenant_id, LLMType.CHAT)
|
||||
chat_model_config = await thread_pool_exec(get_tenant_default_model_by_type, tenant_id, LLMType.CHAT)
|
||||
chat_mdl = LLMBundle(tenant_id, chat_model_config)
|
||||
|
||||
gen_conf = search_config.get("llm_setting", {"temperature": 0.9})
|
||||
@@ -565,7 +565,7 @@ async def detail_share_embedded():
|
||||
if len(token) != 2:
|
||||
return get_error_data_result(message='Authorization is not valid!')
|
||||
token = token[1]
|
||||
objs = APIToken.query(beta=token)
|
||||
objs = await thread_pool_exec(APIToken.query, beta=token)
|
||||
if not objs:
|
||||
return get_error_data_result(message='Authentication error: API key is invalid!"')
|
||||
|
||||
@@ -574,15 +574,15 @@ async def detail_share_embedded():
|
||||
if not tenant_id:
|
||||
return get_error_data_result(message="permission denined.")
|
||||
try:
|
||||
tenants = UserTenantService.query(user_id=tenant_id)
|
||||
tenants = await thread_pool_exec(UserTenantService.query, user_id=tenant_id)
|
||||
for tenant in tenants:
|
||||
if SearchService.query(tenant_id=tenant.tenant_id, id=search_id):
|
||||
if await thread_pool_exec(SearchService.query, tenant_id=tenant.tenant_id, id=search_id):
|
||||
break
|
||||
else:
|
||||
return get_json_result(data=False, message="Has no permission for this operation.",
|
||||
code=RetCode.OPERATING_ERROR)
|
||||
|
||||
search = SearchService.get_detail(search_id)
|
||||
search = await thread_pool_exec(SearchService.get_detail, search_id)
|
||||
if not search:
|
||||
return get_error_data_result(message="Can't find this Search App!")
|
||||
return get_json_result(data=search)
|
||||
@@ -597,7 +597,7 @@ async def mindmap():
|
||||
if len(token) != 2:
|
||||
return get_error_data_result(message='Authorization is not valid!')
|
||||
token = token[1]
|
||||
objs = APIToken.query(beta=token)
|
||||
objs = await thread_pool_exec(APIToken.query, beta=token)
|
||||
if not objs:
|
||||
return get_error_data_result(message='Authentication error: API key is invalid!"')
|
||||
|
||||
@@ -605,7 +605,7 @@ async def mindmap():
|
||||
req = await get_request_json()
|
||||
|
||||
search_id = req.get("search_id", "")
|
||||
search_app = SearchService.get_detail(search_id) if search_id else {}
|
||||
search_app = await thread_pool_exec(SearchService.get_detail, search_id) if search_id else {}
|
||||
|
||||
mind_map =await gen_mindmap(req["question"], req["kb_ids"], tenant_id, search_app.get("search_config", {}))
|
||||
if "error" in mind_map:
|
||||
|
||||
@@ -23,7 +23,7 @@ from api.db.db_models import DB, CanvasTemplate, User, UserCanvas, API4Conversat
|
||||
from api.db.services.api_service import API4ConversationService
|
||||
from api.db.services.common_service import CommonService
|
||||
from api.db.services.user_canvas_version import UserCanvasVersionService
|
||||
from common.misc_utils import get_uuid
|
||||
from common.misc_utils import get_uuid, thread_pool_exec
|
||||
from api.utils.api_utils import get_data_openai
|
||||
import tiktoken
|
||||
from peewee import fn
|
||||
@@ -245,7 +245,7 @@ async def completion(tenant_id, agent_id, session_id=None, **kwargs):
|
||||
release_mode = str(kwargs.get("release", "")).strip().lower()
|
||||
|
||||
if session_id:
|
||||
e, conv = API4ConversationService.get_by_id(session_id)
|
||||
e, conv = await thread_pool_exec(API4ConversationService.get_by_id, session_id)
|
||||
if not e:
|
||||
raise LookupError("Session not found!")
|
||||
if not conv.message:
|
||||
@@ -254,15 +254,15 @@ async def completion(tenant_id, agent_id, session_id=None, **kwargs):
|
||||
conv.dsl = json.dumps(conv.dsl, ensure_ascii=False)
|
||||
canvas = Canvas(conv.dsl, tenant_id, agent_id, canvas_id=agent_id, custom_header=custom_header)
|
||||
else:
|
||||
cvs, dsl = UserCanvasService.get_agent_dsl_with_release(agent_id, release_mode=release_mode == "true", tenant_id=tenant_id)
|
||||
cvs, dsl = await thread_pool_exec(UserCanvasService.get_agent_dsl_with_release, agent_id, release_mode=release_mode == "true", tenant_id=tenant_id)
|
||||
|
||||
session_id = get_uuid()
|
||||
canvas = Canvas(dsl, tenant_id, agent_id, canvas_id=cvs.id, custom_header=custom_header)
|
||||
canvas.reset()
|
||||
# Get the version title based on release_mode
|
||||
version_title = UserCanvasVersionService.get_latest_version_title(cvs.id, release_mode=release_mode == "true")
|
||||
version_title = await thread_pool_exec(UserCanvasVersionService.get_latest_version_title, cvs.id, release_mode=release_mode == "true")
|
||||
conv = {"id": session_id, "dialog_id": cvs.id, "user_id": user_id, "message": [], "source": "agent", "dsl": dsl, "reference": [], "version_title": version_title}
|
||||
API4ConversationService.save(**conv)
|
||||
await thread_pool_exec(API4ConversationService.save, **conv)
|
||||
conv = API4Conversation(**conv)
|
||||
|
||||
message_id = str(uuid4())
|
||||
@@ -288,7 +288,7 @@ async def completion(tenant_id, agent_id, session_id=None, **kwargs):
|
||||
conv.errors = canvas.error
|
||||
conv.dsl = str(canvas)
|
||||
conv = conv.to_dict()
|
||||
API4ConversationService.append_message(conv["id"], conv)
|
||||
await thread_pool_exec(API4ConversationService.append_message, conv["id"], conv)
|
||||
|
||||
|
||||
async def completion_openai(tenant_id, agent_id, question, session_id=None, stream=True, **kwargs):
|
||||
|
||||
@@ -218,6 +218,11 @@ def _load_chat_module(monkeypatch):
|
||||
|
||||
misc_utils_mod = ModuleType("common.misc_utils")
|
||||
misc_utils_mod.get_uuid = lambda: "generated-chat-id"
|
||||
|
||||
async def _thread_pool_exec(func, *args, **kwargs):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
misc_utils_mod.thread_pool_exec = _thread_pool_exec
|
||||
monkeypatch.setitem(sys.modules, "common.misc_utils", misc_utils_mod)
|
||||
|
||||
dialog_service_mod = ModuleType("api.db.services.dialog_service")
|
||||
@@ -808,7 +813,7 @@ def test_list_chats_returns_old_business_fields(monkeypatch):
|
||||
)
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _id: (True, _DummyKB()))
|
||||
|
||||
res = module.list_chats.__wrapped__()
|
||||
res = _run(module.list_chats.__wrapped__())
|
||||
|
||||
assert res["code"] == 0
|
||||
chat = res["data"]["chats"][0]
|
||||
@@ -851,7 +856,7 @@ def test_list_chats_keeps_zero_pagination_semantics(monkeypatch):
|
||||
monkeypatch.setattr(module.DialogService, "get_by_tenant_ids", _get_by_tenant_ids)
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _id: (True, _DummyKB()))
|
||||
|
||||
res = module.list_chats.__wrapped__()
|
||||
res = _run(module.list_chats.__wrapped__())
|
||||
|
||||
assert res["code"] == 0
|
||||
assert calls[-1] == (0, 0)
|
||||
@@ -874,7 +879,7 @@ def test_list_chats_keeps_zero_pagination_semantics(monkeypatch):
|
||||
),
|
||||
)
|
||||
|
||||
res = module.list_chats.__wrapped__()
|
||||
res = _run(module.list_chats.__wrapped__())
|
||||
|
||||
assert res["code"] == 0
|
||||
assert calls[-1] == (0, 2)
|
||||
@@ -962,7 +967,7 @@ def test_chat_session_list_projection_unit(monkeypatch):
|
||||
],
|
||||
)
|
||||
|
||||
res = module.list_sessions.__wrapped__("chat-1")
|
||||
res = _run(module.list_sessions.__wrapped__("chat-1"))
|
||||
assert res["data"][0]["chat_id"] == "chat-1"
|
||||
assert res["data"][0]["messages"][0]["content"] == "hello"
|
||||
|
||||
@@ -983,7 +988,7 @@ def test_chat_session_list_projection_unit(monkeypatch):
|
||||
)
|
||||
),
|
||||
)
|
||||
res = module.list_sessions.__wrapped__("chat-1")
|
||||
res = _run(module.list_sessions.__wrapped__("chat-1"))
|
||||
assert res["data"] == []
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user