mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 15:31:05 +08:00
Fix /chat/completions to allow send only the latest message (#15197)
### What problem does this PR solve? 1. Fix /chat/completions to send only the latest message 2. Allo chat stream=False ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
@@ -169,6 +169,73 @@ async def _create_session_for_completion(chat_id, dialog, user_id):
|
||||
return conv_obj
|
||||
|
||||
|
||||
def _get_bool_request_flag(req, *names, default=False):
|
||||
for name in names:
|
||||
if name not in req:
|
||||
continue
|
||||
value = req.pop(name)
|
||||
if isinstance(value, str):
|
||||
return value.strip().lower() in {"1", "true", "yes", "on"}
|
||||
return bool(value)
|
||||
return default
|
||||
|
||||
|
||||
def _normalize_completion_messages(req):
|
||||
messages = req.get("messages")
|
||||
if messages is None:
|
||||
question = req.get("question")
|
||||
if question is None:
|
||||
return None, get_data_error_result(
|
||||
code=RetCode.ARGUMENT_ERROR,
|
||||
message="required argument are missing: messages",
|
||||
)
|
||||
messages = [{"role": "user", "content": question}]
|
||||
if req.get("files"):
|
||||
messages[-1]["files"] = req["files"]
|
||||
|
||||
if not isinstance(messages, list) or not messages:
|
||||
return None, get_data_error_result(
|
||||
code=RetCode.ARGUMENT_ERROR,
|
||||
message="`messages` must be a non-empty list.",
|
||||
)
|
||||
|
||||
for message in messages:
|
||||
if not isinstance(message, dict):
|
||||
return None, get_data_error_result(
|
||||
code=RetCode.ARGUMENT_ERROR,
|
||||
message="Every item in `messages` must be an object.",
|
||||
)
|
||||
if "role" not in message or "content" not in message:
|
||||
return None, get_data_error_result(
|
||||
code=RetCode.ARGUMENT_ERROR,
|
||||
message="Every item in `messages` must include `role` and `content`.",
|
||||
)
|
||||
|
||||
msg = []
|
||||
for m in messages:
|
||||
if m["role"] == "system":
|
||||
continue
|
||||
if m["role"] == "assistant" and not msg:
|
||||
continue
|
||||
msg.append(m)
|
||||
|
||||
if not msg:
|
||||
return None, get_data_error_result(
|
||||
code=RetCode.ARGUMENT_ERROR,
|
||||
message="`messages` must contain a user message.",
|
||||
)
|
||||
if msg[-1]["role"] != "user":
|
||||
return None, get_data_error_result(
|
||||
code=RetCode.ARGUMENT_ERROR,
|
||||
message="The last message must be from user.",
|
||||
)
|
||||
if not msg[-1].get("id"):
|
||||
msg[-1]["id"] = get_uuid()
|
||||
|
||||
# till now, message and msg are sharing the same copy
|
||||
return (messages, msg), None
|
||||
|
||||
|
||||
async def _validate_llm_id(llm_id, tenant_id, llm_setting=None):
|
||||
if not llm_id:
|
||||
return None
|
||||
@@ -1057,18 +1124,16 @@ async def recommendation():
|
||||
|
||||
@manager.route("/chat/completions", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("messages")
|
||||
async def session_completion(chat_id_in_arg=""):
|
||||
"""Handle chat completion requests, streaming or non-streaming, scoped to the authenticated user."""
|
||||
req = await get_request_json()
|
||||
msg = []
|
||||
for m in req["messages"]:
|
||||
if m["role"] == "system":
|
||||
continue
|
||||
if m["role"] == "assistant" and not msg:
|
||||
continue
|
||||
msg.append(m)
|
||||
message_id = msg[-1].get("id") if msg else None
|
||||
normalized, error = _normalize_completion_messages(req)
|
||||
if error:
|
||||
return error
|
||||
request_messages, request_msg = normalized
|
||||
pass_all_history_messages = _get_bool_request_flag(req, "pass_all_history_messages", "pass_all_history", default=False)
|
||||
msg = request_msg
|
||||
message_id = request_msg[-1].get("id")
|
||||
chat_id = req.pop("chat_id", "") or ""
|
||||
chat_id = chat_id or chat_id_in_arg
|
||||
session_id = req.pop("session_id", "") or req.pop("conversation_id", "") or ""
|
||||
@@ -1104,12 +1169,27 @@ async def session_completion(chat_id_in_arg=""):
|
||||
else:
|
||||
conv = await _create_session_for_completion(chat_id, dia, current_user.id)
|
||||
session_id = conv.id
|
||||
conv.message = deepcopy(req["messages"])
|
||||
|
||||
if pass_all_history_messages:
|
||||
conv.message = deepcopy(request_messages)
|
||||
msg = request_msg
|
||||
else:
|
||||
if not conv.message:
|
||||
conv.message = []
|
||||
conv.message.append(deepcopy(request_msg[-1]))
|
||||
msg = []
|
||||
for m in conv.message:
|
||||
if m["role"] == "system":
|
||||
continue
|
||||
if m["role"] == "assistant" and not msg:
|
||||
continue
|
||||
msg.append(m)
|
||||
else:
|
||||
dia = _build_default_completion_dialog()
|
||||
dia.llm_setting = chat_model_config
|
||||
|
||||
del req["messages"]
|
||||
req.pop("messages", None)
|
||||
req.pop("question", None)
|
||||
|
||||
if conv is not None:
|
||||
if not conv.reference:
|
||||
@@ -1155,7 +1235,7 @@ async def session_completion(chat_id_in_arg=""):
|
||||
return resp
|
||||
|
||||
answer = None
|
||||
async for ans in async_chat(dia, msg, **req):
|
||||
async for ans in async_chat(dia, msg, False, **req):
|
||||
answer = _format_answer(ans)
|
||||
if conv is not None:
|
||||
await thread_pool_exec(ConversationService.update_by_id, conv.id, conv.to_dict())
|
||||
|
||||
@@ -4082,11 +4082,14 @@ The previous endpoint `POST /api/v1/chats/{chat_id}/completions` is deprecated.
|
||||
- `'content-Type: application/json'`
|
||||
- `'Authorization: Bearer <YOUR_API_KEY>'`
|
||||
- Body:
|
||||
|
||||
- `"messages"`: `list[object]`
|
||||
- `"question"`: `string`
|
||||
- `"stream"`: `boolean`
|
||||
- `"chat_id"`: `string` (optional)
|
||||
- `"session_id"`: `string` (optional)
|
||||
- `"llm_id"`: `string` (optional)
|
||||
- `"pass_all_history_messages"`: `boolean` (optional)
|
||||
|
||||
##### Request example
|
||||
|
||||
@@ -4117,10 +4120,6 @@ curl --request POST \
|
||||
"stream": true,
|
||||
"session_id":"9fa7691cb85c11ef9c5f0242ac120005",
|
||||
"messages": [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Hi! I'\''m your assistant. What can I do for you?"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Who are you?"
|
||||
@@ -4131,8 +4130,10 @@ curl --request POST \
|
||||
|
||||
##### Request Parameters
|
||||
|
||||
- `"messages"`: (*Body Parameter*), `list[object]`, *Required*
|
||||
The conversation messages sent to the model.
|
||||
- `"messages"`: (*Body Parameter*), `list[object]`
|
||||
The latest user message, or the conversation messages sent to the model when `pass_all_history_messages` is `true`. Either `messages` or `question` is required.
|
||||
- `"question"`: (*Body Parameter*), `string`
|
||||
Latest user question. This is equivalent to passing `messages: [{"role": "user", "content": question}]`.
|
||||
- `"stream"`: (*Body Parameter*), `boolean`
|
||||
Indicates whether to output responses in a streaming way:
|
||||
- `true`: Enable streaming (default).
|
||||
@@ -4143,6 +4144,8 @@ curl --request POST \
|
||||
Optional session ID. If `chat_id` is provided but `session_id` is omitted, a new session will be generated automatically.
|
||||
- `"llm_id"`: (*Body Parameter*), `string`
|
||||
Optional model override when a specific chat model should be used for this request.
|
||||
- `"pass_all_history_messages"`: (*Body Parameter*), `boolean`
|
||||
When `chat_id` and `session_id` are provided, defaults to `false`, so the server uses stored session history and only the latest user message from the request. Set to `true` to replace/use the submitted full `messages` history, and overrides the stored session history.
|
||||
|
||||
#### Response
|
||||
|
||||
|
||||
@@ -2039,6 +2039,7 @@ def _load_chat_api_module(monkeypatch):
|
||||
|
||||
class _RetCode(int, Enum):
|
||||
SUCCESS = 0
|
||||
ARGUMENT_ERROR = 101
|
||||
DATA_ERROR = 102
|
||||
AUTHENTICATION_ERROR = 109
|
||||
SERVER_ERROR = 500
|
||||
@@ -2258,3 +2259,183 @@ def test_session_completion_user_id_not_spoofable(monkeypatch):
|
||||
_run(inspect.unwrap(module.session_completion)())
|
||||
|
||||
assert captured_user_ids == [module.current_user.id]
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_session_completion_uses_server_history_by_default(monkeypatch):
|
||||
"""Session chat completions should append only the latest user message to stored history by default."""
|
||||
module = _load_chat_api_module(monkeypatch)
|
||||
|
||||
captured_messages = []
|
||||
conv = SimpleNamespace(
|
||||
id="session-1",
|
||||
dialog_id="chat-1",
|
||||
message=[
|
||||
{"role": "assistant", "content": "prologue"},
|
||||
{"role": "user", "content": "server old question", "id": "old-user"},
|
||||
{"role": "assistant", "content": "server old answer", "id": "old-user"},
|
||||
],
|
||||
reference=[],
|
||||
user_id="authenticated-user",
|
||||
name="test",
|
||||
)
|
||||
conv.to_dict = lambda: {
|
||||
"id": conv.id,
|
||||
"dialog_id": conv.dialog_id,
|
||||
"message": conv.message,
|
||||
"reference": conv.reference,
|
||||
"user_id": conv.user_id,
|
||||
"name": conv.name,
|
||||
}
|
||||
|
||||
async def _fake_async_chat(_dia, messages, stream=True, **_kwargs):
|
||||
captured_messages.append([dict(message) for message in messages])
|
||||
yield {"answer": "ok", "reference": {}}
|
||||
|
||||
monkeypatch.setattr(module.ConversationService, "get_by_id", lambda _id: (True, conv))
|
||||
monkeypatch.setattr(module.ConversationService, "update_by_id", lambda *_a, **_k: True, raising=False)
|
||||
monkeypatch.setattr(module, "async_chat", _fake_async_chat)
|
||||
monkeypatch.setattr(module, "structure_answer", lambda _conv, ans, _message_id, _session_id: ans)
|
||||
monkeypatch.setattr(
|
||||
module,
|
||||
"get_request_json",
|
||||
lambda: _AwaitableValue({
|
||||
"chat_id": "chat-1",
|
||||
"session_id": "session-1",
|
||||
"stream": False,
|
||||
"messages": [
|
||||
{"role": "user", "content": "client old question", "id": "client-old"},
|
||||
{"role": "assistant", "content": "client old answer", "id": "client-old"},
|
||||
{"role": "user", "content": "latest question", "id": "latest"},
|
||||
],
|
||||
}),
|
||||
)
|
||||
|
||||
res = _run(inspect.unwrap(module.session_completion)())
|
||||
|
||||
assert res["code"] == 0, res
|
||||
assert [message["content"] for message in captured_messages[0]] == [
|
||||
"server old question",
|
||||
"server old answer",
|
||||
"latest question",
|
||||
]
|
||||
assert [message["content"] for message in conv.message] == [
|
||||
"prologue",
|
||||
"server old question",
|
||||
"server old answer",
|
||||
"latest question",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_session_completion_can_use_submitted_full_history(monkeypatch):
|
||||
"""The UI opt-in flag should preserve the previous full-history request behavior."""
|
||||
module = _load_chat_api_module(monkeypatch)
|
||||
|
||||
captured_messages = []
|
||||
conv = SimpleNamespace(
|
||||
id="session-1",
|
||||
dialog_id="chat-1",
|
||||
message=[
|
||||
{"role": "assistant", "content": "prologue"},
|
||||
{"role": "user", "content": "server old question", "id": "old-user"},
|
||||
],
|
||||
reference=[],
|
||||
user_id="authenticated-user",
|
||||
name="test",
|
||||
)
|
||||
conv.to_dict = lambda: {
|
||||
"id": conv.id,
|
||||
"dialog_id": conv.dialog_id,
|
||||
"message": conv.message,
|
||||
"reference": conv.reference,
|
||||
"user_id": conv.user_id,
|
||||
"name": conv.name,
|
||||
}
|
||||
|
||||
async def _fake_async_chat(_dia, messages, stream=True, **_kwargs):
|
||||
captured_messages.append([dict(message) for message in messages])
|
||||
yield {"answer": "ok", "reference": {}}
|
||||
|
||||
monkeypatch.setattr(module.ConversationService, "get_by_id", lambda _id: (True, conv))
|
||||
monkeypatch.setattr(module.ConversationService, "update_by_id", lambda *_a, **_k: True, raising=False)
|
||||
monkeypatch.setattr(module, "async_chat", _fake_async_chat)
|
||||
monkeypatch.setattr(module, "structure_answer", lambda _conv, ans, _message_id, _session_id: ans)
|
||||
monkeypatch.setattr(
|
||||
module,
|
||||
"get_request_json",
|
||||
lambda: _AwaitableValue({
|
||||
"chat_id": "chat-1",
|
||||
"session_id": "session-1",
|
||||
"stream": False,
|
||||
"pass_all_history_messages": True,
|
||||
"messages": [
|
||||
{"role": "user", "content": "client old question", "id": "client-old"},
|
||||
{"role": "assistant", "content": "client old answer", "id": "client-old"},
|
||||
{"role": "user", "content": "latest question", "id": "latest"},
|
||||
],
|
||||
}),
|
||||
)
|
||||
|
||||
res = _run(inspect.unwrap(module.session_completion)())
|
||||
|
||||
assert res["code"] == 0, res
|
||||
assert [message["content"] for message in captured_messages[0]] == [
|
||||
"client old question",
|
||||
"client old answer",
|
||||
"latest question",
|
||||
]
|
||||
assert [message["content"] for message in conv.message] == [
|
||||
"client old question",
|
||||
"client old answer",
|
||||
"latest question",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_session_completion_accepts_question_payload(monkeypatch):
|
||||
"""Compatibility calls from /chats/{chat_id}/completions may send only `question`."""
|
||||
module = _load_chat_api_module(monkeypatch)
|
||||
|
||||
captured_messages = []
|
||||
conv = SimpleNamespace(
|
||||
id="session-1",
|
||||
dialog_id="chat-1",
|
||||
message=[{"role": "assistant", "content": "prologue"}],
|
||||
reference=[],
|
||||
user_id="authenticated-user",
|
||||
name="test",
|
||||
)
|
||||
conv.to_dict = lambda: {
|
||||
"id": conv.id,
|
||||
"dialog_id": conv.dialog_id,
|
||||
"message": conv.message,
|
||||
"reference": conv.reference,
|
||||
"user_id": conv.user_id,
|
||||
"name": conv.name,
|
||||
}
|
||||
|
||||
async def _fake_async_chat(_dia, messages, stream=True, **_kwargs):
|
||||
captured_messages.append([dict(message) for message in messages])
|
||||
yield {"answer": "ok", "reference": {}}
|
||||
|
||||
monkeypatch.setattr(module.ConversationService, "get_by_id", lambda _id: (True, conv))
|
||||
monkeypatch.setattr(module.ConversationService, "update_by_id", lambda *_a, **_k: True, raising=False)
|
||||
monkeypatch.setattr(module, "async_chat", _fake_async_chat)
|
||||
monkeypatch.setattr(module, "structure_answer", lambda _conv, ans, _message_id, _session_id: ans)
|
||||
monkeypatch.setattr(
|
||||
module,
|
||||
"get_request_json",
|
||||
lambda: _AwaitableValue({
|
||||
"chat_id": "chat-1",
|
||||
"session_id": "session-1",
|
||||
"stream": False,
|
||||
"question": "latest question",
|
||||
}),
|
||||
)
|
||||
|
||||
res = _run(inspect.unwrap(module.session_completion)())
|
||||
|
||||
assert res["code"] == 0, res
|
||||
assert [message["content"] for message in captured_messages[0]] == ["latest question"]
|
||||
assert conv.message[-1]["content"] == "latest question"
|
||||
|
||||
@@ -108,6 +108,7 @@ export const useSendMessage = (controller: AbortController) => {
|
||||
: (derivedMessages ?? [])),
|
||||
message,
|
||||
],
|
||||
pass_all_history_messages: true,
|
||||
reasoning: enableThinking,
|
||||
internet: enableInternet,
|
||||
},
|
||||
|
||||
@@ -80,6 +80,7 @@ export function useSendSingleMessage({
|
||||
reasoning: enableThinking,
|
||||
internet: enableInternet,
|
||||
...params,
|
||||
pass_all_history_messages: true,
|
||||
},
|
||||
controller,
|
||||
);
|
||||
|
||||
Reference in New Issue
Block a user