diff --git a/rag/app/naive.py b/rag/app/naive.py index 2b97442f45..211ee29bcd 100644 --- a/rag/app/naive.py +++ b/rag/app/naive.py @@ -643,6 +643,11 @@ class Pdf(PdfParser): return [(b["text"], self._line_tag(b, zoomin)) for b in self.boxes], tbls +# Maximum number of HTTP redirects followed when fetching a remote image +# referenced by a markdown document (each hop is SSRF-validated). +MAX_IMAGE_REDIRECTS = 5 + + class Markdown(MarkdownParser): def md_to_html(self, sections): if not sections: @@ -714,6 +719,9 @@ class Markdown(MarkdownParser): def load_images_from_urls(self, urls, cache=None): import requests from pathlib import Path + from urllib.parse import urljoin + + from common.ssrf_guard import assert_url_is_safe, pin_dns cache = cache or {} images = [] @@ -725,9 +733,40 @@ class Markdown(MarkdownParser): img_obj = None try: if url.startswith(("http://", "https://")): - response = requests.get(url, stream=True, timeout=30) - if response.status_code == 200 and response.headers.get("Content-Type", "").startswith("image/"): - img_obj = Image.open(BytesIO(response.content)).convert("RGB") + # SSRF guard: image references come from the (untrusted) uploaded + # document, so validate and DNS-pin every hop before connecting. + # Otherwise a markdown image like ![x](http://169.254.169.254/...) + # would make the server fetch internal services / cloud metadata. + # Redirects are followed manually so each hop is re-validated, + # mirroring common/data_source/rss_connector.py. + current_hostname, current_ip = assert_url_is_safe(url) + current_url = url + response = None + try: + for _ in range(MAX_IMAGE_REDIRECTS + 1): + # Release the previous hop before opening the next: with + # stream=True the connection isn't returned to the pool + # until the body is read or the response is closed. + if response is not None: + response.close() + with pin_dns(current_hostname, current_ip): + response = requests.get(current_url, stream=True, timeout=30, allow_redirects=False) + if response.status_code not in (301, 302, 303, 307, 308): + break + location = response.headers.get("Location") + if not location: + break + current_url = urljoin(current_url, location) + current_hostname, current_ip = assert_url_is_safe(current_url) + else: + raise ValueError(f"Exceeded {MAX_IMAGE_REDIRECTS} redirects fetching {url!r}") + if response.status_code == 200 and response.headers.get("Content-Type", "").startswith("image/"): + img_obj = Image.open(BytesIO(response.content)).convert("RGB") + finally: + # Always release the final/streamed response, including the + # non-image and redirect-cap paths where the body is unread. + if response is not None: + response.close() else: local_path = Path(url) if local_path.exists(): diff --git a/test/unit_test/rag/app/test_markdown_image_ssrf.py b/test/unit_test/rag/app/test_markdown_image_ssrf.py new file mode 100644 index 0000000000..6ef2487bab --- /dev/null +++ b/test/unit_test/rag/app/test_markdown_image_ssrf.py @@ -0,0 +1,140 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License on an "AS IS" BASIS, WITHOUT WARRANTIES +# OR CONDITIONS OF ANY KIND, either express or implied. See the License +# for the specific language governing permissions and limitations under +# the License. +# + +"""SSRF-guard regression tests for rag.app.naive.Markdown.load_images_from_urls. + +Image references are parsed out of the (untrusted) uploaded markdown document +and fetched server-side, so the loader must validate + DNS-pin every hop before +connecting. These tests assert that internal/loopback targets are rejected, +that redirects to internal targets are rejected, and that a legitimate public +image is still fetched. +""" + +from __future__ import annotations + +import io +import sys +from unittest.mock import MagicMock, patch + +# Mock heavy modules that trigger ONNX/OCR model loading or optional parser +# backends at import time so importing rag.app.naive stays lightweight. +for _mod in [ + "deepdoc.vision.ocr", + "deepdoc.parser.figure_parser", + "deepdoc.parser.docling_parser", + "deepdoc.parser.tcadp_parser", + "rag.app.picture", +]: + if _mod not in sys.modules: + sys.modules[_mod] = MagicMock() + +import pytest +from PIL import Image + +from common import ssrf_guard +from rag.app.naive import MAX_IMAGE_REDIRECTS, Markdown + + +def _png_bytes() -> bytes: + buf = io.BytesIO() + Image.new("RGB", (1, 1), (255, 0, 0)).save(buf, format="PNG") + return buf.getvalue() + + +class _Resp: + """Minimal stand-in for requests.Response.""" + + def __init__(self, status_code, headers=None, content=b""): + self.status_code = status_code + self.headers = headers or {} + self.content = content + + def close(self): + pass + + +@pytest.fixture +def parser(): + return Markdown(128) + + +@pytest.mark.p1 +def test_blocks_internal_url_without_fetching(parser): + """A markdown image pointing at an internal host must never be requested.""" + with ( + patch.object(ssrf_guard, "assert_url_is_safe", side_effect=ValueError("non-public")) as guard, + patch("requests.get") as get, + ): + images, cache = parser.load_images_from_urls(["http://169.254.169.254/latest/meta-data/"]) + + guard.assert_called_once() + get.assert_not_called() # SSRF guard rejects before any connection is made + assert images == [] + assert cache["http://169.254.169.254/latest/meta-data/"] is None + + +@pytest.mark.p1 +def test_blocks_redirect_to_internal_target(parser): + """A public URL that 302-redirects to a loopback target must be rejected.""" + + def selective_assert(url, **kwargs): + if "127.0.0.1" in url or "localhost" in url: + raise ValueError("redirect resolves to non-public address") + return ("public.example", "8.8.8.8") + + redirect = _Resp(302, headers={"Location": "http://127.0.0.1/secret"}) + with ( + patch.object(ssrf_guard, "assert_url_is_safe", side_effect=selective_assert), + patch("requests.get", return_value=redirect) as get, + ): + images, _ = parser.load_images_from_urls(["http://public.example/logo.png"]) + + # Only the first (public) hop is fetched; the redirect target is blocked + # by re-validation before a second request is made. + assert get.call_count == 1 + assert images == [] + + +@pytest.mark.p1 +def test_fetches_legitimate_public_image(parser): + png = _png_bytes() + ok = _Resp(200, headers={"Content-Type": "image/png"}, content=png) + with ( + patch.object(ssrf_guard, "assert_url_is_safe", return_value=("public.example", "8.8.8.8")), + patch("requests.get", return_value=ok) as get, + ): + images, cache = parser.load_images_from_urls(["http://public.example/logo.png"]) + + get.assert_called_once() + # allow_redirects must be disabled so redirects are validated per hop. + assert get.call_args.kwargs.get("allow_redirects") is False + assert len(images) == 1 + assert isinstance(images[0], Image.Image) + + +@pytest.mark.p1 +def test_redirect_chain_is_bounded(parser): + """An endless redirect loop is abandoned instead of being followed forever.""" + loop = _Resp(302, headers={"Location": "http://public.example/next"}) + with ( + patch.object(ssrf_guard, "assert_url_is_safe", return_value=("public.example", "8.8.8.8")), + patch("requests.get", return_value=loop) as get, + ): + images, _ = parser.load_images_from_urls(["http://public.example/start"]) + + assert get.call_count == MAX_IMAGE_REDIRECTS + 1 + assert images == []