Introduce workbench (#6340)

This PR introduces `WorkBench`.

A workbench provides a group of tools that share the same resource and
state. For example, `McpWorkbench` provides the underlying tools on the
MCP server. A workbench allows tools to be managed together and abstract
away the lifecycle of individual tools under a single entity. This makes
it possible to create agents with stateful tools from serializable
configuration (component configs), and it also supports dynamic tools:
tools change after each execution.

Here is how a workbench may be used with AssistantAgent (not included in
this PR):

```python
workbench = McpWorkbench(server_params)
agent = AssistantAgent("assistant", tools=workbench)
result = await agent.run(task="do task...")
```


TODOs:
1. In a subsequent PR, update `AssistantAgent` to use workbench as an
alternative in the `tools` parameter. Use `StaticWorkbench` to manage
individual tools.
2. In another PR, add documentation on workbench.

---------

Co-authored-by: EeS <chiyoung.song@motov.co.kr>
Co-authored-by: Minh Đăng <74671798+perfogic@users.noreply.github.com>
This commit is contained in:
Eric Zhu 2025-04-24 10:37:41 -07:00 committed by GitHub
parent a283d268df
commit 8fcba01704
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 820 additions and 7 deletions

View File

@ -38,7 +38,6 @@
"outputs": [],
"source": [
"import os\n",
"from pydantic import BaseModel\n",
"from typing import List, Optional\n",
"\n",
"from autogen_core import AgentId, MessageContext, RoutedAgent, SingleThreadedAgentRuntime, message_handler\n",
@ -57,7 +56,8 @@
"from llama_index.embeddings.openai import OpenAIEmbedding\n",
"from llama_index.llms.azure_openai import AzureOpenAI\n",
"from llama_index.llms.openai import OpenAI\n",
"from llama_index.tools.wikipedia import WikipediaToolSpec"
"from llama_index.tools.wikipedia import WikipediaToolSpec\n",
"from pydantic import BaseModel"
]
},
{

View File

@ -7,7 +7,7 @@ from typing import Any, ClassVar, Dict, Generic, Literal, Type, TypeGuard, cast,
from pydantic import BaseModel
from typing_extensions import Self, TypeVar
ComponentType = Literal["model", "agent", "tool", "termination", "token_provider"] | str
ComponentType = Literal["model", "agent", "tool", "termination", "token_provider", "workbench"] | str
ConfigT = TypeVar("ConfigT", bound=BaseModel)
FromConfigT = TypeVar("FromConfigT", bound=BaseModel, contravariant=True)
ToConfigT = TypeVar("ToConfigT", bound=BaseModel, covariant=True)

View File

@ -1,5 +1,7 @@
from ._base import BaseTool, BaseToolWithState, ParametersSchema, Tool, ToolSchema
from ._function_tool import FunctionTool
from ._static_workbench import StaticWorkbench
from ._workbench import ImageResultContent, TextResultContent, ToolResult, Workbench
__all__ = [
"Tool",
@ -8,4 +10,9 @@ __all__ = [
"BaseTool",
"BaseToolWithState",
"FunctionTool",
"Workbench",
"ToolResult",
"TextResultContent",
"ImageResultContent",
"StaticWorkbench",
]

View File

@ -178,4 +178,4 @@ class FunctionTool(BaseTool[BaseModel, BaseModel], Component[FunctionToolConfig]
if not callable(func):
raise TypeError(f"Expected function but got {type(func)}")
return cls(func, "", None)
return cls(func, name=config.name, description=config.description, global_imports=config.global_imports)

View File

@ -0,0 +1,88 @@
import asyncio
from typing import Any, Dict, List, Literal, Mapping
from pydantic import BaseModel
from typing_extensions import Self
from .._cancellation_token import CancellationToken
from .._component_config import Component, ComponentModel
from ._base import BaseTool, ToolSchema
from ._workbench import TextResultContent, ToolResult, Workbench
class StaticWorkbenchConfig(BaseModel):
tools: List[ComponentModel] = []
class StateicWorkbenchState(BaseModel):
type: Literal["StaticWorkbenchState"] = "StaticWorkbenchState"
tools: Dict[str, Mapping[str, Any]] = {}
class StaticWorkbench(Workbench, Component[StaticWorkbenchConfig]):
"""
A workbench that provides a static set of tools that do not change after
each tool execution.
Args:
tools (List[BaseTool[Any, Any]]): A list of tools to be included in the workbench.
The tools should be subclasses of :class:`~autogen_core.tools.BaseTool`.
"""
component_provider_override = "autogen_core.tools.StaticWorkbench"
component_config_schema = StaticWorkbenchConfig
def __init__(self, tools: List[BaseTool[Any, Any]]) -> None:
self._tools = tools
async def list_tools(self) -> List[ToolSchema]:
return [tool.schema for tool in self._tools]
async def call_tool(
self, name: str, arguments: Mapping[str, Any] | None = None, cancellation_token: CancellationToken | None = None
) -> ToolResult:
tool = next((tool for tool in self._tools if tool.name == name), None)
if tool is None:
raise ValueError(f"Tool {name} not found in workbench.")
if not cancellation_token:
cancellation_token = CancellationToken()
if not arguments:
arguments = {}
try:
result_future = asyncio.ensure_future(tool.run_json(arguments, cancellation_token))
cancellation_token.link_future(result_future)
result = await result_future
is_error = False
except Exception as e:
result = str(e)
is_error = True
result_str = tool.return_value_as_string(result)
return ToolResult(name=tool.name, result=[TextResultContent(content=result_str)], is_error=is_error)
async def start(self) -> None:
return None
async def stop(self) -> None:
return None
async def reset(self) -> None:
return None
async def save_state(self) -> Mapping[str, Any]:
tool_states = StateicWorkbenchState()
for tool in self._tools:
tool_states.tools[tool.name] = await tool.save_state_json()
return tool_states.model_dump()
async def load_state(self, state: Mapping[str, Any]) -> None:
parsed_state = StateicWorkbenchState.model_validate(state)
for tool in self._tools:
if tool.name in parsed_state.tools:
await tool.load_state_json(parsed_state.tools[tool.name])
def _to_config(self) -> StaticWorkbenchConfig:
return StaticWorkbenchConfig(tools=[tool.dump_component() for tool in self._tools])
@classmethod
def _from_config(cls, config: StaticWorkbenchConfig) -> Self:
return cls(tools=[BaseTool.load_component(tool) for tool in config.tools])

View File

@ -0,0 +1,164 @@
from abc import ABC, abstractmethod
from types import TracebackType
from typing import Any, List, Literal, Mapping, Optional, Type
from pydantic import BaseModel, Field
from typing_extensions import Annotated, Self
from .._cancellation_token import CancellationToken
from .._component_config import ComponentBase
from .._image import Image
from ._base import ToolSchema
class TextResultContent(BaseModel):
"""
Text result content of a tool execution.
"""
type: Literal["TextResultContent"] = "TextResultContent"
content: str
"""The text content of the result."""
class ImageResultContent(BaseModel):
"""
Image result content of a tool execution.
"""
type: Literal["ImageResultContent"] = "ImageResultContent"
content: Image
"""The image content of the result."""
ResultContent = Annotated[TextResultContent | ImageResultContent, Field(discriminator="type")]
class ToolResult(BaseModel):
"""
A result of a tool execution by a workbench.
"""
type: Literal["ToolResult"] = "ToolResult"
name: str
"""The name of the tool that was executed."""
result: List[ResultContent]
"""The result of the tool execution."""
is_error: bool = False
"""Whether the tool execution resulted in an error."""
class Workbench(ABC, ComponentBase[BaseModel]):
"""
A workbench is a component that provides a set of tools that may share
resources and state.
A workbench is responsible for managing the lifecycle of the tools and
providing a single interface to call them. The tools provided by the workbench
may be dynamic and their availabilities may change after each tool execution.
A workbench can be started by calling the :meth:`~autogen_core.tools.Workbench.start` method
and stopped by calling the :meth:`~autogen_core.tools.Workbench.stop` method.
It can also be used as an asynchronous context manager, which will automatically
start and stop the workbench when entering and exiting the context.
"""
component_type = "workbench"
@abstractmethod
async def list_tools(self) -> List[ToolSchema]:
"""
List the currently available tools in the workbench as :class:`ToolSchema`
objects.
The list of tools may be dynamic, and their content may change after
tool execution.
"""
...
@abstractmethod
async def call_tool(
self, name: str, arguments: Mapping[str, Any] | None = None, cancellation_token: CancellationToken | None = None
) -> ToolResult:
"""
Call a tool in the workbench.
Args:
name (str): The name of the tool to call.
arguments (Mapping[str, Any] | None): The arguments to pass to the tool.
If None, the tool will be called with no arguments.
cancellation_token (CancellationToken | None): An optional cancellation token
to cancel the tool execution.
Returns:
ToolResult: The result of the tool execution.
"""
...
@abstractmethod
async def start(self) -> None:
"""
Start the workbench and initialize any resources.
This method should be called before using the workbench.
"""
...
@abstractmethod
async def stop(self) -> None:
"""
Stop the workbench and release any resources.
This method should be called when the workbench is no longer needed.
"""
...
@abstractmethod
async def reset(self) -> None:
"""
Reset the workbench to its initialized, started state.
"""
...
@abstractmethod
async def save_state(self) -> Mapping[str, Any]:
"""
Save the state of the workbench.
This method should be called to persist the state of the workbench.
"""
...
@abstractmethod
async def load_state(self, state: Mapping[str, Any]) -> None:
"""
Load the state of the workbench.
Args:
state (Mapping[str, Any]): The state to load into the workbench.
"""
...
async def __aenter__(self) -> Self:
"""
Enter the workbench context manager.
This method is called when the workbench is used in a `with` statement.
It calls the :meth:`~autogen_core.tools.WorkBench.start` method to start the workbench.
"""
await self.start()
return self
async def __aexit__(
self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType]
) -> None:
"""
Exit the workbench context manager.
This method is called when the workbench is used in a `with` statement.
It calls the :meth:`~autogen_core.tools.WorkBench.stop` method to stop the workbench.
"""
await self.stop()

View File

@ -0,0 +1,119 @@
from typing import Annotated
import pytest
from autogen_core.code_executor import ImportFromModule
from autogen_core.tools import FunctionTool, StaticWorkbench, Workbench
@pytest.mark.asyncio
async def test_static_workbench() -> None:
def test_tool_func_1(x: Annotated[int, "The number to double."]) -> int:
return x * 2
def test_tool_func_2(x: Annotated[int, "The number to add 2."]) -> int:
raise ValueError("This is a test error") # Simulate an error
test_tool_1 = FunctionTool(
test_tool_func_1,
name="test_tool_1",
description="A test tool that doubles a number.",
global_imports=[ImportFromModule(module="typing_extensions", imports=["Annotated"])],
)
test_tool_2 = FunctionTool(
test_tool_func_2,
name="test_tool_2",
description="A test tool that adds 2 to a number.",
global_imports=[ImportFromModule(module="typing_extensions", imports=["Annotated"])],
)
# Create a StaticWorkbench instance with the test tools.
async with StaticWorkbench(tools=[test_tool_1, test_tool_2]) as workbench:
# List tools
tools = await workbench.list_tools()
assert len(tools) == 2
assert "description" in tools[0]
assert "parameters" in tools[0]
assert tools[0]["name"] == "test_tool_1"
assert tools[0]["description"] == "A test tool that doubles a number."
assert tools[0]["parameters"] == {
"type": "object",
"properties": {"x": {"type": "integer", "title": "X", "description": "The number to double."}},
"required": ["x"],
"additionalProperties": False,
}
assert "description" in tools[1]
assert "parameters" in tools[1]
assert tools[1]["name"] == "test_tool_2"
assert tools[1]["description"] == "A test tool that adds 2 to a number."
assert tools[1]["parameters"] == {
"type": "object",
"properties": {"x": {"type": "integer", "title": "X", "description": "The number to add 2."}},
"required": ["x"],
"additionalProperties": False,
}
# Call tools
result_1 = await workbench.call_tool("test_tool_1", {"x": 5})
assert result_1.name == "test_tool_1"
assert result_1.result[0].type == "TextResultContent"
assert result_1.result[0].content == "10"
assert result_1.is_error is False
# Call tool with error
result_2 = await workbench.call_tool("test_tool_2", {"x": 5})
assert result_2.name == "test_tool_2"
assert result_2.result[0].type == "TextResultContent"
assert result_2.result[0].content == "This is a test error"
assert result_2.is_error is True
# Save state.
state = await workbench.save_state()
assert state["type"] == "StaticWorkbenchState"
assert "tools" in state
assert len(state["tools"]) == 2
# Dump config.
config = workbench.dump_component()
# Load the workbench from the config.
async with Workbench.load_component(config) as new_workbench:
# Load state.
await new_workbench.load_state(state)
# Verify that the tools are still available after loading the state.
tools = await new_workbench.list_tools()
assert len(tools) == 2
assert "description" in tools[0]
assert "parameters" in tools[0]
assert tools[0]["name"] == "test_tool_1"
assert tools[0]["description"] == "A test tool that doubles a number."
assert tools[0]["parameters"] == {
"type": "object",
"properties": {"x": {"type": "integer", "title": "X", "description": "The number to double."}},
"required": ["x"],
"additionalProperties": False,
}
assert "description" in tools[1]
assert "parameters" in tools[1]
assert tools[1]["name"] == "test_tool_2"
assert tools[1]["description"] == "A test tool that adds 2 to a number."
assert tools[1]["parameters"] == {
"type": "object",
"properties": {"x": {"type": "integer", "title": "X", "description": "The number to add 2."}},
"required": ["x"],
"additionalProperties": False,
}
# Call tools
result_1 = await new_workbench.call_tool("test_tool_1", {"x": 5})
assert result_1.name == "test_tool_1"
assert result_1.result[0].type == "TextResultContent"
assert result_1.result[0].content == "10"
assert result_1.is_error is False
# Call tool with error
result_2 = await new_workbench.call_tool("test_tool_2", {"x": 5})
assert result_2.name == "test_tool_2"
assert result_2.result[0].type == "TextResultContent"
assert result_2.result[0].content == "This is a test error"
assert result_2.is_error is True

View File

@ -1,15 +1,19 @@
from ._actor import McpSessionActor
from ._config import McpServerParams, SseServerParams, StdioServerParams
from ._factory import mcp_server_tools
from ._session import create_mcp_server_session
from ._sse import SseMcpToolAdapter
from ._stdio import StdioMcpToolAdapter
from ._workbench import McpWorkbench
__all__ = [
"create_mcp_server_session",
"McpSessionActor",
"StdioMcpToolAdapter",
"StdioServerParams",
"SseMcpToolAdapter",
"SseServerParams",
"McpServerParams",
"mcp_server_tools",
"McpWorkbench",
]

View File

@ -0,0 +1,147 @@
import asyncio
import atexit
from typing import Any, Coroutine, Dict, Mapping, TypedDict
from autogen_core import Component, ComponentBase
from mcp.types import CallToolResult, ListToolsResult
from pydantic import BaseModel
from typing_extensions import Self
from ._config import McpServerParams
from ._session import create_mcp_server_session
McpResult = Coroutine[Any, Any, ListToolsResult] | Coroutine[Any, Any, CallToolResult]
McpFuture = asyncio.Future[McpResult]
class McpActorArgs(TypedDict):
name: str | None
kargs: Mapping[str, Any]
class McpSessionActorConfig(BaseModel):
server_params: McpServerParams
class McpSessionActor(ComponentBase[BaseModel], Component[McpSessionActorConfig]):
component_type = "mcp_session_actor"
component_config_schema = McpSessionActorConfig
component_provider_override = "autogen_ext.tools.mcp.McpSessionActor"
server_params: McpServerParams
# model_config = ConfigDict(arbitrary_types_allowed=True)
def __init__(self, server_params: McpServerParams) -> None:
self.server_params: McpServerParams = server_params
self.name = "mcp_session_actor"
self.description = "MCP session actor"
self._command_queue: asyncio.Queue[Dict[str, Any]] = asyncio.Queue()
self._actor_task: asyncio.Task[Any] | None = None
self._shutdown_future: asyncio.Future[Any] | None = None
self._active = False
atexit.register(self._sync_shutdown)
async def initialize(self) -> None:
if not self._active:
self._active = True
self._actor_task = asyncio.create_task(self._run_actor())
async def call(self, type: str, args: McpActorArgs | None = None) -> McpFuture:
if not self._active:
raise RuntimeError("MCP Actor not running, call initialize() first")
if self._actor_task and self._actor_task.done():
raise RuntimeError("MCP actor task crashed", self._actor_task.exception())
fut: asyncio.Future[McpFuture] = asyncio.Future()
if type in {"list_tools", "shutdown"}:
await self._command_queue.put({"type": type, "future": fut})
res = await fut
elif type == "call_tool":
if args is None:
raise ValueError("args is required for call_tool")
name = args.get("name", None)
kwargs = args.get("kargs", {})
if name is None:
raise ValueError("name is required for call_tool")
await self._command_queue.put({"type": type, "name": name, "args": kwargs, "future": fut})
res = await fut
else:
raise ValueError(f"Unknown command type: {type}")
return res
async def close(self) -> None:
if not self._active or self._actor_task is None:
return
self._shutdown_future = asyncio.Future()
await self._command_queue.put({"type": "shutdown", "future": self._shutdown_future})
await self._shutdown_future
await self._actor_task
self._active = False
async def _run_actor(self) -> None:
result: McpResult
try:
async with create_mcp_server_session(self.server_params) as session:
await session.initialize()
while True:
cmd = await self._command_queue.get()
if cmd["type"] == "shutdown":
cmd["future"].set_result("ok")
break
elif cmd["type"] == "call_tool":
try:
result = session.call_tool(name=cmd["name"], arguments=cmd["args"])
cmd["future"].set_result(result)
except Exception as e:
cmd["future"].set_exception(e)
elif cmd["type"] == "list_tools":
try:
result = session.list_tools()
cmd["future"].set_result(result)
except Exception as e:
cmd["future"].set_exception(e)
except Exception as e:
if self._shutdown_future and not self._shutdown_future.done():
self._shutdown_future.set_exception(e)
finally:
self._active = False
self._actor_task = None
def _sync_shutdown(self) -> None:
if not self._active or self._actor_task is None:
return
try:
loop = asyncio.get_event_loop()
except RuntimeError:
# No loop available — interpreter is likely shutting down
return
if loop.is_closed():
return
if loop.is_running():
loop.create_task(self.close())
else:
loop.run_until_complete(self.close())
def _to_config(self) -> McpSessionActorConfig:
"""
Convert the adapter to its configuration representation.
Returns:
McpSessionConfig: The configuration of the adapter.
"""
return McpSessionActorConfig(server_params=self.server_params)
@classmethod
def _from_config(cls, config: McpSessionActorConfig) -> Self:
"""
Create an instance of McpSessionActor from its configuration.
Args:
config (McpSessionConfig): The configuration of the adapter.
Returns:
McpSessionActor: An instance of SseMcpToolAdapter.
"""
return cls(server_params=config.server_params)

View File

@ -1,22 +1,27 @@
from typing import Any, TypeAlias
from typing import Any, Literal
from mcp import StdioServerParameters
from pydantic import BaseModel
from pydantic import BaseModel, Field
from typing_extensions import Annotated
class StdioServerParams(StdioServerParameters):
"""Parameters for connecting to an MCP server over STDIO."""
type: Literal["StdioServerParams"] = "StdioServerParams"
read_timeout_seconds: float = 5
class SseServerParams(BaseModel):
"""Parameters for connecting to an MCP server over SSE."""
type: Literal["SseServerParams"] = "SseServerParams"
url: str
headers: dict[str, Any] | None = None
timeout: float = 5
sse_read_timeout: float = 60 * 5
McpServerParams: TypeAlias = StdioServerParams | SseServerParams
McpServerParams = Annotated[StdioServerParams | SseServerParams, Field(discriminator="type")]

View File

@ -0,0 +1,200 @@
import builtins
import warnings
from typing import Any, List, Literal, Mapping
from autogen_core import CancellationToken, Component, Image
from autogen_core.tools import (
ImageResultContent,
ParametersSchema,
TextResultContent,
ToolResult,
ToolSchema,
Workbench,
)
from mcp.types import CallToolResult, EmbeddedResource, ImageContent, ListToolsResult, TextContent
from pydantic import BaseModel
from typing_extensions import Self
from ._actor import McpSessionActor
from ._config import McpServerParams, SseServerParams, StdioServerParams
class McpWorkbenchConfig(BaseModel):
server_params: McpServerParams
class McpWorkbenchState(BaseModel):
type: Literal["McpWorkBenchState"] = "McpWorkBenchState"
class McpWorkbench(Workbench, Component[McpWorkbenchConfig]):
"""
A workbench that wraps an MCP server and provides an interface
to list and call tools provided by the server.
Args:
server_params (McpServerParams): The parameters to connect to the MCP server.
This can be either a :class:`StdioServerParams` or :class:`SseServerParams`.
Example:
.. code-block:: python
import asyncio
from autogen_ext.tools.mcp import McpWorkbench, StdioServerParams
async def main() -> None:
params = StdioServerParams(
command="uvx",
args=["mcp-server-fetch"],
read_timeout_seconds=60,
)
# You can also use `start()` and `stop()` to manage the session.
async with McpWorkbench(server_params=params) as workbench:
tools = await workbench.list_tools()
print(tools)
result = await workbench.call_tool(tools[0]["name"], {"url": "https://github.com/"})
print(result)
asyncio.run(main())
"""
component_provider_override = "autogen_ext.tools.mcp.McpWorkbench"
component_config_schema = McpWorkbenchConfig
def __init__(self, server_params: McpServerParams) -> None:
self._server_params = server_params
# self._session: ClientSession | None = None
self._actor: McpSessionActor | None = None
self._read = None
self._write = None
@property
def server_params(self) -> McpServerParams:
return self._server_params
async def list_tools(self) -> List[ToolSchema]:
if not self._actor:
await self.start() # fallback to start the actor if not initialized instead of raising an error
# Why? Because when deserializing the workbench, the actor might not be initialized yet.
# raise RuntimeError("Actor is not initialized. Call start() first.")
if self._actor is None:
raise RuntimeError("Actor is not initialized. Please check the server connection.")
result_future = await self._actor.call("list_tools", None)
list_tool_result = await result_future
assert isinstance(
list_tool_result, ListToolsResult
), f"list_tools must return a CallToolResult, instead of : {str(type(list_tool_result))}"
schema: List[ToolSchema] = []
for tool in list_tool_result.tools:
name = tool.name
description = tool.description or ""
parameters = ParametersSchema(
type="object",
properties=tool.inputSchema["properties"],
required=tool.inputSchema.get("required", []),
additionalProperties=tool.inputSchema.get("additionalProperties", False),
)
tool_schema = ToolSchema(
name=name,
description=description,
parameters=parameters,
)
schema.append(tool_schema)
return schema
async def call_tool(
self, name: str, arguments: Mapping[str, Any] | None = None, cancellation_token: CancellationToken | None = None
) -> ToolResult:
if not self._actor:
await self.start() # fallback to start the actor if not initialized instead of raising an error
# Why? Because when deserializing the workbench, the actor might not be initialized yet.
# raise RuntimeError("Actor is not initialized. Call start() first.")
if self._actor is None:
raise RuntimeError("Actor is not initialized. Please check the server connection.")
if not cancellation_token:
cancellation_token = CancellationToken()
if not arguments:
arguments = {}
try:
result_future = await self._actor.call("call_tool", {"name": name, "kargs": arguments})
cancellation_token.link_future(result_future)
result = await result_future
assert isinstance(
result, CallToolResult
), f"call_tool must return a CallToolResult, instead of : {str(type(result))}"
result_parts: List[TextResultContent | ImageResultContent] = []
is_error = result.isError
for content in result.content:
if isinstance(content, TextContent):
result_parts.append(TextResultContent(content=content.text))
elif isinstance(content, ImageContent):
result_parts.append(ImageResultContent(content=Image.from_base64(content.data)))
elif isinstance(content, EmbeddedResource):
# TODO: how to handle embedded resources?
# For now we just use text representation.
result_parts.append(TextResultContent(content=content.model_dump_json()))
else:
raise ValueError(f"Unknown content type from server: {type(content)}")
except Exception as e:
error_message = self._format_errors(e)
is_error = True
result_parts = [TextResultContent(content=error_message)]
return ToolResult(name=name, result=result_parts, is_error=is_error)
def _format_errors(self, error: Exception) -> str:
"""Recursively format errors into a string."""
error_message = ""
if hasattr(builtins, "ExceptionGroup") and isinstance(error, builtins.ExceptionGroup):
# ExceptionGroup is available in Python 3.11+.
# TODO: how to make this compatible with Python 3.10?
for sub_exception in error.exceptions: # type: ignore
error_message += self._format_errors(sub_exception) # type: ignore
else:
error_message += f"{str(error)}\n"
return error_message
async def start(self) -> None:
if self._actor:
warnings.warn(
"McpWorkbench is already started. No need to start again.",
UserWarning,
stacklevel=2,
)
return # Already initialized, no need to start again
if isinstance(self._server_params, (StdioServerParams, SseServerParams)):
self._actor = McpSessionActor(self._server_params)
await self._actor.initialize()
else:
raise ValueError(f"Unsupported server params type: {type(self._server_params)}")
async def stop(self) -> None:
if self._actor:
# Close the actor
await self._actor.close()
self._actor = None
else:
raise RuntimeError("McpWorkbench is not started. Call start() first.")
async def reset(self) -> None:
pass
async def save_state(self) -> Mapping[str, Any]:
return McpWorkbenchState().model_dump()
async def load_state(self, state: Mapping[str, Any]) -> None:
pass
def _to_config(self) -> McpWorkbenchConfig:
return McpWorkbenchConfig(server_params=self._server_params)
@classmethod
def _from_config(cls, config: McpWorkbenchConfig) -> Self:
return cls(server_params=config.server_params)

View File

@ -4,8 +4,10 @@ from unittest.mock import AsyncMock, MagicMock
import pytest
from autogen_core import CancellationToken
from autogen_core.tools import Workbench
from autogen_core.utils import schema_to_pydantic_model
from autogen_ext.tools.mcp import (
McpWorkbench,
SseMcpToolAdapter,
SseServerParams,
StdioMcpToolAdapter,
@ -422,3 +424,80 @@ async def test_mcp_server_github() -> None:
{"owner": "microsoft", "repo": "autogen", "path": "python", "branch": "main"}, CancellationToken()
)
assert result is not None
@pytest.mark.asyncio
async def test_mcp_workbench_start_stop() -> None:
params = StdioServerParams(
command="uvx",
args=["mcp-server-fetch"],
read_timeout_seconds=60,
)
workbench = McpWorkbench(params)
assert workbench is not None
assert workbench.server_params == params
await workbench.start()
assert workbench._actor is not None # type: ignore[reportPrivateUsage]
await workbench.stop()
assert workbench._actor is None # type: ignore[reportPrivateUsage]
@pytest.mark.asyncio
async def test_mcp_workbench_server_fetch() -> None:
params = StdioServerParams(
command="uvx",
args=["mcp-server-fetch"],
read_timeout_seconds=60,
)
workbench = McpWorkbench(server_params=params)
await workbench.start()
tools = await workbench.list_tools()
assert tools is not None
assert tools[0]["name"] == "fetch"
result = await workbench.call_tool(tools[0]["name"], {"url": "https://github.com/"}, CancellationToken())
assert result is not None
await workbench.stop()
@pytest.mark.asyncio
async def test_mcp_workbench_server_filesystem() -> None:
params = StdioServerParams(
command="npx",
args=[
"-y",
"@modelcontextprotocol/server-filesystem",
".",
],
read_timeout_seconds=60,
)
workbench = McpWorkbench(server_params=params)
await workbench.start()
tools = await workbench.list_tools()
assert tools is not None
tools = [tool for tool in tools if tool["name"] == "read_file"]
assert len(tools) == 1
tool = tools[0]
result = await workbench.call_tool(tool["name"], {"path": "README.md"}, CancellationToken())
assert result is not None
await workbench.stop()
# Serialize the workbench.
config = workbench.dump_component()
# Deserialize the workbench.
async with Workbench.load_component(config) as new_workbench:
tools = await new_workbench.list_tools()
assert tools is not None
tools = [tool for tool in tools if tool["name"] == "read_file"]
assert len(tools) == 1
tool = tools[0]
result = await new_workbench.call_tool(tool["name"], {"path": "README.md"}, CancellationToken())
assert result is not None