From a75ea7ba7c1b0bd4bc76dd3cb7db96810258b68c Mon Sep 17 00:00:00 2001 From: bitloi <89318445+bitloi@users.noreply.github.com> Date: Wed, 3 Jun 2026 00:46:10 -0300 Subject: [PATCH] Fix: Chat completion generation parameter overrides (#15389) ### What problem does this PR solve? Closes #15388. Chat completion routes did not reliably honor per-request generation settings: - `/api/v1/chat/completions` copied generation settings with a truthiness check, so valid zero values such as `temperature: 0`, `top_p: 0`, `frequency_penalty: 0`, `presence_penalty: 0`, and `max_tokens: 0` were dropped. - `/api/v1/openai/{chat_id}/chat/completions` did not forward standard generation settings into the request-specific dialog LLM settings before calling `async_chat`. This PR preserves explicitly supplied generation parameters, including zero values, and merges request-level overrides into existing dialog settings where appropriate. The supported generation parameter keys and merge behavior live in a shared REST API helper to keep both completion routes aligned. Validation: - `git diff --check` - `python3 -m py_compile api/apps/restful_apis/_generation_params.py api/apps/restful_apis/chat_api.py api/apps/restful_apis/openai_api.py test/testcases/test_http_api/test_session_management/test_session_sdk_routes_unit.py` - `uv run ruff check api/apps/restful_apis/_generation_params.py api/apps/restful_apis/chat_api.py api/apps/restful_apis/openai_api.py test/testcases/test_http_api/test_session_management/test_session_sdk_routes_unit.py` - `ZHIPU_AI_API_KEY=dummy uv run pytest test/testcases/test_http_api/test_session_management/test_session_sdk_routes_unit.py -q -k generation_params` ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --- api/apps/restful_apis/_generation_params.py | 38 ++++ api/apps/restful_apis/chat_api.py | 9 +- api/apps/restful_apis/openai_api.py | 3 + .../test_session_sdk_routes_unit.py | 183 +++++++++++++++++- 4 files changed, 226 insertions(+), 7 deletions(-) create mode 100644 api/apps/restful_apis/_generation_params.py diff --git a/api/apps/restful_apis/_generation_params.py b/api/apps/restful_apis/_generation_params.py new file mode 100644 index 0000000000..e5fa79c2bf --- /dev/null +++ b/api/apps/restful_apis/_generation_params.py @@ -0,0 +1,38 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from copy import deepcopy + +GENERATION_CONFIG_KEYS = ("temperature", "top_p", "frequency_penalty", "presence_penalty", "max_tokens") + + +def extract_generation_config(req): + return {key: req[key] for key in GENERATION_CONFIG_KEYS if key in req and req[key] is not None} + + +def pop_generation_config(req): + generation_config = extract_generation_config(req) + for key in GENERATION_CONFIG_KEYS: + req.pop(key, None) + return generation_config + + +def merge_generation_config(dialog, generation_config): + if not generation_config: + return + llm_setting = deepcopy(getattr(dialog, "llm_setting", None) or {}) + llm_setting.update(generation_config) + dialog.llm_setting = llm_setting diff --git a/api/apps/restful_apis/chat_api.py b/api/apps/restful_apis/chat_api.py index d0ea5f6280..5d7966aa42 100644 --- a/api/apps/restful_apis/chat_api.py +++ b/api/apps/restful_apis/chat_api.py @@ -26,6 +26,7 @@ from types import SimpleNamespace from quart import Response, request from api.apps import current_user, login_required +from api.apps.restful_apis._generation_params import merge_generation_config, pop_generation_config from api.db.joint_services.tenant_model_service import ( get_tenant_default_model_by_type, get_model_config_from_provider_instance, get_api_key, split_model_name ) @@ -1173,11 +1174,7 @@ async def session_completion(chat_id_in_arg=""): session_id = req.pop("session_id", "") or req.pop("conversation_id", "") or "" chat_model_id = req.pop("llm_id", "") - chat_model_config = {} - for model_config in ["temperature", "top_p", "frequency_penalty", "presence_penalty", "max_tokens"]: - config = req.get(model_config) - if config: - chat_model_config[model_config] = config + chat_model_config = pop_generation_config(req) try: conv = None @@ -1220,7 +1217,6 @@ async def session_completion(chat_id_in_arg=""): msg.append(m) else: dia = _build_default_completion_dialog() - dia.llm_setting = chat_model_config req.pop("messages", None) req.pop("question", None) @@ -1242,6 +1238,7 @@ async def session_completion(chat_id_in_arg=""): if not tenant_info or not tenant_info.llm_id: raise LookupError("No default chat model for tenant.") dia.llm_id = tenant_info.llm_id + merge_generation_config(dia, chat_model_config) stream_mode = req.pop("stream", True) diff --git a/api/apps/restful_apis/openai_api.py b/api/apps/restful_apis/openai_api.py index 7db6625871..0642dd04ca 100644 --- a/api/apps/restful_apis/openai_api.py +++ b/api/apps/restful_apis/openai_api.py @@ -20,6 +20,7 @@ import time from quart import Response, jsonify from api.apps import current_user, login_required +from api.apps.restful_apis._generation_params import extract_generation_config, merge_generation_config from api.db.services.dialog_service import DialogService, async_chat from api.db.services.doc_metadata_service import DocMetadataService from api.db.joint_services.tenant_model_service import get_model_config_from_provider_instance, get_api_key @@ -29,6 +30,7 @@ from common.metadata_utils import convert_conditions, meta_filter from common.token_utils import num_tokens_from_string from rag.prompts.generator import chunks_format + def _validate_llm_id(llm_id, tenant_id, llm_setting=None): if not llm_id: return None @@ -278,6 +280,7 @@ async def openai_chat_completions(chat_id): dia.llm_id = requested_model if not get_api_key(tenant_id=dia.tenant_id, model_name=requested_model): return get_error_data_result(message=f"Cannot use specified model {requested_model}.") + merge_generation_config(dia, extract_generation_config(req)) metadata_condition = extra_body.get("metadata_condition") or {} if metadata_condition and not isinstance(metadata_condition, dict): diff --git a/test/testcases/test_http_api/test_session_management/test_session_sdk_routes_unit.py b/test/testcases/test_http_api/test_session_management/test_session_sdk_routes_unit.py index 2fa7f824c4..6c1aa4b961 100644 --- a/test/testcases/test_http_api/test_session_management/test_session_sdk_routes_unit.py +++ b/test/testcases/test_http_api/test_session_management/test_session_sdk_routes_unit.py @@ -486,6 +486,19 @@ def _load_session_module(monkeypatch): raise Exception("Model Name is required") return _MockModelConfig2(tenant_id, model_name, model_type).to_dict() + def _get_api_key(tenant_id: str, model_name: str): + if not tenant_id or not model_name: + return None + return "fake-api-key" + + def _split_model_name(model_name: str): + parts = model_name.split("@") + if len(parts) == 1: + return parts[0], "", "" + if len(parts) == 2: + return parts[0], "default", parts[1] + return parts[0], parts[1], parts[2] + def _get_tenant_default_model_by_type(tenant_id: str, model_type): # Check if tenant exists from api.db.services.tenant_llm_service import TenantService @@ -527,6 +540,8 @@ def _load_session_module(monkeypatch): tenant_model_service_mod.get_model_config_by_id = _get_model_config_by_id tenant_model_service_mod.get_model_config_from_provider_instance = _get_model_config_from_provider_instance tenant_model_service_mod.get_tenant_default_model_by_type = _get_tenant_default_model_by_type + tenant_model_service_mod.get_api_key = _get_api_key + tenant_model_service_mod.split_model_name = _split_model_name monkeypatch.setitem(sys.modules, "api.db.joint_services.tenant_model_service", tenant_model_service_mod) agent_pkg = ModuleType("agent") @@ -1152,6 +1167,58 @@ def test_openai_invalid_message_content_type_unit(monkeypatch): assert "messages[].content must be a string or an array of content parts." in res["message"] +@pytest.mark.p2 +def test_openai_nonstream_forwards_generation_params_unit(monkeypatch): + module = _load_openai_api_module(monkeypatch) + + base_llm_setting = {"temperature": 0.7, "model_type": "chat"} + dia = SimpleNamespace( + kb_ids=[], + llm_id="chat-model", + tenant_id="tenant-1", + llm_setting=base_llm_setting, + ) + captured = {} + + monkeypatch.setattr(module, "num_tokens_from_string", lambda text: len(text or "")) + monkeypatch.setattr(module.DialogService, "query", lambda **_kwargs: [dia]) + + async def fake_async_chat(captured_dia, _msg, _stream, **_kwargs): + captured["llm_setting"] = dict(captured_dia.llm_setting) + yield {"answer": "world", "reference": {}} + + monkeypatch.setattr(module, "async_chat", fake_async_chat) + monkeypatch.setattr( + module, + "get_request_json", + lambda: _AwaitableValue( + { + "model": "model", + "messages": [{"role": "user", "content": "hello"}], + "stream": False, + "temperature": 0, + "top_p": 0, + "frequency_penalty": 0, + "presence_penalty": 0, + "max_tokens": 0, + } + ), + ) + + res = _run(inspect.unwrap(module.openai_chat_completions)("chat-1")) + + assert res["choices"][0]["message"]["content"] == "world" + assert captured["llm_setting"] == { + "temperature": 0, + "model_type": "chat", + "top_p": 0, + "frequency_penalty": 0, + "presence_penalty": 0, + "max_tokens": 0, + } + assert base_llm_setting == {"temperature": 0.7, "model_type": "chat"} + + @pytest.mark.p2 def test_agents_openai_compatibility_unit(monkeypatch): module = _load_agent_api_module(monkeypatch) @@ -2201,6 +2268,9 @@ def _load_chat_api_module(monkeypatch): tenant_model_svc = ModuleType("api.db.joint_services.tenant_model_service") tenant_model_svc.get_tenant_default_model_by_type = lambda *_a, **_k: {} + tenant_model_svc.get_model_config_from_provider_instance = lambda **_k: {} + tenant_model_svc.get_api_key = lambda **_k: "fake-api-key" + tenant_model_svc.split_model_name = lambda model_name: (model_name, "", "") monkeypatch.setitem(sys.modules, "api.db.joint_services.tenant_model_service", tenant_model_svc) chunk_feedback_mod = ModuleType("api.db.services.chunk_feedback_service") @@ -2274,7 +2344,7 @@ def _load_chat_api_module(monkeypatch): user_svc_mod = ModuleType("api.db.services.user_service") user_svc_mod.TenantService = SimpleNamespace( - get_by_id=lambda _id: (True, SimpleNamespace(id=_id)), + get_by_id=lambda _id: (True, SimpleNamespace(id=_id, llm_id="chat-model")), get_joined_tenants_by_user_id=lambda _id: [], ) user_svc_mod.UserTenantService = SimpleNamespace(query=lambda **_k: []) @@ -2427,6 +2497,117 @@ def test_session_completion_uses_server_history_by_default(monkeypatch): ] +@pytest.mark.p2 +def test_session_completion_preserves_zero_generation_params(monkeypatch): + module = _load_chat_api_module(monkeypatch) + + captured = {} + + async def _fake_async_chat(dia, _messages, stream=True, **_kwargs): + captured["llm_setting"] = dict(dia.llm_setting) + captured["kwargs"] = dict(_kwargs) + yield {"answer": "ok", "reference": {}} + + monkeypatch.setattr(module, "async_chat", _fake_async_chat) + monkeypatch.setattr(module, "structure_answer", lambda _conv, ans, _message_id, _session_id: ans) + monkeypatch.setattr( + module, + "get_request_json", + lambda: _AwaitableValue({ + "stream": False, + "messages": [{"role": "user", "content": "latest question"}], + "temperature": 0, + "top_p": 0, + "frequency_penalty": 0, + "presence_penalty": 0, + "max_tokens": 0, + }), + ) + + res = _run(inspect.unwrap(module.session_completion)()) + + assert res["code"] == 0, res + assert captured["llm_setting"] == { + "temperature": 0, + "top_p": 0, + "frequency_penalty": 0, + "presence_penalty": 0, + "max_tokens": 0, + } + assert not { + "temperature", + "top_p", + "frequency_penalty", + "presence_penalty", + "max_tokens", + }.intersection(captured["kwargs"]) + + +@pytest.mark.p2 +def test_session_completion_merges_generation_params_for_existing_chat(monkeypatch): + module = _load_chat_api_module(monkeypatch) + + base_llm_setting = {"temperature": 0.7, "top_p": 0.3, "custom": "keep"} + dia = SimpleNamespace( + id="chat-1", + tenant_id="tenant-1", + llm_id="model", + llm_setting=base_llm_setting, + prompt_config={"prologue": ""}, + kb_ids=[], + ) + conv = SimpleNamespace( + id="session-1", + dialog_id="chat-1", + message=[], + reference=[], + user_id="authenticated-user", + name="test", + ) + conv.to_dict = lambda: { + "id": conv.id, + "dialog_id": conv.dialog_id, + "message": conv.message, + "reference": conv.reference, + "user_id": conv.user_id, + "name": conv.name, + } + captured = {} + + async def _fake_async_chat(captured_dia, _messages, stream=True, **_kwargs): + captured["llm_setting"] = dict(captured_dia.llm_setting) + yield {"answer": "ok", "reference": {}} + + monkeypatch.setattr(module.DialogService, "get_by_id", lambda _dialog_id: (True, dia)) + monkeypatch.setattr(module.ConversationService, "get_by_id", lambda _id: (True, conv)) + monkeypatch.setattr(module.ConversationService, "update_by_id", lambda *_a, **_k: True, raising=False) + monkeypatch.setattr(module, "async_chat", _fake_async_chat) + monkeypatch.setattr(module, "structure_answer", lambda _conv, ans, _message_id, _session_id: ans) + monkeypatch.setattr( + module, + "get_request_json", + lambda: _AwaitableValue({ + "chat_id": "chat-1", + "session_id": "session-1", + "stream": False, + "messages": [{"role": "user", "content": "latest question"}], + "temperature": 0, + "presence_penalty": 0, + }), + ) + + res = _run(inspect.unwrap(module.session_completion)()) + + assert res["code"] == 0, res + assert captured["llm_setting"] == { + "temperature": 0, + "top_p": 0.3, + "custom": "keep", + "presence_penalty": 0, + } + assert base_llm_setting == {"temperature": 0.7, "top_p": 0.3, "custom": "keep"} + + @pytest.mark.p2 def test_session_completion_can_use_submitted_full_history(monkeypatch): """The UI opt-in flag should preserve the previous full-history request behavior."""