mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 15:31:05 +08:00
226 lines
8.2 KiB
Python
226 lines
8.2 KiB
Python
#
|
|
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
"""Shared SSRF-guard utilities.
|
|
|
|
Uses only the standard library so it can be imported from both ``api/`` and
|
|
``common/`` without pulling in any heavyweight dependencies.
|
|
"""
|
|
|
|
import ipaddress
|
|
import logging
|
|
import os
|
|
import socket
|
|
import threading
|
|
from contextlib import contextmanager
|
|
from urllib.parse import urlparse
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# DNS pinning — closes the TOCTOU / rebinding window between SSRF validation
|
|
# and the actual TCP connection. The monkey-patch is a no-op for any host
|
|
# that has no active pin, so it cannot affect unrelated code.
|
|
# ---------------------------------------------------------------------------
|
|
|
|
_tl = threading.local()
|
|
_global_dns_pins: dict[str, str] = {}
|
|
_global_pin_lock = threading.Lock()
|
|
_orig_getaddrinfo = socket.getaddrinfo
|
|
|
|
|
|
def _getaddrinfo_with_pins(host, port, *args, **kwargs):
|
|
# Thread-local pins (synchronous callers: requests.get in the same thread)
|
|
local_pins: dict = getattr(_tl, "dns_pins", {})
|
|
if host in local_pins:
|
|
ip = local_pins[host]
|
|
family = socket.AF_INET6 if ":" in ip else socket.AF_INET
|
|
return [(family, socket.SOCK_STREAM, 6, "", (ip, port or 0))]
|
|
# Process-global pins (async callers whose DNS resolves in executor threads)
|
|
with _global_pin_lock:
|
|
ip = _global_dns_pins.get(host)
|
|
if ip is not None:
|
|
family = socket.AF_INET6 if ":" in ip else socket.AF_INET
|
|
return [(family, socket.SOCK_STREAM, 6, "", (ip, port or 0))]
|
|
return _orig_getaddrinfo(host, port, *args, **kwargs)
|
|
|
|
|
|
socket.getaddrinfo = _getaddrinfo_with_pins
|
|
|
|
|
|
@contextmanager
|
|
def pin_dns(hostname: str, ip: str):
|
|
"""Pin *hostname* → *ip* in the current thread for the duration of this context.
|
|
|
|
Use for synchronous ``requests.get()`` callers to prevent DNS rebinding
|
|
between SSRF validation and the actual TCP connection.
|
|
"""
|
|
pins = _tl.__dict__.setdefault("dns_pins", {})
|
|
pins[hostname] = ip
|
|
try:
|
|
yield
|
|
finally:
|
|
pins.pop(hostname, None)
|
|
|
|
|
|
@contextmanager
|
|
def pin_dns_global(hostname: str, ip: str):
|
|
"""Pin *hostname* → *ip* across all threads for the duration of this context.
|
|
|
|
Use for async callers (e.g. asyncio-based crawlers) where DNS resolution
|
|
may happen in thread-pool executor threads rather than the calling thread.
|
|
"""
|
|
with _global_pin_lock:
|
|
_global_dns_pins[hostname] = ip
|
|
try:
|
|
yield
|
|
finally:
|
|
with _global_pin_lock:
|
|
_global_dns_pins.pop(hostname, None)
|
|
|
|
|
|
_DEFAULT_ALLOWED_SCHEMES: frozenset[str] = frozenset({"http", "https"})
|
|
_ALLOW_ANY_HOST_ENV = "ALLOW_ANY_HOST"
|
|
|
|
|
|
def _allow_any_host() -> bool:
|
|
return os.environ.get(_ALLOW_ANY_HOST_ENV, "").strip().lower() in {"1", "true", "yes", "on"}
|
|
|
|
|
|
def _effective_ip(
|
|
ip: ipaddress.IPv4Address | ipaddress.IPv6Address,
|
|
) -> ipaddress.IPv4Address | ipaddress.IPv6Address:
|
|
"""Return the IPv4 equivalent for IPv4-mapped IPv6 addresses, unchanged otherwise.
|
|
|
|
Without this normalization ``::ffff:127.0.0.1`` would pass ``is_global``
|
|
as an IPv6Address in some Python versions, bypassing the loopback check.
|
|
"""
|
|
if isinstance(ip, ipaddress.IPv6Address):
|
|
mapped = ip.ipv4_mapped
|
|
if mapped is not None:
|
|
return mapped
|
|
return ip
|
|
|
|
|
|
def assert_url_is_safe(
|
|
url: str,
|
|
*,
|
|
allowed_schemes: frozenset[str] = _DEFAULT_ALLOWED_SCHEMES,
|
|
) -> tuple[str, str]:
|
|
"""Raise ``ValueError`` if *url* is not safe to fetch (SSRF guard).
|
|
|
|
Checks performed in order:
|
|
|
|
1. Scheme is in *allowed_schemes*.
|
|
2. Hostname is present.
|
|
3. **Every** address returned by ``getaddrinfo`` is globally routable
|
|
(``ip.is_global``). This is an allowlist approach: it catches private,
|
|
loopback, link-local, reserved, multicast, and all other
|
|
special-purpose ranges rather than individual deny-list flags.
|
|
IPv4-mapped IPv6 addresses (e.g. ``::ffff:127.0.0.1``) are normalised
|
|
to their IPv4 form via :func:`_effective_ip` before the check.
|
|
|
|
Returns ``(hostname, resolved_ip)`` — the first validated public IP string
|
|
— so the caller can **pin** that address in its HTTP client and prevent
|
|
DNS-rebinding attacks (the hostname is resolved exactly once).
|
|
"""
|
|
parsed = urlparse(url)
|
|
scheme = parsed.scheme
|
|
if scheme not in allowed_schemes:
|
|
logger.warning(
|
|
"SSRF guard blocked URL with disallowed scheme: scheme=%r url=%r",
|
|
scheme,
|
|
url,
|
|
)
|
|
raise ValueError(f"Disallowed URL scheme: {scheme!r}. Only {sorted(allowed_schemes)} are allowed.")
|
|
|
|
hostname = parsed.hostname
|
|
if not hostname:
|
|
logger.warning("SSRF guard blocked URL with missing host: url=%r", url)
|
|
raise ValueError("URL is missing a host.")
|
|
|
|
try:
|
|
addr_infos = socket.getaddrinfo(hostname, None)
|
|
except socket.gaierror as exc:
|
|
logger.warning("SSRF guard could not resolve hostname=%r reason=%s", hostname, exc)
|
|
raise ValueError(f"Could not resolve hostname {hostname!r}: {exc}") from exc
|
|
|
|
resolved_ip: str | None = None
|
|
for _family, _type, _proto, _canonname, sockaddr in addr_infos:
|
|
raw_ip = ipaddress.ip_address(sockaddr[0])
|
|
eff_ip = _effective_ip(raw_ip)
|
|
if not eff_ip.is_global:
|
|
logger.warning(
|
|
"SSRF guard blocked URL: hostname=%r resolved to non-public address=%s",
|
|
hostname,
|
|
raw_ip,
|
|
)
|
|
raise ValueError(f"URL resolves to a non-public address ({raw_ip}), which is not allowed.")
|
|
if resolved_ip is None:
|
|
resolved_ip = str(raw_ip)
|
|
|
|
if resolved_ip is None:
|
|
logger.warning("SSRF guard blocked URL: hostname=%r resolved to no addresses", hostname)
|
|
raise ValueError(f"Hostname {hostname!r} resolved to no addresses.")
|
|
|
|
return hostname, resolved_ip
|
|
|
|
|
|
def assert_host_is_safe(host: str) -> str:
|
|
"""Raise ``ValueError`` if *host* resolves to a non-public IP (SSRF guard for raw host/port connections).
|
|
|
|
This is the host-level counterpart of :func:`assert_url_is_safe`, intended
|
|
for callers that connect via database drivers or other non-HTTP protocols
|
|
where there is no URL to parse.
|
|
|
|
Returns the first validated public IP string so the caller can pin it if needed.
|
|
"""
|
|
host = host.strip()
|
|
if not host:
|
|
raise ValueError("Host must not be empty.")
|
|
if _allow_any_host():
|
|
logger.warning(
|
|
"SSRF guard bypass enabled via %s; allowing host without validation: host=%r",
|
|
_ALLOW_ANY_HOST_ENV,
|
|
host,
|
|
)
|
|
return host
|
|
|
|
try:
|
|
addr_infos = socket.getaddrinfo(host, None)
|
|
except socket.gaierror as exc:
|
|
logger.warning("SSRF guard could not resolve host=%r reason=%s", host, exc)
|
|
raise ValueError(f"Could not resolve host {host!r}: {exc}") from exc
|
|
|
|
resolved_ip: str | None = None
|
|
for _family, _type, _proto, _canonname, sockaddr in addr_infos:
|
|
raw_ip = ipaddress.ip_address(sockaddr[0])
|
|
eff_ip = _effective_ip(raw_ip)
|
|
if not eff_ip.is_global:
|
|
logger.warning(
|
|
"SSRF guard blocked host: host=%r resolved to non-public address=%s",
|
|
host,
|
|
raw_ip,
|
|
)
|
|
raise ValueError(f"Host resolves to a non-public address ({raw_ip}), which is not allowed.")
|
|
if resolved_ip is None:
|
|
resolved_ip = str(raw_ip)
|
|
|
|
if resolved_ip is None:
|
|
logger.warning("SSRF guard blocked host: host=%r resolved to no addresses", host)
|
|
raise ValueError(f"Host {host!r} resolved to no addresses.")
|
|
|
|
return resolved_ip
|