# # Copyright 2026 The InfiniFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import json import logging import math import os import re import tempfile from copy import deepcopy from types import SimpleNamespace from quart import Response, request from api.apps import current_user, login_required from api.apps.restful_apis._generation_params import merge_generation_config, pop_generation_config from api.db.joint_services.tenant_model_service import ( get_tenant_default_model_by_type, get_model_config_from_provider_instance, get_api_key, split_model_name ) from api.db.services.chunk_feedback_service import ChunkFeedbackService from api.db.services.conversation_service import ConversationService, structure_answer from api.db.services.dialog_service import DialogService, async_chat, gen_mindmap from api.db.services.knowledgebase_service import KnowledgebaseService from api.db.services.llm_service import LLMBundle from api.db.services.search_service import SearchService from api.db.services.user_service import TenantService, UserTenantService from api.utils.api_utils import ( check_duplicate_ids, get_data_error_result, get_json_result, get_request_json, server_error_response, validate_request, ) from api.utils.pagination_utils import validate_rest_api_page_size from common.constants import LLMType, RetCode, StatusEnum from common import settings from common.misc_utils import get_uuid, thread_pool_exec from rag.prompts.generator import chunks_format from rag.prompts.template import load_prompt def _sanitize_json_floats(obj): """Replace NaN/Infinity floats with None so the result is RFC 8259 JSON. `json.dumps` emits the literal tokens `NaN`/`Infinity` by default (allow_nan=True). Those tokens are valid Python JSON output but invalid per the JSON spec, and downstream proxies / Go consumers reject the response with `failed to encode response: json: unsupported value: NaN` (fixes #15245). Retrieval scores (similarity, vector_similarity, term_similarity) can become NaN when an aggregation runs over an empty set or when a similarity denominator is zero, so the chat completions stream is the realistic trigger. `isinstance(obj, float)` alone catches Python float and numpy.float64 (a float subclass) but misses numpy.float32 / numpy.float16 and any other duck-typed numeric. Probe via math.isnan/isinf in a try/except so any object math can evaluate gets sanitized — without changing upstream callers like chunks_format or rag/nlp/search.py. """ try: if math.isnan(obj) or math.isinf(obj): return None except TypeError: pass if isinstance(obj, dict): return {k: _sanitize_json_floats(v) for k, v in obj.items()} if isinstance(obj, list): return [_sanitize_json_floats(v) for v in obj] if isinstance(obj, tuple): return tuple(_sanitize_json_floats(v) for v in obj) return obj _DEFAULT_PROMPT_CONFIG = { "system": ( 'You are an intelligent assistant. Please summarize the content of the dataset to answer the question. ' 'Please list the data in the dataset and answer in detail. When all dataset content is irrelevant to the ' 'question, your answer must include the sentence "The answer you are looking for is not found in the dataset!" ' "Answers need to consider chat history.\n" " Here is the knowledge base:\n" " {knowledge}\n" " The above is the knowledge base." ), "prologue": "Hi! I'm your assistant. What can I do for you?", "parameters": [{"key": "knowledge", "optional": False}], "empty_response": "Sorry! No relevant content was found in the knowledge base!", "quote": True, "tts": False, "refine_multiturn": True, } _DEFAULT_DIRECT_CHAT_PROMPT_CONFIG = { "system": "", "prologue": "", "parameters": [], "empty_response": "", "quote": False, "tts": False, "refine_multiturn": True, } _DEFAULT_RERANK_MODELS = {"BAAI/bge-reranker-v2-m3", "maidalun1020/bce-reranker-base_v1"} _READONLY_FIELDS = {"id", "tenant_id", "created_by", "create_time", "create_date", "update_time", "update_date"} _PERSISTED_FIELDS = set(DialogService.model._meta.fields) def _build_chat_response(chat): data = chat.to_dict() if hasattr(chat, "to_dict") else dict(chat) kb_ids, kb_names = _resolve_kb_names(data.get("kb_ids", [])) data["dataset_ids"] = kb_ids data.pop("kb_ids", None) data["kb_names"] = kb_names return data def _resolve_kb_names(kb_ids): ids, names = [], [] for kb_id in kb_ids or []: ok, kb = KnowledgebaseService.get_by_id(kb_id) if not ok or kb.status != StatusEnum.VALID.value: continue ids.append(kb_id) names.append(kb.name) return ids, names def _has_knowledge_placeholder(prompt_config): return "{knowledge}" in (prompt_config or {}).get("system", "") def _validate_name(name, *, required=True): if name is None: if required: return None, "`name` is required." return None, None if not isinstance(name, str): return None, "Chat name must be a string." name = name.strip() if not name: return None, "`name` is required." if required else "`name` cannot be empty." if len(name.encode("utf-8")) > 255: return None, f"Chat name length is {len(name.encode('utf-8'))} which is larger than 255." return name, None def _build_session_response(conv: dict) -> dict: conv = dict(conv) conv["chat_id"] = conv.pop("dialog_id", conv.get("chat_id")) conv["messages"] = conv.pop("message", conv.get("messages", [])) return conv async def _ensure_owned_chat(chat_id): return await thread_pool_exec( DialogService.query, tenant_id=current_user.id, id=chat_id, status=StatusEnum.VALID.value ) def _build_default_completion_dialog(): return SimpleNamespace( tenant_id=current_user.id, llm_id="", tenant_llm_id=None, llm_setting={}, prompt_config=deepcopy(_DEFAULT_DIRECT_CHAT_PROMPT_CONFIG), kb_ids=[], top_n=6, top_k=1024, rerank_id="", similarity_threshold=0.1, vector_similarity_weight=0.3, meta_data_filter=None, ) async def _create_session_for_completion(chat_id, dialog, user_id): conv = { "id": get_uuid(), "dialog_id": chat_id, "name": "New session", "message": [{"role": "assistant", "content": dialog.prompt_config.get("prologue", "")}], "user_id": user_id, "reference": [], } await thread_pool_exec(ConversationService.save, **conv) ok, conv_obj = await thread_pool_exec(ConversationService.get_by_id, conv["id"]) if not ok: raise LookupError("Fail to create a session!") return conv_obj def _get_bool_request_flag(req, *names, default=False): for name in names: if name not in req: continue value = req.pop(name) if isinstance(value, str): return value.strip().lower() in {"1", "true", "yes", "on"} return bool(value) return default def _normalize_completion_messages(req): messages = req.get("messages") if messages is None: question = req.get("question") if question is None: return None, get_data_error_result( code=RetCode.ARGUMENT_ERROR, message="required argument are missing: messages", ) messages = [{"role": "user", "content": question}] if req.get("files"): messages[-1]["files"] = req["files"] if not isinstance(messages, list) or not messages: return None, get_data_error_result( code=RetCode.ARGUMENT_ERROR, message="`messages` must be a non-empty list.", ) for message in messages: if not isinstance(message, dict): return None, get_data_error_result( code=RetCode.ARGUMENT_ERROR, message="Every item in `messages` must be an object.", ) if "role" not in message or "content" not in message: return None, get_data_error_result( code=RetCode.ARGUMENT_ERROR, message="Every item in `messages` must include `role` and `content`.", ) msg = [] for m in messages: if m["role"] == "system": continue if m["role"] == "assistant" and not msg: continue msg.append(m) if not msg: return None, get_data_error_result( code=RetCode.ARGUMENT_ERROR, message="`messages` must contain a user message.", ) if msg[-1]["role"] != "user": return None, get_data_error_result( code=RetCode.ARGUMENT_ERROR, message="The last message must be from user.", ) if not msg[-1].get("id"): msg[-1]["id"] = get_uuid() # till now, message and msg are sharing the same copy return (messages, msg), None async def _validate_llm_id(llm_id, tenant_id, llm_setting=None): if not llm_id: return None conf_model_type = (llm_setting or {}).get("model_type") if isinstance(conf_model_type, str): model_type = conf_model_type if conf_model_type in {"chat", "image2text"} else "chat" elif isinstance(conf_model_type, list): model_type = "image2text" if "image2text" in conf_model_type else "chat" else: model_type = "chat" try: await thread_pool_exec( get_model_config_from_provider_instance, tenant_id=tenant_id, model_name=llm_id, model_type=model_type, ) except Exception as e: logging.error(f"Fail to get model config for {llm_id}: {e}") return f"`llm_id` {llm_id} doesn't exist" return None async def _validate_rerank_id(rerank_id, tenant_id): if not rerank_id: return None parts = rerank_id.split('@') llm_name = parts[0] if llm_name in _DEFAULT_RERANK_MODELS: return None try: await thread_pool_exec( get_model_config_from_provider_instance, tenant_id=tenant_id, model_name=rerank_id, model_type="rerank", ) except Exception as e: logging.error(f"Fail to get model config for {rerank_id}: {e}") return f"`rerank_id` {rerank_id} doesn't exist" return None # def _validate_prompt_config(prompt_config): # for parameter in prompt_config.get("parameters", []): # if parameter.get("optional"): # continue # if prompt_config.get("system", "").find("{%s}" % parameter["key"]) < 0: # return f"Parameter '{parameter['key']}' is not used" # return None async def _validate_dataset_ids(dataset_ids, tenant_id): if dataset_ids is None: return [] if not isinstance(dataset_ids, list): return "`dataset_ids` should be a list." normalized_ids = [dataset_id for dataset_id in dataset_ids if dataset_id] kbs = [] for dataset_id in normalized_ids: if not await thread_pool_exec(KnowledgebaseService.accessible, kb_id=dataset_id, user_id=tenant_id): return f"You don't own the dataset {dataset_id}" matches = await thread_pool_exec(KnowledgebaseService.query, id=dataset_id) if not matches: return f"You don't own the dataset {dataset_id}" kb = matches[0] if kb.chunk_num == 0: return f"The dataset {dataset_id} doesn't own parsed file" kbs.append(kb) embd_ids = [split_model_name(kb.embd_id)[0] for kb in kbs] if len(set(embd_ids)) > 1: return f'Datasets use different embedding models: {[kb.embd_id for kb in kbs]}' return normalized_ids def _apply_prompt_defaults(req): prompt_config = req.setdefault("prompt_config", {}) for key, value in _DEFAULT_PROMPT_CONFIG.items(): temp = prompt_config.get(key) if (key == "system" and not temp) or key not in prompt_config: prompt_config[key] = deepcopy(value) if req.get("kb_ids") and not prompt_config.get("parameters") and "{knowledge}" in prompt_config.get("system", ""): prompt_config["parameters"] = [{"key": "knowledge", "optional": False}] @manager.route("/chats", methods=["POST"]) # noqa: F821 @login_required async def create(): try: req = await get_request_json() ok, tenant = TenantService.get_by_id(current_user.id) if not ok: return get_data_error_result(message="Tenant not found!") # Validate tenant_id should not be provided if req.get("tenant_id"): return get_data_error_result(message="`tenant_id` must not be provided.") # Validate name name, err = _validate_name(req.get("name"), required=True) if err: return get_data_error_result(message=err) req["name"] = name if "dataset_ids" in req: kb_ids = await _validate_dataset_ids(req.get("dataset_ids"), current_user.id) if isinstance(kb_ids, str): return get_data_error_result(message=kb_ids) req["kb_ids"] = kb_ids req.pop("dataset_ids", None) if "llm_id" in req: err = await _validate_llm_id(req.get("llm_id"), current_user.id, req.get("llm_setting")) if err: return get_data_error_result(message=err) if "rerank_id" in req: err = await _validate_rerank_id(req.get("rerank_id"), current_user.id) if err: return get_data_error_result(message=err) if "prompt_config" in req: if not isinstance(req["prompt_config"], dict): return get_data_error_result(message="`prompt_config` should be an object.") # err = _validate_prompt_config(req["prompt_config"]) # if err: # return get_data_error_result(message=err) req.setdefault("kb_ids", []) req.setdefault("llm_id", tenant.llm_id) if req["llm_id"] is None: req["llm_id"] = tenant.llm_id req.setdefault("llm_setting", {}) req.setdefault("description", "A helpful Assistant") req.setdefault("top_n", 6) req.setdefault("top_k", 1024) req.setdefault("rerank_id", "") req.setdefault("similarity_threshold", 0.1) req.setdefault("vector_similarity_weight", 0.3) req.setdefault("icon", "") _apply_prompt_defaults(req) # err = _validate_prompt_config(req["prompt_config"]) # if err: # return get_data_error_result(message=err) req = {field: value for field, value in req.items() if field in _PERSISTED_FIELDS} for field in _READONLY_FIELDS: req.pop(field, None) if DialogService.query( name=req["name"], tenant_id=current_user.id, status=StatusEnum.VALID.value, ): return get_data_error_result(message="Duplicated chat name in creating chat.") req["id"] = get_uuid() req["tenant_id"] = current_user.id if not DialogService.save(**req): return get_data_error_result(message="Failed to create chat.") ok, chat = DialogService.get_by_id(req["id"]) if not ok: return get_data_error_result(message="Failed to retrieve created chat.") return get_json_result(data=_build_chat_response(chat)) except Exception as ex: return server_error_response(ex) @manager.route("/chats", methods=["GET"]) # noqa: F821 @login_required async def list_chats(): chat_id = request.args.get("id") name = request.args.get("name") keywords = request.args.get("keywords", "") orderby = request.args.get("orderby", "create_time") desc = request.args.get("desc", "true").lower() != "false" owner_ids = request.args.getlist("owner_ids") exact_filters = {"id": chat_id, "name": name} if chat_id or name: keywords = "" try: page_number = int(request.args.get("page", 0)) items_per_page = validate_rest_api_page_size(int(request.args.get("page_size", 0))) if owner_ids: chats, total = await thread_pool_exec( DialogService.get_by_tenant_ids, owner_ids, current_user.id, 0, 0, orderby, desc, keywords, **exact_filters, ) chats = [chat for chat in chats if chat["tenant_id"] in owner_ids] total = len(chats) if page_number and items_per_page: start = (page_number - 1) * items_per_page chats = chats[start : start + items_per_page] else: chats, total = await thread_pool_exec( DialogService.get_by_tenant_ids, [], current_user.id, page_number, items_per_page, orderby, desc, keywords, **exact_filters, ) return get_json_result( data={"chats": [_build_chat_response(chat) for chat in chats], "total": total} ) except Exception as ex: return server_error_response(ex) @manager.route("/chats/", methods=["GET"]) # noqa: F821 @login_required async def get_chat(chat_id): try: tenants = await thread_pool_exec(UserTenantService.query, user_id=current_user.id) for tenant in tenants: if await thread_pool_exec( DialogService.query, tenant_id=tenant.tenant_id, id=chat_id, status=StatusEnum.VALID.value, ): break else: return get_json_result( data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR, ) ok, chat = await thread_pool_exec(DialogService.get_by_id, chat_id) if not ok: return get_data_error_result(message="Chat not found!") return get_json_result(data=_build_chat_response(chat)) except Exception as ex: return server_error_response(ex) @manager.route("/chats/", methods=["PUT"]) # noqa: F821 @login_required async def update_chat(chat_id): if not await _ensure_owned_chat(chat_id): return get_json_result( data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR ) try: req = await get_request_json() ok, tenant = TenantService.get_by_id(current_user.id) if not ok: return get_data_error_result(message="Tenant not found!") ok, current_chat = DialogService.get_by_id(chat_id) if not ok: return get_data_error_result(message="Chat not found!") current_chat = current_chat.to_dict() if req.get("tenant_id"): return get_data_error_result(message="`tenant_id` must not be provided.") if "name" in req: name, err = _validate_name(req.get("name"), required=True) if err: return get_data_error_result(message=err) req["name"] = name if "dataset_ids" in req: kb_ids = await _validate_dataset_ids(req.get("dataset_ids"), current_user.id) if isinstance(kb_ids, str): return get_data_error_result(message=kb_ids) req["kb_ids"] = kb_ids req.pop("dataset_ids", None) if "llm_id" in req: err = await _validate_llm_id(req.get("llm_id"), current_user.id, req.get("llm_setting")) if err: return get_data_error_result(message=err) if "rerank_id" in req: err = await _validate_rerank_id(req.get("rerank_id"), current_user.id) if err: return get_data_error_result(message=err) if "prompt_config" in req: if not isinstance(req["prompt_config"], dict): return get_data_error_result(message="`prompt_config` should be an object.") # err = _validate_prompt_config(req["prompt_config"]) # if err: # return get_data_error_result(message=err) # prompt_config = req.get("prompt_config", {}) # if not prompt_config: # prompt_config = current_chat.get("prompt_config", {}) # kb_ids = req.get("kb_ids", current_chat.get("kb_ids", [])) # if not kb_ids and not prompt_config.get("tavily_api_key") and _has_knowledge_placeholder(prompt_config): # return get_data_error_result(message="Please remove `{knowledge}` in system prompt since no dataset / Tavily used here.") req = {field: value for field, value in req.items() if field in _PERSISTED_FIELDS} for field in _READONLY_FIELDS: req.pop(field, None) if ( "name" in req and req["name"].lower() != current_chat["name"].lower() and DialogService.query( name=req["name"], tenant_id=current_user.id, status=StatusEnum.VALID.value, ) ): return get_data_error_result(message="Duplicated chat name.") if not DialogService.update_by_id(chat_id, req): return get_data_error_result(message="Chat not found!") ok, chat = DialogService.get_by_id(chat_id) if not ok: return get_data_error_result(message="Failed to retrieve updated chat.") return get_json_result(data=_build_chat_response(chat)) except Exception as ex: return server_error_response(ex) @manager.route("/chats/", methods=["PATCH"]) # noqa: F821 @login_required async def patch_chat(chat_id): if not await _ensure_owned_chat(chat_id): return get_json_result( data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR ) try: req = await get_request_json() ok, tenant = TenantService.get_by_id(current_user.id) if not ok: return get_data_error_result(message="Tenant not found!") ok, current_chat = DialogService.get_by_id(chat_id) if not ok: return get_data_error_result(message="Chat not found!") current_chat = current_chat.to_dict() if "name" in req: name, err = _validate_name(req.get("name"), required=False) if err: return get_data_error_result(message=err) if name is not None: req["name"] = name if "dataset_ids" in req: kb_ids = await _validate_dataset_ids(req.get("dataset_ids"), current_user.id) if isinstance(kb_ids, str): return get_data_error_result(message=kb_ids) req["kb_ids"] = kb_ids req.pop("dataset_ids", None) if "llm_id" in req: err = await _validate_llm_id(req.get("llm_id"), current_user.id, req.get("llm_setting")) if err: return get_data_error_result(message=err) if "rerank_id" in req: err = await _validate_rerank_id(req.get("rerank_id"), current_user.id) if err: return get_data_error_result(message=err) if "prompt_config" in req: if not isinstance(req["prompt_config"], dict): return get_data_error_result(message="`prompt_config` should be an object.") prompt_config = deepcopy(current_chat.get("prompt_config", {})) prompt_config.update(req["prompt_config"]) req["prompt_config"] = prompt_config # err = _validate_prompt_config(prompt_config) # if err: # return get_data_error_result(message=err) if "llm_setting" in req: llm_setting = deepcopy(current_chat.get("llm_setting", {})) llm_setting.update(req["llm_setting"]) req["llm_setting"] = llm_setting # if "prompt_config" in req or "kb_ids" in req: # prompt_config = req.get("prompt_config", current_chat.get("prompt_config", {})) # kb_ids = req.get("kb_ids", current_chat.get("kb_ids", [])) # if not kb_ids and not prompt_config.get("tavily_api_key") and _has_knowledge_placeholder(prompt_config): # return get_data_error_result(message="Please remove `{knowledge}` in system prompt since no dataset / Tavily used here.") req = {field: value for field, value in req.items() if field in _PERSISTED_FIELDS} for field in _READONLY_FIELDS: req.pop(field, None) if ( "name" in req and req["name"].lower() != current_chat["name"].lower() and DialogService.query( name=req["name"], tenant_id=current_user.id, status=StatusEnum.VALID.value, ) ): return get_data_error_result(message="Duplicated chat name.") if not DialogService.update_by_id(chat_id, req): return get_data_error_result(message="Failed to update chat.") ok, chat = DialogService.get_by_id(chat_id) if not ok: return get_data_error_result(message="Failed to retrieve updated chat.") return get_json_result(data=_build_chat_response(chat)) except Exception as ex: return server_error_response(ex) @manager.route("/chats/", methods=["DELETE"]) # noqa: F821 @login_required async def delete_chat(chat_id): if not await _ensure_owned_chat(chat_id): return get_json_result( data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR ) try: if not DialogService.update_by_id(chat_id, {"status": StatusEnum.INVALID.value}): return get_data_error_result(message=f"Failed to delete chat {chat_id}") return get_json_result(data=True) except Exception as ex: return server_error_response(ex) @manager.route("/chats", methods=["DELETE"]) # noqa: F821 @login_required async def bulk_delete_chats(): req = await get_request_json() if not req: return get_json_result(data={}) ids = req.get("ids") if not ids: if req.get("delete_all") is True: ids = [ chat.id for chat in DialogService.query( tenant_id=current_user.id, status=StatusEnum.VALID.value ) ] if not ids: return get_json_result(data={}) else: # keep backward compatibility, DELETE with chat_id in request body chat_id = req.get("chat_id") if chat_id: try: if not DialogService.update_by_id(chat_id, {"status": StatusEnum.INVALID.value}): return get_data_error_result(message=f"Failed to delete chat {chat_id}") return get_json_result(data=True) except Exception as ex: return server_error_response(ex) return get_json_result(data={}) errors = [] success_count = 0 unique_ids, duplicate_messages = check_duplicate_ids(ids, "chat") for chat_id in unique_ids: if not await _ensure_owned_chat(chat_id): errors.append(f"Chat({chat_id}) not found.") continue success_count += DialogService.update_by_id(chat_id, {"status": StatusEnum.INVALID.value}) all_errors = errors + duplicate_messages if all_errors: if success_count > 0: return get_json_result( data={"success_count": success_count, "errors": all_errors}, message=f"Partially deleted {success_count} chats with {len(all_errors)} errors", ) return get_data_error_result(message="; ".join(all_errors)) return get_json_result(data={"success_count": success_count}) @manager.route("/chats//sessions", methods=["POST"]) # noqa: F821 @login_required async def create_session(chat_id): """Create a new conversation session for the given chat, owned by the authenticated user.""" if not await _ensure_owned_chat(chat_id): return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) try: req = await get_request_json() ok, dia = DialogService.get_by_id(chat_id) if not ok: return get_data_error_result(message="Chat not found!") name = req.get("name", "New session") if not isinstance(name, str) or not name.strip(): return get_data_error_result(message="`name` can not be empty.") name = name.strip()[:255] conv = { "id": get_uuid(), "dialog_id": chat_id, "name": name, "message": [{"role": "assistant", "content": dia.prompt_config.get("prologue", "")}], "user_id": current_user.id, "reference": [], } ConversationService.save(**conv) ok, conv_obj = ConversationService.get_by_id(conv["id"]) if not ok: return get_data_error_result(message="Fail to create a session!") return get_json_result(data=_build_session_response(conv_obj.to_dict())) except Exception as ex: return server_error_response(ex) @manager.route("/chats//sessions", methods=["GET"]) # noqa: F821 @login_required async def list_sessions(chat_id): try: if not await _ensure_owned_chat(chat_id): return get_json_result( data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR, ) page_number = int(request.args.get("page", 1)) items_per_page = validate_rest_api_page_size(int(request.args.get("page_size", 30))) orderby = request.args.get("orderby", "create_time") desc = request.args.get("desc", "true").lower() != "false" session_id = request.args.get("id") name = request.args.get("name") user_id = request.args.get("user_id") convs = ConversationService.get_list( chat_id, page_number, items_per_page, orderby, desc, session_id, name, user_id ) if items_per_page == 0: convs = [] return get_json_result(data=[_build_session_response(c) for c in convs]) except Exception as ex: return server_error_response(ex) @manager.route("/chats//sessions/", methods=["GET"]) # noqa: F821 @login_required async def get_session(chat_id, session_id): if not await _ensure_owned_chat(chat_id): return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) try: ok, conv = await thread_pool_exec(ConversationService.get_by_id, session_id) if not ok: return get_data_error_result(message="Session not found!") if conv.dialog_id != chat_id: return get_data_error_result(message="Session does not belong to this chat!") dialog = await _ensure_owned_chat(chat_id) avatar = dialog[0].icon if dialog else "" for ref in conv.reference: if isinstance(ref, list): continue ref["chunks"] = chunks_format(ref) result = _build_session_response(conv.to_dict()) result["avatar"] = avatar return get_json_result(data=result) except Exception as ex: return server_error_response(ex) @manager.route("/chats//sessions/", methods=["PATCH"]) # noqa: F821 @login_required async def update_session(chat_id, session_id): if not await _ensure_owned_chat(chat_id): return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) try: req = await get_request_json() if not ConversationService.query(id=session_id, dialog_id=chat_id): return get_data_error_result(message="Session not found!") if "message" in req or "messages" in req: return get_data_error_result(message="`messages` cannot be changed.") if "reference" in req: return get_data_error_result(message="`reference` cannot be changed.") name = req.get("name") if name is not None: if not isinstance(name, str) or not name.strip(): return get_data_error_result(message="`name` can not be empty.") req["name"] = name.strip()[:255] update_fields = {k: v for k, v in req.items() if k not in {"id", "dialog_id", "chat_id", "user_id"}} if not ConversationService.update_by_id(session_id, update_fields): return get_data_error_result(message="Session not found!") ok, conv = ConversationService.get_by_id(session_id) if not ok: return get_data_error_result(message="Fail to update a session!") return get_json_result(data=_build_session_response(conv.to_dict())) except Exception as ex: return server_error_response(ex) @manager.route("/chats//sessions", methods=["DELETE"]) # noqa: F821 @login_required async def delete_sessions(chat_id): if not await _ensure_owned_chat(chat_id): return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) try: req = await get_request_json() if not req: return get_json_result(data={}) session_ids = req.get("ids") if not session_ids: if req.get("delete_all") is True: session_ids = [conv.id for conv in ConversationService.query(dialog_id=chat_id)] if not session_ids: return get_json_result(data={}) else: return get_json_result(data={}) unique_ids, duplicate_messages = check_duplicate_ids(session_ids, "session") errors = [] success_count = 0 for sid in unique_ids: if not ConversationService.query(id=sid, dialog_id=chat_id): errors.append(f"The chat doesn't own the session {sid}") continue ok, conv = ConversationService.get_by_id(sid) if ok: for msg in conv.message or []: for file in msg.get("files") or []: file_id = file.get("id") if not file_id: continue try: settings.STORAGE_IMPL.rm(f"{current_user.id}-downloads", file_id) except Exception: logging.warning("Failed to delete chat upload blob %s/%s", current_user.id, file_id) ConversationService.delete_by_id(sid) success_count += 1 all_errors = errors + duplicate_messages if all_errors: if success_count > 0: return get_json_result( data={"success_count": success_count, "errors": all_errors}, message=f"Partially deleted {success_count} sessions with {len(all_errors)} errors", ) return get_data_error_result(message="; ".join(all_errors)) return get_json_result(data=True) except Exception as ex: return server_error_response(ex) @manager.route("/chats//sessions//messages/", methods=["DELETE"]) # noqa: F821 @login_required async def delete_session_message(chat_id, session_id, msg_id): if not await _ensure_owned_chat(chat_id): return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) try: ok, conv = ConversationService.get_by_id(session_id) if not ok or conv.dialog_id != chat_id: return get_data_error_result(message="Session not found!") conv = conv.to_dict() for i, msg in enumerate(conv["message"]): if msg_id != msg.get("id", ""): continue assert conv["message"][i + 1]["id"] == msg_id conv["message"].pop(i) conv["message"].pop(i) conv["reference"].pop(max(0, i // 2 - 1)) break ConversationService.update_by_id(conv["id"], conv) return get_json_result(data=_build_session_response(conv)) except Exception as ex: return server_error_response(ex) @manager.route("/chats//sessions//messages//feedback", methods=["PUT"]) # noqa: F821 @login_required async def update_message_feedback(chat_id, session_id, msg_id): owned = await _ensure_owned_chat(chat_id) if not owned: return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) try: req = await get_request_json() ok, conv = ConversationService.get_by_id(session_id) if not ok or conv.dialog_id != chat_id: return get_data_error_result(message="Session not found!") thumb_raw = req.get("thumbup") if not isinstance(thumb_raw, bool): return get_data_error_result(message="thumbup must be a boolean") feedback = req.get("feedback", "") conv_dict = conv.to_dict() message_index = None apply_chunk_feedback = False prior_thumb = None for i, msg in enumerate(conv_dict["message"]): if msg_id == msg.get("id", "") and msg.get("role", "") == "assistant": prior_thumb = msg.get("thumbup") if thumb_raw is True: msg["thumbup"] = True msg.pop("feedback", None) apply_chunk_feedback = prior_thumb is not True else: msg["thumbup"] = False if feedback: msg["feedback"] = feedback apply_chunk_feedback = prior_thumb is not False message_index = i break if message_index is not None and apply_chunk_feedback: try: ref_index = (message_index - 1) // 2 if 0 <= ref_index < len(conv_dict.get("reference", [])): reference = conv_dict["reference"][ref_index] if reference: if isinstance(prior_thumb, bool) and prior_thumb != thumb_raw: await thread_pool_exec( ChunkFeedbackService.apply_feedback, tenant_id=current_user.id, reference=reference, is_positive=not prior_thumb, ) feedback_result = await thread_pool_exec( ChunkFeedbackService.apply_feedback, tenant_id=current_user.id, reference=reference, is_positive=thumb_raw is True, ) logging.debug( "Chunk feedback applied: %s succeeded, %s failed", feedback_result["success_count"], feedback_result["fail_count"], ) except Exception as e: logging.warning("Failed to apply chunk feedback: %s", e) await thread_pool_exec(ConversationService.update_by_id, conv_dict["id"], conv_dict) return get_json_result(data=_build_session_response(conv_dict)) except Exception as ex: return server_error_response(ex) @manager.route("/chat/audio/speech", methods=["POST"]) # noqa: F821 @login_required async def tts(): req = await get_request_json() text = req["text"] try: default_tts_model_config = get_tenant_default_model_by_type(current_user.id, LLMType.TTS) except Exception as e: return get_data_error_result(message=str(e)) tts_mdl = LLMBundle(current_user.id, default_tts_model_config) def stream_audio(): try: for txt in re.split(r"[,。/《》?;:!\n\r:;]+", text): for chunk in tts_mdl.tts(txt): yield chunk except Exception as e: yield ("data:" + json.dumps({"code": 500, "message": str(e), "data": {"answer": "**ERROR**: " + str(e)}}, ensure_ascii=False)).encode("utf-8") resp = Response(stream_audio(), mimetype="audio/mpeg") resp.headers.add_header("Cache-Control", "no-cache") resp.headers.add_header("Connection", "keep-alive") resp.headers.add_header("X-Accel-Buffering", "no") return resp @manager.route("/chat/audio/transcription", methods=["POST"]) # noqa: F821 @login_required async def transcription(): req = await request.form stream_mode = req.get("stream", "false").lower() == "true" files = await request.files if "file" not in files: return get_data_error_result(message="Missing 'file' in multipart form-data") uploaded = files["file"] ALLOWED_EXTS = { ".wav", ".mp3", ".m4a", ".aac", ".flac", ".ogg", ".webm", ".opus", ".wma", } filename = uploaded.filename or "" suffix = os.path.splitext(filename)[-1].lower() if suffix not in ALLOWED_EXTS: return get_data_error_result( message=f"Unsupported audio format: {suffix}. Allowed: {', '.join(sorted(ALLOWED_EXTS))}" ) fd, temp_audio_path = tempfile.mkstemp(suffix=suffix) os.close(fd) await uploaded.save(temp_audio_path) try: default_asr_model_config = get_tenant_default_model_by_type(current_user.id, LLMType.SPEECH2TEXT) except Exception as e: return get_data_error_result(message=str(e)) asr_mdl = LLMBundle(current_user.id, default_asr_model_config) if not stream_mode: text = asr_mdl.transcription(temp_audio_path) try: os.remove(temp_audio_path) except Exception as e: logging.error(f"Failed to remove temp audio file: {str(e)}") return get_json_result(data={"text": text}) async def event_stream(): try: for evt in asr_mdl.stream_transcription(temp_audio_path): yield f"data: {json.dumps(evt, ensure_ascii=False)}\n\n" except Exception as e: err = {"event": "error", "text": str(e)} yield f"data: {json.dumps(err, ensure_ascii=False)}\n\n" finally: try: os.remove(temp_audio_path) except Exception as e: logging.error(f"Failed to remove temp audio file: {str(e)}") return Response(event_stream(), content_type="text/event-stream") @manager.route("/chat/mindmap", methods=["POST"]) # noqa: F821 @login_required @validate_request("question", "kb_ids") async def mindmap(): req = await get_request_json() search_id = req.get("search_id", "") search_app = SearchService.get_detail(search_id) if search_id else {} search_config = search_app.get("search_config", {}) if search_app else {} kb_ids = search_config.get("kb_ids", []) kb_ids.extend(req["kb_ids"]) kb_ids = list(set(kb_ids)) mind_map = await gen_mindmap(req["question"], kb_ids, search_app.get("tenant_id", current_user.id), search_config) if "error" in mind_map: return server_error_response(Exception(mind_map["error"])) return get_json_result(data=mind_map) @manager.route("/chat/recommendation", methods=["POST"]) # noqa: F821 @login_required @validate_request("question") async def recommendation(): req = await get_request_json() search_id = req.get("search_id", "") search_config = {} if search_id: if search_app := SearchService.get_detail(search_id): search_config = search_app.get("search_config", {}) question = req["question"] chat_id = search_config.get("chat_id", "") if chat_id: chat_model_config = get_model_config_from_provider_instance(current_user.id, LLMType.CHAT, chat_id) else: chat_model_config = get_tenant_default_model_by_type(current_user.id, LLMType.CHAT) chat_mdl = LLMBundle(current_user.id, chat_model_config) gen_conf = search_config.get("llm_setting", {"temperature": 0.9}) if "parameter" in gen_conf: del gen_conf["parameter"] prompt = load_prompt("related_question") ans = await chat_mdl.async_chat( prompt, [ { "role": "user", "content": f"\nKeywords: {question}\nRelated search terms:\n ", } ], gen_conf, ) return get_json_result(data=[re.sub(r"^[0-9]\. ", "", a) for a in ans.split("\n") if re.match(r"^[0-9]\. ", a)]) @manager.route("/chat/completions", methods=["POST"]) # noqa: F821 @login_required async def session_completion(chat_id_in_arg=""): """Handle chat completion requests, streaming or non-streaming, scoped to the authenticated user.""" req = await get_request_json() normalized, error = _normalize_completion_messages(req) if error: return error request_messages, request_msg = normalized pass_all_history_messages = _get_bool_request_flag(req, "pass_all_history_messages", "pass_all_history", default=False) msg = request_msg message_id = request_msg[-1].get("id") chat_id = req.pop("chat_id", "") or "" chat_id = chat_id or chat_id_in_arg session_id = req.pop("session_id", "") or req.pop("conversation_id", "") or "" chat_model_id = req.pop("llm_id", "") chat_model_config = pop_generation_config(req) try: conv = None if session_id and not chat_id: return get_data_error_result(message="`chat_id` is required when `session_id` is provided.") if chat_id: if not await _ensure_owned_chat(chat_id): return get_json_result( data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR, ) e, dia = await thread_pool_exec(DialogService.get_by_id, chat_id) if not e: return get_data_error_result(message="Chat not found!") if session_id: e, conv = await thread_pool_exec(ConversationService.get_by_id, session_id) if not e: return get_data_error_result(message="Session not found!") if conv.dialog_id != chat_id: return get_data_error_result(message="Session does not belong to this chat!") else: conv = await _create_session_for_completion(chat_id, dia, current_user.id) session_id = conv.id if pass_all_history_messages: conv.message = deepcopy(request_messages) msg = request_msg else: if not conv.message: conv.message = [] conv.message.append(deepcopy(request_msg[-1])) msg = [] for m in conv.message: if m["role"] == "system": continue if m["role"] == "assistant" and not msg: continue msg.append(m) else: dia = _build_default_completion_dialog() req.pop("messages", None) req.pop("question", None) if conv is not None: if not conv.reference: conv.reference = [] conv.reference = [r for r in conv.reference if r] conv.reference.append({"chunks": [], "doc_aggs": []}) if chat_model_id: if not await thread_pool_exec(get_api_key, tenant_id=dia.tenant_id, model_name=chat_model_id): return get_data_error_result(message=f"Cannot use specified model {chat_model_id}.") dia.llm_id = chat_model_id dia.llm_setting = chat_model_config elif not dia.llm_id: logging.info("empty chat_model_id in req, use default chat model.") _, tenant_info = TenantService.get_by_id(dia.tenant_id) if not tenant_info or not tenant_info.llm_id: raise LookupError("No default chat model for tenant.") dia.llm_id = tenant_info.llm_id merge_generation_config(dia, chat_model_config) stream_mode = req.pop("stream", True) def _format_answer(ans): """Wrap a raw answer dict with session and chat identifiers.""" formatted = structure_answer(conv, ans, message_id, session_id) if chat_id: formatted["chat_id"] = chat_id return formatted async def stream(): """Yield SSE-formatted chunks from the async chat generator.""" nonlocal dia, msg, req, conv try: async for ans in async_chat(dia, msg, True, session_id=session_id, **req): ans = _format_answer(ans) payload = _sanitize_json_floats({"code": 0, "message": "", "data": ans}) yield "data:" + json.dumps(payload, ensure_ascii=False) + "\n\n" if conv is not None: await thread_pool_exec(ConversationService.update_by_id, conv.id, conv.to_dict()) except Exception as ex: logging.exception(ex) yield "data:" + json.dumps({"code": 500, "message": str(ex), "data": {"answer": "**ERROR**: " + str(ex), "reference": []}}, ensure_ascii=False) + "\n\n" yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n" if stream_mode: resp = Response(stream(), mimetype="text/event-stream") resp.headers.add_header("Cache-control", "no-cache") resp.headers.add_header("Connection", "keep-alive") resp.headers.add_header("X-Accel-Buffering", "no") resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8") return resp answer = None async for ans in async_chat(dia, msg, False, session_id=session_id, **req): answer = _format_answer(ans) if conv is not None: await thread_pool_exec(ConversationService.update_by_id, conv.id, conv.to_dict()) break return get_json_result(data=_sanitize_json_floats(answer)) except Exception as ex: return server_error_response(ex)