fix: guard SSRF in ExeSQL agent tool DB host (#15609)

### What problem does this PR solve?

Closes #15608.

The ExeSQL agent tool (`agent/tools/exesql.py`) opens database
connections to a node-author-controlled host/port with no SSRF
validation. The sibling `test_db_connection` endpoint already validates
the host via `common.ssrf_guard.assert_host_is_safe` (added by PR
#14860), but the tool that actually performs the connection at agent run
time was left unguarded — so the guard is bypassed simply by running the
agent. An agent author can point the host at `127.0.0.1`,
`169.254.169.254` (cloud metadata), or any internal RFC1918 host/port,
turning ExeSQL into an internal port-scanner / metadata-fetch primitive.

### Fix

Mirror the accepted endpoint guard: validate (and resolve) the host
once, before the `db_type` dispatch, and connect to the validated public
IP so a later DNS change cannot rebind the host to an internal address.

- Add `from common.ssrf_guard import assert_host_is_safe`.
- `safe_host = assert_host_is_safe(self._param.host)` before the
dispatch (rejects loopback, link-local/metadata, RFC1918, and
unresolvable hosts).
- Substitute the validated IP into all 6 driver branches: mysql/mariadb,
oceanbase, postgres, mssql, trino, IBM DB2.

Adds `test/unit_test/agent/tools/test_exesql_ssrf.py` covering loopback,
link-local/metadata, RFC1918, and empty-host rejection (before any
connection), plus an allowed host dialing the validated IP.

### Validation

- `python3 -m py_compile agent/tools/exesql.py`
- `ruff check agent/tools/exesql.py
test/unit_test/agent/tools/test_exesql_ssrf.py`
- `pytest test/unit_test/agent/tools/test_exesql_ssrf.py` — 5 passed

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)

---------

Co-authored-by: Zhichang Yu <yuzhichang@gmail.com>
This commit is contained in:
philluiz2323
2026-06-27 21:11:21 -07:00
committed by yzc
parent 0d7ad0ed0c
commit e256d91ade
2 changed files with 164 additions and 6 deletions

View File

@@ -15,6 +15,7 @@
#
import contextlib
import json
import logging
import os
import re
from abc import ABC
@@ -24,6 +25,7 @@ import psycopg2
import pyodbc
from agent.tools.base import ToolParamBase, ToolBase, ToolMeta
from common.connection_utils import timeout
from common.ssrf_guard import assert_host_is_safe
class ExeSQLParam(ToolParamBase):
@@ -123,20 +125,33 @@ class ExeSQL(ToolBase, ABC):
if self.check_if_canceled("ExeSQL processing"):
return
# The DB host/port are node-author-controlled and are connected to
# server-side, so guard against SSRF (internal hosts, loopback, cloud
# metadata) the same way the `test_db_connection` endpoint does. Connect
# to the validated, resolved public IP so a later DNS change cannot
# rebind the host to an internal address (mirrors agent_api.py).
logging.info(f"ExeSQL validating database host: {self._param.host}")
try:
safe_host = assert_host_is_safe(self._param.host)
except ValueError as e:
logging.warning(f"ExeSQL rejected database host {self._param.host}: {e}")
raise Exception(f"Database host '{self._param.host}' is not allowed: {e}")
logging.info(f"ExeSQL validated database host {self._param.host} -> {safe_host}")
sqls = sql.split(";")
if self._param.db_type in ["mysql", "mariadb"]:
db = pymysql.connect(db=self._param.database, user=self._param.username, host=self._param.host,
db = pymysql.connect(db=self._param.database, user=self._param.username, host=safe_host,
port=self._param.port, password=self._param.password)
elif self._param.db_type == 'oceanbase':
db = pymysql.connect(db=self._param.database, user=self._param.username, host=self._param.host,
db = pymysql.connect(db=self._param.database, user=self._param.username, host=safe_host,
port=self._param.port, password=self._param.password, charset='utf8mb4')
elif self._param.db_type == 'postgres':
db = psycopg2.connect(dbname=self._param.database, user=self._param.username, host=self._param.host,
db = psycopg2.connect(dbname=self._param.database, user=self._param.username, host=safe_host,
port=self._param.port, password=self._param.password)
elif self._param.db_type == 'mssql':
conn_str = (
r'DRIVER={ODBC Driver 17 for SQL Server};'
r'SERVER=' + self._param.host + ',' + str(self._param.port) + ';'
r'SERVER=' + safe_host + ',' + str(self._param.port) + ';'
r'DATABASE=' + self._param.database + ';'
r'UID=' + self._param.username + ';'
r'PWD=' + self._param.password
@@ -171,7 +186,7 @@ class ExeSQL(ToolBase, ABC):
try:
db = trino.dbapi.connect(
host=self._param.host,
host=safe_host,
port=int(self._param.port or 8080),
user=self._param.username or "ragflow",
catalog=catalog,
@@ -185,7 +200,7 @@ class ExeSQL(ToolBase, ABC):
import ibm_db
conn_str = (
f"DATABASE={self._param.database};"
f"HOSTNAME={self._param.host};"
f"HOSTNAME={safe_host};"
f"PORT={self._param.port};"
f"PROTOCOL=TCPIP;"
f"UID={self._param.username};"

View File

@@ -0,0 +1,143 @@
#
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""SSRF-guard regression tests for the ExeSQL agent tool.
The DB host/port are node-author-controlled and connected to server-side, so
``ExeSQL._invoke`` must reject hosts that resolve to non-public addresses
(loopback, link-local/metadata, RFC1918) before opening any connection, and
must dial the validated/resolved public IP for allowed hosts — mirroring the
``test_db_connection`` endpoint guard (PR #14860).
``agent.tools.exesql`` is loaded in isolation (its package ``__init__`` would
auto-discover every tool and pull in the full agent framework), with the heavy
DB drivers and the agent base classes stubbed so only the real SSRF guard runs.
"""
import importlib.util
import sys
import types
from pathlib import Path
from types import SimpleNamespace
import pytest
_REPO_ROOT = Path(__file__).resolve().parents[4]
class _RecordingPyMySQL:
"""Fake pymysql whose connect() records the host it was asked to dial."""
def __init__(self):
self.dialed_host = None
def connect(self, *args, **kwargs):
self.dialed_host = kwargs.get("host")
raise RuntimeError("connection attempted") # stop before real DB I/O
_fake_pymysql = _RecordingPyMySQL()
def _load_exesql_module():
# Stub the heavy drivers and the agent base so the module imports cleanly.
for name in ("pandas", "psycopg2", "pyodbc"):
mod = types.ModuleType(name)
mod.connect = lambda *a, **k: None
sys.modules.setdefault(name, mod)
pymysql_stub = types.ModuleType("pymysql")
pymysql_stub.connect = _fake_pymysql.connect
sys.modules["pymysql"] = pymysql_stub
base = types.ModuleType("agent.tools.base")
class _ToolParamBase:
def __init__(self):
pass
class _ToolBase:
def __init__(self, *a, **k):
pass
base.ToolParamBase = _ToolParamBase
base.ToolBase = _ToolBase
base.ToolMeta = dict
for pkg in ("agent", "agent.tools"):
sys.modules.setdefault(pkg, types.ModuleType(pkg))
sys.modules["agent.tools.base"] = base
# Neutralize the @timeout decorator so _invoke is a plain method.
conn_utils = types.ModuleType("common.connection_utils")
conn_utils.timeout = lambda *a, **k: (lambda f: f)
sys.modules["common.connection_utils"] = conn_utils
spec = importlib.util.spec_from_file_location(
"exesql_uut", _REPO_ROOT / "agent" / "tools" / "exesql.py"
)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
return mod
_exesql_mod = _load_exesql_module()
ExeSQL = _exesql_mod.ExeSQL
def _build_exesql(host, db_type="mysql"):
cpn = ExeSQL.__new__(ExeSQL)
cpn._canvas = SimpleNamespace()
cpn._param = SimpleNamespace(
host=host, port=3306, db_type=db_type,
database="db", username="u", password="p",
)
# Neutralize the component machinery that runs before the host check.
cpn.check_if_canceled = lambda *_a, **_k: False
cpn.get_input_elements_from_text = lambda _sql: {}
cpn.set_input_value = lambda *_a, **_k: None
cpn.string_format = lambda sql, _args: sql
return cpn
@pytest.mark.p2
@pytest.mark.parametrize("host", ["127.0.0.1", "169.254.169.254", "10.0.0.5"])
def test_internal_host_rejected_before_connect(host):
_fake_pymysql.dialed_host = None
cpn = _build_exesql(host)
with pytest.raises(Exception) as ei:
cpn._invoke(sql="SELECT 1")
assert "not allowed" in str(ei.value)
# The SSRF guard must fire before any connection is attempted.
assert _fake_pymysql.dialed_host is None
@pytest.mark.p2
def test_empty_host_rejected():
cpn = _build_exesql("")
with pytest.raises(Exception) as ei:
cpn._invoke(sql="SELECT 1")
assert "not allowed" in str(ei.value)
@pytest.mark.p2
def test_public_host_dials_validated_ip(monkeypatch):
# Public host: pretend it resolves to a public IP, and ensure the driver is
# dialed with that validated IP (not the raw hostname).
monkeypatch.setattr(_exesql_mod, "assert_host_is_safe", lambda _h: "93.184.216.34")
_fake_pymysql.dialed_host = None
cpn = _build_exesql("db.example.com")
with pytest.raises(Exception):
cpn._invoke(sql="SELECT 1") # RuntimeError from the recording connect
assert _fake_pymysql.dialed_host == "93.184.216.34"