ragflow/mcp_client/mcp_tool_call.py

213 lines
8.7 KiB
Python
Raw Permalink Normal View History

import asyncio
import logging
import threading
import weakref
from concurrent.futures import ThreadPoolExecutor
from string import Template
from typing import Any, Literal
from typing_extensions import override
from api.db import MCPServerType
from mcp.client.session import ClientSession
from mcp.client.sse import sse_client
from mcp.client.streamable_http import streamablehttp_client
from mcp.types import CallToolResult, ListToolsResult, TextContent, Tool
from rag.llm.chat_model import ToolCallSession
MCPTaskType = Literal["list_tools", "tool_call"]
MCPTask = tuple[MCPTaskType, dict[str, Any], asyncio.Queue[Any]]
class MCPToolCallSession(ToolCallSession):
_ALL_INSTANCES: weakref.WeakSet["MCPToolCallSession"] = weakref.WeakSet()
def __init__(self, mcp_server: Any, server_variables: dict[str, Any] | None = None) -> None:
self.__class__._ALL_INSTANCES.add(self)
self._mcp_server = mcp_server
self._server_variables = server_variables or {}
self._queue = asyncio.Queue()
self._close = False
self._event_loop = asyncio.new_event_loop()
self._thread_pool = ThreadPoolExecutor(max_workers=1)
self._thread_pool.submit(self._event_loop.run_forever)
asyncio.run_coroutine_threadsafe(self._mcp_server_loop(), self._event_loop)
async def _mcp_server_loop(self) -> None:
url = self._mcp_server.url.strip()
raw_headers: dict[str, str] = self._mcp_server.headers or {}
headers: dict[str, str] = {}
for h, v in raw_headers.items():
nh = Template(h).safe_substitute(self._server_variables)
nv = Template(v).safe_substitute(self._server_variables)
headers[nh] = nv
if self._mcp_server.server_type == MCPServerType.SSE:
# SSE transport
async with sse_client(url, headers) as stream:
async with ClientSession(*stream) as client_session:
try:
await asyncio.wait_for(client_session.initialize(), timeout=5)
logging.info("client_session initialized successfully")
except asyncio.TimeoutError:
logging.error(f"Timeout initializing client_session for server {self._mcp_server.id}")
return
await self._process_mcp_tasks(client_session)
elif self._mcp_server.server_type == MCPServerType.STREAMABLE_HTTP:
# Streamable HTTP transport
async with streamablehttp_client(url, headers) as (read_stream, write_stream, _):
async with ClientSession(read_stream, write_stream) as client_session:
try:
await asyncio.wait_for(client_session.initialize(), timeout=5)
logging.info("client_session initialized successfully")
except asyncio.TimeoutError:
logging.error(f"Timeout initializing client_session for server {self._mcp_server.id}")
return
await asyncio.wait_for(client_session.initialize(), timeout=5)
await self._process_mcp_tasks(client_session)
else:
raise ValueError(f"Unsupported MCP server type {self._mcp_server.server_type} id {self._mcp_server.id}")
async def _process_mcp_tasks(self, client_session: ClientSession) -> None:
while not self._close:
try:
mcp_task, arguments, result_queue = await asyncio.wait_for(self._queue.get(), timeout=1)
except asyncio.TimeoutError:
continue
logging.debug(f"Got MCP task {mcp_task} arguments {arguments}")
r: Any = None
try:
if mcp_task == "list_tools":
r = await client_session.list_tools()
elif mcp_task == "tool_call":
r = await client_session.call_tool(**arguments)
else:
r = ValueError(f"Unknown MCP task {mcp_task}")
except Exception as e:
r = e
await result_queue.put(r)
async def _call_mcp_server(self, task_type: MCPTaskType, timeout: float = 8, **kwargs) -> Any:
results = asyncio.Queue()
await self._queue.put((task_type, kwargs, results))
try:
result: CallToolResult | Exception = await asyncio.wait_for(results.get(), timeout=timeout)
except asyncio.TimeoutError:
raise TimeoutError(f"MCP task '{task_type}' timeout after {timeout}s")
if isinstance(result, Exception):
raise result
return result
async def _call_mcp_tool(self, name: str, arguments: dict[str, Any]) -> str:
result: CallToolResult = await self._call_mcp_server("tool_call", name=name, arguments=arguments)
if result.isError:
return f"MCP server error: {result.content}"
# For now we only support text content
if isinstance(result.content[0], TextContent):
return result.content[0].text
else:
return f"Unsupported content type {type(result.content)}"
async def _get_tools_from_mcp_server(self) -> list[Tool]:
result: ListToolsResult = await self._call_mcp_server("list_tools")
return result.tools
def get_tools(self, timeout: float = 10) -> list[Tool]:
future = asyncio.run_coroutine_threadsafe(self._get_tools_from_mcp_server(), self._event_loop)
try:
return future.result(timeout=timeout)
except TimeoutError:
logging.error(f"Timeout when fetching tools from MCP server: {self._mcp_server.id}")
return []
except Exception:
logging.exception(f"Error fetching tools from MCP server: {self._mcp_server.id}")
return []
@override
def tool_call(self, name: str, arguments: dict[str, Any], timeout: float = 10) -> str:
future = asyncio.run_coroutine_threadsafe(self._call_mcp_tool(name, arguments), self._event_loop)
try:
return future.result(timeout=timeout)
except TimeoutError as te:
logging.error(f"Timeout calling tool '{name}' on MCP server: {self._mcp_server.id}")
return f"Timeout calling tool '{name}': {te}."
except Exception as e:
logging.exception(f"Error calling tool '{name}' on MCP server: {self._mcp_server.id}")
return f"Error calling tool '{name}': {e}."
async def close(self) -> None:
if self._close:
return
self._close = True
self._event_loop.call_soon_threadsafe(self._event_loop.stop)
self._thread_pool.shutdown(wait=True)
self.__class__._ALL_INSTANCES.discard(self)
def close_sync(self, timeout: float = 5) -> None:
if not self._event_loop.is_running():
logging.warning(f"Event loop already stopped for {self._mcp_server.id}")
return
future = asyncio.run_coroutine_threadsafe(self.close(), self._event_loop)
try:
future.result(timeout=timeout)
except TimeoutError:
logging.error(f"Timeout while closing session for server {self._mcp_server.id}")
except Exception:
logging.exception(f"Unexpected error during close_sync for {self._mcp_server.id}")
def close_multiple_mcp_toolcall_sessions(sessions: list[MCPToolCallSession]) -> None:
logging.info(f"Want to clean up {len(sessions)} MCP sessions")
async def _gather_and_stop() -> None:
try:
await asyncio.gather(*[s.close() for s in sessions if s is not None], return_exceptions=True)
finally:
loop.call_soon_threadsafe(loop.stop)
loop = asyncio.new_event_loop()
thread = threading.Thread(target=loop.run_forever, daemon=True)
thread.start()
asyncio.run_coroutine_threadsafe(_gather_and_stop(), loop).result()
thread.join()
logging.info(f"{len(sessions)} MCP sessions has been cleaned up. {len(list(MCPToolCallSession._ALL_INSTANCES))} in global context.")
def shutdown_all_mcp_sessions():
"""Gracefully shutdown all active MCPToolCallSession instances."""
sessions = list(MCPToolCallSession._ALL_INSTANCES)
if not sessions:
logging.info("No MCPToolCallSession instances to close.")
return
logging.info(f"Shutting down {len(sessions)} MCPToolCallSession instances...")
close_multiple_mcp_toolcall_sessions(sessions)
logging.info("All MCPToolCallSession instances have been closed.")
def mcp_tool_metadata_to_openai_tool(mcp_tool: Tool) -> dict[str, Any]:
return {
"type": "function",
"function": {
"name": mcp_tool.name,
"description": mcp_tool.description,
"parameters": mcp_tool.inputSchema,
},
}