mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 23:41:12 +08:00
fix(llm): replace mutable default gen_conf={} with None + defensive copy (#14566)
### What
19 methods across `rag/llm/chat_model.py` and `rag/llm/cv_model.py`
declare `gen_conf={}` (or `gen_conf: dict = {}`) as a parameter default
and then mutate `gen_conf` in place — typically `del
gen_conf["max_tokens"]`, `gen_conf["penalty_score"] = ...`, or
`gen_conf.pop(...)` as part of provider-specific normalization.
### The two bugs in this pattern
**1. Mutable default argument (Python footgun).** Python evaluates
default values **once** at function-definition time, so the single `{}`
dict is *shared* across every caller that doesn't pass `gen_conf`. The
first such call's mutations leak into the default seen by every
subsequent call.
```python
# Before
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
if "max_tokens" in gen_conf:
del gen_conf["max_tokens"] # mutates the SHARED default dict
...
```
After call N with `max_tokens` set, call N+1 that omits `gen_conf` no
longer sees `max_tokens` — even though the caller never touched it.
**2. Caller-dict pollution.** When the caller *does* pass a `gen_conf`
dict, the same in-place mutations modify the caller's dict. A reused
`gen_conf` (very common for chat-loop callers that build the config once
and pass it on every turn) silently loses `max_tokens`,
`presence_penalty`, etc. after the first round.
### The fix
In every affected method:
- Change `gen_conf={}` (or `gen_conf: dict = {}`) → `gen_conf=None`.
- Add `gen_conf = dict(gen_conf or {})` as the first statement of the
body so all subsequent mutations operate on a fresh local copy.
```python
# After
def chat_streamly(self, system, history, gen_conf=None, **kwargs):
gen_conf = dict(gen_conf or {})
if "max_tokens" in gen_conf:
del gen_conf["max_tokens"] # local copy — safe
...
```
This is byte-for-byte identical provider-side behavior for callers that
already pass a fresh `gen_conf` per call. The new `dict(...)` copy is
O(small constant) per call.
### Files changed
- `rag/llm/chat_model.py` — 17 methods
- `rag/llm/cv_model.py` — 2 methods
### Tests
Adds `test/unit_test/rag/llm/test_gen_conf_no_mutable_default.py` — an
`ast`-based regression guard that walks both modules and asserts no
parameter named `gen_conf` ever has a mutable literal (`{}` or `[]`) as
its default. The test caught **five additional `gen_conf: dict = {}`
sites** that an initial `gen_conf={}` text grep had missed (annotated
parameters with whitespace), and would fail again if the pattern is ever
reintroduced.
```
$ pytest test/unit_test/rag/llm/test_gen_conf_no_mutable_default.py -v
============================== 3 passed in 0.04s ===============================
```
`ruff check` passes on all touched files.
### Notes
- This PR is intentionally focused on **just** the `gen_conf` default +
copy fix. There's a related (but separate) `history.insert(0, ...)`
pattern in the same files that mutates the caller's history list in 12
places — left for a follow-up so this PR stays mechanical and easy to
review.
### Latest revision (`700bb54a7`) — addresses CodeRabbit review
- Type annotation: `gen_conf: dict = None` → `gen_conf: dict | None =
None` (5 occurrences in `chat_model.py`). The old annotation was a
static-checker mismatch since `None` isn't a `dict`.
- Regression test: the AST check accessed `default.keys` directly.
`ast.List` has no `.keys` attribute — a future `gen_conf=[]` would crash
with `AttributeError` instead of being caught. Use `getattr` for both
`.keys` (Dict) and `.elts` (List). Manually verified the updated check
correctly catches both `gen_conf={}` and `gen_conf=[]` while ignoring
`gen_conf=None` and non-empty literals.
---------
Co-authored-by: Ricardo <ricardo@example.com>
This commit is contained in:
@@ -221,7 +221,8 @@ class Base(ABC):
|
||||
ans += LENGTH_NOTIFICATION_EN
|
||||
yield ans, tol
|
||||
|
||||
async def async_chat_streamly(self, system, history, gen_conf: dict = {}, **kwargs):
|
||||
async def async_chat_streamly(self, system, history, gen_conf: dict | None = None, **kwargs):
|
||||
gen_conf = dict(gen_conf or {})
|
||||
if system and history and history[0].get("role") != "system":
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
gen_conf = self._clean_conf(gen_conf)
|
||||
@@ -356,7 +357,8 @@ class Base(ABC):
|
||||
self.toolcall_session = toolcall_session
|
||||
self.tools = tools
|
||||
|
||||
async def async_chat_with_tools(self, system: str, history: list, gen_conf: dict = {}):
|
||||
async def async_chat_with_tools(self, system: str, history: list, gen_conf: dict | None = None):
|
||||
gen_conf = dict(gen_conf or {})
|
||||
gen_conf = self._clean_conf(gen_conf)
|
||||
if system and history and history[0].get("role") != "system":
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
@@ -417,7 +419,8 @@ class Base(ABC):
|
||||
|
||||
assert False, "Shouldn't be here."
|
||||
|
||||
async def async_chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict = {}):
|
||||
async def async_chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict | None = None):
|
||||
gen_conf = dict(gen_conf or {})
|
||||
gen_conf = self._clean_conf(gen_conf)
|
||||
tools = self.tools
|
||||
if system and history and history[0].get("role") != "system":
|
||||
@@ -576,7 +579,8 @@ class Base(ABC):
|
||||
ans = self._length_stop(ans)
|
||||
return ans, total_token_count_from_response(response)
|
||||
|
||||
async def async_chat(self, system, history, gen_conf={}, **kwargs):
|
||||
async def async_chat(self, system, history, gen_conf=None, **kwargs):
|
||||
gen_conf = dict(gen_conf or {})
|
||||
if system and history and history[0].get("role") != "system":
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
gen_conf = self._clean_conf(gen_conf)
|
||||
@@ -642,7 +646,8 @@ class BaiChuanChat(Base):
|
||||
"top_p": gen_conf.get("top_p", 0.85),
|
||||
}
|
||||
|
||||
def _chat(self, history, gen_conf={}, **kwargs):
|
||||
def _chat(self, history, gen_conf=None, **kwargs):
|
||||
gen_conf = dict(gen_conf or {})
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
messages=history,
|
||||
@@ -657,7 +662,8 @@ class BaiChuanChat(Base):
|
||||
ans += LENGTH_NOTIFICATION_EN
|
||||
return ans, total_token_count_from_response(response)
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
|
||||
def chat_streamly(self, system, history, gen_conf=None, **kwargs):
|
||||
gen_conf = dict(gen_conf or {})
|
||||
if system and history and history[0].get("role") != "system":
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
if "max_tokens" in gen_conf:
|
||||
@@ -740,7 +746,8 @@ class LocalLLM(Base):
|
||||
yield answer + "\n**ERROR**: " + str(e)
|
||||
yield num_tokens_from_string(answer)
|
||||
|
||||
def chat(self, system, history, gen_conf={}, **kwargs):
|
||||
def chat(self, system, history, gen_conf=None, **kwargs):
|
||||
gen_conf = dict(gen_conf or {})
|
||||
if "max_tokens" in gen_conf:
|
||||
del gen_conf["max_tokens"]
|
||||
prompt = self._prepare_prompt(system, history, gen_conf)
|
||||
@@ -749,7 +756,8 @@ class LocalLLM(Base):
|
||||
total_tokens = next(chat_gen)
|
||||
return ans, total_tokens
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
|
||||
def chat_streamly(self, system, history, gen_conf=None, **kwargs):
|
||||
gen_conf = dict(gen_conf or {})
|
||||
if "max_tokens" in gen_conf:
|
||||
del gen_conf["max_tokens"]
|
||||
prompt = self._prepare_prompt(system, history, gen_conf)
|
||||
@@ -788,7 +796,8 @@ class MistralChat(Base):
|
||||
del gen_conf[k]
|
||||
return gen_conf
|
||||
|
||||
def _chat(self, history, gen_conf={}, **kwargs):
|
||||
def _chat(self, history, gen_conf=None, **kwargs):
|
||||
gen_conf = dict(gen_conf or {})
|
||||
gen_conf = self._clean_conf(gen_conf)
|
||||
response = self.client.chat(model=self.model_name, messages=history, **gen_conf)
|
||||
ans = response.choices[0].message.content
|
||||
@@ -799,7 +808,8 @@ class MistralChat(Base):
|
||||
ans += LENGTH_NOTIFICATION_EN
|
||||
return ans, total_token_count_from_response(response)
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
|
||||
def chat_streamly(self, system, history, gen_conf=None, **kwargs):
|
||||
gen_conf = dict(gen_conf or {})
|
||||
if system and history and history[0].get("role") != "system":
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
gen_conf = self._clean_conf(gen_conf)
|
||||
@@ -867,7 +877,8 @@ class ReplicateChat(Base):
|
||||
self.model_name = model_name
|
||||
self.client = Client(api_token=key)
|
||||
|
||||
def _chat(self, history, gen_conf={}, **kwargs):
|
||||
def _chat(self, history, gen_conf=None, **kwargs):
|
||||
gen_conf = dict(gen_conf or {})
|
||||
system = history[0]["content"] if history and history[0]["role"] == "system" else ""
|
||||
prompt = "\n".join([item["role"] + ":" + item["content"] for item in history[-5:] if item["role"] != "system"])
|
||||
response = self.client.run(
|
||||
@@ -877,7 +888,8 @@ class ReplicateChat(Base):
|
||||
ans = "".join(response)
|
||||
return ans, num_tokens_from_string(ans)
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
|
||||
def chat_streamly(self, system, history, gen_conf=None, **kwargs):
|
||||
gen_conf = dict(gen_conf or {})
|
||||
if "max_tokens" in gen_conf:
|
||||
del gen_conf["max_tokens"]
|
||||
prompt = "\n".join([item["role"] + ":" + item["content"] for item in history[-5:]])
|
||||
@@ -946,7 +958,8 @@ class BaiduYiyanChat(Base):
|
||||
ans = response["result"]
|
||||
return ans, total_token_count_from_response(response)
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
|
||||
def chat_streamly(self, system, history, gen_conf=None, **kwargs):
|
||||
gen_conf = dict(gen_conf or {})
|
||||
gen_conf["penalty_score"] = ((gen_conf.get("presence_penalty", 0) + gen_conf.get("frequency_penalty", 0)) / 2) + 1
|
||||
if "max_tokens" in gen_conf:
|
||||
del gen_conf["max_tokens"]
|
||||
@@ -1020,7 +1033,8 @@ class GoogleChat(Base):
|
||||
del gen_conf[k]
|
||||
return gen_conf
|
||||
|
||||
def _chat(self, history, gen_conf={}, **kwargs):
|
||||
def _chat(self, history, gen_conf=None, **kwargs):
|
||||
gen_conf = dict(gen_conf or {})
|
||||
system = history[0]["content"] if history and history[0]["role"] == "system" else ""
|
||||
|
||||
if "claude" in self.model_name:
|
||||
@@ -1098,7 +1112,8 @@ class GoogleChat(Base):
|
||||
|
||||
return ans, total_tokens
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf={}, **kwargs):
|
||||
def chat_streamly(self, system, history, gen_conf=None, **kwargs):
|
||||
gen_conf = dict(gen_conf or {})
|
||||
if "claude" in self.model_name:
|
||||
if "max_tokens" in gen_conf:
|
||||
del gen_conf["max_tokens"]
|
||||
@@ -1545,7 +1560,8 @@ class LiteLLMBase(ABC):
|
||||
self.toolcall_session = toolcall_session
|
||||
self.tools = tools
|
||||
|
||||
async def async_chat_with_tools(self, system: str, history: list, gen_conf: dict = {}):
|
||||
async def async_chat_with_tools(self, system: str, history: list, gen_conf: dict | None = None):
|
||||
gen_conf = dict(gen_conf or {})
|
||||
gen_conf = self._clean_conf(gen_conf)
|
||||
if system and history and history[0].get("role") != "system":
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
@@ -1622,7 +1638,8 @@ class LiteLLMBase(ABC):
|
||||
|
||||
assert False, "Shouldn't be here."
|
||||
|
||||
async def async_chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict = {}):
|
||||
async def async_chat_streamly_with_tools(self, system: str, history: list, gen_conf: dict | None = None):
|
||||
gen_conf = dict(gen_conf or {})
|
||||
gen_conf = self._clean_conf(gen_conf)
|
||||
tools = self.tools
|
||||
if system and history and history[0].get("role") != "system":
|
||||
|
||||
@@ -437,7 +437,8 @@ class Zhipu4V(GptV4):
|
||||
del gen_conf["frequency_penalty"]
|
||||
return gen_conf
|
||||
|
||||
def _request(self, msg, stream, gen_conf={}):
|
||||
def _request(self, msg, stream, gen_conf=None):
|
||||
gen_conf = dict(gen_conf or {})
|
||||
response = requests.post(
|
||||
self.base_url,
|
||||
json={"model": self.model_name, "messages": msg, "stream": stream, **gen_conf},
|
||||
@@ -1035,7 +1036,8 @@ class NvidiaCV(Base):
|
||||
total_token_count_from_response(response),
|
||||
)
|
||||
|
||||
def _request(self, msg, gen_conf={}):
|
||||
def _request(self, msg, gen_conf=None):
|
||||
gen_conf = dict(gen_conf or {})
|
||||
response = requests.post(
|
||||
url=self.base_url,
|
||||
headers={
|
||||
|
||||
94
test/unit_test/rag/llm/test_gen_conf_no_mutable_default.py
Normal file
94
test/unit_test/rag/llm/test_gen_conf_no_mutable_default.py
Normal file
@@ -0,0 +1,94 @@
|
||||
#
|
||||
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
"""
|
||||
Regression guard for mutable default `gen_conf={}` in the LLM provider
|
||||
integration layer (`rag/llm/chat_model.py`, `rag/llm/cv_model.py`).
|
||||
|
||||
Many provider methods used to declare ``def chat_streamly(..., gen_conf={}, ...)``
|
||||
and then mutate ``gen_conf`` in place (``del gen_conf["max_tokens"]``,
|
||||
``gen_conf["penalty_score"] = ...``). Because Python evaluates default
|
||||
argument values **once** at function-definition time, that single shared
|
||||
dict accumulated mutations across calls — every later caller that omitted
|
||||
``gen_conf`` saw the polluted dict from the previous call.
|
||||
|
||||
The fix is to default to ``None`` and copy at the call site
|
||||
(``gen_conf = dict(gen_conf or {})``). This test parses both modules with
|
||||
the ``ast`` module and asserts no parameter named ``gen_conf`` ever has
|
||||
a mutable literal as its default.
|
||||
"""
|
||||
import ast
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import pytest
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parents[4]
|
||||
TARGET_FILES = [
|
||||
REPO_ROOT / "rag" / "llm" / "chat_model.py",
|
||||
REPO_ROOT / "rag" / "llm" / "cv_model.py",
|
||||
]
|
||||
|
||||
|
||||
def _iter_param_defaults(func: Union[ast.FunctionDef, ast.AsyncFunctionDef]):
|
||||
"""Yield (param_name, default_node) for every parameter with a
|
||||
non-empty default — covers positional, keyword-only, and the new
|
||||
positional-only syntax."""
|
||||
args = func.args
|
||||
pos_args = args.args
|
||||
pos_defaults = args.defaults
|
||||
# positional defaults are right-aligned with args
|
||||
for arg, default in zip(pos_args[-len(pos_defaults):], pos_defaults):
|
||||
yield arg.arg, default
|
||||
for arg, default in zip(args.kwonlyargs, args.kw_defaults):
|
||||
if default is not None:
|
||||
yield arg.arg, default
|
||||
|
||||
|
||||
def _find_mutable_gen_conf_defaults(path: Path):
|
||||
tree = ast.parse(path.read_text(encoding="utf-8"))
|
||||
bad = []
|
||||
for node in ast.walk(tree):
|
||||
if not isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)):
|
||||
continue
|
||||
for name, default in _iter_param_defaults(node):
|
||||
if name != "gen_conf":
|
||||
continue
|
||||
# An empty dict literal `{}` is the original bug. A list literal
|
||||
# `[]` would be the same class of mistake. Anything else is fine.
|
||||
# ast.Dict exposes `.keys`; ast.List exposes `.elts`. Use getattr
|
||||
# for both so `gen_conf=[]` doesn't crash on a missing `.keys`.
|
||||
if isinstance(default, (ast.Dict, ast.List)) and not getattr(default, "keys", None) and not getattr(default, "elts", None):
|
||||
bad.append((node.name, default.lineno))
|
||||
return bad
|
||||
|
||||
|
||||
@pytest.mark.parametrize("path", TARGET_FILES, ids=lambda p: p.name)
|
||||
def test_no_mutable_default_for_gen_conf(path: Path):
|
||||
"""No function in chat_model.py / cv_model.py should declare
|
||||
``gen_conf={}`` (or ``gen_conf=[]``) as a default value."""
|
||||
bad = _find_mutable_gen_conf_defaults(path)
|
||||
assert not bad, (
|
||||
f"{path.name} has functions declaring `gen_conf` with a mutable "
|
||||
f"default: {bad}. Use `gen_conf=None` and copy with "
|
||||
f"`gen_conf = dict(gen_conf or {{}})` at the top of the function."
|
||||
)
|
||||
|
||||
|
||||
def test_target_files_exist():
|
||||
"""Sanity check — if the LLM modules move, this regression guard
|
||||
must follow them."""
|
||||
for path in TARGET_FILES:
|
||||
assert path.is_file(), f"Expected target file at {path}"
|
||||
Reference in New Issue
Block a user