From 1046042e01979a83fc2dc807422f674da093faaa Mon Sep 17 00:00:00 2001 From: Ricardo-M-L <69202550+Ricardo-M-L@users.noreply.github.com> Date: Sat, 9 May 2026 13:11:44 +0800 Subject: [PATCH] fix(llm): replace mutable default `gen_conf={}` with None + defensive copy (#14566) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What 19 methods across `rag/llm/chat_model.py` and `rag/llm/cv_model.py` declare `gen_conf={}` (or `gen_conf: dict = {}`) as a parameter default and then mutate `gen_conf` in place — typically `del gen_conf["max_tokens"]`, `gen_conf["penalty_score"] = ...`, or `gen_conf.pop(...)` as part of provider-specific normalization. ### The two bugs in this pattern **1. Mutable default argument (Python footgun).** Python evaluates default values **once** at function-definition time, so the single `{}` dict is *shared* across every caller that doesn't pass `gen_conf`. The first such call's mutations leak into the default seen by every subsequent call. ```python # Before def chat_streamly(self, system, history, gen_conf={}, **kwargs): if "max_tokens" in gen_conf: del gen_conf["max_tokens"] # mutates the SHARED default dict ... ``` After call N with `max_tokens` set, call N+1 that omits `gen_conf` no longer sees `max_tokens` — even though the caller never touched it. **2. Caller-dict pollution.** When the caller *does* pass a `gen_conf` dict, the same in-place mutations modify the caller's dict. A reused `gen_conf` (very common for chat-loop callers that build the config once and pass it on every turn) silently loses `max_tokens`, `presence_penalty`, etc. after the first round. ### The fix In every affected method: - Change `gen_conf={}` (or `gen_conf: dict = {}`) → `gen_conf=None`. - Add `gen_conf = dict(gen_conf or {})` as the first statement of the body so all subsequent mutations operate on a fresh local copy. ```python # After def chat_streamly(self, system, history, gen_conf=None, **kwargs): gen_conf = dict(gen_conf or {}) if "max_tokens" in gen_conf: del gen_conf["max_tokens"] # local copy — safe ... ``` This is byte-for-byte identical provider-side behavior for callers that already pass a fresh `gen_conf` per call. The new `dict(...)` copy is O(small constant) per call. ### Files changed - `rag/llm/chat_model.py` — 17 methods - `rag/llm/cv_model.py` — 2 methods ### Tests Adds `test/unit_test/rag/llm/test_gen_conf_no_mutable_default.py` — an `ast`-based regression guard that walks both modules and asserts no parameter named `gen_conf` ever has a mutable literal (`{}` or `[]`) as its default. The test caught **five additional `gen_conf: dict = {}` sites** that an initial `gen_conf={}` text grep had missed (annotated parameters with whitespace), and would fail again if the pattern is ever reintroduced. ``` $ pytest test/unit_test/rag/llm/test_gen_conf_no_mutable_default.py -v ============================== 3 passed in 0.04s =============================== ``` `ruff check` passes on all touched files. ### Notes - This PR is intentionally focused on **just** the `gen_conf` default + copy fix. There's a related (but separate) `history.insert(0, ...)` pattern in the same files that mutates the caller's history list in 12 places — left for a follow-up so this PR stays mechanical and easy to review. ### Latest revision (`700bb54a7`) — addresses CodeRabbit review - Type annotation: `gen_conf: dict = None` → `gen_conf: dict | None = None` (5 occurrences in `chat_model.py`). The old annotation was a static-checker mismatch since `None` isn't a `dict`. - Regression test: the AST check accessed `default.keys` directly. `ast.List` has no `.keys` attribute — a future `gen_conf=[]` would crash with `AttributeError` instead of being caught. Use `getattr` for both `.keys` (Dict) and `.elts` (List). Manually verified the updated check correctly catches both `gen_conf={}` and `gen_conf=[]` while ignoring `gen_conf=None` and non-empty literals. --------- Co-authored-by: Ricardo --- rag/llm/chat_model.py | 51 ++++++---- rag/llm/cv_model.py | 6 +- .../llm/test_gen_conf_no_mutable_default.py | 94 +++++++++++++++++++ 3 files changed, 132 insertions(+), 19 deletions(-) create mode 100644 test/unit_test/rag/llm/test_gen_conf_no_mutable_default.py diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index 717c43ad93..45b81a6cc7 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -221,7 +221,8 @@ class Base(ABC): ans += LENGTH_NOTIFICATION_EN yield ans, tol - async def async_chat_streamly(self, system, history, gen_conf: dict = {}, **kwargs): + async def async_chat_streamly(self, system, history, gen_conf: dict | None = None, **kwargs): + gen_conf = dict(gen_conf or {}) if system and history and history[0].get("role") != "system": history.insert(0, {"role": "system", "content": system}) gen_conf = self._clean_conf(gen_conf) @@ -356,7 +357,8 @@ class Base(ABC): self.toolcall_session = toolcall_session self.tools = tools - async def async_chat_with_tools(self, system: str, history: list, gen_conf: dict = {}): + async def async_chat_with_tools(self, system: str, history: list, gen_conf: dict | None = None): + gen_conf = dict(gen_conf or {}) gen_conf = self._clean_conf(gen_conf) if system and history and history[0].get("role") != "system": history.insert(0, {"role": "system", "content": system}) @@ -417,7 +419,8 @@ class Base(ABC): assert False, "Shouldn't be here." - async def async_chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict = {}): + async def async_chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict | None = None): + gen_conf = dict(gen_conf or {}) gen_conf = self._clean_conf(gen_conf) tools = self.tools if system and history and history[0].get("role") != "system": @@ -576,7 +579,8 @@ class Base(ABC): ans = self._length_stop(ans) return ans, total_token_count_from_response(response) - async def async_chat(self, system, history, gen_conf={}, **kwargs): + async def async_chat(self, system, history, gen_conf=None, **kwargs): + gen_conf = dict(gen_conf or {}) if system and history and history[0].get("role") != "system": history.insert(0, {"role": "system", "content": system}) gen_conf = self._clean_conf(gen_conf) @@ -642,7 +646,8 @@ class BaiChuanChat(Base): "top_p": gen_conf.get("top_p", 0.85), } - def _chat(self, history, gen_conf={}, **kwargs): + def _chat(self, history, gen_conf=None, **kwargs): + gen_conf = dict(gen_conf or {}) response = self.client.chat.completions.create( model=self.model_name, messages=history, @@ -657,7 +662,8 @@ class BaiChuanChat(Base): ans += LENGTH_NOTIFICATION_EN return ans, total_token_count_from_response(response) - def chat_streamly(self, system, history, gen_conf={}, **kwargs): + def chat_streamly(self, system, history, gen_conf=None, **kwargs): + gen_conf = dict(gen_conf or {}) if system and history and history[0].get("role") != "system": history.insert(0, {"role": "system", "content": system}) if "max_tokens" in gen_conf: @@ -740,7 +746,8 @@ class LocalLLM(Base): yield answer + "\n**ERROR**: " + str(e) yield num_tokens_from_string(answer) - def chat(self, system, history, gen_conf={}, **kwargs): + def chat(self, system, history, gen_conf=None, **kwargs): + gen_conf = dict(gen_conf or {}) if "max_tokens" in gen_conf: del gen_conf["max_tokens"] prompt = self._prepare_prompt(system, history, gen_conf) @@ -749,7 +756,8 @@ class LocalLLM(Base): total_tokens = next(chat_gen) return ans, total_tokens - def chat_streamly(self, system, history, gen_conf={}, **kwargs): + def chat_streamly(self, system, history, gen_conf=None, **kwargs): + gen_conf = dict(gen_conf or {}) if "max_tokens" in gen_conf: del gen_conf["max_tokens"] prompt = self._prepare_prompt(system, history, gen_conf) @@ -788,7 +796,8 @@ class MistralChat(Base): del gen_conf[k] return gen_conf - def _chat(self, history, gen_conf={}, **kwargs): + def _chat(self, history, gen_conf=None, **kwargs): + gen_conf = dict(gen_conf or {}) gen_conf = self._clean_conf(gen_conf) response = self.client.chat(model=self.model_name, messages=history, **gen_conf) ans = response.choices[0].message.content @@ -799,7 +808,8 @@ class MistralChat(Base): ans += LENGTH_NOTIFICATION_EN return ans, total_token_count_from_response(response) - def chat_streamly(self, system, history, gen_conf={}, **kwargs): + def chat_streamly(self, system, history, gen_conf=None, **kwargs): + gen_conf = dict(gen_conf or {}) if system and history and history[0].get("role") != "system": history.insert(0, {"role": "system", "content": system}) gen_conf = self._clean_conf(gen_conf) @@ -867,7 +877,8 @@ class ReplicateChat(Base): self.model_name = model_name self.client = Client(api_token=key) - def _chat(self, history, gen_conf={}, **kwargs): + def _chat(self, history, gen_conf=None, **kwargs): + gen_conf = dict(gen_conf or {}) system = history[0]["content"] if history and history[0]["role"] == "system" else "" prompt = "\n".join([item["role"] + ":" + item["content"] for item in history[-5:] if item["role"] != "system"]) response = self.client.run( @@ -877,7 +888,8 @@ class ReplicateChat(Base): ans = "".join(response) return ans, num_tokens_from_string(ans) - def chat_streamly(self, system, history, gen_conf={}, **kwargs): + def chat_streamly(self, system, history, gen_conf=None, **kwargs): + gen_conf = dict(gen_conf or {}) if "max_tokens" in gen_conf: del gen_conf["max_tokens"] prompt = "\n".join([item["role"] + ":" + item["content"] for item in history[-5:]]) @@ -946,7 +958,8 @@ class BaiduYiyanChat(Base): ans = response["result"] return ans, total_token_count_from_response(response) - def chat_streamly(self, system, history, gen_conf={}, **kwargs): + def chat_streamly(self, system, history, gen_conf=None, **kwargs): + gen_conf = dict(gen_conf or {}) gen_conf["penalty_score"] = ((gen_conf.get("presence_penalty", 0) + gen_conf.get("frequency_penalty", 0)) / 2) + 1 if "max_tokens" in gen_conf: del gen_conf["max_tokens"] @@ -1020,7 +1033,8 @@ class GoogleChat(Base): del gen_conf[k] return gen_conf - def _chat(self, history, gen_conf={}, **kwargs): + def _chat(self, history, gen_conf=None, **kwargs): + gen_conf = dict(gen_conf or {}) system = history[0]["content"] if history and history[0]["role"] == "system" else "" if "claude" in self.model_name: @@ -1098,7 +1112,8 @@ class GoogleChat(Base): return ans, total_tokens - def chat_streamly(self, system, history, gen_conf={}, **kwargs): + def chat_streamly(self, system, history, gen_conf=None, **kwargs): + gen_conf = dict(gen_conf or {}) if "claude" in self.model_name: if "max_tokens" in gen_conf: del gen_conf["max_tokens"] @@ -1545,7 +1560,8 @@ class LiteLLMBase(ABC): self.toolcall_session = toolcall_session self.tools = tools - async def async_chat_with_tools(self, system: str, history: list, gen_conf: dict = {}): + async def async_chat_with_tools(self, system: str, history: list, gen_conf: dict | None = None): + gen_conf = dict(gen_conf or {}) gen_conf = self._clean_conf(gen_conf) if system and history and history[0].get("role") != "system": history.insert(0, {"role": "system", "content": system}) @@ -1622,7 +1638,8 @@ class LiteLLMBase(ABC): assert False, "Shouldn't be here." - async def async_chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict = {}): + async def async_chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict | None = None): + gen_conf = dict(gen_conf or {}) gen_conf = self._clean_conf(gen_conf) tools = self.tools if system and history and history[0].get("role") != "system": diff --git a/rag/llm/cv_model.py b/rag/llm/cv_model.py index 3d23c0a32e..6c3e6e7a1e 100644 --- a/rag/llm/cv_model.py +++ b/rag/llm/cv_model.py @@ -437,7 +437,8 @@ class Zhipu4V(GptV4): del gen_conf["frequency_penalty"] return gen_conf - def _request(self, msg, stream, gen_conf={}): + def _request(self, msg, stream, gen_conf=None): + gen_conf = dict(gen_conf or {}) response = requests.post( self.base_url, json={"model": self.model_name, "messages": msg, "stream": stream, **gen_conf}, @@ -1035,7 +1036,8 @@ class NvidiaCV(Base): total_token_count_from_response(response), ) - def _request(self, msg, gen_conf={}): + def _request(self, msg, gen_conf=None): + gen_conf = dict(gen_conf or {}) response = requests.post( url=self.base_url, headers={ diff --git a/test/unit_test/rag/llm/test_gen_conf_no_mutable_default.py b/test/unit_test/rag/llm/test_gen_conf_no_mutable_default.py new file mode 100644 index 0000000000..075d4a65f4 --- /dev/null +++ b/test/unit_test/rag/llm/test_gen_conf_no_mutable_default.py @@ -0,0 +1,94 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Regression guard for mutable default `gen_conf={}` in the LLM provider +integration layer (`rag/llm/chat_model.py`, `rag/llm/cv_model.py`). + +Many provider methods used to declare ``def chat_streamly(..., gen_conf={}, ...)`` +and then mutate ``gen_conf`` in place (``del gen_conf["max_tokens"]``, +``gen_conf["penalty_score"] = ...``). Because Python evaluates default +argument values **once** at function-definition time, that single shared +dict accumulated mutations across calls — every later caller that omitted +``gen_conf`` saw the polluted dict from the previous call. + +The fix is to default to ``None`` and copy at the call site +(``gen_conf = dict(gen_conf or {})``). This test parses both modules with +the ``ast`` module and asserts no parameter named ``gen_conf`` ever has +a mutable literal as its default. +""" +import ast +from pathlib import Path +from typing import Union + +import pytest + +REPO_ROOT = Path(__file__).resolve().parents[4] +TARGET_FILES = [ + REPO_ROOT / "rag" / "llm" / "chat_model.py", + REPO_ROOT / "rag" / "llm" / "cv_model.py", +] + + +def _iter_param_defaults(func: Union[ast.FunctionDef, ast.AsyncFunctionDef]): + """Yield (param_name, default_node) for every parameter with a + non-empty default — covers positional, keyword-only, and the new + positional-only syntax.""" + args = func.args + pos_args = args.args + pos_defaults = args.defaults + # positional defaults are right-aligned with args + for arg, default in zip(pos_args[-len(pos_defaults):], pos_defaults): + yield arg.arg, default + for arg, default in zip(args.kwonlyargs, args.kw_defaults): + if default is not None: + yield arg.arg, default + + +def _find_mutable_gen_conf_defaults(path: Path): + tree = ast.parse(path.read_text(encoding="utf-8")) + bad = [] + for node in ast.walk(tree): + if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): + continue + for name, default in _iter_param_defaults(node): + if name != "gen_conf": + continue + # An empty dict literal `{}` is the original bug. A list literal + # `[]` would be the same class of mistake. Anything else is fine. + # ast.Dict exposes `.keys`; ast.List exposes `.elts`. Use getattr + # for both so `gen_conf=[]` doesn't crash on a missing `.keys`. + if isinstance(default, (ast.Dict, ast.List)) and not getattr(default, "keys", None) and not getattr(default, "elts", None): + bad.append((node.name, default.lineno)) + return bad + + +@pytest.mark.parametrize("path", TARGET_FILES, ids=lambda p: p.name) +def test_no_mutable_default_for_gen_conf(path: Path): + """No function in chat_model.py / cv_model.py should declare + ``gen_conf={}`` (or ``gen_conf=[]``) as a default value.""" + bad = _find_mutable_gen_conf_defaults(path) + assert not bad, ( + f"{path.name} has functions declaring `gen_conf` with a mutable " + f"default: {bad}. Use `gen_conf=None` and copy with " + f"`gen_conf = dict(gen_conf or {{}})` at the top of the function." + ) + + +def test_target_files_exist(): + """Sanity check — if the LLM modules move, this regression guard + must follow them.""" + for path in TARGET_FILES: + assert path.is_file(), f"Expected target file at {path}"