diff --git a/api/apps/restful_apis/agent_api.py b/api/apps/restful_apis/agent_api.py index 4d07842592..f88ce90b3f 100644 --- a/api/apps/restful_apis/agent_api.py +++ b/api/apps/restful_apis/agent_api.py @@ -60,6 +60,7 @@ from api.utils.api_utils import ( validate_request, ) from common import settings +from common.ssrf_guard import assert_host_is_safe from common.constants import RetCode from common.misc_utils import get_uuid, thread_pool_exec from peewee import MySQLDatabase, PostgresqlDatabase @@ -782,12 +783,27 @@ async def rerun_agent(tenant_id): @login_required async def test_db_connection(): req = await get_request_json() + try: + safe_host = assert_host_is_safe(req["host"]) + except ValueError as exc: + logging.warning( + "Rejected test_db_connection: unsafe host %r (db_type=%s, user=%s): %s", + req.get("host"), req.get("db_type"), current_user.id, exc, + ) + return get_data_error_result(message=str(exc)) + except OSError as exc: + logging.warning( + "Rejected test_db_connection: cannot resolve host %r (db_type=%s, user=%s): %s", + req.get("host"), req.get("db_type"), current_user.id, exc, + ) + logging.debug("Full resolver exception for host %r", req.get("host"), exc_info=True) + return get_data_error_result(message=f"Could not resolve host {req.get('host')!r}.") try: if req["db_type"] in ["mysql", "mariadb"]: db = MySQLDatabase( req["database"], user=req["username"], - host=req["host"], + host=safe_host, port=req["port"], password=req["password"], ) @@ -797,7 +813,7 @@ async def test_db_connection(): db = MySQLDatabase( req["database"], user=req["username"], - host=req["host"], + host=safe_host, port=req["port"], password=req["password"], charset="utf8mb4", @@ -808,7 +824,7 @@ async def test_db_connection(): db = PostgresqlDatabase( req["database"], user=req["username"], - host=req["host"], + host=safe_host, port=req["port"], password=req["password"], ) @@ -819,7 +835,7 @@ async def test_db_connection(): connection_string = ( f"DRIVER={{ODBC Driver 17 for SQL Server}};" - f"SERVER={req['host']},{req['port']};" + f"SERVER={safe_host},{req['port']};" f"DATABASE={req['database']};" f"UID={req['username']};" f"PWD={req['password']};" @@ -838,7 +854,7 @@ async def test_db_connection(): conn_str = ( f"DATABASE={req['database']};" - f"HOSTNAME={req['host']};" + f"HOSTNAME={safe_host};" f"PORT={req['port']};" f"PROTOCOL=TCPIP;" f"UID={req['username']};" @@ -847,7 +863,7 @@ async def test_db_connection(): logging.info( "DATABASE=%s;HOSTNAME=%s;PORT=%s;PROTOCOL=TCPIP;UID=%s;PWD=****;", req["database"], - req["host"], + safe_host, req["port"], req["username"], ) @@ -873,7 +889,7 @@ async def test_db_connection(): auth = trino.BasicAuthentication(req.get("username") or "ragflow", req["password"]) conn = trino.dbapi.connect( - host=req["host"], + host=safe_host, port=int(req["port"] or 8080), user=req["username"] or "ragflow", catalog=catalog, diff --git a/common/ssrf_guard.py b/common/ssrf_guard.py index b60bcd4bc9..4f87b94d7b 100644 --- a/common/ssrf_guard.py +++ b/common/ssrf_guard.py @@ -170,3 +170,42 @@ def assert_url_is_safe( raise ValueError(f"Hostname {hostname!r} resolved to no addresses.") return hostname, resolved_ip + + +def assert_host_is_safe(host: str) -> str: + """Raise ``ValueError`` if *host* resolves to a non-public IP (SSRF guard for raw host/port connections). + + This is the host-level counterpart of :func:`assert_url_is_safe`, intended + for callers that connect via database drivers or other non-HTTP protocols + where there is no URL to parse. + + Returns the first validated public IP string so the caller can pin it if needed. + """ + if not host: + raise ValueError("Host must not be empty.") + + try: + addr_infos = socket.getaddrinfo(host, None) + except socket.gaierror as exc: + logger.warning("SSRF guard could not resolve host=%r reason=%s", host, exc) + raise ValueError(f"Could not resolve host {host!r}: {exc}") from exc + + resolved_ip: str | None = None + for _family, _type, _proto, _canonname, sockaddr in addr_infos: + raw_ip = ipaddress.ip_address(sockaddr[0]) + eff_ip = _effective_ip(raw_ip) + if not eff_ip.is_global: + logger.warning( + "SSRF guard blocked host: host=%r resolved to non-public address=%s", + host, + raw_ip, + ) + raise ValueError(f"Host resolves to a non-public address ({raw_ip}), which is not allowed.") + if resolved_ip is None: + resolved_ip = str(raw_ip) + + if resolved_ip is None: + logger.warning("SSRF guard blocked host: host=%r resolved to no addresses", host) + raise ValueError(f"Host {host!r} resolved to no addresses.") + + return resolved_ip