mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-07-01 00:05:43 +08:00
Fix: add SSRF guard for agent test_db_connection endpoint (#14860)
### What problem does this PR solve? Closes #14858 The `test_db_connection` endpoint in the agent API accepts a user-supplied `host` and connects to it directly via database drivers (MySQL/PostgreSQL) without any validation. This allows an attacker to probe internal network addresses (e.g. `127.0.0.1`, `10.x.x.x`, link-local, etc.) through the server — a classic Server-Side Request Forgery (SSRF) vulnerability. This PR adds an SSRF guard that resolves the host and rejects any address that is not globally routable before the database connection is attempted. **Changes:** - **`common/ssrf_guard.py`** — Added `assert_host_is_safe()`, a host-level counterpart of the existing `assert_url_is_safe()`, designed for non-HTTP protocols (database drivers) where there is no URL to parse. - **`api/apps/restful_apis/agent_api.py`** — Call `assert_host_is_safe(req["host"])` at the top of `test_db_connection` so that non-public hosts are rejected early with a clear error message. Fixes #14858 ### Type of change - [x] Bug Fix (non-breaking change which fixes an issue) --------- Co-authored-by: Jin Hai <haijin.chn@gmail.com>
This commit is contained in:
@@ -60,6 +60,7 @@ from api.utils.api_utils import (
|
||||
validate_request,
|
||||
)
|
||||
from common import settings
|
||||
from common.ssrf_guard import assert_host_is_safe
|
||||
from common.constants import RetCode
|
||||
from common.misc_utils import get_uuid, thread_pool_exec
|
||||
from peewee import MySQLDatabase, PostgresqlDatabase
|
||||
@@ -782,12 +783,27 @@ async def rerun_agent(tenant_id):
|
||||
@login_required
|
||||
async def test_db_connection():
|
||||
req = await get_request_json()
|
||||
try:
|
||||
safe_host = assert_host_is_safe(req["host"])
|
||||
except ValueError as exc:
|
||||
logging.warning(
|
||||
"Rejected test_db_connection: unsafe host %r (db_type=%s, user=%s): %s",
|
||||
req.get("host"), req.get("db_type"), current_user.id, exc,
|
||||
)
|
||||
return get_data_error_result(message=str(exc))
|
||||
except OSError as exc:
|
||||
logging.warning(
|
||||
"Rejected test_db_connection: cannot resolve host %r (db_type=%s, user=%s): %s",
|
||||
req.get("host"), req.get("db_type"), current_user.id, exc,
|
||||
)
|
||||
logging.debug("Full resolver exception for host %r", req.get("host"), exc_info=True)
|
||||
return get_data_error_result(message=f"Could not resolve host {req.get('host')!r}.")
|
||||
try:
|
||||
if req["db_type"] in ["mysql", "mariadb"]:
|
||||
db = MySQLDatabase(
|
||||
req["database"],
|
||||
user=req["username"],
|
||||
host=req["host"],
|
||||
host=safe_host,
|
||||
port=req["port"],
|
||||
password=req["password"],
|
||||
)
|
||||
@@ -797,7 +813,7 @@ async def test_db_connection():
|
||||
db = MySQLDatabase(
|
||||
req["database"],
|
||||
user=req["username"],
|
||||
host=req["host"],
|
||||
host=safe_host,
|
||||
port=req["port"],
|
||||
password=req["password"],
|
||||
charset="utf8mb4",
|
||||
@@ -808,7 +824,7 @@ async def test_db_connection():
|
||||
db = PostgresqlDatabase(
|
||||
req["database"],
|
||||
user=req["username"],
|
||||
host=req["host"],
|
||||
host=safe_host,
|
||||
port=req["port"],
|
||||
password=req["password"],
|
||||
)
|
||||
@@ -819,7 +835,7 @@ async def test_db_connection():
|
||||
|
||||
connection_string = (
|
||||
f"DRIVER={{ODBC Driver 17 for SQL Server}};"
|
||||
f"SERVER={req['host']},{req['port']};"
|
||||
f"SERVER={safe_host},{req['port']};"
|
||||
f"DATABASE={req['database']};"
|
||||
f"UID={req['username']};"
|
||||
f"PWD={req['password']};"
|
||||
@@ -838,7 +854,7 @@ async def test_db_connection():
|
||||
|
||||
conn_str = (
|
||||
f"DATABASE={req['database']};"
|
||||
f"HOSTNAME={req['host']};"
|
||||
f"HOSTNAME={safe_host};"
|
||||
f"PORT={req['port']};"
|
||||
f"PROTOCOL=TCPIP;"
|
||||
f"UID={req['username']};"
|
||||
@@ -847,7 +863,7 @@ async def test_db_connection():
|
||||
logging.info(
|
||||
"DATABASE=%s;HOSTNAME=%s;PORT=%s;PROTOCOL=TCPIP;UID=%s;PWD=****;",
|
||||
req["database"],
|
||||
req["host"],
|
||||
safe_host,
|
||||
req["port"],
|
||||
req["username"],
|
||||
)
|
||||
@@ -873,7 +889,7 @@ async def test_db_connection():
|
||||
auth = trino.BasicAuthentication(req.get("username") or "ragflow", req["password"])
|
||||
|
||||
conn = trino.dbapi.connect(
|
||||
host=req["host"],
|
||||
host=safe_host,
|
||||
port=int(req["port"] or 8080),
|
||||
user=req["username"] or "ragflow",
|
||||
catalog=catalog,
|
||||
|
||||
@@ -170,3 +170,42 @@ def assert_url_is_safe(
|
||||
raise ValueError(f"Hostname {hostname!r} resolved to no addresses.")
|
||||
|
||||
return hostname, resolved_ip
|
||||
|
||||
|
||||
def assert_host_is_safe(host: str) -> str:
|
||||
"""Raise ``ValueError`` if *host* resolves to a non-public IP (SSRF guard for raw host/port connections).
|
||||
|
||||
This is the host-level counterpart of :func:`assert_url_is_safe`, intended
|
||||
for callers that connect via database drivers or other non-HTTP protocols
|
||||
where there is no URL to parse.
|
||||
|
||||
Returns the first validated public IP string so the caller can pin it if needed.
|
||||
"""
|
||||
if not host:
|
||||
raise ValueError("Host must not be empty.")
|
||||
|
||||
try:
|
||||
addr_infos = socket.getaddrinfo(host, None)
|
||||
except socket.gaierror as exc:
|
||||
logger.warning("SSRF guard could not resolve host=%r reason=%s", host, exc)
|
||||
raise ValueError(f"Could not resolve host {host!r}: {exc}") from exc
|
||||
|
||||
resolved_ip: str | None = None
|
||||
for _family, _type, _proto, _canonname, sockaddr in addr_infos:
|
||||
raw_ip = ipaddress.ip_address(sockaddr[0])
|
||||
eff_ip = _effective_ip(raw_ip)
|
||||
if not eff_ip.is_global:
|
||||
logger.warning(
|
||||
"SSRF guard blocked host: host=%r resolved to non-public address=%s",
|
||||
host,
|
||||
raw_ip,
|
||||
)
|
||||
raise ValueError(f"Host resolves to a non-public address ({raw_ip}), which is not allowed.")
|
||||
if resolved_ip is None:
|
||||
resolved_ip = str(raw_ip)
|
||||
|
||||
if resolved_ip is None:
|
||||
logger.warning("SSRF guard blocked host: host=%r resolved to no addresses", host)
|
||||
raise ValueError(f"Host {host!r} resolved to no addresses.")
|
||||
|
||||
return resolved_ip
|
||||
|
||||
Reference in New Issue
Block a user