mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-07-01 16:25:44 +08:00
Harden closed-advisory fixes (#16409)
## Summary - harden reopened advisory fixes across REST connector, invoke, document downloads, and markdown rendering - add targeted regression coverage for redirect-safe SSRF handling, invoke SSRF checks, document access control, and markdown sanitization - verify each referenced GHSA against the original GitHub advisory text and align the closed-advisory plan with the implemented remediation ## What changed - add tenant access checks to document download endpoints to avoid cross-tenant document disclosure - add per-hop SSRF validation, DNS pinning, redirect handling, and redirect limits to the REST API connector - ensure invoke requests validate and pin the resolved host and never follow redirects implicitly - keep the generic rate-limited request path wrapped, not just GET and POST helpers - sanitize markdown HTML before rendering in the highlight markdown component ## Validation - `cd web && npm test -- --runInBand src/components/highlight-markdown/__tests__/index.test.tsx` - `.venv/bin/python -m pytest -q test/unit_test/data_source/test_rest_api_connector.py` - targeted `test/testcases/test_web_api/...` unit additions were reviewed, but the suite cannot be executed end-to-end in this environment because parent `test/testcases/conftest.py` requires a local service on `127.0.0.1:9380` ## Notes - all GHSA entries referenced by the plan were checked against the original GitHub advisory text, not sampled - the closed-advisory plan document was updated locally during review, but is intentionally not included in this PR
This commit is contained in:
@@ -45,7 +45,7 @@ def test_document_download_by_id_invalid_id_contract(rest_client):
|
||||
assert res.status_code == 200
|
||||
payload = res.json()
|
||||
assert payload["code"] == 102, payload
|
||||
assert payload["message"] == "The dataset not own the document invalid_document_id.", payload
|
||||
assert payload["message"] == "Document not found!", payload
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
|
||||
@@ -1503,14 +1503,14 @@ def test_documents_download_requires_auth_and_invalid_id_contract(rest_client, c
|
||||
assert invalid_doc_res.status_code == 200
|
||||
invalid_doc_payload = invalid_doc_res.json()
|
||||
assert invalid_doc_payload["code"] == 102, invalid_doc_payload
|
||||
assert "The dataset not own the document invalid_document_id." in invalid_doc_payload["message"], invalid_doc_payload
|
||||
assert invalid_doc_payload["message"] == "Document not found!", invalid_doc_payload
|
||||
|
||||
invalid_dataset_path = tmp_path / "invalid_dataset_download.txt"
|
||||
invalid_dataset_res = _download_document_to_file(rest_client, "invalid_dataset_id", document_id, invalid_dataset_path)
|
||||
assert invalid_dataset_res.status_code == 200
|
||||
invalid_dataset_payload = invalid_dataset_res.json()
|
||||
assert invalid_dataset_payload["code"] == 102, invalid_dataset_payload
|
||||
assert f"The dataset not own the document {document_id}." in invalid_dataset_payload["message"], invalid_dataset_payload
|
||||
assert invalid_dataset_payload["message"] == "Document not found!", invalid_dataset_payload
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
|
||||
@@ -271,3 +271,37 @@ def test_header_variable_with_put(monkeypatch):
|
||||
monkeypatch.setattr(module.requests, "put", mock_put)
|
||||
invoke._invoke()
|
||||
assert mock_put.call_args[1]["headers"]["Authorization"] == "Bearer put_token"
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_invoke_rejects_private_url(monkeypatch):
|
||||
module = _load_invoke_module(monkeypatch)
|
||||
invoke = _make_invoke(module, url="http://127.0.0.1:22")
|
||||
monkeypatch.setattr(
|
||||
module,
|
||||
"assert_url_is_safe",
|
||||
MagicMock(side_effect=ValueError("URL resolves to a non-public address")),
|
||||
)
|
||||
# Coderabbit MAJOR #3486038793: _build_url() is now inside the retry
|
||||
# try/except block, so the ValueError from assert_url_is_safe is caught
|
||||
# and the message is stored in _ERROR via the standard error path.
|
||||
invoke._invoke()
|
||||
assert "non-public address" in invoke._param.outputs["_ERROR"]
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_invoke_pins_dns_and_disables_redirects(monkeypatch):
|
||||
module = _load_invoke_module(monkeypatch)
|
||||
invoke = _make_invoke(module, url="http://example.com")
|
||||
pin_ctx = MagicMock()
|
||||
pin_ctx.__enter__ = MagicMock(return_value=None)
|
||||
pin_ctx.__exit__ = MagicMock(return_value=None)
|
||||
monkeypatch.setattr(module, "assert_url_is_safe", MagicMock(return_value=("example.com", "93.184.216.34")))
|
||||
monkeypatch.setattr(module, "pin_dns", MagicMock(return_value=pin_ctx))
|
||||
mock_get = MagicMock(return_value=SimpleNamespace(text="ok"))
|
||||
monkeypatch.setattr(module.requests, "get", mock_get)
|
||||
|
||||
invoke._invoke()
|
||||
|
||||
module.pin_dns.assert_called_once_with("example.com", "93.184.216.34")
|
||||
assert mock_get.call_args[1]["allow_redirects"] is False
|
||||
|
||||
@@ -450,6 +450,23 @@ class TestDocumentMetadataUnit:
|
||||
assert res["code"] == 500
|
||||
assert "download boom" in res["message"]
|
||||
|
||||
def test_download_document_rejects_other_tenant_unit(self, document_rest_api_module, monkeypatch):
|
||||
module = document_rest_api_module
|
||||
monkeypatch.setattr(module.DocumentService, "accessible", lambda _doc_id, _user_id: False)
|
||||
|
||||
res = _run(module.download_document("doc1"))
|
||||
assert res["code"] == RetCode.DATA_ERROR
|
||||
assert "Document not found!" in res["message"]
|
||||
|
||||
def test_dataset_document_download_rejects_other_tenant_unit(self, document_rest_api_module, monkeypatch):
|
||||
module = document_rest_api_module
|
||||
monkeypatch.setattr(module.KnowledgebaseService, "accessible", lambda kb_id, user_id: False)
|
||||
monkeypatch.setattr(module.DocumentService, "accessible", lambda _doc_id, _user_id: True)
|
||||
|
||||
res = _run(module.download("kb1", "doc1"))
|
||||
assert res["code"] == RetCode.DATA_ERROR
|
||||
assert "Document not found!" in res["message"]
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_get_document_image_content_type_from_object_extension_unit(self, document_app_module, monkeypatch):
|
||||
module = document_app_module
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from contextlib import contextmanager
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from unittest.mock import MagicMock, patch
|
||||
from urllib.parse import urlparse
|
||||
|
||||
@@ -211,6 +211,95 @@ class TestSSRFValidation:
|
||||
with pytest.raises(ConnectorValidationError, match="scheme"):
|
||||
_make_connector(url="file:///etc/passwd")
|
||||
|
||||
def test_redirect_to_loopback_rejected(self):
|
||||
"""Redirect targets must be revalidated before they are fetched.
|
||||
|
||||
Exercise ``_safe_request`` directly rather than ``_fetch_page``: the
|
||||
latter is wrapped by ``@retry_builder`` and in some CI environments
|
||||
the retry path on the first ``ConnectorValidationError`` exhausts the
|
||||
``side_effect`` and lets the loop run all 6 iterations, surfacing
|
||||
``Exceeded 5 redirects`` instead of the expected ``loopback blocked``.
|
||||
``_safe_request`` is the actual unit under test for redirect SSRF
|
||||
handling, so testing it directly is the more faithful check.
|
||||
|
||||
Note on the DNS mock: ``_mocked_rest_api_requests_and_dns`` uses a
|
||||
fixed ``return_value`` for ``socket.getaddrinfo``. With a constant
|
||||
return value, a redirect to ``127.0.0.1`` would be reported as
|
||||
resolving to ``93.184.216.34`` (a public address) and would slip
|
||||
through ``is_global`` checks. To exercise the actual rejection path,
|
||||
we override the patched ``getaddrinfo`` here to return the literal
|
||||
loopback address for the loopback hostname.
|
||||
"""
|
||||
connector = _make_connector()
|
||||
first = _mock_response([], status_code=302)
|
||||
first.headers = {"Location": "http://127.0.0.1/secret"}
|
||||
|
||||
def _dns_for_host(host, *args, **kwargs):
|
||||
# Mirror _MOCK_DNS_ADDRINFO shape: (family, type, proto, canon, sockaddr).
|
||||
if host == "127.0.0.1":
|
||||
return [(2, 1, 6, "", ("127.0.0.1", 0))]
|
||||
return list(_MOCK_DNS_ADDRINFO)
|
||||
|
||||
with _mocked_rest_api_requests_and_dns() as mock_rl:
|
||||
mock_rl.get.return_value = first
|
||||
# The ``_mocked_rest_api_requests_and_dns`` context manager already
|
||||
# patches ``socket.getaddrinfo`` with a constant ``return_value``;
|
||||
# replace it here with a side_effect that distinguishes loopback
|
||||
# from public hostnames so the SSRF guard actually rejects the
|
||||
# redirect target.
|
||||
import common.data_source.rest_api_connector as rc_module
|
||||
from unittest.mock import patch as _patch
|
||||
with _patch.object(rc_module.socket, "getaddrinfo", side_effect=_dns_for_host):
|
||||
# Coderabbit MAJOR #3486038795: SSRF validation failures inside
|
||||
# _safe_request are now wrapped to raise ConnectorValidationError
|
||||
# (the connector's documented error contract) instead of leaking
|
||||
# raw ValueError from ssrf_guard.
|
||||
with pytest.raises(ConnectorValidationError, match=r"non-public address|loopback"):
|
||||
connector._safe_request(
|
||||
"GET",
|
||||
connector.url,
|
||||
headers={},
|
||||
params={},
|
||||
)
|
||||
|
||||
@patch("common.data_source.rest_api_connector.assert_url_is_safe")
|
||||
@patch("common.data_source.rest_api_connector.pin_dns")
|
||||
def test_post_307_preserves_body(self, mock_pin_dns, mock_safe):
|
||||
"""307 redirects should keep method and JSON body."""
|
||||
connector = _make_connector(method="POST", request_body={"hello": "world"})
|
||||
first = _mock_response([], status_code=307)
|
||||
first.headers = {"Location": "https://api.example.com/redirected"}
|
||||
second = _mock_response({"items": []}, status_code=200)
|
||||
mock_safe.side_effect = [
|
||||
("api.example.com", "93.184.216.34"),
|
||||
("api.example.com", "93.184.216.34"),
|
||||
]
|
||||
mock_pin_dns.return_value = nullcontext()
|
||||
|
||||
with _mocked_rest_api_requests_and_dns() as mock_rl:
|
||||
mock_rl.post.side_effect = [first, second]
|
||||
connector._fetch_page({})
|
||||
|
||||
assert mock_rl.post.call_count == 2
|
||||
assert mock_rl.post.call_args_list[0].kwargs["json"] == {"hello": "world"}
|
||||
assert mock_rl.post.call_args_list[1].kwargs["json"] == {"hello": "world"}
|
||||
assert mock_rl.post.call_args_list[1].kwargs["allow_redirects"] is False
|
||||
|
||||
@patch("common.data_source.rest_api_connector.assert_url_is_safe")
|
||||
@patch("common.data_source.rest_api_connector.pin_dns")
|
||||
def test_exceeds_max_redirects_raises(self, mock_pin_dns, mock_safe):
|
||||
"""Too many redirects should raise a connector validation error."""
|
||||
connector = _make_connector()
|
||||
redirect = _mock_response([], status_code=302)
|
||||
redirect.headers = {"Location": "https://api.example.com/next"}
|
||||
mock_safe.side_effect = [("api.example.com", "93.184.216.34")] * 6
|
||||
mock_pin_dns.return_value = nullcontext()
|
||||
|
||||
with _mocked_rest_api_requests_and_dns() as mock_rl:
|
||||
mock_rl.get.side_effect = [redirect] * 6
|
||||
with pytest.raises(ConnectorValidationError, match="Exceeded 5 redirects"):
|
||||
connector._fetch_page({})
|
||||
|
||||
|
||||
# ===================================================================== #
|
||||
# 3. Authentication setup #
|
||||
|
||||
Reference in New Issue
Block a user