From 02c2587ca4309833c4574681a0505ae24554f628 Mon Sep 17 00:00:00 2001 From: hyl64 <78853927+hyl64@users.noreply.github.com> Date: Tue, 12 May 2026 13:05:21 +0800 Subject: [PATCH] fix(agent): support iteration item aliases in child nodes (#14146) ## Summary This PR fixes the iteration variable mismatch reported in #14142. Changes: - restore compatibility for `IterationItem@result` by exposing `result` alongside `item` - support bare iteration aliases like `{item}`, `{index}`, and `{result}` inside iteration child-node inputs - add focused unit/runtime tests covering both alias styles and multi-item iteration execution ## Validation ```bash pytest -q --noconftest \ test/testcases/test_web_api/test_canvas_app/test_iterationitem_unit.py \ test/testcases/test_web_api/test_canvas_app/test_iteration_runtime_unit.py \ test/testcases/test_web_api/test_canvas_app/test_invoke_component_unit.py ``` Result: `12 passed` Closes #14142 --- agent/component/base.py | 32 ++ agent/component/iterationitem.py | 6 +- .../test_iteration_runtime_unit.py | 391 ++++++++++++++++++ .../test_iterationitem_unit.py | 148 +++++++ 4 files changed, 576 insertions(+), 1 deletion(-) create mode 100644 test/testcases/test_web_api/test_canvas_app/test_iteration_runtime_unit.py create mode 100644 test/testcases/test_web_api/test_canvas_app/test_iterationitem_unit.py diff --git a/agent/component/base.py b/agent/component/base.py index 1acfa773d6..299adcd453 100644 --- a/agent/component/base.py +++ b/agent/component/base.py @@ -366,6 +366,7 @@ class ComponentBase(ABC): component_name: str thread_limiter = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT_CHATS", 10))) variable_ref_patt = r"\{* *\{([a-zA-Z:0-9]+@[A-Za-z0-9_.-]+|sys\.[A-Za-z0-9_.]+|env\.[A-Za-z0-9_.]+)\} *\}*" + iteration_alias_patt = r"\{* *\{(item|index|result)\} *\}*" def __str__(self): """ @@ -501,6 +502,23 @@ class ComponentBase(ABC): return {var: self.get_input_value(var) for var, o in self.get_input_elements().items()} + def _resolve_iteration_alias_ref(self, exp: str) -> str | None: + if exp not in {"item", "index", "result"}: + return None + + parent = self.get_parent() + if not parent or parent.component_name.lower() != "iteration": + return None + + for cid, cpn in self._canvas.components.items(): + if cpn.get("parent_id") != parent._id: + continue + if cpn["obj"].component_name.lower() != "iterationitem": + continue + return f"{cid}@{exp}" + + return None + def get_input_elements_from_text(self, txt: str) -> dict[str, dict[str, str]]: res = {} for r in re.finditer(self.variable_ref_patt, txt, flags=re.IGNORECASE | re.DOTALL): @@ -512,6 +530,20 @@ class ComponentBase(ABC): "_retrieval": self._canvas.get_variable_value(f"{cpn_id}@_references") if cpn_id else None, "_cpn_id": cpn_id } + for r in re.finditer(self.iteration_alias_patt, txt, flags=re.IGNORECASE | re.DOTALL): + exp = r.group(1) + if exp in res: + continue + ref = self._resolve_iteration_alias_ref(exp) + if not ref: + continue + cpn_id, var_nm = ref.split("@", 1) + res[exp] = { + "name": (self._canvas.get_component_name(cpn_id) + f"@{var_nm}"), + "value": self._canvas.get_variable_value(ref), + "_retrieval": self._canvas.get_variable_value(f"{cpn_id}@_references"), + "_cpn_id": cpn_id + } return res def get_input_elements(self) -> dict[str, Any]: diff --git a/agent/component/iterationitem.py b/agent/component/iterationitem.py index fad4a44e98..c9134e7c77 100644 --- a/agent/component/iterationitem.py +++ b/agent/component/iterationitem.py @@ -54,7 +54,11 @@ class IterationItem(ComponentBase, ABC): if self.check_if_canceled("IterationItem processing"): return - self.set_output("item", arr[self._idx]) + current_item = arr[self._idx] + self.set_output("item", current_item) + # Keep `result` as a compatibility alias because existing DSL examples + # and downstream references may still consume IterationItem via `@result`. + self.set_output("result", current_item) self.set_output("index", self._idx) self._idx += 1 diff --git a/test/testcases/test_web_api/test_canvas_app/test_iteration_runtime_unit.py b/test/testcases/test_web_api/test_canvas_app/test_iteration_runtime_unit.py new file mode 100644 index 0000000000..e73139ec26 --- /dev/null +++ b/test/testcases/test_web_api/test_canvas_app/test_iteration_runtime_unit.py @@ -0,0 +1,391 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import asyncio +import importlib.util +import json +import sys +from pathlib import Path +from types import ModuleType, SimpleNamespace + +import pytest + + +def _load_canvas_runtime(monkeypatch): + repo_root = Path(__file__).resolve().parents[4] + + quart = ModuleType("quart") + quart.make_response = lambda *a, **kw: None + quart.jsonify = lambda *a, **kw: None + monkeypatch.setitem(sys.modules, "quart", quart) + + common_pkg = ModuleType("common") + common_pkg.__path__ = [str(repo_root / "common")] + monkeypatch.setitem(sys.modules, "common", common_pkg) + + common_constants = ModuleType("common.constants") + common_constants.LLMType = SimpleNamespace(TTS="tts") + monkeypatch.setitem(sys.modules, "common.constants", common_constants) + + common_misc = ModuleType("common.misc_utils") + common_misc.get_uuid = lambda: "uuid" + common_misc.hash_str2int = lambda x: 1 + + async def _thread_pool_exec(fn, *args, **kwargs): + return fn(*args, **kwargs) + + common_misc.thread_pool_exec = _thread_pool_exec + monkeypatch.setitem(sys.modules, "common.misc_utils", common_misc) + + common_conn = ModuleType("common.connection_utils") + + def timeout(_seconds): + def decorator(fn): + return fn + + return decorator + + common_conn.timeout = timeout + monkeypatch.setitem(sys.modules, "common.connection_utils", common_conn) + + common_ex = ModuleType("common.exceptions") + + class TaskCanceledException(Exception): + pass + + common_ex.TaskCanceledException = TaskCanceledException + monkeypatch.setitem(sys.modules, "common.exceptions", common_ex) + + api_pkg = ModuleType("api") + api_pkg.__path__ = [str(repo_root / "api")] + monkeypatch.setitem(sys.modules, "api", api_pkg) + api_db_pkg = ModuleType("api.db") + api_db_pkg.__path__ = [str(repo_root / "api" / "db")] + monkeypatch.setitem(sys.modules, "api.db", api_db_pkg) + api_db_services_pkg = ModuleType("api.db.services") + api_db_services_pkg.__path__ = [str(repo_root / "api" / "db" / "services")] + monkeypatch.setitem(sys.modules, "api.db.services", api_db_services_pkg) + api_db_joint_pkg = ModuleType("api.db.joint_services") + api_db_joint_pkg.__path__ = [str(repo_root / "api" / "db" / "joint_services")] + monkeypatch.setitem(sys.modules, "api.db.joint_services", api_db_joint_pkg) + + file_service = ModuleType("api.db.services.file_service") + file_service.FileService = object + monkeypatch.setitem(sys.modules, "api.db.services.file_service", file_service) + + llm_service = ModuleType("api.db.services.llm_service") + llm_service.LLMBundle = object + monkeypatch.setitem(sys.modules, "api.db.services.llm_service", llm_service) + + task_service = ModuleType("api.db.services.task_service") + task_service.has_canceled = lambda _task_id: False + monkeypatch.setitem(sys.modules, "api.db.services.task_service", task_service) + + tenant_model_service = ModuleType("api.db.joint_services.tenant_model_service") + tenant_model_service.get_tenant_default_model_by_type = lambda *_a, **_kw: None + monkeypatch.setitem( + sys.modules, + "api.db.joint_services.tenant_model_service", + tenant_model_service, + ) + + rag_pkg = ModuleType("rag") + rag_pkg.__path__ = [str(repo_root / "rag")] + monkeypatch.setitem(sys.modules, "rag", rag_pkg) + rag_prompts_pkg = ModuleType("rag.prompts") + rag_prompts_pkg.__path__ = [str(repo_root / "rag" / "prompts")] + monkeypatch.setitem(sys.modules, "rag.prompts", rag_prompts_pkg) + rag_prompts = ModuleType("rag.prompts.generator") + rag_prompts.chunks_format = lambda *_a, **_kw: "" + monkeypatch.setitem(sys.modules, "rag.prompts.generator", rag_prompts) + + rag_utils_pkg = ModuleType("rag.utils") + rag_utils_pkg.__path__ = [str(repo_root / "rag" / "utils")] + monkeypatch.setitem(sys.modules, "rag.utils", rag_utils_pkg) + rag_redis = ModuleType("rag.utils.redis_conn") + rag_redis.REDIS_CONN = SimpleNamespace(delete=lambda *_a, **_kw: None, set=lambda *_a, **_kw: None) + monkeypatch.setitem(sys.modules, "rag.utils.redis_conn", rag_redis) + + agent_pkg = ModuleType("agent") + agent_pkg.__path__ = [str(repo_root / "agent")] + monkeypatch.setitem(sys.modules, "agent", agent_pkg) + + agent_settings = ModuleType("agent.settings") + agent_settings.FLOAT_ZERO = 1e-8 + agent_settings.PARAM_MAXDEPTH = 5 + monkeypatch.setitem(sys.modules, "agent.settings", agent_settings) + + dsl_migration = ModuleType("agent.dsl_migration") + dsl_migration.normalize_chunker_dsl = lambda dsl: dsl + monkeypatch.setitem(sys.modules, "agent.dsl_migration", dsl_migration) + + component_pkg = ModuleType("agent.component") + component_pkg.__path__ = [str(repo_root / "agent" / "component")] + monkeypatch.setitem(sys.modules, "agent.component", component_pkg) + + base_spec = importlib.util.spec_from_file_location( + "agent.component.base", repo_root / "agent" / "component" / "base.py" + ) + base_mod = importlib.util.module_from_spec(base_spec) + monkeypatch.setitem(sys.modules, "agent.component.base", base_mod) + base_spec.loader.exec_module(base_mod) + + iteration_spec = importlib.util.spec_from_file_location( + "agent.component.iteration", repo_root / "agent" / "component" / "iteration.py" + ) + iteration_mod = importlib.util.module_from_spec(iteration_spec) + monkeypatch.setitem(sys.modules, "agent.component.iteration", iteration_mod) + iteration_spec.loader.exec_module(iteration_mod) + + iterationitem_spec = importlib.util.spec_from_file_location( + "agent.component.iterationitem", + repo_root / "agent" / "component" / "iterationitem.py", + ) + iterationitem_mod = importlib.util.module_from_spec(iterationitem_spec) + monkeypatch.setitem(sys.modules, "agent.component.iterationitem", iterationitem_mod) + iterationitem_spec.loader.exec_module(iterationitem_mod) + + class BeginParam(base_mod.ComponentParamBase): + def check(self): + return True + + class Begin(base_mod.ComponentBase): + component_name = "Begin" + + def _invoke(self, **kwargs): + return + + def thoughts(self): + return "begin" + + class ProbeParam(base_mod.ComponentParamBase): + def __init__(self): + super().__init__() + self.query = "" + self.inputs = {"query": {"value": None}} + + def get_input_form(self): + return {"query": {"name": "Query", "type": "line"}} + + def check(self): + return True + + class Probe(base_mod.ComponentBase): + component_name = "Probe" + + def _invoke(self, **kwargs): + query_text = kwargs.get("query") + vars_map = self.get_input_elements_from_text(query_text) + query = self.string_format( + query_text, {key: value["value"] for key, value in vars_map.items()} + ) + calls = self._canvas.globals.setdefault("probe.calls", []) + calls.append(query) + self.set_output("result", query) + + def thoughts(self): + return "probe" + + class SinkParam(base_mod.ComponentParamBase): + def check(self): + return True + + class Sink(base_mod.ComponentBase): + component_name = "Sink" + + def _invoke(self, **kwargs): + self.set_output("done", True) + + def thoughts(self): + return "sink" + + class_map = { + "Begin": Begin, + "BeginParam": BeginParam, + "Iteration": iteration_mod.Iteration, + "IterationParam": iteration_mod.IterationParam, + "IterationItem": iterationitem_mod.IterationItem, + "IterationItemParam": iterationitem_mod.IterationItemParam, + "Probe": Probe, + "ProbeParam": ProbeParam, + "Sink": Sink, + "SinkParam": SinkParam, + } + + component_pkg.component_class = lambda name: class_map[name] + + canvas_spec = importlib.util.spec_from_file_location( + "agent.canvas", repo_root / "agent" / "canvas.py" + ) + canvas_mod = importlib.util.module_from_spec(canvas_spec) + monkeypatch.setitem(sys.modules, "agent.canvas", canvas_mod) + canvas_spec.loader.exec_module(canvas_mod) + + return canvas_mod + + +async def _collect_events(canvas): + events = [] + async for event in canvas.run(): + events.append(event) + return events + + +@pytest.mark.p2 +def test_iteration_runtime_processes_all_array_items(monkeypatch): + canvas_mod = _load_canvas_runtime(monkeypatch) + + dsl = { + "components": { + "begin": { + "obj": {"component_name": "Begin", "params": {}}, + "downstream": ["Iteration:1"], + "upstream": [], + }, + "Iteration:1": { + "obj": { + "component_name": "Iteration", + "params": {"items_ref": "env.items"}, + }, + "downstream": ["Sink:1"], + "upstream": ["begin"], + }, + "IterationItem:1": { + "obj": {"component_name": "IterationItem", "params": {}}, + "parent_id": "Iteration:1", + "downstream": ["Probe:1"], + "upstream": [], + }, + "Probe:1": { + "obj": { + "component_name": "Probe", + "params": {"query": "IterationItem:1@result"}, + }, + "parent_id": "Iteration:1", + "downstream": [], + "upstream": ["IterationItem:1"], + }, + "Sink:1": { + "obj": {"component_name": "Sink", "params": {}}, + "downstream": [], + "upstream": ["Iteration:1"], + }, + }, + "graph": { + "nodes": [ + {"id": "begin", "data": {"name": "Begin"}}, + {"id": "Iteration:1", "data": {"name": "Iteration"}}, + {"id": "IterationItem:1", "data": {"name": "IterationItem"}}, + {"id": "Probe:1", "data": {"name": "Probe"}}, + {"id": "Sink:1", "data": {"name": "Sink"}}, + ] + }, + "history": [], + "path": [], + "retrieval": [], + "globals": { + "sys.query": "", + "sys.user_id": "", + "sys.conversation_turns": 0, + "sys.files": [], + "sys.history": [], + "sys.date": "", + "env.items": ["a", "b", "c"], + }, + } + + canvas = canvas_mod.Canvas(json.dumps(dsl)) + events = asyncio.run(_collect_events(canvas)) + + assert canvas.globals["probe.calls"] == ["a", "b", "c"] + assert any(event["event"] == "workflow_finished" for event in events) + + +@pytest.mark.parametrize( + ("query", "expected_calls"), + [ + ("{item}", ["a", "b", "c"]), + ("{index}", ["0", "1", "2"]), + ("{result}", ["a", "b", "c"]), + ], +) +@pytest.mark.p2 +def test_iteration_runtime_supports_bare_iteration_aliases(monkeypatch, query, expected_calls): + canvas_mod = _load_canvas_runtime(monkeypatch) + + dsl = { + "components": { + "begin": { + "obj": {"component_name": "Begin", "params": {}}, + "downstream": ["Iteration:1"], + "upstream": [], + }, + "Iteration:1": { + "obj": { + "component_name": "Iteration", + "params": {"items_ref": "env.items"}, + }, + "downstream": ["Sink:1"], + "upstream": ["begin"], + }, + "IterationItem:1": { + "obj": {"component_name": "IterationItem", "params": {}}, + "parent_id": "Iteration:1", + "downstream": ["Probe:1"], + "upstream": [], + }, + "Probe:1": { + "obj": { + "component_name": "Probe", + "params": {"query": query}, + }, + "parent_id": "Iteration:1", + "downstream": [], + "upstream": ["IterationItem:1"], + }, + "Sink:1": { + "obj": {"component_name": "Sink", "params": {}}, + "downstream": [], + "upstream": ["Iteration:1"], + }, + }, + "graph": { + "nodes": [ + {"id": "begin", "data": {"name": "Begin"}}, + {"id": "Iteration:1", "data": {"name": "Iteration"}}, + {"id": "IterationItem:1", "data": {"name": "IterationItem"}}, + {"id": "Probe:1", "data": {"name": "Probe"}}, + {"id": "Sink:1", "data": {"name": "Sink"}}, + ] + }, + "history": [], + "path": [], + "retrieval": [], + "globals": { + "sys.query": "", + "sys.user_id": "", + "sys.conversation_turns": 0, + "sys.files": [], + "sys.history": [], + "sys.date": "", + "env.items": ["a", "b", "c"], + }, + } + + canvas = canvas_mod.Canvas(json.dumps(dsl)) + asyncio.run(_collect_events(canvas)) + + assert canvas.globals["probe.calls"] == expected_calls diff --git a/test/testcases/test_web_api/test_canvas_app/test_iterationitem_unit.py b/test/testcases/test_web_api/test_canvas_app/test_iterationitem_unit.py new file mode 100644 index 0000000000..1151bb60dc --- /dev/null +++ b/test/testcases/test_web_api/test_canvas_app/test_iterationitem_unit.py @@ -0,0 +1,148 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import importlib.util +import sys +from pathlib import Path +from types import ModuleType, SimpleNamespace +from unittest.mock import MagicMock + +import pytest + + +def _load_iterationitem_module(monkeypatch): + repo_root = Path(__file__).resolve().parents[4] + + quart = ModuleType("quart") + quart.make_response = lambda *a, **kw: None + quart.jsonify = lambda *a, **kw: None + monkeypatch.setitem(sys.modules, "quart", quart) + + common_pkg = ModuleType("common") + common_pkg.__path__ = [str(repo_root / "common")] + monkeypatch.setitem(sys.modules, "common", common_pkg) + + constants = ModuleType("common.constants") + + class _RetCode: + SUCCESS = 0 + EXCEPTION_ERROR = 100 + + constants.RetCode = _RetCode + monkeypatch.setitem(sys.modules, "common.constants", constants) + + conn_spec = importlib.util.spec_from_file_location( + "common.connection_utils", repo_root / "common" / "connection_utils.py" + ) + conn_mod = importlib.util.module_from_spec(conn_spec) + monkeypatch.setitem(sys.modules, "common.connection_utils", conn_mod) + conn_spec.loader.exec_module(conn_mod) + + misc_spec = importlib.util.spec_from_file_location( + "common.misc_utils", repo_root / "common" / "misc_utils.py" + ) + misc_mod = importlib.util.module_from_spec(misc_spec) + monkeypatch.setitem(sys.modules, "common.misc_utils", misc_mod) + misc_spec.loader.exec_module(misc_mod) + + agent_pkg = ModuleType("agent") + agent_pkg.__path__ = [str(repo_root / "agent")] + monkeypatch.setitem(sys.modules, "agent", agent_pkg) + + agent_settings = ModuleType("agent.settings") + agent_settings.FLOAT_ZERO = 1e-8 + agent_settings.PARAM_MAXDEPTH = 5 + monkeypatch.setitem(sys.modules, "agent.settings", agent_settings) + + component_pkg = ModuleType("agent.component") + component_pkg.__path__ = [str(repo_root / "agent" / "component")] + monkeypatch.setitem(sys.modules, "agent.component", component_pkg) + + canvas_mod = ModuleType("agent.canvas") + + class Graph: + pass + + canvas_mod.Graph = Graph + monkeypatch.setitem(sys.modules, "agent.canvas", canvas_mod) + + base_spec = importlib.util.spec_from_file_location( + "agent.component.base", repo_root / "agent" / "component" / "base.py" + ) + base_mod = importlib.util.module_from_spec(base_spec) + monkeypatch.setitem(sys.modules, "agent.component.base", base_mod) + base_spec.loader.exec_module(base_mod) + + iterationitem_spec = importlib.util.spec_from_file_location( + "agent.component.iterationitem", + repo_root / "agent" / "component" / "iterationitem.py", + ) + iterationitem_mod = importlib.util.module_from_spec(iterationitem_spec) + monkeypatch.setitem( + sys.modules, "agent.component.iterationitem", iterationitem_mod + ) + iterationitem_spec.loader.exec_module(iterationitem_mod) + + return iterationitem_mod + + +def _make_iterationitem(module, values): + canvas = MagicMock() + canvas.is_canceled = MagicMock(return_value=False) + canvas.get_variable_value = MagicMock(return_value=values) + canvas.components = {} + + param = module.IterationItemParam() + param.outputs = {} + param.inputs = {} + + inst = module.IterationItem.__new__(module.IterationItem) + inst._canvas = canvas + inst._id = "IterationItem:test" + inst._param = param + inst._idx = 0 + inst.get_parent = MagicMock( + return_value=SimpleNamespace( + _id="Iteration:test", + _param=SimpleNamespace(items_ref="code:1@tempList"), + component_name="Iteration", + ) + ) + return inst + + +@pytest.mark.p2 +def test_iterationitem_exposes_result_alias_for_each_item(monkeypatch): + module = _load_iterationitem_module(monkeypatch) + item = _make_iterationitem(module, ["a", "b", "c"]) + + item._invoke() + assert item.output("item") == "a" + assert item.output("result") == "a" + assert item.output("index") == 0 + + item._invoke() + assert item.output("item") == "b" + assert item.output("result") == "b" + assert item.output("index") == 1 + + item._invoke() + assert item.output("item") == "c" + assert item.output("result") == "c" + assert item.output("index") == 2 + + item._invoke() + assert item.end() is True