Harden closed-advisory fixes (#16409)

## Summary
- harden reopened advisory fixes across REST connector, invoke, document
downloads, and markdown rendering
- add targeted regression coverage for redirect-safe SSRF handling,
invoke SSRF checks, document access control, and markdown sanitization
- verify each referenced GHSA against the original GitHub advisory text
and align the closed-advisory plan with the implemented remediation

## What changed
- add tenant access checks to document download endpoints to avoid
cross-tenant document disclosure
- add per-hop SSRF validation, DNS pinning, redirect handling, and
redirect limits to the REST API connector
- ensure invoke requests validate and pin the resolved host and never
follow redirects implicitly
- keep the generic rate-limited request path wrapped, not just GET and
POST helpers
- sanitize markdown HTML before rendering in the highlight markdown
component

## Validation
- `cd web && npm test -- --runInBand
src/components/highlight-markdown/__tests__/index.test.tsx`
- `.venv/bin/python -m pytest -q
test/unit_test/data_source/test_rest_api_connector.py`
- targeted `test/testcases/test_web_api/...` unit additions were
reviewed, but the suite cannot be executed end-to-end in this
environment because parent `test/testcases/conftest.py` requires a local
service on `127.0.0.1:9380`

## Notes
- all GHSA entries referenced by the plan were checked against the
original GitHub advisory text, not sampled
- the closed-advisory plan document was updated locally during review,
but is intentionally not included in this PR
This commit is contained in:
Zhichang Yu
2026-06-28 11:17:54 +08:00
committed by GitHub
parent f90be41eab
commit c4fe68eaa0
11 changed files with 398 additions and 12 deletions

View File

@@ -25,6 +25,7 @@ import requests
from agent.component.base import ComponentBase, ComponentParamBase
from common.connection_utils import timeout
from common.ssrf_guard import assert_url_is_safe, pin_dns
from deepdoc.parser import HtmlParser
@@ -56,6 +57,11 @@ class Invoke(ComponentBase, ABC):
component_name = "Invoke"
header_variable_ref_patt = r"\{([a-zA-Z_][a-zA-Z0-9_.@-]*)\}"
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._pinned_hostname: str | None = None
self._pinned_ip: str | None = None
@staticmethod
def _coerce_json_arg_if_possible(key, value):
raw_value = value
@@ -169,6 +175,9 @@ class Invoke(ComponentBase, ABC):
url = self._resolve_template_text(self._param.url.strip(), kwargs)
if not url.startswith(("http://", "https://")):
url = "http://" + url
hostname, ip = assert_url_is_safe(url)
self._pinned_hostname = hostname
self._pinned_ip = ip
return url
def _build_headers(self, kwargs: dict) -> dict:
@@ -194,6 +203,7 @@ class Invoke(ComponentBase, ABC):
"headers": headers,
"proxies": proxies,
"timeout": self._param.timeout,
"allow_redirects": False,
}
# GET sends query params; POST/PUT send either JSON or form data based on datatype.
@@ -219,7 +229,6 @@ class Invoke(ComponentBase, ABC):
return
args = self._build_request_args(kwargs)
url = self._build_url(kwargs)
headers = self._build_headers(kwargs)
proxies = self._build_proxies()
@@ -229,7 +238,15 @@ class Invoke(ComponentBase, ABC):
return
try:
response = self._send_request(url, args, headers, proxies)
# Coderabbit MAJOR #3486038788: URL validation is now inside the
# retry/except block so SSRF rejections (ValueError from
# assert_url_is_safe) populate _ERROR via the standard error
# path instead of escaping _invoke().
url = self._build_url(kwargs)
if not self._pinned_hostname or not self._pinned_ip:
raise ValueError("Invoke URL was not validated before request.")
with pin_dns(self._pinned_hostname, self._pinned_ip):
response = self._send_request(url, args, headers, proxies)
result = self._format_response(response)
self.set_output("result", result)
return result

View File

