fix: block SSRF in misc_utils.download_img for OAuth avatars (#14868)

### What problem does this PR solve?

Closes #14865

`download_img` in `common/misc_utils.py` is used for OAuth avatar URLs.
The previous implementation called `async_request` from
`common.http_client`, which followed redirects without re-validating
each hop and did not apply the same SSRF protections as this path needs.
That made it possible to reach non-public or disallowed targets (for
example via redirects or unsafe URLs) when fetching avatars.

This change replaces that flow with an explicit, bounded fetch: each URL
(including every redirect target) is checked with
`common.ssrf_guard.assert_url_is_safe`, DNS is pinned with
`pin_dns_global`, `httpx` streams the body with `follow_redirects=False`
and a manual redirect loop (capped by
`RAGFLOW_OAUTH_AVATAR_MAX_REDIRECTS`), and total response size is capped
(`RAGFLOW_OAUTH_AVATAR_MAX_BYTES`). Timeouts, proxy, and user agent
align with `HTTP_CLIENT_*` env vars without importing `http_client`, so
lightweight tests stay simple.

Unit tests cover empty/None URLs, loopback, cloud metadata-style
addresses, and disallowed schemes so SSRF regressions are caught early.

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)

---------

Co-authored-by: Kevin Hu <kevinhu.sh@gmail.com>
This commit is contained in:
dale053
2026-05-21 21:12:04 -07:00
committed by GitHub
parent b2bf9155ed
commit 6ab25bf715
2 changed files with 276 additions and 14 deletions

View File

@@ -24,22 +24,146 @@ import subprocess
import sys
import threading
import uuid
from urllib.parse import urljoin
from concurrent.futures import ThreadPoolExecutor
logger = logging.getLogger(__name__)
def get_uuid():
return uuid.uuid1().hex
# OAuth avatar fetch: bounded size; each redirect hop is SSRF-checked and DNS-pinned
# (see common.ssrf_guard).
_OAUTH_AVATAR_MAX_BYTES = int(os.environ.get("RAGFLOW_OAUTH_AVATAR_MAX_BYTES", str(5 * 1024 * 1024)))
_OAUTH_AVATAR_MAX_REDIRECTS = int(os.environ.get("RAGFLOW_OAUTH_AVATAR_MAX_REDIRECTS", "5"))
_REDIRECT_STATUS = frozenset({301, 302, 303, 307, 308})
async def download_img(url):
"""Fetch an image URL and return a data URI, or empty string on failure / SSRF block.
URLs must resolve only to globally routable addresses; redirects are followed
only up to ``_OAUTH_AVATAR_MAX_REDIRECTS`` with each target validated.
"""
if not url:
return ""
from common.http_client import async_request
response = await async_request("GET", url)
return "data:" + \
response.headers.get('Content-Type', 'image/jpg') + ";" + \
"base64," + base64.b64encode(response.content).decode("utf-8")
if not isinstance(url, str):
url = str(url)
url = url.strip()
if not url:
return ""
current_url = url
redirect_hops = 0
# Match common/http_client.py defaults without importing http_client (avoids
# pulling settings and keeps this path usable in lightweight test envs).
request_timeout = float(os.environ.get("HTTP_CLIENT_TIMEOUT", "15"))
proxy = os.environ.get("HTTP_CLIENT_PROXY")
user_agent = os.environ.get("HTTP_CLIENT_USER_AGENT", "ragflow-http-client")
from common.ssrf_guard import assert_url_is_safe, pin_dns_global
while redirect_hops <= _OAUTH_AVATAR_MAX_REDIRECTS:
try:
hostname, pin_ip = assert_url_is_safe(current_url)
except ValueError as exc:
logger.warning("download_img rejected URL (SSRF guard): %s", exc)
return ""
import httpx
timeout = httpx.Timeout(request_timeout)
headers = {}
if user_agent:
headers["User-Agent"] = user_agent
async def _stream_one_get() -> tuple[str, str | None]:
"""Return ``('redirect', new_url)``, ``('data', data_uri)``, or ``('fail', None)``."""
with pin_dns_global(hostname, pin_ip):
async with httpx.AsyncClient(
timeout=timeout,
follow_redirects=False,
proxy=proxy,
) as client:
async with client.stream("GET", current_url, headers=headers or None) as response:
if response.status_code in _REDIRECT_STATUS:
await response.aclose()
location = response.headers.get("location")
if not location:
logger.warning(
"download_img redirect missing Location header: url=%r status=%s redirect_hops=%s",
current_url,
response.status_code,
redirect_hops,
)
return ("fail", None)
return ("redirect", urljoin(current_url, location))
if response.status_code != 200:
logger.warning(
"download_img non-200 response: url=%r status=%s redirect_hops=%s",
current_url,
response.status_code,
redirect_hops,
)
return ("fail", None)
body = bytearray()
async for chunk in response.aiter_bytes():
if len(body) + len(chunk) > _OAUTH_AVATAR_MAX_BYTES:
logger.warning(
"download_img response exceeded max size: url=%r max_bytes=%s",
current_url,
_OAUTH_AVATAR_MAX_BYTES,
)
await response.aclose()
return ("fail", None)
body.extend(chunk)
content_type = response.headers.get("Content-Type", "image/jpeg")
data_uri = (
"data:"
+ content_type
+ ";base64,"
+ base64.b64encode(bytes(body)).decode("utf-8")
)
return ("data", data_uri)
try:
kind, payload = await asyncio.wait_for(_stream_one_get(), timeout=request_timeout)
except asyncio.TimeoutError:
logger.warning(
"download_img total wall-clock timeout: url=%r redirect_hops=%s timeout=%s",
current_url,
redirect_hops,
request_timeout,
)
return ""
except Exception as exc:
logger.warning(
"download_img request failed: url=%r redirect_hops=%s err=%s",
current_url,
redirect_hops,
exc,
)
return ""
if kind == "redirect":
current_url = str(payload)
redirect_hops += 1
continue
if kind == "fail":
return ""
return str(payload)
logger.warning(
"download_img redirect hop limit exceeded: url=%r redirect_hops=%s max_redirects=%s",
current_url,
redirect_hops,
_OAUTH_AVATAR_MAX_REDIRECTS,
)
return ""
def hash_str2int(line: str, mod: int = 10 ** 8) -> int: