mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 23:41:12 +08:00
Revert "Fix: /openai/<chat_id>/chat/completions not aware of session_id" (#15205)
Reverts infiniflow/ragflow#15155 because this is never supported, keep it as it is.
This commit is contained in:
@@ -15,25 +15,20 @@
|
||||
#
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from uuid import uuid4
|
||||
|
||||
from quart import Response, jsonify
|
||||
|
||||
from api.apps import current_user, login_required
|
||||
from api.db.services.conversation_service import ConversationService, structure_answer
|
||||
from api.db.services.dialog_service import DialogService, async_chat
|
||||
from api.db.services.doc_metadata_service import DocMetadataService
|
||||
from api.db.services.tenant_llm_service import TenantLLMService
|
||||
from api.utils.api_utils import get_error_data_result, get_request_json, validate_request
|
||||
from api.utils.reference_metadata_utils import enrich_chunks_with_document_metadata
|
||||
from common.constants import RetCode, StatusEnum
|
||||
from common.metadata_utils import convert_conditions, meta_filter
|
||||
from common.token_utils import num_tokens_from_string
|
||||
from rag.prompts.generator import chunks_format
|
||||
|
||||
|
||||
def _validate_llm_id(llm_id, tenant_id, llm_setting=None):
|
||||
if not llm_id:
|
||||
return None
|
||||
@@ -53,6 +48,9 @@ def _validate_llm_id(llm_id, tenant_id, llm_setting=None):
|
||||
return None
|
||||
|
||||
|
||||
import logging
|
||||
from api.utils.reference_metadata_utils import enrich_chunks_with_document_metadata
|
||||
|
||||
def _build_reference_chunks(reference, include_metadata=False, metadata_fields=None):
|
||||
chunks = chunks_format(reference)
|
||||
if not include_metadata:
|
||||
@@ -92,47 +90,6 @@ def _build_sse_response(body):
|
||||
return resp
|
||||
|
||||
|
||||
def _build_chat_messages(messages, conv=None):
|
||||
if conv is not None:
|
||||
if not conv.message:
|
||||
conv.message = []
|
||||
user_message = dict(messages[-1])
|
||||
user_message.setdefault("id", str(uuid4()))
|
||||
user_message.setdefault("created_at", time.time())
|
||||
conv.message.append(user_message)
|
||||
source_messages = conv.message
|
||||
else:
|
||||
source_messages = messages
|
||||
|
||||
chat_messages = []
|
||||
for message in source_messages:
|
||||
if message["role"] == "system":
|
||||
continue
|
||||
if message["role"] == "assistant" and not chat_messages:
|
||||
continue
|
||||
chat_messages.append(message)
|
||||
|
||||
message_id = chat_messages[-1].get("id") if chat_messages else None
|
||||
return chat_messages, message_id
|
||||
|
||||
|
||||
def _prepare_session(chat_id, session_id):
|
||||
if not session_id:
|
||||
return None
|
||||
|
||||
ok, conv = ConversationService.get_by_id(session_id)
|
||||
if not ok:
|
||||
return "Session not found!"
|
||||
if conv.dialog_id != chat_id:
|
||||
return "Session does not belong to this chat!"
|
||||
|
||||
if not conv.reference:
|
||||
conv.reference = []
|
||||
conv.reference = [ref for ref in conv.reference if ref]
|
||||
conv.reference.append({"chunks": [], "doc_aggs": []})
|
||||
return conv
|
||||
|
||||
|
||||
@manager.route("/openai/<chat_id>/chat/completions", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("model", "messages")
|
||||
@@ -152,12 +109,6 @@ async def openai_chat_completions(chat_id):
|
||||
if metadata_fields is not None and not isinstance(metadata_fields, list):
|
||||
return get_error_data_result("reference_metadata.fields must be an array.")
|
||||
|
||||
session_id = req.get("session_id", extra_body.get("session_id", ""))
|
||||
if session_id is None:
|
||||
session_id = ""
|
||||
if not isinstance(session_id, str):
|
||||
return get_error_data_result("session_id must be a string.")
|
||||
|
||||
messages = req.get("messages", [])
|
||||
if len(messages) < 1:
|
||||
return get_error_data_result("You have to provide messages.")
|
||||
@@ -174,10 +125,6 @@ async def openai_chat_completions(chat_id):
|
||||
return get_error_data_result(f"You don't own the chat {chat_id}")
|
||||
dia = dia[0]
|
||||
|
||||
conv = _prepare_session(chat_id, session_id)
|
||||
if isinstance(conv, str):
|
||||
return get_error_data_result(conv)
|
||||
|
||||
using_placeholder_model = requested_model == "model"
|
||||
if using_placeholder_model:
|
||||
requested_model = dia.llm_id or requested_model
|
||||
@@ -205,7 +152,13 @@ async def openai_chat_completions(chat_id):
|
||||
filtered_doc_ids = ["-999"]
|
||||
doc_ids_str = ",".join(filtered_doc_ids) if filtered_doc_ids else None
|
||||
|
||||
msg, message_id = _build_chat_messages(messages, conv)
|
||||
msg = []
|
||||
for message in messages:
|
||||
if message["role"] == "system":
|
||||
continue
|
||||
if message["role"] == "assistant" and not msg:
|
||||
continue
|
||||
msg.append(message)
|
||||
|
||||
tools = None
|
||||
toolcall_session = None
|
||||
@@ -241,16 +194,12 @@ async def openai_chat_completions(chat_id):
|
||||
"system_fingerprint": "",
|
||||
"usage": None,
|
||||
}
|
||||
if session_id:
|
||||
response["session_id"] = session_id
|
||||
|
||||
try:
|
||||
chat_kwargs = {"toolcall_session": toolcall_session, "tools": tools, "quote": need_reference}
|
||||
if doc_ids_str:
|
||||
chat_kwargs["doc_ids"] = doc_ids_str
|
||||
async for ans in async_chat(dia, msg, True, **chat_kwargs):
|
||||
if conv is not None:
|
||||
ans = structure_answer(conv, ans, message_id, session_id)
|
||||
last_ans = ans
|
||||
if ans.get("final"):
|
||||
if ans.get("answer"):
|
||||
@@ -279,8 +228,6 @@ async def openai_chat_completions(chat_id):
|
||||
response["choices"][0]["delta"]["content"] = delta
|
||||
response["choices"][0]["delta"]["reasoning_content"] = None
|
||||
yield f"data:{json.dumps(response, ensure_ascii=False)}\n\n"
|
||||
if conv is not None:
|
||||
ConversationService.update_by_id(conv.id, conv.to_dict())
|
||||
except Exception as e:
|
||||
response["choices"][0]["delta"]["content"] = "**ERROR**: " + str(e)
|
||||
yield f"data:{json.dumps(response, ensure_ascii=False)}\n\n"
|
||||
@@ -301,8 +248,7 @@ async def openai_chat_completions(chat_id):
|
||||
include_metadata=include_reference_metadata,
|
||||
metadata_fields=metadata_fields,
|
||||
)
|
||||
final_content = final_answer if final_answer is not None else full_content
|
||||
response["choices"][0]["delta"]["final_content"] = final_content
|
||||
response["choices"][0]["delta"]["final_content"] = final_answer if final_answer is not None else full_content
|
||||
yield f"data:{json.dumps(response, ensure_ascii=False)}\n\n"
|
||||
yield "data:[DONE]\n\n"
|
||||
|
||||
@@ -315,9 +261,6 @@ async def openai_chat_completions(chat_id):
|
||||
async for ans in async_chat(dia, msg, False, **chat_kwargs):
|
||||
answer = ans
|
||||
break
|
||||
if conv is not None:
|
||||
answer = structure_answer(conv, answer, message_id, session_id)
|
||||
ConversationService.update_by_id(conv.id, conv.to_dict())
|
||||
|
||||
content = answer["answer"]
|
||||
response = {
|
||||
@@ -347,8 +290,6 @@ async def openai_chat_completions(chat_id):
|
||||
}
|
||||
],
|
||||
}
|
||||
if session_id:
|
||||
response["session_id"] = session_id
|
||||
if need_reference:
|
||||
response["choices"][0]["message"]["reference"] = _build_reference_chunks(
|
||||
answer.get("reference", {}),
|
||||
|
||||
@@ -140,60 +140,6 @@ def test_openai_compatible_nonstream_shape(rest_client, create_chat):
|
||||
assert usage["total_tokens"] == usage["prompt_tokens"] + usage["completion_tokens"], usage
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_openai_compatible_nonstream_supports_session_id(rest_client, create_chat):
|
||||
chat_id = create_chat("restful_openai_session_chat")
|
||||
create_session = rest_client.post(f"/chats/{chat_id}/sessions", json={"name": "openai_session"})
|
||||
assert create_session.status_code == 200
|
||||
create_session_payload = create_session.json()
|
||||
assert create_session_payload["code"] == 0, create_session_payload
|
||||
session_id = create_session_payload["data"]["id"]
|
||||
|
||||
res = rest_client.post(
|
||||
f"/openai/{chat_id}/chat/completions",
|
||||
json={
|
||||
"model": "model",
|
||||
"messages": [{"role": "user", "content": "remember this turn"}],
|
||||
"stream": False,
|
||||
"session_id": session_id,
|
||||
},
|
||||
timeout=60,
|
||||
)
|
||||
assert res.status_code == 200
|
||||
payload = res.json()
|
||||
assert payload["session_id"] == session_id, payload
|
||||
assert payload["object"] == "chat.completion", payload
|
||||
|
||||
get_session = rest_client.get(f"/chats/{chat_id}/sessions/{session_id}")
|
||||
assert get_session.status_code == 200
|
||||
session_payload = get_session.json()
|
||||
assert session_payload["code"] == 0, session_payload
|
||||
messages = session_payload["data"]["messages"]
|
||||
assert any(
|
||||
message["role"] == "user" and message["content"] == "remember this turn"
|
||||
for message in messages
|
||||
), messages
|
||||
assert any(message["role"] == "assistant" for message in messages), messages
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_openai_compatible_invalid_session_id(rest_client, create_chat):
|
||||
chat_id = create_chat("restful_openai_invalid_session_chat")
|
||||
res = rest_client.post(
|
||||
f"/openai/{chat_id}/chat/completions",
|
||||
json={
|
||||
"model": "model",
|
||||
"messages": [{"role": "user", "content": "hello"}],
|
||||
"stream": False,
|
||||
"session_id": "invalid_session",
|
||||
},
|
||||
)
|
||||
assert res.status_code == 200
|
||||
payload = res.json()
|
||||
assert payload["code"] != 0, payload
|
||||
assert "Session not found!" in payload["message"], payload
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_openai_compatible_nonstream_with_reference_output_shape(rest_client, create_chat):
|
||||
chat_id = create_chat("restful_openai_reference_chat")
|
||||
|
||||
Reference in New Issue
Block a user