Fix: add SSRF guard for agent test_db_connection endpoint (#14860)

### What problem does this PR solve?

Closes #14858

The `test_db_connection` endpoint in the agent API accepts a
user-supplied `host` and connects to it directly via database drivers
(MySQL/PostgreSQL) without any validation. This allows an attacker to
probe internal network addresses (e.g. `127.0.0.1`, `10.x.x.x`,
link-local, etc.) through the server — a classic Server-Side Request
Forgery (SSRF) vulnerability.

This PR adds an SSRF guard that resolves the host and rejects any
address that is not globally routable before the database connection is
attempted.

**Changes:**
- **`common/ssrf_guard.py`** — Added `assert_host_is_safe()`, a
host-level counterpart of the existing `assert_url_is_safe()`, designed
for non-HTTP protocols (database drivers) where there is no URL to
parse.
- **`api/apps/restful_apis/agent_api.py`** — Call
`assert_host_is_safe(req["host"])` at the top of `test_db_connection` so
that non-public hosts are rejected early with a clear error message.

Fixes #14858

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)

---------

Co-authored-by: Jin Hai <haijin.chn@gmail.com>
This commit is contained in:
dale053
2026-05-17 23:32:44 -07:00
committed by GitHub
parent b09da6e347
commit fe82a96193
2 changed files with 62 additions and 7 deletions

View File

@@ -60,6 +60,7 @@ from api.utils.api_utils import (
validate_request,
)
from common import settings
from common.ssrf_guard import assert_host_is_safe
from common.constants import RetCode
from common.misc_utils import get_uuid, thread_pool_exec
from peewee import MySQLDatabase, PostgresqlDatabase
@@ -782,12 +783,27 @@ async def rerun_agent(tenant_id):
@login_required
async def test_db_connection():
req = await get_request_json()
try:
safe_host = assert_host_is_safe(req["host"])
except ValueError as exc:
logging.warning(
"Rejected test_db_connection: unsafe host %r (db_type=%s, user=%s): %s",
req.get("host"), req.get("db_type"), current_user.id, exc,
)
return get_data_error_result(message=str(exc))
except OSError as exc:
logging.warning(
"Rejected test_db_connection: cannot resolve host %r (db_type=%s, user=%s): %s",
req.get("host"), req.get("db_type"), current_user.id, exc,
)
logging.debug("Full resolver exception for host %r", req.get("host"), exc_info=True)
return get_data_error_result(message=f"Could not resolve host {req.get('host')!r}.")
try:
if req["db_type"] in ["mysql", "mariadb"]:
db = MySQLDatabase(
req["database"],
user=req["username"],
host=req["host"],
host=safe_host,
port=req["port"],
password=req["password"],
)
@@ -797,7 +813,7 @@ async def test_db_connection():
db = MySQLDatabase(
req["database"],
user=req["username"],
host=req["host"],
host=safe_host,
port=req["port"],
password=req["password"],
charset="utf8mb4",
@@ -808,7 +824,7 @@ async def test_db_connection():
db = PostgresqlDatabase(
req["database"],
user=req["username"],
host=req["host"],
host=safe_host,
port=req["port"],
password=req["password"],
)
@@ -819,7 +835,7 @@ async def test_db_connection():
connection_string = (
f"DRIVER={{ODBC Driver 17 for SQL Server}};"
f"SERVER={req['host']},{req['port']};"
f"SERVER={safe_host},{req['port']};"
f"DATABASE={req['database']};"
f"UID={req['username']};"
f"PWD={req['password']};"
@@ -838,7 +854,7 @@ async def test_db_connection():
conn_str = (
f"DATABASE={req['database']};"
f"HOSTNAME={req['host']};"
f"HOSTNAME={safe_host};"
f"PORT={req['port']};"
f"PROTOCOL=TCPIP;"
f"UID={req['username']};"
@@ -847,7 +863,7 @@ async def test_db_connection():
logging.info(
"DATABASE=%s;HOSTNAME=%s;PORT=%s;PROTOCOL=TCPIP;UID=%s;PWD=****;",
req["database"],
req["host"],
safe_host,
req["port"],
req["username"],
)
@@ -873,7 +889,7 @@ async def test_db_connection():
auth = trino.BasicAuthentication(req.get("username") or "ragflow", req["password"])
conn = trino.dbapi.connect(
host=req["host"],
host=safe_host,
port=int(req["port"] or 8080),
user=req["username"] or "ragflow",
catalog=catalog,

View File

@@ -170,3 +170,42 @@ def assert_url_is_safe(
raise ValueError(f"Hostname {hostname!r} resolved to no addresses.")
return hostname, resolved_ip
def assert_host_is_safe(host: str) -> str:
"""Raise ``ValueError`` if *host* resolves to a non-public IP (SSRF guard for raw host/port connections).
This is the host-level counterpart of :func:`assert_url_is_safe`, intended
for callers that connect via database drivers or other non-HTTP protocols
where there is no URL to parse.
Returns the first validated public IP string so the caller can pin it if needed.
"""
if not host:
raise ValueError("Host must not be empty.")
try:
addr_infos = socket.getaddrinfo(host, None)
except socket.gaierror as exc:
logger.warning("SSRF guard could not resolve host=%r reason=%s", host, exc)
raise ValueError(f"Could not resolve host {host!r}: {exc}") from exc
resolved_ip: str | None = None
for _family, _type, _proto, _canonname, sockaddr in addr_infos:
raw_ip = ipaddress.ip_address(sockaddr[0])
eff_ip = _effective_ip(raw_ip)
if not eff_ip.is_global:
logger.warning(
"SSRF guard blocked host: host=%r resolved to non-public address=%s",
host,
raw_ip,
)
raise ValueError(f"Host resolves to a non-public address ({raw_ip}), which is not allowed.")
if resolved_ip is None:
resolved_ip = str(raw_ip)
if resolved_ip is None:
logger.warning("SSRF guard blocked host: host=%r resolved to no addresses", host)
raise ValueError(f"Host {host!r} resolved to no addresses.")
return resolved_ip