fix(llm): replace mutable default gen_conf={} with None + defensive copy (#14566)

### What

19 methods across `rag/llm/chat_model.py` and `rag/llm/cv_model.py`
declare `gen_conf={}` (or `gen_conf: dict = {}`) as a parameter default
and then mutate `gen_conf` in place — typically `del
gen_conf["max_tokens"]`, `gen_conf["penalty_score"] = ...`, or
`gen_conf.pop(...)` as part of provider-specific normalization.

### The two bugs in this pattern

**1. Mutable default argument (Python footgun).** Python evaluates
default values **once** at function-definition time, so the single `{}`
dict is *shared* across every caller that doesn't pass `gen_conf`. The
first such call's mutations leak into the default seen by every
subsequent call.

```python
# Before
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
    if "max_tokens" in gen_conf:
        del gen_conf["max_tokens"]   # mutates the SHARED default dict
    ...
```

After call N with `max_tokens` set, call N+1 that omits `gen_conf` no
longer sees `max_tokens` — even though the caller never touched it.

**2. Caller-dict pollution.** When the caller *does* pass a `gen_conf`
dict, the same in-place mutations modify the caller's dict. A reused
`gen_conf` (very common for chat-loop callers that build the config once
and pass it on every turn) silently loses `max_tokens`,
`presence_penalty`, etc. after the first round.

### The fix

In every affected method:

- Change `gen_conf={}` (or `gen_conf: dict = {}`) → `gen_conf=None`.
- Add `gen_conf = dict(gen_conf or {})` as the first statement of the
body so all subsequent mutations operate on a fresh local copy.

```python
# After
def chat_streamly(self, system, history, gen_conf=None, **kwargs):
    gen_conf = dict(gen_conf or {})
    if "max_tokens" in gen_conf:
        del gen_conf["max_tokens"]   # local copy — safe
    ...
```

This is byte-for-byte identical provider-side behavior for callers that
already pass a fresh `gen_conf` per call. The new `dict(...)` copy is
O(small constant) per call.

### Files changed

- `rag/llm/chat_model.py` — 17 methods
- `rag/llm/cv_model.py` — 2 methods

### Tests

Adds `test/unit_test/rag/llm/test_gen_conf_no_mutable_default.py` — an
`ast`-based regression guard that walks both modules and asserts no
parameter named `gen_conf` ever has a mutable literal (`{}` or `[]`) as
its default. The test caught **five additional `gen_conf: dict = {}`
sites** that an initial `gen_conf={}` text grep had missed (annotated
parameters with whitespace), and would fail again if the pattern is ever
reintroduced.

```
$ pytest test/unit_test/rag/llm/test_gen_conf_no_mutable_default.py -v
============================== 3 passed in 0.04s ===============================
```

`ruff check` passes on all touched files.

### Notes

- This PR is intentionally focused on **just** the `gen_conf` default +
copy fix. There's a related (but separate) `history.insert(0, ...)`
pattern in the same files that mutates the caller's history list in 12
places — left for a follow-up so this PR stays mechanical and easy to
review.

### Latest revision (`700bb54a7`) — addresses CodeRabbit review

- Type annotation: `gen_conf: dict = None` → `gen_conf: dict | None =
None` (5 occurrences in `chat_model.py`). The old annotation was a
static-checker mismatch since `None` isn't a `dict`.
- Regression test: the AST check accessed `default.keys` directly.
`ast.List` has no `.keys` attribute — a future `gen_conf=[]` would crash
with `AttributeError` instead of being caught. Use `getattr` for both
`.keys` (Dict) and `.elts` (List). Manually verified the updated check
correctly catches both `gen_conf={}` and `gen_conf=[]` while ignoring
`gen_conf=None` and non-empty literals.

---------

Co-authored-by: Ricardo <ricardo@example.com>
This commit is contained in:
Ricardo-M-L
2026-05-09 13:11:44 +08:00
committed by GitHub
parent 42504fa18c
commit 1046042e01
3 changed files with 132 additions and 19 deletions

View File

@@ -221,7 +221,8 @@ class Base(ABC):
ans += LENGTH_NOTIFICATION_EN
yield ans, tol
async def async_chat_streamly(self, system, history, gen_conf: dict = {}, **kwargs):
async def async_chat_streamly(self, system, history, gen_conf: dict | None = None, **kwargs):
gen_conf = dict(gen_conf or {})
if system and history and history[0].get("role") != "system":
history.insert(0, {"role": "system", "content": system})
gen_conf = self._clean_conf(gen_conf)
@@ -356,7 +357,8 @@ class Base(ABC):
self.toolcall_session = toolcall_session
self.tools = tools
async def async_chat_with_tools(self, system: str, history: list, gen_conf: dict = {}):
async def async_chat_with_tools(self, system: str, history: list, gen_conf: dict | None = None):
gen_conf = dict(gen_conf or {})
gen_conf = self._clean_conf(gen_conf)
if system and history and history[0].get("role") != "system":
history.insert(0, {"role": "system", "content": system})
@@ -417,7 +419,8 @@ class Base(ABC):
assert False, "Shouldn't be here."
async def async_chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict = {}):
async def async_chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict | None = None):
gen_conf = dict(gen_conf or {})
gen_conf = self._clean_conf(gen_conf)
tools = self.tools
if system and history and history[0].get("role") != "system":
@@ -576,7 +579,8 @@ class Base(ABC):
ans = self._length_stop(ans)
return ans, total_token_count_from_response(response)
async def async_chat(self, system, history, gen_conf={}, **kwargs):
async def async_chat(self, system, history, gen_conf=None, **kwargs):
gen_conf = dict(gen_conf or {})
if system and history and history[0].get("role") != "system":
history.insert(0, {"role": "system", "content": system})
gen_conf = self._clean_conf(gen_conf)
@@ -642,7 +646,8 @@ class BaiChuanChat(Base):
"top_p": gen_conf.get("top_p", 0.85),
}
def _chat(self, history, gen_conf={}, **kwargs):
def _chat(self, history, gen_conf=None, **kwargs):
gen_conf = dict(gen_conf or {})
response = self.client.chat.completions.create(
model=self.model_name,
messages=history,
@@ -657,7 +662,8 @@ class BaiChuanChat(Base):
ans += LENGTH_NOTIFICATION_EN
return ans, total_token_count_from_response(response)
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
def chat_streamly(self, system, history, gen_conf=None, **kwargs):
gen_conf = dict(gen_conf or {})
if system and history and history[0].get("role") != "system":
history.insert(0, {"role": "system", "content": system})
if "max_tokens" in gen_conf:
@@ -740,7 +746,8 @@ class LocalLLM(Base):
yield answer + "\n**ERROR**: " + str(e)
yield num_tokens_from_string(answer)
def chat(self, system, history, gen_conf={}, **kwargs):
def chat(self, system, history, gen_conf=None, **kwargs):
gen_conf = dict(gen_conf or {})
if "max_tokens" in gen_conf:
del gen_conf["max_tokens"]
prompt = self._prepare_prompt(system, history, gen_conf)
@@ -749,7 +756,8 @@ class LocalLLM(Base):
total_tokens = next(chat_gen)
return ans, total_tokens
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
def chat_streamly(self, system, history, gen_conf=None, **kwargs):
gen_conf = dict(gen_conf or {})
if "max_tokens" in gen_conf:
del gen_conf["max_tokens"]
prompt = self._prepare_prompt(system, history, gen_conf)
@@ -788,7 +796,8 @@ class MistralChat(Base):
del gen_conf[k]
return gen_conf
def _chat(self, history, gen_conf={}, **kwargs):
def _chat(self, history, gen_conf=None, **kwargs):
gen_conf = dict(gen_conf or {})
gen_conf = self._clean_conf(gen_conf)
response = self.client.chat(model=self.model_name, messages=history, **gen_conf)
ans = response.choices[0].message.content
@@ -799,7 +808,8 @@ class MistralChat(Base):
ans += LENGTH_NOTIFICATION_EN
return ans, total_token_count_from_response(response)
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
def chat_streamly(self, system, history, gen_conf=None, **kwargs):
gen_conf = dict(gen_conf or {})
if system and history and history[0].get("role") != "system":
history.insert(0, {"role": "system", "content": system})
gen_conf = self._clean_conf(gen_conf)
@@ -867,7 +877,8 @@ class ReplicateChat(Base):
self.model_name = model_name
self.client = Client(api_token=key)
def _chat(self, history, gen_conf={}, **kwargs):
def _chat(self, history, gen_conf=None, **kwargs):
gen_conf = dict(gen_conf or {})
system = history[0]["content"] if history and history[0]["role"] == "system" else ""
prompt = "\n".join([item["role"] + ":" + item["content"] for item in history[-5:] if item["role"] != "system"])
response = self.client.run(
@@ -877,7 +888,8 @@ class ReplicateChat(Base):
ans = "".join(response)
return ans, num_tokens_from_string(ans)
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
def chat_streamly(self, system, history, gen_conf=None, **kwargs):
gen_conf = dict(gen_conf or {})
if "max_tokens" in gen_conf:
del gen_conf["max_tokens"]
prompt = "\n".join([item["role"] + ":" + item["content"] for item in history[-5:]])
@@ -946,7 +958,8 @@ class BaiduYiyanChat(Base):
ans = response["result"]
return ans, total_token_count_from_response(response)
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
def chat_streamly(self, system, history, gen_conf=None, **kwargs):
gen_conf = dict(gen_conf or {})
gen_conf["penalty_score"] = ((gen_conf.get("presence_penalty", 0) + gen_conf.get("frequency_penalty", 0)) / 2) + 1
if "max_tokens" in gen_conf:
del gen_conf["max_tokens"]
@@ -1020,7 +1033,8 @@ class GoogleChat(Base):
del gen_conf[k]
return gen_conf
def _chat(self, history, gen_conf={}, **kwargs):
def _chat(self, history, gen_conf=None, **kwargs):
gen_conf = dict(gen_conf or {})
system = history[0]["content"] if history and history[0]["role"] == "system" else ""
if "claude" in self.model_name:
@@ -1098,7 +1112,8 @@ class GoogleChat(Base):
return ans, total_tokens
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
def chat_streamly(self, system, history, gen_conf=None, **kwargs):
gen_conf = dict(gen_conf or {})
if "claude" in self.model_name:
if "max_tokens" in gen_conf:
del gen_conf["max_tokens"]
@@ -1545,7 +1560,8 @@ class LiteLLMBase(ABC):
self.toolcall_session = toolcall_session
self.tools = tools
async def async_chat_with_tools(self, system: str, history: list, gen_conf: dict = {}):
async def async_chat_with_tools(self, system: str, history: list, gen_conf: dict | None = None):
gen_conf = dict(gen_conf or {})
gen_conf = self._clean_conf(gen_conf)
if system and history and history[0].get("role") != "system":
history.insert(0, {"role": "system", "content": system})
@@ -1622,7 +1638,8 @@ class LiteLLMBase(ABC):
assert False, "Shouldn't be here."
async def async_chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict = {}):
async def async_chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict | None = None):
gen_conf = dict(gen_conf or {})
gen_conf = self._clean_conf(gen_conf)
tools = self.tools
if system and history and history[0].get("role") != "system":

