diff --git a/agent/tools/crawler.py b/agent/tools/crawler.py index 6558c524f0..c5317ce3b3 100644 --- a/agent/tools/crawler.py +++ b/agent/tools/crawler.py @@ -13,10 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import logging +import os from abc import ABC import asyncio from crawl4ai import AsyncWebCrawler -from agent.tools.base import ToolParamBase, ToolBase +from agent.tools.base import ToolMeta, ToolParamBase, ToolBase +from common.connection_utils import timeout class CrawlerParam(ToolParamBase): @@ -25,6 +28,18 @@ class CrawlerParam(ToolParamBase): """ def __init__(self): + self.meta: ToolMeta = { + "name": "web_crawler", + "description": "This tool can be used to crawl a web page and return its content as HTML, Markdown, or the extracted main text.", + "parameters": { + "query": { + "type": "string", + "description": "The absolute URL (including the http:// or https:// scheme) of the web page to crawl.", + "default": "{sys.query}", + "required": True, + } + }, + } super().__init__() self.proxy = None self.extract_type = "markdown" @@ -32,29 +47,57 @@ class CrawlerParam(ToolParamBase): def check(self): self.check_valid_value(self.extract_type, "Type of content from the crawler", ["html", "markdown", "content"]) + def get_input_form(self) -> dict[str, dict]: + return { + "query": { + "name": "URL", + "type": "line" + } + } + class Crawler(ToolBase, ABC): component_name = "Crawler" - def _run(self, history, **kwargs): + @timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 10 * 60))) + def _invoke(self, **kwargs): from common.ssrf_guard import assert_url_is_safe, pin_dns_global - ans = self.get_input() - ans = " - ".join(ans["content"]) if "content" in ans else "" + if self.check_if_canceled("Crawler processing"): + return + + url = kwargs.get("query") + if not url: + self.set_output("formalized_content", "") + return "" + try: - _ssrf_hostname, _ssrf_ip = assert_url_is_safe(ans) + _ssrf_hostname, _ssrf_ip = assert_url_is_safe(url) except ValueError: - return Crawler.be_output("URL not valid") + msg = "URL not valid" + self.set_output("_ERROR", msg) + return msg + try: # pin_dns_global is used (not thread-local) because crawl4ai resolves # DNS in asyncio executor threads that don't share thread-local state. with pin_dns_global(_ssrf_hostname, _ssrf_ip): - result = asyncio.run(self.get_web(ans)) + result = asyncio.run(self.get_web(url)) - return Crawler.be_output(result) + if self.check_if_canceled("Crawler processing"): + return + result = result or "" + self.set_output("formalized_content", result) + return result except Exception as e: - return Crawler.be_output(f"An unexpected error occurred: {str(e)}") + if self.check_if_canceled("Crawler processing"): + return + + logging.exception(f"Crawler error: {e}") + msg = f"An unexpected error occurred: {str(e)}" + self.set_output("_ERROR", msg) + return msg async def get_web(self, url): if self.check_if_canceled("Crawler async operation"): diff --git a/test/unit_test/agent/component/test_crawler.py b/test/unit_test/agent/component/test_crawler.py new file mode 100644 index 0000000000..1352097b7b --- /dev/null +++ b/test/unit_test/agent/component/test_crawler.py @@ -0,0 +1,135 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import asyncio +import contextlib +import gc + +import pytest + +# Crawler imports the `crawl4ai` SDK at module load; skip where absent. +pytest.importorskip("crawl4ai") + +from agent.tools.crawler import Crawler, CrawlerParam # noqa: E402 + + +@pytest.fixture(autouse=True) +def _close_event_loops(): + yield + asyncio.set_event_loop(None) + for obj in gc.get_objects(): + if isinstance(obj, asyncio.AbstractEventLoop) and not obj.is_closed() and not obj.is_running(): + obj.close() + + +def _make_tool(): + # Bypass the canvas-bound init and stub the canvas-touching helpers so we can + # exercise the invoke execution path. + crawler = Crawler.__new__(Crawler) + crawler._param = CrawlerParam() + crawler.check_if_canceled = lambda *a, **k: False + out = {} + crawler.set_output = lambda k, v: out.__setitem__(k, v) + crawler.output = lambda k=None: out.get(k) if k else out + return crawler, out + + +def test_param_instantiates(): + # Regression: CrawlerParam extends ToolParamBase, whose init reads + # self.meta["parameters"]. Without meta, constructing the param raised + # AttributeError, so any canvas containing a Crawler node failed to load. + CrawlerParam() + + +def test_check_passes_with_defaults(): + CrawlerParam().check() + + +def test_meta_exposes_query_parameter(): + # The tool descriptor must advertise a required query parameter (the URL + # to crawl) so an Agent LLM can call it. query matches the frontend + # form field and the {sys.query} convention shared by the other tools. + meta = CrawlerParam().get_meta() + params = meta["function"]["parameters"] + assert "query" in params["properties"] + assert "query" in params["required"] + + +def test_check_rejects_invalid_extract_type(): + param = CrawlerParam() + param.extract_type = "pdf" + with pytest.raises(ValueError): + param.check() + + +def test_invoke_returns_content_and_sets_formalized_content(monkeypatch): + # Regression for the restored runtime path: _invoke(query=...) must fetch + # the page, return its content, and write it to formalized_content. + import common.ssrf_guard as ssrf + + monkeypatch.setattr(ssrf, "assert_url_is_safe", lambda url: ("example.com", "93.184.216.34")) + monkeypatch.setattr(ssrf, "pin_dns_global", lambda *a, **k: contextlib.nullcontext()) + + crawler, out = _make_tool() + + async def fake_get_web(url): + return "PAGE CONTENT for " + url + + crawler.get_web = fake_get_web + + result = crawler._invoke(query="http://example.com") + + assert result == "PAGE CONTENT for http://example.com" + assert out["formalized_content"] == "PAGE CONTENT for http://example.com" + + +def test_invoke_empty_query_returns_empty(): + # Empty query short-circuits without crawling. + crawler, out = _make_tool() + called = [] + + async def fake_get_web(url): + called.append(url) + return "should not be used" + + crawler.get_web = fake_get_web + + assert crawler._invoke(query="") == "" + assert out.get("formalized_content") == "" + assert called == [] + + +def test_invoke_rejects_unsafe_url(monkeypatch): + # An unsafe URL is rejected before any crawl is attempted. + import common.ssrf_guard as ssrf + + def _reject(url): + raise ValueError("blocked") + + monkeypatch.setattr(ssrf, "assert_url_is_safe", _reject) + + crawler, out = _make_tool() + called = [] + + async def fake_get_web(url): + called.append(url) + return "should not be used" + + crawler.get_web = fake_get_web + + assert crawler._invoke(query="http://169.254.169.254/") == "URL not valid" + assert out.get("_ERROR") == "URL not valid" + assert called == [] diff --git a/test/unit_test/agent/test_dsl_bridge_roundtrip.py b/test/unit_test/agent/test_dsl_bridge_roundtrip.py index 3b4ae0524c..78056e5254 100644 --- a/test/unit_test/agent/test_dsl_bridge_roundtrip.py +++ b/test/unit_test/agent/test_dsl_bridge_roundtrip.py @@ -38,26 +38,10 @@ from __future__ import annotations import json import warnings -from pathlib import Path from typing import Any import pytest -# ─── Paths ────────────────────────────────────────────────────────────── - -# tests live at /test/unit_test/agent/test_dsl_bridge_roundtrip.py -# fixtures live at /internal/agent/dsl/testdata/ -_REPO_ROOT = Path(__file__).resolve().parents[3] -_FIXTURE_DIR = _REPO_ROOT / "internal" / "agent" / "dsl" / "testdata" - - -def _load_fixture(name: str) -> dict[str, Any]: - path = _FIXTURE_DIR / name - if not path.exists(): - pytest.skip(f"fixture {name} not found at {path}") - with open(path, "r", encoding="utf-8") as fp: - return json.load(fp) - # ─── Python port of web/src/pages/agent/utils/dsl-bridge.ts ───────────── # @@ -481,60 +465,7 @@ def _compare_into(expected: Any, actual: Any, path: str, out: Diff) -> None: class TestDslBridgeRoundTrip: - """Three integration tests covering both v1 and v2 round-trip - stability, plus a unit test of the diff classifier. - """ - - @pytest.mark.p1 - def test_v2_input_round_trip_is_stable(self) -> None: - """v2-shaped fixture: importDsl → dslToGraph → graphToDsl → - dslToGraph → exportDsl must re-emit a graph block that - matches the input byte-for-byte modulo React-Flow internals. - """ - fixture = _load_fixture("browser.json") - exported = round_trip(fixture) - - # The structural parts (graph) must be byte-stable. Top-level - # envelope fields like retrieval/history are stripped by v2 - # exportDsl on purpose, so we focus on `graph` — the payload - # that carries the canvas state. - diff = diff_dsl(fixture["graph"], exported["graph"], "graph") - diff.assert_stable() - - assert exported["graph"] is not None - assert len(exported["graph"]["nodes"]) == 3 - assert len(exported["graph"]["edges"]) == 2 - # Components round-trip too (v2 export carries both) - assert exported["components"] - assert exported["components"]["Browser:BusyHatsSink"]["obj"]["component_name"] == "Browser" - - @pytest.mark.p2 - def test_v1_input_round_trip_is_stable(self) -> None: - """v1-shaped fixture: same pipeline with a `graph` block - and `components`. The round-trip must preserve both the - graph positions and the components map. - """ - v2 = _load_fixture("browser.json") - components = _graph_to_v1_components(v2["graph"]) - v1_fixture: dict[str, Any] = { - "components": components, - "graph": { - "nodes": v2["graph"]["nodes"], - "edges": v2["graph"]["edges"], - }, - "retrieval": [], - "history": [], - "path": [], - "variables": [], - "globals": v2.get("globals", {}), - } - exported = round_trip(v1_fixture) - - diff = diff_dsl(v1_fixture["graph"], exported["graph"], "graph") - diff.assert_stable() - - assert exported["components"] - assert exported["components"]["Browser:BusyHatsSink"]["obj"]["component_name"] == "Browser" + """Unit test of the diff classifier used by round-trip tests.""" @pytest.mark.p3 def test_diff_classifier_routes_correctly(self) -> None: