diff --git a/agent/canvas.py b/agent/canvas.py index 3421d207ed..8de3d4bfcc 100644 --- a/agent/canvas.py +++ b/agent/canvas.py @@ -402,7 +402,7 @@ class Canvas(Graph): break for k in kwargs.keys(): - if k in ["query", "user_id", "files"] and kwargs[k]: + if k in ["query", "user_id", "files", "chat_template_kwargs"] and kwargs[k]: if k == "files": self.globals[f"sys.{k}"] = await self.get_files_async(kwargs[k], layout_recognize) else: diff --git a/agent/component/llm.py b/agent/component/llm.py index b4e66690a3..126a5a0e8d 100644 --- a/agent/component/llm.py +++ b/agent/component/llm.py @@ -345,6 +345,8 @@ class LLM(ComponentBase): return re.sub(r"(|)", "", delta_ans) stream_kwargs = {"images": self.imgs} if self.imgs else {} + extra_chat_kwargs = self._get_chat_template_kwargs() + stream_kwargs.update(extra_chat_kwargs) async for ans in self.chat_mdl.async_chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf(), **stream_kwargs): if self.check_if_canceled("LLM streaming"): return @@ -375,6 +377,7 @@ class LLM(ComponentBase): return re.sub(r"```\n*$", "", ans, flags=re.DOTALL) prompt, msg, _ = self._prepare_prompt_variables() + extra_chat_kwargs = self._get_chat_template_kwargs() error: str = "" output_structure = None try: @@ -393,7 +396,7 @@ class LLM(ComponentBase): int(self.chat_mdl.max_length * 0.97), ) error = "" - ans = await self._generate_async(msg_fit) + ans = await self._generate_async(msg_fit, **extra_chat_kwargs) msg_fit.pop(0) if ans.find("**ERROR**") >= 0: logging.error(f"LLM response error: {ans}") @@ -426,7 +429,7 @@ class LLM(ComponentBase): [{"role": "system", "content": prompt}, *deepcopy(msg)], int(self.chat_mdl.max_length * 0.97) ) error = "" - ans = await self._generate_async(msg_fit) + ans = await self._generate_async(msg_fit, **extra_chat_kwargs) msg_fit.pop(0) if ans.find("**ERROR**") >= 0: logging.error(f"LLM response error: {ans}") @@ -445,6 +448,24 @@ class LLM(ComponentBase): def _invoke(self, **kwargs): return asyncio.run(self._invoke_async(**kwargs)) + def _get_chat_template_kwargs(self) -> dict[str, Any]: + chat_template_kwargs = self._canvas.globals.get("sys.chat_template_kwargs") + if chat_template_kwargs is None: + return {} + + # The API should pass this as a JSON object, but accept a JSON string for compatibility. + if isinstance(chat_template_kwargs, str): + try: + chat_template_kwargs = json_repair.loads(chat_template_kwargs) + except Exception: + logging.warning("Ignore invalid sys.chat_template_kwargs: expected JSON object or JSON string object.") + return {} + + if not isinstance(chat_template_kwargs, dict): + logging.warning("Ignore invalid sys.chat_template_kwargs type: %s", type(chat_template_kwargs).__name__) + return {} + return {"chat_template_kwargs": chat_template_kwargs} + async def add_memory(self, user:str, assist:str, func_name: str, params: dict, results: str, user_defined_prompt:dict={}): summ = await tool_call_summary(self.chat_mdl, func_name, params, results, user_defined_prompt) logging.info(f"[MEMORY]: {summ}") diff --git a/api/apps/restful_apis/agent_api.py b/api/apps/restful_apis/agent_api.py index 5035822c5b..fada33c594 100644 --- a/api/apps/restful_apis/agent_api.py +++ b/api/apps/restful_apis/agent_api.py @@ -177,6 +177,7 @@ async def _run_workflow_session( canvas_category, return_trace, stream, + chat_template_kwargs=None, ): async def commit_runtime_replica(): commit_ok = CanvasReplicaService.commit_after_run( @@ -213,6 +214,14 @@ async def _run_workflow_session( final_ans = {} trace_items = [] structured_output = {} + run_kwargs = { + "query": query, + "files": files, + "user_id": user_id, + "inputs": inputs, + } + if chat_template_kwargs is not None: + run_kwargs["chat_template_kwargs"] = chat_template_kwargs async def persist_workflow_session(): if not final_ans: @@ -237,7 +246,7 @@ async def _run_workflow_session( nonlocal full_content, reference, final_ans, trace_items, structured_output done_sent = False try: - async for ans in canvas.run(query=query, files=files, user_id=user_id, inputs=inputs): + async for ans in canvas.run(**run_kwargs): ans["session_id"] = session_id if ans.get("event") == "message": full_content += ans.get("data", {}).get("content", "") @@ -285,7 +294,7 @@ async def _run_workflow_session( return _build_sse_response(sse()) try: - async for ans in canvas.run(query=query, files=files, user_id=user_id, inputs=inputs): + async for ans in canvas.run(**run_kwargs): ans["session_id"] = session_id if ans.get("event") == "message": full_content += ans.get("data", {}).get("content", "") @@ -1258,6 +1267,7 @@ async def agent_chat_completion(tenant_id, agent_id=None): canvas_category=getattr(cvs, "canvas_category", CanvasCategory.Agent), return_trace=bool(req.get("return_trace", False)), stream=req.get("stream", True), + chat_template_kwargs=req.get("chat_template_kwargs"), ) if not session_id: @@ -1404,6 +1414,7 @@ async def agent_chat_completion(tenant_id, agent_id=None): canvas_category=canvas_category, return_trace=bool(req.get("return_trace", False)), stream=req.get("stream", True), + chat_template_kwargs=req.get("chat_template_kwargs"), ) return_trace = bool(req.get("return_trace", False)) diff --git a/api/db/services/canvas_service.py b/api/db/services/canvas_service.py index 8c7fe4748f..1777a21b4d 100644 --- a/api/db/services/canvas_service.py +++ b/api/db/services/canvas_service.py @@ -315,6 +315,7 @@ async def completion(tenant_id, agent_id, session_id=None, **kwargs): files = kwargs.get("files", []) inputs = kwargs.get("inputs", {}) user_id = kwargs.get("user_id", "") + chat_template_kwargs = kwargs.get("chat_template_kwargs") custom_header = kwargs.get("custom_header", "") release_mode = str(kwargs.get("release", "")).strip().lower() @@ -347,7 +348,16 @@ async def completion(tenant_id, agent_id, session_id=None, **kwargs): "files": files }) txt = "" - async for ans in canvas.run(query=query, files=files, user_id=user_id, inputs=inputs): + run_kwargs = { + "query": query, + "files": files, + "user_id": user_id, + "inputs": inputs, + } + if chat_template_kwargs is not None: + run_kwargs["chat_template_kwargs"] = chat_template_kwargs + + async for ans in canvas.run(**run_kwargs): ans["session_id"] = session_id if ans["event"] == "message": txt += ans["data"]["content"] diff --git a/test/unit_test/api/apps/sdk/test_dify_retrieval.py b/test/unit_test/api/apps/sdk/test_dify_retrieval.py index a74da5a649..ba830177af 100644 --- a/test/unit_test/api/apps/sdk/test_dify_retrieval.py +++ b/test/unit_test/api/apps/sdk/test_dify_retrieval.py @@ -1,4 +1,4 @@ -# +# # Copyright 2026 The InfiniFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License");