From 82313020c71b8b91873232c2334c2c1c382f1c49 Mon Sep 17 00:00:00 2001 From: buua436 Date: Mon, 27 Apr 2026 19:13:00 +0800 Subject: [PATCH] Refa: align list operations and strict mode (#14387) ### What problem does this PR solve? align list operations and strict mode ### Type of change - [x] Refactoring --- agent/component/list_operations.py | 82 ++++++-- .../test_list_operations_unit.py | 191 ++++++++++++++++++ web/src/locales/en.ts | 5 +- web/src/locales/zh.ts | 9 +- web/src/pages/agent/constant/index.tsx | 5 +- .../agent/form/list-operations-form/index.tsx | 70 +++++-- 6 files changed, 318 insertions(+), 44 deletions(-) create mode 100644 test/testcases/test_web_api/test_canvas_app/test_list_operations_unit.py diff --git a/agent/component/list_operations.py b/agent/component/list_operations.py index 6016f75850..953e145529 100644 --- a/agent/component/list_operations.py +++ b/agent/component/list_operations.py @@ -10,8 +10,9 @@ class ListOperationsParam(ComponentParamBase): def __init__(self): super().__init__() self.query = "" - self.operations = "topN" - self.n=0 + self.operations = "nth" + self.n = 0 + self.strict = False self.sort_method = "asc" self.filter = { "operator": "=", @@ -34,7 +35,11 @@ class ListOperationsParam(ComponentParamBase): def check(self): self.check_empty(self.query, "query") - self.check_valid_value(self.operations, "Support operations", ["topN","head","tail","filter","sort","drop_duplicates"]) + self.check_valid_value( + self.operations, + "Support operations", + ["nth", "head", "tail", "filter", "sort", "drop_duplicates"], + ) def get_input_form(self) -> dict[str, dict]: return {} @@ -51,8 +56,8 @@ class ListOperations(ComponentBase,ABC): if not isinstance(self.inputs, list): raise TypeError("The input of List Operations should be an array.") self.set_input_value(inputs, self.inputs) - if self._param.operations == "topN": - self._topN() + if self._param.operations == "nth": + self._nth() elif self._param.operations == "head": self._head() elif self._param.operations == "tail": @@ -70,35 +75,74 @@ class ListOperations(ComponentBase,ABC): return int(getattr(self._param, "n", 0)) except Exception: return 0 - + + def _is_strict(self): + strict = getattr(self._param, "strict", False) + if isinstance(strict, str): + return strict.strip().lower() in {"1", "true", "yes", "on"} + return bool(strict) + def _set_outputs(self, outputs): self._param.outputs["result"]["value"] = outputs self._param.outputs["first"]["value"] = outputs[0] if outputs else None self._param.outputs["last"]["value"] = outputs[-1] if outputs else None - - def _topN(self): + + def _raise_strict_range_error(self, operation, n): + raise ValueError( + f"{operation} requires n to be within the valid range in strict mode, got {n}." + ) + + def _nth(self): n = self._coerce_n() - if n < 1: + strict = self._is_strict() + if n == 0: + if strict: + self._raise_strict_range_error("nth", n) outputs = [] + elif n > 0: + if n <= len(self.inputs): + outputs = [self.inputs[n - 1]] + elif strict: + self._raise_strict_range_error("nth", n) + else: + outputs = [] else: - n = min(n, len(self.inputs)) - outputs = self.inputs[:n] + if abs(n) <= len(self.inputs): + outputs = [self.inputs[n]] + elif strict: + self._raise_strict_range_error("nth", n) + else: + outputs = [] self._set_outputs(outputs) def _head(self): n = self._coerce_n() - if 1 <= n <= len(self.inputs): - outputs = [self.inputs[n - 1]] + strict = self._is_strict() + if strict: + if 1 <= n <= len(self.inputs): + outputs = self.inputs[:n] + else: + self._raise_strict_range_error("head", n) else: - outputs = [] + if n < 1: + outputs = [] + else: + outputs = self.inputs[:n] self._set_outputs(outputs) def _tail(self): n = self._coerce_n() - if 1 <= n <= len(self.inputs): - outputs = [self.inputs[-n]] + strict = self._is_strict() + if strict: + if 1 <= n <= len(self.inputs): + outputs = self.inputs[-n:] + else: + self._raise_strict_range_error("tail", n) else: - outputs = [] + if n < 1: + outputs = [] + else: + outputs = self.inputs[-n:] self._set_outputs(outputs) def _filter(self): @@ -107,7 +151,7 @@ class ListOperations(ComponentBase,ABC): def _norm(self,v): s = "" if v is None else str(v) return s - + def _eval(self, v, operator, value): if operator == "=": return v == value @@ -163,6 +207,6 @@ class ListOperations(ComponentBase,ABC): if isinstance(x, set): return tuple(sorted(self._hashable(v) for v in x)) return x - + def thoughts(self) -> str: return "ListOperation in progress" diff --git a/test/testcases/test_web_api/test_canvas_app/test_list_operations_unit.py b/test/testcases/test_web_api/test_canvas_app/test_list_operations_unit.py new file mode 100644 index 0000000000..869a8dc5d6 --- /dev/null +++ b/test/testcases/test_web_api/test_canvas_app/test_list_operations_unit.py @@ -0,0 +1,191 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import importlib.util +import sys +from pathlib import Path +from types import ModuleType, SimpleNamespace + +import pytest + + +def _load_list_operations_module(monkeypatch): + repo_root = Path(__file__).resolve().parents[4] + + agent_pkg = ModuleType("agent") + agent_pkg.__path__ = [str(repo_root / "agent")] + monkeypatch.setitem(sys.modules, "agent", agent_pkg) + + component_pkg = ModuleType("agent.component") + component_pkg.__path__ = [str(repo_root / "agent" / "component")] + monkeypatch.setitem(sys.modules, "agent.component", component_pkg) + + base_mod = ModuleType("agent.component.base") + + class _ComponentParamBase: + def __init__(self): + self.outputs = {} + + def check_empty(self, *_args, **_kwargs): + return None + + def check_valid_value(self, *_args, **_kwargs): + return None + + class _ComponentBase: + def set_input_value(self, *_args, **_kwargs): + return None + + base_mod.ComponentBase = _ComponentBase + base_mod.ComponentParamBase = _ComponentParamBase + monkeypatch.setitem(sys.modules, "agent.component.base", base_mod) + + api_pkg = ModuleType("api") + api_pkg.__path__ = [str(repo_root / "api")] + monkeypatch.setitem(sys.modules, "api", api_pkg) + + api_utils_mod = ModuleType("api.utils.api_utils") + api_utils_mod.timeout = lambda *_args, **_kwargs: (lambda func: func) + monkeypatch.setitem(sys.modules, "api.utils.api_utils", api_utils_mod) + + module_path = repo_root / "agent" / "component" / "list_operations.py" + spec = importlib.util.spec_from_file_location( + "test_list_operations_unit_module", module_path + ) + module = importlib.util.module_from_spec(spec) + monkeypatch.setitem(sys.modules, "test_list_operations_unit_module", module) + spec.loader.exec_module(module) + return module + + +def _make_component(module, *, inputs, operation, n, strict=False): + component = module.ListOperations.__new__(module.ListOperations) + component.inputs = inputs + component._param = SimpleNamespace( + n=n, + strict=strict, + outputs={ + "result": {"value": []}, + "first": {"value": None}, + "last": {"value": None}, + }, + ) + return component + + +@pytest.mark.p2 +@pytest.mark.parametrize( + ("n", "expected"), + [ + (0, []), + (-1, ["e"]), + (-5, ["a"]), + (-6, []), + (2, ["b"]), + (5, ["e"]), + (6, []), + ], +) +def test_nth_behaves_like_lenient_indexing(monkeypatch, n, expected): + module = _load_list_operations_module(monkeypatch) + component = _make_component( + module, inputs=["a", "b", "c", "d", "e"], operation="nth", n=n + ) + component._nth() + assert component._param.outputs["result"]["value"] == expected + + +@pytest.mark.p2 +@pytest.mark.parametrize( + ("strict", "n", "expected"), + [ + (False, 0, []), + (False, 2, ["a", "b"]), + (False, 10, ["a", "b", "c", "d", "e"]), + (True, 2, ["a", "b"]), + ], +) +def test_head_supports_lenient_and_strict(monkeypatch, strict, n, expected): + module = _load_list_operations_module(monkeypatch) + component = _make_component( + module, inputs=["a", "b", "c", "d", "e"], operation="head", n=n, strict=strict + ) + component._head() + assert component._param.outputs["result"]["value"] == expected + + +@pytest.mark.p2 +@pytest.mark.parametrize("n", [0, 10]) +def test_head_strict_raises_for_out_of_range(monkeypatch, n): + module = _load_list_operations_module(monkeypatch) + component = _make_component( + module, inputs=["a", "b", "c", "d", "e"], operation="head", n=n, strict=True + ) + with pytest.raises(ValueError, match="head requires n"): + component._head() + + +@pytest.mark.p2 +@pytest.mark.parametrize( + ("strict", "n", "expected"), + [ + (False, 0, []), + (False, 2, ["d", "e"]), + (False, 10, ["a", "b", "c", "d", "e"]), + (True, 2, ["d", "e"]), + ], +) +def test_tail_supports_lenient_and_strict(monkeypatch, strict, n, expected): + module = _load_list_operations_module(monkeypatch) + component = _make_component( + module, inputs=["a", "b", "c", "d", "e"], operation="tail", n=n, strict=strict + ) + component._tail() + assert component._param.outputs["result"]["value"] == expected + + +@pytest.mark.p2 +@pytest.mark.parametrize("n", [0, 10]) +def test_tail_strict_raises_for_out_of_range(monkeypatch, n): + module = _load_list_operations_module(monkeypatch) + component = _make_component( + module, inputs=["a", "b", "c", "d", "e"], operation="tail", n=n, strict=True + ) + with pytest.raises(ValueError, match="tail requires n"): + component._tail() + + +@pytest.mark.p2 +@pytest.mark.parametrize("n", [0, 6, -6]) +def test_nth_strict_raises_for_out_of_range(monkeypatch, n): + module = _load_list_operations_module(monkeypatch) + component = _make_component( + module, inputs=["a", "b", "c", "d", "e"], operation="nth", n=n, strict=True + ) + with pytest.raises(ValueError, match="nth requires n"): + component._nth() + + +@pytest.mark.p2 +def test_set_outputs_tracks_first_and_last(monkeypatch): + module = _load_list_operations_module(monkeypatch) + component = _make_component( + module, inputs=["a", "b", "c", "d", "e"], operation="tail", n=3 + ) + component._tail() + assert component._param.outputs["result"]["value"] == ["c", "d", "e"] + assert component._param.outputs["first"]["value"] == "c" + assert component._param.outputs["last"]["value"] == "e" diff --git a/web/src/locales/en.ts b/web/src/locales/en.ts index 1876b2b879..88d70fe358 100644 --- a/web/src/locales/en.ts +++ b/web/src/locales/en.ts @@ -2394,7 +2394,7 @@ Important structured information may include: names, dates, locations, events, k renameKeys: 'Rename keys', }, ListOperationsOptions: { - topN: 'Top N', + nth: 'Nth', head: 'Head', tail: 'Tail', sort: 'Sort', @@ -2402,6 +2402,9 @@ Important structured information may include: names, dates, locations, events, k dropDuplicates: 'Drop duplicates', }, sortMethod: 'Sort method', + strictMode: 'Strict mode', + strictModeTip: + 'Off uses lenient behavior and returns an empty result for invalid n. On uses strict behavior and raises an error for out-of-range n.', SortMethodOptions: { asc: 'Ascending', desc: 'Descending', diff --git a/web/src/locales/zh.ts b/web/src/locales/zh.ts index 1a49402c2a..9d62b1b6bc 100644 --- a/web/src/locales/zh.ts +++ b/web/src/locales/zh.ts @@ -2080,14 +2080,17 @@ Tokenizer 会根据所选方式将内容存储为对应的数据结构。`, renameKeys: '重命名键', }, ListOperationsOptions: { - topN: '取前N项', - head: '取前第N项', - tail: '取后第N项', + nth: '第N项', + head: '取前N项', + tail: '取后N项', sort: '排序', filter: '筛选', dropDuplicates: '去重', }, sortMethod: '排序方式', + strictMode: '严格模式', + strictModeTip: + '关闭时使用宽松模式,非法 n 返回空结果;开启时使用严格模式,超出范围的 n 会直接报错。', SortMethodOptions: { asc: '升序', desc: '降序', diff --git a/web/src/pages/agent/constant/index.tsx b/web/src/pages/agent/constant/index.tsx index d4fd25335b..6cbb516715 100644 --- a/web/src/pages/agent/constant/index.tsx +++ b/web/src/pages/agent/constant/index.tsx @@ -587,7 +587,7 @@ export enum SortMethod { } export enum ListOperations { - TopN = 'topN', + Nth = 'nth', Head = 'head', Tail = 'tail', Filter = 'filter', @@ -597,7 +597,8 @@ export enum ListOperations { export const initialListOperationsValues = { query: '', - operations: ListOperations.TopN, + operations: ListOperations.Nth, + strict: false, outputs: { // result: { // type: 'Array', diff --git a/web/src/pages/agent/form/list-operations-form/index.tsx b/web/src/pages/agent/form/list-operations-form/index.tsx index afc44e9075..22cca2519e 100644 --- a/web/src/pages/agent/form/list-operations-form/index.tsx +++ b/web/src/pages/agent/form/list-operations-form/index.tsx @@ -10,6 +10,7 @@ import { FormMessage, } from '@/components/ui/form'; import { Separator } from '@/components/ui/separator'; +import { Switch } from '@/components/ui/switch'; import { useBuildSwitchOperatorOptions } from '@/hooks/logic-hooks/use-build-operator-options'; import { buildOptions } from '@/utils/form'; import { zodResolver } from '@hookform/resolvers/zod'; @@ -38,7 +39,8 @@ import { QueryVariable } from '../components/query-variable'; export const RetrievalPartialSchema = { query: z.string(), operations: z.string(), - n: z.number().int().min(1).optional(), + n: z.number().int().optional(), + strict: z.boolean().optional(), sort_method: z.string().optional(), filter: z .object({ @@ -50,7 +52,7 @@ export const RetrievalPartialSchema = { }; const NumFields = [ - ListOperations.TopN, + ListOperations.Nth, ListOperations.Head, ListOperations.Tail, ]; @@ -71,6 +73,13 @@ function showField(operations: string) { }; } +function getMinValue(operations: string) { + if (operations === ListOperations.Nth) { + return Number.MIN_SAFE_INTEGER; + } + return 0; +} + export const FormSchema = z.object(RetrievalPartialSchema); export type ListOperationsFormSchemaType = z.infer; @@ -129,6 +138,7 @@ function ListOperationsForm({ node }: INextOperatorForm) { ); const { showFilter, showNum, showSortMethod } = showField(operations); + const minValue = getMinValue(operations); const handleOperationsChange = useCallback( (operations: string) => { @@ -180,23 +190,45 @@ function ListOperationsForm({ node }: INextOperatorForm) { )} {showNum && ( - ( - - {t('flow.flowNum')} - - - - - - )} - /> + <> + ( + + {t('flow.flowNum')} + + + + + + )} + /> + ( + + + {t('flow.strictMode')} + + +
+ +
+
+ +
+ )} + /> + )} {showSortMethod && (