mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-06-29 15:31:05 +08:00
fix: guard SSRF in ExeSQL agent tool DB host (#15609)
### What problem does this PR solve? Closes #15608. The ExeSQL agent tool (`agent/tools/exesql.py`) opens database connections to a node-author-controlled host/port with no SSRF validation. The sibling `test_db_connection` endpoint already validates the host via `common.ssrf_guard.assert_host_is_safe` (added by PR #14860), but the tool that actually performs the connection at agent run time was left unguarded — so the guard is bypassed simply by running the agent. An agent author can point the host at `127.0.0.1`, `169.254.169.254` (cloud metadata), or any internal RFC1918 host/port, turning ExeSQL into an internal port-scanner / metadata-fetch primitive. ### Fix Mirror the accepted endpoint guard: validate (and resolve) the host once, before the `db_type` dispatch, and connect to the validated public IP so a later DNS change cannot rebind the host to an internal address. - Add `from common.ssrf_guard import assert_host_is_safe`. - `safe_host = assert_host_is_safe(self._param.host)` before the dispatch (rejects loopback, link-local/metadata, RFC1918, and unresolvable hosts). - Substitute the validated IP into all 6 driver branches: mysql/mariadb, oceanbase, postgres, mssql, trino, IBM DB2. Adds `test/unit_test/agent/tools/test_exesql_ssrf.py` covering loopback, link-local/metadata, RFC1918, and empty-host rejection (before any connection), plus an allowed host dialing the validated IP. ### Validation - `python3 -m py_compile agent/tools/exesql.py` - `ruff check agent/tools/exesql.py test/unit_test/agent/tools/test_exesql_ssrf.py` - `pytest test/unit_test/agent/tools/test_exesql_ssrf.py` — 5 passed ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --------- Co-authored-by: Zhichang Yu <yuzhichang@gmail.com>
This commit is contained in:
@@ -15,6 +15,7 @@
|
||||
#
|
||||
import contextlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from abc import ABC
|
||||
@@ -24,6 +25,7 @@ import psycopg2
|
||||
import pyodbc
|
||||
from agent.tools.base import ToolParamBase, ToolBase, ToolMeta
|
||||
from common.connection_utils import timeout
|
||||
from common.ssrf_guard import assert_host_is_safe
|
||||
|
||||
|
||||
class ExeSQLParam(ToolParamBase):
|
||||
@@ -123,20 +125,33 @@ class ExeSQL(ToolBase, ABC):
|
||||
if self.check_if_canceled("ExeSQL processing"):
|
||||
return
|
||||
|
||||
# The DB host/port are node-author-controlled and are connected to
|
||||
# server-side, so guard against SSRF (internal hosts, loopback, cloud
|
||||
# metadata) the same way the `test_db_connection` endpoint does. Connect
|
||||
# to the validated, resolved public IP so a later DNS change cannot
|
||||
# rebind the host to an internal address (mirrors agent_api.py).
|
||||
logging.info(f"ExeSQL validating database host: {self._param.host}")
|
||||
try:
|
||||
safe_host = assert_host_is_safe(self._param.host)
|
||||
except ValueError as e:
|
||||
logging.warning(f"ExeSQL rejected database host {self._param.host}: {e}")
|
||||
raise Exception(f"Database host '{self._param.host}' is not allowed: {e}")
|
||||
logging.info(f"ExeSQL validated database host {self._param.host} -> {safe_host}")
|
||||
|
||||
sqls = sql.split(";")
|
||||
if self._param.db_type in ["mysql", "mariadb"]:
|
||||
db = pymysql.connect(db=self._param.database, user=self._param.username, host=self._param.host,
|
||||
db = pymysql.connect(db=self._param.database, user=self._param.username, host=safe_host,
|
||||
port=self._param.port, password=self._param.password)
|
||||
elif self._param.db_type == 'oceanbase':
|
||||
db = pymysql.connect(db=self._param.database, user=self._param.username, host=self._param.host,
|
||||
db = pymysql.connect(db=self._param.database, user=self._param.username, host=safe_host,
|
||||
port=self._param.port, password=self._param.password, charset='utf8mb4')
|
||||
elif self._param.db_type == 'postgres':
|
||||
db = psycopg2.connect(dbname=self._param.database, user=self._param.username, host=self._param.host,
|
||||
db = psycopg2.connect(dbname=self._param.database, user=self._param.username, host=safe_host,
|
||||
port=self._param.port, password=self._param.password)
|
||||
elif self._param.db_type == 'mssql':
|
||||
conn_str = (
|
||||
r'DRIVER={ODBC Driver 17 for SQL Server};'
|
||||
r'SERVER=' + self._param.host + ',' + str(self._param.port) + ';'
|
||||
r'SERVER=' + safe_host + ',' + str(self._param.port) + ';'
|
||||
r'DATABASE=' + self._param.database + ';'
|
||||
r'UID=' + self._param.username + ';'
|
||||
r'PWD=' + self._param.password
|
||||
@@ -171,7 +186,7 @@ class ExeSQL(ToolBase, ABC):
|
||||
|
||||
try:
|
||||
db = trino.dbapi.connect(
|
||||
host=self._param.host,
|
||||
host=safe_host,
|
||||
port=int(self._param.port or 8080),
|
||||
user=self._param.username or "ragflow",
|
||||
catalog=catalog,
|
||||
@@ -185,7 +200,7 @@ class ExeSQL(ToolBase, ABC):
|
||||
import ibm_db
|
||||
conn_str = (
|
||||
f"DATABASE={self._param.database};"
|
||||
f"HOSTNAME={self._param.host};"
|
||||
f"HOSTNAME={safe_host};"
|
||||
f"PORT={self._param.port};"
|
||||
f"PROTOCOL=TCPIP;"
|
||||
f"UID={self._param.username};"
|
||||
|
||||
143
test/unit_test/agent/tools/test_exesql_ssrf.py
Normal file
143
test/unit_test/agent/tools/test_exesql_ssrf.py
Normal file
@@ -0,0 +1,143 @@
|
||||
#
|
||||
# Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
"""SSRF-guard regression tests for the ExeSQL agent tool.
|
||||
|
||||
The DB host/port are node-author-controlled and connected to server-side, so
|
||||
``ExeSQL._invoke`` must reject hosts that resolve to non-public addresses
|
||||
(loopback, link-local/metadata, RFC1918) before opening any connection, and
|
||||
must dial the validated/resolved public IP for allowed hosts — mirroring the
|
||||
``test_db_connection`` endpoint guard (PR #14860).
|
||||
|
||||
``agent.tools.exesql`` is loaded in isolation (its package ``__init__`` would
|
||||
auto-discover every tool and pull in the full agent framework), with the heavy
|
||||
DB drivers and the agent base classes stubbed so only the real SSRF guard runs.
|
||||
"""
|
||||
import importlib.util
|
||||
import sys
|
||||
import types
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
_REPO_ROOT = Path(__file__).resolve().parents[4]
|
||||
|
||||
|
||||
class _RecordingPyMySQL:
|
||||
"""Fake pymysql whose connect() records the host it was asked to dial."""
|
||||
|
||||
def __init__(self):
|
||||
self.dialed_host = None
|
||||
|
||||
def connect(self, *args, **kwargs):
|
||||
self.dialed_host = kwargs.get("host")
|
||||
raise RuntimeError("connection attempted") # stop before real DB I/O
|
||||
|
||||
|
||||
_fake_pymysql = _RecordingPyMySQL()
|
||||
|
||||
|
||||
def _load_exesql_module():
|
||||
# Stub the heavy drivers and the agent base so the module imports cleanly.
|
||||
for name in ("pandas", "psycopg2", "pyodbc"):
|
||||
mod = types.ModuleType(name)
|
||||
mod.connect = lambda *a, **k: None
|
||||
sys.modules.setdefault(name, mod)
|
||||
|
||||
pymysql_stub = types.ModuleType("pymysql")
|
||||
pymysql_stub.connect = _fake_pymysql.connect
|
||||
sys.modules["pymysql"] = pymysql_stub
|
||||
|
||||
base = types.ModuleType("agent.tools.base")
|
||||
|
||||
class _ToolParamBase:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
class _ToolBase:
|
||||
def __init__(self, *a, **k):
|
||||
pass
|
||||
|
||||
base.ToolParamBase = _ToolParamBase
|
||||
base.ToolBase = _ToolBase
|
||||
base.ToolMeta = dict
|
||||
for pkg in ("agent", "agent.tools"):
|
||||
sys.modules.setdefault(pkg, types.ModuleType(pkg))
|
||||
sys.modules["agent.tools.base"] = base
|
||||
|
||||
# Neutralize the @timeout decorator so _invoke is a plain method.
|
||||
conn_utils = types.ModuleType("common.connection_utils")
|
||||
conn_utils.timeout = lambda *a, **k: (lambda f: f)
|
||||
sys.modules["common.connection_utils"] = conn_utils
|
||||
|
||||
spec = importlib.util.spec_from_file_location(
|
||||
"exesql_uut", _REPO_ROOT / "agent" / "tools" / "exesql.py"
|
||||
)
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(mod)
|
||||
return mod
|
||||
|
||||
|
||||
_exesql_mod = _load_exesql_module()
|
||||
ExeSQL = _exesql_mod.ExeSQL
|
||||
|
||||
|
||||
def _build_exesql(host, db_type="mysql"):
|
||||
cpn = ExeSQL.__new__(ExeSQL)
|
||||
cpn._canvas = SimpleNamespace()
|
||||
cpn._param = SimpleNamespace(
|
||||
host=host, port=3306, db_type=db_type,
|
||||
database="db", username="u", password="p",
|
||||
)
|
||||
# Neutralize the component machinery that runs before the host check.
|
||||
cpn.check_if_canceled = lambda *_a, **_k: False
|
||||
cpn.get_input_elements_from_text = lambda _sql: {}
|
||||
cpn.set_input_value = lambda *_a, **_k: None
|
||||
cpn.string_format = lambda sql, _args: sql
|
||||
return cpn
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
@pytest.mark.parametrize("host", ["127.0.0.1", "169.254.169.254", "10.0.0.5"])
|
||||
def test_internal_host_rejected_before_connect(host):
|
||||
_fake_pymysql.dialed_host = None
|
||||
cpn = _build_exesql(host)
|
||||
with pytest.raises(Exception) as ei:
|
||||
cpn._invoke(sql="SELECT 1")
|
||||
assert "not allowed" in str(ei.value)
|
||||
# The SSRF guard must fire before any connection is attempted.
|
||||
assert _fake_pymysql.dialed_host is None
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_empty_host_rejected():
|
||||
cpn = _build_exesql("")
|
||||
with pytest.raises(Exception) as ei:
|
||||
cpn._invoke(sql="SELECT 1")
|
||||
assert "not allowed" in str(ei.value)
|
||||
|
||||
|
||||
@pytest.mark.p2
|
||||
def test_public_host_dials_validated_ip(monkeypatch):
|
||||
# Public host: pretend it resolves to a public IP, and ensure the driver is
|
||||
# dialed with that validated IP (not the raw hostname).
|
||||
monkeypatch.setattr(_exesql_mod, "assert_host_is_safe", lambda _h: "93.184.216.34")
|
||||
_fake_pymysql.dialed_host = None
|
||||
|
||||
cpn = _build_exesql("db.example.com")
|
||||
with pytest.raises(Exception):
|
||||
cpn._invoke(sql="SELECT 1") # RuntimeError from the recording connect
|
||||
assert _fake_pymysql.dialed_host == "93.184.216.34"
|
||||
Reference in New Issue
Block a user