mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 23:41:12 +08:00
fix: block SSRF in misc_utils.download_img for OAuth avatars (#14868)
### What problem does this PR solve? Closes #14865 `download_img` in `common/misc_utils.py` is used for OAuth avatar URLs. The previous implementation called `async_request` from `common.http_client`, which followed redirects without re-validating each hop and did not apply the same SSRF protections as this path needs. That made it possible to reach non-public or disallowed targets (for example via redirects or unsafe URLs) when fetching avatars. This change replaces that flow with an explicit, bounded fetch: each URL (including every redirect target) is checked with `common.ssrf_guard.assert_url_is_safe`, DNS is pinned with `pin_dns_global`, `httpx` streams the body with `follow_redirects=False` and a manual redirect loop (capped by `RAGFLOW_OAUTH_AVATAR_MAX_REDIRECTS`), and total response size is capped (`RAGFLOW_OAUTH_AVATAR_MAX_BYTES`). Timeouts, proxy, and user agent align with `HTTP_CLIENT_*` env vars without importing `http_client`, so lightweight tests stay simple. Unit tests cover empty/None URLs, loopback, cloud metadata-style addresses, and disallowed schemes so SSRF regressions are caught early. ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --------- Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
This commit is contained in:
@@ -24,22 +24,146 @@ import subprocess
|
||||
import sys
|
||||
import threading
|
||||
import uuid
|
||||
from urllib.parse import urljoin
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_uuid():
|
||||
return uuid.uuid1().hex
|
||||
|
||||
|
||||
# OAuth avatar fetch: bounded size; each redirect hop is SSRF-checked and DNS-pinned
|
||||
# (see common.ssrf_guard).
|
||||
_OAUTH_AVATAR_MAX_BYTES = int(os.environ.get("RAGFLOW_OAUTH_AVATAR_MAX_BYTES", str(5 * 1024 * 1024)))
|
||||
_OAUTH_AVATAR_MAX_REDIRECTS = int(os.environ.get("RAGFLOW_OAUTH_AVATAR_MAX_REDIRECTS", "5"))
|
||||
_REDIRECT_STATUS = frozenset({301, 302, 303, 307, 308})
|
||||
|
||||
|
||||
async def download_img(url):
|
||||
"""Fetch an image URL and return a data URI, or empty string on failure / SSRF block.
|
||||
|
||||
URLs must resolve only to globally routable addresses; redirects are followed
|
||||
only up to ``_OAUTH_AVATAR_MAX_REDIRECTS`` with each target validated.
|
||||
"""
|
||||
if not url:
|
||||
return ""
|
||||
from common.http_client import async_request
|
||||
response = await async_request("GET", url)
|
||||
return "data:" + \
|
||||
response.headers.get('Content-Type', 'image/jpg') + ";" + \
|
||||
"base64," + base64.b64encode(response.content).decode("utf-8")
|
||||
if not isinstance(url, str):
|
||||
url = str(url)
|
||||
url = url.strip()
|
||||
if not url:
|
||||
return ""
|
||||
|
||||
current_url = url
|
||||
redirect_hops = 0
|
||||
|
||||
# Match common/http_client.py defaults without importing http_client (avoids
|
||||
# pulling settings and keeps this path usable in lightweight test envs).
|
||||
request_timeout = float(os.environ.get("HTTP_CLIENT_TIMEOUT", "15"))
|
||||
proxy = os.environ.get("HTTP_CLIENT_PROXY")
|
||||
user_agent = os.environ.get("HTTP_CLIENT_USER_AGENT", "ragflow-http-client")
|
||||
|
||||
from common.ssrf_guard import assert_url_is_safe, pin_dns_global
|
||||
|
||||
while redirect_hops <= _OAUTH_AVATAR_MAX_REDIRECTS:
|
||||
try:
|
||||
hostname, pin_ip = assert_url_is_safe(current_url)
|
||||
except ValueError as exc:
|
||||
logger.warning("download_img rejected URL (SSRF guard): %s", exc)
|
||||
return ""
|
||||
|
||||
import httpx
|
||||
|
||||
timeout = httpx.Timeout(request_timeout)
|
||||
headers = {}
|
||||
if user_agent:
|
||||
headers["User-Agent"] = user_agent
|
||||
|
||||
async def _stream_one_get() -> tuple[str, str | None]:
|
||||
"""Return ``('redirect', new_url)``, ``('data', data_uri)``, or ``('fail', None)``."""
|
||||
with pin_dns_global(hostname, pin_ip):
|
||||
async with httpx.AsyncClient(
|
||||
timeout=timeout,
|
||||
follow_redirects=False,
|
||||
proxy=proxy,
|
||||
) as client:
|
||||
async with client.stream("GET", current_url, headers=headers or None) as response:
|
||||
if response.status_code in _REDIRECT_STATUS:
|
||||
await response.aclose()
|
||||
location = response.headers.get("location")
|
||||
if not location:
|
||||
logger.warning(
|
||||
"download_img redirect missing Location header: url=%r status=%s redirect_hops=%s",
|
||||
current_url,
|
||||
response.status_code,
|
||||
redirect_hops,
|
||||
)
|
||||
return ("fail", None)
|
||||
return ("redirect", urljoin(current_url, location))
|
||||
if response.status_code != 200:
|
||||
logger.warning(
|
||||
"download_img non-200 response: url=%r status=%s redirect_hops=%s",
|
||||
current_url,
|
||||
response.status_code,
|
||||
redirect_hops,
|
||||
)
|
||||
return ("fail", None)
|
||||
body = bytearray()
|
||||
async for chunk in response.aiter_bytes():
|
||||
if len(body) + len(chunk) > _OAUTH_AVATAR_MAX_BYTES:
|
||||
logger.warning(
|
||||
"download_img response exceeded max size: url=%r max_bytes=%s",
|
||||
current_url,
|
||||
_OAUTH_AVATAR_MAX_BYTES,
|
||||
)
|
||||
await response.aclose()
|
||||
return ("fail", None)
|
||||
body.extend(chunk)
|
||||
content_type = response.headers.get("Content-Type", "image/jpeg")
|
||||
data_uri = (
|
||||
"data:"
|
||||
+ content_type
|
||||
+ ";base64,"
|
||||
+ base64.b64encode(bytes(body)).decode("utf-8")
|
||||
)
|
||||
return ("data", data_uri)
|
||||
|
||||
try:
|
||||
kind, payload = await asyncio.wait_for(_stream_one_get(), timeout=request_timeout)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(
|
||||
"download_img total wall-clock timeout: url=%r redirect_hops=%s timeout=%s",
|
||||
current_url,
|
||||
redirect_hops,
|
||||
request_timeout,
|
||||
)
|
||||
return ""
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"download_img request failed: url=%r redirect_hops=%s err=%s",
|
||||
current_url,
|
||||
redirect_hops,
|
||||
exc,
|
||||
)
|
||||
return ""
|
||||
|
||||
if kind == "redirect":
|
||||
current_url = str(payload)
|
||||
redirect_hops += 1
|
||||
continue
|
||||
if kind == "fail":
|
||||
return ""
|
||||
return str(payload)
|
||||
|
||||
logger.warning(
|
||||
"download_img redirect hop limit exceeded: url=%r redirect_hops=%s max_redirects=%s",
|
||||
current_url,
|
||||
redirect_hops,
|
||||
_OAUTH_AVATAR_MAX_REDIRECTS,
|
||||
)
|
||||
return ""
|
||||
|
||||
|
||||
def hash_str2int(line: str, mod: int = 10 ** 8) -> int:
|
||||
|
||||
@@ -13,10 +13,96 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import uuid
|
||||
import asyncio
|
||||
import hashlib
|
||||
import pytest
|
||||
from common.misc_utils import get_uuid, download_img, hash_str2int, convert_bytes
|
||||
import sys
|
||||
import types
|
||||
import uuid
|
||||
from contextlib import contextmanager
|
||||
from unittest.mock import patch
|
||||
|
||||
from common import ssrf_guard
|
||||
from common.misc_utils import convert_bytes, download_img, get_uuid, hash_str2int
|
||||
|
||||
|
||||
class _Hdr:
|
||||
def __init__(self, mapping: dict[str, str]):
|
||||
self._m = {k.lower(): v for k, v in mapping.items()}
|
||||
|
||||
def get(self, key: str, default=None):
|
||||
return self._m.get(key.lower(), default)
|
||||
|
||||
|
||||
class _MockStreamResp:
|
||||
def __init__(self, status_code: int, *, location: str | None = None, body: bytes = b""):
|
||||
self.status_code = status_code
|
||||
hdrs: dict[str, str] = {}
|
||||
if location is not None:
|
||||
hdrs["Location"] = location
|
||||
if body:
|
||||
hdrs.setdefault("Content-Type", "image/jpeg")
|
||||
self.headers = _Hdr(hdrs)
|
||||
self._body = body
|
||||
|
||||
async def aclose(self):
|
||||
return None
|
||||
|
||||
async def aiter_bytes(self):
|
||||
if self._body:
|
||||
yield self._body
|
||||
|
||||
|
||||
class _FakeStreamCtx:
|
||||
def __init__(self, resp: _MockStreamResp):
|
||||
self._resp = resp
|
||||
|
||||
async def __aenter__(self):
|
||||
return self._resp
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return None
|
||||
|
||||
|
||||
class _FakeAsyncClient:
|
||||
def __init__(self, responses: list[_MockStreamResp]):
|
||||
self._responses = responses
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return None
|
||||
|
||||
def stream(self, method, url, headers=None):
|
||||
if not self._responses:
|
||||
return _FakeStreamCtx(_MockStreamResp(404))
|
||||
return _FakeStreamCtx(self._responses.pop(0))
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _fake_httpx_sys_modules(client):
|
||||
"""Minimal ``httpx`` stub so ``download_img`` can be exercised without real httpx."""
|
||||
saved = sys.modules.get("httpx")
|
||||
fake = types.ModuleType("httpx")
|
||||
|
||||
class _Timeout:
|
||||
def __init__(self, *_a, **_kw):
|
||||
pass
|
||||
|
||||
fake.Timeout = _Timeout
|
||||
|
||||
def _AsyncClient(*_a, **_kw):
|
||||
return client
|
||||
|
||||
fake.AsyncClient = _AsyncClient
|
||||
sys.modules["httpx"] = fake
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if saved is not None:
|
||||
sys.modules["httpx"] = saved
|
||||
else:
|
||||
sys.modules.pop("httpx", None)
|
||||
|
||||
|
||||
class TestGetUuid:
|
||||
@@ -92,16 +178,68 @@ class TestGetUuid:
|
||||
class TestDownloadImg:
|
||||
"""Test cases for download_img function"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_url_returns_empty_string(self):
|
||||
def test_empty_url_returns_empty_string(self):
|
||||
"""Test that empty URL returns empty string"""
|
||||
result = await download_img("")
|
||||
result = asyncio.run(download_img(""))
|
||||
assert result == ""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_none_url_returns_empty_string(self):
|
||||
def test_none_url_returns_empty_string(self):
|
||||
"""Test that None URL returns empty string"""
|
||||
result = await download_img(None)
|
||||
result = asyncio.run(download_img(None))
|
||||
assert result == ""
|
||||
|
||||
def test_loopback_url_blocked(self):
|
||||
"""OAuth avatar fetch must not call loopback (SSRF regression)."""
|
||||
result = asyncio.run(download_img("http://127.0.0.1/avatar.png"))
|
||||
assert result == ""
|
||||
|
||||
def test_metadata_ip_blocked(self):
|
||||
"""Link-local / cloud metadata ranges are non-global and must be rejected."""
|
||||
result = asyncio.run(download_img("http://169.254.169.254/latest/meta-data/"))
|
||||
assert result == ""
|
||||
|
||||
def test_disallowed_scheme_blocked(self):
|
||||
result = asyncio.run(download_img("file:///etc/passwd"))
|
||||
assert result == ""
|
||||
|
||||
def test_redirect_to_loopback_blocked(self):
|
||||
"""Redirect from an allowed host to loopback must be rejected (SSRF)."""
|
||||
from urllib.parse import urlparse
|
||||
|
||||
real_assert = ssrf_guard.assert_url_is_safe
|
||||
|
||||
def selective_assert(url: str, **kwargs):
|
||||
host = urlparse(url).hostname or ""
|
||||
if host in ("127.0.0.1", "localhost", "169.254.169.254"):
|
||||
return real_assert(url, **kwargs)
|
||||
return ("public-avatar.test", "8.8.8.8")
|
||||
|
||||
client = _FakeAsyncClient([_MockStreamResp(302, location="http://127.0.0.1/next")])
|
||||
|
||||
with (
|
||||
patch.object(ssrf_guard, "assert_url_is_safe", side_effect=selective_assert),
|
||||
_fake_httpx_sys_modules(client),
|
||||
):
|
||||
result = asyncio.run(download_img("http://public-avatar.test/start.png"))
|
||||
assert result == ""
|
||||
|
||||
def test_redirect_too_many_hops_blocked(self):
|
||||
"""Excessive redirect chains must return empty without hanging."""
|
||||
import common.misc_utils as misc_utils
|
||||
|
||||
hops = [
|
||||
_MockStreamResp(302, location="http://h.example/1"),
|
||||
_MockStreamResp(302, location="http://h.example/2"),
|
||||
_MockStreamResp(302, location="http://h.example/3"),
|
||||
]
|
||||
client = _FakeAsyncClient(hops)
|
||||
|
||||
with (
|
||||
patch.object(misc_utils, "_OAUTH_AVATAR_MAX_REDIRECTS", 2),
|
||||
patch.object(ssrf_guard, "assert_url_is_safe", return_value=("h.example", "8.8.8.8")),
|
||||
_fake_httpx_sys_modules(client),
|
||||
):
|
||||
result = asyncio.run(misc_utils.download_img("http://h.example/start"))
|
||||
assert result == ""
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user