mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 15:31:05 +08:00
Refa: align list operations and strict mode (#14387)
### What problem does this PR solve? align list operations and strict mode ### Type of change - [x] Refactoring
This commit is contained in:
@@ -10,8 +10,9 @@ class ListOperationsParam(ComponentParamBase):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.query = ""
|
||||
self.operations = "topN"
|
||||
self.n=0
|
||||
self.operations = "nth"
|
||||
self.n = 0
|
||||
self.strict = False
|
||||
self.sort_method = "asc"
|
||||
self.filter = {
|
||||
"operator": "=",
|
||||
@@ -34,7 +35,11 @@ class ListOperationsParam(ComponentParamBase):
|
||||
|
||||
def check(self):
|
||||
self.check_empty(self.query, "query")
|
||||
self.check_valid_value(self.operations, "Support operations", ["topN","head","tail","filter","sort","drop_duplicates"])
|
||||
self.check_valid_value(
|
||||
self.operations,
|
||||
"Support operations",
|
||||
["nth", "head", "tail", "filter", "sort", "drop_duplicates"],
|
||||
)
|
||||
|
||||
def get_input_form(self) -> dict[str, dict]:
|
||||
return {}
|
||||
@@ -51,8 +56,8 @@ class ListOperations(ComponentBase,ABC):
|
||||
if not isinstance(self.inputs, list):
|
||||
raise TypeError("The input of List Operations should be an array.")
|
||||
self.set_input_value(inputs, self.inputs)
|
||||
if self._param.operations == "topN":
|
||||
self._topN()
|
||||
if self._param.operations == "nth":
|
||||
self._nth()
|
||||
elif self._param.operations == "head":
|
||||
self._head()
|
||||
elif self._param.operations == "tail":
|
||||
@@ -70,35 +75,74 @@ class ListOperations(ComponentBase,ABC):
|
||||
return int(getattr(self._param, "n", 0))
|
||||
except Exception:
|
||||
return 0
|
||||
|
||||
|
||||
def _is_strict(self):
|
||||
strict = getattr(self._param, "strict", False)
|
||||
if isinstance(strict, str):
|
||||
return strict.strip().lower() in {"1", "true", "yes", "on"}
|
||||
return bool(strict)
|
||||
|
||||
def _set_outputs(self, outputs):
|
||||
self._param.outputs["result"]["value"] = outputs
|
||||
self._param.outputs["first"]["value"] = outputs[0] if outputs else None
|
||||
self._param.outputs["last"]["value"] = outputs[-1] if outputs else None
|
||||
|
||||
def _topN(self):
|
||||
|
||||
def _raise_strict_range_error(self, operation, n):
|
||||
raise ValueError(
|
||||
f"{operation} requires n to be within the valid range in strict mode, got {n}."
|
||||
)
|
||||
|
||||
def _nth(self):
|
||||
n = self._coerce_n()
|
||||
if n < 1:
|
||||
strict = self._is_strict()
|
||||
if n == 0:
|
||||
if strict:
|
||||
self._raise_strict_range_error("nth", n)
|
||||
outputs = []
|
||||
elif n > 0:
|
||||
if n <= len(self.inputs):
|
||||
outputs = [self.inputs[n - 1]]
|
||||
elif strict:
|
||||
self._raise_strict_range_error("nth", n)
|
||||
else:
|
||||
outputs = []
|
||||
else:
|
||||
n = min(n, len(self.inputs))
|
||||
outputs = self.inputs[:n]
|
||||
if abs(n) <= len(self.inputs):
|
||||
outputs = [self.inputs[n]]
|
||||
elif strict:
|
||||
self._raise_strict_range_error("nth", n)
|
||||
else:
|
||||
outputs = []
|
||||
self._set_outputs(outputs)
|
||||
|
||||
def _head(self):
|
||||
n = self._coerce_n()
|
||||
if 1 <= n <= len(self.inputs):
|
||||
outputs = [self.inputs[n - 1]]
|
||||
strict = self._is_strict()
|
||||
if strict:
|
||||
if 1 <= n <= len(self.inputs):
|
||||
outputs = self.inputs[:n]
|
||||
else:
|
||||
self._raise_strict_range_error("head", n)
|
||||
else:
|
||||
outputs = []
|
||||
if n < 1:
|
||||
outputs = []
|
||||
else:
|
||||
outputs = self.inputs[:n]
|
||||
self._set_outputs(outputs)
|
||||
|
||||
def _tail(self):
|
||||
n = self._coerce_n()
|
||||
if 1 <= n <= len(self.inputs):
|
||||
outputs = [self.inputs[-n]]
|
||||
strict = self._is_strict()
|
||||
if strict:
|
||||
if 1 <= n <= len(self.inputs):
|
||||
outputs = self.inputs[-n:]
|
||||
else:
|
||||
self._raise_strict_range_error("tail", n)
|
||||
else:
|
||||
outputs = []
|
||||
if n < 1:
|
||||
outputs = []
|
||||
else:
|
||||
outputs = self.inputs[-n:]
|
||||
self._set_outputs(outputs)
|
||||
|
||||
def _filter(self):
|
||||
@@ -107,7 +151,7 @@ class ListOperations(ComponentBase,ABC):
|
||||
def _norm(self,v):
|
||||
s = "" if v is None else str(v)
|
||||
return s
|
||||
|
||||
|
||||
def _eval(self, v, operator, value):
|
||||
if operator == "=":
|
||||
return v == value
|
||||
@@ -163,6 +207,6 @@ class ListOperations(ComponentBase,ABC):
|
||||
if isinstance(x, set):
|
||||
return tuple(sorted(self._hashable(v) for v in x))
|
||||
return x
|
||||
|
||||
|
||||
def thoughts(self) -> str:
|
||||
return "ListOperation in progress"
|
||||
|
||||
@@ -0,0 +1,191 @@
|
||||
#
|
||||
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import importlib.util
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from types import ModuleType, SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
def _load_list_operations_module(monkeypatch):
|
||||
repo_root = Path(__file__).resolve().parents[4]
|
||||
|
||||
agent_pkg = ModuleType("agent")
|
||||
agent_pkg.__path__ = [str(repo_root / "agent")]
|
||||
monkeypatch.setitem(sys.modules, "agent", agent_pkg)
|
||||
|
||||
component_pkg = ModuleType("agent.component")
|
||||
component_pkg.__path__ = [str(repo_root / "agent" / "component")]
|
||||
monkeypatch.setitem(sys.modules, "agent.component", component_pkg)
|
||||
|
||||
base_mod = ModuleType("agent.component.base")
|
||||
|
||||
class _ComponentParamBase:
|
||||
def __init__(self):
|
||||
self.outputs = {}
|
||||
|
||||
def check_empty(self, *_args, **_kwargs):
|
||||
return None
|
||||
|
||||
def check_valid_value(self, *_args, **_kwargs):
|
||||
return None
|
||||
|
||||
class _ComponentBase:
|
||||
def set_input_value(self, *_args, **_kwargs):
|
||||
return None
|
||||
|
||||
base_mod.ComponentBase = _ComponentBase
|
||||
base_mod.ComponentParamBase = _ComponentParamBase
|
||||
monkeypatch.setitem(sys.modules, "agent.component.base", base_mod)
|
||||
|
||||
api_pkg = ModuleType("api")
|
||||
api_pkg.__path__ = [str(repo_root / "api")]
|
||||
monkeypatch.setitem(sys.modules, "api", api_pkg)
|
||||
|
||||
api_utils_mod = ModuleType("api.utils.api_utils")
|
||||
api_utils_mod.timeout = lambda *_args, **_kwargs: (lambda func: func)
|
||||
monkeypatch.setitem(sys.modules, "api.utils.api_utils", api_utils_mod)
|
||||
|
||||
module_path = repo_root / "agent" / "component" / "list_operations.py"
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
"test_list_operations_unit_module", module_path
|
||||
)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
monkeypatch.setitem(sys.modules, "test_list_operations_unit_module", module)
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
||||
def _make_component(module, *, inputs, operation, n, strict=False):
|
||||
component = module.ListOperations.__new__(module.ListOperations)
|
||||
component.inputs = inputs
|
||||
component._param = SimpleNamespace(
|
||||
n=n,
|
||||
strict=strict,
|
||||
outputs={
|
||||
"result": {"value": []},
|
||||
"first": {"value": None},
|
||||
"last": {"value": None},
|
||||
},
|
||||
)
|
||||
return component
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
@pytest.mark.parametrize(
|
||||
("n", "expected"),
|
||||
[
|
||||
(0, []),
|
||||
(-1, ["e"]),
|
||||
(-5, ["a"]),
|
||||
(-6, []),
|
||||
(2, ["b"]),
|
||||
(5, ["e"]),
|
||||
(6, []),
|
||||
],
|
||||
)
|
||||
def test_nth_behaves_like_lenient_indexing(monkeypatch, n, expected):
|
||||
module = _load_list_operations_module(monkeypatch)
|
||||
component = _make_component(
|
||||
module, inputs=["a", "b", "c", "d", "e"], operation="nth", n=n
|
||||
)
|
||||
component._nth()
|
||||
assert component._param.outputs["result"]["value"] == expected
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
@pytest.mark.parametrize(
|
||||
("strict", "n", "expected"),
|
||||
[
|
||||
(False, 0, []),
|
||||
(False, 2, ["a", "b"]),
|
||||
(False, 10, ["a", "b", "c", "d", "e"]),
|
||||
(True, 2, ["a", "b"]),
|
||||
],
|
||||
)
|
||||
def test_head_supports_lenient_and_strict(monkeypatch, strict, n, expected):
|
||||
module = _load_list_operations_module(monkeypatch)
|
||||
component = _make_component(
|
||||
module, inputs=["a", "b", "c", "d", "e"], operation="head", n=n, strict=strict
|
||||
)
|
||||
component._head()
|
||||
assert component._param.outputs["result"]["value"] == expected
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
@pytest.mark.parametrize("n", [0, 10])
|
||||
def test_head_strict_raises_for_out_of_range(monkeypatch, n):
|
||||
module = _load_list_operations_module(monkeypatch)
|
||||
component = _make_component(
|
||||
module, inputs=["a", "b", "c", "d", "e"], operation="head", n=n, strict=True
|
||||
)
|
||||
with pytest.raises(ValueError, match="head requires n"):
|
||||
component._head()
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
@pytest.mark.parametrize(
|
||||
("strict", "n", "expected"),
|
||||
[
|
||||
(False, 0, []),
|
||||
(False, 2, ["d", "e"]),
|
||||
(False, 10, ["a", "b", "c", "d", "e"]),
|
||||
(True, 2, ["d", "e"]),
|
||||
],
|
||||
)
|
||||
def test_tail_supports_lenient_and_strict(monkeypatch, strict, n, expected):
|
||||
module = _load_list_operations_module(monkeypatch)
|
||||
component = _make_component(
|
||||
module, inputs=["a", "b", "c", "d", "e"], operation="tail", n=n, strict=strict
|
||||
)
|
||||
component._tail()
|
||||
assert component._param.outputs["result"]["value"] == expected
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
@pytest.mark.parametrize("n", [0, 10])
|
||||
def test_tail_strict_raises_for_out_of_range(monkeypatch, n):
|
||||
module = _load_list_operations_module(monkeypatch)
|
||||
component = _make_component(
|
||||
module, inputs=["a", "b", "c", "d", "e"], operation="tail", n=n, strict=True
|
||||
)
|
||||
with pytest.raises(ValueError, match="tail requires n"):
|
||||
component._tail()
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
@pytest.mark.parametrize("n", [0, 6, -6])
|
||||
def test_nth_strict_raises_for_out_of_range(monkeypatch, n):
|
||||
module = _load_list_operations_module(monkeypatch)
|
||||
component = _make_component(
|
||||
module, inputs=["a", "b", "c", "d", "e"], operation="nth", n=n, strict=True
|
||||
)
|
||||
with pytest.raises(ValueError, match="nth requires n"):
|
||||
component._nth()
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_set_outputs_tracks_first_and_last(monkeypatch):
|
||||
module = _load_list_operations_module(monkeypatch)
|
||||
component = _make_component(
|
||||
module, inputs=["a", "b", "c", "d", "e"], operation="tail", n=3
|
||||
)
|
||||
component._tail()
|
||||
assert component._param.outputs["result"]["value"] == ["c", "d", "e"]
|
||||
assert component._param.outputs["first"]["value"] == "c"
|
||||
assert component._param.outputs["last"]["value"] == "e"
|
||||
@@ -2394,7 +2394,7 @@ Important structured information may include: names, dates, locations, events, k
|
||||
renameKeys: 'Rename keys',
|
||||
},
|
||||
ListOperationsOptions: {
|
||||
topN: 'Top N',
|
||||
nth: 'Nth',
|
||||
head: 'Head',
|
||||
tail: 'Tail',
|
||||
sort: 'Sort',
|
||||
@@ -2402,6 +2402,9 @@ Important structured information may include: names, dates, locations, events, k
|
||||
dropDuplicates: 'Drop duplicates',
|
||||
},
|
||||
sortMethod: 'Sort method',
|
||||
strictMode: 'Strict mode',
|
||||
strictModeTip:
|
||||
'Off uses lenient behavior and returns an empty result for invalid n. On uses strict behavior and raises an error for out-of-range n.',
|
||||
SortMethodOptions: {
|
||||
asc: 'Ascending',
|
||||
desc: 'Descending',
|
||||
|
||||
@@ -2080,14 +2080,17 @@ Tokenizer 会根据所选方式将内容存储为对应的数据结构。`,
|
||||
renameKeys: '重命名键',
|
||||
},
|
||||
ListOperationsOptions: {
|
||||
topN: '取前N项',
|
||||
head: '取前第N项',
|
||||
tail: '取后第N项',
|
||||
nth: '第N项',
|
||||
head: '取前N项',
|
||||
tail: '取后N项',
|
||||
sort: '排序',
|
||||
filter: '筛选',
|
||||
dropDuplicates: '去重',
|
||||
},
|
||||
sortMethod: '排序方式',
|
||||
strictMode: '严格模式',
|
||||
strictModeTip:
|
||||
'关闭时使用宽松模式,非法 n 返回空结果;开启时使用严格模式,超出范围的 n 会直接报错。',
|
||||
SortMethodOptions: {
|
||||
asc: '升序',
|
||||
desc: '降序',
|
||||
|
||||
@@ -587,7 +587,7 @@ export enum SortMethod {
|
||||
}
|
||||
|
||||
export enum ListOperations {
|
||||
TopN = 'topN',
|
||||
Nth = 'nth',
|
||||
Head = 'head',
|
||||
Tail = 'tail',
|
||||
Filter = 'filter',
|
||||
@@ -597,7 +597,8 @@ export enum ListOperations {
|
||||
|
||||
export const initialListOperationsValues = {
|
||||
query: '',
|
||||
operations: ListOperations.TopN,
|
||||
operations: ListOperations.Nth,
|
||||
strict: false,
|
||||
outputs: {
|
||||
// result: {
|
||||
// type: 'Array<?>',
|
||||
|
||||
@@ -10,6 +10,7 @@ import {
|
||||
FormMessage,
|
||||
} from '@/components/ui/form';
|
||||
import { Separator } from '@/components/ui/separator';
|
||||
import { Switch } from '@/components/ui/switch';
|
||||
import { useBuildSwitchOperatorOptions } from '@/hooks/logic-hooks/use-build-operator-options';
|
||||
import { buildOptions } from '@/utils/form';
|
||||
import { zodResolver } from '@hookform/resolvers/zod';
|
||||
@@ -38,7 +39,8 @@ import { QueryVariable } from '../components/query-variable';
|
||||
export const RetrievalPartialSchema = {
|
||||
query: z.string(),
|
||||
operations: z.string(),
|
||||
n: z.number().int().min(1).optional(),
|
||||
n: z.number().int().optional(),
|
||||
strict: z.boolean().optional(),
|
||||
sort_method: z.string().optional(),
|
||||
filter: z
|
||||
.object({
|
||||
@@ -50,7 +52,7 @@ export const RetrievalPartialSchema = {
|
||||
};
|
||||
|
||||
const NumFields = [
|
||||
ListOperations.TopN,
|
||||
ListOperations.Nth,
|
||||
ListOperations.Head,
|
||||
ListOperations.Tail,
|
||||
];
|
||||
@@ -71,6 +73,13 @@ function showField(operations: string) {
|
||||
};
|
||||
}
|
||||
|
||||
function getMinValue(operations: string) {
|
||||
if (operations === ListOperations.Nth) {
|
||||
return Number.MIN_SAFE_INTEGER;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
export const FormSchema = z.object(RetrievalPartialSchema);
|
||||
|
||||
export type ListOperationsFormSchemaType = z.infer<typeof FormSchema>;
|
||||
@@ -129,6 +138,7 @@ function ListOperationsForm({ node }: INextOperatorForm) {
|
||||
);
|
||||
|
||||
const { showFilter, showNum, showSortMethod } = showField(operations);
|
||||
const minValue = getMinValue(operations);
|
||||
|
||||
const handleOperationsChange = useCallback(
|
||||
(operations: string) => {
|
||||
@@ -180,23 +190,45 @@ function ListOperationsForm({ node }: INextOperatorForm) {
|
||||
)}
|
||||
</RAGFlowFormItem>
|
||||
{showNum && (
|
||||
<FormField
|
||||
control={form.control}
|
||||
name="n"
|
||||
render={({ field }) => (
|
||||
<FormItem>
|
||||
<FormLabel>{t('flow.flowNum')}</FormLabel>
|
||||
<FormControl>
|
||||
<NumberInput
|
||||
{...field}
|
||||
className="w-full"
|
||||
min={1}
|
||||
></NumberInput>
|
||||
</FormControl>
|
||||
<FormMessage />
|
||||
</FormItem>
|
||||
)}
|
||||
/>
|
||||
<>
|
||||
<FormField
|
||||
control={form.control}
|
||||
name="n"
|
||||
render={({ field }) => (
|
||||
<FormItem>
|
||||
<FormLabel>{t('flow.flowNum')}</FormLabel>
|
||||
<FormControl>
|
||||
<NumberInput
|
||||
{...field}
|
||||
className="w-full"
|
||||
min={minValue}
|
||||
></NumberInput>
|
||||
</FormControl>
|
||||
<FormMessage />
|
||||
</FormItem>
|
||||
)}
|
||||
/>
|
||||
<FormField
|
||||
control={form.control}
|
||||
name="strict"
|
||||
render={({ field }) => (
|
||||
<FormItem className="space-y-2">
|
||||
<FormLabel tooltip={t('flow.strictModeTip')}>
|
||||
{t('flow.strictMode')}
|
||||
</FormLabel>
|
||||
<FormControl>
|
||||
<div className="pt-1">
|
||||
<Switch
|
||||
checked={field.value}
|
||||
onCheckedChange={field.onChange}
|
||||
></Switch>
|
||||
</div>
|
||||
</FormControl>
|
||||
<FormMessage />
|
||||
</FormItem>
|
||||
)}
|
||||
/>
|
||||
</>
|
||||
)}
|
||||
{showSortMethod && (
|
||||
<RAGFlowFormItem name="sort_method" label={t('flow.sortMethod')}>
|
||||
|
||||
Reference in New Issue
Block a user