Add tool name and description override functionality to Workbench implementations (#6690)

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: ekzhu <320302+ekzhu@users.noreply.github.com>
Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
This commit is contained in:
Copilot 2025-07-06 13:39:05 -07:00 committed by GitHub
parent e10767421f
commit 0bd99ee516
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 767 additions and 25 deletions

View File

@ -1,4 +1,13 @@
from ._base import BaseStreamTool, BaseTool, BaseToolWithState, ParametersSchema, StreamTool, Tool, ToolSchema
from ._base import (
BaseStreamTool,
BaseTool,
BaseToolWithState,
ParametersSchema,
StreamTool,
Tool,
ToolOverride,
ToolSchema,
)
from ._function_tool import FunctionTool
from ._static_workbench import StaticStreamWorkbench, StaticWorkbench
from ._workbench import ImageResultContent, TextResultContent, ToolResult, Workbench
@ -18,4 +27,5 @@ __all__ = [
"ImageResultContent",
"StaticWorkbench",
"StaticStreamWorkbench",
"ToolOverride",
]

View File

@ -2,7 +2,19 @@ import json
import logging
from abc import ABC, abstractmethod
from collections.abc import Sequence
from typing import Any, AsyncGenerator, Dict, Generic, Mapping, Protocol, Type, TypeVar, cast, runtime_checkable
from typing import (
Any,
AsyncGenerator,
Dict,
Generic,
Mapping,
Optional,
Protocol,
Type,
TypeVar,
cast,
runtime_checkable,
)
import jsonref
from pydantic import BaseModel
@ -33,6 +45,13 @@ class ToolSchema(TypedDict):
strict: NotRequired[bool]
class ToolOverride(BaseModel):
"""Override configuration for a tool's name and/or description."""
name: Optional[str] = None
description: Optional[str] = None
@runtime_checkable
class Tool(Protocol):
@property

View File

@ -1,18 +1,19 @@
import asyncio
import builtins
from typing import Any, AsyncGenerator, Dict, List, Literal, Mapping
from typing import Any, AsyncGenerator, Dict, List, Literal, Mapping, Optional
from pydantic import BaseModel
from pydantic import BaseModel, Field
from typing_extensions import Self
from .._cancellation_token import CancellationToken
from .._component_config import Component, ComponentModel
from ._base import BaseTool, StreamTool, ToolSchema
from ._base import BaseTool, StreamTool, ToolOverride, ToolSchema
from ._workbench import StreamWorkbench, TextResultContent, ToolResult, Workbench
class StaticWorkbenchConfig(BaseModel):
tools: List[ComponentModel] = []
tool_overrides: Dict[str, ToolOverride] = Field(default_factory=dict)
class StateicWorkbenchState(BaseModel):
@ -28,16 +29,67 @@ class StaticWorkbench(Workbench, Component[StaticWorkbenchConfig]):
Args:
tools (List[BaseTool[Any, Any]]): A list of tools to be included in the workbench.
The tools should be subclasses of :class:`~autogen_core.tools.BaseTool`.
tool_overrides (Optional[Dict[str, ToolOverride]]): Optional mapping of original tool
names to override configurations for name and/or description. This allows
customizing how tools appear to consumers while maintaining the underlying
tool functionality.
"""
component_provider_override = "autogen_core.tools.StaticWorkbench"
component_config_schema = StaticWorkbenchConfig
def __init__(self, tools: List[BaseTool[Any, Any]]) -> None:
def __init__(
self, tools: List[BaseTool[Any, Any]], tool_overrides: Optional[Dict[str, ToolOverride]] = None
) -> None:
self._tools = tools
self._tool_overrides = tool_overrides or {}
# Build reverse mapping from override names to original names for call_tool
self._override_name_to_original: Dict[str, str] = {}
existing_tool_names = {tool.name for tool in self._tools}
for original_name, override in self._tool_overrides.items():
if override.name and override.name != original_name:
# Check for conflicts with existing tool names
if override.name in existing_tool_names and override.name != original_name:
raise ValueError(
f"Tool override name '{override.name}' conflicts with existing tool name. "
f"Override names must not conflict with any tool names."
)
# Check for conflicts with other override names
if override.name in self._override_name_to_original:
existing_original = self._override_name_to_original[override.name]
raise ValueError(
f"Tool override name '{override.name}' is used by multiple tools: "
f"'{existing_original}' and '{original_name}'. Override names must be unique."
)
self._override_name_to_original[override.name] = original_name
async def list_tools(self) -> List[ToolSchema]:
return [tool.schema for tool in self._tools]
result_schemas: List[ToolSchema] = []
for tool in self._tools:
original_schema = tool.schema
# Apply overrides if they exist for this tool
if tool.name in self._tool_overrides:
override = self._tool_overrides[tool.name]
# Create a new ToolSchema with overrides applied
schema: ToolSchema = {
"name": override.name if override.name is not None else original_schema["name"],
"description": override.description
if override.description is not None
else original_schema.get("description", ""),
}
# Copy optional fields
if "parameters" in original_schema:
schema["parameters"] = original_schema["parameters"]
if "strict" in original_schema:
schema["strict"] = original_schema["strict"]
else:
schema = original_schema
result_schemas.append(schema)
return result_schemas
async def call_tool(
self,
@ -46,10 +98,13 @@ class StaticWorkbench(Workbench, Component[StaticWorkbenchConfig]):
cancellation_token: CancellationToken | None = None,
call_id: str | None = None,
) -> ToolResult:
tool = next((tool for tool in self._tools if tool.name == name), None)
# Check if the name is an override name and map it back to the original
original_name = self._override_name_to_original.get(name, name)
tool = next((tool for tool in self._tools if tool.name == original_name), None)
if tool is None:
return ToolResult(
name=name,
name=name, # Return the requested name (which might be overridden)
result=[TextResultContent(content=f"Tool {name} not found.")],
is_error=True,
)
@ -66,7 +121,7 @@ class StaticWorkbench(Workbench, Component[StaticWorkbenchConfig]):
except Exception as e:
result_str = self._format_errors(e)
is_error = True
return ToolResult(name=tool.name, result=[TextResultContent(content=result_str)], is_error=is_error)
return ToolResult(name=name, result=[TextResultContent(content=result_str)], is_error=is_error)
async def start(self) -> None:
return None
@ -90,11 +145,13 @@ class StaticWorkbench(Workbench, Component[StaticWorkbenchConfig]):
await tool.load_state_json(parsed_state.tools[tool.name])
def _to_config(self) -> StaticWorkbenchConfig:
return StaticWorkbenchConfig(tools=[tool.dump_component() for tool in self._tools])
return StaticWorkbenchConfig(
tools=[tool.dump_component() for tool in self._tools], tool_overrides=self._tool_overrides
)
@classmethod
def _from_config(cls, config: StaticWorkbenchConfig) -> Self:
return cls(tools=[BaseTool.load_component(tool) for tool in config.tools])
return cls(tools=[BaseTool.load_component(tool) for tool in config.tools], tool_overrides=config.tool_overrides)
def _format_errors(self, error: Exception) -> str:
"""Recursively format errors into a string."""

View File

@ -0,0 +1,285 @@
from typing import Annotated, Dict
import pytest
from autogen_core.code_executor import ImportFromModule
from autogen_core.tools import FunctionTool, StaticWorkbench, ToolOverride, Workbench
@pytest.mark.asyncio
async def test_static_workbench_with_tool_overrides() -> None:
"""Test StaticWorkbench with tool name and description overrides."""
def test_tool_func_1(x: Annotated[int, "The number to double."]) -> int:
return x * 2
def test_tool_func_2(a: Annotated[int, "First number"], b: Annotated[int, "Second number"]) -> int:
return a + b
test_tool_1 = FunctionTool(
test_tool_func_1,
name="double",
description="A test tool that doubles a number.",
global_imports=[ImportFromModule(module="typing_extensions", imports=["Annotated"])],
)
test_tool_2 = FunctionTool(
test_tool_func_2,
name="add",
description="A test tool that adds two numbers.",
global_imports=[ImportFromModule(module="typing_extensions", imports=["Annotated"])],
)
# Define tool overrides
overrides: Dict[str, ToolOverride] = {
"double": ToolOverride(name="multiply_by_two", description="Multiplies a number by 2"),
"add": ToolOverride(description="Performs addition of two integers"), # Only override description
}
# Create a StaticWorkbench instance with tool overrides
async with StaticWorkbench(tools=[test_tool_1, test_tool_2], tool_overrides=overrides) as workbench:
# List tools and verify overrides are applied
tools = await workbench.list_tools()
assert len(tools) == 2
# Check first tool has name and description overridden
assert tools[0]["name"] == "multiply_by_two"
assert tools[0].get("description") == "Multiplies a number by 2"
assert tools[0].get("parameters") == {
"type": "object",
"properties": {"x": {"type": "integer", "title": "X", "description": "The number to double."}},
"required": ["x"],
"additionalProperties": False,
}
# Check second tool has only description overridden
assert tools[1]["name"] == "add" # Original name
assert tools[1].get("description") == "Performs addition of two integers" # Overridden description
assert tools[1].get("parameters") == {
"type": "object",
"properties": {
"a": {"type": "integer", "title": "A", "description": "First number"},
"b": {"type": "integer", "title": "B", "description": "Second number"},
},
"required": ["a", "b"],
"additionalProperties": False,
}
# Call tools using override names
result_1 = await workbench.call_tool("multiply_by_two", {"x": 5})
assert result_1.name == "multiply_by_two" # Should return the override name
assert result_1.result[0].type == "TextResultContent"
assert result_1.result[0].content == "10"
assert result_1.to_text() == "10"
assert result_1.is_error is False
# Call tool using original name (should still work for description-only override)
result_2 = await workbench.call_tool("add", {"a": 3, "b": 7})
assert result_2.name == "add"
assert result_2.result[0].type == "TextResultContent"
assert result_2.result[0].content == "10"
assert result_2.to_text() == "10"
assert result_2.is_error is False
# Test calling non-existent tool
result_3 = await workbench.call_tool("nonexistent", {"x": 5})
assert result_3.name == "nonexistent"
assert result_3.is_error is True
assert result_3.result[0].type == "TextResultContent"
assert "Tool nonexistent not found" in result_3.result[0].content
@pytest.mark.asyncio
async def test_static_workbench_without_overrides() -> None:
"""Test StaticWorkbench without overrides (original behavior)."""
def test_tool_func(x: Annotated[int, "The number to double."]) -> int:
return x * 2
test_tool = FunctionTool(
test_tool_func,
name="double",
description="A test tool that doubles a number.",
global_imports=[ImportFromModule(module="typing_extensions", imports=["Annotated"])],
)
# Create workbench without overrides
async with StaticWorkbench(tools=[test_tool]) as workbench:
tools = await workbench.list_tools()
assert len(tools) == 1
assert tools[0].get("name") == "double"
assert tools[0].get("description") == "A test tool that doubles a number."
@pytest.mark.asyncio
async def test_static_workbench_serialization_with_overrides() -> None:
"""Test that StaticWorkbench can be serialized and deserialized with overrides."""
def test_tool_func(x: Annotated[int, "The number to double."]) -> int:
return x * 2
test_tool = FunctionTool(
test_tool_func,
name="double",
description="A test tool that doubles a number.",
global_imports=[ImportFromModule(module="typing_extensions", imports=["Annotated"])],
)
overrides: Dict[str, ToolOverride] = {
"double": ToolOverride(name="multiply_by_two", description="Multiplies a number by 2")
}
# Create workbench with overrides
workbench = StaticWorkbench(tools=[test_tool], tool_overrides=overrides)
# Save configuration
config = workbench.dump_component()
assert "tool_overrides" in config.config
# Load workbench from configuration
async with Workbench.load_component(config) as new_workbench:
tools = await new_workbench.list_tools()
assert len(tools) == 1
assert tools[0]["name"] == "multiply_by_two"
assert tools[0].get("description") == "Multiplies a number by 2"
# Test calling tool with override name
result = await new_workbench.call_tool("multiply_by_two", {"x": 5})
assert result.name == "multiply_by_two"
assert result.result[0].content == "10"
assert result.is_error is False
@pytest.mark.asyncio
async def test_static_workbench_partial_overrides() -> None:
"""Test StaticWorkbench with partial overrides (name only, description only)."""
def tool1_func(x: Annotated[int, "Number"]) -> int:
return x
def tool2_func(x: Annotated[int, "Number"]) -> int:
return x
tool1 = FunctionTool(
tool1_func,
name="tool1",
description="Original description 1",
global_imports=[ImportFromModule(module="typing_extensions", imports=["Annotated"])],
)
tool2 = FunctionTool(
tool2_func,
name="tool2",
description="Original description 2",
global_imports=[ImportFromModule(module="typing_extensions", imports=["Annotated"])],
)
overrides: Dict[str, ToolOverride] = {
"tool1": ToolOverride(name="renamed_tool1"), # Only name override
"tool2": ToolOverride(description="New description 2"), # Only description override
}
async with StaticWorkbench(tools=[tool1, tool2], tool_overrides=overrides) as workbench:
tools = await workbench.list_tools()
# tool1: name overridden, description unchanged
assert tools[0].get("name") == "renamed_tool1"
assert tools[0].get("description") == "Original description 1"
# tool2: name unchanged, description overridden
assert tools[1].get("name") == "tool2"
assert tools[1].get("description") == "New description 2"
# Test calling with override name
result1 = await workbench.call_tool("renamed_tool1", {"x": 42})
assert result1.name == "renamed_tool1"
assert result1.result[0].content == "42"
# Test calling with original name
result2 = await workbench.call_tool("tool2", {"x": 42})
assert result2.name == "tool2"
assert result2.result[0].content == "42"
def test_tool_override_model() -> None:
"""Test ToolOverride model functionality."""
# Test with both fields
override1 = ToolOverride(name="new_name", description="new_desc")
assert override1.name == "new_name"
assert override1.description == "new_desc"
# Test with only name
override2 = ToolOverride(name="new_name")
assert override2.name == "new_name"
assert override2.description is None
# Test with only description
override3 = ToolOverride(description="new_desc")
assert override3.name is None
assert override3.description == "new_desc"
# Test empty
override4 = ToolOverride()
assert override4.name is None
assert override4.description is None
def test_static_workbench_conflict_detection() -> None:
"""Test that StaticWorkbench detects conflicts in tool override names."""
def test_tool_func_1(x: Annotated[int, "Number"]) -> int:
return x
def test_tool_func_2(x: Annotated[int, "Number"]) -> int:
return x
def test_tool_func_3(x: Annotated[int, "Number"]) -> int:
return x
tool1 = FunctionTool(
test_tool_func_1,
name="tool1",
description="Tool 1",
global_imports=[ImportFromModule(module="typing_extensions", imports=["Annotated"])],
)
tool2 = FunctionTool(
test_tool_func_2,
name="tool2",
description="Tool 2",
global_imports=[ImportFromModule(module="typing_extensions", imports=["Annotated"])],
)
tool3 = FunctionTool(
test_tool_func_3,
name="tool3",
description="Tool 3",
global_imports=[ImportFromModule(module="typing_extensions", imports=["Annotated"])],
)
# Test 1: Valid overrides - should work
overrides_valid: Dict[str, ToolOverride] = {
"tool1": ToolOverride(name="renamed_tool1"),
"tool2": ToolOverride(name="renamed_tool2"),
}
workbench_valid = StaticWorkbench(tools=[tool1, tool2, tool3], tool_overrides=overrides_valid)
assert "renamed_tool1" in workbench_valid._override_name_to_original # type: ignore[reportPrivateUsage]
assert "renamed_tool2" in workbench_valid._override_name_to_original # type: ignore[reportPrivateUsage]
# Test 2: Conflict with existing tool name - should fail
overrides_conflict: Dict[str, ToolOverride] = {
"tool1": ToolOverride(name="tool2") # tool2 already exists
}
with pytest.raises(ValueError):
StaticWorkbench(tools=[tool1, tool2, tool3], tool_overrides=overrides_conflict)
# Test 3: Duplicate override names - should fail
overrides_duplicate: Dict[str, ToolOverride] = {
"tool1": ToolOverride(name="same_name"),
"tool2": ToolOverride(name="same_name"), # Duplicate
}
with pytest.raises(ValueError):
StaticWorkbench(tools=[tool1, tool2, tool3], tool_overrides=overrides_duplicate)
# Test 4: Self-renaming - should work but not add to reverse mapping
overrides_self: Dict[str, ToolOverride] = {
"tool1": ToolOverride(name="tool1") # Renaming to itself
}
workbench_self = StaticWorkbench(tools=[tool1, tool2, tool3], tool_overrides=overrides_self)
assert "tool1" not in workbench_self._override_name_to_original # type: ignore[reportPrivateUsage]

View File

@ -1,19 +1,20 @@
import asyncio
import builtins
import warnings
from typing import Any, List, Literal, Mapping
from typing import Any, Dict, List, Literal, Mapping, Optional
from autogen_core import CancellationToken, Component, Image, trace_tool_span
from autogen_core.tools import (
ImageResultContent,
ParametersSchema,
TextResultContent,
ToolOverride,
ToolResult,
ToolSchema,
Workbench,
)
from mcp.types import CallToolResult, EmbeddedResource, ImageContent, ListToolsResult, TextContent
from pydantic import BaseModel
from pydantic import BaseModel, Field
from typing_extensions import Self
from ._actor import McpSessionActor
@ -22,6 +23,7 @@ from ._config import McpServerParams, SseServerParams, StdioServerParams, Stream
class McpWorkbenchConfig(BaseModel):
server_params: McpServerParams
tool_overrides: Dict[str, ToolOverride] = Field(default_factory=dict)
class McpWorkbenchState(BaseModel):
@ -36,6 +38,10 @@ class McpWorkbench(Workbench, Component[McpWorkbenchConfig]):
Args:
server_params (McpServerParams): The parameters to connect to the MCP server.
This can be either a :class:`StdioServerParams` or :class:`SseServerParams`.
tool_overrides (Optional[Dict[str, ToolOverride]]): Optional mapping of original tool
names to override configurations for name and/or description. This allows
customizing how server tools appear to consumers while maintaining the underlying
tool functionality.
Examples:
@ -65,6 +71,38 @@ class McpWorkbench(Workbench, Component[McpWorkbenchConfig]):
asyncio.run(main())
Example of using tool overrides:
.. code-block:: python
import asyncio
from autogen_ext.tools.mcp import McpWorkbench, StdioServerParams
from autogen_core.tools import ToolOverride
async def main() -> None:
params = StdioServerParams(
command="uvx",
args=["mcp-server-fetch"],
read_timeout_seconds=60,
)
# Override the fetch tool's name and description
overrides = {
"fetch": ToolOverride(name="web_fetch", description="Enhanced web fetching tool with better error handling")
}
async with McpWorkbench(server_params=params, tool_overrides=overrides) as workbench:
tools = await workbench.list_tools()
# The tool will now appear as "web_fetch" with the new description
print(tools)
# Call the overridden tool
result = await workbench.call_tool("web_fetch", {"url": "https://github.com/"})
print(result)
asyncio.run(main())
Example of using the workbench with the `GitHub MCP Server <https://github.com/github/github-mcp-server>`_:
.. code-block:: python
@ -149,8 +187,26 @@ class McpWorkbench(Workbench, Component[McpWorkbenchConfig]):
component_provider_override = "autogen_ext.tools.mcp.McpWorkbench"
component_config_schema = McpWorkbenchConfig
def __init__(self, server_params: McpServerParams) -> None:
def __init__(
self, server_params: McpServerParams, tool_overrides: Optional[Dict[str, ToolOverride]] = None
) -> None:
self._server_params = server_params
self._tool_overrides = tool_overrides or {}
# Build reverse mapping from override names to original names for call_tool
self._override_name_to_original: Dict[str, str] = {}
for original_name, override in self._tool_overrides.items():
override_name = override.name
if override_name and override_name != original_name:
# Check for conflicts with other override names
if override_name in self._override_name_to_original:
existing_original = self._override_name_to_original[override_name]
raise ValueError(
f"Tool override name '{override_name}' is used by multiple tools: "
f"'{existing_original}' and '{original_name}'. Override names must be unique."
)
self._override_name_to_original[override_name] = original_name
# self._session: ClientSession | None = None
self._actor: McpSessionActor | None = None
self._actor_loop: asyncio.AbstractEventLoop | None = None
@ -175,8 +231,18 @@ class McpWorkbench(Workbench, Component[McpWorkbenchConfig]):
), f"list_tools must return a CallToolResult, instead of : {str(type(list_tool_result))}"
schema: List[ToolSchema] = []
for tool in list_tool_result.tools:
name = tool.name
original_name = tool.name
name = original_name
description = tool.description or ""
# Apply overrides if they exist for this tool
if original_name in self._tool_overrides:
override = self._tool_overrides[original_name]
if override.name is not None:
name = override.name
if override.description is not None:
description = override.description
parameters = ParametersSchema(
type="object",
properties=tool.inputSchema.get("properties", {}),
@ -208,12 +274,16 @@ class McpWorkbench(Workbench, Component[McpWorkbenchConfig]):
cancellation_token = CancellationToken()
if not arguments:
arguments = {}
# Check if the name is an override name and map it back to the original
original_name = self._override_name_to_original.get(name, name)
with trace_tool_span(
tool_name=name,
tool_name=name, # Use the requested name for tracing
tool_call_id=call_id,
):
try:
result_future = await self._actor.call("call_tool", {"name": name, "kargs": arguments})
result_future = await self._actor.call("call_tool", {"name": original_name, "kargs": arguments})
cancellation_token.link_future(result_future)
result = await result_future
assert isinstance(
@ -236,7 +306,7 @@ class McpWorkbench(Workbench, Component[McpWorkbenchConfig]):
error_message = self._format_errors(e)
is_error = True
result_parts = [TextResultContent(content=error_message)]
return ToolResult(name=name, result=result_parts, is_error=is_error)
return ToolResult(name=name, result=result_parts, is_error=is_error) # Return the requested name
def _format_errors(self, error: Exception) -> str:
"""Recursively format errors into a string."""
@ -285,18 +355,21 @@ class McpWorkbench(Workbench, Component[McpWorkbenchConfig]):
pass
def _to_config(self) -> McpWorkbenchConfig:
return McpWorkbenchConfig(server_params=self._server_params)
return McpWorkbenchConfig(server_params=self._server_params, tool_overrides=self._tool_overrides)
@classmethod
def _from_config(cls, config: McpWorkbenchConfig) -> Self:
return cls(server_params=config.server_params)
return cls(server_params=config.server_params, tool_overrides=config.tool_overrides)
def __del__(self) -> None:
# Ensure the actor is stopped when the workbench is deleted
if self._actor and self._actor_loop:
loop = self._actor_loop
if loop.is_running() and not loop.is_closed():
loop.call_soon_threadsafe(lambda: asyncio.create_task(self.stop()))
# Use getattr to safely handle cases where attributes may not be set (e.g., if __init__ failed)
actor = getattr(self, "_actor", None)
actor_loop = getattr(self, "_actor_loop", None)
if actor and actor_loop:
if actor_loop.is_running() and not actor_loop.is_closed():
actor_loop.call_soon_threadsafe(lambda: asyncio.create_task(self.stop()))
else:
msg = "Cannot safely stop actor at [McpWorkbench.__del__]: loop is closed or not running"
warnings.warn(msg, RuntimeWarning, stacklevel=2)

View File

@ -0,0 +1,298 @@
import asyncio
from typing import Any, Dict
from unittest.mock import AsyncMock
import pytest
from autogen_core.tools import ToolOverride
from autogen_ext.tools.mcp import McpWorkbench, StdioServerParams
from mcp import Tool
from mcp.types import ListToolsResult
@pytest.fixture
def sample_mcp_tools() -> list[Tool]:
"""Create sample MCP tools for testing."""
return [
Tool(
name="fetch",
description="Fetches content from a URL",
inputSchema={
"type": "object",
"properties": {"url": {"type": "string"}},
"required": ["url"],
},
),
Tool(
name="search",
description="Searches for information",
inputSchema={
"type": "object",
"properties": {"query": {"type": "string"}},
"required": ["query"],
},
),
]
@pytest.fixture
def mock_mcp_actor() -> AsyncMock:
"""Mock MCP session actor."""
actor = AsyncMock()
return actor
@pytest.fixture
def sample_server_params() -> StdioServerParams:
"""Sample server parameters for testing."""
return StdioServerParams(command="echo", args=["test"])
@pytest.mark.asyncio
async def test_mcp_workbench_with_tool_overrides(
sample_mcp_tools: list[Tool], mock_mcp_actor: AsyncMock, sample_server_params: StdioServerParams
) -> None:
"""Test McpWorkbench with tool name and description overrides."""
# Define tool overrides
overrides: Dict[str, ToolOverride] = {
"fetch": ToolOverride(name="web_fetch", description="Enhanced web fetching tool"),
"search": ToolOverride(description="Advanced search functionality"), # Only override description
}
# Create workbench with overrides
workbench = McpWorkbench(server_params=sample_server_params, tool_overrides=overrides)
workbench._actor = mock_mcp_actor # type: ignore[reportPrivateUsage]
# Mock list_tools response
list_tools_result = ListToolsResult(tools=sample_mcp_tools)
# The actor.call() should return a Future that when awaited returns the list_tools_result
future_result: asyncio.Future[ListToolsResult] = asyncio.Future()
future_result.set_result(list_tools_result)
mock_mcp_actor.call.return_value = future_result
try:
# List tools and verify overrides are applied
tools = await workbench.list_tools()
assert len(tools) == 2
# Check first tool has name and description overridden
assert tools[0].get("name") == "web_fetch"
assert tools[0].get("description") == "Enhanced web fetching tool"
# Check second tool has only description overridden
assert tools[1].get("name") == "search" # Original name
assert tools[1].get("description") == "Advanced search functionality" # Overridden description
# Verify actor was called correctly
mock_mcp_actor.call.assert_called_with("list_tools", None)
finally:
workbench._actor = None # type: ignore[reportPrivateUsage]
@pytest.mark.asyncio
async def test_mcp_workbench_call_tool_with_overrides(
sample_mcp_tools: list[Tool], mock_mcp_actor: AsyncMock, sample_server_params: StdioServerParams
) -> None:
"""Test calling tools with override names maps back to original names."""
overrides: Dict[str, ToolOverride] = {
"fetch": ToolOverride(name="web_fetch", description="Enhanced web fetching tool")
}
workbench = McpWorkbench(server_params=sample_server_params, tool_overrides=overrides)
workbench._actor = mock_mcp_actor # type: ignore[reportPrivateUsage]
# Mock successful tool call response
from mcp.types import CallToolResult, TextContent
mock_result = CallToolResult(content=[TextContent(text="Mock response", type="text")], isError=False)
# Create futures for each call
def mock_call_side_effect(
method: str, args: dict[str, Any] | None = None
) -> asyncio.Future[ListToolsResult | CallToolResult]:
future_result: asyncio.Future[ListToolsResult | CallToolResult] = asyncio.Future()
if method == "list_tools":
future_result.set_result(ListToolsResult(tools=sample_mcp_tools))
elif method == "call_tool":
future_result.set_result(mock_result)
else:
future_result.set_exception(ValueError(f"Unexpected method: {method}"))
return future_result
mock_mcp_actor.call.side_effect = mock_call_side_effect
try:
# Call tool using override name
result = await workbench.call_tool("web_fetch", {"url": "https://example.com"})
# Verify the result
assert result.name == "web_fetch" # Should return the override name
assert result.result[0].content == "Mock response"
assert result.is_error is False
# Verify the actor was called with the original tool name
call_args = mock_mcp_actor.call.call_args_list[-1]
assert call_args[0][0] == "call_tool"
assert call_args[0][1]["name"] == "fetch" # Original name should be used
assert call_args[0][1]["kargs"] == {"url": "https://example.com"}
finally:
workbench._actor = None # type: ignore[reportPrivateUsage]
@pytest.mark.asyncio
async def test_mcp_workbench_without_overrides(
sample_mcp_tools: list[Tool], mock_mcp_actor: AsyncMock, sample_server_params: StdioServerParams
) -> None:
"""Test McpWorkbench without overrides (original behavior)."""
workbench = McpWorkbench(server_params=sample_server_params)
workbench._actor = mock_mcp_actor # type: ignore[reportPrivateUsage]
# Mock list_tools response
list_tools_result = ListToolsResult(tools=sample_mcp_tools)
future_result: asyncio.Future[ListToolsResult] = asyncio.Future()
future_result.set_result(list_tools_result)
mock_mcp_actor.call.return_value = future_result
try:
tools = await workbench.list_tools()
assert len(tools) == 2
# Verify original names and descriptions are preserved
assert tools[0].get("name") == "fetch"
assert tools[0].get("description") == "Fetches content from a URL"
assert tools[1].get("name") == "search"
assert tools[1].get("description") == "Searches for information"
finally:
workbench._actor = None # type: ignore[reportPrivateUsage]
@pytest.mark.asyncio
async def test_mcp_workbench_serialization_with_overrides(sample_server_params: StdioServerParams) -> None:
"""Test that McpWorkbench can be serialized and deserialized with overrides."""
overrides: Dict[str, ToolOverride] = {
"fetch": ToolOverride(name="web_fetch", description="Enhanced web fetching tool")
}
# Create workbench with overrides
workbench = McpWorkbench(server_params=sample_server_params, tool_overrides=overrides)
# Save configuration
config = workbench.dump_component()
assert "tool_overrides" in config.config
assert "fetch" in config.config["tool_overrides"]
assert config.config["tool_overrides"]["fetch"]["name"] == "web_fetch"
assert config.config["tool_overrides"]["fetch"]["description"] == "Enhanced web fetching tool"
# Load workbench from configuration
new_workbench = McpWorkbench.load_component(config)
assert len(new_workbench._tool_overrides) == 1 # type: ignore[reportPrivateUsage]
assert new_workbench._tool_overrides["fetch"].name == "web_fetch" # type: ignore[reportPrivateUsage]
assert new_workbench._tool_overrides["fetch"].description == "Enhanced web fetching tool" # type: ignore[reportPrivateUsage]
@pytest.mark.asyncio
async def test_mcp_workbench_partial_overrides(
sample_mcp_tools: list[Tool], mock_mcp_actor: AsyncMock, sample_server_params: StdioServerParams
) -> None:
"""Test McpWorkbench with partial overrides (name only, description only)."""
overrides: Dict[str, ToolOverride] = {
"fetch": ToolOverride(name="web_fetch"), # Only name override
"search": ToolOverride(description="Advanced search"), # Only description override
}
workbench = McpWorkbench(server_params=sample_server_params, tool_overrides=overrides)
workbench._actor = mock_mcp_actor # type: ignore[reportPrivateUsage]
# Mock list_tools response
list_tools_result = ListToolsResult(tools=sample_mcp_tools)
future_result: asyncio.Future[ListToolsResult] = asyncio.Future()
future_result.set_result(list_tools_result)
mock_mcp_actor.call.return_value = future_result
try:
tools = await workbench.list_tools()
# fetch: name overridden, description unchanged
assert tools[0].get("name") == "web_fetch"
assert tools[0].get("description") == "Fetches content from a URL" # Original description
# search: name unchanged, description overridden
assert tools[1].get("name") == "search" # Original name
assert tools[1].get("description") == "Advanced search" # Overridden description
finally:
workbench._actor = None # type: ignore[reportPrivateUsage]
def test_mcp_tool_override_model() -> None:
"""Test ToolOverride model functionality for MCP."""
# Test with both fields
override1 = ToolOverride(name="new_name", description="new_desc")
assert override1.name == "new_name"
assert override1.description == "new_desc"
# Test with only name
override2 = ToolOverride(name="new_name")
assert override2.name == "new_name"
assert override2.description is None
# Test with only description
override3 = ToolOverride(description="new_desc")
assert override3.name is None
assert override3.description == "new_desc"
# Test empty
override4 = ToolOverride()
assert override4.name is None
assert override4.description is None
@pytest.mark.asyncio
async def test_mcp_workbench_override_name_to_original_mapping(sample_server_params: StdioServerParams) -> None:
"""Test that the reverse mapping from override names to original names works correctly."""
overrides: Dict[str, ToolOverride] = {
"original1": ToolOverride(name="override1"),
"original2": ToolOverride(name="override2"),
"original3": ToolOverride(description="only description override"), # No name change, only description
}
workbench = McpWorkbench(server_params=sample_server_params, tool_overrides=overrides)
# Check reverse mapping is built correctly
assert workbench._override_name_to_original["override1"] == "original1" # type: ignore[reportPrivateUsage]
assert workbench._override_name_to_original["override2"] == "original2" # type: ignore[reportPrivateUsage]
assert "original3" not in workbench._override_name_to_original # type: ignore[reportPrivateUsage]
assert len(workbench._override_name_to_original) == 2 # type: ignore[reportPrivateUsage]
def test_mcp_workbench_conflict_detection() -> None:
"""Test that McpWorkbench detects conflicts in tool override names."""
server_params = StdioServerParams(command="echo", args=["test"])
# Test 1: Valid overrides - should work
overrides_valid: Dict[str, ToolOverride] = {
"fetch": ToolOverride(name="web_fetch"),
"search": ToolOverride(name="advanced_search"),
}
workbench_valid = McpWorkbench(server_params=server_params, tool_overrides=overrides_valid)
assert workbench_valid._override_name_to_original["web_fetch"] == "fetch" # type: ignore[reportPrivateUsage]
assert workbench_valid._override_name_to_original["advanced_search"] == "search" # type: ignore[reportPrivateUsage]
# Test 2: Duplicate override names - should fail
overrides_duplicate: Dict[str, ToolOverride] = {
"fetch": ToolOverride(name="same_name"),
"search": ToolOverride(name="same_name"), # Duplicate
}
with pytest.raises(ValueError):
McpWorkbench(server_params=server_params, tool_overrides=overrides_duplicate)