Make shared session possible for MCP tool (#6312)

Resolves #6232, #6198

This PR introduces an optional parameter `session` to `mcp_server_tools`
to support reuse of the same session.

```python
import asyncio

from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.conditions import TextMentionTermination
from autogen_agentchat.teams import RoundRobinGroupChat
from autogen_agentchat.ui import Console
from autogen_ext.models.openai import OpenAIChatCompletionClient
from autogen_ext.tools.mcp import StdioServerParams, create_mcp_server_session, mcp_server_tools


async def main() -> None:
    model_client = OpenAIChatCompletionClient(model="gpt-4o", parallel_tool_calls=False)  # type: ignore
    params = StdioServerParams(
        command="npx",
        args=["@playwright/mcp@latest"],
        read_timeout_seconds=60,
    )
    async with create_mcp_server_session(params) as session:
        await session.initialize()
        tools = await mcp_server_tools(server_params=params, session=session)
        print(f"Tools: {[tool.name for tool in tools]}")

        agent = AssistantAgent(
            name="Assistant",
            model_client=model_client,
            tools=tools,  # type: ignore
        )

        termination = TextMentionTermination("TERMINATE")
        team = RoundRobinGroupChat([agent], termination_condition=termination)
        await Console(
            team.run_stream(
                task="Go to https://ekzhu.com/, visit the first link in the page, then tell me about the linked page."
            )
        )


asyncio.run(main())
``` 

Based on discussion in this thread: #6284, we will consider
serialization and deserialization of MCP server tools when used in this
manner in a separate issue.

This PR also replaces the `json_schema_to_pydantic` dependency with
built-in utils.
This commit is contained in:
Eric Zhu 2025-04-16 17:43:28 -07:00 committed by GitHub
parent 844de21c00
commit 27b834f296
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 179 additions and 46 deletions

View File

@ -229,11 +229,14 @@ class _JSONSchemaToPydantic:
item_type = self.get_ref(item_schema["$ref"].split("/")[-1])
else:
item_type_name = item_schema.get("type")
if item_type_name not in TYPE_MAPPING:
if item_type_name is None:
item_type = List[str]
elif item_type_name not in TYPE_MAPPING:
raise UnsupportedKeywordError(
f"Unsupported or missing item type `{item_type_name}` for array field `{key}` in `{model_name}`"
)
item_type = TYPE_MAPPING[item_type_name]
else:
item_type = TYPE_MAPPING[item_type_name]
base_type = conlist(item_type, **constraints) if constraints else List[item_type] # type: ignore[valid-type]

View File

@ -135,10 +135,7 @@ semantic-kernel-all = [
rich = ["rich>=13.9.4"]
mcp = [
"mcp>=1.6.0",
"json-schema-to-pydantic>=0.2.4"
]
mcp = ["mcp>=1.6.0"]
[tool.hatch.build.targets.wheel]
packages = ["src/autogen_ext"]

View File

@ -1,4 +1,4 @@
from .memory_controller import MemoryController, MemoryControllerConfig
from ._memory_bank import MemoryBankConfig
from .memory_controller import MemoryController, MemoryControllerConfig
__all__ = ["MemoryController", "MemoryControllerConfig", "MemoryBankConfig"]

View File

@ -1,9 +1,11 @@
from ._config import McpServerParams, SseServerParams, StdioServerParams
from ._factory import mcp_server_tools
from ._session import create_mcp_server_session
from ._sse import SseMcpToolAdapter
from ._stdio import StdioMcpToolAdapter
__all__ = [
"create_mcp_server_session",
"StdioMcpToolAdapter",
"StdioServerParams",
"SseMcpToolAdapter",

View File

@ -1,12 +1,12 @@
import asyncio
import builtins
from abc import ABC
from typing import Any, Generic, Type, TypeVar
from typing import Any, Dict, Generic, Type, TypeVar
from autogen_core import CancellationToken
from autogen_core.tools import BaseTool
from json_schema_to_pydantic import create_model
from mcp import Tool
from autogen_core.utils import schema_to_pydantic_model
from mcp import ClientSession, Tool
from pydantic import BaseModel
from ._config import McpServerParams
@ -26,16 +26,17 @@ class McpToolAdapter(BaseTool[BaseModel, Any], ABC, Generic[TServerParams]):
component_type = "tool"
def __init__(self, server_params: TServerParams, tool: Tool) -> None:
def __init__(self, server_params: TServerParams, tool: Tool, session: ClientSession | None = None) -> None:
self._tool = tool
self._server_params = server_params
self._session = session
# Extract name and description
name = tool.name
description = tool.description or ""
# Create the input model from the tool's schema
input_model = create_model(tool.inputSchema, allow_undefined_array_items=True)
input_model = schema_to_pydantic_model(tool.inputSchema)
# Use Any as return type since MCP tool returns can vary
return_type: Type[Any] = object
@ -61,20 +62,27 @@ class McpToolAdapter(BaseTool[BaseModel, Any], ABC, Generic[TServerParams]):
# for many servers.
kwargs = args.model_dump(exclude_unset=True)
if self._session is not None:
# If a session is provided, use it directly.
session = self._session
return await self._run(args=kwargs, cancellation_token=cancellation_token, session=session)
async with create_mcp_server_session(self._server_params) as session:
await session.initialize()
return await self._run(args=kwargs, cancellation_token=cancellation_token, session=session)
async def _run(self, args: Dict[str, Any], cancellation_token: CancellationToken, session: ClientSession) -> Any:
try:
async with create_mcp_server_session(self._server_params) as session:
await session.initialize()
if cancellation_token.is_cancelled():
raise Exception("Operation cancelled")
if cancellation_token.is_cancelled():
raise Exception("Operation cancelled")
result_future = asyncio.ensure_future(session.call_tool(name=self._tool.name, arguments=args))
cancellation_token.link_future(result_future)
result = await result_future
result_future = asyncio.ensure_future(session.call_tool(name=self._tool.name, arguments=kwargs))
cancellation_token.link_future(result_future)
result = await result_future
if result.isError:
raise Exception(f"MCP tool execution failed: {result.content}")
return result.content
if result.isError:
raise Exception(f"MCP tool execution failed: {result.content}")
return result.content
except Exception as e:
error_message = self._format_errors(e)
raise Exception(error_message) from e

View File

@ -1,3 +1,5 @@
from mcp import ClientSession
from ._config import McpServerParams, SseServerParams, StdioServerParams
from ._session import create_mcp_server_session
from ._sse import SseMcpToolAdapter
@ -6,6 +8,7 @@ from ._stdio import StdioMcpToolAdapter
async def mcp_server_tools(
server_params: McpServerParams,
session: ClientSession | None = None,
) -> list[StdioMcpToolAdapter | SseMcpToolAdapter]:
"""Creates a list of MCP tool adapters that can be used with AutoGen agents.
@ -24,6 +27,9 @@ async def mcp_server_tools(
server_params (McpServerParams): Connection parameters for the MCP server.
Can be either StdioServerParams for command-line tools or
SseServerParams for HTTP/SSE services.
session (ClientSession | None): Optional existing session to use. This is used
when you want to reuse an existing connection to the MCP server. The session
will be reused when creating the MCP tool adapters.
Returns:
list[StdioMcpToolAdapter | SseMcpToolAdapter]: A list of tool adapters ready to use
@ -110,6 +116,58 @@ async def mcp_server_tools(
asyncio.run(main())
**Sharing an MCP client session across multiple tools:**
You can create a single MCP client session and share it across multiple tools.
This is sometimes required when the server maintains a session state
(e.g., a browser state) that should be reused for multiple requests.
The following example show how to create a single MCP client session
to a local `Playwright <https://github.com/microsoft/playwright-mcp>`_
server and share it across multiple tools.
.. code-block:: python
import asyncio
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.conditions import TextMentionTermination
from autogen_agentchat.teams import RoundRobinGroupChat
from autogen_agentchat.ui import Console
from autogen_ext.models.openai import OpenAIChatCompletionClient
from autogen_ext.tools.mcp import StdioServerParams, create_mcp_server_session, mcp_server_tools
async def main() -> None:
model_client = OpenAIChatCompletionClient(model="gpt-4o", parallel_tool_calls=False) # type: ignore
params = StdioServerParams(
command="npx",
args=["@playwright/mcp@latest"],
read_timeout_seconds=60,
)
async with create_mcp_server_session(params) as session:
await session.initialize()
tools = await mcp_server_tools(server_params=params, session=session)
print(f"Tools: {[tool.name for tool in tools]}")
agent = AssistantAgent(
name="Assistant",
model_client=model_client,
tools=tools, # type: ignore
)
termination = TextMentionTermination("TERMINATE")
team = RoundRobinGroupChat([agent], termination_condition=termination)
await Console(
team.run_stream(
task="Go to https://ekzhu.com/, visit the first link in the page, then tell me about the linked page."
)
)
asyncio.run(main())
**Remote MCP service over SSE example:**
@ -130,13 +188,16 @@ async def mcp_server_tools(
For more examples and detailed usage, see the samples directory in the package repository.
"""
async with create_mcp_server_session(server_params) as session:
await session.initialize()
if session is None:
async with create_mcp_server_session(server_params) as temp_session:
await temp_session.initialize()
tools = await temp_session.list_tools()
else:
tools = await session.list_tools()
if isinstance(server_params, StdioServerParams):
return [StdioMcpToolAdapter(server_params=server_params, tool=tool) for tool in tools.tools]
return [StdioMcpToolAdapter(server_params=server_params, tool=tool, session=session) for tool in tools.tools]
elif isinstance(server_params, SseServerParams):
return [SseMcpToolAdapter(server_params=server_params, tool=tool) for tool in tools.tools]
return [SseMcpToolAdapter(server_params=server_params, tool=tool, session=session) for tool in tools.tools]
raise ValueError(f"Unsupported server params type: {type(server_params)}")

View File

@ -1,5 +1,5 @@
from autogen_core import Component
from mcp import Tool
from mcp import ClientSession, Tool
from pydantic import BaseModel
from typing_extensions import Self
@ -35,8 +35,11 @@ class SseMcpToolAdapter(
Args:
server_params (SseServerParameters): Parameters for the MCP server connection,
including URL, headers, and timeouts
tool (Tool): The MCP tool to wrap
including URL, headers, and timeouts.
tool (Tool): The MCP tool to wrap.
session (ClientSession, optional): The MCP client session to use. If not provided,
it will create a new session. This is useful for testing or when you want to
manage the session lifecycle yourself.
Examples:
Use a remote translation service that implements MCP over SSE to create tools
@ -86,8 +89,8 @@ class SseMcpToolAdapter(
component_config_schema = SseMcpToolAdapterConfig
component_provider_override = "autogen_ext.tools.mcp.SseMcpToolAdapter"
def __init__(self, server_params: SseServerParams, tool: Tool) -> None:
super().__init__(server_params=server_params, tool=tool)
def __init__(self, server_params: SseServerParams, tool: Tool, session: ClientSession | None = None) -> None:
super().__init__(server_params=server_params, tool=tool, session=session)
def _to_config(self) -> SseMcpToolAdapterConfig:
"""

View File

@ -1,5 +1,5 @@
from autogen_core import Component
from mcp import Tool
from mcp import ClientSession, Tool
from pydantic import BaseModel
from typing_extensions import Self
@ -37,6 +37,9 @@ class StdioMcpToolAdapter(
server_params (StdioServerParams): Parameters for the MCP server connection,
including command to run and its arguments
tool (Tool): The MCP tool to wrap
session (ClientSession, optional): The MCP client session to use. If not provided,
a new session will be created. This is useful for testing or when you want to
manage the session lifecycle yourself.
See :func:`~autogen_ext.tools.mcp.mcp_server_tools` for examples.
"""
@ -44,8 +47,8 @@ class StdioMcpToolAdapter(
component_config_schema = StdioMcpToolAdapterConfig
component_provider_override = "autogen_ext.tools.mcp.StdioMcpToolAdapter"
def __init__(self, server_params: StdioServerParams, tool: Tool) -> None:
super().__init__(server_params=server_params, tool=tool)
def __init__(self, server_params: StdioServerParams, tool: Tool, session: ClientSession | None = None) -> None:
super().__init__(server_params=server_params, tool=tool, session=session)
def _to_config(self) -> StdioMcpToolAdapterConfig:
"""

View File

@ -4,14 +4,15 @@ from unittest.mock import AsyncMock, MagicMock
import pytest
from autogen_core import CancellationToken
from autogen_core.utils import schema_to_pydantic_model
from autogen_ext.tools.mcp import (
SseMcpToolAdapter,
SseServerParams,
StdioMcpToolAdapter,
StdioServerParams,
create_mcp_server_session,
mcp_server_tools,
)
from json_schema_to_pydantic import create_model
from mcp import ClientSession, Tool
@ -127,7 +128,7 @@ async def test_mcp_tool_execution(
with caplog.at_level(logging.INFO):
adapter = StdioMcpToolAdapter(server_params=sample_server_params, tool=sample_tool)
result = await adapter.run_json(
args=create_model(sample_tool.inputSchema)(**{"test_param": "test"}).model_dump(),
args=schema_to_pydantic_model(sample_tool.inputSchema)(**{"test_param": "test"}).model_dump(),
cancellation_token=cancellation_token,
)
@ -179,6 +180,48 @@ async def test_adapter_from_server_params(
)
@pytest.mark.asyncio
async def test_adapter_from_factory(
sample_tool: Tool,
sample_server_params: StdioServerParams,
mock_session: AsyncMock,
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Test that factory function returns a list of tools."""
mock_context = AsyncMock()
mock_context.__aenter__.return_value = mock_session
monkeypatch.setattr(
"autogen_ext.tools.mcp._factory.create_mcp_server_session",
lambda *args, **kwargs: mock_context, # type: ignore
)
mock_session.list_tools.return_value.tools = [sample_tool]
tools = await mcp_server_tools(server_params=sample_server_params)
assert tools is not None
assert len(tools) > 0
assert isinstance(tools[0], StdioMcpToolAdapter)
@pytest.mark.asyncio
async def test_adapter_from_factory_existing_session(
sample_tool: Tool,
sample_server_params: StdioServerParams,
mock_session: AsyncMock,
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Test that factory function returns a list of tools with an existing session."""
mock_context = AsyncMock()
mock_context.__aenter__.return_value = mock_session
monkeypatch.setattr(
"autogen_ext.tools.mcp._factory.create_mcp_server_session",
lambda *args, **kwargs: mock_context, # type: ignore
)
mock_session.list_tools.return_value.tools = [sample_tool]
tools = await mcp_server_tools(server_params=sample_server_params, session=mock_session)
assert tools is not None
assert len(tools) > 0
assert isinstance(tools[0], StdioMcpToolAdapter)
@pytest.mark.asyncio
async def test_sse_adapter_config_serialization(sample_sse_tool: Tool) -> None:
"""Test that SSE adapter can be saved to and loaded from config."""
@ -231,7 +274,7 @@ async def test_sse_tool_execution(
with caplog.at_level(logging.INFO):
adapter = SseMcpToolAdapter(server_params=params, tool=sample_sse_tool)
result = await adapter.run_json(
args=create_model(sample_sse_tool.inputSchema)(**{"test_param": "test"}).model_dump(),
args=schema_to_pydantic_model(sample_sse_tool.inputSchema)(**{"test_param": "test"}).model_dump(),
cancellation_token=CancellationToken(),
)
@ -284,8 +327,6 @@ async def test_sse_adapter_from_server_params(
)
# TODO: why is this test not working in CI?
@pytest.mark.skip(reason="Skipping test_mcp_server_fetch due to CI issues.")
@pytest.mark.asyncio
async def test_mcp_server_fetch() -> None:
params = StdioServerParams(
@ -300,8 +341,6 @@ async def test_mcp_server_fetch() -> None:
assert result is not None
# TODO: why is this test not working in CI?
@pytest.mark.skip(reason="Skipping due to CI issues.")
@pytest.mark.asyncio
async def test_mcp_server_filesystem() -> None:
params = StdioServerParams(
@ -322,8 +361,6 @@ async def test_mcp_server_filesystem() -> None:
assert result is not None
# TODO: why is this test not working in CI?
@pytest.mark.skip(reason="Skipping due to CI issues.")
@pytest.mark.asyncio
async def test_mcp_server_git() -> None:
params = StdioServerParams(
@ -341,6 +378,27 @@ async def test_mcp_server_git() -> None:
assert result is not None
@pytest.mark.asyncio
async def test_mcp_server_git_existing_session() -> None:
params = StdioServerParams(
command="uvx",
args=["mcp-server-git"],
read_timeout_seconds=60,
)
async with create_mcp_server_session(params) as session:
await session.initialize()
tools = await mcp_server_tools(server_params=params, session=session)
assert tools is not None
git_log = [tool for tool in tools if tool.name == "git_log"][0]
repo_path = os.path.join(os.path.dirname(__file__), "..", "..", "..", "..", "..")
result = await git_log.run_json({"repo_path": repo_path}, CancellationToken())
assert result is not None
git_status = [tool for tool in tools if tool.name == "git_status"][0]
result = await git_status.run_json({"repo_path": repo_path}, CancellationToken())
assert result is not None
@pytest.mark.asyncio
async def test_mcp_server_github() -> None:
# Check if GITHUB_TOKEN is set.

2
python/uv.lock generated
View File

@ -644,7 +644,6 @@ magentic-one = [
{ name = "playwright" },
]
mcp = [
{ name = "json-schema-to-pydantic" },
{ name = "mcp" },
]
ollama = [
@ -746,7 +745,6 @@ requires-dist = [
{ name = "httpx", marker = "extra == 'http-tool'", specifier = ">=0.27.0" },
{ name = "ipykernel", marker = "extra == 'jupyter-executor'", specifier = ">=6.29.5" },
{ name = "json-schema-to-pydantic", marker = "extra == 'http-tool'", specifier = ">=0.2.0" },
{ name = "json-schema-to-pydantic", marker = "extra == 'mcp'", specifier = ">=0.2.4" },
{ name = "langchain-core", marker = "extra == 'langchain'", specifier = "~=0.3.3" },
{ name = "llama-cpp-python", marker = "extra == 'llama-cpp'", specifier = ">=0.3.8" },
{ name = "magika", marker = "extra == 'file-surfer'", specifier = ">=0.6.1rc2" },