diff --git a/api/apps/restful_apis/openai_api.py b/api/apps/restful_apis/openai_api.py index baa011f32a..f422795006 100644 --- a/api/apps/restful_apis/openai_api.py +++ b/api/apps/restful_apis/openai_api.py @@ -15,20 +15,25 @@ # import json +import logging import time +from uuid import uuid4 from quart import Response, jsonify from api.apps import current_user, login_required +from api.db.services.conversation_service import ConversationService, structure_answer from api.db.services.dialog_service import DialogService, async_chat from api.db.services.doc_metadata_service import DocMetadataService from api.db.services.tenant_llm_service import TenantLLMService from api.utils.api_utils import get_error_data_result, get_request_json, validate_request +from api.utils.reference_metadata_utils import enrich_chunks_with_document_metadata from common.constants import RetCode, StatusEnum from common.metadata_utils import convert_conditions, meta_filter from common.token_utils import num_tokens_from_string from rag.prompts.generator import chunks_format + def _validate_llm_id(llm_id, tenant_id, llm_setting=None): if not llm_id: return None @@ -48,9 +53,6 @@ def _validate_llm_id(llm_id, tenant_id, llm_setting=None): return None -import logging -from api.utils.reference_metadata_utils import enrich_chunks_with_document_metadata - def _build_reference_chunks(reference, include_metadata=False, metadata_fields=None): chunks = chunks_format(reference) if not include_metadata: @@ -90,6 +92,47 @@ def _build_sse_response(body): return resp +def _build_chat_messages(messages, conv=None): + if conv is not None: + if not conv.message: + conv.message = [] + user_message = dict(messages[-1]) + user_message.setdefault("id", str(uuid4())) + user_message.setdefault("created_at", time.time()) + conv.message.append(user_message) + source_messages = conv.message + else: + source_messages = messages + + chat_messages = [] + for message in source_messages: + if message["role"] == "system": + continue + if message["role"] == "assistant" and not chat_messages: + continue + chat_messages.append(message) + + message_id = chat_messages[-1].get("id") if chat_messages else None + return chat_messages, message_id + + +def _prepare_session(chat_id, session_id): + if not session_id: + return None + + ok, conv = ConversationService.get_by_id(session_id) + if not ok: + return "Session not found!" + if conv.dialog_id != chat_id: + return "Session does not belong to this chat!" + + if not conv.reference: + conv.reference = [] + conv.reference = [ref for ref in conv.reference if ref] + conv.reference.append({"chunks": [], "doc_aggs": []}) + return conv + + @manager.route("/openai//chat/completions", methods=["POST"]) # noqa: F821 @login_required @validate_request("model", "messages") @@ -109,6 +152,12 @@ async def openai_chat_completions(chat_id): if metadata_fields is not None and not isinstance(metadata_fields, list): return get_error_data_result("reference_metadata.fields must be an array.") + session_id = req.get("session_id", extra_body.get("session_id", "")) + if session_id is None: + session_id = "" + if not isinstance(session_id, str): + return get_error_data_result("session_id must be a string.") + messages = req.get("messages", []) if len(messages) < 1: return get_error_data_result("You have to provide messages.") @@ -125,6 +174,10 @@ async def openai_chat_completions(chat_id): return get_error_data_result(f"You don't own the chat {chat_id}") dia = dia[0] + conv = _prepare_session(chat_id, session_id) + if isinstance(conv, str): + return get_error_data_result(conv) + using_placeholder_model = requested_model == "model" if using_placeholder_model: requested_model = dia.llm_id or requested_model @@ -152,13 +205,7 @@ async def openai_chat_completions(chat_id): filtered_doc_ids = ["-999"] doc_ids_str = ",".join(filtered_doc_ids) if filtered_doc_ids else None - msg = [] - for message in messages: - if message["role"] == "system": - continue - if message["role"] == "assistant" and not msg: - continue - msg.append(message) + msg, message_id = _build_chat_messages(messages, conv) tools = None toolcall_session = None @@ -194,12 +241,16 @@ async def openai_chat_completions(chat_id): "system_fingerprint": "", "usage": None, } + if session_id: + response["session_id"] = session_id try: chat_kwargs = {"toolcall_session": toolcall_session, "tools": tools, "quote": need_reference} if doc_ids_str: chat_kwargs["doc_ids"] = doc_ids_str async for ans in async_chat(dia, msg, True, **chat_kwargs): + if conv is not None: + ans = structure_answer(conv, ans, message_id, session_id) last_ans = ans if ans.get("final"): if ans.get("answer"): @@ -228,6 +279,8 @@ async def openai_chat_completions(chat_id): response["choices"][0]["delta"]["content"] = delta response["choices"][0]["delta"]["reasoning_content"] = None yield f"data:{json.dumps(response, ensure_ascii=False)}\n\n" + if conv is not None: + ConversationService.update_by_id(conv.id, conv.to_dict()) except Exception as e: response["choices"][0]["delta"]["content"] = "**ERROR**: " + str(e) yield f"data:{json.dumps(response, ensure_ascii=False)}\n\n" @@ -248,7 +301,8 @@ async def openai_chat_completions(chat_id): include_metadata=include_reference_metadata, metadata_fields=metadata_fields, ) - response["choices"][0]["delta"]["final_content"] = final_answer if final_answer is not None else full_content + final_content = final_answer if final_answer is not None else full_content + response["choices"][0]["delta"]["final_content"] = final_content yield f"data:{json.dumps(response, ensure_ascii=False)}\n\n" yield "data:[DONE]\n\n" @@ -261,6 +315,9 @@ async def openai_chat_completions(chat_id): async for ans in async_chat(dia, msg, False, **chat_kwargs): answer = ans break + if conv is not None: + answer = structure_answer(conv, answer, message_id, session_id) + ConversationService.update_by_id(conv.id, conv.to_dict()) content = answer["answer"] response = { @@ -290,6 +347,8 @@ async def openai_chat_completions(chat_id): } ], } + if session_id: + response["session_id"] = session_id if need_reference: response["choices"][0]["message"]["reference"] = _build_reference_chunks( answer.get("reference", {}), diff --git a/test/testcases/restful_api/test_openai_compatible.py b/test/testcases/restful_api/test_openai_compatible.py index 49e2c55ca5..cd6fd0bf44 100644 --- a/test/testcases/restful_api/test_openai_compatible.py +++ b/test/testcases/restful_api/test_openai_compatible.py @@ -140,6 +140,60 @@ def test_openai_compatible_nonstream_shape(rest_client, create_chat): assert usage["total_tokens"] == usage["prompt_tokens"] + usage["completion_tokens"], usage +@pytest.mark.p2 +def test_openai_compatible_nonstream_supports_session_id(rest_client, create_chat): + chat_id = create_chat("restful_openai_session_chat") + create_session = rest_client.post(f"/chats/{chat_id}/sessions", json={"name": "openai_session"}) + assert create_session.status_code == 200 + create_session_payload = create_session.json() + assert create_session_payload["code"] == 0, create_session_payload + session_id = create_session_payload["data"]["id"] + + res = rest_client.post( + f"/openai/{chat_id}/chat/completions", + json={ + "model": "model", + "messages": [{"role": "user", "content": "remember this turn"}], + "stream": False, + "session_id": session_id, + }, + timeout=60, + ) + assert res.status_code == 200 + payload = res.json() + assert payload["session_id"] == session_id, payload + assert payload["object"] == "chat.completion", payload + + get_session = rest_client.get(f"/chats/{chat_id}/sessions/{session_id}") + assert get_session.status_code == 200 + session_payload = get_session.json() + assert session_payload["code"] == 0, session_payload + messages = session_payload["data"]["messages"] + assert any( + message["role"] == "user" and message["content"] == "remember this turn" + for message in messages + ), messages + assert any(message["role"] == "assistant" for message in messages), messages + + +@pytest.mark.p2 +def test_openai_compatible_invalid_session_id(rest_client, create_chat): + chat_id = create_chat("restful_openai_invalid_session_chat") + res = rest_client.post( + f"/openai/{chat_id}/chat/completions", + json={ + "model": "model", + "messages": [{"role": "user", "content": "hello"}], + "stream": False, + "session_id": "invalid_session", + }, + ) + assert res.status_code == 200 + payload = res.json() + assert payload["code"] != 0, payload + assert "Session not found!" in payload["message"], payload + + @pytest.mark.p2 def test_openai_compatible_nonstream_with_reference_output_shape(rest_client, create_chat): chat_id = create_chat("restful_openai_reference_chat")