feat: pass chat_template_kwargs through agent chat completion (#14542)

### What problem does this PR solve?

The agent API currently does not pass chat_template_kwargs to the
underlying LLM call path, so clients cannot control template-level model
behavior (such as thinking-mode toggles) when invoking
/agents/chat/completion. This PR adds passthrough support for
chat_template_kwargs across agent execution flows (session and
non-session, streaming and non-streaming) by propagating it through
canvas runtime state and into LLM invocation kwargs. This addresses the
feature gap raised in [Issue
#14182](https://github.com/infiniflow/ragflow/issues/14182).

Closes #14182 

### Type of change

- [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
Full Stack Developer
2026-05-22 02:15:49 -05:00
committed by GitHub
parent c33d0b8081
commit 8f90740d2e
5 changed files with 49 additions and 7 deletions

View File

@@ -402,7 +402,7 @@ class Canvas(Graph):
break
for k in kwargs.keys():
if k in ["query", "user_id", "files"] and kwargs[k]:
if k in ["query", "user_id", "files", "chat_template_kwargs"] and kwargs[k]:
if k == "files":
self.globals[f"sys.{k}"] = await self.get_files_async(kwargs[k], layout_recognize)
else:

View File

@@ -345,6 +345,8 @@ class LLM(ComponentBase):
return re.sub(r"(<think>|</think>)", "", delta_ans)
stream_kwargs = {"images": self.imgs} if self.imgs else {}
extra_chat_kwargs = self._get_chat_template_kwargs()
stream_kwargs.update(extra_chat_kwargs)
async for ans in self.chat_mdl.async_chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), **stream_kwargs):
if self.check_if_canceled("LLM streaming"):
return
@@ -375,6 +377,7 @@ class LLM(ComponentBase):
return re.sub(r"```\n*$", "", ans, flags=re.DOTALL)
prompt, msg, _ = self._prepare_prompt_variables()
extra_chat_kwargs = self._get_chat_template_kwargs()
error: str = ""
output_structure = None
try:
@@ -393,7 +396,7 @@ class LLM(ComponentBase):
int(self.chat_mdl.max_length * 0.97),
)
error = ""
ans = await self._generate_async(msg_fit)
ans = await self._generate_async(msg_fit, **extra_chat_kwargs)
msg_fit.pop(0)
if ans.find("**ERROR**") >= 0:
logging.error(f"LLM response error: {ans}")
@@ -426,7 +429,7 @@ class LLM(ComponentBase):
[{"role": "system", "content": prompt}, *deepcopy(msg)], int(self.chat_mdl.max_length * 0.97)
)
error = ""
ans = await self._generate_async(msg_fit)
ans = await self._generate_async(msg_fit, **extra_chat_kwargs)
msg_fit.pop(0)
if ans.find("**ERROR**") >= 0:
logging.error(f"LLM response error: {ans}")
@@ -445,6 +448,24 @@ class LLM(ComponentBase):
def _invoke(self, **kwargs):
return asyncio.run(self._invoke_async(**kwargs))
def _get_chat_template_kwargs(self) -> dict[str, Any]:
chat_template_kwargs = self._canvas.globals.get("sys.chat_template_kwargs")
if chat_template_kwargs is None:
return {}
# The API should pass this as a JSON object, but accept a JSON string for compatibility.
if isinstance(chat_template_kwargs, str):
try:
chat_template_kwargs = json_repair.loads(chat_template_kwargs)
except Exception:
logging.warning("Ignore invalid sys.chat_template_kwargs: expected JSON object or JSON string object.")
return {}
if not isinstance(chat_template_kwargs, dict):
logging.warning("Ignore invalid sys.chat_template_kwargs type: %s", type(chat_template_kwargs).__name__)
return {}
return {"chat_template_kwargs": chat_template_kwargs}
async def add_memory(self, user:str, assist:str, func_name: str, params: dict, results: str, user_defined_prompt:dict={}):
summ = await tool_call_summary(self.chat_mdl, func_name, params, results, user_defined_prompt)
logging.info(f"[MEMORY]: {summ}")

View File

@@ -177,6 +177,7 @@ async def _run_workflow_session(
canvas_category,
return_trace,
stream,
chat_template_kwargs=None,
):
async def commit_runtime_replica():
commit_ok = CanvasReplicaService.commit_after_run(
@@ -213,6 +214,14 @@ async def _run_workflow_session(
final_ans = {}
trace_items = []
structured_output = {}
run_kwargs = {
"query": query,
"files": files,
"user_id": user_id,
"inputs": inputs,
}
if chat_template_kwargs is not None:
run_kwargs["chat_template_kwargs"] = chat_template_kwargs
async def persist_workflow_session():
if not final_ans:
@@ -237,7 +246,7 @@ async def _run_workflow_session(
nonlocal full_content, reference, final_ans, trace_items, structured_output
done_sent = False
try:
async for ans in canvas.run(query=query, files=files, user_id=user_id, inputs=inputs):
async for ans in canvas.run(**run_kwargs):
ans["session_id"] = session_id
if ans.get("event") == "message":
full_content += ans.get("data", {}).get("content", "")
@@ -285,7 +294,7 @@ async def _run_workflow_session(
return _build_sse_response(sse())
try:
async for ans in canvas.run(query=query, files=files, user_id=user_id, inputs=inputs):
async for ans in canvas.run(**run_kwargs):
ans["session_id"] = session_id
if ans.get("event") == "message":
full_content += ans.get("data", {}).get("content", "")
@@ -1258,6 +1267,7 @@ async def agent_chat_completion(tenant_id, agent_id=None):
canvas_category=getattr(cvs, "canvas_category", CanvasCategory.Agent),
return_trace=bool(req.get("return_trace", False)),
stream=req.get("stream", True),
chat_template_kwargs=req.get("chat_template_kwargs"),
)
if not session_id:
@@ -1404,6 +1414,7 @@ async def agent_chat_completion(tenant_id, agent_id=None):
canvas_category=canvas_category,
return_trace=bool(req.get("return_trace", False)),
stream=req.get("stream", True),
chat_template_kwargs=req.get("chat_template_kwargs"),
)
return_trace = bool(req.get("return_trace", False))

View File

@@ -315,6 +315,7 @@ async def completion(tenant_id, agent_id, session_id=None, **kwargs):
files = kwargs.get("files", [])
inputs = kwargs.get("inputs", {})
user_id = kwargs.get("user_id", "")
chat_template_kwargs = kwargs.get("chat_template_kwargs")
custom_header = kwargs.get("custom_header", "")
release_mode = str(kwargs.get("release", "")).strip().lower()
@@ -347,7 +348,16 @@ async def completion(tenant_id, agent_id, session_id=None, **kwargs):
"files": files
})
txt = ""
async for ans in canvas.run(query=query, files=files, user_id=user_id, inputs=inputs):
run_kwargs = {
"query": query,
"files": files,
"user_id": user_id,
"inputs": inputs,
}
if chat_template_kwargs is not None:
run_kwargs["chat_template_kwargs"] = chat_template_kwargs
async for ans in canvas.run(**run_kwargs):
ans["session_id"] = session_id
if ans["event"] == "message":
txt += ans["data"]["content"]

View File

@@ -1,4 +1,4 @@
#
#
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");