diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index 717c43ad93..45b81a6cc7 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -221,7 +221,8 @@ class Base(ABC): ans += LENGTH_NOTIFICATION_EN yield ans, tol - async def async_chat_streamly(self, system, history, gen_conf: dict = {}, **kwargs): + async def async_chat_streamly(self, system, history, gen_conf: dict | None = None, **kwargs): + gen_conf = dict(gen_conf or {}) if system and history and history[0].get("role") != "system": history.insert(0, {"role": "system", "content": system}) gen_conf = self._clean_conf(gen_conf) @@ -356,7 +357,8 @@ class Base(ABC): self.toolcall_session = toolcall_session self.tools = tools - async def async_chat_with_tools(self, system: str, history: list, gen_conf: dict = {}): + async def async_chat_with_tools(self, system: str, history: list, gen_conf: dict | None = None): + gen_conf = dict(gen_conf or {}) gen_conf = self._clean_conf(gen_conf) if system and history and history[0].get("role") != "system": history.insert(0, {"role": "system", "content": system}) @@ -417,7 +419,8 @@ class Base(ABC): assert False, "Shouldn't be here." - async def async_chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict = {}): + async def async_chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict | None = None): + gen_conf = dict(gen_conf or {}) gen_conf = self._clean_conf(gen_conf) tools = self.tools if system and history and history[0].get("role") != "system": @@ -576,7 +579,8 @@ class Base(ABC): ans = self._length_stop(ans) return ans, total_token_count_from_response(response) - async def async_chat(self, system, history, gen_conf={}, **kwargs): + async def async_chat(self, system, history, gen_conf=None, **kwargs): + gen_conf = dict(gen_conf or {}) if system and history and history[0].get("role") != "system": history.insert(0, {"role": "system", "content": system}) gen_conf = self._clean_conf(gen_conf) @@ -642,7 +646,8 @@ class BaiChuanChat(Base): "top_p": gen_conf.get("top_p", 0.85), } - def _chat(self, history, gen_conf={}, **kwargs): + def _chat(self, history, gen_conf=None, **kwargs): + gen_conf = dict(gen_conf or {}) response = self.client.chat.completions.create( model=self.model_name, messages=history, @@ -657,7 +662,8 @@ class BaiChuanChat(Base): ans += LENGTH_NOTIFICATION_EN return ans, total_token_count_from_response(response) - def chat_streamly(self, system, history, gen_conf={}, **kwargs): + def chat_streamly(self, system, history, gen_conf=None, **kwargs): + gen_conf = dict(gen_conf or {}) if system and history and history[0].get("role") != "system": history.insert(0, {"role": "system", "content": system}) if "max_tokens" in gen_conf: @@ -740,7 +746,8 @@ class LocalLLM(Base): yield answer + "\n**ERROR**: " + str(e) yield num_tokens_from_string(answer) - def chat(self, system, history, gen_conf={}, **kwargs): + def chat(self, system, history, gen_conf=None, **kwargs): + gen_conf = dict(gen_conf or {}) if "max_tokens" in gen_conf: del gen_conf["max_tokens"] prompt = self._prepare_prompt(system, history, gen_conf) @@ -749,7 +756,8 @@ class LocalLLM(Base): total_tokens = next(chat_gen) return ans, total_tokens - def chat_streamly(self, system, history, gen_conf={}, **kwargs): + def chat_streamly(self, system, history, gen_conf=None, **kwargs): + gen_conf = dict(gen_conf or {}) if "max_tokens" in gen_conf: del gen_conf["max_tokens"] prompt = self._prepare_prompt(system, history, gen_conf) @@ -788,7 +796,8 @@ class MistralChat(Base): del gen_conf[k] return gen_conf - def _chat(self, history, gen_conf={}, **kwargs): + def _chat(self, history, gen_conf=None, **kwargs): + gen_conf = dict(gen_conf or {}) gen_conf = self._clean_conf(gen_conf) response = self.client.chat(model=self.model_name, messages=history, **gen_conf) ans = response.choices[0].message.content @@ -799,7 +808,8 @@ class MistralChat(Base): ans += LENGTH_NOTIFICATION_EN return ans, total_token_count_from_response(response) - def chat_streamly(self, system, history, gen_conf={}, **kwargs): + def chat_streamly(self, system, history, gen_conf=None, **kwargs): + gen_conf = dict(gen_conf or {}) if system and history and history[0].get("role") != "system": history.insert(0, {"role": "system", "content": system}) gen_conf = self._clean_conf(gen_conf) @@ -867,7 +877,8 @@ class ReplicateChat(Base): self.model_name = model_name self.client = Client(api_token=key) - def _chat(self, history, gen_conf={}, **kwargs): + def _chat(self, history, gen_conf=None, **kwargs): + gen_conf = dict(gen_conf or {}) system = history[0]["content"] if history and history[0]["role"] == "system" else "" prompt = "\n".join([item["role"] + ":" + item["content"] for item in history[-5:] if item["role"] != "system"]) response = self.client.run( @@ -877,7 +888,8 @@ class ReplicateChat(Base): ans = "".join(response) return ans, num_tokens_from_string(ans) - def chat_streamly(self, system, history, gen_conf={}, **kwargs): + def chat_streamly(self, system, history, gen_conf=None, **kwargs): + gen_conf = dict(gen_conf or {}) if "max_tokens" in gen_conf: del gen_conf["max_tokens"] prompt = "\n".join([item["role"] + ":" + item["content"] for item in history[-5:]]) @@ -946,7 +958,8 @@ class BaiduYiyanChat(Base): ans = response["result"] return ans, total_token_count_from_response(response) - def chat_streamly(self, system, history, gen_conf={}, **kwargs): + def chat_streamly(self, system, history, gen_conf=None, **kwargs): + gen_conf = dict(gen_conf or {}) gen_conf["penalty_score"] = ((gen_conf.get("presence_penalty", 0) + gen_conf.get("frequency_penalty", 0)) / 2) + 1 if "max_tokens" in gen_conf: del gen_conf["max_tokens"] @@ -1020,7 +1033,8 @@ class GoogleChat(Base): del gen_conf[k] return gen_conf - def _chat(self, history, gen_conf={}, **kwargs): + def _chat(self, history, gen_conf=None, **kwargs): + gen_conf = dict(gen_conf or {}) system = history[0]["content"] if history and history[0]["role"] == "system" else "" if "claude" in self.model_name: @@ -1098,7 +1112,8 @@ class GoogleChat(Base): return ans, total_tokens - def chat_streamly(self, system, history, gen_conf={}, **kwargs): + def chat_streamly(self, system, history, gen_conf=None, **kwargs): + gen_conf = dict(gen_conf or {}) if "claude" in self.model_name: if "max_tokens" in gen_conf: del gen_conf["max_tokens"] @@ -1545,7 +1560,8 @@ class LiteLLMBase(ABC): self.toolcall_session = toolcall_session self.tools = tools - async def async_chat_with_tools(self, system: str, history: list, gen_conf: dict = {}): + async def async_chat_with_tools(self, system: str, history: list, gen_conf: dict | None = None): + gen_conf = dict(gen_conf or {}) gen_conf = self._clean_conf(gen_conf) if system and history and history[0].get("role") != "system": history.insert(0, {"role": "system", "content": system}) @@ -1622,7 +1638,8 @@ class LiteLLMBase(ABC): assert False, "Shouldn't be here." - async def async_chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict = {}): + async def async_chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict | None = None): + gen_conf = dict(gen_conf or {}) gen_conf = self._clean_conf(gen_conf) tools = self.tools if system and history and history[0].get("role") != "system": diff --git a/rag/llm/cv_model.py b/rag/llm/cv_model.py index 3d23c0a32e..6c3e6e7a1e 100644 --- a/rag/llm/cv_model.py +++ b/rag/llm/cv_model.py @@ -437,7 +437,8 @@ class Zhipu4V(GptV4): del gen_conf["frequency_penalty"] return gen_conf - def _request(self, msg, stream, gen_conf={}): + def _request(self, msg, stream, gen_conf=None): + gen_conf = dict(gen_conf or {}) response = requests.post( self.base_url, json={"model": self.model_name, "messages": msg, "stream": stream, **gen_conf}, @@ -1035,7 +1036,8 @@ class NvidiaCV(Base): total_token_count_from_response(response), ) - def _request(self, msg, gen_conf={}): + def _request(self, msg, gen_conf=None): + gen_conf = dict(gen_conf or {}) response = requests.post( url=self.base_url, headers={ diff --git a/test/unit_test/rag/llm/test_gen_conf_no_mutable_default.py b/test/unit_test/rag/llm/test_gen_conf_no_mutable_default.py new file mode 100644 index 0000000000..075d4a65f4 --- /dev/null +++ b/test/unit_test/rag/llm/test_gen_conf_no_mutable_default.py @@ -0,0 +1,94 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Regression guard for mutable default `gen_conf={}` in the LLM provider +integration layer (`rag/llm/chat_model.py`, `rag/llm/cv_model.py`). + +Many provider methods used to declare ``def chat_streamly(..., gen_conf={}, ...)`` +and then mutate ``gen_conf`` in place (``del gen_conf["max_tokens"]``, +``gen_conf["penalty_score"] = ...``). Because Python evaluates default +argument values **once** at function-definition time, that single shared +dict accumulated mutations across calls — every later caller that omitted +``gen_conf`` saw the polluted dict from the previous call. + +The fix is to default to ``None`` and copy at the call site +(``gen_conf = dict(gen_conf or {})``). This test parses both modules with +the ``ast`` module and asserts no parameter named ``gen_conf`` ever has +a mutable literal as its default. +""" +import ast +from pathlib import Path +from typing import Union + +import pytest + +REPO_ROOT = Path(__file__).resolve().parents[4] +TARGET_FILES = [ + REPO_ROOT / "rag" / "llm" / "chat_model.py", + REPO_ROOT / "rag" / "llm" / "cv_model.py", +] + + +def _iter_param_defaults(func: Union[ast.FunctionDef, ast.AsyncFunctionDef]): + """Yield (param_name, default_node) for every parameter with a + non-empty default — covers positional, keyword-only, and the new + positional-only syntax.""" + args = func.args + pos_args = args.args + pos_defaults = args.defaults + # positional defaults are right-aligned with args + for arg, default in zip(pos_args[-len(pos_defaults):], pos_defaults): + yield arg.arg, default + for arg, default in zip(args.kwonlyargs, args.kw_defaults): + if default is not None: + yield arg.arg, default + + +def _find_mutable_gen_conf_defaults(path: Path): + tree = ast.parse(path.read_text(encoding="utf-8")) + bad = [] + for node in ast.walk(tree): + if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): + continue + for name, default in _iter_param_defaults(node): + if name != "gen_conf": + continue + # An empty dict literal `{}` is the original bug. A list literal + # `[]` would be the same class of mistake. Anything else is fine. + # ast.Dict exposes `.keys`; ast.List exposes `.elts`. Use getattr + # for both so `gen_conf=[]` doesn't crash on a missing `.keys`. + if isinstance(default, (ast.Dict, ast.List)) and not getattr(default, "keys", None) and not getattr(default, "elts", None): + bad.append((node.name, default.lineno)) + return bad + + +@pytest.mark.parametrize("path", TARGET_FILES, ids=lambda p: p.name) +def test_no_mutable_default_for_gen_conf(path: Path): + """No function in chat_model.py / cv_model.py should declare + ``gen_conf={}`` (or ``gen_conf=[]``) as a default value.""" + bad = _find_mutable_gen_conf_defaults(path) + assert not bad, ( + f"{path.name} has functions declaring `gen_conf` with a mutable " + f"default: {bad}. Use `gen_conf=None` and copy with " + f"`gen_conf = dict(gen_conf or {{}})` at the top of the function." + ) + + +def test_target_files_exist(): + """Sanity check — if the LLM modules move, this regression guard + must follow them.""" + for path in TARGET_FILES: + assert path.is_file(), f"Expected target file at {path}"