Fix: validate URL scheme and resolved IP before crawling to prevent SSRF (#14090)

### What problem does this PR solve?

The POST /upload_info?url=<url> endpoint accepted a user-supplied URL
and passed it directly to AsyncWebCrawler without any validation. There
were no restrictions on URL scheme, destination hostname, or resolved IP
address. This allowed any authenticated user to instruct the server to
make outbound HTTP requests to internal infrastructure — including RFC
1918 private networks, loopback addresses, and cloud metadata services
such as http://169.254.169.254 — effectively using the server as a proxy
for internal network reconnaissance or credential theft.

This PR adds an SSRF guard (_validate_url_for_crawl) that runs before
any crawl is initiated. It enforces an allowlist of safe schemes
(http/https), resolves the hostname at validation time, and rejects any
URL whose resolved IP falls within a private or reserved network range.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)
This commit is contained in:
Xing Hong
2026-04-25 15:30:15 +09:00
committed by GitHub
parent 78188ce9e9
commit fb95136f39
10 changed files with 485 additions and 109 deletions

View File

@@ -179,10 +179,7 @@ class Invoke(ComponentBase, ABC):
if not isinstance(headers, dict):
raise ValueError("Invoke headers must be a JSON object.")
return {
key: self._resolve_header_text(value, kwargs) if isinstance(value, str) else value
for key, value in headers.items()
}
return {key: self._resolve_header_text(value, kwargs) if isinstance(value, str) else value for key, value in headers.items()}
def _build_proxies(self) -> dict | None:
if not re.sub(r"https?:?/?/?", "", self._param.proxy):
@@ -215,7 +212,7 @@ class Invoke(ComponentBase, ABC):
# HtmlParser keeps the Invoke output text-focused when the endpoint returns HTML.
sections = HtmlParser()(None, response.content)
return "\n".join(sections)
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 3)))
def _invoke(self, **kwargs):
if self.check_if_canceled("Invoke processing"):

View File

