From fe82a961932fdb656d4d65517e3705746b361e0c Mon Sep 17 00:00:00 2001 From: dale053 Date: Sun, 17 May 2026 23:32:44 -0700 Subject: [PATCH] Fix: add SSRF guard for agent test_db_connection endpoint (#14860) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What problem does this PR solve? Closes #14858 The `test_db_connection` endpoint in the agent API accepts a user-supplied `host` and connects to it directly via database drivers (MySQL/PostgreSQL) without any validation. This allows an attacker to probe internal network addresses (e.g. `127.0.0.1`, `10.x.x.x`, link-local, etc.) through the server — a classic Server-Side Request Forgery (SSRF) vulnerability. This PR adds an SSRF guard that resolves the host and rejects any address that is not globally routable before the database connection is attempted. **Changes:** - **`common/ssrf_guard.py`** — Added `assert_host_is_safe()`, a host-level counterpart of the existing `assert_url_is_safe()`, designed for non-HTTP protocols (database drivers) where there is no URL to parse. - **`api/apps/restful_apis/agent_api.py`** — Call `assert_host_is_safe(req["host"])` at the top of `test_db_connection` so that non-public hosts are rejected early with a clear error message. Fixes #14858 ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --------- Co-authored-by: Jin Hai --- api/apps/restful_apis/agent_api.py | 30 +++++++++++++++++------ common/ssrf_guard.py | 39 ++++++++++++++++++++++++++++++ 2 files changed, 62 insertions(+), 7 deletions(-) diff --git a/api/apps/restful_apis/agent_api.py b/api/apps/restful_apis/agent_api.py index 4d07842592..f88ce90b3f 100644 --- a/api/apps/restful_apis/agent_api.py +++ b/api/apps/restful_apis/agent_api.py @@ -60,6 +60,7 @@ from api.utils.api_utils import ( validate_request, ) from common import settings +from common.ssrf_guard import assert_host_is_safe from common.constants import RetCode from common.misc_utils import get_uuid, thread_pool_exec from peewee import MySQLDatabase, PostgresqlDatabase @@ -782,12 +783,27 @@ async def rerun_agent(tenant_id): @login_required async def test_db_connection(): req = await get_request_json() + try: + safe_host = assert_host_is_safe(req["host"]) + except ValueError as exc: + logging.warning( + "Rejected test_db_connection: unsafe host %r (db_type=%s, user=%s): %s", + req.get("host"), req.get("db_type"), current_user.id, exc, + ) + return get_data_error_result(message=str(exc)) + except OSError as exc: + logging.warning( + "Rejected test_db_connection: cannot resolve host %r (db_type=%s, user=%s): %s", + req.get("host"), req.get("db_type"), current_user.id, exc, + ) + logging.debug("Full resolver exception for host %r", req.get("host"), exc_info=True) + return get_data_error_result(message=f"Could not resolve host {req.get('host')!r}.") try: if req["db_type"] in ["mysql", "mariadb"]: db = MySQLDatabase( req["database"], user=req["username"], - host=req["host"], + host=safe_host, port=req["port"], password=req["password"], ) @@ -797,7 +813,7 @@ async def test_db_connection(): db = MySQLDatabase( req["database"], user=req["username"], - host=req["host"], + host=safe_host, port=req["port"], password=req["password"], charset="utf8mb4", @@ -808,7 +824,7 @@ async def test_db_connection(): db = PostgresqlDatabase( req["database"], user=req["username"], - host=req["host"], + host=safe_host, port=req["port"], password=req["password"], ) @@ -819,7 +835,7 @@ async def test_db_connection(): connection_string = ( f"DRIVER={{ODBC Driver 17 for SQL Server}};" - f"SERVER={req['host']},{req['port']};" + f"SERVER={safe_host},{req['port']};" f"DATABASE={req['database']};" f"UID={req['username']};" f"PWD={req['password']};" @@ -838,7 +854,7 @@ async def test_db_connection(): conn_str = ( f"DATABASE={req['database']};" - f"HOSTNAME={req['host']};" + f"HOSTNAME={safe_host};" f"PORT={req['port']};" f"PROTOCOL=TCPIP;" f"UID={req['username']};" @@ -847,7 +863,7 @@ async def test_db_connection(): logging.info( "DATABASE=%s;HOSTNAME=%s;PORT=%s;PROTOCOL=TCPIP;UID=%s;PWD=****;", req["database"], - req["host"], + safe_host, req["port"], req["username"], ) @@ -873,7 +889,7 @@ async def test_db_connection(): auth = trino.BasicAuthentication(req.get("username") or "ragflow", req["password"]) conn = trino.dbapi.connect( - host=req["host"], + host=safe_host, port=int(req["port"] or 8080), user=req["username"] or "ragflow", catalog=catalog, diff --git a/common/ssrf_guard.py b/common/ssrf_guard.py index b60bcd4bc9..4f87b94d7b 100644 --- a/common/ssrf_guard.py +++ b/common/ssrf_guard.py @@ -170,3 +170,42 @@ def assert_url_is_safe( raise ValueError(f"Hostname {hostname!r} resolved to no addresses.") return hostname, resolved_ip + + +def assert_host_is_safe(host: str) -> str: + """Raise ``ValueError`` if *host* resolves to a non-public IP (SSRF guard for raw host/port connections). + + This is the host-level counterpart of :func:`assert_url_is_safe`, intended + for callers that connect via database drivers or other non-HTTP protocols + where there is no URL to parse. + + Returns the first validated public IP string so the caller can pin it if needed. + """ + if not host: + raise ValueError("Host must not be empty.") + + try: + addr_infos = socket.getaddrinfo(host, None) + except socket.gaierror as exc: + logger.warning("SSRF guard could not resolve host=%r reason=%s", host, exc) + raise ValueError(f"Could not resolve host {host!r}: {exc}") from exc + + resolved_ip: str | None = None + for _family, _type, _proto, _canonname, sockaddr in addr_infos: + raw_ip = ipaddress.ip_address(sockaddr[0]) + eff_ip = _effective_ip(raw_ip) + if not eff_ip.is_global: + logger.warning( + "SSRF guard blocked host: host=%r resolved to non-public address=%s", + host, + raw_ip, + ) + raise ValueError(f"Host resolves to a non-public address ({raw_ip}), which is not allowed.") + if resolved_ip is None: + resolved_ip = str(raw_ip) + + if resolved_ip is None: + logger.warning("SSRF guard blocked host: host=%r resolved to no addresses", host) + raise ValueError(f"Host {host!r} resolved to no addresses.") + + return resolved_ip