Files
ragflow/agent/tools/exesql.py
philluiz2323 e256d91ade fix: guard SSRF in ExeSQL agent tool DB host (#15609)
### What problem does this PR solve?

Closes #15608.

The ExeSQL agent tool (`agent/tools/exesql.py`) opens database
connections to a node-author-controlled host/port with no SSRF
validation. The sibling `test_db_connection` endpoint already validates
the host via `common.ssrf_guard.assert_host_is_safe` (added by PR
#14860), but the tool that actually performs the connection at agent run
time was left unguarded — so the guard is bypassed simply by running the
agent. An agent author can point the host at `127.0.0.1`,
`169.254.169.254` (cloud metadata), or any internal RFC1918 host/port,
turning ExeSQL into an internal port-scanner / metadata-fetch primitive.

### Fix

Mirror the accepted endpoint guard: validate (and resolve) the host
once, before the `db_type` dispatch, and connect to the validated public
IP so a later DNS change cannot rebind the host to an internal address.

- Add `from common.ssrf_guard import assert_host_is_safe`.
- `safe_host = assert_host_is_safe(self._param.host)` before the
dispatch (rejects loopback, link-local/metadata, RFC1918, and
unresolvable hosts).
- Substitute the validated IP into all 6 driver branches: mysql/mariadb,
oceanbase, postgres, mssql, trino, IBM DB2.

Adds `test/unit_test/agent/tools/test_exesql_ssrf.py` covering loopback,
link-local/metadata, RFC1918, and empty-host rejection (before any
connection), plus an allowed host dialing the validated IP.

### Validation

- `python3 -m py_compile agent/tools/exesql.py`
- `ruff check agent/tools/exesql.py
test/unit_test/agent/tools/test_exesql_ssrf.py`
- `pytest test/unit_test/agent/tools/test_exesql_ssrf.py` — 5 passed

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)

---------

Co-authored-by: Zhichang Yu <yuzhichang@gmail.com>
2026-06-29 09:45:16 +08:00

329 lines
14 KiB
Python

#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import contextlib
import json
import logging
import os
import re
from abc import ABC
import pandas as pd
import pymysql
import psycopg2
import pyodbc
from agent.tools.base import ToolParamBase, ToolBase, ToolMeta
from common.connection_utils import timeout
from common.ssrf_guard import assert_host_is_safe
class ExeSQLParam(ToolParamBase):
"""
Define the ExeSQL component parameters.
"""
def __init__(self):
self.meta:ToolMeta = {
"name": "execute_sql",
"description": "This is a tool that can execute SQL.",
"parameters": {
"sql": {
"type": "string",
"description": "The SQL needs to be executed.",
"default": "{sys.query}",
"required": True
}
}
}
super().__init__()
self.db_type = "mysql"
self.database = ""
self.username = ""
self.host = ""
self.port = 3306
self.password = ""
self.max_records = 1024
def check(self):
self.check_valid_value(self.db_type, "Choose DB type", ['mysql', 'postgres', 'mariadb', 'mssql', 'IBM DB2', 'trino', 'oceanbase'])
self.check_empty(self.database, "Database name")
self.check_empty(self.username, "database username")
self.check_empty(self.host, "IP Address")
self.check_positive_integer(self.port, "IP Port")
if self.db_type != "trino":
self.check_empty(self.password, "Database password")
self.check_positive_integer(self.max_records, "Maximum number of records")
if self.database == "rag_flow":
if self.host == "ragflow-mysql":
raise ValueError("For the security reason, it does not support database named rag_flow.")
if self.password == "infini_rag_flow":
raise ValueError("For the security reason, it does not support database named rag_flow.")
def get_input_form(self) -> dict[str, dict]:
return {
"sql": {
"name": "SQL",
"type": "line"
}
}
class ExeSQL(ToolBase, ABC):
component_name = "ExeSQL"
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60)))
def _invoke(self, **kwargs):
if self.check_if_canceled("ExeSQL processing"):
return
def convert_decimals(obj):
from decimal import Decimal
import math
if isinstance(obj, float):
# Handle NaN and Infinity which are not valid JSON values
if math.isnan(obj) or math.isinf(obj):
return None
return obj
if isinstance(obj, Decimal):
return float(obj) # 或 str(obj)
elif isinstance(obj, dict):
return {k: convert_decimals(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [convert_decimals(item) for item in obj]
return obj
sql = kwargs.get("sql")
if not sql:
raise Exception("SQL for `ExeSQL` MUST not be empty.")
if self.check_if_canceled("ExeSQL processing"):
return
vars = self.get_input_elements_from_text(sql)
args = {}
for k, o in vars.items():
args[k] = o["value"]
if not isinstance(args[k], str):
try:
args[k] = json.dumps(args[k], ensure_ascii=False)
except Exception:
args[k] = str(args[k])
self.set_input_value(k, args[k])
sql = self.string_format(sql, args)
if self.check_if_canceled("ExeSQL processing"):
return
# The DB host/port are node-author-controlled and are connected to
# server-side, so guard against SSRF (internal hosts, loopback, cloud
# metadata) the same way the `test_db_connection` endpoint does. Connect
# to the validated, resolved public IP so a later DNS change cannot
# rebind the host to an internal address (mirrors agent_api.py).
logging.info(f"ExeSQL validating database host: {self._param.host}")
try:
safe_host = assert_host_is_safe(self._param.host)
except ValueError as e:
logging.warning(f"ExeSQL rejected database host {self._param.host}: {e}")
raise Exception(f"Database host '{self._param.host}' is not allowed: {e}")
logging.info(f"ExeSQL validated database host {self._param.host} -> {safe_host}")
sqls = sql.split(";")
if self._param.db_type in ["mysql", "mariadb"]:
db = pymysql.connect(db=self._param.database, user=self._param.username, host=safe_host,
port=self._param.port, password=self._param.password)
elif self._param.db_type == 'oceanbase':
db = pymysql.connect(db=self._param.database, user=self._param.username, host=safe_host,
port=self._param.port, password=self._param.password, charset='utf8mb4')
elif self._param.db_type == 'postgres':
db = psycopg2.connect(dbname=self._param.database, user=self._param.username, host=safe_host,
port=self._param.port, password=self._param.password)
elif self._param.db_type == 'mssql':
conn_str = (
r'DRIVER={ODBC Driver 17 for SQL Server};'
r'SERVER=' + safe_host + ',' + str(self._param.port) + ';'
r'DATABASE=' + self._param.database + ';'
r'UID=' + self._param.username + ';'
r'PWD=' + self._param.password
)
db = pyodbc.connect(conn_str)
elif self._param.db_type == 'trino':
try:
import trino
from trino.auth import BasicAuthentication
except Exception:
raise Exception("Missing dependency 'trino'. Please install: pip install trino")
def _parse_catalog_schema(db: str):
if not db:
return None, None
if "." in db:
c, s = db.split(".", 1)
elif "/" in db:
c, s = db.split("/", 1)
else:
c, s = db, "default"
return c, s
catalog, schema = _parse_catalog_schema(self._param.database)
if not catalog:
raise Exception("For Trino, `database` must be 'catalog.schema' or at least 'catalog'.")
http_scheme = "https" if os.environ.get("TRINO_USE_TLS", "0") == "1" else "http"
auth = None
if http_scheme == "https" and self._param.password:
auth = BasicAuthentication(self._param.username, self._param.password)
try:
db = trino.dbapi.connect(
host=safe_host,
port=int(self._param.port or 8080),
user=self._param.username or "ragflow",
catalog=catalog,
schema=schema or "default",
http_scheme=http_scheme,
auth=auth
)
except Exception as e:
raise Exception("Database Connection Failed! \n" + str(e))
elif self._param.db_type == 'IBM DB2':
import ibm_db
conn_str = (
f"DATABASE={self._param.database};"
f"HOSTNAME={safe_host};"
f"PORT={self._param.port};"
f"PROTOCOL=TCPIP;"
f"UID={self._param.username};"
f"PWD={self._param.password};"
)
try:
conn = ibm_db.connect(conn_str, "", "")
except Exception as e:
raise Exception("Database Connection Failed! \n" + str(e))
try:
sql_res = []
formalized_content = []
for single_sql in sqls:
if self.check_if_canceled("ExeSQL processing"):
return
single_sql = single_sql.replace("```", "").strip()
if not single_sql:
continue
single_sql = re.sub(r"\[ID:[0-9]+\]", "", single_sql)
try:
stmt = ibm_db.exec_immediate(conn, single_sql)
rows = []
row = ibm_db.fetch_assoc(stmt)
while row and len(rows) < self._param.max_records:
if self.check_if_canceled("ExeSQL processing"):
return
rows.append(row)
row = ibm_db.fetch_assoc(stmt)
if not rows:
sql_res.append({"content": "No record in the database!"})
continue
df = pd.DataFrame(rows)
for col in df.columns:
if pd.api.types.is_datetime64_any_dtype(df[col]):
df[col] = df[col].dt.strftime("%Y-%m-%d")
df = df.where(pd.notnull(df), None)
sql_res.append(convert_decimals(df.to_dict(orient="records")))
formalized_content.append(df.to_markdown(index=False, floatfmt=".6f"))
except Exception as e:
# Keep the node alive on a bad statement: report and continue.
with contextlib.suppress(Exception):
ibm_db.rollback(conn)
msg = f"SQL Execution Failed: {single_sql}\n{str(e)}"
sql_res.append({"content": msg})
formalized_content.append(msg)
continue
finally:
with contextlib.suppress(Exception):
ibm_db.close(conn)
self.set_output("json", sql_res)
self.set_output("formalized_content", "\n\n".join(formalized_content))
return self.output("formalized_content")
try:
cursor = db.cursor()
except Exception as e:
with contextlib.suppress(Exception):
db.close()
raise Exception("Database Connection Failed! \n" + str(e))
try:
sql_res = []
formalized_content = []
for single_sql in sqls:
if self.check_if_canceled("ExeSQL processing"):
return
single_sql = single_sql.replace('```', '').strip()
if not single_sql:
continue
single_sql = re.sub(r"\[ID:[0-9]+\]", "", single_sql)
if re.match(r"^(insert|update|delete)\b", single_sql, flags=re.IGNORECASE):
sql_res.append({"content": "For security reasons, INSERT, UPDATE, and DELETE statements are not supported."})
formalized_content.append("For security reasons, INSERT, UPDATE, and DELETE statements are not supported.")
continue
try:
cursor.execute(single_sql)
if cursor.rowcount == 0:
sql_res.append({"content": "No record in the database!"})
break
if self._param.db_type == 'mssql':
single_res = pd.DataFrame.from_records(cursor.fetchmany(self._param.max_records),
columns=[desc[0] for desc in cursor.description])
else:
single_res = pd.DataFrame([i for i in cursor.fetchmany(self._param.max_records)])
single_res.columns = [i[0] for i in cursor.description]
for col in single_res.columns:
if pd.api.types.is_datetime64_any_dtype(single_res[col]):
single_res[col] = single_res[col].dt.strftime('%Y-%m-%d')
single_res = single_res.where(pd.notnull(single_res), None)
sql_res.append(convert_decimals(single_res.to_dict(orient='records')))
formalized_content.append(single_res.to_markdown(index=False, floatfmt=".6f"))
except Exception as e:
# A failing statement must not abort the node: report it and keep
# going so earlier results survive and later statements still run.
# The rollback clears PostgreSQL's aborted-transaction state, which
# would otherwise make every subsequent statement fail too.
with contextlib.suppress(Exception):
db.rollback()
msg = f"SQL Execution Failed: {single_sql}\n{str(e)}"
sql_res.append({"content": msg})
formalized_content.append(msg)
continue
finally:
with contextlib.suppress(Exception):
cursor.close()
with contextlib.suppress(Exception):
db.close()
self.set_output("json", sql_res)
self.set_output("formalized_content", "\n\n".join(formalized_content))
return self.output("formalized_content")
def thoughts(self) -> str:
return "Query sent—waiting for the data."