@@ -19,7 +19,6 @@ from crawl4ai import AsyncWebCrawler
from agent.tools.base import ToolParamBase, ToolBase
class CrawlerParam(ToolParamBase):
"""
Define the Crawler component parameters.
@@ -31,20 +30,26 @@ class CrawlerParam(ToolParamBase):
self.extract_type = "markdown"
def check(self):
self.check_valid_value(self.extract_type, "Type of content from the crawler", ['html', 'markdown', 'content'])
self.check_valid_value(self.extract_type, "Type of content from the crawler", ["html", "markdown", "content"])
class Crawler(ToolBase, ABC):
component_name = "Crawler"
def _run(self, history, **kwargs):
from api.utils.web_utils import is_valid_url
from common.ssrf_guard import assert_url_is_safe, pin_dns_global
ans = self.get_input()
ans = " - ".join(ans["content"]) if "content" in ans else ""
if not is_valid_url(ans):
try:
_ssrf_hostname, _ssrf_ip = assert_url_is_safe(ans)
except ValueError:
return Crawler.be_output("URL not valid")
try:
result = asyncio.run(self.get_web(ans))
# pin_dns_global is used (not thread-local) because crawl4ai resolves
# DNS in asyncio executor threads that don't share thread-local state.
with pin_dns_global(_ssrf_hostname, _ssrf_ip):
result = asyncio.run(self.get_web(ans))
return Crawler.be_output(result)
@@ -57,18 +62,15 @@ class Crawler(ToolBase, ABC):
proxy = self._param.proxy if self._param.proxy else None
async with AsyncWebCrawler(verbose=True, proxy=proxy) as crawler:
result = await crawler.arun(
url=url,
bypass_cache=True
)
result = await crawler.arun(url=url, bypass_cache=True)
if self.check_if_canceled("Crawler async operation"):
return
if self._param.extract_type == 'html':
if self._param.extract_type == "html":
return result.cleaned_html
elif self._param.extract_type == 'markdown':
elif self._param.extract_type == "markdown":
return result.markdown
elif self._param.extract_type == 'content':
elif self._param.extract_type == "content":
return result.extracted_content
return result.markdown

View File

@@ -20,6 +20,7 @@ from abc import ABC
import requests
from agent.tools.base import ToolMeta, ToolParamBase, ToolBase
from common.connection_utils import timeout
from common.ssrf_guard import assert_url_is_safe, pin_dns
class SearXNGParam(ToolParamBase):
@@ -36,15 +37,15 @@ class SearXNGParam(ToolParamBase):
"type": "string",
"description": "The search keywords to execute with SearXNG. The keywords should be the most important words/terms(includes synonyms) from the original request.",
"default": "{sys.query}",
"required": True
"required": True,
},
"searxng_url": {
"type": "string",
"description": "The base URL of your SearXNG instance (e.g., http://localhost:4000). This is required to connect to your SearXNG server.",
"required": False,
"default": ""
}
}
"default": "",
},
},
}
super().__init__()
self.top_n = 10
@@ -61,17 +62,7 @@ class SearXNGParam(ToolParamBase):
self.check_positive_integer(self.top_n, "Top N")
def get_input_form(self) -> dict[str, dict]:
return {
"query": {
"name": "Query",
"type": "line"
},
"searxng_url": {
"name": "SearXNG URL",
"type": "line",
"placeholder": "http://localhost:4000"
}
}
return {"query": {"name": "Query", "type": "line"}, "searxng_url": {"name": "SearXNG URL", "type": "line", "placeholder": "http://localhost:4000"}}
class SearXNG(ToolBase, ABC):
@@ -94,26 +85,22 @@ class SearXNG(ToolBase, ABC):
self.set_output("formalized_content", "")
return ""
try:
_ssrf_hostname, _ssrf_ip = assert_url_is_safe(searxng_url)
except ValueError as e:
self.set_output("_ERROR", str(e))
return f"SearXNG error: SSRF guard blocked {searxng_url!r}: {e}"
last_e = ""
for _ in range(self._param.max_retries+1):
for _ in range(self._param.max_retries + 1):
if self.check_if_canceled("SearXNG processing"):
return
try:
search_params = {
'q': query,
'format': 'json',
'categories': 'general',
'language': 'auto',
'safesearch': 1,
'pageno': 1
}
search_params = {"q": query, "format": "json", "categories": "general", "language": "auto", "safesearch": 1, "pageno": 1}
response = requests.get(
f"{searxng_url}/search",
params=search_params,
timeout=10
)
with pin_dns(_ssrf_hostname, _ssrf_ip):
response = requests.get(f"{searxng_url}/search", params=search_params, timeout=10)
response.raise_for_status()
if self.check_if_canceled("SearXNG processing"):
@@ -128,15 +115,12 @@ class SearXNG(ToolBase, ABC):
if not isinstance(results, list):
raise ValueError("Invalid results format from SearXNG")
results = results[:self._param.top_n]
results = results[: self._param.top_n]
if self.check_if_canceled("SearXNG processing"):
return
self._retrieve_chunks(results,
get_title=lambda r: r.get("title", ""),
get_url=lambda r: r.get("url", ""),
get_content=lambda r: r.get("content", ""))
self._retrieve_chunks(results, get_title=lambda r: r.get("title", ""), get_url=lambda r: r.get("url", ""), get_content=lambda r: r.get("content", ""))
self.set_output("json", results)
return self.output("formalized_content")

View File

@@ -43,6 +43,7 @@ from common import settings
from common.constants import SANDBOX_ARTIFACT_BUCKET, ParserType, RetCode, TaskStatus
from common.file_utils import get_project_base_directory
from common.misc_utils import get_uuid, thread_pool_exec
from common.ssrf_guard import assert_url_is_safe
from deepdoc.parser.html_parser import RAGFlowHtmlParser
from rag.nlp import search
@@ -333,6 +334,7 @@ async def run():
except Exception as e:
return server_error_response(e)
@manager.route("/get/<doc_id>", methods=["GET"]) # noqa: F821
@login_required
async def get(doc_id):
@@ -581,6 +583,7 @@ async def upload_info():
try:
if url and not file_objs:
assert_url_is_safe(url)
return get_json_result(data=FileService.upload_info(current_user.id, None, url))
if len(file_objs) == 1:

View File

@@ -23,6 +23,8 @@ from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import Union
logger = logging.getLogger(__name__)
import xxhash
from peewee import fn
@@ -33,6 +35,7 @@ from api.db.services.common_service import CommonService
from api.db.services.document_service import DocumentService
from api.db.services.file2document_service import File2DocumentService
from common.misc_utils import get_uuid
from common.ssrf_guard import assert_url_is_safe
from common.constants import TaskStatus, FileSource, ParserType
from api.db.services.knowledgebase_service import KnowledgebaseService
from api.db.services.task_service import TaskService
@@ -624,6 +627,26 @@ class FileService(CommonService):
return errors
_ALLOWED_SCHEMES = {"http", "https"}
@staticmethod
def _validate_url_for_crawl(url: str) -> tuple[str, str]:
"""Raise ValueError if the URL is not safe to crawl (SSRF guard).
Delegates to :func:`common.ssrf_guard.assert_url_is_safe`, which
validates the scheme, hostname, and every DNS-resolved address, and
returns ``(hostname, resolved_ip)`` for DNS pinning.
Only the scheme and host (and port when present) are forwarded to the
guard so that credentials or query parameters in *url* are never
written to the log.
"""
from urllib.parse import urlparse
parsed = urlparse(url)
port_suffix = f":{parsed.port}" if parsed.port else ""
redacted = f"{parsed.scheme}://{parsed.hostname}{port_suffix}"
return assert_url_is_safe(redacted, allowed_schemes=FileService._ALLOWED_SCHEMES)
@staticmethod
def upload_info(user_id, file, url: str|None=None):
def structured(filename, filetype, blob, content_type):
@@ -646,6 +669,53 @@ class FileService(CommonService):
}
if url:
import requests as _requests
from urllib.parse import urljoin as _urljoin
_MAX_CRAWL_REDIRECTS = 10
# Pre-resolve the full redirect chain so that AsyncWebCrawler never
# follows a server-sent redirect to an unvalidated (potentially
# internal) host. Each hop is SSRF-checked before being followed;
# the validated (hostname, ip) pairs are pinned via Chromium's
# --host-resolver-rules so the browser cannot re-resolve any of them
# through a fresh DNS query.
current_url = url
current_hostname, current_ip = FileService._validate_url_for_crawl(current_url)
# Accumulate MAP rules for every hostname we encounter in the chain.
host_pins: dict[str, str] = {current_hostname: current_ip}
for _ in range(_MAX_CRAWL_REDIRECTS):
try:
_resp = _requests.get(
current_url,
timeout=10,
allow_redirects=False,
)
except _requests.RequestException as _exc:
raise ValueError(f"Failed to fetch {current_url!r}: {_exc}") from _exc
if _resp.status_code not in (301, 302, 303, 307, 308):
break
_location = _resp.headers.get("Location")
if not _location:
break
_next_url = _urljoin(current_url, _location)
_next_hostname, _next_ip = FileService._validate_url_for_crawl(_next_url)
host_pins[_next_hostname] = _next_ip
current_url = _next_url
else:
raise ValueError(
f"Exceeded {_MAX_CRAWL_REDIRECTS} redirects fetching {url!r}"
)
# Build a single MAP rule string covering every validated hostname
# in the redirect chain. Chromium uses the pinned IP for each,
# skipping DNS entirely and eliminating the rebinding window.
_map_rules = ",".join(f"MAP {h} {ip}" for h, ip in host_pins.items())
from crawl4ai import (
AsyncWebCrawler,
BrowserConfig,
@@ -659,6 +729,7 @@ class FileService(CommonService):
browser_config = BrowserConfig(
headless=True,
verbose=False,
extra_args=[f"--host-resolver-rules={_map_rules}"],
)
async with AsyncWebCrawler(config=browser_config) as crawler:
crawler_config = CrawlerRunConfig(
@@ -668,8 +739,10 @@ class FileService(CommonService):
pdf=True,
screenshot=False
)
# Use the final resolved URL so the browser starts at the
# redirect destination rather than re-following the chain.
result: CrawlResult = await crawler.arun(
url=url,
url=current_url,
config=crawler_config
)
return result
@@ -679,7 +752,7 @@ class FileService(CommonService):
filename += ".pdf"
return structured(filename, "pdf", page.pdf, page.response_headers["content-type"])
return structured(filename, "html", str(page.markdown).encode("utf-8"), page.response_headers["content-type"], user_id)
return structured(filename, "html", str(page.markdown).encode("utf-8"), page.response_headers["content-type"])
DocumentService.check_doc_health(user_id, file.filename)
return structured(file.filename, filename_type(file.filename), file.read(), file.content_type)

View File

@@ -15,11 +15,8 @@
#
import base64
import ipaddress
import json
import re
import socket
from urllib.parse import urlparse
import aiosmtplib
from email.mime.text import MIMEText
from email.header import Header
@@ -37,10 +34,10 @@ from webdriver_manager.chrome import ChromeDriverManager
OTP_LENGTH = 4
OTP_TTL_SECONDS = 5 * 60 # valid for 5 minutes
ATTEMPT_LIMIT = 5 # maximum attempts
ATTEMPT_LOCK_SECONDS = 30 * 60 # lock for 30 minutes
RESEND_COOLDOWN_SECONDS = 60 # cooldown for 1 minute
OTP_TTL_SECONDS = 5 * 60 # valid for 5 minutes
ATTEMPT_LIMIT = 5 # maximum attempts
ATTEMPT_LOCK_SECONDS = 30 * 60 # lock for 30 minutes
RESEND_COOLDOWN_SECONDS = 60 # cooldown for 1 minute
CONTENT_TYPE_MAP = {
@@ -188,29 +185,16 @@ def __get_pdf_from_html(path: str, timeout: int, install_driver: bool, print_opt
return base64.b64decode(result["data"])
def is_private_ip(ip: str) -> bool:
try:
ip_obj = ipaddress.ip_address(ip)
return ip_obj.is_private
except ValueError:
return False
def is_valid_url(url: str) -> bool:
if not re.match(r"(https?)://[-A-Za-z0-9+&@#/%?=~_|!:,.;]+[-A-Za-z0-9+&@#/%=~_|]", url):
return False
parsed_url = urlparse(url)
hostname = parsed_url.hostname
from common.ssrf_guard import assert_url_is_safe
if not hostname:
return False
try:
ip = socket.gethostbyname(hostname)
if is_private_ip(ip):
return False
except socket.gaierror:
assert_url_is_safe(url)
return True
except ValueError:
return False
return True
def safe_json_parse(data: str | dict) -> dict:

View File

@@ -1,11 +1,9 @@
import hashlib
import ipaddress
import socket
from datetime import datetime, timezone
from email.utils import parsedate_to_datetime
from time import struct_time
from typing import Any
from urllib.parse import urlparse
from urllib.parse import urljoin, urlparse
import bs4
import feedparser
@@ -14,28 +12,9 @@ import requests
from common.data_source.config import INDEX_BATCH_SIZE, REQUEST_TIMEOUT_SECONDS, DocumentSource
from common.data_source.interfaces import LoadConnector, PollConnector
from common.data_source.models import Document, GenerateDocumentsOutput, SecondsSinceUnixEpoch
from common.ssrf_guard import assert_url_is_safe, pin_dns as _pin_dns
def _is_private_ip(ip: str) -> bool:
try:
ip_obj = ipaddress.ip_address(ip)
return ip_obj.is_private or ip_obj.is_link_local or ip_obj.is_loopback
except ValueError:
return False
def _validate_url_no_ssrf(url: str) -> None:
parsed = urlparse(url)
hostname = parsed.hostname
if not hostname:
raise ValueError("URL must have a valid hostname")
try:
ip = socket.gethostbyname(hostname)
if _is_private_ip(ip):
raise ValueError(f"URL resolves to private/internal IP address: {ip}")
except socket.gaierror as e:
raise ValueError(f"Failed to resolve hostname: {hostname}") from e
_MAX_REDIRECTS = 10
class RSSConnector(LoadConnector, PollConnector):
@@ -87,7 +66,8 @@ class RSSConnector(LoadConnector, PollConnector):
if batch:
yield batch
def _validate_feed_url(self) -> None:
def _validate_feed_url(self) -> tuple[str, str]:
"""Validate ``self.feed_url`` and return ``(hostname, resolved_ip)``."""
if not self.feed_url:
raise ValueError("feed_url is required")
@@ -95,7 +75,7 @@ class RSSConnector(LoadConnector, PollConnector):
if parsed.scheme not in {"http", "https"} or not parsed.netloc:
raise ValueError("feed_url must be a valid http or https URL")
_validate_url_no_ssrf(self.feed_url)
return assert_url_is_safe(self.feed_url)
def _read_feed(self, require_entries: bool) -> Any:
if self._cached_feed is not None:
@@ -103,15 +83,38 @@ class RSSConnector(LoadConnector, PollConnector):
raise ValueError("RSS feed contains no entries")
return self._cached_feed
self._validate_feed_url()
# Validate once to get the pinned IP for the initial request.
current_hostname, current_ip = self._validate_feed_url()
current_url = self.feed_url
# Follow redirects manually: each hop is validated and DNS-pinned
# *before* the connection is made, closing the TOCTOU rebinding window
# that existed when allow_redirects=True was used with post-hoc checks.
response: requests.Response | None = None
for _ in range(_MAX_REDIRECTS + 1):
with _pin_dns(current_hostname, current_ip):
response = requests.get(
current_url,
timeout=REQUEST_TIMEOUT_SECONDS,
allow_redirects=False,
)
if response.status_code not in (301, 302, 303, 307, 308):
break
location = response.headers.get("Location")
if not location:
break # broken redirect; let raise_for_status() handle it
redirect_url = urljoin(current_url, location)
# Validate redirect target before following it.
current_hostname, current_ip = assert_url_is_safe(redirect_url)
current_url = redirect_url
else:
raise ValueError(f"Exceeded {_MAX_REDIRECTS} redirects fetching {self.feed_url!r}")
response = requests.get(self.feed_url, timeout=REQUEST_TIMEOUT_SECONDS, allow_redirects=True)
response.raise_for_status()
final_url = getattr(response, "url", self.feed_url)
if final_url != self.feed_url and urlparse(final_url).hostname:
_validate_url_no_ssrf(final_url)
feed = feedparser.parse(response.content)
if getattr(feed, "bozo", False) and not feed.entries:
error = getattr(feed, "bozo_exception", None)

172
common/ssrf_guard.py Normal file
View File

@@ -0,0 +1,172 @@
#
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Shared SSRF-guard utilities.
Uses only the standard library so it can be imported from both ``api/`` and
``common/`` without pulling in any heavyweight dependencies.
"""
import ipaddress
import logging
import socket
import threading
from contextlib import contextmanager
from urllib.parse import urlparse
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# DNS pinning — closes the TOCTOU / rebinding window between SSRF validation
# and the actual TCP connection. The monkey-patch is a no-op for any host
# that has no active pin, so it cannot affect unrelated code.
# ---------------------------------------------------------------------------
_tl = threading.local()
_global_dns_pins: dict[str, str] = {}
_global_pin_lock = threading.Lock()
_orig_getaddrinfo = socket.getaddrinfo
def _getaddrinfo_with_pins(host, port, *args, **kwargs):
# Thread-local pins (synchronous callers: requests.get in the same thread)
local_pins: dict = getattr(_tl, "dns_pins", {})
if host in local_pins:
ip = local_pins[host]
family = socket.AF_INET6 if ":" in ip else socket.AF_INET
return [(family, socket.SOCK_STREAM, 6, "", (ip, port or 0))]
# Process-global pins (async callers whose DNS resolves in executor threads)
with _global_pin_lock:
ip = _global_dns_pins.get(host)
if ip is not None:
family = socket.AF_INET6 if ":" in ip else socket.AF_INET
return [(family, socket.SOCK_STREAM, 6, "", (ip, port or 0))]
return _orig_getaddrinfo(host, port, *args, **kwargs)
socket.getaddrinfo = _getaddrinfo_with_pins
@contextmanager
def pin_dns(hostname: str, ip: str):
"""Pin *hostname* → *ip* in the current thread for the duration of this context.
Use for synchronous ``requests.get()`` callers to prevent DNS rebinding
between SSRF validation and the actual TCP connection.
"""
pins = _tl.__dict__.setdefault("dns_pins", {})
pins[hostname] = ip
try:
yield
finally:
pins.pop(hostname, None)
@contextmanager
def pin_dns_global(hostname: str, ip: str):
"""Pin *hostname* → *ip* across all threads for the duration of this context.
Use for async callers (e.g. asyncio-based crawlers) where DNS resolution
may happen in thread-pool executor threads rather than the calling thread.
"""
with _global_pin_lock:
_global_dns_pins[hostname] = ip
try:
yield
finally:
with _global_pin_lock:
_global_dns_pins.pop(hostname, None)
_DEFAULT_ALLOWED_SCHEMES: frozenset[str] = frozenset({"http", "https"})
def _effective_ip(
ip: ipaddress.IPv4Address | ipaddress.IPv6Address,
) -> ipaddress.IPv4Address | ipaddress.IPv6Address:
"""Return the IPv4 equivalent for IPv4-mapped IPv6 addresses, unchanged otherwise.
Without this normalization ``::ffff:127.0.0.1`` would pass ``is_global``
as an IPv6Address in some Python versions, bypassing the loopback check.
"""
if isinstance(ip, ipaddress.IPv6Address):
mapped = ip.ipv4_mapped
if mapped is not None:
return mapped
return ip
def assert_url_is_safe(
url: str,
*,
allowed_schemes: frozenset[str] = _DEFAULT_ALLOWED_SCHEMES,
) -> tuple[str, str]:
"""Raise ``ValueError`` if *url* is not safe to fetch (SSRF guard).
Checks performed in order:
1. Scheme is in *allowed_schemes*.
2. Hostname is present.
3. **Every** address returned by ``getaddrinfo`` is globally routable
(``ip.is_global``). This is an allowlist approach: it catches private,
loopback, link-local, reserved, multicast, and all other
special-purpose ranges rather than individual deny-list flags.
IPv4-mapped IPv6 addresses (e.g. ``::ffff:127.0.0.1``) are normalised
to their IPv4 form via :func:`_effective_ip` before the check.
Returns ``(hostname, resolved_ip)`` — the first validated public IP string
— so the caller can **pin** that address in its HTTP client and prevent
DNS-rebinding attacks (the hostname is resolved exactly once).
"""
parsed = urlparse(url)
scheme = parsed.scheme
if scheme not in allowed_schemes:
logger.warning(
"SSRF guard blocked URL with disallowed scheme: scheme=%r url=%r",
scheme,
url,
)
raise ValueError(f"Disallowed URL scheme: {scheme!r}. Only {sorted(allowed_schemes)} are allowed.")
hostname = parsed.hostname
if not hostname:
logger.warning("SSRF guard blocked URL with missing host: url=%r", url)
raise ValueError("URL is missing a host.")
try:
addr_infos = socket.getaddrinfo(hostname, None)
except socket.gaierror as exc:
logger.warning("SSRF guard could not resolve hostname=%r reason=%s", hostname, exc)
raise ValueError(f"Could not resolve hostname {hostname!r}: {exc}") from exc
resolved_ip: str | None = None
for _family, _type, _proto, _canonname, sockaddr in addr_infos:
raw_ip = ipaddress.ip_address(sockaddr[0])
eff_ip = _effective_ip(raw_ip)
if not eff_ip.is_global:
logger.warning(
"SSRF guard blocked URL: hostname=%r resolved to non-public address=%s",
hostname,
raw_ip,
)
raise ValueError(f"URL resolves to a non-public address ({raw_ip}), which is not allowed.")
if resolved_ip is None:
resolved_ip = str(raw_ip)
if resolved_ip is None:
logger.warning("SSRF guard blocked URL: hostname=%r resolved to no addresses", hostname)
raise ValueError(f"Hostname {hostname!r} resolved to no addresses.")
return hostname, resolved_ip

