mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 15:31:05 +08:00
fix(agent): add SSRF guard to Invoke HTTP component (#15426)
## Summary Closes #15425. The agent **Invoke** (HTTP Request) component now calls `assert_url_is_safe` and `pin_dns` before `requests.*`, matching Crawler and SearXNG. ## Changes - `agent/component/invoke.py`: SSRF guard + DNS pinning on outbound requests. - `test_invoke_component_unit.py`: unit test blocks loopback URL without calling `requests.get`. ## Test plan - [x] `pytest test/testcases/test_web_api/test_canvas_app/test_invoke_component_unit.py::test_invoke_blocks_loopback_url_with_ssrf_guard` (requires project test env / `ZHIPU_AI_API_KEY` in CI) --------- Co-authored-by: Cursor <cursoragent@cursor.com> Co-authored-by: Zhichang Yu <yuzhichang@gmail.com>
This commit is contained in:
@@ -20,6 +20,7 @@ import re
|
|||||||
import time
|
import time
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
@@ -190,8 +191,24 @@ class Invoke(ComponentBase, ABC):
|
|||||||
|
|
||||||
return {key: self._resolve_header_text(value, kwargs) if isinstance(value, str) else value for key, value in headers.items()}
|
return {key: self._resolve_header_text(value, kwargs) if isinstance(value, str) else value for key, value in headers.items()}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _ssrf_log_target(url: str) -> str:
|
||||||
|
parsed = urlparse(url)
|
||||||
|
if not parsed.scheme or not parsed.hostname:
|
||||||
|
return "invalid-url"
|
||||||
|
return f"{parsed.scheme}://{parsed.hostname}"
|
||||||
|
|
||||||
|
def _normalize_proxy_url(self) -> str | None:
|
||||||
|
proxy = (self._param.proxy or "").strip()
|
||||||
|
if not re.sub(r"https?:?/?/?", "", proxy):
|
||||||
|
return None
|
||||||
|
if not proxy.startswith(("http://", "https://")):
|
||||||
|
proxy = "http://" + proxy
|
||||||
|
return proxy
|
||||||
|
|
||||||
def _build_proxies(self) -> dict | None:
|
def _build_proxies(self) -> dict | None:
|
||||||
if not re.sub(r"https?:?/?/?", "", self._param.proxy):
|
proxy_url = self._normalize_proxy_url()
|
||||||
|
if not proxy_url:
|
||||||
return None
|
return None
|
||||||
return {"http": self._param.proxy, "https": self._param.proxy}
|
return {"http": self._param.proxy, "https": self._param.proxy}
|
||||||
|
|
||||||
@@ -231,6 +248,20 @@ class Invoke(ComponentBase, ABC):
|
|||||||
args = self._build_request_args(kwargs)
|
args = self._build_request_args(kwargs)
|
||||||
headers = self._build_headers(kwargs)
|
headers = self._build_headers(kwargs)
|
||||||
proxies = self._build_proxies()
|
proxies = self._build_proxies()
|
||||||
|
proxy_hostname = proxy_ip = None
|
||||||
|
|
||||||
|
if proxies:
|
||||||
|
proxy_url = self._normalize_proxy_url()
|
||||||
|
try:
|
||||||
|
proxy_hostname, proxy_ip = assert_url_is_safe(proxy_url)
|
||||||
|
except ValueError as exc:
|
||||||
|
logging.warning(
|
||||||
|
"Invoke SSRF guard blocked proxy=%s: %s",
|
||||||
|
self._ssrf_log_target(proxy_url),
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
self.set_output("_ERROR", "URL not valid")
|
||||||
|
return "Http request error: URL not valid"
|
||||||
|
|
||||||
last_error = None
|
last_error = None
|
||||||
for _ in range(self._param.max_retries + 1):
|
for _ in range(self._param.max_retries + 1):
|
||||||
@@ -238,18 +269,26 @@ class Invoke(ComponentBase, ABC):
|
|||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Coderabbit MAJOR #3486038788: URL validation is now inside the
|
|
||||||
# retry/except block so SSRF rejections (ValueError from
|
|
||||||
# assert_url_is_safe) populate _ERROR via the standard error
|
|
||||||
# path instead of escaping _invoke().
|
|
||||||
url = self._build_url(kwargs)
|
url = self._build_url(kwargs)
|
||||||
if not self._pinned_hostname or not self._pinned_ip:
|
if not self._pinned_hostname or not self._pinned_ip:
|
||||||
raise ValueError("Invoke URL was not validated before request.")
|
raise ValueError("Invoke URL was not validated before request.")
|
||||||
with pin_dns(self._pinned_hostname, self._pinned_ip):
|
with pin_dns(self._pinned_hostname, self._pinned_ip):
|
||||||
response = self._send_request(url, args, headers, proxies)
|
if proxy_hostname and proxy_ip:
|
||||||
|
with pin_dns(proxy_hostname, proxy_ip):
|
||||||
|
response = self._send_request(url, args, headers, proxies)
|
||||||
|
else:
|
||||||
|
response = self._send_request(url, args, headers, proxies)
|
||||||
result = self._format_response(response)
|
result = self._format_response(response)
|
||||||
self.set_output("result", result)
|
self.set_output("result", result)
|
||||||
return result
|
return result
|
||||||
|
except ValueError as e:
|
||||||
|
logging.warning(
|
||||||
|
"Invoke SSRF guard blocked url=%s: %s",
|
||||||
|
self._ssrf_log_target(locals().get("url", self._param.url)),
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
self.set_output("_ERROR", "URL not valid")
|
||||||
|
return "Http request error: URL not valid"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if self.check_if_canceled("Invoke processing"):
|
if self.check_if_canceled("Invoke processing"):
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -274,19 +274,62 @@ def test_header_variable_with_put(monkeypatch):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.p2
|
@pytest.mark.p2
|
||||||
def test_invoke_rejects_private_url(monkeypatch):
|
def test_invoke_blocks_loopback_url_with_ssrf_guard(monkeypatch):
|
||||||
|
"""Invoke must use the shared SSRF guard before requests.* (issue Invoke SSRF)."""
|
||||||
module = _load_invoke_module(monkeypatch)
|
module = _load_invoke_module(monkeypatch)
|
||||||
invoke = _make_invoke(module, url="http://127.0.0.1:22")
|
invoke = _make_invoke(module, url="http://127.0.0.1:8123/api")
|
||||||
monkeypatch.setattr(
|
mock_get = MagicMock(return_value=SimpleNamespace(text="ok"))
|
||||||
module,
|
monkeypatch.setattr(module.requests, "get", mock_get)
|
||||||
"assert_url_is_safe",
|
result = invoke._invoke()
|
||||||
MagicMock(side_effect=ValueError("URL resolves to a non-public address")),
|
mock_get.assert_not_called()
|
||||||
)
|
assert result == "Http request error: URL not valid"
|
||||||
# Coderabbit MAJOR #3486038793: _build_url() is now inside the retry
|
assert invoke.output("_ERROR") == "URL not valid"
|
||||||
# try/except block, so the ValueError from assert_url_is_safe is caught
|
|
||||||
# and the message is stored in _ERROR via the standard error path.
|
|
||||||
|
@pytest.mark.p2
|
||||||
|
def test_invoke_blocks_metadata_ip(monkeypatch):
|
||||||
|
module = _load_invoke_module(monkeypatch)
|
||||||
|
invoke = _make_invoke(module, url="http://169.254.169.254/latest/meta-data/")
|
||||||
|
mock_get = MagicMock(return_value=SimpleNamespace(text="should not run"))
|
||||||
|
monkeypatch.setattr(module.requests, "get", mock_get)
|
||||||
|
result = invoke._invoke()
|
||||||
|
mock_get.assert_not_called()
|
||||||
|
assert "URL not valid" in result
|
||||||
|
assert invoke.output("_ERROR") == "URL not valid"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.p2
|
||||||
|
def test_invoke_url_without_scheme_gets_scheme_then_validated(monkeypatch):
|
||||||
|
"""Bare hostnames are prefixed with http:// before SSRF validation."""
|
||||||
|
module = _load_invoke_module(monkeypatch)
|
||||||
|
invoke = _make_invoke(module, url="127.0.0.1:9380/")
|
||||||
|
mock_get = MagicMock(return_value=SimpleNamespace(text="should not run"))
|
||||||
|
monkeypatch.setattr(module.requests, "get", mock_get)
|
||||||
|
result = invoke._invoke()
|
||||||
|
mock_get.assert_not_called()
|
||||||
|
assert "URL not valid" in result
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.p2
|
||||||
|
def test_invoke_blocks_loopback_proxy(monkeypatch):
|
||||||
|
module = _load_invoke_module(monkeypatch)
|
||||||
|
invoke = _make_invoke(module, url="http://example.com", proxy="http://127.0.0.1:8080")
|
||||||
|
mock_get = MagicMock(return_value=SimpleNamespace(text="should not run"))
|
||||||
|
monkeypatch.setattr(module.requests, "get", mock_get)
|
||||||
|
result = invoke._invoke()
|
||||||
|
mock_get.assert_not_called()
|
||||||
|
assert "URL not valid" in result
|
||||||
|
assert invoke.output("_ERROR") == "URL not valid"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.p2
|
||||||
|
def test_invoke_disables_redirect_following(monkeypatch):
|
||||||
|
module = _load_invoke_module(monkeypatch)
|
||||||
|
invoke = _make_invoke(module, url="http://example.com")
|
||||||
|
mock_get = MagicMock(return_value=SimpleNamespace(text="ok"))
|
||||||
|
monkeypatch.setattr(module.requests, "get", mock_get)
|
||||||
invoke._invoke()
|
invoke._invoke()
|
||||||
assert "non-public address" in invoke._param.outputs["_ERROR"]
|
assert mock_get.call_args[1]["allow_redirects"] is False
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.p2
|
@pytest.mark.p2
|
||||||
|
|||||||
Reference in New Issue
Block a user