mirror of
https://github.com/microsoft/autogen.git
synced 2025-12-12 15:31:21 +00:00
Introduce workbench (#6340)
This PR introduces `WorkBench`.
A workbench provides a group of tools that share the same resource and
state. For example, `McpWorkbench` provides the underlying tools on the
MCP server. A workbench allows tools to be managed together and abstract
away the lifecycle of individual tools under a single entity. This makes
it possible to create agents with stateful tools from serializable
configuration (component configs), and it also supports dynamic tools:
tools change after each execution.
Here is how a workbench may be used with AssistantAgent (not included in
this PR):
```python
workbench = McpWorkbench(server_params)
agent = AssistantAgent("assistant", tools=workbench)
result = await agent.run(task="do task...")
```
TODOs:
1. In a subsequent PR, update `AssistantAgent` to use workbench as an
alternative in the `tools` parameter. Use `StaticWorkbench` to manage
individual tools.
2. In another PR, add documentation on workbench.
---------
Co-authored-by: EeS <chiyoung.song@motov.co.kr>
Co-authored-by: Minh Đăng <74671798+perfogic@users.noreply.github.com>
This commit is contained in:
parent
a283d268df
commit
8fcba01704
@ -38,7 +38,6 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"from pydantic import BaseModel\n",
|
||||
"from typing import List, Optional\n",
|
||||
"\n",
|
||||
"from autogen_core import AgentId, MessageContext, RoutedAgent, SingleThreadedAgentRuntime, message_handler\n",
|
||||
@ -57,7 +56,8 @@
|
||||
"from llama_index.embeddings.openai import OpenAIEmbedding\n",
|
||||
"from llama_index.llms.azure_openai import AzureOpenAI\n",
|
||||
"from llama_index.llms.openai import OpenAI\n",
|
||||
"from llama_index.tools.wikipedia import WikipediaToolSpec"
|
||||
"from llama_index.tools.wikipedia import WikipediaToolSpec\n",
|
||||
"from pydantic import BaseModel"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@ -7,7 +7,7 @@ from typing import Any, ClassVar, Dict, Generic, Literal, Type, TypeGuard, cast,
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Self, TypeVar
|
||||
|
||||
ComponentType = Literal["model", "agent", "tool", "termination", "token_provider"] | str
|
||||
ComponentType = Literal["model", "agent", "tool", "termination", "token_provider", "workbench"] | str
|
||||
ConfigT = TypeVar("ConfigT", bound=BaseModel)
|
||||
FromConfigT = TypeVar("FromConfigT", bound=BaseModel, contravariant=True)
|
||||
ToConfigT = TypeVar("ToConfigT", bound=BaseModel, covariant=True)
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
from ._base import BaseTool, BaseToolWithState, ParametersSchema, Tool, ToolSchema
|
||||
from ._function_tool import FunctionTool
|
||||
from ._static_workbench import StaticWorkbench
|
||||
from ._workbench import ImageResultContent, TextResultContent, ToolResult, Workbench
|
||||
|
||||
__all__ = [
|
||||
"Tool",
|
||||
@ -8,4 +10,9 @@ __all__ = [
|
||||
"BaseTool",
|
||||
"BaseToolWithState",
|
||||
"FunctionTool",
|
||||
"Workbench",
|
||||
"ToolResult",
|
||||
"TextResultContent",
|
||||
"ImageResultContent",
|
||||
"StaticWorkbench",
|
||||
]
|
||||
|
||||
@ -178,4 +178,4 @@ class FunctionTool(BaseTool[BaseModel, BaseModel], Component[FunctionToolConfig]
|
||||
if not callable(func):
|
||||
raise TypeError(f"Expected function but got {type(func)}")
|
||||
|
||||
return cls(func, "", None)
|
||||
return cls(func, name=config.name, description=config.description, global_imports=config.global_imports)
|
||||
|
||||
@ -0,0 +1,88 @@
|
||||
import asyncio
|
||||
from typing import Any, Dict, List, Literal, Mapping
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Self
|
||||
|
||||
from .._cancellation_token import CancellationToken
|
||||
from .._component_config import Component, ComponentModel
|
||||
from ._base import BaseTool, ToolSchema
|
||||
from ._workbench import TextResultContent, ToolResult, Workbench
|
||||
|
||||
|
||||
class StaticWorkbenchConfig(BaseModel):
|
||||
tools: List[ComponentModel] = []
|
||||
|
||||
|
||||
class StateicWorkbenchState(BaseModel):
|
||||
type: Literal["StaticWorkbenchState"] = "StaticWorkbenchState"
|
||||
tools: Dict[str, Mapping[str, Any]] = {}
|
||||
|
||||
|
||||
class StaticWorkbench(Workbench, Component[StaticWorkbenchConfig]):
|
||||
"""
|
||||
A workbench that provides a static set of tools that do not change after
|
||||
each tool execution.
|
||||
|
||||
Args:
|
||||
tools (List[BaseTool[Any, Any]]): A list of tools to be included in the workbench.
|
||||
The tools should be subclasses of :class:`~autogen_core.tools.BaseTool`.
|
||||
"""
|
||||
|
||||
component_provider_override = "autogen_core.tools.StaticWorkbench"
|
||||
component_config_schema = StaticWorkbenchConfig
|
||||
|
||||
def __init__(self, tools: List[BaseTool[Any, Any]]) -> None:
|
||||
self._tools = tools
|
||||
|
||||
async def list_tools(self) -> List[ToolSchema]:
|
||||
return [tool.schema for tool in self._tools]
|
||||
|
||||
async def call_tool(
|
||||
self, name: str, arguments: Mapping[str, Any] | None = None, cancellation_token: CancellationToken | None = None
|
||||
) -> ToolResult:
|
||||
tool = next((tool for tool in self._tools if tool.name == name), None)
|
||||
if tool is None:
|
||||
raise ValueError(f"Tool {name} not found in workbench.")
|
||||
if not cancellation_token:
|
||||
cancellation_token = CancellationToken()
|
||||
if not arguments:
|
||||
arguments = {}
|
||||
try:
|
||||
result_future = asyncio.ensure_future(tool.run_json(arguments, cancellation_token))
|
||||
cancellation_token.link_future(result_future)
|
||||
result = await result_future
|
||||
is_error = False
|
||||
except Exception as e:
|
||||
result = str(e)
|
||||
is_error = True
|
||||
result_str = tool.return_value_as_string(result)
|
||||
return ToolResult(name=tool.name, result=[TextResultContent(content=result_str)], is_error=is_error)
|
||||
|
||||
async def start(self) -> None:
|
||||
return None
|
||||
|
||||
async def stop(self) -> None:
|
||||
return None
|
||||
|
||||
async def reset(self) -> None:
|
||||
return None
|
||||
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
tool_states = StateicWorkbenchState()
|
||||
for tool in self._tools:
|
||||
tool_states.tools[tool.name] = await tool.save_state_json()
|
||||
return tool_states.model_dump()
|
||||
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
parsed_state = StateicWorkbenchState.model_validate(state)
|
||||
for tool in self._tools:
|
||||
if tool.name in parsed_state.tools:
|
||||
await tool.load_state_json(parsed_state.tools[tool.name])
|
||||
|
||||
def _to_config(self) -> StaticWorkbenchConfig:
|
||||
return StaticWorkbenchConfig(tools=[tool.dump_component() for tool in self._tools])
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: StaticWorkbenchConfig) -> Self:
|
||||
return cls(tools=[BaseTool.load_component(tool) for tool in config.tools])
|
||||
@ -0,0 +1,164 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from types import TracebackType
|
||||
from typing import Any, List, Literal, Mapping, Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Annotated, Self
|
||||
|
||||
from .._cancellation_token import CancellationToken
|
||||
from .._component_config import ComponentBase
|
||||
from .._image import Image
|
||||
from ._base import ToolSchema
|
||||
|
||||
|
||||
class TextResultContent(BaseModel):
|
||||
"""
|
||||
Text result content of a tool execution.
|
||||
"""
|
||||
|
||||
type: Literal["TextResultContent"] = "TextResultContent"
|
||||
|
||||
content: str
|
||||
"""The text content of the result."""
|
||||
|
||||
|
||||
class ImageResultContent(BaseModel):
|
||||
"""
|
||||
Image result content of a tool execution.
|
||||
"""
|
||||
|
||||
type: Literal["ImageResultContent"] = "ImageResultContent"
|
||||
|
||||
content: Image
|
||||
"""The image content of the result."""
|
||||
|
||||
|
||||
ResultContent = Annotated[TextResultContent | ImageResultContent, Field(discriminator="type")]
|
||||
|
||||
|
||||
class ToolResult(BaseModel):
|
||||
"""
|
||||
A result of a tool execution by a workbench.
|
||||
"""
|
||||
|
||||
type: Literal["ToolResult"] = "ToolResult"
|
||||
|
||||
name: str
|
||||
"""The name of the tool that was executed."""
|
||||
|
||||
result: List[ResultContent]
|
||||
"""The result of the tool execution."""
|
||||
|
||||
is_error: bool = False
|
||||
"""Whether the tool execution resulted in an error."""
|
||||
|
||||
|
||||
class Workbench(ABC, ComponentBase[BaseModel]):
|
||||
"""
|
||||
A workbench is a component that provides a set of tools that may share
|
||||
resources and state.
|
||||
|
||||
A workbench is responsible for managing the lifecycle of the tools and
|
||||
providing a single interface to call them. The tools provided by the workbench
|
||||
may be dynamic and their availabilities may change after each tool execution.
|
||||
|
||||
A workbench can be started by calling the :meth:`~autogen_core.tools.Workbench.start` method
|
||||
and stopped by calling the :meth:`~autogen_core.tools.Workbench.stop` method.
|
||||
It can also be used as an asynchronous context manager, which will automatically
|
||||
start and stop the workbench when entering and exiting the context.
|
||||
"""
|
||||
|
||||
component_type = "workbench"
|
||||
|
||||
@abstractmethod
|
||||
async def list_tools(self) -> List[ToolSchema]:
|
||||
"""
|
||||
List the currently available tools in the workbench as :class:`ToolSchema`
|
||||
objects.
|
||||
|
||||
The list of tools may be dynamic, and their content may change after
|
||||
tool execution.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def call_tool(
|
||||
self, name: str, arguments: Mapping[str, Any] | None = None, cancellation_token: CancellationToken | None = None
|
||||
) -> ToolResult:
|
||||
"""
|
||||
Call a tool in the workbench.
|
||||
|
||||
Args:
|
||||
name (str): The name of the tool to call.
|
||||
arguments (Mapping[str, Any] | None): The arguments to pass to the tool.
|
||||
If None, the tool will be called with no arguments.
|
||||
cancellation_token (CancellationToken | None): An optional cancellation token
|
||||
to cancel the tool execution.
|
||||
Returns:
|
||||
ToolResult: The result of the tool execution.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def start(self) -> None:
|
||||
"""
|
||||
Start the workbench and initialize any resources.
|
||||
|
||||
This method should be called before using the workbench.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def stop(self) -> None:
|
||||
"""
|
||||
Stop the workbench and release any resources.
|
||||
|
||||
This method should be called when the workbench is no longer needed.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def reset(self) -> None:
|
||||
"""
|
||||
Reset the workbench to its initialized, started state.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
"""
|
||||
Save the state of the workbench.
|
||||
|
||||
This method should be called to persist the state of the workbench.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
"""
|
||||
Load the state of the workbench.
|
||||
|
||||
Args:
|
||||
state (Mapping[str, Any]): The state to load into the workbench.
|
||||
"""
|
||||
...
|
||||
|
||||
async def __aenter__(self) -> Self:
|
||||
"""
|
||||
Enter the workbench context manager.
|
||||
|
||||
This method is called when the workbench is used in a `with` statement.
|
||||
It calls the :meth:`~autogen_core.tools.WorkBench.start` method to start the workbench.
|
||||
"""
|
||||
await self.start()
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType]
|
||||
) -> None:
|
||||
"""
|
||||
Exit the workbench context manager.
|
||||
This method is called when the workbench is used in a `with` statement.
|
||||
It calls the :meth:`~autogen_core.tools.WorkBench.stop` method to stop the workbench.
|
||||
"""
|
||||
await self.stop()
|
||||
119
python/packages/autogen-core/tests/test_workbench.py
Normal file
119
python/packages/autogen-core/tests/test_workbench.py
Normal file
@ -0,0 +1,119 @@
|
||||
from typing import Annotated
|
||||
|
||||
import pytest
|
||||
from autogen_core.code_executor import ImportFromModule
|
||||
from autogen_core.tools import FunctionTool, StaticWorkbench, Workbench
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_static_workbench() -> None:
|
||||
def test_tool_func_1(x: Annotated[int, "The number to double."]) -> int:
|
||||
return x * 2
|
||||
|
||||
def test_tool_func_2(x: Annotated[int, "The number to add 2."]) -> int:
|
||||
raise ValueError("This is a test error") # Simulate an error
|
||||
|
||||
test_tool_1 = FunctionTool(
|
||||
test_tool_func_1,
|
||||
name="test_tool_1",
|
||||
description="A test tool that doubles a number.",
|
||||
global_imports=[ImportFromModule(module="typing_extensions", imports=["Annotated"])],
|
||||
)
|
||||
test_tool_2 = FunctionTool(
|
||||
test_tool_func_2,
|
||||
name="test_tool_2",
|
||||
description="A test tool that adds 2 to a number.",
|
||||
global_imports=[ImportFromModule(module="typing_extensions", imports=["Annotated"])],
|
||||
)
|
||||
|
||||
# Create a StaticWorkbench instance with the test tools.
|
||||
async with StaticWorkbench(tools=[test_tool_1, test_tool_2]) as workbench:
|
||||
# List tools
|
||||
tools = await workbench.list_tools()
|
||||
assert len(tools) == 2
|
||||
assert "description" in tools[0]
|
||||
assert "parameters" in tools[0]
|
||||
assert tools[0]["name"] == "test_tool_1"
|
||||
assert tools[0]["description"] == "A test tool that doubles a number."
|
||||
assert tools[0]["parameters"] == {
|
||||
"type": "object",
|
||||
"properties": {"x": {"type": "integer", "title": "X", "description": "The number to double."}},
|
||||
"required": ["x"],
|
||||
"additionalProperties": False,
|
||||
}
|
||||
assert "description" in tools[1]
|
||||
assert "parameters" in tools[1]
|
||||
assert tools[1]["name"] == "test_tool_2"
|
||||
assert tools[1]["description"] == "A test tool that adds 2 to a number."
|
||||
assert tools[1]["parameters"] == {
|
||||
"type": "object",
|
||||
"properties": {"x": {"type": "integer", "title": "X", "description": "The number to add 2."}},
|
||||
"required": ["x"],
|
||||
"additionalProperties": False,
|
||||
}
|
||||
|
||||
# Call tools
|
||||
result_1 = await workbench.call_tool("test_tool_1", {"x": 5})
|
||||
assert result_1.name == "test_tool_1"
|
||||
assert result_1.result[0].type == "TextResultContent"
|
||||
assert result_1.result[0].content == "10"
|
||||
assert result_1.is_error is False
|
||||
|
||||
# Call tool with error
|
||||
result_2 = await workbench.call_tool("test_tool_2", {"x": 5})
|
||||
assert result_2.name == "test_tool_2"
|
||||
assert result_2.result[0].type == "TextResultContent"
|
||||
assert result_2.result[0].content == "This is a test error"
|
||||
assert result_2.is_error is True
|
||||
|
||||
# Save state.
|
||||
state = await workbench.save_state()
|
||||
assert state["type"] == "StaticWorkbenchState"
|
||||
assert "tools" in state
|
||||
assert len(state["tools"]) == 2
|
||||
|
||||
# Dump config.
|
||||
config = workbench.dump_component()
|
||||
|
||||
# Load the workbench from the config.
|
||||
async with Workbench.load_component(config) as new_workbench:
|
||||
# Load state.
|
||||
await new_workbench.load_state(state)
|
||||
|
||||
# Verify that the tools are still available after loading the state.
|
||||
tools = await new_workbench.list_tools()
|
||||
assert len(tools) == 2
|
||||
assert "description" in tools[0]
|
||||
assert "parameters" in tools[0]
|
||||
assert tools[0]["name"] == "test_tool_1"
|
||||
assert tools[0]["description"] == "A test tool that doubles a number."
|
||||
assert tools[0]["parameters"] == {
|
||||
"type": "object",
|
||||
"properties": {"x": {"type": "integer", "title": "X", "description": "The number to double."}},
|
||||
"required": ["x"],
|
||||
"additionalProperties": False,
|
||||
}
|
||||
assert "description" in tools[1]
|
||||
assert "parameters" in tools[1]
|
||||
assert tools[1]["name"] == "test_tool_2"
|
||||
assert tools[1]["description"] == "A test tool that adds 2 to a number."
|
||||
assert tools[1]["parameters"] == {
|
||||
"type": "object",
|
||||
"properties": {"x": {"type": "integer", "title": "X", "description": "The number to add 2."}},
|
||||
"required": ["x"],
|
||||
"additionalProperties": False,
|
||||
}
|
||||
|
||||
# Call tools
|
||||
result_1 = await new_workbench.call_tool("test_tool_1", {"x": 5})
|
||||
assert result_1.name == "test_tool_1"
|
||||
assert result_1.result[0].type == "TextResultContent"
|
||||
assert result_1.result[0].content == "10"
|
||||
assert result_1.is_error is False
|
||||
|
||||
# Call tool with error
|
||||
result_2 = await new_workbench.call_tool("test_tool_2", {"x": 5})
|
||||
assert result_2.name == "test_tool_2"
|
||||
assert result_2.result[0].type == "TextResultContent"
|
||||
assert result_2.result[0].content == "This is a test error"
|
||||
assert result_2.is_error is True
|
||||
@ -1,15 +1,19 @@
|
||||
from ._actor import McpSessionActor
|
||||
from ._config import McpServerParams, SseServerParams, StdioServerParams
|
||||
from ._factory import mcp_server_tools
|
||||
from ._session import create_mcp_server_session
|
||||
from ._sse import SseMcpToolAdapter
|
||||
from ._stdio import StdioMcpToolAdapter
|
||||
from ._workbench import McpWorkbench
|
||||
|
||||
__all__ = [
|
||||
"create_mcp_server_session",
|
||||
"McpSessionActor",
|
||||
"StdioMcpToolAdapter",
|
||||
"StdioServerParams",
|
||||
"SseMcpToolAdapter",
|
||||
"SseServerParams",
|
||||
"McpServerParams",
|
||||
"mcp_server_tools",
|
||||
"McpWorkbench",
|
||||
]
|
||||
|
||||
147
python/packages/autogen-ext/src/autogen_ext/tools/mcp/_actor.py
Normal file
147
python/packages/autogen-ext/src/autogen_ext/tools/mcp/_actor.py
Normal file
@ -0,0 +1,147 @@
|
||||
import asyncio
|
||||
import atexit
|
||||
from typing import Any, Coroutine, Dict, Mapping, TypedDict
|
||||
|
||||
from autogen_core import Component, ComponentBase
|
||||
from mcp.types import CallToolResult, ListToolsResult
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Self
|
||||
|
||||
from ._config import McpServerParams
|
||||
from ._session import create_mcp_server_session
|
||||
|
||||
McpResult = Coroutine[Any, Any, ListToolsResult] | Coroutine[Any, Any, CallToolResult]
|
||||
McpFuture = asyncio.Future[McpResult]
|
||||
|
||||
|
||||
class McpActorArgs(TypedDict):
|
||||
name: str | None
|
||||
kargs: Mapping[str, Any]
|
||||
|
||||
|
||||
class McpSessionActorConfig(BaseModel):
|
||||
server_params: McpServerParams
|
||||
|
||||
|
||||
class McpSessionActor(ComponentBase[BaseModel], Component[McpSessionActorConfig]):
|
||||
component_type = "mcp_session_actor"
|
||||
component_config_schema = McpSessionActorConfig
|
||||
component_provider_override = "autogen_ext.tools.mcp.McpSessionActor"
|
||||
|
||||
server_params: McpServerParams
|
||||
|
||||
# model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
def __init__(self, server_params: McpServerParams) -> None:
|
||||
self.server_params: McpServerParams = server_params
|
||||
self.name = "mcp_session_actor"
|
||||
self.description = "MCP session actor"
|
||||
self._command_queue: asyncio.Queue[Dict[str, Any]] = asyncio.Queue()
|
||||
self._actor_task: asyncio.Task[Any] | None = None
|
||||
self._shutdown_future: asyncio.Future[Any] | None = None
|
||||
self._active = False
|
||||
atexit.register(self._sync_shutdown)
|
||||
|
||||
async def initialize(self) -> None:
|
||||
if not self._active:
|
||||
self._active = True
|
||||
self._actor_task = asyncio.create_task(self._run_actor())
|
||||
|
||||
async def call(self, type: str, args: McpActorArgs | None = None) -> McpFuture:
|
||||
if not self._active:
|
||||
raise RuntimeError("MCP Actor not running, call initialize() first")
|
||||
if self._actor_task and self._actor_task.done():
|
||||
raise RuntimeError("MCP actor task crashed", self._actor_task.exception())
|
||||
fut: asyncio.Future[McpFuture] = asyncio.Future()
|
||||
if type in {"list_tools", "shutdown"}:
|
||||
await self._command_queue.put({"type": type, "future": fut})
|
||||
res = await fut
|
||||
elif type == "call_tool":
|
||||
if args is None:
|
||||
raise ValueError("args is required for call_tool")
|
||||
name = args.get("name", None)
|
||||
kwargs = args.get("kargs", {})
|
||||
if name is None:
|
||||
raise ValueError("name is required for call_tool")
|
||||
await self._command_queue.put({"type": type, "name": name, "args": kwargs, "future": fut})
|
||||
res = await fut
|
||||
else:
|
||||
raise ValueError(f"Unknown command type: {type}")
|
||||
return res
|
||||
|
||||
async def close(self) -> None:
|
||||
if not self._active or self._actor_task is None:
|
||||
return
|
||||
self._shutdown_future = asyncio.Future()
|
||||
await self._command_queue.put({"type": "shutdown", "future": self._shutdown_future})
|
||||
await self._shutdown_future
|
||||
await self._actor_task
|
||||
self._active = False
|
||||
|
||||
async def _run_actor(self) -> None:
|
||||
result: McpResult
|
||||
try:
|
||||
async with create_mcp_server_session(self.server_params) as session:
|
||||
await session.initialize()
|
||||
while True:
|
||||
cmd = await self._command_queue.get()
|
||||
if cmd["type"] == "shutdown":
|
||||
cmd["future"].set_result("ok")
|
||||
break
|
||||
elif cmd["type"] == "call_tool":
|
||||
try:
|
||||
result = session.call_tool(name=cmd["name"], arguments=cmd["args"])
|
||||
cmd["future"].set_result(result)
|
||||
except Exception as e:
|
||||
cmd["future"].set_exception(e)
|
||||
elif cmd["type"] == "list_tools":
|
||||
try:
|
||||
result = session.list_tools()
|
||||
cmd["future"].set_result(result)
|
||||
except Exception as e:
|
||||
cmd["future"].set_exception(e)
|
||||
except Exception as e:
|
||||
if self._shutdown_future and not self._shutdown_future.done():
|
||||
self._shutdown_future.set_exception(e)
|
||||
finally:
|
||||
self._active = False
|
||||
self._actor_task = None
|
||||
|
||||
def _sync_shutdown(self) -> None:
|
||||
if not self._active or self._actor_task is None:
|
||||
return
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
# No loop available — interpreter is likely shutting down
|
||||
return
|
||||
|
||||
if loop.is_closed():
|
||||
return
|
||||
|
||||
if loop.is_running():
|
||||
loop.create_task(self.close())
|
||||
else:
|
||||
loop.run_until_complete(self.close())
|
||||
|
||||
def _to_config(self) -> McpSessionActorConfig:
|
||||
"""
|
||||
Convert the adapter to its configuration representation.
|
||||
|
||||
Returns:
|
||||
McpSessionConfig: The configuration of the adapter.
|
||||
"""
|
||||
return McpSessionActorConfig(server_params=self.server_params)
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: McpSessionActorConfig) -> Self:
|
||||
"""
|
||||
Create an instance of McpSessionActor from its configuration.
|
||||
|
||||
Args:
|
||||
config (McpSessionConfig): The configuration of the adapter.
|
||||
|
||||
Returns:
|
||||
McpSessionActor: An instance of SseMcpToolAdapter.
|
||||
"""
|
||||
return cls(server_params=config.server_params)
|
||||
@ -1,22 +1,27 @@
|
||||
from typing import Any, TypeAlias
|
||||
from typing import Any, Literal
|
||||
|
||||
from mcp import StdioServerParameters
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Annotated
|
||||
|
||||
|
||||
class StdioServerParams(StdioServerParameters):
|
||||
"""Parameters for connecting to an MCP server over STDIO."""
|
||||
|
||||
type: Literal["StdioServerParams"] = "StdioServerParams"
|
||||
|
||||
read_timeout_seconds: float = 5
|
||||
|
||||
|
||||
class SseServerParams(BaseModel):
|
||||
"""Parameters for connecting to an MCP server over SSE."""
|
||||
|
||||
type: Literal["SseServerParams"] = "SseServerParams"
|
||||
|
||||
url: str
|
||||
headers: dict[str, Any] | None = None
|
||||
timeout: float = 5
|
||||
sse_read_timeout: float = 60 * 5
|
||||
|
||||
|
||||
McpServerParams: TypeAlias = StdioServerParams | SseServerParams
|
||||
McpServerParams = Annotated[StdioServerParams | SseServerParams, Field(discriminator="type")]
|
||||
|
||||
@ -0,0 +1,200 @@
|
||||
import builtins
|
||||
import warnings
|
||||
from typing import Any, List, Literal, Mapping
|
||||
|
||||
from autogen_core import CancellationToken, Component, Image
|
||||
from autogen_core.tools import (
|
||||
ImageResultContent,
|
||||
ParametersSchema,
|
||||
TextResultContent,
|
||||
ToolResult,
|
||||
ToolSchema,
|
||||
Workbench,
|
||||
)
|
||||
from mcp.types import CallToolResult, EmbeddedResource, ImageContent, ListToolsResult, TextContent
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Self
|
||||
|
||||
from ._actor import McpSessionActor
|
||||
from ._config import McpServerParams, SseServerParams, StdioServerParams
|
||||
|
||||
|
||||
class McpWorkbenchConfig(BaseModel):
|
||||
server_params: McpServerParams
|
||||
|
||||
|
||||
class McpWorkbenchState(BaseModel):
|
||||
type: Literal["McpWorkBenchState"] = "McpWorkBenchState"
|
||||
|
||||
|
||||
class McpWorkbench(Workbench, Component[McpWorkbenchConfig]):
|
||||
"""
|
||||
A workbench that wraps an MCP server and provides an interface
|
||||
to list and call tools provided by the server.
|
||||
|
||||
Args:
|
||||
server_params (McpServerParams): The parameters to connect to the MCP server.
|
||||
This can be either a :class:`StdioServerParams` or :class:`SseServerParams`.
|
||||
|
||||
Example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
|
||||
from autogen_ext.tools.mcp import McpWorkbench, StdioServerParams
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
params = StdioServerParams(
|
||||
command="uvx",
|
||||
args=["mcp-server-fetch"],
|
||||
read_timeout_seconds=60,
|
||||
)
|
||||
|
||||
# You can also use `start()` and `stop()` to manage the session.
|
||||
async with McpWorkbench(server_params=params) as workbench:
|
||||
tools = await workbench.list_tools()
|
||||
print(tools)
|
||||
result = await workbench.call_tool(tools[0]["name"], {"url": "https://github.com/"})
|
||||
print(result)
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
"""
|
||||
|
||||
component_provider_override = "autogen_ext.tools.mcp.McpWorkbench"
|
||||
component_config_schema = McpWorkbenchConfig
|
||||
|
||||
def __init__(self, server_params: McpServerParams) -> None:
|
||||
self._server_params = server_params
|
||||
# self._session: ClientSession | None = None
|
||||
self._actor: McpSessionActor | None = None
|
||||
self._read = None
|
||||
self._write = None
|
||||
|
||||
@property
|
||||
def server_params(self) -> McpServerParams:
|
||||
return self._server_params
|
||||
|
||||
async def list_tools(self) -> List[ToolSchema]:
|
||||
if not self._actor:
|
||||
await self.start() # fallback to start the actor if not initialized instead of raising an error
|
||||
# Why? Because when deserializing the workbench, the actor might not be initialized yet.
|
||||
# raise RuntimeError("Actor is not initialized. Call start() first.")
|
||||
if self._actor is None:
|
||||
raise RuntimeError("Actor is not initialized. Please check the server connection.")
|
||||
result_future = await self._actor.call("list_tools", None)
|
||||
list_tool_result = await result_future
|
||||
assert isinstance(
|
||||
list_tool_result, ListToolsResult
|
||||
), f"list_tools must return a CallToolResult, instead of : {str(type(list_tool_result))}"
|
||||
schema: List[ToolSchema] = []
|
||||
for tool in list_tool_result.tools:
|
||||
name = tool.name
|
||||
description = tool.description or ""
|
||||
parameters = ParametersSchema(
|
||||
type="object",
|
||||
properties=tool.inputSchema["properties"],
|
||||
required=tool.inputSchema.get("required", []),
|
||||
additionalProperties=tool.inputSchema.get("additionalProperties", False),
|
||||
)
|
||||
tool_schema = ToolSchema(
|
||||
name=name,
|
||||
description=description,
|
||||
parameters=parameters,
|
||||
)
|
||||
schema.append(tool_schema)
|
||||
return schema
|
||||
|
||||
async def call_tool(
|
||||
self, name: str, arguments: Mapping[str, Any] | None = None, cancellation_token: CancellationToken | None = None
|
||||
) -> ToolResult:
|
||||
if not self._actor:
|
||||
await self.start() # fallback to start the actor if not initialized instead of raising an error
|
||||
# Why? Because when deserializing the workbench, the actor might not be initialized yet.
|
||||
# raise RuntimeError("Actor is not initialized. Call start() first.")
|
||||
if self._actor is None:
|
||||
raise RuntimeError("Actor is not initialized. Please check the server connection.")
|
||||
if not cancellation_token:
|
||||
cancellation_token = CancellationToken()
|
||||
if not arguments:
|
||||
arguments = {}
|
||||
try:
|
||||
result_future = await self._actor.call("call_tool", {"name": name, "kargs": arguments})
|
||||
cancellation_token.link_future(result_future)
|
||||
result = await result_future
|
||||
assert isinstance(
|
||||
result, CallToolResult
|
||||
), f"call_tool must return a CallToolResult, instead of : {str(type(result))}"
|
||||
result_parts: List[TextResultContent | ImageResultContent] = []
|
||||
is_error = result.isError
|
||||
for content in result.content:
|
||||
if isinstance(content, TextContent):
|
||||
result_parts.append(TextResultContent(content=content.text))
|
||||
elif isinstance(content, ImageContent):
|
||||
result_parts.append(ImageResultContent(content=Image.from_base64(content.data)))
|
||||
elif isinstance(content, EmbeddedResource):
|
||||
# TODO: how to handle embedded resources?
|
||||
# For now we just use text representation.
|
||||
result_parts.append(TextResultContent(content=content.model_dump_json()))
|
||||
else:
|
||||
raise ValueError(f"Unknown content type from server: {type(content)}")
|
||||
except Exception as e:
|
||||
error_message = self._format_errors(e)
|
||||
is_error = True
|
||||
result_parts = [TextResultContent(content=error_message)]
|
||||
return ToolResult(name=name, result=result_parts, is_error=is_error)
|
||||
|
||||
def _format_errors(self, error: Exception) -> str:
|
||||
"""Recursively format errors into a string."""
|
||||
|
||||
error_message = ""
|
||||
if hasattr(builtins, "ExceptionGroup") and isinstance(error, builtins.ExceptionGroup):
|
||||
# ExceptionGroup is available in Python 3.11+.
|
||||
# TODO: how to make this compatible with Python 3.10?
|
||||
for sub_exception in error.exceptions: # type: ignore
|
||||
error_message += self._format_errors(sub_exception) # type: ignore
|
||||
else:
|
||||
error_message += f"{str(error)}\n"
|
||||
return error_message
|
||||
|
||||
async def start(self) -> None:
|
||||
if self._actor:
|
||||
warnings.warn(
|
||||
"McpWorkbench is already started. No need to start again.",
|
||||
UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return # Already initialized, no need to start again
|
||||
|
||||
if isinstance(self._server_params, (StdioServerParams, SseServerParams)):
|
||||
self._actor = McpSessionActor(self._server_params)
|
||||
await self._actor.initialize()
|
||||
else:
|
||||
raise ValueError(f"Unsupported server params type: {type(self._server_params)}")
|
||||
|
||||
async def stop(self) -> None:
|
||||
if self._actor:
|
||||
# Close the actor
|
||||
await self._actor.close()
|
||||
self._actor = None
|
||||
else:
|
||||
raise RuntimeError("McpWorkbench is not started. Call start() first.")
|
||||
|
||||
async def reset(self) -> None:
|
||||
pass
|
||||
|
||||
async def save_state(self) -> Mapping[str, Any]:
|
||||
return McpWorkbenchState().model_dump()
|
||||
|
||||
async def load_state(self, state: Mapping[str, Any]) -> None:
|
||||
pass
|
||||
|
||||
def _to_config(self) -> McpWorkbenchConfig:
|
||||
return McpWorkbenchConfig(server_params=self._server_params)
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: McpWorkbenchConfig) -> Self:
|
||||
return cls(server_params=config.server_params)
|
||||
@ -4,8 +4,10 @@ from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
from autogen_core import CancellationToken
|
||||
from autogen_core.tools import Workbench
|
||||
from autogen_core.utils import schema_to_pydantic_model
|
||||
from autogen_ext.tools.mcp import (
|
||||
McpWorkbench,
|
||||
SseMcpToolAdapter,
|
||||
SseServerParams,
|
||||
StdioMcpToolAdapter,
|
||||
@ -422,3 +424,80 @@ async def test_mcp_server_github() -> None:
|
||||
{"owner": "microsoft", "repo": "autogen", "path": "python", "branch": "main"}, CancellationToken()
|
||||
)
|
||||
assert result is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_workbench_start_stop() -> None:
|
||||
params = StdioServerParams(
|
||||
command="uvx",
|
||||
args=["mcp-server-fetch"],
|
||||
read_timeout_seconds=60,
|
||||
)
|
||||
|
||||
workbench = McpWorkbench(params)
|
||||
assert workbench is not None
|
||||
assert workbench.server_params == params
|
||||
await workbench.start()
|
||||
assert workbench._actor is not None # type: ignore[reportPrivateUsage]
|
||||
await workbench.stop()
|
||||
assert workbench._actor is None # type: ignore[reportPrivateUsage]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_workbench_server_fetch() -> None:
|
||||
params = StdioServerParams(
|
||||
command="uvx",
|
||||
args=["mcp-server-fetch"],
|
||||
read_timeout_seconds=60,
|
||||
)
|
||||
|
||||
workbench = McpWorkbench(server_params=params)
|
||||
await workbench.start()
|
||||
|
||||
tools = await workbench.list_tools()
|
||||
assert tools is not None
|
||||
assert tools[0]["name"] == "fetch"
|
||||
|
||||
result = await workbench.call_tool(tools[0]["name"], {"url": "https://github.com/"}, CancellationToken())
|
||||
assert result is not None
|
||||
|
||||
await workbench.stop()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_workbench_server_filesystem() -> None:
|
||||
params = StdioServerParams(
|
||||
command="npx",
|
||||
args=[
|
||||
"-y",
|
||||
"@modelcontextprotocol/server-filesystem",
|
||||
".",
|
||||
],
|
||||
read_timeout_seconds=60,
|
||||
)
|
||||
|
||||
workbench = McpWorkbench(server_params=params)
|
||||
await workbench.start()
|
||||
|
||||
tools = await workbench.list_tools()
|
||||
assert tools is not None
|
||||
tools = [tool for tool in tools if tool["name"] == "read_file"]
|
||||
assert len(tools) == 1
|
||||
tool = tools[0]
|
||||
result = await workbench.call_tool(tool["name"], {"path": "README.md"}, CancellationToken())
|
||||
assert result is not None
|
||||
|
||||
await workbench.stop()
|
||||
|
||||
# Serialize the workbench.
|
||||
config = workbench.dump_component()
|
||||
|
||||
# Deserialize the workbench.
|
||||
async with Workbench.load_component(config) as new_workbench:
|
||||
tools = await new_workbench.list_tools()
|
||||
assert tools is not None
|
||||
tools = [tool for tool in tools if tool["name"] == "read_file"]
|
||||
assert len(tools) == 1
|
||||
tool = tools[0]
|
||||
result = await new_workbench.call_tool(tool["name"], {"path": "README.md"}, CancellationToken())
|
||||
assert result is not None
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user