diff --git a/agent/component/pipeline_chunker.py b/agent/component/pipeline_chunker.py new file mode 100644 index 0000000000..3bf5bd0305 --- /dev/null +++ b/agent/component/pipeline_chunker.py @@ -0,0 +1,194 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +PipelineChunker Component + +Run RAGFlow Pipeline-style chunkers (rag.app.*) against uploaded files inside an +Agent workflow. Emits plain text chunks for downstream Agent nodes — no +embedding, no persistence. Wraps existing chunker functions; does not +re-implement chunking logic. +""" + +import importlib +import logging +import os +from abc import ABC + +from agent.component.base import ComponentBase, ComponentParamBase +from api.db.services.file_service import FileService +from common.connection_utils import timeout + + +# Parser id -> dotted module path under rag.app. Imported lazily so we don't +# pull deepdoc/OCR/VLM machinery at component-discovery time. +_PARSER_MODULES: dict[str, str] = { + "general": "rag.app.naive", + "naive": "rag.app.naive", + "paper": "rag.app.paper", + "book": "rag.app.book", + "presentation": "rag.app.presentation", + "manual": "rag.app.manual", + "laws": "rag.app.laws", + "qa": "rag.app.qa", + "table": "rag.app.table", + "resume": "rag.app.resume", + "picture": "rag.app.picture", + "one": "rag.app.one", + "audio": "rag.app.audio", + "email": "rag.app.email", + "tag": "rag.app.tag", +} + + +def _load_chunker(parser_id: str): + """Resolve a parser id to the underlying ``rag.app..chunk`` callable.""" + module_path = _PARSER_MODULES[parser_id.lower()] + return importlib.import_module(module_path).chunk + + +class PipelineChunkerParam(ComponentParamBase): + """ + Define the PipelineChunker component parameters. + """ + + def __init__(self): + """Initialise PipelineChunker defaults and declare component outputs.""" + super().__init__() + self.inputs = [] # variable references to uploaded files + self.parser_id = "naive" + self.lang = "English" + self.from_page = 0 + self.to_page = 100000000 + self.parser_config = {} + + self.outputs = { + "chunks": {"type": "list", "value": []}, + "chunks_full": {"type": "list", "value": []}, + "summary": {"type": "str", "value": ""}, + } + + def check(self): + """Validate parser id, page range, and parser_config shape.""" + self.check_valid_value( + self.parser_id.lower(), + "[PipelineChunker] parser_id", + list(_PARSER_MODULES.keys()), + ) + self.check_nonnegative_number(self.from_page, "[PipelineChunker] from_page") + self.check_nonnegative_number(self.to_page, "[PipelineChunker] to_page") + if isinstance(self.from_page, (int, float)) and isinstance(self.to_page, (int, float)) and self.from_page > self.to_page: + raise ValueError("[PipelineChunker] from_page must be <= to_page") + if not isinstance(self.parser_config, dict): + raise ValueError("[PipelineChunker] parser_config must be a dict.") + return True + + +class PipelineChunker(ComponentBase, ABC): + """ + Run a Pipeline-style chunker (naive, paper, qa, manual, book, ...) against + one or more uploaded files and surface the resulting chunks to downstream + Agent nodes. + """ + + component_name = "PipelineChunker" + + def get_input_form(self) -> dict[str, dict]: + """Expose each referenced file input as a file-typed form element.""" + res = {} + for ref in self._param.inputs or []: + for k, o in self.get_input_elements_from_text(ref).items(): + res[k] = {"name": o.get("name", ""), "type": "file"} + return res + + def _get_file_content(self, file_ref: str) -> tuple[bytes | None, str | None]: + """Resolve a canvas variable reference to ``(content_bytes, filename)``.""" + value = self._canvas.get_variable_value(file_ref) + if value is None: + return None, None + + if isinstance(value, list) and value: + value = value[0] + + if isinstance(value, dict): + file_id = value.get("id") or value.get("file_id") + created_by = value.get("created_by") or self._canvas.get_tenant_id() + filename = value.get("name") or value.get("filename") or "uploaded" + if file_id: + try: + return FileService.get_blob(created_by, file_id), filename + except Exception as e: + logging.exception( + f"[PipelineChunker] FileService.get_blob failed for " + f"file_id={file_id} created_by={created_by} filename={filename}: {e}" + ) + return None, None + return None, None + + @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10 * 60))) + def _invoke(self, **kwargs): + """Run the configured chunker over every referenced file and publish outputs.""" + if self.check_if_canceled("PipelineChunker processing"): + return + + chunker = _load_chunker(self._param.parser_id) + tenant_id = self._canvas.get_tenant_id() + chunk_kwargs = dict( + lang=self._param.lang, + tenant_id=tenant_id, + from_page=self._param.from_page, + to_page=self._param.to_page, + parser_config=self._param.parser_config or {}, + callback=lambda prog=0, msg="": logging.info(f"[PipelineChunker] {prog}: {msg}"), + ) + + all_chunks: list[dict] = [] + per_file_counts: list[str] = [] + + for file_ref in self._param.inputs or []: + if self.check_if_canceled("PipelineChunker processing"): + return + + content, filename = self._get_file_content(file_ref) + self.set_input_value(file_ref, filename or "") + if content is None: + logging.warning(f"[PipelineChunker] could not resolve file ref: {file_ref}") + per_file_counts.append(f"{filename or file_ref}: error (could not resolve file)") + continue + + try: + file_chunks = chunker(filename, binary=content, **chunk_kwargs) or [] + except Exception as e: + logging.exception(e) + per_file_counts.append(f"{filename}: error (chunking failed)") + continue + + all_chunks.extend(file_chunks) + per_file_counts.append(f"{filename}: {len(file_chunks)} chunks") + + text_only = [(c.get("content_with_weight") or c.get("text") or "") for c in all_chunks if isinstance(c, dict)] + text_only = [t for t in text_only if t] + + self.set_output("chunks", text_only) + self.set_output("chunks_full", all_chunks) + self.set_output( + "summary", + f"Parser: {self._param.parser_id} | Files: {len(self._param.inputs or [])} | Chunks: {len(text_only)}" + (" | " + "; ".join(per_file_counts) if per_file_counts else ""), + ) + + def thoughts(self) -> str: + """Return a short status line for UI display.""" + return f"Chunking with `{self._param.parser_id}` strategy..." diff --git a/test/unit_test/agent/test_pipeline_chunker.py b/test/unit_test/agent/test_pipeline_chunker.py new file mode 100644 index 0000000000..d19981a73d --- /dev/null +++ b/test/unit_test/agent/test_pipeline_chunker.py @@ -0,0 +1,155 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Unit tests for the PipelineChunker agent component (#14773). + +These tests cover only the pieces that don't require a live Canvas/Graph: +parameter validation and the parser-id -> module lookup table. Full +end-to-end behavior is intentionally left to higher-level integration tests. +""" + +from __future__ import annotations + +import sys +from unittest.mock import MagicMock + +import pytest + +pytestmark = pytest.mark.p2 + + +# The component pulls in api.db.services.file_service (-> quart_auth, peewee, +# the entire backend stack) and rag.app.* (-> deepdoc, OCR, xgboost, +# transformers). None of that is exercised by these unit tests, so replace +# the heavy modules with stubs to keep the test runnable without the full +# runtime environment. We track every key we install and restore the prior +# sys.modules state in teardown_module so the stubs don't leak into other +# test files. +_SENTINEL_ABSENT = object() +_INSTALLED_STUBS: dict[str, object] = {} + + +def _install_stub(name: str, stub: object) -> None: + """Insert ``stub`` into sys.modules and remember the prior entry.""" + if name in _INSTALLED_STUBS: + return + _INSTALLED_STUBS[name] = sys.modules.get(name, _SENTINEL_ABSENT) + sys.modules[name] = stub + + +_file_service_stub = MagicMock() +_file_service_stub.FileService = MagicMock() +if "api.db.services.file_service" not in sys.modules: + _install_stub("api.db.services.file_service", _file_service_stub) + +for mod in [ + "deepdoc.vision.ocr", + "deepdoc.parser.figure_parser", + "rag.app.picture", + "rag.app.audio", + "rag.app.resume", + "rag.app.naive", + "rag.app.paper", + "rag.app.book", + "rag.app.presentation", + "rag.app.manual", + "rag.app.laws", + "rag.app.qa", + "rag.app.table", + "rag.app.one", + "rag.app.email", + "rag.app.tag", +]: + if mod not in sys.modules: + stub = MagicMock() + stub.chunk = MagicMock(return_value=[{"content_with_weight": "stub"}]) + _install_stub(mod, stub) + + +def teardown_module(module): + """Restore sys.modules to its pre-stub state when this file's tests finish.""" + for name, original in _INSTALLED_STUBS.items(): + if original is _SENTINEL_ABSENT: + sys.modules.pop(name, None) + else: + sys.modules[name] = original + _INSTALLED_STUBS.clear() + +from agent.component.pipeline_chunker import ( # noqa: E402 + _PARSER_MODULES, + PipelineChunkerParam, + _load_chunker, +) + + +class TestPipelineChunkerParam: + """Validate parameter parsing and the strategy whitelist.""" + + def test_default_param_validates(self): + """A freshly constructed param object should pass ``check()``.""" + p = PipelineChunkerParam() + assert p.check() is True + + def test_accepts_each_known_parser(self): + """Every parser id in the lookup table must validate.""" + for parser_id in _PARSER_MODULES: + p = PipelineChunkerParam() + p.parser_id = parser_id + assert p.check() is True + + def test_rejects_unknown_parser(self): + """Unknown parser ids must raise ``ValueError`` at validation time.""" + p = PipelineChunkerParam() + p.parser_id = "nonsense-parser" + with pytest.raises(ValueError): + p.check() + + def test_rejects_non_dict_parser_config(self): + """``parser_config`` must be a dict; anything else must raise.""" + p = PipelineChunkerParam() + p.parser_config = "not a dict" + with pytest.raises(ValueError): + p.check() + + def test_rejects_negative_pages(self): + """Negative page indices must raise ``ValueError``.""" + p = PipelineChunkerParam() + p.from_page = -1 + with pytest.raises(ValueError): + p.check() + + def test_rejects_inverted_page_range(self): + """``from_page`` greater than ``to_page`` must raise ``ValueError``.""" + p = PipelineChunkerParam() + p.from_page = 10 + p.to_page = 5 + with pytest.raises(ValueError, match="from_page must be <= to_page"): + p.check() + + +class TestLoadChunker: + """Verify the lazy parser-id -> chunker callable resolver.""" + + def test_load_chunker_returns_callable_for_each_known_parser(self): + """Every known parser id should resolve to a callable ``chunk`` function.""" + for parser_id in _PARSER_MODULES: + chunker = _load_chunker(parser_id) + assert callable(chunker) + + def test_load_chunker_raises_for_unknown_parser(self): + """Unknown parser ids should raise ``KeyError`` from the lookup.""" + with pytest.raises(KeyError): + _load_chunker("not-a-real-parser")