diff --git a/agent/tools/exesql.py b/agent/tools/exesql.py index 3f969f4316..305801124c 100644 --- a/agent/tools/exesql.py +++ b/agent/tools/exesql.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import contextlib import json import os import re @@ -195,43 +196,43 @@ class ExeSQL(ToolBase, ABC): except Exception as e: raise Exception("Database Connection Failed! \n" + str(e)) - sql_res = [] - formalized_content = [] - for single_sql in sqls: - if self.check_if_canceled("ExeSQL processing"): - ibm_db.close(conn) - return - - single_sql = single_sql.replace("```", "").strip() - if not single_sql: - continue - single_sql = re.sub(r"\[ID:[0-9]+\]", "", single_sql) - - stmt = ibm_db.exec_immediate(conn, single_sql) - rows = [] - row = ibm_db.fetch_assoc(stmt) - while row and len(rows) < self._param.max_records: + try: + sql_res = [] + formalized_content = [] + for single_sql in sqls: if self.check_if_canceled("ExeSQL processing"): - ibm_db.close(conn) return - rows.append(row) + + single_sql = single_sql.replace("```", "").strip() + if not single_sql: + continue + single_sql = re.sub(r"\[ID:[0-9]+\]", "", single_sql) + + stmt = ibm_db.exec_immediate(conn, single_sql) + rows = [] row = ibm_db.fetch_assoc(stmt) + while row and len(rows) < self._param.max_records: + if self.check_if_canceled("ExeSQL processing"): + return + rows.append(row) + row = ibm_db.fetch_assoc(stmt) - if not rows: - sql_res.append({"content": "No record in the database!"}) - continue + if not rows: + sql_res.append({"content": "No record in the database!"}) + continue - df = pd.DataFrame(rows) - for col in df.columns: - if pd.api.types.is_datetime64_any_dtype(df[col]): - df[col] = df[col].dt.strftime("%Y-%m-%d") + df = pd.DataFrame(rows) + for col in df.columns: + if pd.api.types.is_datetime64_any_dtype(df[col]): + df[col] = df[col].dt.strftime("%Y-%m-%d") - df = df.where(pd.notnull(df), None) + df = df.where(pd.notnull(df), None) - sql_res.append(convert_decimals(df.to_dict(orient="records"))) - formalized_content.append(df.to_markdown(index=False, floatfmt=".6f")) - - ibm_db.close(conn) + sql_res.append(convert_decimals(df.to_dict(orient="records"))) + formalized_content.append(df.to_markdown(index=False, floatfmt=".6f")) + finally: + with contextlib.suppress(Exception): + ibm_db.close(conn) self.set_output("json", sql_res) self.set_output("formalized_content", "\n\n".join(formalized_content)) @@ -239,42 +240,45 @@ class ExeSQL(ToolBase, ABC): try: cursor = db.cursor() except Exception as e: + with contextlib.suppress(Exception): + db.close() raise Exception("Database Connection Failed! \n" + str(e)) - sql_res = [] - formalized_content = [] - for single_sql in sqls: - if self.check_if_canceled("ExeSQL processing"): + try: + sql_res = [] + formalized_content = [] + for single_sql in sqls: + if self.check_if_canceled("ExeSQL processing"): + return + + single_sql = single_sql.replace('```', '').strip() + if not single_sql: + continue + single_sql = re.sub(r"\[ID:[0-9]+\]", "", single_sql) + cursor.execute(single_sql) + if cursor.rowcount == 0: + sql_res.append({"content": "No record in the database!"}) + break + if self._param.db_type == 'mssql': + single_res = pd.DataFrame.from_records(cursor.fetchmany(self._param.max_records), + columns=[desc[0] for desc in cursor.description]) + else: + single_res = pd.DataFrame([i for i in cursor.fetchmany(self._param.max_records)]) + single_res.columns = [i[0] for i in cursor.description] + + for col in single_res.columns: + if pd.api.types.is_datetime64_any_dtype(single_res[col]): + single_res[col] = single_res[col].dt.strftime('%Y-%m-%d') + + single_res = single_res.where(pd.notnull(single_res), None) + + sql_res.append(convert_decimals(single_res.to_dict(orient='records'))) + formalized_content.append(single_res.to_markdown(index=False, floatfmt=".6f")) + finally: + with contextlib.suppress(Exception): cursor.close() + with contextlib.suppress(Exception): db.close() - return - - single_sql = single_sql.replace('```','') - if not single_sql: - continue - single_sql = re.sub(r"\[ID:[0-9]+\]", "", single_sql) - cursor.execute(single_sql) - if cursor.rowcount == 0: - sql_res.append({"content": "No record in the database!"}) - break - if self._param.db_type == 'mssql': - single_res = pd.DataFrame.from_records(cursor.fetchmany(self._param.max_records), - columns=[desc[0] for desc in cursor.description]) - else: - single_res = pd.DataFrame([i for i in cursor.fetchmany(self._param.max_records)]) - single_res.columns = [i[0] for i in cursor.description] - - for col in single_res.columns: - if pd.api.types.is_datetime64_any_dtype(single_res[col]): - single_res[col] = single_res[col].dt.strftime('%Y-%m-%d') - - single_res = single_res.where(pd.notnull(single_res), None) - - sql_res.append(convert_decimals(single_res.to_dict(orient='records'))) - formalized_content.append(single_res.to_markdown(index=False, floatfmt=".6f")) - - cursor.close() - db.close() self.set_output("json", sql_res) self.set_output("formalized_content", "\n\n".join(formalized_content))