diff --git a/python/packages/autogen-core/src/autogen_core/tools/__init__.py b/python/packages/autogen-core/src/autogen_core/tools/__init__.py index d6bbaf577..aee634e1f 100644 --- a/python/packages/autogen-core/src/autogen_core/tools/__init__.py +++ b/python/packages/autogen-core/src/autogen_core/tools/__init__.py @@ -1,4 +1,13 @@ -from ._base import BaseStreamTool, BaseTool, BaseToolWithState, ParametersSchema, StreamTool, Tool, ToolSchema +from ._base import ( + BaseStreamTool, + BaseTool, + BaseToolWithState, + ParametersSchema, + StreamTool, + Tool, + ToolOverride, + ToolSchema, +) from ._function_tool import FunctionTool from ._static_workbench import StaticStreamWorkbench, StaticWorkbench from ._workbench import ImageResultContent, TextResultContent, ToolResult, Workbench @@ -18,4 +27,5 @@ __all__ = [ "ImageResultContent", "StaticWorkbench", "StaticStreamWorkbench", + "ToolOverride", ] diff --git a/python/packages/autogen-core/src/autogen_core/tools/_base.py b/python/packages/autogen-core/src/autogen_core/tools/_base.py index 8936c9361..d2ea76e21 100644 --- a/python/packages/autogen-core/src/autogen_core/tools/_base.py +++ b/python/packages/autogen-core/src/autogen_core/tools/_base.py @@ -2,7 +2,19 @@ import json import logging from abc import ABC, abstractmethod from collections.abc import Sequence -from typing import Any, AsyncGenerator, Dict, Generic, Mapping, Protocol, Type, TypeVar, cast, runtime_checkable +from typing import ( + Any, + AsyncGenerator, + Dict, + Generic, + Mapping, + Optional, + Protocol, + Type, + TypeVar, + cast, + runtime_checkable, +) import jsonref from pydantic import BaseModel @@ -33,6 +45,13 @@ class ToolSchema(TypedDict): strict: NotRequired[bool] +class ToolOverride(BaseModel): + """Override configuration for a tool's name and/or description.""" + + name: Optional[str] = None + description: Optional[str] = None + + @runtime_checkable class Tool(Protocol): @property diff --git a/python/packages/autogen-core/src/autogen_core/tools/_static_workbench.py b/python/packages/autogen-core/src/autogen_core/tools/_static_workbench.py index 71e9ca4af..40b1ce47d 100644 --- a/python/packages/autogen-core/src/autogen_core/tools/_static_workbench.py +++ b/python/packages/autogen-core/src/autogen_core/tools/_static_workbench.py @@ -1,18 +1,19 @@ import asyncio import builtins -from typing import Any, AsyncGenerator, Dict, List, Literal, Mapping +from typing import Any, AsyncGenerator, Dict, List, Literal, Mapping, Optional -from pydantic import BaseModel +from pydantic import BaseModel, Field from typing_extensions import Self from .._cancellation_token import CancellationToken from .._component_config import Component, ComponentModel -from ._base import BaseTool, StreamTool, ToolSchema +from ._base import BaseTool, StreamTool, ToolOverride, ToolSchema from ._workbench import StreamWorkbench, TextResultContent, ToolResult, Workbench class StaticWorkbenchConfig(BaseModel): tools: List[ComponentModel] = [] + tool_overrides: Dict[str, ToolOverride] = Field(default_factory=dict) class StateicWorkbenchState(BaseModel): @@ -28,16 +29,67 @@ class StaticWorkbench(Workbench, Component[StaticWorkbenchConfig]): Args: tools (List[BaseTool[Any, Any]]): A list of tools to be included in the workbench. The tools should be subclasses of :class:`~autogen_core.tools.BaseTool`. + tool_overrides (Optional[Dict[str, ToolOverride]]): Optional mapping of original tool + names to override configurations for name and/or description. This allows + customizing how tools appear to consumers while maintaining the underlying + tool functionality. """ component_provider_override = "autogen_core.tools.StaticWorkbench" component_config_schema = StaticWorkbenchConfig - def __init__(self, tools: List[BaseTool[Any, Any]]) -> None: + def __init__( + self, tools: List[BaseTool[Any, Any]], tool_overrides: Optional[Dict[str, ToolOverride]] = None + ) -> None: self._tools = tools + self._tool_overrides = tool_overrides or {} + + # Build reverse mapping from override names to original names for call_tool + self._override_name_to_original: Dict[str, str] = {} + existing_tool_names = {tool.name for tool in self._tools} + + for original_name, override in self._tool_overrides.items(): + if override.name and override.name != original_name: + # Check for conflicts with existing tool names + if override.name in existing_tool_names and override.name != original_name: + raise ValueError( + f"Tool override name '{override.name}' conflicts with existing tool name. " + f"Override names must not conflict with any tool names." + ) + # Check for conflicts with other override names + if override.name in self._override_name_to_original: + existing_original = self._override_name_to_original[override.name] + raise ValueError( + f"Tool override name '{override.name}' is used by multiple tools: " + f"'{existing_original}' and '{original_name}'. Override names must be unique." + ) + self._override_name_to_original[override.name] = original_name async def list_tools(self) -> List[ToolSchema]: - return [tool.schema for tool in self._tools] + result_schemas: List[ToolSchema] = [] + for tool in self._tools: + original_schema = tool.schema + + # Apply overrides if they exist for this tool + if tool.name in self._tool_overrides: + override = self._tool_overrides[tool.name] + # Create a new ToolSchema with overrides applied + schema: ToolSchema = { + "name": override.name if override.name is not None else original_schema["name"], + "description": override.description + if override.description is not None + else original_schema.get("description", ""), + } + # Copy optional fields + if "parameters" in original_schema: + schema["parameters"] = original_schema["parameters"] + if "strict" in original_schema: + schema["strict"] = original_schema["strict"] + else: + schema = original_schema + + result_schemas.append(schema) + return result_schemas async def call_tool( self, @@ -46,10 +98,13 @@ class StaticWorkbench(Workbench, Component[StaticWorkbenchConfig]): cancellation_token: CancellationToken | None = None, call_id: str | None = None, ) -> ToolResult: - tool = next((tool for tool in self._tools if tool.name == name), None) + # Check if the name is an override name and map it back to the original + original_name = self._override_name_to_original.get(name, name) + + tool = next((tool for tool in self._tools if tool.name == original_name), None) if tool is None: return ToolResult( - name=name, + name=name, # Return the requested name (which might be overridden) result=[TextResultContent(content=f"Tool {name} not found.")], is_error=True, ) @@ -66,7 +121,7 @@ class StaticWorkbench(Workbench, Component[StaticWorkbenchConfig]): except Exception as e: result_str = self._format_errors(e) is_error = True - return ToolResult(name=tool.name, result=[TextResultContent(content=result_str)], is_error=is_error) + return ToolResult(name=name, result=[TextResultContent(content=result_str)], is_error=is_error) async def start(self) -> None: return None @@ -90,11 +145,13 @@ class StaticWorkbench(Workbench, Component[StaticWorkbenchConfig]): await tool.load_state_json(parsed_state.tools[tool.name]) def _to_config(self) -> StaticWorkbenchConfig: - return StaticWorkbenchConfig(tools=[tool.dump_component() for tool in self._tools]) + return StaticWorkbenchConfig( + tools=[tool.dump_component() for tool in self._tools], tool_overrides=self._tool_overrides + ) @classmethod def _from_config(cls, config: StaticWorkbenchConfig) -> Self: - return cls(tools=[BaseTool.load_component(tool) for tool in config.tools]) + return cls(tools=[BaseTool.load_component(tool) for tool in config.tools], tool_overrides=config.tool_overrides) def _format_errors(self, error: Exception) -> str: """Recursively format errors into a string.""" diff --git a/python/packages/autogen-core/tests/test_static_workbench_overrides.py b/python/packages/autogen-core/tests/test_static_workbench_overrides.py new file mode 100644 index 000000000..37cf1b752 --- /dev/null +++ b/python/packages/autogen-core/tests/test_static_workbench_overrides.py @@ -0,0 +1,285 @@ +from typing import Annotated, Dict + +import pytest +from autogen_core.code_executor import ImportFromModule +from autogen_core.tools import FunctionTool, StaticWorkbench, ToolOverride, Workbench + + +@pytest.mark.asyncio +async def test_static_workbench_with_tool_overrides() -> None: + """Test StaticWorkbench with tool name and description overrides.""" + + def test_tool_func_1(x: Annotated[int, "The number to double."]) -> int: + return x * 2 + + def test_tool_func_2(a: Annotated[int, "First number"], b: Annotated[int, "Second number"]) -> int: + return a + b + + test_tool_1 = FunctionTool( + test_tool_func_1, + name="double", + description="A test tool that doubles a number.", + global_imports=[ImportFromModule(module="typing_extensions", imports=["Annotated"])], + ) + test_tool_2 = FunctionTool( + test_tool_func_2, + name="add", + description="A test tool that adds two numbers.", + global_imports=[ImportFromModule(module="typing_extensions", imports=["Annotated"])], + ) + + # Define tool overrides + overrides: Dict[str, ToolOverride] = { + "double": ToolOverride(name="multiply_by_two", description="Multiplies a number by 2"), + "add": ToolOverride(description="Performs addition of two integers"), # Only override description + } + + # Create a StaticWorkbench instance with tool overrides + async with StaticWorkbench(tools=[test_tool_1, test_tool_2], tool_overrides=overrides) as workbench: + # List tools and verify overrides are applied + tools = await workbench.list_tools() + assert len(tools) == 2 + + # Check first tool has name and description overridden + assert tools[0]["name"] == "multiply_by_two" + assert tools[0].get("description") == "Multiplies a number by 2" + assert tools[0].get("parameters") == { + "type": "object", + "properties": {"x": {"type": "integer", "title": "X", "description": "The number to double."}}, + "required": ["x"], + "additionalProperties": False, + } + + # Check second tool has only description overridden + assert tools[1]["name"] == "add" # Original name + assert tools[1].get("description") == "Performs addition of two integers" # Overridden description + assert tools[1].get("parameters") == { + "type": "object", + "properties": { + "a": {"type": "integer", "title": "A", "description": "First number"}, + "b": {"type": "integer", "title": "B", "description": "Second number"}, + }, + "required": ["a", "b"], + "additionalProperties": False, + } + + # Call tools using override names + result_1 = await workbench.call_tool("multiply_by_two", {"x": 5}) + assert result_1.name == "multiply_by_two" # Should return the override name + assert result_1.result[0].type == "TextResultContent" + assert result_1.result[0].content == "10" + assert result_1.to_text() == "10" + assert result_1.is_error is False + + # Call tool using original name (should still work for description-only override) + result_2 = await workbench.call_tool("add", {"a": 3, "b": 7}) + assert result_2.name == "add" + assert result_2.result[0].type == "TextResultContent" + assert result_2.result[0].content == "10" + assert result_2.to_text() == "10" + assert result_2.is_error is False + + # Test calling non-existent tool + result_3 = await workbench.call_tool("nonexistent", {"x": 5}) + assert result_3.name == "nonexistent" + assert result_3.is_error is True + assert result_3.result[0].type == "TextResultContent" + assert "Tool nonexistent not found" in result_3.result[0].content + + +@pytest.mark.asyncio +async def test_static_workbench_without_overrides() -> None: + """Test StaticWorkbench without overrides (original behavior).""" + + def test_tool_func(x: Annotated[int, "The number to double."]) -> int: + return x * 2 + + test_tool = FunctionTool( + test_tool_func, + name="double", + description="A test tool that doubles a number.", + global_imports=[ImportFromModule(module="typing_extensions", imports=["Annotated"])], + ) + + # Create workbench without overrides + async with StaticWorkbench(tools=[test_tool]) as workbench: + tools = await workbench.list_tools() + assert len(tools) == 1 + assert tools[0].get("name") == "double" + assert tools[0].get("description") == "A test tool that doubles a number." + + +@pytest.mark.asyncio +async def test_static_workbench_serialization_with_overrides() -> None: + """Test that StaticWorkbench can be serialized and deserialized with overrides.""" + + def test_tool_func(x: Annotated[int, "The number to double."]) -> int: + return x * 2 + + test_tool = FunctionTool( + test_tool_func, + name="double", + description="A test tool that doubles a number.", + global_imports=[ImportFromModule(module="typing_extensions", imports=["Annotated"])], + ) + + overrides: Dict[str, ToolOverride] = { + "double": ToolOverride(name="multiply_by_two", description="Multiplies a number by 2") + } + + # Create workbench with overrides + workbench = StaticWorkbench(tools=[test_tool], tool_overrides=overrides) + + # Save configuration + config = workbench.dump_component() + assert "tool_overrides" in config.config + + # Load workbench from configuration + async with Workbench.load_component(config) as new_workbench: + tools = await new_workbench.list_tools() + assert len(tools) == 1 + assert tools[0]["name"] == "multiply_by_two" + assert tools[0].get("description") == "Multiplies a number by 2" + + # Test calling tool with override name + result = await new_workbench.call_tool("multiply_by_two", {"x": 5}) + assert result.name == "multiply_by_two" + assert result.result[0].content == "10" + assert result.is_error is False + + +@pytest.mark.asyncio +async def test_static_workbench_partial_overrides() -> None: + """Test StaticWorkbench with partial overrides (name only, description only).""" + + def tool1_func(x: Annotated[int, "Number"]) -> int: + return x + + def tool2_func(x: Annotated[int, "Number"]) -> int: + return x + + tool1 = FunctionTool( + tool1_func, + name="tool1", + description="Original description 1", + global_imports=[ImportFromModule(module="typing_extensions", imports=["Annotated"])], + ) + tool2 = FunctionTool( + tool2_func, + name="tool2", + description="Original description 2", + global_imports=[ImportFromModule(module="typing_extensions", imports=["Annotated"])], + ) + + overrides: Dict[str, ToolOverride] = { + "tool1": ToolOverride(name="renamed_tool1"), # Only name override + "tool2": ToolOverride(description="New description 2"), # Only description override + } + + async with StaticWorkbench(tools=[tool1, tool2], tool_overrides=overrides) as workbench: + tools = await workbench.list_tools() + + # tool1: name overridden, description unchanged + assert tools[0].get("name") == "renamed_tool1" + assert tools[0].get("description") == "Original description 1" + + # tool2: name unchanged, description overridden + assert tools[1].get("name") == "tool2" + assert tools[1].get("description") == "New description 2" + + # Test calling with override name + result1 = await workbench.call_tool("renamed_tool1", {"x": 42}) + assert result1.name == "renamed_tool1" + assert result1.result[0].content == "42" + + # Test calling with original name + result2 = await workbench.call_tool("tool2", {"x": 42}) + assert result2.name == "tool2" + assert result2.result[0].content == "42" + + +def test_tool_override_model() -> None: + """Test ToolOverride model functionality.""" + + # Test with both fields + override1 = ToolOverride(name="new_name", description="new_desc") + assert override1.name == "new_name" + assert override1.description == "new_desc" + + # Test with only name + override2 = ToolOverride(name="new_name") + assert override2.name == "new_name" + assert override2.description is None + + # Test with only description + override3 = ToolOverride(description="new_desc") + assert override3.name is None + assert override3.description == "new_desc" + + # Test empty + override4 = ToolOverride() + assert override4.name is None + assert override4.description is None + + +def test_static_workbench_conflict_detection() -> None: + """Test that StaticWorkbench detects conflicts in tool override names.""" + + def test_tool_func_1(x: Annotated[int, "Number"]) -> int: + return x + + def test_tool_func_2(x: Annotated[int, "Number"]) -> int: + return x + + def test_tool_func_3(x: Annotated[int, "Number"]) -> int: + return x + + tool1 = FunctionTool( + test_tool_func_1, + name="tool1", + description="Tool 1", + global_imports=[ImportFromModule(module="typing_extensions", imports=["Annotated"])], + ) + tool2 = FunctionTool( + test_tool_func_2, + name="tool2", + description="Tool 2", + global_imports=[ImportFromModule(module="typing_extensions", imports=["Annotated"])], + ) + tool3 = FunctionTool( + test_tool_func_3, + name="tool3", + description="Tool 3", + global_imports=[ImportFromModule(module="typing_extensions", imports=["Annotated"])], + ) + + # Test 1: Valid overrides - should work + overrides_valid: Dict[str, ToolOverride] = { + "tool1": ToolOverride(name="renamed_tool1"), + "tool2": ToolOverride(name="renamed_tool2"), + } + workbench_valid = StaticWorkbench(tools=[tool1, tool2, tool3], tool_overrides=overrides_valid) + assert "renamed_tool1" in workbench_valid._override_name_to_original # type: ignore[reportPrivateUsage] + assert "renamed_tool2" in workbench_valid._override_name_to_original # type: ignore[reportPrivateUsage] + + # Test 2: Conflict with existing tool name - should fail + overrides_conflict: Dict[str, ToolOverride] = { + "tool1": ToolOverride(name="tool2") # tool2 already exists + } + with pytest.raises(ValueError): + StaticWorkbench(tools=[tool1, tool2, tool3], tool_overrides=overrides_conflict) + + # Test 3: Duplicate override names - should fail + overrides_duplicate: Dict[str, ToolOverride] = { + "tool1": ToolOverride(name="same_name"), + "tool2": ToolOverride(name="same_name"), # Duplicate + } + with pytest.raises(ValueError): + StaticWorkbench(tools=[tool1, tool2, tool3], tool_overrides=overrides_duplicate) + + # Test 4: Self-renaming - should work but not add to reverse mapping + overrides_self: Dict[str, ToolOverride] = { + "tool1": ToolOverride(name="tool1") # Renaming to itself + } + workbench_self = StaticWorkbench(tools=[tool1, tool2, tool3], tool_overrides=overrides_self) + assert "tool1" not in workbench_self._override_name_to_original # type: ignore[reportPrivateUsage] diff --git a/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_workbench.py b/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_workbench.py index 26c9434e3..cc3541d77 100644 --- a/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_workbench.py +++ b/python/packages/autogen-ext/src/autogen_ext/tools/mcp/_workbench.py @@ -1,19 +1,20 @@ import asyncio import builtins import warnings -from typing import Any, List, Literal, Mapping +from typing import Any, Dict, List, Literal, Mapping, Optional from autogen_core import CancellationToken, Component, Image, trace_tool_span from autogen_core.tools import ( ImageResultContent, ParametersSchema, TextResultContent, + ToolOverride, ToolResult, ToolSchema, Workbench, ) from mcp.types import CallToolResult, EmbeddedResource, ImageContent, ListToolsResult, TextContent -from pydantic import BaseModel +from pydantic import BaseModel, Field from typing_extensions import Self from ._actor import McpSessionActor @@ -22,6 +23,7 @@ from ._config import McpServerParams, SseServerParams, StdioServerParams, Stream class McpWorkbenchConfig(BaseModel): server_params: McpServerParams + tool_overrides: Dict[str, ToolOverride] = Field(default_factory=dict) class McpWorkbenchState(BaseModel): @@ -36,6 +38,10 @@ class McpWorkbench(Workbench, Component[McpWorkbenchConfig]): Args: server_params (McpServerParams): The parameters to connect to the MCP server. This can be either a :class:`StdioServerParams` or :class:`SseServerParams`. + tool_overrides (Optional[Dict[str, ToolOverride]]): Optional mapping of original tool + names to override configurations for name and/or description. This allows + customizing how server tools appear to consumers while maintaining the underlying + tool functionality. Examples: @@ -65,6 +71,38 @@ class McpWorkbench(Workbench, Component[McpWorkbenchConfig]): asyncio.run(main()) + Example of using tool overrides: + + .. code-block:: python + + import asyncio + from autogen_ext.tools.mcp import McpWorkbench, StdioServerParams + from autogen_core.tools import ToolOverride + + + async def main() -> None: + params = StdioServerParams( + command="uvx", + args=["mcp-server-fetch"], + read_timeout_seconds=60, + ) + + # Override the fetch tool's name and description + overrides = { + "fetch": ToolOverride(name="web_fetch", description="Enhanced web fetching tool with better error handling") + } + + async with McpWorkbench(server_params=params, tool_overrides=overrides) as workbench: + tools = await workbench.list_tools() + # The tool will now appear as "web_fetch" with the new description + print(tools) + # Call the overridden tool + result = await workbench.call_tool("web_fetch", {"url": "https://github.com/"}) + print(result) + + + asyncio.run(main()) + Example of using the workbench with the `GitHub MCP Server `_: .. code-block:: python @@ -149,8 +187,26 @@ class McpWorkbench(Workbench, Component[McpWorkbenchConfig]): component_provider_override = "autogen_ext.tools.mcp.McpWorkbench" component_config_schema = McpWorkbenchConfig - def __init__(self, server_params: McpServerParams) -> None: + def __init__( + self, server_params: McpServerParams, tool_overrides: Optional[Dict[str, ToolOverride]] = None + ) -> None: self._server_params = server_params + self._tool_overrides = tool_overrides or {} + + # Build reverse mapping from override names to original names for call_tool + self._override_name_to_original: Dict[str, str] = {} + for original_name, override in self._tool_overrides.items(): + override_name = override.name + if override_name and override_name != original_name: + # Check for conflicts with other override names + if override_name in self._override_name_to_original: + existing_original = self._override_name_to_original[override_name] + raise ValueError( + f"Tool override name '{override_name}' is used by multiple tools: " + f"'{existing_original}' and '{original_name}'. Override names must be unique." + ) + self._override_name_to_original[override_name] = original_name + # self._session: ClientSession | None = None self._actor: McpSessionActor | None = None self._actor_loop: asyncio.AbstractEventLoop | None = None @@ -175,8 +231,18 @@ class McpWorkbench(Workbench, Component[McpWorkbenchConfig]): ), f"list_tools must return a CallToolResult, instead of : {str(type(list_tool_result))}" schema: List[ToolSchema] = [] for tool in list_tool_result.tools: - name = tool.name + original_name = tool.name + name = original_name description = tool.description or "" + + # Apply overrides if they exist for this tool + if original_name in self._tool_overrides: + override = self._tool_overrides[original_name] + if override.name is not None: + name = override.name + if override.description is not None: + description = override.description + parameters = ParametersSchema( type="object", properties=tool.inputSchema.get("properties", {}), @@ -208,12 +274,16 @@ class McpWorkbench(Workbench, Component[McpWorkbenchConfig]): cancellation_token = CancellationToken() if not arguments: arguments = {} + + # Check if the name is an override name and map it back to the original + original_name = self._override_name_to_original.get(name, name) + with trace_tool_span( - tool_name=name, + tool_name=name, # Use the requested name for tracing tool_call_id=call_id, ): try: - result_future = await self._actor.call("call_tool", {"name": name, "kargs": arguments}) + result_future = await self._actor.call("call_tool", {"name": original_name, "kargs": arguments}) cancellation_token.link_future(result_future) result = await result_future assert isinstance( @@ -236,7 +306,7 @@ class McpWorkbench(Workbench, Component[McpWorkbenchConfig]): error_message = self._format_errors(e) is_error = True result_parts = [TextResultContent(content=error_message)] - return ToolResult(name=name, result=result_parts, is_error=is_error) + return ToolResult(name=name, result=result_parts, is_error=is_error) # Return the requested name def _format_errors(self, error: Exception) -> str: """Recursively format errors into a string.""" @@ -285,18 +355,21 @@ class McpWorkbench(Workbench, Component[McpWorkbenchConfig]): pass def _to_config(self) -> McpWorkbenchConfig: - return McpWorkbenchConfig(server_params=self._server_params) + return McpWorkbenchConfig(server_params=self._server_params, tool_overrides=self._tool_overrides) @classmethod def _from_config(cls, config: McpWorkbenchConfig) -> Self: - return cls(server_params=config.server_params) + return cls(server_params=config.server_params, tool_overrides=config.tool_overrides) def __del__(self) -> None: # Ensure the actor is stopped when the workbench is deleted - if self._actor and self._actor_loop: - loop = self._actor_loop - if loop.is_running() and not loop.is_closed(): - loop.call_soon_threadsafe(lambda: asyncio.create_task(self.stop())) + # Use getattr to safely handle cases where attributes may not be set (e.g., if __init__ failed) + actor = getattr(self, "_actor", None) + actor_loop = getattr(self, "_actor_loop", None) + + if actor and actor_loop: + if actor_loop.is_running() and not actor_loop.is_closed(): + actor_loop.call_soon_threadsafe(lambda: asyncio.create_task(self.stop())) else: msg = "Cannot safely stop actor at [McpWorkbench.__del__]: loop is closed or not running" warnings.warn(msg, RuntimeWarning, stacklevel=2) diff --git a/python/packages/autogen-ext/tests/tools/test_mcp_workbench_overrides.py b/python/packages/autogen-ext/tests/tools/test_mcp_workbench_overrides.py new file mode 100644 index 000000000..76fcef1c1 --- /dev/null +++ b/python/packages/autogen-ext/tests/tools/test_mcp_workbench_overrides.py @@ -0,0 +1,298 @@ +import asyncio +from typing import Any, Dict +from unittest.mock import AsyncMock + +import pytest +from autogen_core.tools import ToolOverride +from autogen_ext.tools.mcp import McpWorkbench, StdioServerParams +from mcp import Tool +from mcp.types import ListToolsResult + + +@pytest.fixture +def sample_mcp_tools() -> list[Tool]: + """Create sample MCP tools for testing.""" + return [ + Tool( + name="fetch", + description="Fetches content from a URL", + inputSchema={ + "type": "object", + "properties": {"url": {"type": "string"}}, + "required": ["url"], + }, + ), + Tool( + name="search", + description="Searches for information", + inputSchema={ + "type": "object", + "properties": {"query": {"type": "string"}}, + "required": ["query"], + }, + ), + ] + + +@pytest.fixture +def mock_mcp_actor() -> AsyncMock: + """Mock MCP session actor.""" + actor = AsyncMock() + return actor + + +@pytest.fixture +def sample_server_params() -> StdioServerParams: + """Sample server parameters for testing.""" + return StdioServerParams(command="echo", args=["test"]) + + +@pytest.mark.asyncio +async def test_mcp_workbench_with_tool_overrides( + sample_mcp_tools: list[Tool], mock_mcp_actor: AsyncMock, sample_server_params: StdioServerParams +) -> None: + """Test McpWorkbench with tool name and description overrides.""" + + # Define tool overrides + overrides: Dict[str, ToolOverride] = { + "fetch": ToolOverride(name="web_fetch", description="Enhanced web fetching tool"), + "search": ToolOverride(description="Advanced search functionality"), # Only override description + } + + # Create workbench with overrides + workbench = McpWorkbench(server_params=sample_server_params, tool_overrides=overrides) + workbench._actor = mock_mcp_actor # type: ignore[reportPrivateUsage] + + # Mock list_tools response + list_tools_result = ListToolsResult(tools=sample_mcp_tools) + + # The actor.call() should return a Future that when awaited returns the list_tools_result + future_result: asyncio.Future[ListToolsResult] = asyncio.Future() + future_result.set_result(list_tools_result) + mock_mcp_actor.call.return_value = future_result + + try: + # List tools and verify overrides are applied + tools = await workbench.list_tools() + assert len(tools) == 2 + + # Check first tool has name and description overridden + assert tools[0].get("name") == "web_fetch" + assert tools[0].get("description") == "Enhanced web fetching tool" + + # Check second tool has only description overridden + assert tools[1].get("name") == "search" # Original name + assert tools[1].get("description") == "Advanced search functionality" # Overridden description + + # Verify actor was called correctly + mock_mcp_actor.call.assert_called_with("list_tools", None) + + finally: + workbench._actor = None # type: ignore[reportPrivateUsage] + + +@pytest.mark.asyncio +async def test_mcp_workbench_call_tool_with_overrides( + sample_mcp_tools: list[Tool], mock_mcp_actor: AsyncMock, sample_server_params: StdioServerParams +) -> None: + """Test calling tools with override names maps back to original names.""" + + overrides: Dict[str, ToolOverride] = { + "fetch": ToolOverride(name="web_fetch", description="Enhanced web fetching tool") + } + + workbench = McpWorkbench(server_params=sample_server_params, tool_overrides=overrides) + workbench._actor = mock_mcp_actor # type: ignore[reportPrivateUsage] + + # Mock successful tool call response + from mcp.types import CallToolResult, TextContent + + mock_result = CallToolResult(content=[TextContent(text="Mock response", type="text")], isError=False) + + # Create futures for each call + def mock_call_side_effect( + method: str, args: dict[str, Any] | None = None + ) -> asyncio.Future[ListToolsResult | CallToolResult]: + future_result: asyncio.Future[ListToolsResult | CallToolResult] = asyncio.Future() + if method == "list_tools": + future_result.set_result(ListToolsResult(tools=sample_mcp_tools)) + elif method == "call_tool": + future_result.set_result(mock_result) + else: + future_result.set_exception(ValueError(f"Unexpected method: {method}")) + return future_result + + mock_mcp_actor.call.side_effect = mock_call_side_effect + + try: + # Call tool using override name + result = await workbench.call_tool("web_fetch", {"url": "https://example.com"}) + + # Verify the result + assert result.name == "web_fetch" # Should return the override name + assert result.result[0].content == "Mock response" + assert result.is_error is False + + # Verify the actor was called with the original tool name + call_args = mock_mcp_actor.call.call_args_list[-1] + assert call_args[0][0] == "call_tool" + assert call_args[0][1]["name"] == "fetch" # Original name should be used + assert call_args[0][1]["kargs"] == {"url": "https://example.com"} + + finally: + workbench._actor = None # type: ignore[reportPrivateUsage] + + +@pytest.mark.asyncio +async def test_mcp_workbench_without_overrides( + sample_mcp_tools: list[Tool], mock_mcp_actor: AsyncMock, sample_server_params: StdioServerParams +) -> None: + """Test McpWorkbench without overrides (original behavior).""" + + workbench = McpWorkbench(server_params=sample_server_params) + workbench._actor = mock_mcp_actor # type: ignore[reportPrivateUsage] + + # Mock list_tools response + list_tools_result = ListToolsResult(tools=sample_mcp_tools) + future_result: asyncio.Future[ListToolsResult] = asyncio.Future() + future_result.set_result(list_tools_result) + mock_mcp_actor.call.return_value = future_result + + try: + tools = await workbench.list_tools() + assert len(tools) == 2 + + # Verify original names and descriptions are preserved + assert tools[0].get("name") == "fetch" + assert tools[0].get("description") == "Fetches content from a URL" + assert tools[1].get("name") == "search" + assert tools[1].get("description") == "Searches for information" + + finally: + workbench._actor = None # type: ignore[reportPrivateUsage] + + +@pytest.mark.asyncio +async def test_mcp_workbench_serialization_with_overrides(sample_server_params: StdioServerParams) -> None: + """Test that McpWorkbench can be serialized and deserialized with overrides.""" + + overrides: Dict[str, ToolOverride] = { + "fetch": ToolOverride(name="web_fetch", description="Enhanced web fetching tool") + } + + # Create workbench with overrides + workbench = McpWorkbench(server_params=sample_server_params, tool_overrides=overrides) + + # Save configuration + config = workbench.dump_component() + assert "tool_overrides" in config.config + assert "fetch" in config.config["tool_overrides"] + assert config.config["tool_overrides"]["fetch"]["name"] == "web_fetch" + assert config.config["tool_overrides"]["fetch"]["description"] == "Enhanced web fetching tool" + + # Load workbench from configuration + new_workbench = McpWorkbench.load_component(config) + assert len(new_workbench._tool_overrides) == 1 # type: ignore[reportPrivateUsage] + assert new_workbench._tool_overrides["fetch"].name == "web_fetch" # type: ignore[reportPrivateUsage] + assert new_workbench._tool_overrides["fetch"].description == "Enhanced web fetching tool" # type: ignore[reportPrivateUsage] + + +@pytest.mark.asyncio +async def test_mcp_workbench_partial_overrides( + sample_mcp_tools: list[Tool], mock_mcp_actor: AsyncMock, sample_server_params: StdioServerParams +) -> None: + """Test McpWorkbench with partial overrides (name only, description only).""" + + overrides: Dict[str, ToolOverride] = { + "fetch": ToolOverride(name="web_fetch"), # Only name override + "search": ToolOverride(description="Advanced search"), # Only description override + } + + workbench = McpWorkbench(server_params=sample_server_params, tool_overrides=overrides) + workbench._actor = mock_mcp_actor # type: ignore[reportPrivateUsage] + + # Mock list_tools response + list_tools_result = ListToolsResult(tools=sample_mcp_tools) + future_result: asyncio.Future[ListToolsResult] = asyncio.Future() + future_result.set_result(list_tools_result) + mock_mcp_actor.call.return_value = future_result + + try: + tools = await workbench.list_tools() + + # fetch: name overridden, description unchanged + assert tools[0].get("name") == "web_fetch" + assert tools[0].get("description") == "Fetches content from a URL" # Original description + + # search: name unchanged, description overridden + assert tools[1].get("name") == "search" # Original name + assert tools[1].get("description") == "Advanced search" # Overridden description + + finally: + workbench._actor = None # type: ignore[reportPrivateUsage] + + +def test_mcp_tool_override_model() -> None: + """Test ToolOverride model functionality for MCP.""" + + # Test with both fields + override1 = ToolOverride(name="new_name", description="new_desc") + assert override1.name == "new_name" + assert override1.description == "new_desc" + + # Test with only name + override2 = ToolOverride(name="new_name") + assert override2.name == "new_name" + assert override2.description is None + + # Test with only description + override3 = ToolOverride(description="new_desc") + assert override3.name is None + assert override3.description == "new_desc" + + # Test empty + override4 = ToolOverride() + assert override4.name is None + assert override4.description is None + + +@pytest.mark.asyncio +async def test_mcp_workbench_override_name_to_original_mapping(sample_server_params: StdioServerParams) -> None: + """Test that the reverse mapping from override names to original names works correctly.""" + + overrides: Dict[str, ToolOverride] = { + "original1": ToolOverride(name="override1"), + "original2": ToolOverride(name="override2"), + "original3": ToolOverride(description="only description override"), # No name change, only description + } + + workbench = McpWorkbench(server_params=sample_server_params, tool_overrides=overrides) + + # Check reverse mapping is built correctly + assert workbench._override_name_to_original["override1"] == "original1" # type: ignore[reportPrivateUsage] + assert workbench._override_name_to_original["override2"] == "original2" # type: ignore[reportPrivateUsage] + assert "original3" not in workbench._override_name_to_original # type: ignore[reportPrivateUsage] + assert len(workbench._override_name_to_original) == 2 # type: ignore[reportPrivateUsage] + + +def test_mcp_workbench_conflict_detection() -> None: + """Test that McpWorkbench detects conflicts in tool override names.""" + + server_params = StdioServerParams(command="echo", args=["test"]) + + # Test 1: Valid overrides - should work + overrides_valid: Dict[str, ToolOverride] = { + "fetch": ToolOverride(name="web_fetch"), + "search": ToolOverride(name="advanced_search"), + } + workbench_valid = McpWorkbench(server_params=server_params, tool_overrides=overrides_valid) + assert workbench_valid._override_name_to_original["web_fetch"] == "fetch" # type: ignore[reportPrivateUsage] + assert workbench_valid._override_name_to_original["advanced_search"] == "search" # type: ignore[reportPrivateUsage] + + # Test 2: Duplicate override names - should fail + overrides_duplicate: Dict[str, ToolOverride] = { + "fetch": ToolOverride(name="same_name"), + "search": ToolOverride(name="same_name"), # Duplicate + } + with pytest.raises(ValueError): + McpWorkbench(server_params=server_params, tool_overrides=overrides_duplicate)