diff --git a/agent/tools/github.py b/agent/tools/github.py index f48ab0a2d6..4a95ac366a 100644 --- a/agent/tools/github.py +++ b/agent/tools/github.py @@ -20,6 +20,7 @@ from abc import ABC import requests from agent.tools.base import ToolParamBase, ToolMeta, ToolBase from common.connection_utils import timeout +from common.http_client import DEFAULT_TIMEOUT class GitHubParam(ToolParamBase): @@ -75,7 +76,7 @@ class GitHub(ToolBase, ABC): url = 'https://api.github.com/search/repositories?q=' + kwargs["query"] + '&sort=stars&order=desc&per_page=' + str( self._param.top_n) headers = {"Content-Type": "application/vnd.github+json", "X-GitHub-Api-Version": '2022-11-28'} - response = requests.get(url=url, headers=headers).json() + response = requests.get(url=url, headers=headers, timeout=DEFAULT_TIMEOUT).json() if self.check_if_canceled("GitHub processing"): return diff --git a/agent/tools/jin10.py b/agent/tools/jin10.py index b477dba81e..a37249ca40 100644 --- a/agent/tools/jin10.py +++ b/agent/tools/jin10.py @@ -18,6 +18,7 @@ from abc import ABC import pandas as pd import requests from agent.component.base import ComponentBase, ComponentParamBase +from common.http_client import DEFAULT_TIMEOUT class Jin10Param(ComponentParamBase): @@ -72,7 +73,7 @@ class Jin10(ComponentBase, ABC): } response = requests.get( url='https://open-data-api.jin10.com/data-api/flash?category=' + self._param.flash_type, - headers=headers, data=json.dumps(params)) + headers=headers, data=json.dumps(params), timeout=DEFAULT_TIMEOUT) response = response.json() for i in response['data']: if self.check_if_canceled("Jin10 processing"): @@ -84,7 +85,7 @@ class Jin10(ComponentBase, ABC): } response = requests.get( url='https://open-data-api.jin10.com/data-api/calendar/' + self._param.calendar_datatype + '?category=' + self._param.calendar_type, - headers=headers, data=json.dumps(params)) + headers=headers, data=json.dumps(params), timeout=DEFAULT_TIMEOUT) response = response.json() if self.check_if_canceled("Jin10 processing"): @@ -98,7 +99,7 @@ class Jin10(ComponentBase, ABC): params['codes'] = 'BTCUSD' response = requests.get( url='https://open-data-api.jin10.com/data-api/' + self._param.symbols_datatype + '?type=' + self._param.symbols_type, - headers=headers, data=json.dumps(params)) + headers=headers, data=json.dumps(params), timeout=DEFAULT_TIMEOUT) response = response.json() if self.check_if_canceled("Jin10 processing"): return @@ -134,7 +135,7 @@ class Jin10(ComponentBase, ABC): } response = requests.get( url='https://open-data-api.jin10.com/data-api/news', - headers=headers, data=json.dumps(params)) + headers=headers, data=json.dumps(params), timeout=DEFAULT_TIMEOUT) response = response.json() if self.check_if_canceled("Jin10 processing"): return diff --git a/agent/tools/qweather.py b/agent/tools/qweather.py index a597c2c5b8..2a1b7d9772 100644 --- a/agent/tools/qweather.py +++ b/agent/tools/qweather.py @@ -17,6 +17,7 @@ from abc import ABC import pandas as pd import requests from agent.component.base import ComponentBase, ComponentParamBase +from common.http_client import DEFAULT_TIMEOUT class QWeatherParam(ComponentParamBase): @@ -71,7 +72,8 @@ class QWeather(ComponentBase, ABC): return response = requests.get( - url="https://geoapi.qweather.com/v2/city/lookup?location=" + ans + "&key=" + self._param.web_apikey).json() + url="https://geoapi.qweather.com/v2/city/lookup?location=" + ans + "&key=" + self._param.web_apikey, + timeout=DEFAULT_TIMEOUT).json() if response["code"] == "200": location_id = response["location"][0]["id"] else: @@ -84,7 +86,7 @@ class QWeather(ComponentBase, ABC): if self._param.type == "weather": url = base_url + "weather/" + self._param.time_period + "?location=" + location_id + "&key=" + self._param.web_apikey + "&lang=" + self._param.lang - response = requests.get(url=url).json() + response = requests.get(url=url, timeout=DEFAULT_TIMEOUT).json() if self.check_if_canceled("Qweather processing"): return if response["code"] == "200": @@ -104,7 +106,7 @@ class QWeather(ComponentBase, ABC): elif self._param.type == "indices": url = base_url + "indices/1d?type=0&location=" + location_id + "&key=" + self._param.web_apikey + "&lang=" + self._param.lang - response = requests.get(url=url).json() + response = requests.get(url=url, timeout=DEFAULT_TIMEOUT).json() if self.check_if_canceled("Qweather processing"): return if response["code"] == "200": @@ -117,7 +119,7 @@ class QWeather(ComponentBase, ABC): elif self._param.type == "airquality": url = base_url + "air/now?location=" + location_id + "&key=" + self._param.web_apikey + "&lang=" + self._param.lang - response = requests.get(url=url).json() + response = requests.get(url=url, timeout=DEFAULT_TIMEOUT).json() if self.check_if_canceled("Qweather processing"): return if response["code"] == "200": diff --git a/agent/tools/tushare.py b/agent/tools/tushare.py index 6a0d0c2a34..feec503067 100644 --- a/agent/tools/tushare.py +++ b/agent/tools/tushare.py @@ -19,6 +19,7 @@ import pandas as pd import time import requests from agent.component.base import ComponentBase, ComponentParamBase +from common.http_client import DEFAULT_TIMEOUT class TuShareParam(ComponentParamBase): @@ -62,7 +63,7 @@ class TuShare(ComponentBase, ABC): "params": {"src": self._param.src, "start_date": self._param.start_date, "end_date": self._param.end_date} } - response = requests.post(url="http://api.tushare.pro", data=json.dumps(params).encode('utf-8')) + response = requests.post(url="http://api.tushare.pro", data=json.dumps(params).encode('utf-8'), timeout=DEFAULT_TIMEOUT) response = response.json() if self.check_if_canceled("TuShare processing"): return diff --git a/test/unit_test/agent/tools/test_http_timeout.py b/test/unit_test/agent/tools/test_http_timeout.py new file mode 100644 index 0000000000..49e9d9be97 --- /dev/null +++ b/test/unit_test/agent/tools/test_http_timeout.py @@ -0,0 +1,105 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Guard against external-API agent tools issuing HTTP requests without a timeout. + +A blocking ``requests``/``httpx`` call with no ``timeout`` will hang forever if +the upstream stalls. Because these tools run inside agent canvas execution, a +single stalled socket hangs the whole agent run with no recovery. This test +parses the tool sources and fails if any ``requests``/``httpx`` request call is +missing a ``timeout`` keyword, covering current and future call sites. +""" + +import ast +from pathlib import Path + +import pytest + + +def _repo_root() -> Path: + """Anchor on the repo root (the dir holding pyproject.toml). + + Walking for the first ``*/agent/tools`` directory is unsafe: this test file + itself lives under ``test/unit_test/agent/tools``, so that heuristic would + resolve to the test directory and scan nothing. + """ + for parent in Path(__file__).resolve().parents: + if (parent / "pyproject.toml").is_file(): + return parent + raise RuntimeError("Could not locate repo root (no pyproject.toml found)") + + +TOOLS_DIR = _repo_root() / "agent" / "tools" +# Fail loudly if we ever point at the wrong place, rather than silently +# scanning zero real tools and passing. +assert (TOOLS_DIR / "github.py").is_file(), f"agent tools not found at {TOOLS_DIR}" + +# Methods on ``requests`` / ``httpx`` (and a session/client) that open a socket +# and therefore must be bounded by a timeout. +_REQUEST_METHODS = {"get", "post", "put", "patch", "delete", "head", "options", "request"} +# Modules whose request methods we police, as referenced at the call site. +_HTTP_NAMESPACES = {"requests", "httpx"} +# Common local names for a session/client built from those modules, so calls +# like ``session.get(...)`` or ``client.request(...)`` are covered too. +_HTTP_INSTANCE_NAMES = {"session", "client", "session_client"} + + +def _root_name(node: ast.AST) -> str | None: + """Resolve the leftmost identifier of an HTTP call target. + + Unwraps a constructor call (``requests.Session()`` -> ``requests``) and any + chained attribute (``requests.sessions.Session`` -> ``requests``), returning + the base ``Name`` id or ``None``. + """ + if isinstance(node, ast.Call): + node = node.func + while isinstance(node, ast.Attribute): + node = node.value + return node.id if isinstance(node, ast.Name) else None + + +def _iter_request_calls(tree: ast.AST): + """Yield ``ast.Call`` nodes that look like an HTTP ``(...)`` request.""" + for node in ast.walk(tree): + if not isinstance(node, ast.Call): + continue + func = node.func + if not isinstance(func, ast.Attribute) or func.attr not in _REQUEST_METHODS: + continue + root = _root_name(func.value) + if root in _HTTP_NAMESPACES or root in _HTTP_INSTANCE_NAMES: + yield node + + +def _has_timeout(call: ast.Call) -> bool: + if any(kw.arg == "timeout" for kw in call.keywords): + return True + # ``**{"timeout": ...}`` / ``**kwargs`` spread — assume the caller is explicit. + return any(kw.arg is None for kw in call.keywords) + + +def _tool_files(): + return sorted(p for p in TOOLS_DIR.glob("*.py") if p.name != "__init__.py") + + +@pytest.mark.parametrize("path", _tool_files(), ids=lambda p: p.name) +def test_http_calls_have_timeout(path: Path): + tree = ast.parse(path.read_text(encoding="utf-8"), filename=str(path)) + missing = [ + f"{path.name}:{call.lineno}" + for call in _iter_request_calls(tree) + if not _has_timeout(call) + ] + assert not missing, "HTTP request(s) without timeout=: " + ", ".join(missing)