diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index 95d86c4b93..6571f11701 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -61,6 +61,47 @@ ERROR_PREFIX = "**ERROR**" LENGTH_NOTIFICATION_CN = "······\n由于大模型的上下文窗口大小限制,回答已经被大模型截断。" LENGTH_NOTIFICATION_EN = "...\nThe answer is truncated by your chosen LLM due to its limitation on context length." +# Generation parameters that are safe to forward to the underlying completion +# call. `gen_conf` originates from a chat assistant's `llm_setting`, which can +# also carry RAGFlow-internal metadata (e.g. `model_type`). Anything outside +# this set is dropped so providers don't reject the request with errors like +# "Extra inputs are not permitted" / "Unknown parameter: 'model_type'" (#15427). +ALLOWED_GEN_CONF_KEYS = frozenset( + { + "temperature", + "max_completion_tokens", + "top_p", + "stream", + "stream_options", + "stop", + "n", + "presence_penalty", + "frequency_penalty", + "functions", + "function_call", + "logit_bias", + "user", + "response_format", + "seed", + "tools", + "tool_choice", + "logprobs", + "top_logprobs", + "extra_headers", + } +) + +# LiteLLM additionally understands reasoning-control parameters that the +# model-family policies may inject into `gen_conf` (e.g. `thinking` for +# Anthropic / Kimi reasoning models, `reasoning_effort` for OpenAI o-series). +LITELLM_ALLOWED_GEN_CONF_KEYS = ALLOWED_GEN_CONF_KEYS | frozenset( + { + "thinking", + "reasoning_effort", + "extra_body", + } +) + def _apply_model_family_policies( model_name: str, @@ -159,30 +200,7 @@ class Base(ABC): if "max_tokens" in gen_conf: del gen_conf["max_tokens"] - allowed_conf = { - "temperature", - "max_completion_tokens", - "top_p", - "stream", - "stream_options", - "stop", - "n", - "presence_penalty", - "frequency_penalty", - "functions", - "function_call", - "logit_bias", - "user", - "response_format", - "seed", - "tools", - "tool_choice", - "logprobs", - "top_logprobs", - "extra_headers", - } - - gen_conf = {k: v for k, v in gen_conf.items() if k in allowed_conf} + gen_conf = {k: v for k, v in gen_conf.items() if k in ALLOWED_GEN_CONF_KEYS} return gen_conf async def _async_chat_streamly(self, history, gen_conf, **kwargs): @@ -1395,6 +1413,7 @@ class LiteLLMBase(ABC): ) gen_conf.pop("max_tokens", None) + gen_conf = {k: v for k, v in gen_conf.items() if k in LITELLM_ALLOWED_GEN_CONF_KEYS} return gen_conf def _need_reasoning_content_back(self) -> bool: diff --git a/test/unit_test/rag/graphrag/conftest.py b/test/unit_test/rag/graphrag/conftest.py index 8aa4f43e81..a980592b5d 100644 --- a/test/unit_test/rag/graphrag/conftest.py +++ b/test/unit_test/rag/graphrag/conftest.py @@ -47,4 +47,12 @@ sys.modules["common.connection_utils"].timeout = lambda *a, **kw: (lambda fn: fn sys.modules["api.db.services.task_service"].has_canceled = lambda *_a, **_kw: False sys.modules["rag.graphrag.general.leiden"].run = lambda *_a, **_kw: {} sys.modules["rag.graphrag.general.leiden"].add_community_info2graph = lambda *_a, **_kw: None -sys.modules["rag.llm.chat_model"].Base = object +# Only stub ``Base`` when we actually mocked chat_model. This conftest mutates +# the global sys.modules at import time, and rag/graphrag/ is collected before +# rag/llm/. If an earlier test package already imported the real chat_model, +# unconditionally assigning ``Base = object`` clobbered the genuine class and +# leaked into the rag/llm unit tests that import it (AttributeError: no +# attribute '_clean_conf'). graphrag only uses ``Base`` as a type alias, so the +# real class works just as well when it is already loaded. +if isinstance(sys.modules["rag.llm.chat_model"], MagicMock): + sys.modules["rag.llm.chat_model"].Base = object diff --git a/test/unit_test/rag/llm/test_clean_conf_whitelist.py b/test/unit_test/rag/llm/test_clean_conf_whitelist.py new file mode 100644 index 0000000000..019a27be1a --- /dev/null +++ b/test/unit_test/rag/llm/test_clean_conf_whitelist.py @@ -0,0 +1,123 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Regression guard for issue #15427: LiteLLM-routed chats failed with +``model_type: Extra inputs are not permitted`` (Anthropic) / +``Unknown parameter: 'model_type'`` (OpenAI). + +A chat assistant's ``llm_setting`` is forwarded to the provider as +``gen_conf``. ``llm_setting`` can legitimately carry RAGFlow-internal +metadata such as ``model_type`` (the new chat REST APIs read it back out). +``Base._clean_conf`` already whitelisted the keys it forwards, so OpenAI- +compatible providers were unaffected, but ``LiteLLMBase._clean_conf`` only +dropped ``max_tokens`` and passed everything else straight through to +``litellm.acompletion`` — which forwarded ``model_type`` to the upstream +provider and got rejected. + +These tests pin the whitelisting behaviour for both backends so the leak +cannot reappear. +""" + +import pytest + +from rag.llm.chat_model import ( + ALLOWED_GEN_CONF_KEYS, + LITELLM_ALLOWED_GEN_CONF_KEYS, + Base, + LiteLLMBase, +) + + +class _ConcreteBase(Base): + """Concrete subclass so we can build an instance without touching the + real OpenAI client constructor.""" + + +class _ConcreteLiteLLM(LiteLLMBase): + """Concrete subclass for the same reason on the LiteLLM path.""" + + +def _make_base(model_name="gpt-4o"): + inst = _ConcreteBase.__new__(_ConcreteBase) + inst.model_name = model_name + return inst + + +def _make_litellm(model_name="gpt-4o", provider=""): + inst = _ConcreteLiteLLM.__new__(_ConcreteLiteLLM) + inst.model_name = model_name + inst.provider = provider + return inst + + +# --------------------------------------------------------------------------- # +# The actual bug: model_type must never reach the provider. +# --------------------------------------------------------------------------- # +def test_litellm_drops_model_type(): + cleaned = _make_litellm()._clean_conf({"temperature": 0.5, "model_type": "chat"}) + assert "model_type" not in cleaned + assert cleaned["temperature"] == 0.5 + + +def test_base_drops_model_type(): + cleaned = _make_base()._clean_conf({"temperature": 0.5, "model_type": "chat"}) + assert "model_type" not in cleaned + assert cleaned["temperature"] == 0.5 + + +@pytest.mark.parametrize("stray_key", ["model_type", "llm_id", "parameter", "icon", "foo"]) +def test_litellm_drops_arbitrary_internal_keys(stray_key): + cleaned = _make_litellm()._clean_conf({stray_key: "x", "top_p": 0.9}) + assert stray_key not in cleaned + assert cleaned["top_p"] == 0.9 + + +# --------------------------------------------------------------------------- # +# The fix must not over-filter: genuine generation params still pass through. +# --------------------------------------------------------------------------- # +def test_litellm_preserves_known_generation_params(): + gen_conf = { + "temperature": 0.7, + "top_p": 0.95, + "presence_penalty": 0.1, + "frequency_penalty": 0.2, + } + cleaned = _make_litellm()._clean_conf(dict(gen_conf)) + assert cleaned == gen_conf + + +def test_litellm_preserves_thinking_param(): + """``thinking`` is injected by the model-family policy for reasoning + models and must survive the whitelist (it is a valid LiteLLM param).""" + cleaned = _make_litellm()._clean_conf({"thinking": {"type": "enabled"}, "temperature": 1.0}) + assert cleaned["thinking"] == {"type": "enabled"} + + +def test_max_tokens_is_dropped_on_both_backends(): + assert "max_tokens" not in _make_litellm()._clean_conf({"max_tokens": 100, "temperature": 0.3}) + assert "max_tokens" not in _make_base()._clean_conf({"max_tokens": 100, "temperature": 0.3}) + + +# --------------------------------------------------------------------------- # +# Whitelist invariants. +# --------------------------------------------------------------------------- # +def test_litellm_whitelist_is_superset_of_base(): + assert ALLOWED_GEN_CONF_KEYS <= LITELLM_ALLOWED_GEN_CONF_KEYS + + +def test_model_type_not_whitelisted_anywhere(): + assert "model_type" not in ALLOWED_GEN_CONF_KEYS + assert "model_type" not in LITELLM_ALLOWED_GEN_CONF_KEYS