diff --git a/python/packages/autogen-core/docs/src/user-guide/core-user-guide/cookbook/llamaindex-agent.ipynb b/python/packages/autogen-core/docs/src/user-guide/core-user-guide/cookbook/llamaindex-agent.ipynb index 08a322403..03894e686 100644 --- a/python/packages/autogen-core/docs/src/user-guide/core-user-guide/cookbook/llamaindex-agent.ipynb +++ b/python/packages/autogen-core/docs/src/user-guide/core-user-guide/cookbook/llamaindex-agent.ipynb @@ -38,7 +38,6 @@ "outputs": [], "source": [ "import os\n", - "from pydantic import BaseModel\n", "from typing import List, Optional\n", "\n", "from autogen_core import AgentId, MessageContext, RoutedAgent, SingleThreadedAgentRuntime, message_handler\n", @@ -57,7 +56,8 @@ "from llama_index.embeddings.openai import OpenAIEmbedding\n", "from llama_index.llms.azure_openai import AzureOpenAI\n", "from llama_index.llms.openai import OpenAI\n", - "from llama_index.tools.wikipedia import WikipediaToolSpec" + "from llama_index.tools.wikipedia import WikipediaToolSpec\n", + "from pydantic import BaseModel" ] }, { diff --git a/python/packages/autogen-core/src/autogen_core/_component_config.py b/python/packages/autogen-core/src/autogen_core/_component_config.py index ce768ddaf..bb603a839 100644 --- a/python/packages/autogen-core/src/autogen_core/_component_config.py +++ b/python/packages/autogen-core/src/autogen_core/_component_config.py @@ -7,7 +7,7 @@ from typing import Any, ClassVar, Dict, Generic, Literal, Type, TypeGuard, cast, from pydantic import BaseModel from typing_extensions import Self, TypeVar -ComponentType = Literal["model", "agent", "tool", "termination", "token_provider"] | str +ComponentType = Literal["model", "agent", "tool", "termination", "token_provider", "workbench"] | str ConfigT = TypeVar("ConfigT", bound=BaseModel) FromConfigT = TypeVar("FromConfigT", bound=BaseModel, contravariant=True) ToConfigT = TypeVar("ToConfigT", bound=BaseModel, covariant=True) diff --git a/python/packages/autogen-core/src/autogen_core/tools/__init__.py b/python/packages/autogen-core/src/autogen_core/tools/__init__.py index 52a9d725f..4f48463a0 100644 --- a/python/packages/autogen-core/src/autogen_core/tools/__init__.py +++ b/python/packages/autogen-core/src/autogen_core/tools/__init__.py @@ -1,5 +1,7 @@ from ._base import BaseTool, BaseToolWithState, ParametersSchema, Tool, ToolSchema from ._function_tool import FunctionTool +from ._static_workbench import StaticWorkbench +from ._workbench import ImageResultContent, TextResultContent, ToolResult, Workbench __all__ = [ "Tool", @@ -8,4 +10,9 @@ __all__ = [ "BaseTool", "BaseToolWithState", "FunctionTool", + "Workbench", + "ToolResult", + "TextResultContent", + "ImageResultContent", + "StaticWorkbench", ] diff --git a/python/packages/autogen-core/src/autogen_core/tools/_function_tool.py b/python/packages/autogen-core/src/autogen_core/tools/_function_tool.py index 048b26525..985d7d1d1 100644 --- a/python/packages/autogen-core/src/autogen_core/tools/_function_tool.py +++ b/python/packages/autogen-core/src/autogen_core/tools/_function_tool.py @@ -178,4 +178,4 @@ class FunctionTool(BaseTool[BaseModel, BaseModel], Component[FunctionToolConfig] if not callable(func): raise TypeError(f"Expected function but got {type(func)}") - return cls(func, "", None) + return cls(func, name=config.name, description=config.description, global_imports=config.global_imports) diff --git a/python/packages/autogen-core/src/autogen_core/tools/_static_workbench.py b/python/packages/autogen-core/src/autogen_core/tools/_static_workbench.py new file mode 100644 index 000000000..5a997defa --- /dev/null +++ b/python/packages/autogen-core/src/autogen_core/tools/_static_workbench.py @@ -0,0 +1,88 @@ +import asyncio +from typing import Any, Dict, List, Literal, Mapping + +from pydantic import BaseModel +from typing_extensions import Self + +from .._cancellation_token import CancellationToken +from .._component_config import Component, ComponentModel +from ._base import BaseTool, ToolSchema +from ._workbench import TextResultContent, ToolResult, Workbench + + +class StaticWorkbenchConfig(BaseModel): + tools: List[ComponentModel] = [] + + +class StateicWorkbenchState(BaseModel): + type: Literal["StaticWorkbenchState"] = "StaticWorkbenchState" + tools: Dict[str, Mapping[str, Any]] = {} + + +class StaticWorkbench(Workbench, Component[StaticWorkbenchConfig]): + """ + A workbench that provides a static set of tools that do not change after + each tool execution. + + Args: + tools (List[BaseTool[Any, Any]]): A list of tools to be included in the workbench. + The tools should be subclasses of :class:`~autogen_core.tools.BaseTool`. + """ + + component_provider_override = "autogen_core.tools.StaticWorkbench" + component_config_schema = StaticWorkbenchConfig + + def __init__(self, tools: List[BaseTool[Any, Any]]) -> None: + self._tools = tools + + async def list_tools(self) -> List[ToolSchema]: + return [tool.schema for tool in self._tools] + + async def call_tool( + self, name: str, arguments: Mapping[str, Any] | None = None, cancellation_token: CancellationToken | None = None + ) -> ToolResult: + tool = next((tool for tool in self._tools if tool.name == name), None) + if tool is None: + raise ValueError(f"Tool {name} not found in workbench.") + if not cancellation_token: + cancellation_token = CancellationToken() + if not arguments: + arguments = {} + try: + result_future = asyncio.ensure_future(tool.run_json(arguments, cancellation_token)) + cancellation_token.link_future(result_future) + result = await result_future + is_error = False + except Exception as e: + result = str(e) + is_error = True + result_str = tool.return_value_as_string(result) + return ToolResult(name=tool.name, result=[TextResultContent(content=result_str)], is_error=is_error) + + async def start(self) -> None: + return None + + async def stop(self) -> None: + return None + + async def reset(self) -> None: + return None + + async def save_state(self) -> Mapping[str, Any]: + tool_states = StateicWorkbenchState() + for tool in self._tools: + tool_states.tools[tool.name] = await tool.save_state_json() + return tool_states.model_dump() + + async def load_state(self, state: Mapping[str, Any]) -> None: + parsed_state = StateicWorkbenchState.model_validate(state) + for tool in self._tools: + if tool.name in parsed_state.tools: + await tool.load_state_json(parsed_state.tools[tool.name]) + + def _to_config(self) -> StaticWorkbenchConfig: + return StaticWorkbenchConfig(tools=[tool.dump_component() for tool in self._tools]) + + @classmethod + def _from_config(cls, config: StaticWorkbenchConfig) -> Self: + return cls(tools=[BaseTool.load_component(tool) for tool in config.tools]) diff --git a/python/packages/autogen-core/src/autogen_core/tools/_workbench.py b/python/packages/autogen-core/src/autogen_core/tools/_workbench.py new file mode 100644 index 000000000..23f1b4769 --- /dev/null +++ b/python/packages/autogen-core/src/autogen_core/tools/_workbench.py @@ -0,0 +1,164 @@ +from abc import ABC, abstractmethod +from types import TracebackType +from typing import Any, List, Literal, Mapping, Optional, Type + +from pydantic import BaseModel, Field +from typing_extensions import Annotated, Self + +from .._cancellation_token import CancellationToken +from .._component_config import ComponentBase +from .._image import Image +from ._base import ToolSchema + + +class TextResultContent(BaseModel): + """ + Text result content of a tool execution. + """ + + type: Literal["TextResultContent"] = "TextResultContent" + + content: str + """The text content of the result.""" + + +class ImageResultContent(BaseModel): + """ + Image result content of a tool execution. + """ + + type: Literal["ImageResultContent"] = "ImageResultContent" + + content: Image + """The image content of the result.""" + + +ResultContent = Annotated[TextResultContent | ImageResultContent, Field(discriminator="type")] + + +class ToolResult(BaseModel): + """ + A result of a tool execution by a workbench. + """ + + type: Literal["ToolResult"] = "ToolResult" + + name: str + """The name of the tool that was executed.""" + + result: List[ResultContent] + """The result of the tool execution.""" + + is_error: bool = False + """Whether the tool execution resulted in an error.""" + + +class Workbench(ABC, ComponentBase[BaseModel]): + """ + A workbench is a component that provides a set of tools that may share + resources and state. + + A workbench is responsible for managing the lifecycle of the tools and + providing a single interface to call them. The tools provided by the workbench + may be dynamic and their availabilities may change after each tool execution. + + A workbench can be started by calling the :meth:`~autogen_core.tools.Workbench.start` method + and stopped by calling the :meth:`~autogen_core.tools.Workbench.stop` method. + It can also be used as an asynchronous context manager, which will automatically + start and stop the workbench when entering and exiting the context. + """ + + component_type = "workbench" + + @abstractmethod + async def list_tools(self) -> List[ToolSchema]: + """ + List the currently available tools in the workbench as :class:`ToolSchema` + objects. + + The list of tools may be dynamic, and their content may change after + tool execution. + """ + ... + + @abstractmethod + async def call_tool( + self, name: str, arguments: Mapping[str, Any] | None = None, cancellation_token: CancellationToken | None = None + ) -> ToolResult: + """ + Call a tool in the workbench. + + Args: + name (str): The name of the tool to call. + arguments (Mapping[str, Any] | None): The arguments to pass to the tool. + If None, the tool will be called with no arguments. + cancellation_token (CancellationToken | None): An optional cancellation token + to cancel the tool execution. + Returns: + ToolResult: The result of the tool execution. + """ + ... + + @abstractmethod + async def start(self) -> None: + """ + Start the workbench and initialize any resources. + + This method should be called before using the workbench. + """ + ... + + @abstractmethod + async def stop(self) -> None: + """ + Stop the workbench and release any resources. + + This method should be called when the workbench is no longer needed. + """ + ... + + @abstractmethod + async def reset(self) -> None: + """ + Reset the workbench to its initialized, started state. + """ + ... + + @abstractmethod + async def save_state(self) -> Mapping[str, Any]: + """ + Save the state of the workbench. + + This method should be called to persist the state of the workbench. + """ + ... + + @abstractmethod + async def load_state(self, state: Mapping[str, Any]) -> None: + """ + Load the state of the workbench. + + Args: + state (Mapping[str, Any]): The state to load into the workbench. + """ + ... + + async def __aenter__(self) -> Self: + """ + Enter the workbench context manager. + + This method is called when the workbench is used in a `with` statement. + It calls the :meth:`~autogen_core.tools.WorkBench.start` method to start the workbench. + """ + await self.start() + return self + + async def __aexit__( + self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType] + ) -> None: + """ + Exit the workbench context manager. + This method is called when the workbench is used in a `with` statement. + It calls the :meth:`~autogen_core.tools.WorkBench.stop` method to stop the workbench. + """ + await self.stop() diff --git a/python/packages/autogen-core/tests/test_workbench.py b/python/packages/autogen-core/tests/test_workbench.py new file mode 100644 index 000000000..0de372c45 --- /dev/null +++ b/python/packages/autogen-core/tests/test_workbench.py @@ -0,0 +1,119 @@ +from typing import Annotated + +import pytest +from autogen_core.code_executor import ImportFromModule +from autogen_core.tools import FunctionTool, StaticWorkbench, Workbench + + +@pytest.mark.asyncio +async def test_static_workbench() -> None: + def test_tool_func_1(x: Annotated[int, "The number to double."]) -> int: + return x * 2 + + def test_tool_func_2(x: Annotated[int, "The number to add 2."]) -> int: + raise ValueError("This is a test error") # Simulate an error + + test_tool_1 = FunctionTool( + test_tool_func_1, + name="test_tool_1", + description="A test tool that doubles a number.", + global_imports=[ImportFromModule(module="typing_extensions", imports=["Annotated"])], + ) + test_tool_2 = FunctionTool( + test_tool_func_2, + name="test_tool_2", + description="A test tool that adds 2 to a number.", + global_imports=[ImportFromModule(module="typing_extensions", imports=["Annotated"])], + ) + + # Create a StaticWorkbench instance with the test tools. + async with StaticWorkbench(tools=[test_tool_1, test_tool_2]) as workbench: + # List tools + tools = await workbench.list_tools() + assert len(tools) == 2 + assert "description" in tools[0] + assert "parameters" in tools[0] + assert tools[0]["name"] == "test_tool_1" + assert tools[0]["description"] == "A test tool that doubles a number." + assert tools[0]["parameters"] == { + "type": "object", + "properties": {"x": {"type": "integer", "title": "X", "description": "The number to double."}}, + "required": ["x"], + "additionalProperties": False, + } + assert "description" in tools[1] + assert "parameters" in tools[1] + assert tools[1]["name"] == "test_tool_2" + assert tools[1]["description"] == "A test tool that adds 2 to a number." + assert tools[1]["parameters"] == { + "type": "object", + "properties": {"x": {"type": "integer", "title": "X", "description": "The number to add 2."}}, + "required": ["x"], + "additionalProperties": False, + } + + # Call tools + result_1 = await workbench.call_tool("test_tool_1", {"x": 5}) + assert result_1.name == "test_tool_1" + assert result_1.result[0].type == "TextResultContent" + assert result_1.result[0].content == "10" + assert result_1.is_error is False + + # Call tool with error + result_2 = await workbench.call_tool("test_tool_2", {"x": 5}) + assert result_2.name == "test_tool_2" + assert result_2.result[0].type == "TextResultContent" + assert result_2.result[0].content == "This is a test error" + assert result_2.is_error is True + + # Save state. + state = await workbench.save_state() + assert state["type"] == "StaticWorkbenchState" + assert "tools" in state + assert len(state["tools"]) == 2 + + # Dump config. + config = workbench.dump_component() + + # Load the workbench from the config. + async with Workbench.load_component(config) as new_workbench: + # Load state. + await new_workbench.load_state(state) + + # Verify that the tools are still available after loading the state. + tools = await new_workbench.list_tools() + assert len(tools) == 2 + assert "description" in tools[0] + assert "parameters" in tools[0] + assert tools[0]["name"] == "test_tool_1" + assert tools[0]["description"] == "A test tool that doubles a number." + assert tools[0]["parameters"] == { + "type": "object", + "properties": {"x": {"type": "integer", "title": "X", "description": "The number to double."}}, + "required": ["x"], + "additionalProperties": False, + } + assert "description" in tools[1] + assert "parameters" in tools[1] + assert tools[1]["name"] == "test_tool_2" + assert tools[1]["description"] == "A test tool that adds 2 to a number." + assert tools[1]["parameters"] == { + "type": "object", + "properties": {"x": {"type": "integer", "title": "X", "description": "The number to add 2."}}, + "required": ["x"], + "additionalProperties": False, + } + + # Call tools + result_1 = await new_workbench.call_tool("test_tool_1", {"x": 5}) + assert result_1.name == "test_tool_1" + assert result_1.result[0].type == "TextResultContent" + assert result_1.result[0].content == "10" + assert result_1.is_error is False + + # Call tool with error + result_2 = await new_workbench.call_tool("test_tool_2", {"x": 5}) + assert result_2.name == "test_tool_2" + assert result_2.result[0].type == "TextResultContent" + assert result_2.result[0].content == "This is a test error" + assert result_2.is_error is True diff --git a/python/packages/autogen-ext/src/autogen_ext/tools/mcp/__init__.py b/python/packages/autogen-ext/src/autogen_ext/tools/mcp/__init__.py index eeae32f1c..bde36794c 100644 --- a/python/packages/autogen-ext/src/autogen_ext/tools/mcp/__init__.py +++ b/python/packages/autogen-ext/src/autogen_ext/tools/mcp/__init__.py @@ -1,15 +1,19 @@ +from ._actor import McpSessionActor from ._config import McpServerParams, SseServerParams, StdioServerParams from ._factory import mcp_server_tools from ._session import create_mcp_server_session from ._sse import SseMcpToolAdapter from ._stdio import StdioMcpToolAdapter +from ._workbench import McpWorkbench __all__ = [ "create_mcp_server_session", + "McpSessionActor", "StdioMcpToolAdapter", "StdioServerParams", "SseMcpToolAdapter", "SseServerParams", "McpServerParams", "mcp_server_tools", + "McpWorkbench", ] diff --git a/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_actor.py b/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_actor.py new file mode 100644 index 000000000..7e84e24b5 --- /dev/null +++ b/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_actor.py @@ -0,0 +1,147 @@ +import asyncio +import atexit +from typing import Any, Coroutine, Dict, Mapping, TypedDict + +from autogen_core import Component, ComponentBase +from mcp.types import CallToolResult, ListToolsResult +from pydantic import BaseModel +from typing_extensions import Self + +from ._config import McpServerParams +from ._session import create_mcp_server_session + +McpResult = Coroutine[Any, Any, ListToolsResult] | Coroutine[Any, Any, CallToolResult] +McpFuture = asyncio.Future[McpResult] + + +class McpActorArgs(TypedDict): + name: str | None + kargs: Mapping[str, Any] + + +class McpSessionActorConfig(BaseModel): + server_params: McpServerParams + + +class McpSessionActor(ComponentBase[BaseModel], Component[McpSessionActorConfig]): + component_type = "mcp_session_actor" + component_config_schema = McpSessionActorConfig + component_provider_override = "autogen_ext.tools.mcp.McpSessionActor" + + server_params: McpServerParams + + # model_config = ConfigDict(arbitrary_types_allowed=True) + + def __init__(self, server_params: McpServerParams) -> None: + self.server_params: McpServerParams = server_params + self.name = "mcp_session_actor" + self.description = "MCP session actor" + self._command_queue: asyncio.Queue[Dict[str, Any]] = asyncio.Queue() + self._actor_task: asyncio.Task[Any] | None = None + self._shutdown_future: asyncio.Future[Any] | None = None + self._active = False + atexit.register(self._sync_shutdown) + + async def initialize(self) -> None: + if not self._active: + self._active = True + self._actor_task = asyncio.create_task(self._run_actor()) + + async def call(self, type: str, args: McpActorArgs | None = None) -> McpFuture: + if not self._active: + raise RuntimeError("MCP Actor not running, call initialize() first") + if self._actor_task and self._actor_task.done(): + raise RuntimeError("MCP actor task crashed", self._actor_task.exception()) + fut: asyncio.Future[McpFuture] = asyncio.Future() + if type in {"list_tools", "shutdown"}: + await self._command_queue.put({"type": type, "future": fut}) + res = await fut + elif type == "call_tool": + if args is None: + raise ValueError("args is required for call_tool") + name = args.get("name", None) + kwargs = args.get("kargs", {}) + if name is None: + raise ValueError("name is required for call_tool") + await self._command_queue.put({"type": type, "name": name, "args": kwargs, "future": fut}) + res = await fut + else: + raise ValueError(f"Unknown command type: {type}") + return res + + async def close(self) -> None: + if not self._active or self._actor_task is None: + return + self._shutdown_future = asyncio.Future() + await self._command_queue.put({"type": "shutdown", "future": self._shutdown_future}) + await self._shutdown_future + await self._actor_task + self._active = False + + async def _run_actor(self) -> None: + result: McpResult + try: + async with create_mcp_server_session(self.server_params) as session: + await session.initialize() + while True: + cmd = await self._command_queue.get() + if cmd["type"] == "shutdown": + cmd["future"].set_result("ok") + break + elif cmd["type"] == "call_tool": + try: + result = session.call_tool(name=cmd["name"], arguments=cmd["args"]) + cmd["future"].set_result(result) + except Exception as e: + cmd["future"].set_exception(e) + elif cmd["type"] == "list_tools": + try: + result = session.list_tools() + cmd["future"].set_result(result) + except Exception as e: + cmd["future"].set_exception(e) + except Exception as e: + if self._shutdown_future and not self._shutdown_future.done(): + self._shutdown_future.set_exception(e) + finally: + self._active = False + self._actor_task = None + + def _sync_shutdown(self) -> None: + if not self._active or self._actor_task is None: + return + try: + loop = asyncio.get_event_loop() + except RuntimeError: + # No loop available — interpreter is likely shutting down + return + + if loop.is_closed(): + return + + if loop.is_running(): + loop.create_task(self.close()) + else: + loop.run_until_complete(self.close()) + + def _to_config(self) -> McpSessionActorConfig: + """ + Convert the adapter to its configuration representation. + + Returns: + McpSessionConfig: The configuration of the adapter. + """ + return McpSessionActorConfig(server_params=self.server_params) + + @classmethod + def _from_config(cls, config: McpSessionActorConfig) -> Self: + """ + Create an instance of McpSessionActor from its configuration. + + Args: + config (McpSessionConfig): The configuration of the adapter. + + Returns: + McpSessionActor: An instance of SseMcpToolAdapter. + """ + return cls(server_params=config.server_params) diff --git a/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_config.py b/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_config.py index 236ff6892..215102c5e 100644 --- a/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_config.py +++ b/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_config.py @@ -1,22 +1,27 @@ -from typing import Any, TypeAlias +from typing import Any, Literal from mcp import StdioServerParameters -from pydantic import BaseModel +from pydantic import BaseModel, Field +from typing_extensions import Annotated class StdioServerParams(StdioServerParameters): """Parameters for connecting to an MCP server over STDIO.""" + type: Literal["StdioServerParams"] = "StdioServerParams" + read_timeout_seconds: float = 5 class SseServerParams(BaseModel): """Parameters for connecting to an MCP server over SSE.""" + type: Literal["SseServerParams"] = "SseServerParams" + url: str headers: dict[str, Any] | None = None timeout: float = 5 sse_read_timeout: float = 60 * 5 -McpServerParams: TypeAlias = StdioServerParams | SseServerParams +McpServerParams = Annotated[StdioServerParams | SseServerParams, Field(discriminator="type")] diff --git a/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_workbench.py b/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_workbench.py new file mode 100644 index 000000000..3272a86ee --- /dev/null +++ b/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_workbench.py @@ -0,0 +1,200 @@ +import builtins +import warnings +from typing import Any, List, Literal, Mapping + +from autogen_core import CancellationToken, Component, Image +from autogen_core.tools import ( + ImageResultContent, + ParametersSchema, + TextResultContent, + ToolResult, + ToolSchema, + Workbench, +) +from mcp.types import CallToolResult, EmbeddedResource, ImageContent, ListToolsResult, TextContent +from pydantic import BaseModel +from typing_extensions import Self + +from ._actor import McpSessionActor +from ._config import McpServerParams, SseServerParams, StdioServerParams + + +class McpWorkbenchConfig(BaseModel): + server_params: McpServerParams + + +class McpWorkbenchState(BaseModel): + type: Literal["McpWorkBenchState"] = "McpWorkBenchState" + + +class McpWorkbench(Workbench, Component[McpWorkbenchConfig]): + """ + A workbench that wraps an MCP server and provides an interface + to list and call tools provided by the server. + + Args: + server_params (McpServerParams): The parameters to connect to the MCP server. + This can be either a :class:`StdioServerParams` or :class:`SseServerParams`. + + Example: + + .. code-block:: python + + import asyncio + + from autogen_ext.tools.mcp import McpWorkbench, StdioServerParams + + + async def main() -> None: + params = StdioServerParams( + command="uvx", + args=["mcp-server-fetch"], + read_timeout_seconds=60, + ) + + # You can also use `start()` and `stop()` to manage the session. + async with McpWorkbench(server_params=params) as workbench: + tools = await workbench.list_tools() + print(tools) + result = await workbench.call_tool(tools[0]["name"], {"url": "https://github.com/"}) + print(result) + + + asyncio.run(main()) + + """ + + component_provider_override = "autogen_ext.tools.mcp.McpWorkbench" + component_config_schema = McpWorkbenchConfig + + def __init__(self, server_params: McpServerParams) -> None: + self._server_params = server_params + # self._session: ClientSession | None = None + self._actor: McpSessionActor | None = None + self._read = None + self._write = None + + @property + def server_params(self) -> McpServerParams: + return self._server_params + + async def list_tools(self) -> List[ToolSchema]: + if not self._actor: + await self.start() # fallback to start the actor if not initialized instead of raising an error + # Why? Because when deserializing the workbench, the actor might not be initialized yet. + # raise RuntimeError("Actor is not initialized. Call start() first.") + if self._actor is None: + raise RuntimeError("Actor is not initialized. Please check the server connection.") + result_future = await self._actor.call("list_tools", None) + list_tool_result = await result_future + assert isinstance( + list_tool_result, ListToolsResult + ), f"list_tools must return a CallToolResult, instead of : {str(type(list_tool_result))}" + schema: List[ToolSchema] = [] + for tool in list_tool_result.tools: + name = tool.name + description = tool.description or "" + parameters = ParametersSchema( + type="object", + properties=tool.inputSchema["properties"], + required=tool.inputSchema.get("required", []), + additionalProperties=tool.inputSchema.get("additionalProperties", False), + ) + tool_schema = ToolSchema( + name=name, + description=description, + parameters=parameters, + ) + schema.append(tool_schema) + return schema + + async def call_tool( + self, name: str, arguments: Mapping[str, Any] | None = None, cancellation_token: CancellationToken | None = None + ) -> ToolResult: + if not self._actor: + await self.start() # fallback to start the actor if not initialized instead of raising an error + # Why? Because when deserializing the workbench, the actor might not be initialized yet. + # raise RuntimeError("Actor is not initialized. Call start() first.") + if self._actor is None: + raise RuntimeError("Actor is not initialized. Please check the server connection.") + if not cancellation_token: + cancellation_token = CancellationToken() + if not arguments: + arguments = {} + try: + result_future = await self._actor.call("call_tool", {"name": name, "kargs": arguments}) + cancellation_token.link_future(result_future) + result = await result_future + assert isinstance( + result, CallToolResult + ), f"call_tool must return a CallToolResult, instead of : {str(type(result))}" + result_parts: List[TextResultContent | ImageResultContent] = [] + is_error = result.isError + for content in result.content: + if isinstance(content, TextContent): + result_parts.append(TextResultContent(content=content.text)) + elif isinstance(content, ImageContent): + result_parts.append(ImageResultContent(content=Image.from_base64(content.data))) + elif isinstance(content, EmbeddedResource): + # TODO: how to handle embedded resources? + # For now we just use text representation. + result_parts.append(TextResultContent(content=content.model_dump_json())) + else: + raise ValueError(f"Unknown content type from server: {type(content)}") + except Exception as e: + error_message = self._format_errors(e) + is_error = True + result_parts = [TextResultContent(content=error_message)] + return ToolResult(name=name, result=result_parts, is_error=is_error) + + def _format_errors(self, error: Exception) -> str: + """Recursively format errors into a string.""" + + error_message = "" + if hasattr(builtins, "ExceptionGroup") and isinstance(error, builtins.ExceptionGroup): + # ExceptionGroup is available in Python 3.11+. + # TODO: how to make this compatible with Python 3.10? + for sub_exception in error.exceptions: # type: ignore + error_message += self._format_errors(sub_exception) # type: ignore + else: + error_message += f"{str(error)}\n" + return error_message + + async def start(self) -> None: + if self._actor: + warnings.warn( + "McpWorkbench is already started. No need to start again.", + UserWarning, + stacklevel=2, + ) + return # Already initialized, no need to start again + + if isinstance(self._server_params, (StdioServerParams, SseServerParams)): + self._actor = McpSessionActor(self._server_params) + await self._actor.initialize() + else: + raise ValueError(f"Unsupported server params type: {type(self._server_params)}") + + async def stop(self) -> None: + if self._actor: + # Close the actor + await self._actor.close() + self._actor = None + else: + raise RuntimeError("McpWorkbench is not started. Call start() first.") + + async def reset(self) -> None: + pass + + async def save_state(self) -> Mapping[str, Any]: + return McpWorkbenchState().model_dump() + + async def load_state(self, state: Mapping[str, Any]) -> None: + pass + + def _to_config(self) -> McpWorkbenchConfig: + return McpWorkbenchConfig(server_params=self._server_params) + + @classmethod + def _from_config(cls, config: McpWorkbenchConfig) -> Self: + return cls(server_params=config.server_params) diff --git a/python/packages/autogen-ext/tests/tools/test_mcp_tools.py b/python/packages/autogen-ext/tests/tools/test_mcp_tools.py index 998d678ab..10f17c2c3 100644 --- a/python/packages/autogen-ext/tests/tools/test_mcp_tools.py +++ b/python/packages/autogen-ext/tests/tools/test_mcp_tools.py @@ -4,8 +4,10 @@ from unittest.mock import AsyncMock, MagicMock import pytest from autogen_core import CancellationToken +from autogen_core.tools import Workbench from autogen_core.utils import schema_to_pydantic_model from autogen_ext.tools.mcp import ( + McpWorkbench, SseMcpToolAdapter, SseServerParams, StdioMcpToolAdapter, @@ -422,3 +424,80 @@ async def test_mcp_server_github() -> None: {"owner": "microsoft", "repo": "autogen", "path": "python", "branch": "main"}, CancellationToken() ) assert result is not None + + +@pytest.mark.asyncio +async def test_mcp_workbench_start_stop() -> None: + params = StdioServerParams( + command="uvx", + args=["mcp-server-fetch"], + read_timeout_seconds=60, + ) + + workbench = McpWorkbench(params) + assert workbench is not None + assert workbench.server_params == params + await workbench.start() + assert workbench._actor is not None # type: ignore[reportPrivateUsage] + await workbench.stop() + assert workbench._actor is None # type: ignore[reportPrivateUsage] + + +@pytest.mark.asyncio +async def test_mcp_workbench_server_fetch() -> None: + params = StdioServerParams( + command="uvx", + args=["mcp-server-fetch"], + read_timeout_seconds=60, + ) + + workbench = McpWorkbench(server_params=params) + await workbench.start() + + tools = await workbench.list_tools() + assert tools is not None + assert tools[0]["name"] == "fetch" + + result = await workbench.call_tool(tools[0]["name"], {"url": "https://github.com/"}, CancellationToken()) + assert result is not None + + await workbench.stop() + + +@pytest.mark.asyncio +async def test_mcp_workbench_server_filesystem() -> None: + params = StdioServerParams( + command="npx", + args=[ + "-y", + "@modelcontextprotocol/server-filesystem", + ".", + ], + read_timeout_seconds=60, + ) + + workbench = McpWorkbench(server_params=params) + await workbench.start() + + tools = await workbench.list_tools() + assert tools is not None + tools = [tool for tool in tools if tool["name"] == "read_file"] + assert len(tools) == 1 + tool = tools[0] + result = await workbench.call_tool(tool["name"], {"path": "README.md"}, CancellationToken()) + assert result is not None + + await workbench.stop() + + # Serialize the workbench. + config = workbench.dump_component() + + # Deserialize the workbench. + async with Workbench.load_component(config) as new_workbench: + tools = await new_workbench.list_tools() + assert tools is not None + tools = [tool for tool in tools if tool["name"] == "read_file"] + assert len(tools) == 1 + tool = tools[0] + result = await new_workbench.call_tool(tool["name"], {"path": "README.md"}, CancellationToken()) + assert result is not None