Refa: align list operations and strict mode (#14387)

### What problem does this PR solve?

align list operations and strict mode

### Type of change
- [x] Refactoring
This commit is contained in:
buua436
2026-04-27 19:13:00 +08:00
committed by GitHub
parent c1941fd503
commit 82313020c7
6 changed files with 318 additions and 44 deletions

View File

@@ -10,8 +10,9 @@ class ListOperationsParam(ComponentParamBase):
def __init__(self):
super().__init__()
self.query = ""
self.operations = "topN"
self.n=0
self.operations = "nth"
self.n = 0
self.strict = False
self.sort_method = "asc"
self.filter = {
"operator": "=",
@@ -34,7 +35,11 @@ class ListOperationsParam(ComponentParamBase):
def check(self):
self.check_empty(self.query, "query")
self.check_valid_value(self.operations, "Support operations", ["topN","head","tail","filter","sort","drop_duplicates"])
self.check_valid_value(
self.operations,
"Support operations",
["nth", "head", "tail", "filter", "sort", "drop_duplicates"],
)
def get_input_form(self) -> dict[str, dict]:
return {}
@@ -51,8 +56,8 @@ class ListOperations(ComponentBase,ABC):
if not isinstance(self.inputs, list):
raise TypeError("The input of List Operations should be an array.")
self.set_input_value(inputs, self.inputs)
if self._param.operations == "topN":
self._topN()
if self._param.operations == "nth":
self._nth()
elif self._param.operations == "head":
self._head()
elif self._param.operations == "tail":
@@ -70,35 +75,74 @@ class ListOperations(ComponentBase,ABC):
return int(getattr(self._param, "n", 0))
except Exception:
return 0
def _is_strict(self):
strict = getattr(self._param, "strict", False)
if isinstance(strict, str):
return strict.strip().lower() in {"1", "true", "yes", "on"}
return bool(strict)
def _set_outputs(self, outputs):
self._param.outputs["result"]["value"] = outputs
self._param.outputs["first"]["value"] = outputs[0] if outputs else None
self._param.outputs["last"]["value"] = outputs[-1] if outputs else None
def _topN(self):
def _raise_strict_range_error(self, operation, n):
raise ValueError(
f"{operation} requires n to be within the valid range in strict mode, got {n}."
)
def _nth(self):
n = self._coerce_n()
if n < 1:
strict = self._is_strict()
if n == 0:
if strict:
self._raise_strict_range_error("nth", n)
outputs = []
elif n > 0:
if n <= len(self.inputs):
outputs = [self.inputs[n - 1]]
elif strict:
self._raise_strict_range_error("nth", n)
else:
outputs = []
else:
n = min(n, len(self.inputs))
outputs = self.inputs[:n]
if abs(n) <= len(self.inputs):
outputs = [self.inputs[n]]
elif strict:
self._raise_strict_range_error("nth", n)
else:
outputs = []
self._set_outputs(outputs)
def _head(self):
n = self._coerce_n()
if 1 <= n <= len(self.inputs):
outputs = [self.inputs[n - 1]]
strict = self._is_strict()
if strict:
if 1 <= n <= len(self.inputs):
outputs = self.inputs[:n]
else:
self._raise_strict_range_error("head", n)
else:
outputs = []
if n < 1:
outputs = []
else:
outputs = self.inputs[:n]
self._set_outputs(outputs)
def _tail(self):
n = self._coerce_n()
if 1 <= n <= len(self.inputs):
outputs = [self.inputs[-n]]
strict = self._is_strict()
if strict:
if 1 <= n <= len(self.inputs):
outputs = self.inputs[-n:]
else:
self._raise_strict_range_error("tail", n)
else:
outputs = []
if n < 1:
outputs = []
else:
outputs = self.inputs[-n:]
self._set_outputs(outputs)
def _filter(self):
@@ -107,7 +151,7 @@ class ListOperations(ComponentBase,ABC):
def _norm(self,v):
s = "" if v is None else str(v)
return s
def _eval(self, v, operator, value):
if operator == "=":
return v == value
@@ -163,6 +207,6 @@ class ListOperations(ComponentBase,ABC):
if isinstance(x, set):
return tuple(sorted(self._hashable(v) for v in x))
return x
def thoughts(self) -> str:
return "ListOperation in progress"

View File

@@ -0,0 +1,191 @@
#
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import importlib.util
import sys
from pathlib import Path
from types import ModuleType, SimpleNamespace
import pytest
def _load_list_operations_module(monkeypatch):
repo_root = Path(__file__).resolve().parents[4]
agent_pkg = ModuleType("agent")
agent_pkg.__path__ = [str(repo_root / "agent")]
monkeypatch.setitem(sys.modules, "agent", agent_pkg)
component_pkg = ModuleType("agent.component")
component_pkg.__path__ = [str(repo_root / "agent" / "component")]
monkeypatch.setitem(sys.modules, "agent.component", component_pkg)
base_mod = ModuleType("agent.component.base")
class _ComponentParamBase:
def __init__(self):
self.outputs = {}
def check_empty(self, *_args, **_kwargs):
return None
def check_valid_value(self, *_args, **_kwargs):
return None
class _ComponentBase:
def set_input_value(self, *_args, **_kwargs):
return None
base_mod.ComponentBase = _ComponentBase
base_mod.ComponentParamBase = _ComponentParamBase
monkeypatch.setitem(sys.modules, "agent.component.base", base_mod)
api_pkg = ModuleType("api")
api_pkg.__path__ = [str(repo_root / "api")]
monkeypatch.setitem(sys.modules, "api", api_pkg)
api_utils_mod = ModuleType("api.utils.api_utils")
api_utils_mod.timeout = lambda *_args, **_kwargs: (lambda func: func)
monkeypatch.setitem(sys.modules, "api.utils.api_utils", api_utils_mod)
module_path = repo_root / "agent" / "component" / "list_operations.py"
spec = importlib.util.spec_from_file_location(
"test_list_operations_unit_module", module_path
)
module = importlib.util.module_from_spec(spec)
monkeypatch.setitem(sys.modules, "test_list_operations_unit_module", module)
spec.loader.exec_module(module)
return module
def _make_component(module, *, inputs, operation, n, strict=False):
component = module.ListOperations.__new__(module.ListOperations)
component.inputs = inputs
component._param = SimpleNamespace(
n=n,
strict=strict,
outputs={
"result": {"value": []},
"first": {"value": None},
"last": {"value": None},
},
)
return component
@pytest.mark.p2
@pytest.mark.parametrize(
("n", "expected"),
[
(0, []),
(-1, ["e"]),
(-5, ["a"]),
(-6, []),
(2, ["b"]),
(5, ["e"]),
(6, []),
],
)
def test_nth_behaves_like_lenient_indexing(monkeypatch, n, expected):
module = _load_list_operations_module(monkeypatch)
component = _make_component(
module, inputs=["a", "b", "c", "d", "e"], operation="nth", n=n
)
component._nth()
assert component._param.outputs["result"]["value"] == expected
@pytest.mark.p2
@pytest.mark.parametrize(
("strict", "n", "expected"),
[
(False, 0, []),
(False, 2, ["a", "b"]),
(False, 10, ["a", "b", "c", "d", "e"]),
(True, 2, ["a", "b"]),
],
)
def test_head_supports_lenient_and_strict(monkeypatch, strict, n, expected):
module = _load_list_operations_module(monkeypatch)
component = _make_component(
module, inputs=["a", "b", "c", "d", "e"], operation="head", n=n, strict=strict
)
component._head()
assert component._param.outputs["result"]["value"] == expected
@pytest.mark.p2
@pytest.mark.parametrize("n", [0, 10])
def test_head_strict_raises_for_out_of_range(monkeypatch, n):
module = _load_list_operations_module(monkeypatch)
component = _make_component(
module, inputs=["a", "b", "c", "d", "e"], operation="head", n=n, strict=True
)
with pytest.raises(ValueError, match="head requires n"):
component._head()
@pytest.mark.p2
@pytest.mark.parametrize(
("strict", "n", "expected"),
[
(False, 0, []),
(False, 2, ["d", "e"]),
(False, 10, ["a", "b", "c", "d", "e"]),
(True, 2, ["d", "e"]),
],
)
def test_tail_supports_lenient_and_strict(monkeypatch, strict, n, expected):
module = _load_list_operations_module(monkeypatch)
component = _make_component(
module, inputs=["a", "b", "c", "d", "e"], operation="tail", n=n, strict=strict
)
component._tail()
assert component._param.outputs["result"]["value"] == expected
@pytest.mark.p2
@pytest.mark.parametrize("n", [0, 10])
def test_tail_strict_raises_for_out_of_range(monkeypatch, n):
module = _load_list_operations_module(monkeypatch)
component = _make_component(
module, inputs=["a", "b", "c", "d", "e"], operation="tail", n=n, strict=True
)
with pytest.raises(ValueError, match="tail requires n"):
component._tail()
@pytest.mark.p2
@pytest.mark.parametrize("n", [0, 6, -6])
def test_nth_strict_raises_for_out_of_range(monkeypatch, n):
module = _load_list_operations_module(monkeypatch)
component = _make_component(
module, inputs=["a", "b", "c", "d", "e"], operation="nth", n=n, strict=True
)
with pytest.raises(ValueError, match="nth requires n"):
component._nth()
@pytest.mark.p2
def test_set_outputs_tracks_first_and_last(monkeypatch):
module = _load_list_operations_module(monkeypatch)
component = _make_component(
module, inputs=["a", "b", "c", "d", "e"], operation="tail", n=3
)
component._tail()
assert component._param.outputs["result"]["value"] == ["c", "d", "e"]
assert component._param.outputs["first"]["value"] == "c"
assert component._param.outputs["last"]["value"] == "e"

View File

@@ -2394,7 +2394,7 @@ Important structured information may include: names, dates, locations, events, k
renameKeys: 'Rename keys',
},
ListOperationsOptions: {
topN: 'Top N',
nth: 'Nth',
head: 'Head',
tail: 'Tail',
sort: 'Sort',
@@ -2402,6 +2402,9 @@ Important structured information may include: names, dates, locations, events, k
dropDuplicates: 'Drop duplicates',
},
sortMethod: 'Sort method',
strictMode: 'Strict mode',
strictModeTip:
'Off uses lenient behavior and returns an empty result for invalid n. On uses strict behavior and raises an error for out-of-range n.',
SortMethodOptions: {
asc: 'Ascending',
desc: 'Descending',

View File

@@ -2080,14 +2080,17 @@ Tokenizer 会根据所选方式将内容存储为对应的数据结构。`,
renameKeys: '重命名键',
},
ListOperationsOptions: {
topN: '取前N项',
head: '取前N项',
tail: '取后N项',
nth: 'N项',
head: '取前N项',
tail: '取后N项',
sort: '排序',
filter: '筛选',
dropDuplicates: '去重',
},
sortMethod: '排序方式',
strictMode: '严格模式',
strictModeTip:
'关闭时使用宽松模式,非法 n 返回空结果;开启时使用严格模式,超出范围的 n 会直接报错。',
SortMethodOptions: {
asc: '升序',
desc: '降序',

View File

@@ -587,7 +587,7 @@ export enum SortMethod {
}
export enum ListOperations {
TopN = 'topN',
Nth = 'nth',
Head = 'head',
Tail = 'tail',
Filter = 'filter',
@@ -597,7 +597,8 @@ export enum ListOperations {
export const initialListOperationsValues = {
query: '',
operations: ListOperations.TopN,
operations: ListOperations.Nth,
strict: false,
outputs: {
// result: {
// type: 'Array<?>',

View File

@@ -10,6 +10,7 @@ import {
FormMessage,
} from '@/components/ui/form';
import { Separator } from '@/components/ui/separator';
import { Switch } from '@/components/ui/switch';
import { useBuildSwitchOperatorOptions } from '@/hooks/logic-hooks/use-build-operator-options';
import { buildOptions } from '@/utils/form';
import { zodResolver } from '@hookform/resolvers/zod';
@@ -38,7 +39,8 @@ import { QueryVariable } from '../components/query-variable';
export const RetrievalPartialSchema = {
query: z.string(),
operations: z.string(),
n: z.number().int().min(1).optional(),
n: z.number().int().optional(),
strict: z.boolean().optional(),
sort_method: z.string().optional(),
filter: z
.object({
@@ -50,7 +52,7 @@ export const RetrievalPartialSchema = {
};
const NumFields = [
ListOperations.TopN,
ListOperations.Nth,
ListOperations.Head,
ListOperations.Tail,
];
@@ -71,6 +73,13 @@ function showField(operations: string) {
};
}
function getMinValue(operations: string) {
if (operations === ListOperations.Nth) {
return Number.MIN_SAFE_INTEGER;
}
return 0;
}
export const FormSchema = z.object(RetrievalPartialSchema);
export type ListOperationsFormSchemaType = z.infer<typeof FormSchema>;
@@ -129,6 +138,7 @@ function ListOperationsForm({ node }: INextOperatorForm) {
);
const { showFilter, showNum, showSortMethod } = showField(operations);
const minValue = getMinValue(operations);
const handleOperationsChange = useCallback(
(operations: string) => {
@@ -180,23 +190,45 @@ function ListOperationsForm({ node }: INextOperatorForm) {
)}
</RAGFlowFormItem>
{showNum && (
<FormField
control={form.control}
name="n"
render={({ field }) => (
<FormItem>
<FormLabel>{t('flow.flowNum')}</FormLabel>
<FormControl>
<NumberInput
{...field}
className="w-full"
min={1}
></NumberInput>
</FormControl>
<FormMessage />
</FormItem>
)}
/>
<>
<FormField
control={form.control}
name="n"
render={({ field }) => (
<FormItem>
<FormLabel>{t('flow.flowNum')}</FormLabel>
<FormControl>
<NumberInput
{...field}
className="w-full"
min={minValue}
></NumberInput>
</FormControl>
<FormMessage />
</FormItem>
)}
/>
<FormField
control={form.control}
name="strict"
render={({ field }) => (
<FormItem className="space-y-2">
<FormLabel tooltip={t('flow.strictModeTip')}>
{t('flow.strictMode')}
</FormLabel>
<FormControl>
<div className="pt-1">
<Switch
checked={field.value}
onCheckedChange={field.onChange}
></Switch>
</div>
</FormControl>
<FormMessage />
</FormItem>
)}
/>
</>
)}
{showSortMethod && (
<RAGFlowFormItem name="sort_method" label={t('flow.sortMethod')}>