@@ -2003,6 +2003,10 @@ async def download(dataset_id, document_id):
"""
if not document_id:
return get_error_data_result(message="Specify document_id please.")
if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=current_user.id):
return get_data_error_result(message="Document not found!")
if not DocumentService.accessible(document_id, current_user.id):
return get_data_error_result(message="Document not found!")
doc = DocumentService.query(kb_id=dataset_id, id=document_id)
if not doc:
return get_error_data_result(message=f"The dataset not own the document {document_id}.")
@@ -2060,6 +2064,8 @@ async def download_document(document_id):
"""
if not document_id:
return get_error_data_result(message="Specify document_id please.")
if not DocumentService.accessible(document_id, current_user.id):
return get_data_error_result(message="Document not found!")
doc = DocumentService.query(id=document_id)
if not doc:
return get_error_data_result(message=f"The dataset not own the document {document_id}.")

View File

@@ -13,7 +13,7 @@ import re
import time
from datetime import datetime, timezone
from typing import Any, Dict, Generator, Iterable, List, Mapping, Optional
from urllib.parse import parse_qs, urlparse, urlunparse
from urllib.parse import parse_qs, urljoin, urlparse, urlunparse
import ipaddress
import socket
@@ -35,6 +35,7 @@ from common.data_source.interfaces import (
)
from common.data_source.models import Document
from common.data_source.utils import rl_requests, retry_builder
from common.ssrf_guard import assert_url_is_safe, pin_dns
try:
from jsonpath import jsonpath as _jsonpath # type: ignore[import]
@@ -43,6 +44,8 @@ except Exception: # pragma: no cover
_FIELD_SEGMENT_RE = re.compile(r'^(?P<key>[^\[\]]+)(\[(?P<index>\d+|\*)\])?$')
_DEFAULT_MAX_PAGES = 1000
_REDIRECT_STATUSES = frozenset({301, 302, 303, 307, 308})
_MAX_REDIRECTS = 5
class AuthType:
@@ -604,11 +607,19 @@ class RestAPIConnector(LoadConnector, PollConnector):
)
if self.method == "GET":
resp = rl_requests.get(url, headers=headers, params=query_params, auth=self._basic_auth, timeout=60)
resp = self._safe_request(
"GET",
url,
headers=headers,
params=query_params,
)
elif self.method == "POST":
resp = rl_requests.post(
url, headers=headers, params=query_params,
json=self._static_request_body or {}, auth=self._basic_auth, timeout=60,
resp = self._safe_request(
"POST",
url,
headers=headers,
params=query_params,
json_body=self._static_request_body or {},
)
else:
raise ConnectorValidationError(f"Unsupported HTTP method: {self.method}")
@@ -647,6 +658,113 @@ class RestAPIConnector(LoadConnector, PollConnector):
except ValueError as exc:
raise ConnectorValidationError("REST API response is not valid JSON") from exc
# Headers that carry auth state. Stripped on cross-origin redirects to
# prevent credential exfiltration to a third-party host. (Coderabbit MAJOR #3486038792)
_AUTH_SENSITIVE_HEADER_KEYS = frozenset({
"authorization",
"proxy-authorization",
"apikey",
"api-key",
"x-api-key",
"x-auth-token",
})
def _safe_request(
self,
method: str,
url: str,
*,
headers: Dict[str, str],
params: Dict[str, Any],
body: Any = None,
json_body: Any = None,
) -> requests.Response:
"""Issue an HTTP request with per-hop SSRF validation and DNS pinning."""
current_url = url
current_method = method.upper()
current_body = body
current_json = json_body
current_params = dict(params)
# Local auth handle: cleared when crossing origins, even though
# ``self._basic_auth`` may still hold the original credentials.
current_auth = self._basic_auth
previous_netloc = urlparse(current_url).netloc
for _ in range(_MAX_REDIRECTS + 1):
# Normalize SSRF validation failures to the connector's documented
# ConnectorValidationError so they don't leak ValueError out of
# _page_iter_for_validation(). (Coderabbit MAJOR #3486038789)
try:
hostname, pin_ip = assert_url_is_safe(current_url)
except ValueError as exc:
raise ConnectorValidationError(
f"Unsafe REST API URL: {exc}"
) from exc
with pin_dns(hostname, pin_ip):
if current_method == "GET":
resp = rl_requests.get(
current_url,
headers=headers,
params=current_params,
auth=current_auth,
timeout=60,
allow_redirects=False,
)
elif current_method == "POST":
resp = rl_requests.post(
current_url,
headers=headers,
params=current_params,
json=current_json,
auth=current_auth,
timeout=60,
allow_redirects=False,
)
else:
resp = rl_requests.request(
current_method,
current_url,
headers=headers,
params=current_params,
auth=current_auth,
timeout=60,
data=current_body,
json=current_json,
allow_redirects=False,
)
if resp.status_code not in _REDIRECT_STATUSES:
return resp
location = resp.headers.get("Location")
if not location:
return resp
current_url = urljoin(current_url, location)
next_netloc = urlparse(current_url).netloc
# Coderabbit MAJOR #3486038792: strip credentials when the redirect
# crosses to a different origin so a public→private redirect chain
# cannot exfiltrate Bearer/Basic/API-key headers.
if next_netloc and next_netloc != previous_netloc:
headers = {
k: v
for k, v in headers.items()
if k.lower() not in self._AUTH_SENSITIVE_HEADER_KEYS
}
current_auth = None
previous_netloc = next_netloc
if resp.status_code in (301, 302, 303):
current_method = "GET"
current_body = None
current_json = None
# Clear carried params — only the new Location URL's query
# string should apply for the downgraded GET.
current_params = {}
raise ConnectorValidationError(f"Exceeded {_MAX_REDIRECTS} redirects fetching {url!r}")
def _build_url_with_templates(self, params: Dict[str, Any]) -> tuple[str, Dict[str, Any]]:
"""Substitute ``{key}`` placeholders in the URL; return remaining query params."""
url = self.url

View File

@@ -224,11 +224,13 @@ def wrap_request_to_handle_ratelimiting(request_fn: R, default_wait_time_sec: in
_rate_limited_get = wrap_request_to_handle_ratelimiting(requests.get)
_rate_limited_post = wrap_request_to_handle_ratelimiting(requests.post)
_rate_limited_request = wrap_request_to_handle_ratelimiting(requests.request)
class _RateLimitedRequest:
get = _rate_limited_get
post = _rate_limited_post
request = _rate_limited_request
rl_requests = _RateLimitedRequest

View File

@@ -45,7 +45,7 @@ def test_document_download_by_id_invalid_id_contract(rest_client):
assert res.status_code == 200
payload = res.json()
assert payload["code"] == 102, payload
assert payload["message"] == "The dataset not own the document invalid_document_id.", payload
assert payload["message"] == "Document not found!", payload
@pytest.mark.p2

View File

@@ -1503,14 +1503,14 @@ def test_documents_download_requires_auth_and_invalid_id_contract(rest_client, c
assert invalid_doc_res.status_code == 200
invalid_doc_payload = invalid_doc_res.json()
assert invalid_doc_payload["code"] == 102, invalid_doc_payload
assert "The dataset not own the document invalid_document_id." in invalid_doc_payload["message"], invalid_doc_payload
assert invalid_doc_payload["message"] == "Document not found!", invalid_doc_payload
invalid_dataset_path = tmp_path / "invalid_dataset_download.txt"
invalid_dataset_res = _download_document_to_file(rest_client, "invalid_dataset_id", document_id, invalid_dataset_path)
assert invalid_dataset_res.status_code == 200
invalid_dataset_payload = invalid_dataset_res.json()
assert invalid_dataset_payload["code"] == 102, invalid_dataset_payload
assert f"The dataset not own the document {document_id}." in invalid_dataset_payload["message"], invalid_dataset_payload
assert invalid_dataset_payload["message"] == "Document not found!", invalid_dataset_payload
@pytest.mark.p2

View File

@@ -271,3 +271,37 @@ def test_header_variable_with_put(monkeypatch):
monkeypatch.setattr(module.requests, "put", mock_put)
invoke._invoke()
assert mock_put.call_args[1]["headers"]["Authorization"] == "Bearer put_token"
@pytest.mark.p2
def test_invoke_rejects_private_url(monkeypatch):
module = _load_invoke_module(monkeypatch)
invoke = _make_invoke(module, url="http://127.0.0.1:22")
monkeypatch.setattr(
module,
"assert_url_is_safe",
MagicMock(side_effect=ValueError("URL resolves to a non-public address")),
)
# Coderabbit MAJOR #3486038793: _build_url() is now inside the retry
# try/except block, so the ValueError from assert_url_is_safe is caught
# and the message is stored in _ERROR via the standard error path.
invoke._invoke()
assert "non-public address" in invoke._param.outputs["_ERROR"]
@pytest.mark.p2
def test_invoke_pins_dns_and_disables_redirects(monkeypatch):
module = _load_invoke_module(monkeypatch)
invoke = _make_invoke(module, url="http://example.com")
pin_ctx = MagicMock()
pin_ctx.__enter__ = MagicMock(return_value=None)
pin_ctx.__exit__ = MagicMock(return_value=None)
monkeypatch.setattr(module, "assert_url_is_safe", MagicMock(return_value=("example.com", "93.184.216.34")))
monkeypatch.setattr(module, "pin_dns", MagicMock(return_value=pin_ctx))
mock_get = MagicMock(return_value=SimpleNamespace(text="ok"))
monkeypatch.setattr(module.requests, "get", mock_get)
invoke._invoke()
module.pin_dns.assert_called_once_with("example.com", "93.184.216.34")
assert mock_get.call_args[1]["allow_redirects"] is False

View File

@@ -450,6 +450,23 @@ class TestDocumentMetadataUnit:
assert res["code"] == 500
assert "download boom" in res["message"]
def test_download_document_rejects_other_tenant_unit(self, document_rest_api_module, monkeypatch):
module = document_rest_api_module
monkeypatch.setattr(module.DocumentService, "accessible", lambda _doc_id, _user_id: False)
res = _run(module.download_document("doc1"))
assert res["code"] == RetCode.DATA_ERROR
assert "Document not found!" in res["message"]
def test_dataset_document_download_rejects_other_tenant_unit(self, document_rest_api_module, monkeypatch):
module = document_rest_api_module
monkeypatch.setattr(module.KnowledgebaseService, "accessible", lambda kb_id, user_id: False)
monkeypatch.setattr(module.DocumentService, "accessible", lambda _doc_id, _user_id: True)
res = _run(module.download("kb1", "doc1"))
assert res["code"] == RetCode.DATA_ERROR
assert "Document not found!" in res["message"]
@pytest.mark.p2
def test_get_document_image_content_type_from_object_extension_unit(self, document_app_module, monkeypatch):
module = document_app_module

View File

@@ -14,7 +14,7 @@
# limitations under the License.
#
from contextlib import contextmanager
from contextlib import contextmanager, nullcontext
from unittest.mock import MagicMock, patch
from urllib.parse import urlparse
@@ -211,6 +211,95 @@ class TestSSRFValidation:
with pytest.raises(ConnectorValidationError, match="scheme"):
_make_connector(url="file:///etc/passwd")
def test_redirect_to_loopback_rejected(self):
"""Redirect targets must be revalidated before they are fetched.
Exercise ``_safe_request`` directly rather than ``_fetch_page``: the
latter is wrapped by ``@retry_builder`` and in some CI environments
the retry path on the first ``ConnectorValidationError`` exhausts the
``side_effect`` and lets the loop run all 6 iterations, surfacing
``Exceeded 5 redirects`` instead of the expected ``loopback blocked``.
``_safe_request`` is the actual unit under test for redirect SSRF
handling, so testing it directly is the more faithful check.
Note on the DNS mock: ``_mocked_rest_api_requests_and_dns`` uses a
fixed ``return_value`` for ``socket.getaddrinfo``. With a constant
return value, a redirect to ``127.0.0.1`` would be reported as
resolving to ``93.184.216.34`` (a public address) and would slip
through ``is_global`` checks. To exercise the actual rejection path,
we override the patched ``getaddrinfo`` here to return the literal
loopback address for the loopback hostname.
"""
connector = _make_connector()
first = _mock_response([], status_code=302)
first.headers = {"Location": "http://127.0.0.1/secret"}
def _dns_for_host(host, *args, **kwargs):
# Mirror _MOCK_DNS_ADDRINFO shape: (family, type, proto, canon, sockaddr).
if host == "127.0.0.1":
return [(2, 1, 6, "", ("127.0.0.1", 0))]
return list(_MOCK_DNS_ADDRINFO)
with _mocked_rest_api_requests_and_dns() as mock_rl:
mock_rl.get.return_value = first
# The ``_mocked_rest_api_requests_and_dns`` context manager already
# patches ``socket.getaddrinfo`` with a constant ``return_value``;
# replace it here with a side_effect that distinguishes loopback
# from public hostnames so the SSRF guard actually rejects the
# redirect target.
import common.data_source.rest_api_connector as rc_module
from unittest.mock import patch as _patch
with _patch.object(rc_module.socket, "getaddrinfo", side_effect=_dns_for_host):
# Coderabbit MAJOR #3486038795: SSRF validation failures inside
# _safe_request are now wrapped to raise ConnectorValidationError
# (the connector's documented error contract) instead of leaking
# raw ValueError from ssrf_guard.
with pytest.raises(ConnectorValidationError, match=r"non-public address|loopback"):
connector._safe_request(
"GET",
connector.url,
headers={},
params={},
)
@patch("common.data_source.rest_api_connector.assert_url_is_safe")
@patch("common.data_source.rest_api_connector.pin_dns")
def test_post_307_preserves_body(self, mock_pin_dns, mock_safe):
"""307 redirects should keep method and JSON body."""
connector = _make_connector(method="POST", request_body={"hello": "world"})
first = _mock_response([], status_code=307)
first.headers = {"Location": "https://api.example.com/redirected"}
second = _mock_response({"items": []}, status_code=200)
mock_safe.side_effect = [
("api.example.com", "93.184.216.34"),
("api.example.com", "93.184.216.34"),
]
mock_pin_dns.return_value = nullcontext()
with _mocked_rest_api_requests_and_dns() as mock_rl:
mock_rl.post.side_effect = [first, second]
connector._fetch_page({})
assert mock_rl.post.call_count == 2
assert mock_rl.post.call_args_list[0].kwargs["json"] == {"hello": "world"}
assert mock_rl.post.call_args_list[1].kwargs["json"] == {"hello": "world"}
assert mock_rl.post.call_args_list[1].kwargs["allow_redirects"] is False
@patch("common.data_source.rest_api_connector.assert_url_is_safe")
@patch("common.data_source.rest_api_connector.pin_dns")
def test_exceeds_max_redirects_raises(self, mock_pin_dns, mock_safe):
"""Too many redirects should raise a connector validation error."""
connector = _make_connector()
redirect = _mock_response([], status_code=302)
redirect.headers = {"Location": "https://api.example.com/next"}
mock_safe.side_effect = [("api.example.com", "93.184.216.34")] * 6
mock_pin_dns.return_value = nullcontext()
with _mocked_rest_api_requests_and_dns() as mock_rl:
mock_rl.get.side_effect = [redirect] * 6
with pytest.raises(ConnectorValidationError, match="Exceeded 5 redirects"):
connector._fetch_page({})
# ===================================================================== #
# 3. Authentication setup #

View File

@@ -0,0 +1,96 @@
import { render } from '@testing-library/react';
import React from 'react';
import HighLightMarkdown from '..';
jest.mock('@/constants/markdown-remark-plugins', () => ({
MarkdownRemarkPlugins: [],
}));
// Coderabbit MAJOR #3486038797: the previous mock rendered react-markdown's
// output as a plain <div> containing `children` as text, so the spec never
// exercised rehypeRaw or the post-preprocessLaTeX sanitization path. With
// that mock, `<b>safe</b>` was just text inside a div, and an entity-encoded
// `<img onerror=...>` payload would never reach the DOM no matter what the
// component did — masking the exact bypass DOMPurify is meant to catch.
//
// We mock react-markdown to render `children` as raw HTML (via
// dangerouslySetInnerHTML). This mimics the real pipeline: if the component
// fails to sanitize (e.g. sanitizes BEFORE preprocessLaTeX), the unsafe HTML
// will reach this mock and be inserted into the DOM, failing the assertions.
jest.mock('react-markdown', () => ({
__esModule: true,
default: ({ children }: any) => {
const ReactLib = jest.requireActual('react');
return ReactLib.createElement('div', {
dangerouslySetInnerHTML: { __html: children },
});
},
}));
jest.mock('react-syntax-highlighter', () => ({
Prism: ({ children }: any) => {
const react = jest.requireActual('react');
return react.createElement('pre', null, children);
},
}));
jest.mock('react-syntax-highlighter/dist/esm/styles/prism', () => ({
oneDark: {},
oneLight: {},
}));
jest.mock('rehype-katex', () => jest.fn());
jest.mock('rehype-raw', () => jest.fn());
jest.mock('../../theme-provider', () => ({
useIsDarkTheme: () => false,
}));
describe('HighLightMarkdown', () => {
it('sanitizes unsafe html before rendering', () => {
const { container } = render(
React.createElement(
HighLightMarkdown,
null,
'hello <img src=x onerror="alert(1)" /><script>alert(1)</script><b>safe</b>',
),
);
// <b>safe</b> is allowed by DOMPurify default profile, so it should
// render as a real <b> element with text "safe".
expect(container.querySelector('b')?.textContent).toBe('safe');
// <script> is removed entirely by DOMPurify.
expect(container.querySelector('script')).toBeNull();
// <img> is kept but its dangerous handler attribute must be stripped.
// (DOMPurify default profile removes on* event attributes.)
const imgs = container.querySelectorAll('img');
imgs.forEach((img) => {
expect(img.getAttribute('onerror')).toBeNull();
});
});
it('strips html encoded as entities (preprocessLaTeX bypass)', () => {
// preprocessLaTeX() decodes &lt;/&gt;/&amp; back to raw HTML before
// rehypeRaw runs. Sanitization must occur AFTER preprocessLaTeX, so
// a payload delivered as &lt;img onerror=...&gt; cannot survive.
const { container } = render(
React.createElement(
HighLightMarkdown,
null,
'&lt;img src=x onerror="alert(1)" /&gt;&lt;script&gt;alert(1)&lt;/script&gt;safe',
),
);
// <script> entirely removed.
expect(container.querySelector('script')).toBeNull();
// <img> kept (allowed by default profile) but onerror must be stripped.
container.querySelectorAll('img').forEach((img) => {
expect(img.getAttribute('onerror')).toBeNull();
});
// The literal word "safe" should still be visible.
expect(container.textContent).toContain('safe');
});
});

View File

@@ -1,5 +1,6 @@
import { MarkdownRemarkPlugins } from '@/constants/markdown-remark-plugins';
import classNames from 'classnames';
import DOMPurify from 'dompurify';
import Markdown from 'react-markdown';
import { Prism as SyntaxHighlighter } from 'react-syntax-highlighter';
import {
@@ -25,6 +26,12 @@ const HighLightMarkdown = ({
children: string | null | undefined;
}) => {
const isDarkTheme = useIsDarkTheme();
// IMPORTANT: preprocessLaTeX() decodes &lt;/&gt;/&amp; back to raw HTML before
// rehypeRaw parses the markdown. Sanitizing children *before* preprocessLaTeX
// would let entity-encoded payloads bypass DOMPurify and inject HTML.
// Sanitize the *post*-processed string instead. (Coderabbit CRITICAL #3486038798)
const processed = children ? preprocessLaTeX(children) : children;
const safeChildren = processed ? DOMPurify.sanitize(processed) : processed;
const dir = children
? getDirAttribute(children.replace(citationMarkerReg, ''))
: undefined;
@@ -60,7 +67,7 @@ const HighLightMarkdown = ({
} as any
}
>
{children ? preprocessLaTeX(children) : children}
{safeChildren}
</Markdown>
</div>
);