View File

@@ -79,6 +79,7 @@ def _load_document_app_module(monkeypatch):
@pytest.mark.p2
def test_upload_info_rejects_mixed_inputs(monkeypatch):
module = _load_document_app_module(monkeypatch)
monkeypatch.setattr(module, "assert_url_is_safe", lambda url: ("example.com", "93.184.216.34"))
files = _DummyFiles({"file": [_DummyFile("a.txt")]})
monkeypatch.setattr(module, "request", _DummyRequest(files=files, args={"url": "https://example.com/a.txt"}))
@@ -100,6 +101,7 @@ def test_upload_info_requires_file_or_url(monkeypatch):
@pytest.mark.p2
def test_upload_info_supports_url_single_and_multiple_files(monkeypatch):
module = _load_document_app_module(monkeypatch)
monkeypatch.setattr(module, "assert_url_is_safe", lambda url: ("example.com", "93.184.216.34"))
captured = []
def fake_upload_info(user_id, file_obj, url=None):

View File

@@ -14,6 +14,7 @@
# limitations under the License.
#
import importlib.util
import socket
import sys
import types
import warnings
@@ -120,3 +121,158 @@ def test_upload_document_skips_cross_kb_document_id_collision(monkeypatch):
assert len(err) == 1
assert err[0].startswith("collision.txt: ")
assert "Existing document id collision with another knowledge base; skipping update." in err[0]
# ---------------------------------------------------------------------------
# Helpers shared by TestValidateUrlForCrawl
# ---------------------------------------------------------------------------
def _addrinfo(ip_str: str) -> list:
"""Build a minimal getaddrinfo-style result for a single address string."""
family = socket.AF_INET6 if ":" in ip_str else socket.AF_INET
return [(family, socket.SOCK_STREAM, 6, "", (ip_str, 0))]
# ---------------------------------------------------------------------------
# _validate_url_for_crawl SSRF-guard tests
# ---------------------------------------------------------------------------
@pytest.mark.p2
class TestValidateUrlForCrawl:
"""Focused regression suite for the SSRF guard on the URL-crawl path.
All DNS lookups are monkeypatched so the tests are deterministic and
require no network access.
"""
# -- scheme checks -------------------------------------------------------
def test_rejects_ftp_scheme(self):
with pytest.raises(ValueError, match="scheme"):
FileService._validate_url_for_crawl("ftp://example.com/file.txt")
def test_rejects_file_scheme(self):
with pytest.raises(ValueError, match="scheme"):
FileService._validate_url_for_crawl("file:///etc/passwd")
def test_rejects_javascript_scheme(self):
with pytest.raises(ValueError, match="scheme"):
FileService._validate_url_for_crawl("javascript:alert(1)")
# -- host checks ---------------------------------------------------------
def test_rejects_missing_host(self):
with pytest.raises(ValueError, match="host"):
FileService._validate_url_for_crawl("http:///path")
def test_rejects_dns_resolution_failure(self, monkeypatch):
def _raise(h, p):
raise socket.gaierror("NXDOMAIN")
monkeypatch.setattr(socket, "getaddrinfo", _raise)
with pytest.raises(ValueError, match="Could not resolve"):
FileService._validate_url_for_crawl("http://nxdomain.invalid/")
# -- blocked address families --------------------------------------------
def test_rejects_loopback_ipv4(self, monkeypatch):
monkeypatch.setattr(socket, "getaddrinfo", lambda h, p: _addrinfo("127.0.0.1"))
with pytest.raises(ValueError, match="non-public"):
FileService._validate_url_for_crawl("http://localhost/")
def test_rejects_private_class_a(self, monkeypatch):
monkeypatch.setattr(socket, "getaddrinfo", lambda h, p: _addrinfo("10.0.0.1"))
with pytest.raises(ValueError, match="non-public"):
FileService._validate_url_for_crawl("http://internal.example/")
def test_rejects_private_class_b(self, monkeypatch):
monkeypatch.setattr(socket, "getaddrinfo", lambda h, p: _addrinfo("172.16.0.1"))
with pytest.raises(ValueError, match="non-public"):
FileService._validate_url_for_crawl("http://internal.example/")
def test_rejects_private_class_c(self, monkeypatch):
monkeypatch.setattr(socket, "getaddrinfo", lambda h, p: _addrinfo("192.168.1.100"))
with pytest.raises(ValueError, match="non-public"):
FileService._validate_url_for_crawl("http://internal.example/")
def test_rejects_link_local_ipv4(self, monkeypatch):
monkeypatch.setattr(socket, "getaddrinfo", lambda h, p: _addrinfo("169.254.0.1"))
with pytest.raises(ValueError, match="non-public"):
FileService._validate_url_for_crawl("http://link-local.example/")
def test_rejects_reserved_ipv4(self, monkeypatch):
# 240.0.0.0/4 is IANA reserved — not globally routable
monkeypatch.setattr(socket, "getaddrinfo", lambda h, p: _addrinfo("240.0.0.1"))
with pytest.raises(ValueError, match="non-public"):
FileService._validate_url_for_crawl("http://reserved.example/")
def test_rejects_ipv4_mapped_loopback(self, monkeypatch):
"""::ffff:127.0.0.1 must not bypass the loopback check."""
monkeypatch.setattr(socket, "getaddrinfo", lambda h, p: _addrinfo("::ffff:127.0.0.1"))
with pytest.raises(ValueError, match="non-public"):
FileService._validate_url_for_crawl("http://mapped-loopback.example/")
def test_rejects_ipv4_mapped_private(self, monkeypatch):
"""::ffff:192.168.1.1 must not bypass the private-range check."""
monkeypatch.setattr(socket, "getaddrinfo", lambda h, p: _addrinfo("::ffff:192.168.1.1"))
with pytest.raises(ValueError, match="non-public"):
FileService._validate_url_for_crawl("http://mapped-private.example/")
def test_rejects_when_any_record_is_private(self, monkeypatch):
"""All DNS records must pass; one private record is enough to block."""
monkeypatch.setattr(
socket,
"getaddrinfo",
lambda h, p: _addrinfo("93.184.216.34") + _addrinfo("10.0.0.1"),
)
with pytest.raises(ValueError, match="non-public"):
FileService._validate_url_for_crawl("http://mixed.example/")
# -- allowed cases -------------------------------------------------------
def test_allows_public_ipv4(self, monkeypatch):
monkeypatch.setattr(socket, "getaddrinfo", lambda h, p: _addrinfo("93.184.216.34"))
hostname, resolved_ip = FileService._validate_url_for_crawl("https://example.com/doc.pdf")
assert hostname == "example.com"
assert resolved_ip == "93.184.216.34"
def test_allows_public_ipv6(self, monkeypatch):
monkeypatch.setattr(
socket,
"getaddrinfo",
lambda h, p: _addrinfo("2606:2800:220:1:248:1893:25c8:1946"),
)
hostname, resolved_ip = FileService._validate_url_for_crawl("https://example.com/")
assert hostname == "example.com"
assert resolved_ip == "2606:2800:220:1:248:1893:25c8:1946"
def test_allows_http_scheme(self, monkeypatch):
monkeypatch.setattr(socket, "getaddrinfo", lambda h, p: _addrinfo("1.2.3.4"))
hostname, _ = FileService._validate_url_for_crawl("http://example.com/")
assert hostname == "example.com"
# -- multi-record behaviour ----------------------------------------------
def test_returns_first_ip_for_multi_record_host(self, monkeypatch):
"""The first public IP is returned as the DNS pin value."""
monkeypatch.setattr(
socket,
"getaddrinfo",
lambda h, p: _addrinfo("1.2.3.4") + _addrinfo("5.6.7.8"),
)
_, resolved_ip = FileService._validate_url_for_crawl("http://multi.example/")
assert resolved_ip == "1.2.3.4"
def test_allows_dual_stack_host(self, monkeypatch):
"""A host with both public IPv4 and public IPv6 records is allowed."""
monkeypatch.setattr(
socket,
"getaddrinfo",
lambda h, p: (
_addrinfo("93.184.216.34")
+ _addrinfo("2606:2800:220:1:248:1893:25c8:1946")
),
)
hostname, resolved_ip = FileService._validate_url_for_crawl("https://example.com/")
assert hostname == "example.com"
assert resolved_ip == "93.184.216.34"