mirror of
https://github.com/microsoft/autogen.git
synced 2025-11-01 18:29:49 +00:00
Add tool name and description override functionality to Workbench implementations (#6690)
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: ekzhu <320302+ekzhu@users.noreply.github.com> Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
This commit is contained in:
parent
e10767421f
commit
0bd99ee516
@ -1,4 +1,13 @@
|
||||
from ._base import BaseStreamTool, BaseTool, BaseToolWithState, ParametersSchema, StreamTool, Tool, ToolSchema
|
||||
from ._base import (
|
||||
BaseStreamTool,
|
||||
BaseTool,
|
||||
BaseToolWithState,
|
||||
ParametersSchema,
|
||||
StreamTool,
|
||||
Tool,
|
||||
ToolOverride,
|
||||
ToolSchema,
|
||||
)
|
||||
from ._function_tool import FunctionTool
|
||||
from ._static_workbench import StaticStreamWorkbench, StaticWorkbench
|
||||
from ._workbench import ImageResultContent, TextResultContent, ToolResult, Workbench
|
||||
@ -18,4 +27,5 @@ __all__ = [
|
||||
"ImageResultContent",
|
||||
"StaticWorkbench",
|
||||
"StaticStreamWorkbench",
|
||||
"ToolOverride",
|
||||
]
|
||||
|
||||
@ -2,7 +2,19 @@ import json
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, AsyncGenerator, Dict, Generic, Mapping, Protocol, Type, TypeVar, cast, runtime_checkable
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
Dict,
|
||||
Generic,
|
||||
Mapping,
|
||||
Optional,
|
||||
Protocol,
|
||||
Type,
|
||||
TypeVar,
|
||||
cast,
|
||||
runtime_checkable,
|
||||
)
|
||||
|
||||
import jsonref
|
||||
from pydantic import BaseModel
|
||||
@ -33,6 +45,13 @@ class ToolSchema(TypedDict):
|
||||
strict: NotRequired[bool]
|
||||
|
||||
|
||||
class ToolOverride(BaseModel):
|
||||
"""Override configuration for a tool's name and/or description."""
|
||||
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Tool(Protocol):
|
||||
@property
|
||||
|
||||
@ -1,18 +1,19 @@
|
||||
import asyncio
|
||||
import builtins
|
||||
from typing import Any, AsyncGenerator, Dict, List, Literal, Mapping
|
||||
from typing import Any, AsyncGenerator, Dict, List, Literal, Mapping, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Self
|
||||
|
||||
from .._cancellation_token import CancellationToken
|
||||
from .._component_config import Component, ComponentModel
|
||||
from ._base import BaseTool, StreamTool, ToolSchema
|
||||
from ._base import BaseTool, StreamTool, ToolOverride, ToolSchema
|
||||
from ._workbench import StreamWorkbench, TextResultContent, ToolResult, Workbench
|
||||
|
||||
|
||||
class StaticWorkbenchConfig(BaseModel):
|
||||
tools: List[ComponentModel] = []
|
||||
tool_overrides: Dict[str, ToolOverride] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class StateicWorkbenchState(BaseModel):
|
||||
@ -28,16 +29,67 @@ class StaticWorkbench(Workbench, Component[StaticWorkbenchConfig]):
|
||||
Args:
|
||||
tools (List[BaseTool[Any, Any]]): A list of tools to be included in the workbench.
|
||||
The tools should be subclasses of :class:`~autogen_core.tools.BaseTool`.
|
||||
tool_overrides (Optional[Dict[str, ToolOverride]]): Optional mapping of original tool
|
||||
names to override configurations for name and/or description. This allows
|
||||
customizing how tools appear to consumers while maintaining the underlying
|
||||
tool functionality.
|
||||
"""
|
||||
|
||||
component_provider_override = "autogen_core.tools.StaticWorkbench"
|
||||
component_config_schema = StaticWorkbenchConfig
|
||||
|
||||
def __init__(self, tools: List[BaseTool[Any, Any]]) -> None:
|
||||
def __init__(
|
||||
self, tools: List[BaseTool[Any, Any]], tool_overrides: Optional[Dict[str, ToolOverride]] = None
|
||||
) -> None:
|
||||
self._tools = tools
|
||||
self._tool_overrides = tool_overrides or {}
|
||||
|
||||
# Build reverse mapping from override names to original names for call_tool
|
||||
self._override_name_to_original: Dict[str, str] = {}
|
||||
existing_tool_names = {tool.name for tool in self._tools}
|
||||
|
||||
for original_name, override in self._tool_overrides.items():
|
||||
if override.name and override.name != original_name:
|
||||
# Check for conflicts with existing tool names
|
||||
if override.name in existing_tool_names and override.name != original_name:
|
||||
raise ValueError(
|
||||
f"Tool override name '{override.name}' conflicts with existing tool name. "
|
||||
f"Override names must not conflict with any tool names."
|
||||
)
|
||||
# Check for conflicts with other override names
|
||||
if override.name in self._override_name_to_original:
|
||||
existing_original = self._override_name_to_original[override.name]
|
||||
raise ValueError(
|
||||
f"Tool override name '{override.name}' is used by multiple tools: "
|
||||
f"'{existing_original}' and '{original_name}'. Override names must be unique."
|
||||
)
|
||||
self._override_name_to_original[override.name] = original_name
|
||||
|
||||
async def list_tools(self) -> List[ToolSchema]:
|
||||
return [tool.schema for tool in self._tools]
|
||||
result_schemas: List[ToolSchema] = []
|
||||
for tool in self._tools:
|
||||
original_schema = tool.schema
|
||||
|
||||
# Apply overrides if they exist for this tool
|
||||
if tool.name in self._tool_overrides:
|
||||
override = self._tool_overrides[tool.name]
|
||||
# Create a new ToolSchema with overrides applied
|
||||
schema: ToolSchema = {
|
||||
"name": override.name if override.name is not None else original_schema["name"],
|
||||
"description": override.description
|
||||
if override.description is not None
|
||||
else original_schema.get("description", ""),
|
||||
}
|
||||
# Copy optional fields
|
||||
if "parameters" in original_schema:
|
||||
schema["parameters"] = original_schema["parameters"]
|
||||
if "strict" in original_schema:
|
||||
schema["strict"] = original_schema["strict"]
|
||||
else:
|
||||
schema = original_schema
|
||||
|
||||
result_schemas.append(schema)
|
||||
return result_schemas
|
||||
|
||||
async def call_tool(
|
||||
self,
|
||||
@ -46,10 +98,13 @@ class StaticWorkbench(Workbench, Component[StaticWorkbenchConfig]):
|
||||
cancellation_token: CancellationToken | None = None,
|
||||
call_id: str | None = None,
|
||||
) -> ToolResult:
|
||||
tool = next((tool for tool in self._tools if tool.name == name), None)
|
||||
# Check if the name is an override name and map it back to the original
|
||||
original_name = self._override_name_to_original.get(name, name)
|
||||
|
||||
tool = next((tool for tool in self._tools if tool.name == original_name), None)
|
||||
if tool is None:
|
||||
return ToolResult(
|
||||
name=name,
|
||||
name=name, # Return the requested name (which might be overridden)
|
||||
result=[TextResultContent(content=f"Tool {name} not found.")],
|
||||
is_error=True,
|
||||
)
|
||||
@ -66,7 +121,7 @@ class StaticWorkbench(Workbench, Component[StaticWorkbenchConfig]):
|
||||
except Exception as e:
|
||||
result_str = self._format_errors(e)
|
||||
is_error = True
|
||||
return ToolResult(name=tool.name, result=[TextResultContent(content=result_str)], is_error=is_error)
|
||||
return ToolResult(name=name, result=[TextResultContent(content=result_str)], is_error=is_error)
|
||||
|
||||
async def start(self) -> None:
|
||||
return None
|
||||
@ -90,11 +145,13 @@ class StaticWorkbench(Workbench, Component[StaticWorkbenchConfig]):
|
||||
await tool.load_state_json(parsed_state.tools[tool.name])
|
||||
|
||||
def _to_config(self) -> StaticWorkbenchConfig:
|
||||
return StaticWorkbenchConfig(tools=[tool.dump_component() for tool in self._tools])
|
||||
return StaticWorkbenchConfig(
|
||||
tools=[tool.dump_component() for tool in self._tools], tool_overrides=self._tool_overrides
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: StaticWorkbenchConfig) -> Self:
|
||||
return cls(tools=[BaseTool.load_component(tool) for tool in config.tools])
|
||||
return cls(tools=[BaseTool.load_component(tool) for tool in config.tools], tool_overrides=config.tool_overrides)
|
||||
|
||||
def _format_errors(self, error: Exception) -> str:
|
||||
"""Recursively format errors into a string."""
|
||||
|
||||
@ -0,0 +1,285 @@
|
||||
from typing import Annotated, Dict
|
||||
|
||||
import pytest
|
||||
from autogen_core.code_executor import ImportFromModule
|
||||
from autogen_core.tools import FunctionTool, StaticWorkbench, ToolOverride, Workbench
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_static_workbench_with_tool_overrides() -> None:
|
||||
"""Test StaticWorkbench with tool name and description overrides."""
|
||||
|
||||
def test_tool_func_1(x: Annotated[int, "The number to double."]) -> int:
|
||||
return x * 2
|
||||
|
||||
def test_tool_func_2(a: Annotated[int, "First number"], b: Annotated[int, "Second number"]) -> int:
|
||||
return a + b
|
||||
|
||||
test_tool_1 = FunctionTool(
|
||||
test_tool_func_1,
|
||||
name="double",
|
||||
description="A test tool that doubles a number.",
|
||||
global_imports=[ImportFromModule(module="typing_extensions", imports=["Annotated"])],
|
||||
)
|
||||
test_tool_2 = FunctionTool(
|
||||
test_tool_func_2,
|
||||
name="add",
|
||||
description="A test tool that adds two numbers.",
|
||||
global_imports=[ImportFromModule(module="typing_extensions", imports=["Annotated"])],
|
||||
)
|
||||
|
||||
# Define tool overrides
|
||||
overrides: Dict[str, ToolOverride] = {
|
||||
"double": ToolOverride(name="multiply_by_two", description="Multiplies a number by 2"),
|
||||
"add": ToolOverride(description="Performs addition of two integers"), # Only override description
|
||||
}
|
||||
|
||||
# Create a StaticWorkbench instance with tool overrides
|
||||
async with StaticWorkbench(tools=[test_tool_1, test_tool_2], tool_overrides=overrides) as workbench:
|
||||
# List tools and verify overrides are applied
|
||||
tools = await workbench.list_tools()
|
||||
assert len(tools) == 2
|
||||
|
||||
# Check first tool has name and description overridden
|
||||
assert tools[0]["name"] == "multiply_by_two"
|
||||
assert tools[0].get("description") == "Multiplies a number by 2"
|
||||
assert tools[0].get("parameters") == {
|
||||
"type": "object",
|
||||
"properties": {"x": {"type": "integer", "title": "X", "description": "The number to double."}},
|
||||
"required": ["x"],
|
||||
"additionalProperties": False,
|
||||
}
|
||||
|
||||
# Check second tool has only description overridden
|
||||
assert tools[1]["name"] == "add" # Original name
|
||||
assert tools[1].get("description") == "Performs addition of two integers" # Overridden description
|
||||
assert tools[1].get("parameters") == {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"a": {"type": "integer", "title": "A", "description": "First number"},
|
||||
"b": {"type": "integer", "title": "B", "description": "Second number"},
|
||||
},
|
||||
"required": ["a", "b"],
|
||||
"additionalProperties": False,
|
||||
}
|
||||
|
||||
# Call tools using override names
|
||||
result_1 = await workbench.call_tool("multiply_by_two", {"x": 5})
|
||||
assert result_1.name == "multiply_by_two" # Should return the override name
|
||||
assert result_1.result[0].type == "TextResultContent"
|
||||
assert result_1.result[0].content == "10"
|
||||
assert result_1.to_text() == "10"
|
||||
assert result_1.is_error is False
|
||||
|
||||
# Call tool using original name (should still work for description-only override)
|
||||
result_2 = await workbench.call_tool("add", {"a": 3, "b": 7})
|
||||
assert result_2.name == "add"
|
||||
assert result_2.result[0].type == "TextResultContent"
|
||||
assert result_2.result[0].content == "10"
|
||||
assert result_2.to_text() == "10"
|
||||
assert result_2.is_error is False
|
||||
|
||||
# Test calling non-existent tool
|
||||
result_3 = await workbench.call_tool("nonexistent", {"x": 5})
|
||||
assert result_3.name == "nonexistent"
|
||||
assert result_3.is_error is True
|
||||
assert result_3.result[0].type == "TextResultContent"
|
||||
assert "Tool nonexistent not found" in result_3.result[0].content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_static_workbench_without_overrides() -> None:
|
||||
"""Test StaticWorkbench without overrides (original behavior)."""
|
||||
|
||||
def test_tool_func(x: Annotated[int, "The number to double."]) -> int:
|
||||
return x * 2
|
||||
|
||||
test_tool = FunctionTool(
|
||||
test_tool_func,
|
||||
name="double",
|
||||
description="A test tool that doubles a number.",
|
||||
global_imports=[ImportFromModule(module="typing_extensions", imports=["Annotated"])],
|
||||
)
|
||||
|
||||
# Create workbench without overrides
|
||||
async with StaticWorkbench(tools=[test_tool]) as workbench:
|
||||
tools = await workbench.list_tools()
|
||||
assert len(tools) == 1
|
||||
assert tools[0].get("name") == "double"
|
||||
assert tools[0].get("description") == "A test tool that doubles a number."
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_static_workbench_serialization_with_overrides() -> None:
|
||||
"""Test that StaticWorkbench can be serialized and deserialized with overrides."""
|
||||
|
||||
def test_tool_func(x: Annotated[int, "The number to double."]) -> int:
|
||||
return x * 2
|
||||
|
||||
test_tool = FunctionTool(
|
||||
test_tool_func,
|
||||
name="double",
|
||||
description="A test tool that doubles a number.",
|
||||
global_imports=[ImportFromModule(module="typing_extensions", imports=["Annotated"])],
|
||||
)
|
||||
|
||||
overrides: Dict[str, ToolOverride] = {
|
||||
"double": ToolOverride(name="multiply_by_two", description="Multiplies a number by 2")
|
||||
}
|
||||
|
||||
# Create workbench with overrides
|
||||
workbench = StaticWorkbench(tools=[test_tool], tool_overrides=overrides)
|
||||
|
||||
# Save configuration
|
||||
config = workbench.dump_component()
|
||||
assert "tool_overrides" in config.config
|
||||
|
||||
# Load workbench from configuration
|
||||
async with Workbench.load_component(config) as new_workbench:
|
||||
tools = await new_workbench.list_tools()
|
||||
assert len(tools) == 1
|
||||
assert tools[0]["name"] == "multiply_by_two"
|
||||
assert tools[0].get("description") == "Multiplies a number by 2"
|
||||
|
||||
# Test calling tool with override name
|
||||
result = await new_workbench.call_tool("multiply_by_two", {"x": 5})
|
||||
assert result.name == "multiply_by_two"
|
||||
assert result.result[0].content == "10"
|
||||
assert result.is_error is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_static_workbench_partial_overrides() -> None:
|
||||
"""Test StaticWorkbench with partial overrides (name only, description only)."""
|
||||
|
||||
def tool1_func(x: Annotated[int, "Number"]) -> int:
|
||||
return x
|
||||
|
||||
def tool2_func(x: Annotated[int, "Number"]) -> int:
|
||||
return x
|
||||
|
||||
tool1 = FunctionTool(
|
||||
tool1_func,
|
||||
name="tool1",
|
||||
description="Original description 1",
|
||||
global_imports=[ImportFromModule(module="typing_extensions", imports=["Annotated"])],
|
||||
)
|
||||
tool2 = FunctionTool(
|
||||
tool2_func,
|
||||
name="tool2",
|
||||
description="Original description 2",
|
||||
global_imports=[ImportFromModule(module="typing_extensions", imports=["Annotated"])],
|
||||
)
|
||||
|
||||
overrides: Dict[str, ToolOverride] = {
|
||||
"tool1": ToolOverride(name="renamed_tool1"), # Only name override
|
||||
"tool2": ToolOverride(description="New description 2"), # Only description override
|
||||
}
|
||||
|
||||
async with StaticWorkbench(tools=[tool1, tool2], tool_overrides=overrides) as workbench:
|
||||
tools = await workbench.list_tools()
|
||||
|
||||
# tool1: name overridden, description unchanged
|
||||
assert tools[0].get("name") == "renamed_tool1"
|
||||
assert tools[0].get("description") == "Original description 1"
|
||||
|
||||
# tool2: name unchanged, description overridden
|
||||
assert tools[1].get("name") == "tool2"
|
||||
assert tools[1].get("description") == "New description 2"
|
||||
|
||||
# Test calling with override name
|
||||
result1 = await workbench.call_tool("renamed_tool1", {"x": 42})
|
||||
assert result1.name == "renamed_tool1"
|
||||
assert result1.result[0].content == "42"
|
||||
|
||||
# Test calling with original name
|
||||
result2 = await workbench.call_tool("tool2", {"x": 42})
|
||||
assert result2.name == "tool2"
|
||||
assert result2.result[0].content == "42"
|
||||
|
||||
|
||||
def test_tool_override_model() -> None:
|
||||
"""Test ToolOverride model functionality."""
|
||||
|
||||
# Test with both fields
|
||||
override1 = ToolOverride(name="new_name", description="new_desc")
|
||||
assert override1.name == "new_name"
|
||||
assert override1.description == "new_desc"
|
||||
|
||||
# Test with only name
|
||||
override2 = ToolOverride(name="new_name")
|
||||
assert override2.name == "new_name"
|
||||
assert override2.description is None
|
||||
|
||||
# Test with only description
|
||||
override3 = ToolOverride(description="new_desc")
|
||||
assert override3.name is None
|
||||
assert override3.description == "new_desc"
|
||||
|
||||
# Test empty
|
||||
override4 = ToolOverride()
|
||||
assert override4.name is None
|
||||
assert override4.description is None
|
||||
|
||||
|
||||
def test_static_workbench_conflict_detection() -> None:
|
||||
"""Test that StaticWorkbench detects conflicts in tool override names."""
|
||||
|
||||
def test_tool_func_1(x: Annotated[int, "Number"]) -> int:
|
||||
return x
|
||||
|
||||
def test_tool_func_2(x: Annotated[int, "Number"]) -> int:
|
||||
return x
|
||||
|
||||
def test_tool_func_3(x: Annotated[int, "Number"]) -> int:
|
||||
return x
|
||||
|
||||
tool1 = FunctionTool(
|
||||
test_tool_func_1,
|
||||
name="tool1",
|
||||
description="Tool 1",
|
||||
global_imports=[ImportFromModule(module="typing_extensions", imports=["Annotated"])],
|
||||
)
|
||||
tool2 = FunctionTool(
|
||||
test_tool_func_2,
|
||||
name="tool2",
|
||||
description="Tool 2",
|
||||
global_imports=[ImportFromModule(module="typing_extensions", imports=["Annotated"])],
|
||||
)
|
||||
tool3 = FunctionTool(
|
||||
test_tool_func_3,
|
||||
name="tool3",
|
||||
description="Tool 3",
|
||||
global_imports=[ImportFromModule(module="typing_extensions", imports=["Annotated"])],
|
||||
)
|
||||
|
||||
# Test 1: Valid overrides - should work
|
||||
overrides_valid: Dict[str, ToolOverride] = {
|
||||
"tool1": ToolOverride(name="renamed_tool1"),
|
||||
"tool2": ToolOverride(name="renamed_tool2"),
|
||||
}
|
||||
workbench_valid = StaticWorkbench(tools=[tool1, tool2, tool3], tool_overrides=overrides_valid)
|
||||
assert "renamed_tool1" in workbench_valid._override_name_to_original # type: ignore[reportPrivateUsage]
|
||||
assert "renamed_tool2" in workbench_valid._override_name_to_original # type: ignore[reportPrivateUsage]
|
||||
|
||||
# Test 2: Conflict with existing tool name - should fail
|
||||
overrides_conflict: Dict[str, ToolOverride] = {
|
||||
"tool1": ToolOverride(name="tool2") # tool2 already exists
|
||||
}
|
||||
with pytest.raises(ValueError):
|
||||
StaticWorkbench(tools=[tool1, tool2, tool3], tool_overrides=overrides_conflict)
|
||||
|
||||
# Test 3: Duplicate override names - should fail
|
||||
overrides_duplicate: Dict[str, ToolOverride] = {
|
||||
"tool1": ToolOverride(name="same_name"),
|
||||
"tool2": ToolOverride(name="same_name"), # Duplicate
|
||||
}
|
||||
with pytest.raises(ValueError):
|
||||
StaticWorkbench(tools=[tool1, tool2, tool3], tool_overrides=overrides_duplicate)
|
||||
|
||||
# Test 4: Self-renaming - should work but not add to reverse mapping
|
||||
overrides_self: Dict[str, ToolOverride] = {
|
||||
"tool1": ToolOverride(name="tool1") # Renaming to itself
|
||||
}
|
||||
workbench_self = StaticWorkbench(tools=[tool1, tool2, tool3], tool_overrides=overrides_self)
|
||||
assert "tool1" not in workbench_self._override_name_to_original # type: ignore[reportPrivateUsage]
|
||||
@ -1,19 +1,20 @@
|
||||
import asyncio
|
||||
import builtins
|
||||
import warnings
|
||||
from typing import Any, List, Literal, Mapping
|
||||
from typing import Any, Dict, List, Literal, Mapping, Optional
|
||||
|
||||
from autogen_core import CancellationToken, Component, Image, trace_tool_span
|
||||
from autogen_core.tools import (
|
||||
ImageResultContent,
|
||||
ParametersSchema,
|
||||
TextResultContent,
|
||||
ToolOverride,
|
||||
ToolResult,
|
||||
ToolSchema,
|
||||
Workbench,
|
||||
)
|
||||
from mcp.types import CallToolResult, EmbeddedResource, ImageContent, ListToolsResult, TextContent
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Self
|
||||
|
||||
from ._actor import McpSessionActor
|
||||
@ -22,6 +23,7 @@ from ._config import McpServerParams, SseServerParams, StdioServerParams, Stream
|
||||
|
||||
class McpWorkbenchConfig(BaseModel):
|
||||
server_params: McpServerParams
|
||||
tool_overrides: Dict[str, ToolOverride] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class McpWorkbenchState(BaseModel):
|
||||
@ -36,6 +38,10 @@ class McpWorkbench(Workbench, Component[McpWorkbenchConfig]):
|
||||
Args:
|
||||
server_params (McpServerParams): The parameters to connect to the MCP server.
|
||||
This can be either a :class:`StdioServerParams` or :class:`SseServerParams`.
|
||||
tool_overrides (Optional[Dict[str, ToolOverride]]): Optional mapping of original tool
|
||||
names to override configurations for name and/or description. This allows
|
||||
customizing how server tools appear to consumers while maintaining the underlying
|
||||
tool functionality.
|
||||
|
||||
Examples:
|
||||
|
||||
@ -65,6 +71,38 @@ class McpWorkbench(Workbench, Component[McpWorkbenchConfig]):
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
Example of using tool overrides:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import asyncio
|
||||
from autogen_ext.tools.mcp import McpWorkbench, StdioServerParams
|
||||
from autogen_core.tools import ToolOverride
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
params = StdioServerParams(
|
||||
command="uvx",
|
||||
args=["mcp-server-fetch"],
|
||||
read_timeout_seconds=60,
|
||||
)
|
||||
|
||||
# Override the fetch tool's name and description
|
||||
overrides = {
|
||||
"fetch": ToolOverride(name="web_fetch", description="Enhanced web fetching tool with better error handling")
|
||||
}
|
||||
|
||||
async with McpWorkbench(server_params=params, tool_overrides=overrides) as workbench:
|
||||
tools = await workbench.list_tools()
|
||||
# The tool will now appear as "web_fetch" with the new description
|
||||
print(tools)
|
||||
# Call the overridden tool
|
||||
result = await workbench.call_tool("web_fetch", {"url": "https://github.com/"})
|
||||
print(result)
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
Example of using the workbench with the `GitHub MCP Server <https://github.com/github/github-mcp-server>`_:
|
||||
|
||||
.. code-block:: python
|
||||
@ -149,8 +187,26 @@ class McpWorkbench(Workbench, Component[McpWorkbenchConfig]):
|
||||
component_provider_override = "autogen_ext.tools.mcp.McpWorkbench"
|
||||
component_config_schema = McpWorkbenchConfig
|
||||
|
||||
def __init__(self, server_params: McpServerParams) -> None:
|
||||
def __init__(
|
||||
self, server_params: McpServerParams, tool_overrides: Optional[Dict[str, ToolOverride]] = None
|
||||
) -> None:
|
||||
self._server_params = server_params
|
||||
self._tool_overrides = tool_overrides or {}
|
||||
|
||||
# Build reverse mapping from override names to original names for call_tool
|
||||
self._override_name_to_original: Dict[str, str] = {}
|
||||
for original_name, override in self._tool_overrides.items():
|
||||
override_name = override.name
|
||||
if override_name and override_name != original_name:
|
||||
# Check for conflicts with other override names
|
||||
if override_name in self._override_name_to_original:
|
||||
existing_original = self._override_name_to_original[override_name]
|
||||
raise ValueError(
|
||||
f"Tool override name '{override_name}' is used by multiple tools: "
|
||||
f"'{existing_original}' and '{original_name}'. Override names must be unique."
|
||||
)
|
||||
self._override_name_to_original[override_name] = original_name
|
||||
|
||||
# self._session: ClientSession | None = None
|
||||
self._actor: McpSessionActor | None = None
|
||||
self._actor_loop: asyncio.AbstractEventLoop | None = None
|
||||
@ -175,8 +231,18 @@ class McpWorkbench(Workbench, Component[McpWorkbenchConfig]):
|
||||
), f"list_tools must return a CallToolResult, instead of : {str(type(list_tool_result))}"
|
||||
schema: List[ToolSchema] = []
|
||||
for tool in list_tool_result.tools:
|
||||
name = tool.name
|
||||
original_name = tool.name
|
||||
name = original_name
|
||||
description = tool.description or ""
|
||||
|
||||
# Apply overrides if they exist for this tool
|
||||
if original_name in self._tool_overrides:
|
||||
override = self._tool_overrides[original_name]
|
||||
if override.name is not None:
|
||||
name = override.name
|
||||
if override.description is not None:
|
||||
description = override.description
|
||||
|
||||
parameters = ParametersSchema(
|
||||
type="object",
|
||||
properties=tool.inputSchema.get("properties", {}),
|
||||
@ -208,12 +274,16 @@ class McpWorkbench(Workbench, Component[McpWorkbenchConfig]):
|
||||
cancellation_token = CancellationToken()
|
||||
if not arguments:
|
||||
arguments = {}
|
||||
|
||||
# Check if the name is an override name and map it back to the original
|
||||
original_name = self._override_name_to_original.get(name, name)
|
||||
|
||||
with trace_tool_span(
|
||||
tool_name=name,
|
||||
tool_name=name, # Use the requested name for tracing
|
||||
tool_call_id=call_id,
|
||||
):
|
||||
try:
|
||||
result_future = await self._actor.call("call_tool", {"name": name, "kargs": arguments})
|
||||
result_future = await self._actor.call("call_tool", {"name": original_name, "kargs": arguments})
|
||||
cancellation_token.link_future(result_future)
|
||||
result = await result_future
|
||||
assert isinstance(
|
||||
@ -236,7 +306,7 @@ class McpWorkbench(Workbench, Component[McpWorkbenchConfig]):
|
||||
error_message = self._format_errors(e)
|
||||
is_error = True
|
||||
result_parts = [TextResultContent(content=error_message)]
|
||||
return ToolResult(name=name, result=result_parts, is_error=is_error)
|
||||
return ToolResult(name=name, result=result_parts, is_error=is_error) # Return the requested name
|
||||
|
||||
def _format_errors(self, error: Exception) -> str:
|
||||
"""Recursively format errors into a string."""
|
||||
@ -285,18 +355,21 @@ class McpWorkbench(Workbench, Component[McpWorkbenchConfig]):
|
||||
pass
|
||||
|
||||
def _to_config(self) -> McpWorkbenchConfig:
|
||||
return McpWorkbenchConfig(server_params=self._server_params)
|
||||
return McpWorkbenchConfig(server_params=self._server_params, tool_overrides=self._tool_overrides)
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config: McpWorkbenchConfig) -> Self:
|
||||
return cls(server_params=config.server_params)
|
||||
return cls(server_params=config.server_params, tool_overrides=config.tool_overrides)
|
||||
|
||||
def __del__(self) -> None:
|
||||
# Ensure the actor is stopped when the workbench is deleted
|
||||
if self._actor and self._actor_loop:
|
||||
loop = self._actor_loop
|
||||
if loop.is_running() and not loop.is_closed():
|
||||
loop.call_soon_threadsafe(lambda: asyncio.create_task(self.stop()))
|
||||
# Use getattr to safely handle cases where attributes may not be set (e.g., if __init__ failed)
|
||||
actor = getattr(self, "_actor", None)
|
||||
actor_loop = getattr(self, "_actor_loop", None)
|
||||
|
||||
if actor and actor_loop:
|
||||
if actor_loop.is_running() and not actor_loop.is_closed():
|
||||
actor_loop.call_soon_threadsafe(lambda: asyncio.create_task(self.stop()))
|
||||
else:
|
||||
msg = "Cannot safely stop actor at [McpWorkbench.__del__]: loop is closed or not running"
|
||||
warnings.warn(msg, RuntimeWarning, stacklevel=2)
|
||||
|
||||
@ -0,0 +1,298 @@
|
||||
import asyncio
|
||||
from typing import Any, Dict
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
from autogen_core.tools import ToolOverride
|
||||
from autogen_ext.tools.mcp import McpWorkbench, StdioServerParams
|
||||
from mcp import Tool
|
||||
from mcp.types import ListToolsResult
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_mcp_tools() -> list[Tool]:
|
||||
"""Create sample MCP tools for testing."""
|
||||
return [
|
||||
Tool(
|
||||
name="fetch",
|
||||
description="Fetches content from a URL",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {"url": {"type": "string"}},
|
||||
"required": ["url"],
|
||||
},
|
||||
),
|
||||
Tool(
|
||||
name="search",
|
||||
description="Searches for information",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {"query": {"type": "string"}},
|
||||
"required": ["query"],
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_mcp_actor() -> AsyncMock:
|
||||
"""Mock MCP session actor."""
|
||||
actor = AsyncMock()
|
||||
return actor
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_server_params() -> StdioServerParams:
|
||||
"""Sample server parameters for testing."""
|
||||
return StdioServerParams(command="echo", args=["test"])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_workbench_with_tool_overrides(
|
||||
sample_mcp_tools: list[Tool], mock_mcp_actor: AsyncMock, sample_server_params: StdioServerParams
|
||||
) -> None:
|
||||
"""Test McpWorkbench with tool name and description overrides."""
|
||||
|
||||
# Define tool overrides
|
||||
overrides: Dict[str, ToolOverride] = {
|
||||
"fetch": ToolOverride(name="web_fetch", description="Enhanced web fetching tool"),
|
||||
"search": ToolOverride(description="Advanced search functionality"), # Only override description
|
||||
}
|
||||
|
||||
# Create workbench with overrides
|
||||
workbench = McpWorkbench(server_params=sample_server_params, tool_overrides=overrides)
|
||||
workbench._actor = mock_mcp_actor # type: ignore[reportPrivateUsage]
|
||||
|
||||
# Mock list_tools response
|
||||
list_tools_result = ListToolsResult(tools=sample_mcp_tools)
|
||||
|
||||
# The actor.call() should return a Future that when awaited returns the list_tools_result
|
||||
future_result: asyncio.Future[ListToolsResult] = asyncio.Future()
|
||||
future_result.set_result(list_tools_result)
|
||||
mock_mcp_actor.call.return_value = future_result
|
||||
|
||||
try:
|
||||
# List tools and verify overrides are applied
|
||||
tools = await workbench.list_tools()
|
||||
assert len(tools) == 2
|
||||
|
||||
# Check first tool has name and description overridden
|
||||
assert tools[0].get("name") == "web_fetch"
|
||||
assert tools[0].get("description") == "Enhanced web fetching tool"
|
||||
|
||||
# Check second tool has only description overridden
|
||||
assert tools[1].get("name") == "search" # Original name
|
||||
assert tools[1].get("description") == "Advanced search functionality" # Overridden description
|
||||
|
||||
# Verify actor was called correctly
|
||||
mock_mcp_actor.call.assert_called_with("list_tools", None)
|
||||
|
||||
finally:
|
||||
workbench._actor = None # type: ignore[reportPrivateUsage]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_workbench_call_tool_with_overrides(
|
||||
sample_mcp_tools: list[Tool], mock_mcp_actor: AsyncMock, sample_server_params: StdioServerParams
|
||||
) -> None:
|
||||
"""Test calling tools with override names maps back to original names."""
|
||||
|
||||
overrides: Dict[str, ToolOverride] = {
|
||||
"fetch": ToolOverride(name="web_fetch", description="Enhanced web fetching tool")
|
||||
}
|
||||
|
||||
workbench = McpWorkbench(server_params=sample_server_params, tool_overrides=overrides)
|
||||
workbench._actor = mock_mcp_actor # type: ignore[reportPrivateUsage]
|
||||
|
||||
# Mock successful tool call response
|
||||
from mcp.types import CallToolResult, TextContent
|
||||
|
||||
mock_result = CallToolResult(content=[TextContent(text="Mock response", type="text")], isError=False)
|
||||
|
||||
# Create futures for each call
|
||||
def mock_call_side_effect(
|
||||
method: str, args: dict[str, Any] | None = None
|
||||
) -> asyncio.Future[ListToolsResult | CallToolResult]:
|
||||
future_result: asyncio.Future[ListToolsResult | CallToolResult] = asyncio.Future()
|
||||
if method == "list_tools":
|
||||
future_result.set_result(ListToolsResult(tools=sample_mcp_tools))
|
||||
elif method == "call_tool":
|
||||
future_result.set_result(mock_result)
|
||||
else:
|
||||
future_result.set_exception(ValueError(f"Unexpected method: {method}"))
|
||||
return future_result
|
||||
|
||||
mock_mcp_actor.call.side_effect = mock_call_side_effect
|
||||
|
||||
try:
|
||||
# Call tool using override name
|
||||
result = await workbench.call_tool("web_fetch", {"url": "https://example.com"})
|
||||
|
||||
# Verify the result
|
||||
assert result.name == "web_fetch" # Should return the override name
|
||||
assert result.result[0].content == "Mock response"
|
||||
assert result.is_error is False
|
||||
|
||||
# Verify the actor was called with the original tool name
|
||||
call_args = mock_mcp_actor.call.call_args_list[-1]
|
||||
assert call_args[0][0] == "call_tool"
|
||||
assert call_args[0][1]["name"] == "fetch" # Original name should be used
|
||||
assert call_args[0][1]["kargs"] == {"url": "https://example.com"}
|
||||
|
||||
finally:
|
||||
workbench._actor = None # type: ignore[reportPrivateUsage]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_workbench_without_overrides(
|
||||
sample_mcp_tools: list[Tool], mock_mcp_actor: AsyncMock, sample_server_params: StdioServerParams
|
||||
) -> None:
|
||||
"""Test McpWorkbench without overrides (original behavior)."""
|
||||
|
||||
workbench = McpWorkbench(server_params=sample_server_params)
|
||||
workbench._actor = mock_mcp_actor # type: ignore[reportPrivateUsage]
|
||||
|
||||
# Mock list_tools response
|
||||
list_tools_result = ListToolsResult(tools=sample_mcp_tools)
|
||||
future_result: asyncio.Future[ListToolsResult] = asyncio.Future()
|
||||
future_result.set_result(list_tools_result)
|
||||
mock_mcp_actor.call.return_value = future_result
|
||||
|
||||
try:
|
||||
tools = await workbench.list_tools()
|
||||
assert len(tools) == 2
|
||||
|
||||
# Verify original names and descriptions are preserved
|
||||
assert tools[0].get("name") == "fetch"
|
||||
assert tools[0].get("description") == "Fetches content from a URL"
|
||||
assert tools[1].get("name") == "search"
|
||||
assert tools[1].get("description") == "Searches for information"
|
||||
|
||||
finally:
|
||||
workbench._actor = None # type: ignore[reportPrivateUsage]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_workbench_serialization_with_overrides(sample_server_params: StdioServerParams) -> None:
|
||||
"""Test that McpWorkbench can be serialized and deserialized with overrides."""
|
||||
|
||||
overrides: Dict[str, ToolOverride] = {
|
||||
"fetch": ToolOverride(name="web_fetch", description="Enhanced web fetching tool")
|
||||
}
|
||||
|
||||
# Create workbench with overrides
|
||||
workbench = McpWorkbench(server_params=sample_server_params, tool_overrides=overrides)
|
||||
|
||||
# Save configuration
|
||||
config = workbench.dump_component()
|
||||
assert "tool_overrides" in config.config
|
||||
assert "fetch" in config.config["tool_overrides"]
|
||||
assert config.config["tool_overrides"]["fetch"]["name"] == "web_fetch"
|
||||
assert config.config["tool_overrides"]["fetch"]["description"] == "Enhanced web fetching tool"
|
||||
|
||||
# Load workbench from configuration
|
||||
new_workbench = McpWorkbench.load_component(config)
|
||||
assert len(new_workbench._tool_overrides) == 1 # type: ignore[reportPrivateUsage]
|
||||
assert new_workbench._tool_overrides["fetch"].name == "web_fetch" # type: ignore[reportPrivateUsage]
|
||||
assert new_workbench._tool_overrides["fetch"].description == "Enhanced web fetching tool" # type: ignore[reportPrivateUsage]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_workbench_partial_overrides(
|
||||
sample_mcp_tools: list[Tool], mock_mcp_actor: AsyncMock, sample_server_params: StdioServerParams
|
||||
) -> None:
|
||||
"""Test McpWorkbench with partial overrides (name only, description only)."""
|
||||
|
||||
overrides: Dict[str, ToolOverride] = {
|
||||
"fetch": ToolOverride(name="web_fetch"), # Only name override
|
||||
"search": ToolOverride(description="Advanced search"), # Only description override
|
||||
}
|
||||
|
||||
workbench = McpWorkbench(server_params=sample_server_params, tool_overrides=overrides)
|
||||
workbench._actor = mock_mcp_actor # type: ignore[reportPrivateUsage]
|
||||
|
||||
# Mock list_tools response
|
||||
list_tools_result = ListToolsResult(tools=sample_mcp_tools)
|
||||
future_result: asyncio.Future[ListToolsResult] = asyncio.Future()
|
||||
future_result.set_result(list_tools_result)
|
||||
mock_mcp_actor.call.return_value = future_result
|
||||
|
||||
try:
|
||||
tools = await workbench.list_tools()
|
||||
|
||||
# fetch: name overridden, description unchanged
|
||||
assert tools[0].get("name") == "web_fetch"
|
||||
assert tools[0].get("description") == "Fetches content from a URL" # Original description
|
||||
|
||||
# search: name unchanged, description overridden
|
||||
assert tools[1].get("name") == "search" # Original name
|
||||
assert tools[1].get("description") == "Advanced search" # Overridden description
|
||||
|
||||
finally:
|
||||
workbench._actor = None # type: ignore[reportPrivateUsage]
|
||||
|
||||
|
||||
def test_mcp_tool_override_model() -> None:
|
||||
"""Test ToolOverride model functionality for MCP."""
|
||||
|
||||
# Test with both fields
|
||||
override1 = ToolOverride(name="new_name", description="new_desc")
|
||||
assert override1.name == "new_name"
|
||||
assert override1.description == "new_desc"
|
||||
|
||||
# Test with only name
|
||||
override2 = ToolOverride(name="new_name")
|
||||
assert override2.name == "new_name"
|
||||
assert override2.description is None
|
||||
|
||||
# Test with only description
|
||||
override3 = ToolOverride(description="new_desc")
|
||||
assert override3.name is None
|
||||
assert override3.description == "new_desc"
|
||||
|
||||
# Test empty
|
||||
override4 = ToolOverride()
|
||||
assert override4.name is None
|
||||
assert override4.description is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_workbench_override_name_to_original_mapping(sample_server_params: StdioServerParams) -> None:
|
||||
"""Test that the reverse mapping from override names to original names works correctly."""
|
||||
|
||||
overrides: Dict[str, ToolOverride] = {
|
||||
"original1": ToolOverride(name="override1"),
|
||||
"original2": ToolOverride(name="override2"),
|
||||
"original3": ToolOverride(description="only description override"), # No name change, only description
|
||||
}
|
||||
|
||||
workbench = McpWorkbench(server_params=sample_server_params, tool_overrides=overrides)
|
||||
|
||||
# Check reverse mapping is built correctly
|
||||
assert workbench._override_name_to_original["override1"] == "original1" # type: ignore[reportPrivateUsage]
|
||||
assert workbench._override_name_to_original["override2"] == "original2" # type: ignore[reportPrivateUsage]
|
||||
assert "original3" not in workbench._override_name_to_original # type: ignore[reportPrivateUsage]
|
||||
assert len(workbench._override_name_to_original) == 2 # type: ignore[reportPrivateUsage]
|
||||
|
||||
|
||||
def test_mcp_workbench_conflict_detection() -> None:
|
||||
"""Test that McpWorkbench detects conflicts in tool override names."""
|
||||
|
||||
server_params = StdioServerParams(command="echo", args=["test"])
|
||||
|
||||
# Test 1: Valid overrides - should work
|
||||
overrides_valid: Dict[str, ToolOverride] = {
|
||||
"fetch": ToolOverride(name="web_fetch"),
|
||||
"search": ToolOverride(name="advanced_search"),
|
||||
}
|
||||
workbench_valid = McpWorkbench(server_params=server_params, tool_overrides=overrides_valid)
|
||||
assert workbench_valid._override_name_to_original["web_fetch"] == "fetch" # type: ignore[reportPrivateUsage]
|
||||
assert workbench_valid._override_name_to_original["advanced_search"] == "search" # type: ignore[reportPrivateUsage]
|
||||
|
||||
# Test 2: Duplicate override names - should fail
|
||||
overrides_duplicate: Dict[str, ToolOverride] = {
|
||||
"fetch": ToolOverride(name="same_name"),
|
||||
"search": ToolOverride(name="same_name"), # Duplicate
|
||||
}
|
||||
with pytest.raises(ValueError):
|
||||
McpWorkbench(server_params=server_params, tool_overrides=overrides_duplicate)
|
||||
Loading…
x
Reference in New Issue
Block a user