From 27b834f296703b0a14968ad8a6302bb342b20d75 Mon Sep 17 00:00:00 2001 From: Eric Zhu Date: Wed, 16 Apr 2025 17:43:28 -0700 Subject: [PATCH] Make shared session possible for MCP tool (#6312) Resolves #6232, #6198 This PR introduces an optional parameter `session` to `mcp_server_tools` to support reuse of the same session. ```python import asyncio from autogen_agentchat.agents import AssistantAgent from autogen_agentchat.conditions import TextMentionTermination from autogen_agentchat.teams import RoundRobinGroupChat from autogen_agentchat.ui import Console from autogen_ext.models.openai import OpenAIChatCompletionClient from autogen_ext.tools.mcp import StdioServerParams, create_mcp_server_session, mcp_server_tools async def main() -> None: model_client = OpenAIChatCompletionClient(model="gpt-4o", parallel_tool_calls=False) # type: ignore params = StdioServerParams( command="npx", args=["@playwright/mcp@latest"], read_timeout_seconds=60, ) async with create_mcp_server_session(params) as session: await session.initialize() tools = await mcp_server_tools(server_params=params, session=session) print(f"Tools: {[tool.name for tool in tools]}") agent = AssistantAgent( name="Assistant", model_client=model_client, tools=tools, # type: ignore ) termination = TextMentionTermination("TERMINATE") team = RoundRobinGroupChat([agent], termination_condition=termination) await Console( team.run_stream( task="Go to https://ekzhu.com/, visit the first link in the page, then tell me about the linked page." ) ) asyncio.run(main()) ``` Based on discussion in this thread: #6284, we will consider serialization and deserialization of MCP server tools when used in this manner in a separate issue. This PR also replaces the `json_schema_to_pydantic` dependency with built-in utils. --- .../autogen_core/utils/_json_to_pydantic.py | 7 +- python/packages/autogen-ext/pyproject.toml | 5 +- .../task_centric_memory/__init__.py | 2 +- .../src/autogen_ext/tools/mcp/__init__.py | 2 + .../src/autogen_ext/tools/mcp/_base.py | 40 ++++++---- .../src/autogen_ext/tools/mcp/_factory.py | 69 ++++++++++++++++- .../src/autogen_ext/tools/mcp/_sse.py | 13 ++-- .../src/autogen_ext/tools/mcp/_stdio.py | 9 ++- .../autogen-ext/tests/tools/test_mcp_tools.py | 76 ++++++++++++++++--- python/uv.lock | 2 - 10 files changed, 179 insertions(+), 46 deletions(-) diff --git a/python/packages/autogen-core/src/autogen_core/utils/_json_to_pydantic.py b/python/packages/autogen-core/src/autogen_core/utils/_json_to_pydantic.py index 892a22a90..082d950a1 100644 --- a/python/packages/autogen-core/src/autogen_core/utils/_json_to_pydantic.py +++ b/python/packages/autogen-core/src/autogen_core/utils/_json_to_pydantic.py @@ -229,11 +229,14 @@ class _JSONSchemaToPydantic: item_type = self.get_ref(item_schema["$ref"].split("/")[-1]) else: item_type_name = item_schema.get("type") - if item_type_name not in TYPE_MAPPING: + if item_type_name is None: + item_type = List[str] + elif item_type_name not in TYPE_MAPPING: raise UnsupportedKeywordError( f"Unsupported or missing item type `{item_type_name}` for array field `{key}` in `{model_name}`" ) - item_type = TYPE_MAPPING[item_type_name] + else: + item_type = TYPE_MAPPING[item_type_name] base_type = conlist(item_type, **constraints) if constraints else List[item_type] # type: ignore[valid-type] diff --git a/python/packages/autogen-ext/pyproject.toml b/python/packages/autogen-ext/pyproject.toml index ebf38593b..5b08ab891 100644 --- a/python/packages/autogen-ext/pyproject.toml +++ b/python/packages/autogen-ext/pyproject.toml @@ -135,10 +135,7 @@ semantic-kernel-all = [ rich = ["rich>=13.9.4"] -mcp = [ - "mcp>=1.6.0", - "json-schema-to-pydantic>=0.2.4" -] +mcp = ["mcp>=1.6.0"] [tool.hatch.build.targets.wheel] packages = ["src/autogen_ext"] diff --git a/python/packages/autogen-ext/src/autogen_ext/experimental/task_centric_memory/__init__.py b/python/packages/autogen-ext/src/autogen_ext/experimental/task_centric_memory/__init__.py index f56657f6f..97415af2b 100644 --- a/python/packages/autogen-ext/src/autogen_ext/experimental/task_centric_memory/__init__.py +++ b/python/packages/autogen-ext/src/autogen_ext/experimental/task_centric_memory/__init__.py @@ -1,4 +1,4 @@ -from .memory_controller import MemoryController, MemoryControllerConfig from ._memory_bank import MemoryBankConfig +from .memory_controller import MemoryController, MemoryControllerConfig __all__ = ["MemoryController", "MemoryControllerConfig", "MemoryBankConfig"] diff --git a/python/packages/autogen-ext/src/autogen_ext/tools/mcp/__init__.py b/python/packages/autogen-ext/src/autogen_ext/tools/mcp/__init__.py index 83d76fcad..eeae32f1c 100644 --- a/python/packages/autogen-ext/src/autogen_ext/tools/mcp/__init__.py +++ b/python/packages/autogen-ext/src/autogen_ext/tools/mcp/__init__.py @@ -1,9 +1,11 @@ from ._config import McpServerParams, SseServerParams, StdioServerParams from ._factory import mcp_server_tools +from ._session import create_mcp_server_session from ._sse import SseMcpToolAdapter from ._stdio import StdioMcpToolAdapter __all__ = [ + "create_mcp_server_session", "StdioMcpToolAdapter", "StdioServerParams", "SseMcpToolAdapter", diff --git a/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_base.py b/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_base.py index 0901be9ed..488984ede 100644 --- a/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_base.py +++ b/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_base.py @@ -1,12 +1,12 @@ import asyncio import builtins from abc import ABC -from typing import Any, Generic, Type, TypeVar +from typing import Any, Dict, Generic, Type, TypeVar from autogen_core import CancellationToken from autogen_core.tools import BaseTool -from json_schema_to_pydantic import create_model -from mcp import Tool +from autogen_core.utils import schema_to_pydantic_model +from mcp import ClientSession, Tool from pydantic import BaseModel from ._config import McpServerParams @@ -26,16 +26,17 @@ class McpToolAdapter(BaseTool[BaseModel, Any], ABC, Generic[TServerParams]): component_type = "tool" - def __init__(self, server_params: TServerParams, tool: Tool) -> None: + def __init__(self, server_params: TServerParams, tool: Tool, session: ClientSession | None = None) -> None: self._tool = tool self._server_params = server_params + self._session = session # Extract name and description name = tool.name description = tool.description or "" # Create the input model from the tool's schema - input_model = create_model(tool.inputSchema, allow_undefined_array_items=True) + input_model = schema_to_pydantic_model(tool.inputSchema) # Use Any as return type since MCP tool returns can vary return_type: Type[Any] = object @@ -61,20 +62,27 @@ class McpToolAdapter(BaseTool[BaseModel, Any], ABC, Generic[TServerParams]): # for many servers. kwargs = args.model_dump(exclude_unset=True) + if self._session is not None: + # If a session is provided, use it directly. + session = self._session + return await self._run(args=kwargs, cancellation_token=cancellation_token, session=session) + + async with create_mcp_server_session(self._server_params) as session: + await session.initialize() + return await self._run(args=kwargs, cancellation_token=cancellation_token, session=session) + + async def _run(self, args: Dict[str, Any], cancellation_token: CancellationToken, session: ClientSession) -> Any: try: - async with create_mcp_server_session(self._server_params) as session: - await session.initialize() + if cancellation_token.is_cancelled(): + raise Exception("Operation cancelled") - if cancellation_token.is_cancelled(): - raise Exception("Operation cancelled") + result_future = asyncio.ensure_future(session.call_tool(name=self._tool.name, arguments=args)) + cancellation_token.link_future(result_future) + result = await result_future - result_future = asyncio.ensure_future(session.call_tool(name=self._tool.name, arguments=kwargs)) - cancellation_token.link_future(result_future) - result = await result_future - - if result.isError: - raise Exception(f"MCP tool execution failed: {result.content}") - return result.content + if result.isError: + raise Exception(f"MCP tool execution failed: {result.content}") + return result.content except Exception as e: error_message = self._format_errors(e) raise Exception(error_message) from e diff --git a/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_factory.py b/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_factory.py index 3b8c2356b..5e23cff92 100644 --- a/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_factory.py +++ b/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_factory.py @@ -1,3 +1,5 @@ +from mcp import ClientSession + from ._config import McpServerParams, SseServerParams, StdioServerParams from ._session import create_mcp_server_session from ._sse import SseMcpToolAdapter @@ -6,6 +8,7 @@ from ._stdio import StdioMcpToolAdapter async def mcp_server_tools( server_params: McpServerParams, + session: ClientSession | None = None, ) -> list[StdioMcpToolAdapter | SseMcpToolAdapter]: """Creates a list of MCP tool adapters that can be used with AutoGen agents. @@ -24,6 +27,9 @@ async def mcp_server_tools( server_params (McpServerParams): Connection parameters for the MCP server. Can be either StdioServerParams for command-line tools or SseServerParams for HTTP/SSE services. + session (ClientSession | None): Optional existing session to use. This is used + when you want to reuse an existing connection to the MCP server. The session + will be reused when creating the MCP tool adapters. Returns: list[StdioMcpToolAdapter | SseMcpToolAdapter]: A list of tool adapters ready to use @@ -110,6 +116,58 @@ async def mcp_server_tools( asyncio.run(main()) + **Sharing an MCP client session across multiple tools:** + + You can create a single MCP client session and share it across multiple tools. + This is sometimes required when the server maintains a session state + (e.g., a browser state) that should be reused for multiple requests. + + The following example show how to create a single MCP client session + to a local `Playwright `_ + server and share it across multiple tools. + + + .. code-block:: python + + import asyncio + + from autogen_agentchat.agents import AssistantAgent + from autogen_agentchat.conditions import TextMentionTermination + from autogen_agentchat.teams import RoundRobinGroupChat + from autogen_agentchat.ui import Console + from autogen_ext.models.openai import OpenAIChatCompletionClient + from autogen_ext.tools.mcp import StdioServerParams, create_mcp_server_session, mcp_server_tools + + + async def main() -> None: + model_client = OpenAIChatCompletionClient(model="gpt-4o", parallel_tool_calls=False) # type: ignore + params = StdioServerParams( + command="npx", + args=["@playwright/mcp@latest"], + read_timeout_seconds=60, + ) + async with create_mcp_server_session(params) as session: + await session.initialize() + tools = await mcp_server_tools(server_params=params, session=session) + print(f"Tools: {[tool.name for tool in tools]}") + + agent = AssistantAgent( + name="Assistant", + model_client=model_client, + tools=tools, # type: ignore + ) + + termination = TextMentionTermination("TERMINATE") + team = RoundRobinGroupChat([agent], termination_condition=termination) + await Console( + team.run_stream( + task="Go to https://ekzhu.com/, visit the first link in the page, then tell me about the linked page." + ) + ) + + + asyncio.run(main()) + **Remote MCP service over SSE example:** @@ -130,13 +188,16 @@ async def mcp_server_tools( For more examples and detailed usage, see the samples directory in the package repository. """ - async with create_mcp_server_session(server_params) as session: - await session.initialize() + if session is None: + async with create_mcp_server_session(server_params) as temp_session: + await temp_session.initialize() + tools = await temp_session.list_tools() + else: tools = await session.list_tools() if isinstance(server_params, StdioServerParams): - return [StdioMcpToolAdapter(server_params=server_params, tool=tool) for tool in tools.tools] + return [StdioMcpToolAdapter(server_params=server_params, tool=tool, session=session) for tool in tools.tools] elif isinstance(server_params, SseServerParams): - return [SseMcpToolAdapter(server_params=server_params, tool=tool) for tool in tools.tools] + return [SseMcpToolAdapter(server_params=server_params, tool=tool, session=session) for tool in tools.tools] raise ValueError(f"Unsupported server params type: {type(server_params)}") diff --git a/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_sse.py b/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_sse.py index 252af7ce5..48e3f348d 100644 --- a/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_sse.py +++ b/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_sse.py @@ -1,5 +1,5 @@ from autogen_core import Component -from mcp import Tool +from mcp import ClientSession, Tool from pydantic import BaseModel from typing_extensions import Self @@ -35,8 +35,11 @@ class SseMcpToolAdapter( Args: server_params (SseServerParameters): Parameters for the MCP server connection, - including URL, headers, and timeouts - tool (Tool): The MCP tool to wrap + including URL, headers, and timeouts. + tool (Tool): The MCP tool to wrap. + session (ClientSession, optional): The MCP client session to use. If not provided, + it will create a new session. This is useful for testing or when you want to + manage the session lifecycle yourself. Examples: Use a remote translation service that implements MCP over SSE to create tools @@ -86,8 +89,8 @@ class SseMcpToolAdapter( component_config_schema = SseMcpToolAdapterConfig component_provider_override = "autogen_ext.tools.mcp.SseMcpToolAdapter" - def __init__(self, server_params: SseServerParams, tool: Tool) -> None: - super().__init__(server_params=server_params, tool=tool) + def __init__(self, server_params: SseServerParams, tool: Tool, session: ClientSession | None = None) -> None: + super().__init__(server_params=server_params, tool=tool, session=session) def _to_config(self) -> SseMcpToolAdapterConfig: """ diff --git a/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_stdio.py b/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_stdio.py index 4f827785e..14f024c6e 100644 --- a/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_stdio.py +++ b/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_stdio.py @@ -1,5 +1,5 @@ from autogen_core import Component -from mcp import Tool +from mcp import ClientSession, Tool from pydantic import BaseModel from typing_extensions import Self @@ -37,6 +37,9 @@ class StdioMcpToolAdapter( server_params (StdioServerParams): Parameters for the MCP server connection, including command to run and its arguments tool (Tool): The MCP tool to wrap + session (ClientSession, optional): The MCP client session to use. If not provided, + a new session will be created. This is useful for testing or when you want to + manage the session lifecycle yourself. See :func:`~autogen_ext.tools.mcp.mcp_server_tools` for examples. """ @@ -44,8 +47,8 @@ class StdioMcpToolAdapter( component_config_schema = StdioMcpToolAdapterConfig component_provider_override = "autogen_ext.tools.mcp.StdioMcpToolAdapter" - def __init__(self, server_params: StdioServerParams, tool: Tool) -> None: - super().__init__(server_params=server_params, tool=tool) + def __init__(self, server_params: StdioServerParams, tool: Tool, session: ClientSession | None = None) -> None: + super().__init__(server_params=server_params, tool=tool, session=session) def _to_config(self) -> StdioMcpToolAdapterConfig: """ diff --git a/python/packages/autogen-ext/tests/tools/test_mcp_tools.py b/python/packages/autogen-ext/tests/tools/test_mcp_tools.py index 14ba9c89c..998d678ab 100644 --- a/python/packages/autogen-ext/tests/tools/test_mcp_tools.py +++ b/python/packages/autogen-ext/tests/tools/test_mcp_tools.py @@ -4,14 +4,15 @@ from unittest.mock import AsyncMock, MagicMock import pytest from autogen_core import CancellationToken +from autogen_core.utils import schema_to_pydantic_model from autogen_ext.tools.mcp import ( SseMcpToolAdapter, SseServerParams, StdioMcpToolAdapter, StdioServerParams, + create_mcp_server_session, mcp_server_tools, ) -from json_schema_to_pydantic import create_model from mcp import ClientSession, Tool @@ -127,7 +128,7 @@ async def test_mcp_tool_execution( with caplog.at_level(logging.INFO): adapter = StdioMcpToolAdapter(server_params=sample_server_params, tool=sample_tool) result = await adapter.run_json( - args=create_model(sample_tool.inputSchema)(**{"test_param": "test"}).model_dump(), + args=schema_to_pydantic_model(sample_tool.inputSchema)(**{"test_param": "test"}).model_dump(), cancellation_token=cancellation_token, ) @@ -179,6 +180,48 @@ async def test_adapter_from_server_params( ) +@pytest.mark.asyncio +async def test_adapter_from_factory( + sample_tool: Tool, + sample_server_params: StdioServerParams, + mock_session: AsyncMock, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Test that factory function returns a list of tools.""" + mock_context = AsyncMock() + mock_context.__aenter__.return_value = mock_session + monkeypatch.setattr( + "autogen_ext.tools.mcp._factory.create_mcp_server_session", + lambda *args, **kwargs: mock_context, # type: ignore + ) + mock_session.list_tools.return_value.tools = [sample_tool] + tools = await mcp_server_tools(server_params=sample_server_params) + assert tools is not None + assert len(tools) > 0 + assert isinstance(tools[0], StdioMcpToolAdapter) + + +@pytest.mark.asyncio +async def test_adapter_from_factory_existing_session( + sample_tool: Tool, + sample_server_params: StdioServerParams, + mock_session: AsyncMock, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Test that factory function returns a list of tools with an existing session.""" + mock_context = AsyncMock() + mock_context.__aenter__.return_value = mock_session + monkeypatch.setattr( + "autogen_ext.tools.mcp._factory.create_mcp_server_session", + lambda *args, **kwargs: mock_context, # type: ignore + ) + mock_session.list_tools.return_value.tools = [sample_tool] + tools = await mcp_server_tools(server_params=sample_server_params, session=mock_session) + assert tools is not None + assert len(tools) > 0 + assert isinstance(tools[0], StdioMcpToolAdapter) + + @pytest.mark.asyncio async def test_sse_adapter_config_serialization(sample_sse_tool: Tool) -> None: """Test that SSE adapter can be saved to and loaded from config.""" @@ -231,7 +274,7 @@ async def test_sse_tool_execution( with caplog.at_level(logging.INFO): adapter = SseMcpToolAdapter(server_params=params, tool=sample_sse_tool) result = await adapter.run_json( - args=create_model(sample_sse_tool.inputSchema)(**{"test_param": "test"}).model_dump(), + args=schema_to_pydantic_model(sample_sse_tool.inputSchema)(**{"test_param": "test"}).model_dump(), cancellation_token=CancellationToken(), ) @@ -284,8 +327,6 @@ async def test_sse_adapter_from_server_params( ) -# TODO: why is this test not working in CI? -@pytest.mark.skip(reason="Skipping test_mcp_server_fetch due to CI issues.") @pytest.mark.asyncio async def test_mcp_server_fetch() -> None: params = StdioServerParams( @@ -300,8 +341,6 @@ async def test_mcp_server_fetch() -> None: assert result is not None -# TODO: why is this test not working in CI? -@pytest.mark.skip(reason="Skipping due to CI issues.") @pytest.mark.asyncio async def test_mcp_server_filesystem() -> None: params = StdioServerParams( @@ -322,8 +361,6 @@ async def test_mcp_server_filesystem() -> None: assert result is not None -# TODO: why is this test not working in CI? -@pytest.mark.skip(reason="Skipping due to CI issues.") @pytest.mark.asyncio async def test_mcp_server_git() -> None: params = StdioServerParams( @@ -341,6 +378,27 @@ async def test_mcp_server_git() -> None: assert result is not None +@pytest.mark.asyncio +async def test_mcp_server_git_existing_session() -> None: + params = StdioServerParams( + command="uvx", + args=["mcp-server-git"], + read_timeout_seconds=60, + ) + async with create_mcp_server_session(params) as session: + await session.initialize() + tools = await mcp_server_tools(server_params=params, session=session) + assert tools is not None + git_log = [tool for tool in tools if tool.name == "git_log"][0] + repo_path = os.path.join(os.path.dirname(__file__), "..", "..", "..", "..", "..") + result = await git_log.run_json({"repo_path": repo_path}, CancellationToken()) + assert result is not None + + git_status = [tool for tool in tools if tool.name == "git_status"][0] + result = await git_status.run_json({"repo_path": repo_path}, CancellationToken()) + assert result is not None + + @pytest.mark.asyncio async def test_mcp_server_github() -> None: # Check if GITHUB_TOKEN is set. diff --git a/python/uv.lock b/python/uv.lock index 4325b4c80..919821edd 100644 --- a/python/uv.lock +++ b/python/uv.lock @@ -644,7 +644,6 @@ magentic-one = [ { name = "playwright" }, ] mcp = [ - { name = "json-schema-to-pydantic" }, { name = "mcp" }, ] ollama = [ @@ -746,7 +745,6 @@ requires-dist = [ { name = "httpx", marker = "extra == 'http-tool'", specifier = ">=0.27.0" }, { name = "ipykernel", marker = "extra == 'jupyter-executor'", specifier = ">=6.29.5" }, { name = "json-schema-to-pydantic", marker = "extra == 'http-tool'", specifier = ">=0.2.0" }, - { name = "json-schema-to-pydantic", marker = "extra == 'mcp'", specifier = ">=0.2.4" }, { name = "langchain-core", marker = "extra == 'langchain'", specifier = "~=0.3.3" }, { name = "llama-cpp-python", marker = "extra == 'llama-cpp'", specifier = ">=0.3.8" }, { name = "magika", marker = "extra == 'file-surfer'", specifier = ">=0.6.1rc2" },