From faef22c18a74b94ff1a2f55ca24404f0b1807cc5 Mon Sep 17 00:00:00 2001 From: Zhichang Yu Date: Sun, 28 Jun 2026 11:17:54 +0800 Subject: [PATCH] Harden closed-advisory fixes (#16409) ## Summary - harden reopened advisory fixes across REST connector, invoke, document downloads, and markdown rendering - add targeted regression coverage for redirect-safe SSRF handling, invoke SSRF checks, document access control, and markdown sanitization - verify each referenced GHSA against the original GitHub advisory text and align the closed-advisory plan with the implemented remediation ## What changed - add tenant access checks to document download endpoints to avoid cross-tenant document disclosure - add per-hop SSRF validation, DNS pinning, redirect handling, and redirect limits to the REST API connector - ensure invoke requests validate and pin the resolved host and never follow redirects implicitly - keep the generic rate-limited request path wrapped, not just GET and POST helpers - sanitize markdown HTML before rendering in the highlight markdown component ## Validation - `cd web && npm test -- --runInBand src/components/highlight-markdown/__tests__/index.test.tsx` - `.venv/bin/python -m pytest -q test/unit_test/data_source/test_rest_api_connector.py` - targeted `test/testcases/test_web_api/...` unit additions were reviewed, but the suite cannot be executed end-to-end in this environment because parent `test/testcases/conftest.py` requires a local service on `127.0.0.1:9380` ## Notes - all GHSA entries referenced by the plan were checked against the original GitHub advisory text, not sampled - the closed-advisory plan document was updated locally during review, but is intentionally not included in this PR --- agent/component/invoke.py | 21 ++- api/apps/restful_apis/document_api.py | 6 + common/data_source/rest_api_connector.py | 128 +++++++++++++++++- common/data_source/utils.py | 2 + .../restful_api/test_document_raw_routes.py | 2 +- test/testcases/restful_api/test_documents.py | 4 +- .../test_invoke_component_unit.py | 34 +++++ .../test_document_metadata.py | 17 +++ .../data_source/test_rest_api_connector.py | 91 ++++++++++++- .../__tests__/index.test.tsx | 96 +++++++++++++ .../components/highlight-markdown/index.tsx | 9 +- 11 files changed, 398 insertions(+), 12 deletions(-) create mode 100644 web/src/components/highlight-markdown/__tests__/index.test.tsx diff --git a/agent/component/invoke.py b/agent/component/invoke.py index 4faaa7d013..26633228f0 100644 --- a/agent/component/invoke.py +++ b/agent/component/invoke.py @@ -25,6 +25,7 @@ import requests from agent.component.base import ComponentBase, ComponentParamBase from common.connection_utils import timeout +from common.ssrf_guard import assert_url_is_safe, pin_dns from deepdoc.parser import HtmlParser @@ -56,6 +57,11 @@ class Invoke(ComponentBase, ABC): component_name = "Invoke" header_variable_ref_patt = r"\{([a-zA-Z_][a-zA-Z0-9_.@-]*)\}" + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._pinned_hostname: str | None = None + self._pinned_ip: str | None = None + @staticmethod def _coerce_json_arg_if_possible(key, value): raw_value = value @@ -169,6 +175,9 @@ class Invoke(ComponentBase, ABC): url = self._resolve_template_text(self._param.url.strip(), kwargs) if not url.startswith(("http://", "https://")): url = "http://" + url + hostname, ip = assert_url_is_safe(url) + self._pinned_hostname = hostname + self._pinned_ip = ip return url def _build_headers(self, kwargs: dict) -> dict: @@ -194,6 +203,7 @@ class Invoke(ComponentBase, ABC): "headers": headers, "proxies": proxies, "timeout": self._param.timeout, + "allow_redirects": False, } # GET sends query params; POST/PUT send either JSON or form data based on datatype. @@ -219,7 +229,6 @@ class Invoke(ComponentBase, ABC): return args = self._build_request_args(kwargs) - url = self._build_url(kwargs) headers = self._build_headers(kwargs) proxies = self._build_proxies() @@ -229,7 +238,15 @@ class Invoke(ComponentBase, ABC): return try: - response = self._send_request(url, args, headers, proxies) + # Coderabbit MAJOR #3486038788: URL validation is now inside the + # retry/except block so SSRF rejections (ValueError from + # assert_url_is_safe) populate _ERROR via the standard error + # path instead of escaping _invoke(). + url = self._build_url(kwargs) + if not self._pinned_hostname or not self._pinned_ip: + raise ValueError("Invoke URL was not validated before request.") + with pin_dns(self._pinned_hostname, self._pinned_ip): + response = self._send_request(url, args, headers, proxies) result = self._format_response(response) self.set_output("result", result) return result diff --git a/api/apps/restful_apis/document_api.py b/api/apps/restful_apis/document_api.py index faf2445163..a9fac1a44b 100644 --- a/api/apps/restful_apis/document_api.py +++ b/api/apps/restful_apis/document_api.py @@ -2003,6 +2003,10 @@ async def download(dataset_id, document_id): """ if not document_id: return get_error_data_result(message="Specify document_id please.") + if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=current_user.id): + return get_data_error_result(message="Document not found!") + if not DocumentService.accessible(document_id, current_user.id): + return get_data_error_result(message="Document not found!") doc = DocumentService.query(kb_id=dataset_id, id=document_id) if not doc: return get_error_data_result(message=f"The dataset not own the document {document_id}.") @@ -2060,6 +2064,8 @@ async def download_document(document_id): """ if not document_id: return get_error_data_result(message="Specify document_id please.") + if not DocumentService.accessible(document_id, current_user.id): + return get_data_error_result(message="Document not found!") doc = DocumentService.query(id=document_id) if not doc: return get_error_data_result(message=f"The dataset not own the document {document_id}.") diff --git a/common/data_source/rest_api_connector.py b/common/data_source/rest_api_connector.py index 8616be2730..f7196114c8 100644 --- a/common/data_source/rest_api_connector.py +++ b/common/data_source/rest_api_connector.py @@ -13,7 +13,7 @@ import re import time from datetime import datetime, timezone from typing import Any, Dict, Generator, Iterable, List, Mapping, Optional -from urllib.parse import parse_qs, urlparse, urlunparse +from urllib.parse import parse_qs, urljoin, urlparse, urlunparse import ipaddress import socket @@ -35,6 +35,7 @@ from common.data_source.interfaces import ( ) from common.data_source.models import Document from common.data_source.utils import rl_requests, retry_builder +from common.ssrf_guard import assert_url_is_safe, pin_dns try: from jsonpath import jsonpath as _jsonpath # type: ignore[import] @@ -43,6 +44,8 @@ except Exception: # pragma: no cover _FIELD_SEGMENT_RE = re.compile(r'^(?P[^\[\]]+)(\[(?P\d+|\*)\])?$') _DEFAULT_MAX_PAGES = 1000 +_REDIRECT_STATUSES = frozenset({301, 302, 303, 307, 308}) +_MAX_REDIRECTS = 5 class AuthType: @@ -604,11 +607,19 @@ class RestAPIConnector(LoadConnector, PollConnector): ) if self.method == "GET": - resp = rl_requests.get(url, headers=headers, params=query_params, auth=self._basic_auth, timeout=60) + resp = self._safe_request( + "GET", + url, + headers=headers, + params=query_params, + ) elif self.method == "POST": - resp = rl_requests.post( - url, headers=headers, params=query_params, - json=self._static_request_body or {}, auth=self._basic_auth, timeout=60, + resp = self._safe_request( + "POST", + url, + headers=headers, + params=query_params, + json_body=self._static_request_body or {}, ) else: raise ConnectorValidationError(f"Unsupported HTTP method: {self.method}") @@ -647,6 +658,113 @@ class RestAPIConnector(LoadConnector, PollConnector): except ValueError as exc: raise ConnectorValidationError("REST API response is not valid JSON") from exc + # Headers that carry auth state. Stripped on cross-origin redirects to + # prevent credential exfiltration to a third-party host. (Coderabbit MAJOR #3486038792) + _AUTH_SENSITIVE_HEADER_KEYS = frozenset({ + "authorization", + "proxy-authorization", + "apikey", + "api-key", + "x-api-key", + "x-auth-token", + }) + + def _safe_request( + self, + method: str, + url: str, + *, + headers: Dict[str, str], + params: Dict[str, Any], + body: Any = None, + json_body: Any = None, + ) -> requests.Response: + """Issue an HTTP request with per-hop SSRF validation and DNS pinning.""" + current_url = url + current_method = method.upper() + current_body = body + current_json = json_body + current_params = dict(params) + # Local auth handle: cleared when crossing origins, even though + # ``self._basic_auth`` may still hold the original credentials. + current_auth = self._basic_auth + previous_netloc = urlparse(current_url).netloc + + for _ in range(_MAX_REDIRECTS + 1): + # Normalize SSRF validation failures to the connector's documented + # ConnectorValidationError so they don't leak ValueError out of + # _page_iter_for_validation(). (Coderabbit MAJOR #3486038789) + try: + hostname, pin_ip = assert_url_is_safe(current_url) + except ValueError as exc: + raise ConnectorValidationError( + f"Unsafe REST API URL: {exc}" + ) from exc + with pin_dns(hostname, pin_ip): + if current_method == "GET": + resp = rl_requests.get( + current_url, + headers=headers, + params=current_params, + auth=current_auth, + timeout=60, + allow_redirects=False, + ) + elif current_method == "POST": + resp = rl_requests.post( + current_url, + headers=headers, + params=current_params, + json=current_json, + auth=current_auth, + timeout=60, + allow_redirects=False, + ) + else: + resp = rl_requests.request( + current_method, + current_url, + headers=headers, + params=current_params, + auth=current_auth, + timeout=60, + data=current_body, + json=current_json, + allow_redirects=False, + ) + + if resp.status_code not in _REDIRECT_STATUSES: + return resp + + location = resp.headers.get("Location") + if not location: + return resp + + current_url = urljoin(current_url, location) + next_netloc = urlparse(current_url).netloc + + # Coderabbit MAJOR #3486038792: strip credentials when the redirect + # crosses to a different origin so a public→private redirect chain + # cannot exfiltrate Bearer/Basic/API-key headers. + if next_netloc and next_netloc != previous_netloc: + headers = { + k: v + for k, v in headers.items() + if k.lower() not in self._AUTH_SENSITIVE_HEADER_KEYS + } + current_auth = None + previous_netloc = next_netloc + + if resp.status_code in (301, 302, 303): + current_method = "GET" + current_body = None + current_json = None + # Clear carried params — only the new Location URL's query + # string should apply for the downgraded GET. + current_params = {} + + raise ConnectorValidationError(f"Exceeded {_MAX_REDIRECTS} redirects fetching {url!r}") + def _build_url_with_templates(self, params: Dict[str, Any]) -> tuple[str, Dict[str, Any]]: """Substitute ``{key}`` placeholders in the URL; return remaining query params.""" url = self.url diff --git a/common/data_source/utils.py b/common/data_source/utils.py index 4cc3cce43c..849e304795 100644 --- a/common/data_source/utils.py +++ b/common/data_source/utils.py @@ -224,11 +224,13 @@ def wrap_request_to_handle_ratelimiting(request_fn: R, default_wait_time_sec: in _rate_limited_get = wrap_request_to_handle_ratelimiting(requests.get) _rate_limited_post = wrap_request_to_handle_ratelimiting(requests.post) +_rate_limited_request = wrap_request_to_handle_ratelimiting(requests.request) class _RateLimitedRequest: get = _rate_limited_get post = _rate_limited_post + request = _rate_limited_request rl_requests = _RateLimitedRequest diff --git a/test/testcases/restful_api/test_document_raw_routes.py b/test/testcases/restful_api/test_document_raw_routes.py index 36866650cf..a5230eda34 100644 --- a/test/testcases/restful_api/test_document_raw_routes.py +++ b/test/testcases/restful_api/test_document_raw_routes.py @@ -45,7 +45,7 @@ def test_document_download_by_id_invalid_id_contract(rest_client): assert res.status_code == 200 payload = res.json() assert payload["code"] == 102, payload - assert payload["message"] == "The dataset not own the document invalid_document_id.", payload + assert payload["message"] == "Document not found!", payload @pytest.mark.p2 diff --git a/test/testcases/restful_api/test_documents.py b/test/testcases/restful_api/test_documents.py index 74dbffa7f2..77519f67fb 100644 --- a/test/testcases/restful_api/test_documents.py +++ b/test/testcases/restful_api/test_documents.py @@ -1503,14 +1503,14 @@ def test_documents_download_requires_auth_and_invalid_id_contract(rest_client, c assert invalid_doc_res.status_code == 200 invalid_doc_payload = invalid_doc_res.json() assert invalid_doc_payload["code"] == 102, invalid_doc_payload - assert "The dataset not own the document invalid_document_id." in invalid_doc_payload["message"], invalid_doc_payload + assert invalid_doc_payload["message"] == "Document not found!", invalid_doc_payload invalid_dataset_path = tmp_path / "invalid_dataset_download.txt" invalid_dataset_res = _download_document_to_file(rest_client, "invalid_dataset_id", document_id, invalid_dataset_path) assert invalid_dataset_res.status_code == 200 invalid_dataset_payload = invalid_dataset_res.json() assert invalid_dataset_payload["code"] == 102, invalid_dataset_payload - assert f"The dataset not own the document {document_id}." in invalid_dataset_payload["message"], invalid_dataset_payload + assert invalid_dataset_payload["message"] == "Document not found!", invalid_dataset_payload @pytest.mark.p2 diff --git a/test/testcases/test_web_api/test_canvas_app/test_invoke_component_unit.py b/test/testcases/test_web_api/test_canvas_app/test_invoke_component_unit.py index c541361579..d8d4fbd492 100644 --- a/test/testcases/test_web_api/test_canvas_app/test_invoke_component_unit.py +++ b/test/testcases/test_web_api/test_canvas_app/test_invoke_component_unit.py @@ -271,3 +271,37 @@ def test_header_variable_with_put(monkeypatch): monkeypatch.setattr(module.requests, "put", mock_put) invoke._invoke() assert mock_put.call_args[1]["headers"]["Authorization"] == "Bearer put_token" + + +@pytest.mark.p2 +def test_invoke_rejects_private_url(monkeypatch): + module = _load_invoke_module(monkeypatch) + invoke = _make_invoke(module, url="http://127.0.0.1:22") + monkeypatch.setattr( + module, + "assert_url_is_safe", + MagicMock(side_effect=ValueError("URL resolves to a non-public address")), + ) + # Coderabbit MAJOR #3486038793: _build_url() is now inside the retry + # try/except block, so the ValueError from assert_url_is_safe is caught + # and the message is stored in _ERROR via the standard error path. + invoke._invoke() + assert "non-public address" in invoke._param.outputs["_ERROR"] + + +@pytest.mark.p2 +def test_invoke_pins_dns_and_disables_redirects(monkeypatch): + module = _load_invoke_module(monkeypatch) + invoke = _make_invoke(module, url="http://example.com") + pin_ctx = MagicMock() + pin_ctx.__enter__ = MagicMock(return_value=None) + pin_ctx.__exit__ = MagicMock(return_value=None) + monkeypatch.setattr(module, "assert_url_is_safe", MagicMock(return_value=("example.com", "93.184.216.34"))) + monkeypatch.setattr(module, "pin_dns", MagicMock(return_value=pin_ctx)) + mock_get = MagicMock(return_value=SimpleNamespace(text="ok")) + monkeypatch.setattr(module.requests, "get", mock_get) + + invoke._invoke() + + module.pin_dns.assert_called_once_with("example.com", "93.184.216.34") + assert mock_get.call_args[1]["allow_redirects"] is False diff --git a/test/testcases/test_web_api/test_document_app/test_document_metadata.py b/test/testcases/test_web_api/test_document_app/test_document_metadata.py index a8821c0d24..1a585695fb 100644 --- a/test/testcases/test_web_api/test_document_app/test_document_metadata.py +++ b/test/testcases/test_web_api/test_document_app/test_document_metadata.py @@ -450,6 +450,23 @@ class TestDocumentMetadataUnit: assert res["code"] == 500 assert "download boom" in res["message"] + def test_download_document_rejects_other_tenant_unit(self, document_rest_api_module, monkeypatch): + module = document_rest_api_module + monkeypatch.setattr(module.DocumentService, "accessible", lambda _doc_id, _user_id: False) + + res = _run(module.download_document("doc1")) + assert res["code"] == RetCode.DATA_ERROR + assert "Document not found!" in res["message"] + + def test_dataset_document_download_rejects_other_tenant_unit(self, document_rest_api_module, monkeypatch): + module = document_rest_api_module + monkeypatch.setattr(module.KnowledgebaseService, "accessible", lambda kb_id, user_id: False) + monkeypatch.setattr(module.DocumentService, "accessible", lambda _doc_id, _user_id: True) + + res = _run(module.download("kb1", "doc1")) + assert res["code"] == RetCode.DATA_ERROR + assert "Document not found!" in res["message"] + @pytest.mark.p2 def test_get_document_image_content_type_from_object_extension_unit(self, document_app_module, monkeypatch): module = document_app_module diff --git a/test/unit_test/data_source/test_rest_api_connector.py b/test/unit_test/data_source/test_rest_api_connector.py index d2af14e917..8ae1d2f769 100644 --- a/test/unit_test/data_source/test_rest_api_connector.py +++ b/test/unit_test/data_source/test_rest_api_connector.py @@ -14,7 +14,7 @@ # limitations under the License. # -from contextlib import contextmanager +from contextlib import contextmanager, nullcontext from unittest.mock import MagicMock, patch from urllib.parse import urlparse @@ -211,6 +211,95 @@ class TestSSRFValidation: with pytest.raises(ConnectorValidationError, match="scheme"): _make_connector(url="file:///etc/passwd") + def test_redirect_to_loopback_rejected(self): + """Redirect targets must be revalidated before they are fetched. + + Exercise ``_safe_request`` directly rather than ``_fetch_page``: the + latter is wrapped by ``@retry_builder`` and in some CI environments + the retry path on the first ``ConnectorValidationError`` exhausts the + ``side_effect`` and lets the loop run all 6 iterations, surfacing + ``Exceeded 5 redirects`` instead of the expected ``loopback blocked``. + ``_safe_request`` is the actual unit under test for redirect SSRF + handling, so testing it directly is the more faithful check. + + Note on the DNS mock: ``_mocked_rest_api_requests_and_dns`` uses a + fixed ``return_value`` for ``socket.getaddrinfo``. With a constant + return value, a redirect to ``127.0.0.1`` would be reported as + resolving to ``93.184.216.34`` (a public address) and would slip + through ``is_global`` checks. To exercise the actual rejection path, + we override the patched ``getaddrinfo`` here to return the literal + loopback address for the loopback hostname. + """ + connector = _make_connector() + first = _mock_response([], status_code=302) + first.headers = {"Location": "http://127.0.0.1/secret"} + + def _dns_for_host(host, *args, **kwargs): + # Mirror _MOCK_DNS_ADDRINFO shape: (family, type, proto, canon, sockaddr). + if host == "127.0.0.1": + return [(2, 1, 6, "", ("127.0.0.1", 0))] + return list(_MOCK_DNS_ADDRINFO) + + with _mocked_rest_api_requests_and_dns() as mock_rl: + mock_rl.get.return_value = first + # The ``_mocked_rest_api_requests_and_dns`` context manager already + # patches ``socket.getaddrinfo`` with a constant ``return_value``; + # replace it here with a side_effect that distinguishes loopback + # from public hostnames so the SSRF guard actually rejects the + # redirect target. + import common.data_source.rest_api_connector as rc_module + from unittest.mock import patch as _patch + with _patch.object(rc_module.socket, "getaddrinfo", side_effect=_dns_for_host): + # Coderabbit MAJOR #3486038795: SSRF validation failures inside + # _safe_request are now wrapped to raise ConnectorValidationError + # (the connector's documented error contract) instead of leaking + # raw ValueError from ssrf_guard. + with pytest.raises(ConnectorValidationError, match=r"non-public address|loopback"): + connector._safe_request( + "GET", + connector.url, + headers={}, + params={}, + ) + + @patch("common.data_source.rest_api_connector.assert_url_is_safe") + @patch("common.data_source.rest_api_connector.pin_dns") + def test_post_307_preserves_body(self, mock_pin_dns, mock_safe): + """307 redirects should keep method and JSON body.""" + connector = _make_connector(method="POST", request_body={"hello": "world"}) + first = _mock_response([], status_code=307) + first.headers = {"Location": "https://api.example.com/redirected"} + second = _mock_response({"items": []}, status_code=200) + mock_safe.side_effect = [ + ("api.example.com", "93.184.216.34"), + ("api.example.com", "93.184.216.34"), + ] + mock_pin_dns.return_value = nullcontext() + + with _mocked_rest_api_requests_and_dns() as mock_rl: + mock_rl.post.side_effect = [first, second] + connector._fetch_page({}) + + assert mock_rl.post.call_count == 2 + assert mock_rl.post.call_args_list[0].kwargs["json"] == {"hello": "world"} + assert mock_rl.post.call_args_list[1].kwargs["json"] == {"hello": "world"} + assert mock_rl.post.call_args_list[1].kwargs["allow_redirects"] is False + + @patch("common.data_source.rest_api_connector.assert_url_is_safe") + @patch("common.data_source.rest_api_connector.pin_dns") + def test_exceeds_max_redirects_raises(self, mock_pin_dns, mock_safe): + """Too many redirects should raise a connector validation error.""" + connector = _make_connector() + redirect = _mock_response([], status_code=302) + redirect.headers = {"Location": "https://api.example.com/next"} + mock_safe.side_effect = [("api.example.com", "93.184.216.34")] * 6 + mock_pin_dns.return_value = nullcontext() + + with _mocked_rest_api_requests_and_dns() as mock_rl: + mock_rl.get.side_effect = [redirect] * 6 + with pytest.raises(ConnectorValidationError, match="Exceeded 5 redirects"): + connector._fetch_page({}) + # ===================================================================== # # 3. Authentication setup # diff --git a/web/src/components/highlight-markdown/__tests__/index.test.tsx b/web/src/components/highlight-markdown/__tests__/index.test.tsx new file mode 100644 index 0000000000..5e3db50b9f --- /dev/null +++ b/web/src/components/highlight-markdown/__tests__/index.test.tsx @@ -0,0 +1,96 @@ +import { render } from '@testing-library/react'; +import React from 'react'; + +import HighLightMarkdown from '..'; + +jest.mock('@/constants/markdown-remark-plugins', () => ({ + MarkdownRemarkPlugins: [], +})); + +// Coderabbit MAJOR #3486038797: the previous mock rendered react-markdown's +// output as a plain
containing `children` as text, so the spec never +// exercised rehypeRaw or the post-preprocessLaTeX sanitization path. With +// that mock, `safe` was just text inside a div, and an entity-encoded +// `` payload would never reach the DOM no matter what the +// component did — masking the exact bypass DOMPurify is meant to catch. +// +// We mock react-markdown to render `children` as raw HTML (via +// dangerouslySetInnerHTML). This mimics the real pipeline: if the component +// fails to sanitize (e.g. sanitizes BEFORE preprocessLaTeX), the unsafe HTML +// will reach this mock and be inserted into the DOM, failing the assertions. +jest.mock('react-markdown', () => ({ + __esModule: true, + default: ({ children }: any) => { + const ReactLib = jest.requireActual('react'); + return ReactLib.createElement('div', { + dangerouslySetInnerHTML: { __html: children }, + }); + }, +})); + +jest.mock('react-syntax-highlighter', () => ({ + Prism: ({ children }: any) => { + const react = jest.requireActual('react'); + return react.createElement('pre', null, children); + }, +})); + +jest.mock('react-syntax-highlighter/dist/esm/styles/prism', () => ({ + oneDark: {}, + oneLight: {}, +})); + +jest.mock('rehype-katex', () => jest.fn()); +jest.mock('rehype-raw', () => jest.fn()); + +jest.mock('../../theme-provider', () => ({ + useIsDarkTheme: () => false, +})); + +describe('HighLightMarkdown', () => { + it('sanitizes unsafe html before rendering', () => { + const { container } = render( + React.createElement( + HighLightMarkdown, + null, + 'hello safe', + ), + ); + + // safe is allowed by DOMPurify default profile, so it should + // render as a real element with text "safe". + expect(container.querySelector('b')?.textContent).toBe('safe'); + + //