Files
ragflow/test/unit_test/agent/tools/test_exesql_ssrf.py
philluiz2323 e256d91ade fix: guard SSRF in ExeSQL agent tool DB host (#15609)
### What problem does this PR solve?

Closes #15608.

The ExeSQL agent tool (`agent/tools/exesql.py`) opens database
connections to a node-author-controlled host/port with no SSRF
validation. The sibling `test_db_connection` endpoint already validates
the host via `common.ssrf_guard.assert_host_is_safe` (added by PR
#14860), but the tool that actually performs the connection at agent run
time was left unguarded — so the guard is bypassed simply by running the
agent. An agent author can point the host at `127.0.0.1`,
`169.254.169.254` (cloud metadata), or any internal RFC1918 host/port,
turning ExeSQL into an internal port-scanner / metadata-fetch primitive.

### Fix

Mirror the accepted endpoint guard: validate (and resolve) the host
once, before the `db_type` dispatch, and connect to the validated public
IP so a later DNS change cannot rebind the host to an internal address.

- Add `from common.ssrf_guard import assert_host_is_safe`.
- `safe_host = assert_host_is_safe(self._param.host)` before the
dispatch (rejects loopback, link-local/metadata, RFC1918, and
unresolvable hosts).
- Substitute the validated IP into all 6 driver branches: mysql/mariadb,
oceanbase, postgres, mssql, trino, IBM DB2.

Adds `test/unit_test/agent/tools/test_exesql_ssrf.py` covering loopback,
link-local/metadata, RFC1918, and empty-host rejection (before any
connection), plus an allowed host dialing the validated IP.

### Validation

- `python3 -m py_compile agent/tools/exesql.py`
- `ruff check agent/tools/exesql.py
test/unit_test/agent/tools/test_exesql_ssrf.py`
- `pytest test/unit_test/agent/tools/test_exesql_ssrf.py` — 5 passed

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)

---------

Co-authored-by: Zhichang Yu <yuzhichang@gmail.com>
2026-06-29 09:45:16 +08:00

144 lines
4.8 KiB
Python

#
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""SSRF-guard regression tests for the ExeSQL agent tool.
The DB host/port are node-author-controlled and connected to server-side, so
``ExeSQL._invoke`` must reject hosts that resolve to non-public addresses
(loopback, link-local/metadata, RFC1918) before opening any connection, and
must dial the validated/resolved public IP for allowed hosts — mirroring the
``test_db_connection`` endpoint guard (PR #14860).
``agent.tools.exesql`` is loaded in isolation (its package ``__init__`` would
auto-discover every tool and pull in the full agent framework), with the heavy
DB drivers and the agent base classes stubbed so only the real SSRF guard runs.
"""
import importlib.util
import sys
import types
from pathlib import Path
from types import SimpleNamespace
import pytest
_REPO_ROOT = Path(__file__).resolve().parents[4]
class _RecordingPyMySQL:
"""Fake pymysql whose connect() records the host it was asked to dial."""
def __init__(self):
self.dialed_host = None
def connect(self, *args, **kwargs):
self.dialed_host = kwargs.get("host")
raise RuntimeError("connection attempted") # stop before real DB I/O
_fake_pymysql = _RecordingPyMySQL()
def _load_exesql_module():
# Stub the heavy drivers and the agent base so the module imports cleanly.
for name in ("pandas", "psycopg2", "pyodbc"):
mod = types.ModuleType(name)
mod.connect = lambda *a, **k: None
sys.modules.setdefault(name, mod)
pymysql_stub = types.ModuleType("pymysql")
pymysql_stub.connect = _fake_pymysql.connect
sys.modules["pymysql"] = pymysql_stub
base = types.ModuleType("agent.tools.base")
class _ToolParamBase:
def __init__(self):
pass
class _ToolBase:
def __init__(self, *a, **k):
pass
base.ToolParamBase = _ToolParamBase
base.ToolBase = _ToolBase
base.ToolMeta = dict
for pkg in ("agent", "agent.tools"):
sys.modules.setdefault(pkg, types.ModuleType(pkg))
sys.modules["agent.tools.base"] = base
# Neutralize the @timeout decorator so _invoke is a plain method.
conn_utils = types.ModuleType("common.connection_utils")
conn_utils.timeout = lambda *a, **k: (lambda f: f)
sys.modules["common.connection_utils"] = conn_utils
spec = importlib.util.spec_from_file_location(
"exesql_uut", _REPO_ROOT / "agent" / "tools" / "exesql.py"
)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
return mod
_exesql_mod = _load_exesql_module()
ExeSQL = _exesql_mod.ExeSQL
def _build_exesql(host, db_type="mysql"):
cpn = ExeSQL.__new__(ExeSQL)
cpn._canvas = SimpleNamespace()
cpn._param = SimpleNamespace(
host=host, port=3306, db_type=db_type,
database="db", username="u", password="p",
)
# Neutralize the component machinery that runs before the host check.
cpn.check_if_canceled = lambda *_a, **_k: False
cpn.get_input_elements_from_text = lambda _sql: {}
cpn.set_input_value = lambda *_a, **_k: None
cpn.string_format = lambda sql, _args: sql
return cpn
@pytest.mark.p2
@pytest.mark.parametrize("host", ["127.0.0.1", "169.254.169.254", "10.0.0.5"])
def test_internal_host_rejected_before_connect(host):
_fake_pymysql.dialed_host = None
cpn = _build_exesql(host)
with pytest.raises(Exception) as ei:
cpn._invoke(sql="SELECT 1")
assert "not allowed" in str(ei.value)
# The SSRF guard must fire before any connection is attempted.
assert _fake_pymysql.dialed_host is None
@pytest.mark.p2
def test_empty_host_rejected():
cpn = _build_exesql("")
with pytest.raises(Exception) as ei:
cpn._invoke(sql="SELECT 1")
assert "not allowed" in str(ei.value)
@pytest.mark.p2
def test_public_host_dials_validated_ip(monkeypatch):
# Public host: pretend it resolves to a public IP, and ensure the driver is
# dialed with that validated IP (not the raw hostname).
monkeypatch.setattr(_exesql_mod, "assert_host_is_safe", lambda _h: "93.184.216.34")
_fake_pymysql.dialed_host = None
cpn = _build_exesql("db.example.com")
with pytest.raises(Exception):
cpn._invoke(sql="SELECT 1") # RuntimeError from the recording connect
assert _fake_pymysql.dialed_host == "93.184.216.34"