View File

@@ -437,7 +437,8 @@ class Zhipu4V(GptV4):
del gen_conf["frequency_penalty"]
return gen_conf
def _request(self, msg, stream, gen_conf={}):
def _request(self, msg, stream, gen_conf=None):
gen_conf = dict(gen_conf or {})
response = requests.post(
self.base_url,
json={"model": self.model_name, "messages": msg, "stream": stream, **gen_conf},
@@ -1035,7 +1036,8 @@ class NvidiaCV(Base):
total_token_count_from_response(response),
)
def _request(self, msg, gen_conf={}):
def _request(self, msg, gen_conf=None):
gen_conf = dict(gen_conf or {})
response = requests.post(
url=self.base_url,
headers={

View File

@@ -0,0 +1,94 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Regression guard for mutable default `gen_conf={}` in the LLM provider
integration layer (`rag/llm/chat_model.py`, `rag/llm/cv_model.py`).
Many provider methods used to declare ``def chat_streamly(..., gen_conf={}, ...)``
and then mutate ``gen_conf`` in place (``del gen_conf["max_tokens"]``,
``gen_conf["penalty_score"] = ...``). Because Python evaluates default
argument values **once** at function-definition time, that single shared
dict accumulated mutations across calls — every later caller that omitted
``gen_conf`` saw the polluted dict from the previous call.
The fix is to default to ``None`` and copy at the call site
(``gen_conf = dict(gen_conf or {})``). This test parses both modules with
the ``ast`` module and asserts no parameter named ``gen_conf`` ever has
a mutable literal as its default.
"""
import ast
from pathlib import Path
from typing import Union
import pytest
REPO_ROOT = Path(__file__).resolve().parents[4]
TARGET_FILES = [
REPO_ROOT / "rag" / "llm" / "chat_model.py",
REPO_ROOT / "rag" / "llm" / "cv_model.py",
]
def _iter_param_defaults(func: Union[ast.FunctionDef, ast.AsyncFunctionDef]):
"""Yield (param_name, default_node) for every parameter with a
non-empty default — covers positional, keyword-only, and the new
positional-only syntax."""
args = func.args
pos_args = args.args
pos_defaults = args.defaults
# positional defaults are right-aligned with args
for arg, default in zip(pos_args[-len(pos_defaults):], pos_defaults):
yield arg.arg, default
for arg, default in zip(args.kwonlyargs, args.kw_defaults):
if default is not None:
yield arg.arg, default
def _find_mutable_gen_conf_defaults(path: Path):
tree = ast.parse(path.read_text(encoding="utf-8"))
bad = []
for node in ast.walk(tree):
if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
continue
for name, default in _iter_param_defaults(node):
if name != "gen_conf":
continue
# An empty dict literal `{}` is the original bug. A list literal
# `[]` would be the same class of mistake. Anything else is fine.
# ast.Dict exposes `.keys`; ast.List exposes `.elts`. Use getattr
# for both so `gen_conf=[]` doesn't crash on a missing `.keys`.
if isinstance(default, (ast.Dict, ast.List)) and not getattr(default, "keys", None) and not getattr(default, "elts", None):
bad.append((node.name, default.lineno))
return bad
@pytest.mark.parametrize("path", TARGET_FILES, ids=lambda p: p.name)
def test_no_mutable_default_for_gen_conf(path: Path):
"""No function in chat_model.py / cv_model.py should declare
``gen_conf={}`` (or ``gen_conf=[]``) as a default value."""
bad = _find_mutable_gen_conf_defaults(path)
assert not bad, (
f"{path.name} has functions declaring `gen_conf` with a mutable "
f"default: {bad}. Use `gen_conf=None` and copy with "
f"`gen_conf = dict(gen_conf or {{}})` at the top of the function."
)
def test_target_files_exist():
"""Sanity check — if the LLM modules move, this regression guard
must follow them."""
for path in TARGET_FILES:
assert path.is_file(), f"Expected target file at {path}"