diff --git a/common/misc_utils.py b/common/misc_utils.py index 1826be77f3..9225fcd25d 100644 --- a/common/misc_utils.py +++ b/common/misc_utils.py @@ -24,22 +24,146 @@ import subprocess import sys import threading import uuid +from urllib.parse import urljoin from concurrent.futures import ThreadPoolExecutor +logger = logging.getLogger(__name__) + def get_uuid(): return uuid.uuid1().hex +# OAuth avatar fetch: bounded size; each redirect hop is SSRF-checked and DNS-pinned +# (see common.ssrf_guard). +_OAUTH_AVATAR_MAX_BYTES = int(os.environ.get("RAGFLOW_OAUTH_AVATAR_MAX_BYTES", str(5 * 1024 * 1024))) +_OAUTH_AVATAR_MAX_REDIRECTS = int(os.environ.get("RAGFLOW_OAUTH_AVATAR_MAX_REDIRECTS", "5")) +_REDIRECT_STATUS = frozenset({301, 302, 303, 307, 308}) + + async def download_img(url): + """Fetch an image URL and return a data URI, or empty string on failure / SSRF block. + + URLs must resolve only to globally routable addresses; redirects are followed + only up to ``_OAUTH_AVATAR_MAX_REDIRECTS`` with each target validated. + """ if not url: return "" - from common.http_client import async_request - response = await async_request("GET", url) - return "data:" + \ - response.headers.get('Content-Type', 'image/jpg') + ";" + \ - "base64," + base64.b64encode(response.content).decode("utf-8") + if not isinstance(url, str): + url = str(url) + url = url.strip() + if not url: + return "" + + current_url = url + redirect_hops = 0 + + # Match common/http_client.py defaults without importing http_client (avoids + # pulling settings and keeps this path usable in lightweight test envs). + request_timeout = float(os.environ.get("HTTP_CLIENT_TIMEOUT", "15")) + proxy = os.environ.get("HTTP_CLIENT_PROXY") + user_agent = os.environ.get("HTTP_CLIENT_USER_AGENT", "ragflow-http-client") + + from common.ssrf_guard import assert_url_is_safe, pin_dns_global + + while redirect_hops <= _OAUTH_AVATAR_MAX_REDIRECTS: + try: + hostname, pin_ip = assert_url_is_safe(current_url) + except ValueError as exc: + logger.warning("download_img rejected URL (SSRF guard): %s", exc) + return "" + + import httpx + + timeout = httpx.Timeout(request_timeout) + headers = {} + if user_agent: + headers["User-Agent"] = user_agent + + async def _stream_one_get() -> tuple[str, str | None]: + """Return ``('redirect', new_url)``, ``('data', data_uri)``, or ``('fail', None)``.""" + with pin_dns_global(hostname, pin_ip): + async with httpx.AsyncClient( + timeout=timeout, + follow_redirects=False, + proxy=proxy, + ) as client: + async with client.stream("GET", current_url, headers=headers or None) as response: + if response.status_code in _REDIRECT_STATUS: + await response.aclose() + location = response.headers.get("location") + if not location: + logger.warning( + "download_img redirect missing Location header: url=%r status=%s redirect_hops=%s", + current_url, + response.status_code, + redirect_hops, + ) + return ("fail", None) + return ("redirect", urljoin(current_url, location)) + if response.status_code != 200: + logger.warning( + "download_img non-200 response: url=%r status=%s redirect_hops=%s", + current_url, + response.status_code, + redirect_hops, + ) + return ("fail", None) + body = bytearray() + async for chunk in response.aiter_bytes(): + if len(body) + len(chunk) > _OAUTH_AVATAR_MAX_BYTES: + logger.warning( + "download_img response exceeded max size: url=%r max_bytes=%s", + current_url, + _OAUTH_AVATAR_MAX_BYTES, + ) + await response.aclose() + return ("fail", None) + body.extend(chunk) + content_type = response.headers.get("Content-Type", "image/jpeg") + data_uri = ( + "data:" + + content_type + + ";base64," + + base64.b64encode(bytes(body)).decode("utf-8") + ) + return ("data", data_uri) + + try: + kind, payload = await asyncio.wait_for(_stream_one_get(), timeout=request_timeout) + except asyncio.TimeoutError: + logger.warning( + "download_img total wall-clock timeout: url=%r redirect_hops=%s timeout=%s", + current_url, + redirect_hops, + request_timeout, + ) + return "" + except Exception as exc: + logger.warning( + "download_img request failed: url=%r redirect_hops=%s err=%s", + current_url, + redirect_hops, + exc, + ) + return "" + + if kind == "redirect": + current_url = str(payload) + redirect_hops += 1 + continue + if kind == "fail": + return "" + return str(payload) + + logger.warning( + "download_img redirect hop limit exceeded: url=%r redirect_hops=%s max_redirects=%s", + current_url, + redirect_hops, + _OAUTH_AVATAR_MAX_REDIRECTS, + ) + return "" def hash_str2int(line: str, mod: int = 10 ** 8) -> int: diff --git a/test/unit_test/common/test_misc_utils.py b/test/unit_test/common/test_misc_utils.py index 82c8f97657..6ca24f1bbd 100644 --- a/test/unit_test/common/test_misc_utils.py +++ b/test/unit_test/common/test_misc_utils.py @@ -13,10 +13,96 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import uuid +import asyncio import hashlib -import pytest -from common.misc_utils import get_uuid, download_img, hash_str2int, convert_bytes +import sys +import types +import uuid +from contextlib import contextmanager +from unittest.mock import patch + +from common import ssrf_guard +from common.misc_utils import convert_bytes, download_img, get_uuid, hash_str2int + + +class _Hdr: + def __init__(self, mapping: dict[str, str]): + self._m = {k.lower(): v for k, v in mapping.items()} + + def get(self, key: str, default=None): + return self._m.get(key.lower(), default) + + +class _MockStreamResp: + def __init__(self, status_code: int, *, location: str | None = None, body: bytes = b""): + self.status_code = status_code + hdrs: dict[str, str] = {} + if location is not None: + hdrs["Location"] = location + if body: + hdrs.setdefault("Content-Type", "image/jpeg") + self.headers = _Hdr(hdrs) + self._body = body + + async def aclose(self): + return None + + async def aiter_bytes(self): + if self._body: + yield self._body + + +class _FakeStreamCtx: + def __init__(self, resp: _MockStreamResp): + self._resp = resp + + async def __aenter__(self): + return self._resp + + async def __aexit__(self, exc_type, exc, tb): + return None + + +class _FakeAsyncClient: + def __init__(self, responses: list[_MockStreamResp]): + self._responses = responses + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return None + + def stream(self, method, url, headers=None): + if not self._responses: + return _FakeStreamCtx(_MockStreamResp(404)) + return _FakeStreamCtx(self._responses.pop(0)) + + +@contextmanager +def _fake_httpx_sys_modules(client): + """Minimal ``httpx`` stub so ``download_img`` can be exercised without real httpx.""" + saved = sys.modules.get("httpx") + fake = types.ModuleType("httpx") + + class _Timeout: + def __init__(self, *_a, **_kw): + pass + + fake.Timeout = _Timeout + + def _AsyncClient(*_a, **_kw): + return client + + fake.AsyncClient = _AsyncClient + sys.modules["httpx"] = fake + try: + yield + finally: + if saved is not None: + sys.modules["httpx"] = saved + else: + sys.modules.pop("httpx", None) class TestGetUuid: @@ -92,16 +178,68 @@ class TestGetUuid: class TestDownloadImg: """Test cases for download_img function""" - @pytest.mark.asyncio - async def test_empty_url_returns_empty_string(self): + def test_empty_url_returns_empty_string(self): """Test that empty URL returns empty string""" - result = await download_img("") + result = asyncio.run(download_img("")) assert result == "" - @pytest.mark.asyncio - async def test_none_url_returns_empty_string(self): + def test_none_url_returns_empty_string(self): """Test that None URL returns empty string""" - result = await download_img(None) + result = asyncio.run(download_img(None)) + assert result == "" + + def test_loopback_url_blocked(self): + """OAuth avatar fetch must not call loopback (SSRF regression).""" + result = asyncio.run(download_img("http://127.0.0.1/avatar.png")) + assert result == "" + + def test_metadata_ip_blocked(self): + """Link-local / cloud metadata ranges are non-global and must be rejected.""" + result = asyncio.run(download_img("http://169.254.169.254/latest/meta-data/")) + assert result == "" + + def test_disallowed_scheme_blocked(self): + result = asyncio.run(download_img("file:///etc/passwd")) + assert result == "" + + def test_redirect_to_loopback_blocked(self): + """Redirect from an allowed host to loopback must be rejected (SSRF).""" + from urllib.parse import urlparse + + real_assert = ssrf_guard.assert_url_is_safe + + def selective_assert(url: str, **kwargs): + host = urlparse(url).hostname or "" + if host in ("127.0.0.1", "localhost", "169.254.169.254"): + return real_assert(url, **kwargs) + return ("public-avatar.test", "8.8.8.8") + + client = _FakeAsyncClient([_MockStreamResp(302, location="http://127.0.0.1/next")]) + + with ( + patch.object(ssrf_guard, "assert_url_is_safe", side_effect=selective_assert), + _fake_httpx_sys_modules(client), + ): + result = asyncio.run(download_img("http://public-avatar.test/start.png")) + assert result == "" + + def test_redirect_too_many_hops_blocked(self): + """Excessive redirect chains must return empty without hanging.""" + import common.misc_utils as misc_utils + + hops = [ + _MockStreamResp(302, location="http://h.example/1"), + _MockStreamResp(302, location="http://h.example/2"), + _MockStreamResp(302, location="http://h.example/3"), + ] + client = _FakeAsyncClient(hops) + + with ( + patch.object(misc_utils, "_OAUTH_AVATAR_MAX_REDIRECTS", 2), + patch.object(ssrf_guard, "assert_url_is_safe", return_value=("h.example", "8.8.8.8")), + _fake_httpx_sys_modules(client), + ): + result = asyncio.run(misc_utils.download_img("http://h.example/start")) assert result == ""