diff --git a/agent/tools/exesql.py b/agent/tools/exesql.py index dec6942f2e..86935cc49d 100644 --- a/agent/tools/exesql.py +++ b/agent/tools/exesql.py @@ -15,6 +15,7 @@ # import contextlib import json +import logging import os import re from abc import ABC @@ -24,6 +25,7 @@ import psycopg2 import pyodbc from agent.tools.base import ToolParamBase, ToolBase, ToolMeta from common.connection_utils import timeout +from common.ssrf_guard import assert_host_is_safe class ExeSQLParam(ToolParamBase): @@ -123,20 +125,33 @@ class ExeSQL(ToolBase, ABC): if self.check_if_canceled("ExeSQL processing"): return + # The DB host/port are node-author-controlled and are connected to + # server-side, so guard against SSRF (internal hosts, loopback, cloud + # metadata) the same way the `test_db_connection` endpoint does. Connect + # to the validated, resolved public IP so a later DNS change cannot + # rebind the host to an internal address (mirrors agent_api.py). + logging.info(f"ExeSQL validating database host: {self._param.host}") + try: + safe_host = assert_host_is_safe(self._param.host) + except ValueError as e: + logging.warning(f"ExeSQL rejected database host {self._param.host}: {e}") + raise Exception(f"Database host '{self._param.host}' is not allowed: {e}") + logging.info(f"ExeSQL validated database host {self._param.host} -> {safe_host}") + sqls = sql.split(";") if self._param.db_type in ["mysql", "mariadb"]: - db = pymysql.connect(db=self._param.database, user=self._param.username, host=self._param.host, + db = pymysql.connect(db=self._param.database, user=self._param.username, host=safe_host, port=self._param.port, password=self._param.password) elif self._param.db_type == 'oceanbase': - db = pymysql.connect(db=self._param.database, user=self._param.username, host=self._param.host, + db = pymysql.connect(db=self._param.database, user=self._param.username, host=safe_host, port=self._param.port, password=self._param.password, charset='utf8mb4') elif self._param.db_type == 'postgres': - db = psycopg2.connect(dbname=self._param.database, user=self._param.username, host=self._param.host, + db = psycopg2.connect(dbname=self._param.database, user=self._param.username, host=safe_host, port=self._param.port, password=self._param.password) elif self._param.db_type == 'mssql': conn_str = ( r'DRIVER={ODBC Driver 17 for SQL Server};' - r'SERVER=' + self._param.host + ',' + str(self._param.port) + ';' + r'SERVER=' + safe_host + ',' + str(self._param.port) + ';' r'DATABASE=' + self._param.database + ';' r'UID=' + self._param.username + ';' r'PWD=' + self._param.password @@ -171,7 +186,7 @@ class ExeSQL(ToolBase, ABC): try: db = trino.dbapi.connect( - host=self._param.host, + host=safe_host, port=int(self._param.port or 8080), user=self._param.username or "ragflow", catalog=catalog, @@ -185,7 +200,7 @@ class ExeSQL(ToolBase, ABC): import ibm_db conn_str = ( f"DATABASE={self._param.database};" - f"HOSTNAME={self._param.host};" + f"HOSTNAME={safe_host};" f"PORT={self._param.port};" f"PROTOCOL=TCPIP;" f"UID={self._param.username};" diff --git a/test/unit_test/agent/tools/test_exesql_ssrf.py b/test/unit_test/agent/tools/test_exesql_ssrf.py new file mode 100644 index 0000000000..ec8961d8cd --- /dev/null +++ b/test/unit_test/agent/tools/test_exesql_ssrf.py @@ -0,0 +1,143 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""SSRF-guard regression tests for the ExeSQL agent tool. + +The DB host/port are node-author-controlled and connected to server-side, so +``ExeSQL._invoke`` must reject hosts that resolve to non-public addresses +(loopback, link-local/metadata, RFC1918) before opening any connection, and +must dial the validated/resolved public IP for allowed hosts — mirroring the +``test_db_connection`` endpoint guard (PR #14860). + +``agent.tools.exesql`` is loaded in isolation (its package ``__init__`` would +auto-discover every tool and pull in the full agent framework), with the heavy +DB drivers and the agent base classes stubbed so only the real SSRF guard runs. +""" +import importlib.util +import sys +import types +from pathlib import Path +from types import SimpleNamespace + +import pytest + +_REPO_ROOT = Path(__file__).resolve().parents[4] + + +class _RecordingPyMySQL: + """Fake pymysql whose connect() records the host it was asked to dial.""" + + def __init__(self): + self.dialed_host = None + + def connect(self, *args, **kwargs): + self.dialed_host = kwargs.get("host") + raise RuntimeError("connection attempted") # stop before real DB I/O + + +_fake_pymysql = _RecordingPyMySQL() + + +def _load_exesql_module(): + # Stub the heavy drivers and the agent base so the module imports cleanly. + for name in ("pandas", "psycopg2", "pyodbc"): + mod = types.ModuleType(name) + mod.connect = lambda *a, **k: None + sys.modules.setdefault(name, mod) + + pymysql_stub = types.ModuleType("pymysql") + pymysql_stub.connect = _fake_pymysql.connect + sys.modules["pymysql"] = pymysql_stub + + base = types.ModuleType("agent.tools.base") + + class _ToolParamBase: + def __init__(self): + pass + + class _ToolBase: + def __init__(self, *a, **k): + pass + + base.ToolParamBase = _ToolParamBase + base.ToolBase = _ToolBase + base.ToolMeta = dict + for pkg in ("agent", "agent.tools"): + sys.modules.setdefault(pkg, types.ModuleType(pkg)) + sys.modules["agent.tools.base"] = base + + # Neutralize the @timeout decorator so _invoke is a plain method. + conn_utils = types.ModuleType("common.connection_utils") + conn_utils.timeout = lambda *a, **k: (lambda f: f) + sys.modules["common.connection_utils"] = conn_utils + + spec = importlib.util.spec_from_file_location( + "exesql_uut", _REPO_ROOT / "agent" / "tools" / "exesql.py" + ) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + return mod + + +_exesql_mod = _load_exesql_module() +ExeSQL = _exesql_mod.ExeSQL + + +def _build_exesql(host, db_type="mysql"): + cpn = ExeSQL.__new__(ExeSQL) + cpn._canvas = SimpleNamespace() + cpn._param = SimpleNamespace( + host=host, port=3306, db_type=db_type, + database="db", username="u", password="p", + ) + # Neutralize the component machinery that runs before the host check. + cpn.check_if_canceled = lambda *_a, **_k: False + cpn.get_input_elements_from_text = lambda _sql: {} + cpn.set_input_value = lambda *_a, **_k: None + cpn.string_format = lambda sql, _args: sql + return cpn + + +@pytest.mark.p2 +@pytest.mark.parametrize("host", ["127.0.0.1", "169.254.169.254", "10.0.0.5"]) +def test_internal_host_rejected_before_connect(host): + _fake_pymysql.dialed_host = None + cpn = _build_exesql(host) + with pytest.raises(Exception) as ei: + cpn._invoke(sql="SELECT 1") + assert "not allowed" in str(ei.value) + # The SSRF guard must fire before any connection is attempted. + assert _fake_pymysql.dialed_host is None + + +@pytest.mark.p2 +def test_empty_host_rejected(): + cpn = _build_exesql("") + with pytest.raises(Exception) as ei: + cpn._invoke(sql="SELECT 1") + assert "not allowed" in str(ei.value) + + +@pytest.mark.p2 +def test_public_host_dials_validated_ip(monkeypatch): + # Public host: pretend it resolves to a public IP, and ensure the driver is + # dialed with that validated IP (not the raw hostname). + monkeypatch.setattr(_exesql_mod, "assert_host_is_safe", lambda _h: "93.184.216.34") + _fake_pymysql.dialed_host = None + + cpn = _build_exesql("db.example.com") + with pytest.raises(Exception): + cpn._invoke(sql="SELECT 1") # RuntimeError from the recording connect + assert _fake_pymysql.dialed_host == "93.184.216.34"