Files
ragflow/agent/tools/exesql.py

302 lines
13 KiB
Python

#
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import contextlib
import json
import logging
import os
import re
from abc import ABC
import pandas as pd
import pymysql
import psycopg2
import pyodbc
from agent.tools.base import ToolParamBase, ToolBase, ToolMeta
from common.connection_utils import timeout
from common.ssrf_guard import assert_host_is_safe
class ExeSQLParam(ToolParamBase):
"""
Define the ExeSQL component parameters.
"""
def __init__(self):
self.meta: ToolMeta = {
"name": "execute_sql",
"description": "This is a tool that can execute SQL.",
"parameters": {"sql": {"type": "string", "description": "The SQL needs to be executed.", "default": "{sys.query}", "required": True}},
}
super().__init__()
self.db_type = "mysql"
self.database = ""
self.username = ""
self.host = ""
self.port = 3306
self.password = ""
self.max_records = 1024
def check(self):
self.check_valid_value(self.db_type, "Choose DB type", ["mysql", "postgres", "mariadb", "mssql", "IBM DB2", "trino", "oceanbase"])
self.check_empty(self.database, "Database name")
self.check_empty(self.username, "database username")
self.check_empty(self.host, "IP Address")
self.check_positive_integer(self.port, "IP Port")
if self.db_type != "trino":
self.check_empty(self.password, "Database password")
self.check_positive_integer(self.max_records, "Maximum number of records")
if self.database == "rag_flow":
if self.host == "ragflow-mysql":
raise ValueError("For the security reason, it does not support database named rag_flow.")
if self.password == "infini_rag_flow":
raise ValueError("For the security reason, it does not support database named rag_flow.")
def get_input_form(self) -> dict[str, dict]:
return {"sql": {"name": "SQL", "type": "line"}}
class ExeSQL(ToolBase, ABC):
component_name = "ExeSQL"
@timeout(int(os.environ.get("COMPONENT_EXEC_TIMEOUT", 60)))
def _invoke(self, **kwargs):
if self.check_if_canceled("ExeSQL processing"):
return
def convert_decimals(obj):
from decimal import Decimal
import math
if isinstance(obj, float):
# Handle NaN and Infinity which are not valid JSON values
if math.isnan(obj) or math.isinf(obj):
return None
return obj
if isinstance(obj, Decimal):
return float(obj) # 或 str(obj)
elif isinstance(obj, dict):
return {k: convert_decimals(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [convert_decimals(item) for item in obj]
return obj
sql = kwargs.get("sql")
if not sql:
raise Exception("SQL for `ExeSQL` MUST not be empty.")
if self.check_if_canceled("ExeSQL processing"):
return
vars = self.get_input_elements_from_text(sql)
args = {}
for k, o in vars.items():
args[k] = o["value"]
if not isinstance(args[k], str):
try:
args[k] = json.dumps(args[k], ensure_ascii=False)
except Exception:
args[k] = str(args[k])
self.set_input_value(k, args[k])
sql = self.string_format(sql, args)
if self.check_if_canceled("ExeSQL processing"):
return
# The DB host/port are node-author-controlled and are connected to
# server-side, so guard against SSRF (internal hosts, loopback, cloud
# metadata) the same way the `test_db_connection` endpoint does. Connect
# to the validated, resolved public IP so a later DNS change cannot
# rebind the host to an internal address (mirrors agent_api.py).
logging.info(f"ExeSQL validating database host: {self._param.host}")
try:
safe_host = assert_host_is_safe(self._param.host)
except ValueError as e:
logging.warning(f"ExeSQL rejected database host {self._param.host}: {e}")
raise Exception(f"Database host '{self._param.host}' is not allowed: {e}")
logging.info(f"ExeSQL validated database host {self._param.host} -> {safe_host}")
sqls = sql.split(";")
if self._param.db_type in ["mysql", "mariadb"]:
db = pymysql.connect(db=self._param.database, user=self._param.username, host=safe_host, port=self._param.port, password=self._param.password)
elif self._param.db_type == "oceanbase":
db = pymysql.connect(db=self._param.database, user=self._param.username, host=safe_host, port=self._param.port, password=self._param.password, charset="utf8mb4")
elif self._param.db_type == "postgres":
db = psycopg2.connect(dbname=self._param.database, user=self._param.username, host=safe_host, port=self._param.port, password=self._param.password)
elif self._param.db_type == "mssql":
conn_str = (
r"DRIVER={ODBC Driver 17 for SQL Server};"
r"SERVER=" + safe_host + "," + str(self._param.port) + ";"
r"DATABASE=" + self._param.database + ";"
r"UID=" + self._param.username + ";"
r"PWD=" + self._param.password
)
db = pyodbc.connect(conn_str)
elif self._param.db_type == "trino":
try:
import trino
from trino.auth import BasicAuthentication
except Exception:
raise Exception("Missing dependency 'trino'. Please install: pip install trino")
def _parse_catalog_schema(db: str):
if not db:
return None, None
if "." in db:
c, s = db.split(".", 1)
elif "/" in db:
c, s = db.split("/", 1)
else:
c, s = db, "default"
return c, s
catalog, schema = _parse_catalog_schema(self._param.database)
if not catalog:
raise Exception("For Trino, `database` must be 'catalog.schema' or at least 'catalog'.")
http_scheme = "https" if os.environ.get("TRINO_USE_TLS", "0") == "1" else "http"
auth = None
if http_scheme == "https" and self._param.password:
auth = BasicAuthentication(self._param.username, self._param.password)
try:
db = trino.dbapi.connect(
host=safe_host, port=int(self._param.port or 8080), user=self._param.username or "ragflow", catalog=catalog, schema=schema or "default", http_scheme=http_scheme, auth=auth
)
except Exception as e:
raise Exception("Database Connection Failed! \n" + str(e))
elif self._param.db_type == "IBM DB2":
import ibm_db
conn_str = f"DATABASE={self._param.database};HOSTNAME={safe_host};PORT={self._param.port};PROTOCOL=TCPIP;UID={self._param.username};PWD={self._param.password};"
try:
conn = ibm_db.connect(conn_str, "", "")
except Exception as e:
raise Exception("Database Connection Failed! \n" + str(e))
try:
sql_res = []
formalized_content = []
for single_sql in sqls:
if self.check_if_canceled("ExeSQL processing"):
return
single_sql = single_sql.replace("```", "").strip()
if not single_sql:
continue
single_sql = re.sub(r"\[ID:[0-9]+\]", "", single_sql)
try:
stmt = ibm_db.exec_immediate(conn, single_sql)
rows = []
row = ibm_db.fetch_assoc(stmt)
while row and len(rows) < self._param.max_records:
if self.check_if_canceled("ExeSQL processing"):
return
rows.append(row)
row = ibm_db.fetch_assoc(stmt)
if not rows:
sql_res.append({"content": "No record in the database!"})
continue
df = pd.DataFrame(rows)
for col in df.columns:
if pd.api.types.is_datetime64_any_dtype(df[col]):
df[col] = df[col].dt.strftime("%Y-%m-%d")
df = df.where(pd.notnull(df), None)
sql_res.append(convert_decimals(df.to_dict(orient="records")))
formalized_content.append(df.to_markdown(index=False, floatfmt=".6f"))
except Exception as e:
# Keep the node alive on a bad statement: report and continue.
with contextlib.suppress(Exception):
ibm_db.rollback(conn)
msg = f"SQL Execution Failed: {single_sql}\n{str(e)}"
sql_res.append({"content": msg})
formalized_content.append(msg)
continue
finally:
with contextlib.suppress(Exception):
ibm_db.close(conn)
self.set_output("json", sql_res)
self.set_output("formalized_content", "\n\n".join(formalized_content))
return self.output("formalized_content")
try:
cursor = db.cursor()
except Exception as e:
with contextlib.suppress(Exception):
db.close()
raise Exception("Database Connection Failed! \n" + str(e))
try:
sql_res = []
formalized_content = []
for single_sql in sqls:
if self.check_if_canceled("ExeSQL processing"):
return
single_sql = single_sql.replace("```", "").strip()
if not single_sql:
continue
single_sql = re.sub(r"\[ID:[0-9]+\]", "", single_sql)
if re.match(r"^(insert|update|delete)\b", single_sql, flags=re.IGNORECASE):
sql_res.append({"content": "For security reasons, INSERT, UPDATE, and DELETE statements are not supported."})
formalized_content.append("For security reasons, INSERT, UPDATE, and DELETE statements are not supported.")
continue
try:
cursor.execute(single_sql)
if cursor.rowcount == 0:
sql_res.append({"content": "No record in the database!"})
break
if self._param.db_type == "mssql":
single_res = pd.DataFrame.from_records(cursor.fetchmany(self._param.max_records), columns=[desc[0] for desc in cursor.description])
else:
single_res = pd.DataFrame([i for i in cursor.fetchmany(self._param.max_records)])
single_res.columns = [i[0] for i in cursor.description]
for col in single_res.columns:
if pd.api.types.is_datetime64_any_dtype(single_res[col]):
single_res[col] = single_res[col].dt.strftime("%Y-%m-%d")
single_res = single_res.where(pd.notnull(single_res), None)
sql_res.append(convert_decimals(single_res.to_dict(orient="records")))
formalized_content.append(single_res.to_markdown(index=False, floatfmt=".6f"))
except Exception as e:
# A failing statement must not abort the node: report it and keep
# going so earlier results survive and later statements still run.
# The rollback clears PostgreSQL's aborted-transaction state, which
# would otherwise make every subsequent statement fail too.
with contextlib.suppress(Exception):
db.rollback()
msg = f"SQL Execution Failed: {single_sql}\n{str(e)}"
sql_res.append({"content": msg})
formalized_content.append(msg)
continue
finally:
with contextlib.suppress(Exception):
cursor.close()
with contextlib.suppress(Exception):
db.close()
self.set_output("json", sql_res)
self.set_output("formalized_content", "\n\n".join(formalized_content))
return self.output("formalized_content")
def thoughts(self) -> str:
return "Query sent—waiting for the data."