mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 15:31:05 +08:00
feat: pass chat_template_kwargs through agent chat completion (#14542)
### What problem does this PR solve? The agent API currently does not pass chat_template_kwargs to the underlying LLM call path, so clients cannot control template-level model behavior (such as thinking-mode toggles) when invoking /agents/chat/completion. This PR adds passthrough support for chat_template_kwargs across agent execution flows (session and non-session, streaming and non-streaming) by propagating it through canvas runtime state and into LLM invocation kwargs. This addresses the feature gap raised in [Issue #14182](https://github.com/infiniflow/ragflow/issues/14182). Closes #14182 ### Type of change - [x] New Feature (non-breaking change which adds functionality)
This commit is contained in:
committed by
GitHub
parent
c33d0b8081
commit
8f90740d2e
@@ -402,7 +402,7 @@ class Canvas(Graph):
|
||||
break
|
||||
|
||||
for k in kwargs.keys():
|
||||
if k in ["query", "user_id", "files"] and kwargs[k]:
|
||||
if k in ["query", "user_id", "files", "chat_template_kwargs"] and kwargs[k]:
|
||||
if k == "files":
|
||||
self.globals[f"sys.{k}"] = await self.get_files_async(kwargs[k], layout_recognize)
|
||||
else:
|
||||
|
||||
@@ -345,6 +345,8 @@ class LLM(ComponentBase):
|
||||
return re.sub(r"(<think>|</think>)", "", delta_ans)
|
||||
|
||||
stream_kwargs = {"images": self.imgs} if self.imgs else {}
|
||||
extra_chat_kwargs = self._get_chat_template_kwargs()
|
||||
stream_kwargs.update(extra_chat_kwargs)
|
||||
async for ans in self.chat_mdl.async_chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), **stream_kwargs):
|
||||
if self.check_if_canceled("LLM streaming"):
|
||||
return
|
||||
@@ -375,6 +377,7 @@ class LLM(ComponentBase):
|
||||
return re.sub(r"```\n*$", "", ans, flags=re.DOTALL)
|
||||
|
||||
prompt, msg, _ = self._prepare_prompt_variables()
|
||||
extra_chat_kwargs = self._get_chat_template_kwargs()
|
||||
error: str = ""
|
||||
output_structure = None
|
||||
try:
|
||||
@@ -393,7 +396,7 @@ class LLM(ComponentBase):
|
||||
int(self.chat_mdl.max_length * 0.97),
|
||||
)
|
||||
error = ""
|
||||
ans = await self._generate_async(msg_fit)
|
||||
ans = await self._generate_async(msg_fit, **extra_chat_kwargs)
|
||||
msg_fit.pop(0)
|
||||
if ans.find("**ERROR**") >= 0:
|
||||
logging.error(f"LLM response error: {ans}")
|
||||
@@ -426,7 +429,7 @@ class LLM(ComponentBase):
|
||||
[{"role": "system", "content": prompt}, *deepcopy(msg)], int(self.chat_mdl.max_length * 0.97)
|
||||
)
|
||||
error = ""
|
||||
ans = await self._generate_async(msg_fit)
|
||||
ans = await self._generate_async(msg_fit, **extra_chat_kwargs)
|
||||
msg_fit.pop(0)
|
||||
if ans.find("**ERROR**") >= 0:
|
||||
logging.error(f"LLM response error: {ans}")
|
||||
@@ -445,6 +448,24 @@ class LLM(ComponentBase):
|
||||
def _invoke(self, **kwargs):
|
||||
return asyncio.run(self._invoke_async(**kwargs))
|
||||
|
||||
def _get_chat_template_kwargs(self) -> dict[str, Any]:
|
||||
chat_template_kwargs = self._canvas.globals.get("sys.chat_template_kwargs")
|
||||
if chat_template_kwargs is None:
|
||||
return {}
|
||||
|
||||
# The API should pass this as a JSON object, but accept a JSON string for compatibility.
|
||||
if isinstance(chat_template_kwargs, str):
|
||||
try:
|
||||
chat_template_kwargs = json_repair.loads(chat_template_kwargs)
|
||||
except Exception:
|
||||
logging.warning("Ignore invalid sys.chat_template_kwargs: expected JSON object or JSON string object.")
|
||||
return {}
|
||||
|
||||
if not isinstance(chat_template_kwargs, dict):
|
||||
logging.warning("Ignore invalid sys.chat_template_kwargs type: %s", type(chat_template_kwargs).__name__)
|
||||
return {}
|
||||
return {"chat_template_kwargs": chat_template_kwargs}
|
||||
|
||||
async def add_memory(self, user:str, assist:str, func_name: str, params: dict, results: str, user_defined_prompt:dict={}):
|
||||
summ = await tool_call_summary(self.chat_mdl, func_name, params, results, user_defined_prompt)
|
||||
logging.info(f"[MEMORY]: {summ}")
|
||||
|
||||
@@ -177,6 +177,7 @@ async def _run_workflow_session(
|
||||
canvas_category,
|
||||
return_trace,
|
||||
stream,
|
||||
chat_template_kwargs=None,
|
||||
):
|
||||
async def commit_runtime_replica():
|
||||
commit_ok = CanvasReplicaService.commit_after_run(
|
||||
@@ -213,6 +214,14 @@ async def _run_workflow_session(
|
||||
final_ans = {}
|
||||
trace_items = []
|
||||
structured_output = {}
|
||||
run_kwargs = {
|
||||
"query": query,
|
||||
"files": files,
|
||||
"user_id": user_id,
|
||||
"inputs": inputs,
|
||||
}
|
||||
if chat_template_kwargs is not None:
|
||||
run_kwargs["chat_template_kwargs"] = chat_template_kwargs
|
||||
|
||||
async def persist_workflow_session():
|
||||
if not final_ans:
|
||||
@@ -237,7 +246,7 @@ async def _run_workflow_session(
|
||||
nonlocal full_content, reference, final_ans, trace_items, structured_output
|
||||
done_sent = False
|
||||
try:
|
||||
async for ans in canvas.run(query=query, files=files, user_id=user_id, inputs=inputs):
|
||||
async for ans in canvas.run(**run_kwargs):
|
||||
ans["session_id"] = session_id
|
||||
if ans.get("event") == "message":
|
||||
full_content += ans.get("data", {}).get("content", "")
|
||||
@@ -285,7 +294,7 @@ async def _run_workflow_session(
|
||||
return _build_sse_response(sse())
|
||||
|
||||
try:
|
||||
async for ans in canvas.run(query=query, files=files, user_id=user_id, inputs=inputs):
|
||||
async for ans in canvas.run(**run_kwargs):
|
||||
ans["session_id"] = session_id
|
||||
if ans.get("event") == "message":
|
||||
full_content += ans.get("data", {}).get("content", "")
|
||||
@@ -1258,6 +1267,7 @@ async def agent_chat_completion(tenant_id, agent_id=None):
|
||||
canvas_category=getattr(cvs, "canvas_category", CanvasCategory.Agent),
|
||||
return_trace=bool(req.get("return_trace", False)),
|
||||
stream=req.get("stream", True),
|
||||
chat_template_kwargs=req.get("chat_template_kwargs"),
|
||||
)
|
||||
|
||||
if not session_id:
|
||||
@@ -1404,6 +1414,7 @@ async def agent_chat_completion(tenant_id, agent_id=None):
|
||||
canvas_category=canvas_category,
|
||||
return_trace=bool(req.get("return_trace", False)),
|
||||
stream=req.get("stream", True),
|
||||
chat_template_kwargs=req.get("chat_template_kwargs"),
|
||||
)
|
||||
|
||||
return_trace = bool(req.get("return_trace", False))
|
||||
|
||||
@@ -315,6 +315,7 @@ async def completion(tenant_id, agent_id, session_id=None, **kwargs):
|
||||
files = kwargs.get("files", [])
|
||||
inputs = kwargs.get("inputs", {})
|
||||
user_id = kwargs.get("user_id", "")
|
||||
chat_template_kwargs = kwargs.get("chat_template_kwargs")
|
||||
custom_header = kwargs.get("custom_header", "")
|
||||
release_mode = str(kwargs.get("release", "")).strip().lower()
|
||||
|
||||
@@ -347,7 +348,16 @@ async def completion(tenant_id, agent_id, session_id=None, **kwargs):
|
||||
"files": files
|
||||
})
|
||||
txt = ""
|
||||
async for ans in canvas.run(query=query, files=files, user_id=user_id, inputs=inputs):
|
||||
run_kwargs = {
|
||||
"query": query,
|
||||
"files": files,
|
||||
"user_id": user_id,
|
||||
"inputs": inputs,
|
||||
}
|
||||
if chat_template_kwargs is not None:
|
||||
run_kwargs["chat_template_kwargs"] = chat_template_kwargs
|
||||
|
||||
async for ans in canvas.run(**run_kwargs):
|
||||
ans["session_id"] = session_id
|
||||
if ans["event"] == "message":
|
||||
txt += ans["data"]["content"]
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
#
|
||||
#
|
||||
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
||||
Reference in New Issue
Block a user