dify/api/core/mcp/mcp_client.py

116 lines
4.3 KiB
Python
Raw Normal View History

import logging
from collections.abc import Callable
from contextlib import AbstractContextManager, ExitStack
from types import TracebackType
from typing import Any
from urllib.parse import urlparse
from core.mcp.client.sse_client import sse_client
from core.mcp.client.streamable_client import streamablehttp_client
from core.mcp.error import MCPConnectionError
from core.mcp.session.client_session import ClientSession
from core.mcp.types import CallToolResult, Tool
logger = logging.getLogger(__name__)
class MCPClient:
def __init__(
self,
server_url: str,
headers: dict[str, str] | None = None,
timeout: float | None = None,
sse_read_timeout: float | None = None,
):
self.server_url = server_url
self.headers = headers or {}
self.timeout = timeout
self.sse_read_timeout = sse_read_timeout
# Initialize session and client objects
self._session: ClientSession | None = None
self._exit_stack = ExitStack()
self._initialized = False
def __enter__(self):
self._initialize()
self._initialized = True
return self
def __exit__(self, exc_type: type | None, exc_value: BaseException | None, traceback: TracebackType | None):
self.cleanup()
def _initialize(
self,
):
"""Initialize the client with fallback to SSE if streamable connection fails"""
connection_methods: dict[str, Callable[..., AbstractContextManager[Any]]] = {
"mcp": streamablehttp_client,
"sse": sse_client,
}
parsed_url = urlparse(self.server_url)
path = parsed_url.path or ""
2025-07-21 21:14:38 +08:00
method_name = path.rstrip("/").split("/")[-1] if path else ""
if method_name in connection_methods:
client_factory = connection_methods[method_name]
self.connect_server(client_factory, method_name)
else:
try:
logger.debug("Not supported method %s found in URL path, trying default 'mcp' method.", method_name)
self.connect_server(sse_client, "sse")
except MCPConnectionError:
logger.debug("MCP connection failed with 'sse', falling back to 'mcp' method.")
self.connect_server(streamablehttp_client, "mcp")
def connect_server(self, client_factory: Callable[..., AbstractContextManager[Any]], method_name: str) -> None:
"""
Connect to the MCP server using streamable http or sse.
Default to streamable http.
Args:
client_factory: The client factory to use(streamablehttp_client or sse_client).
method_name: The method name to use(mcp or sse).
"""
streams_context = client_factory(
url=self.server_url,
headers=self.headers,
timeout=self.timeout,
sse_read_timeout=self.sse_read_timeout,
)
# Use exit_stack to manage context managers properly
if method_name == "mcp":
read_stream, write_stream, _ = self._exit_stack.enter_context(streams_context)
streams = (read_stream, write_stream)
else: # sse_client
streams = self._exit_stack.enter_context(streams_context)
session_context = ClientSession(*streams)
self._session = self._exit_stack.enter_context(session_context)
self._session.initialize()
def list_tools(self) -> list[Tool]:
"""List available tools from the MCP server"""
if not self._session:
raise ValueError("Session not initialized.")
response = self._session.list_tools()
return response.tools
def invoke_tool(self, tool_name: str, tool_args: dict[str, Any]) -> CallToolResult:
"""Call a tool"""
if not self._session:
raise ValueError("Session not initialized.")
return self._session.call_tool(tool_name, tool_args)
def cleanup(self):
"""Clean up resources"""
try:
# ExitStack will handle proper cleanup of all managed context managers
self._exit_stack.close()
except Exception as e:
logger.exception("Error during cleanup")
raise ValueError(f"Error during cleanup: {e}")
finally:
self._session = None
self._initialized = False