# # Copyright 2025 The InfiniFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import asyncio import logging import threading import weakref from concurrent.futures import ThreadPoolExecutor from concurrent.futures import TimeoutError as FuturesTimeoutError from dataclasses import dataclass from string import Template from typing import Any, Literal, Protocol from typing_extensions import override from common.constants import MCPServerType from mcp.client.session import ClientSession from mcp.client.sse import sse_client from mcp.client.streamable_http import streamablehttp_client from mcp.types import CallToolResult, ListToolsResult, TextContent, Tool MCPTaskType = Literal["list_tools", "tool_call"] MCPTask = tuple[MCPTaskType, dict[str, Any], asyncio.Queue[Any]] class ToolCallSession(Protocol): def tool_call(self, name: str, arguments: dict[str, Any], timeout: float | int = 10) -> str: ... @dataclass(frozen=True) class MCPToolBinding: session: ToolCallSession original_name: str class MCPToolCallSession(ToolCallSession): _ALL_INSTANCES: weakref.WeakSet["MCPToolCallSession"] = weakref.WeakSet() def __init__(self, mcp_server: Any, server_variables: dict[str, Any] | None = None, custom_header = None) -> None: self.__class__._ALL_INSTANCES.add(self) self._custom_header = custom_header self._mcp_server = mcp_server self._server_variables = server_variables or {} self._queue = asyncio.Queue() self._close = False self._event_loop = asyncio.new_event_loop() self._thread_pool = ThreadPoolExecutor(max_workers=1) self._thread_pool.submit(self._event_loop.run_forever) asyncio.run_coroutine_threadsafe(self._mcp_server_loop(), self._event_loop) async def _mcp_server_loop(self) -> None: url = self._mcp_server.url.strip() raw_headers: dict[str, str] = self._mcp_server.headers or {} custom_header: dict[str, str] = self._custom_header or {} headers: dict[str, str] = {} for h, v in raw_headers.items(): nh = Template(h).safe_substitute(self._server_variables) nv = Template(v).safe_substitute(self._server_variables) if nh.strip() and nv.strip().strip("Bearer"): headers[nh] = nv for h, v in custom_header.items(): nh = Template(h).safe_substitute(custom_header) nv = Template(v).safe_substitute(custom_header) headers[nh] = nv if self._mcp_server.server_type == MCPServerType.SSE: # SSE transport try: async with sse_client(url, headers) as stream: async with ClientSession(*stream) as client_session: try: await asyncio.wait_for(client_session.initialize(), timeout=5) logging.info("client_session initialized successfully") await self._process_mcp_tasks(client_session) except asyncio.TimeoutError: msg = f"Timeout initializing client_session for server {self._mcp_server.id}" logging.error(msg) await self._process_mcp_tasks(None, msg) except asyncio.CancelledError: logging.warning(f"SSE transport MCP session cancelled for server {self._mcp_server.id}") return except Exception: msg = "Connection failed (possibly due to auth error). Please check authentication settings first" await self._process_mcp_tasks(None, msg) elif self._mcp_server.server_type == MCPServerType.STREAMABLE_HTTP: # Streamable HTTP transport try: async with streamablehttp_client(url, headers) as (read_stream, write_stream, _): async with ClientSession(read_stream, write_stream) as client_session: try: await asyncio.wait_for(client_session.initialize(), timeout=5) logging.info("client_session initialized successfully") await self._process_mcp_tasks(client_session) except asyncio.TimeoutError: msg = f"Timeout initializing client_session for server {self._mcp_server.id}" logging.error(msg) await self._process_mcp_tasks(None, msg) except asyncio.CancelledError: logging.warning(f"STREAMABLE_HTTP MCP session cancelled for server {self._mcp_server.id}") return except Exception as e: logging.exception(e) msg = "Connection failed (possibly due to auth error). Please check authentication settings first" await self._process_mcp_tasks(None, msg) else: await self._process_mcp_tasks(None, f"Unsupported MCP server type: {self._mcp_server.server_type}, id: {self._mcp_server.id}") async def _process_mcp_tasks(self, client_session: ClientSession | None, error_message: str | None = None) -> None: while not self._close: try: mcp_task, arguments, result_queue = await asyncio.wait_for(self._queue.get(), timeout=1) except asyncio.TimeoutError: continue except asyncio.CancelledError: break logging.debug(f"Got MCP task {mcp_task} arguments {arguments}") r: Any = None if not client_session or error_message: r = ValueError(error_message) try: await result_queue.put(r) except asyncio.CancelledError: break continue try: if mcp_task == "list_tools": r = await client_session.list_tools() elif mcp_task == "tool_call": r = await client_session.call_tool(**arguments) else: r = ValueError(f"Unknown MCP task {mcp_task}") except Exception as e: r = e except asyncio.CancelledError: break try: await result_queue.put(r) except asyncio.CancelledError: break async def _call_mcp_server(self, task_type: MCPTaskType, request_timeout: float | int = 8, **kwargs) -> Any: if self._close: raise ValueError("Session is closed") results = asyncio.Queue() await self._queue.put((task_type, kwargs, results)) try: result: CallToolResult | Exception = await asyncio.wait_for(results.get(), timeout=request_timeout) if isinstance(result, Exception): raise result return result except asyncio.TimeoutError: raise asyncio.TimeoutError(f"MCP task '{task_type}' timeout after {request_timeout}s") except Exception: raise async def _call_mcp_tool(self, name: str, arguments: dict[str, Any], request_timeout: float | int = 10) -> str: result: CallToolResult = await self._call_mcp_server("tool_call", name=name, arguments=arguments, request_timeout=request_timeout) if result.isError: return f"MCP server error: {result.content}" # For now, we only support text content if not result.content: return "MCP server returned empty content." if isinstance(result.content[0], TextContent): return result.content[0].text else: return f"Unsupported content type {type(result.content)}" async def _get_tools_from_mcp_server(self, request_timeout: float | int = 8) -> list[Tool]: try: result: ListToolsResult = await self._call_mcp_server("list_tools", request_timeout=request_timeout) return result.tools except Exception: raise def get_tools(self, timeout: float | int = 10) -> list[Tool]: if self._close: raise ValueError("Session is closed") future = asyncio.run_coroutine_threadsafe(self._get_tools_from_mcp_server(request_timeout=timeout), self._event_loop) try: return future.result(timeout=timeout) except FuturesTimeoutError: msg = f"Timeout when fetching tools from MCP server: {self._mcp_server.id} (timeout={timeout})" logging.error(msg) raise RuntimeError(msg) except Exception: logging.exception(f"Error fetching tools from MCP server: {self._mcp_server.id}") raise @override def tool_call(self, name: str, arguments: dict[str, Any], timeout: float | int = 10) -> str: if self._close: return "Error: Session is closed" future = asyncio.run_coroutine_threadsafe( self._call_mcp_tool(name, arguments, request_timeout=timeout), self._event_loop, ) try: return future.result(timeout=timeout) except FuturesTimeoutError: logging.error(f"Timeout calling tool '{name}' on MCP server: {self._mcp_server.id} (timeout={timeout})") return f"Timeout calling tool '{name}' (timeout={timeout})." except Exception as e: logging.exception(f"Error calling tool '{name}' on MCP server: {self._mcp_server.id}") return f"Error calling tool '{name}': {e}." async def close(self) -> None: if self._close: return self._close = True while not self._queue.empty(): try: _, _, result_queue = self._queue.get_nowait() try: await result_queue.put(asyncio.CancelledError("Session is closing")) except Exception: pass except asyncio.QueueEmpty: break except Exception: break try: self._event_loop.call_soon_threadsafe(self._event_loop.stop) except Exception: pass try: self._thread_pool.shutdown(wait=True) except Exception: pass self.__class__._ALL_INSTANCES.discard(self) def close_sync(self, timeout: float | int = 5) -> None: if not self._event_loop.is_running(): logging.warning(f"Event loop already stopped for {self._mcp_server.id}") return try: future = asyncio.run_coroutine_threadsafe(self.close(), self._event_loop) try: future.result(timeout=timeout) except FuturesTimeoutError: logging.error(f"Timeout while closing session for server {self._mcp_server.id} (timeout={timeout})") except Exception: logging.exception(f"Unexpected error during close_sync for {self._mcp_server.id}") except Exception: logging.exception(f"Exception while scheduling close for server {self._mcp_server.id}") def close_multiple_mcp_toolcall_sessions(sessions: list[MCPToolCallSession]) -> None: logging.info(f"Want to clean up {len(sessions)} MCP sessions") async def _gather_and_stop() -> None: try: await asyncio.gather(*[s.close() for s in sessions if s is not None], return_exceptions=True) except Exception: logging.exception("Exception during MCP session cleanup") finally: try: loop.call_soon_threadsafe(loop.stop) except Exception: pass try: loop = asyncio.new_event_loop() thread = threading.Thread(target=loop.run_forever, daemon=True) thread.start() asyncio.run_coroutine_threadsafe(_gather_and_stop(), loop).result() thread.join() except Exception: logging.exception("Exception during MCP session cleanup thread management") logging.info( f"{len(sessions)} MCP sessions has been cleaned up. {len(list(MCPToolCallSession._ALL_INSTANCES))} in global context.") def shutdown_all_mcp_sessions(): """Gracefully shutdown all active MCPToolCallSession instances.""" sessions = list(MCPToolCallSession._ALL_INSTANCES) if not sessions: logging.info("No MCPToolCallSession instances to close.") return logging.info(f"Shutting down {len(sessions)} MCPToolCallSession instances...") close_multiple_mcp_toolcall_sessions(sessions) logging.info("All MCPToolCallSession instances have been closed.") def mcp_tool_metadata_to_openai_tool(mcp_tool: Tool | dict, function_name: str | None = None) -> dict[str, Any]: if isinstance(mcp_tool, dict): return { "type": "function", "function": { "name": function_name or mcp_tool["name"], "description": mcp_tool["description"], "parameters": mcp_tool["inputSchema"], }, } return { "type": "function", "function": { "name": function_name or mcp_tool.name, "description": mcp_tool.description, "parameters": mcp_tool.inputSchema, }, }