mirror of
https://github.com/microsoft/autogen.git
synced 2025-08-04 14:52:10 +00:00

Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com>
268 lines
9.3 KiB
Python
268 lines
9.3 KiB
Python
#!/usr/bin/env python3
|
|
"""Test the MCP client implementation"""
|
|
|
|
import asyncio
|
|
import pytest
|
|
from typing import Any
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
from datetime import datetime, timezone
|
|
|
|
from mcp.types import (
|
|
ListToolsResult,
|
|
Tool,
|
|
CallToolResult,
|
|
TextContent,
|
|
ListResourcesResult,
|
|
Resource,
|
|
ReadResourceResult,
|
|
TextResourceContents,
|
|
ListPromptsResult,
|
|
Prompt,
|
|
GetPromptResult,
|
|
PromptMessage,
|
|
InitializeResult,
|
|
ServerCapabilities,
|
|
Implementation
|
|
)
|
|
|
|
from autogenstudio.mcp.client import MCPClient, MCPEventHandler
|
|
from autogenstudio.mcp.utils import McpOperationError
|
|
|
|
|
|
class MockEventHandler(MCPEventHandler):
|
|
"""Mock event handler for testing"""
|
|
|
|
def __init__(self):
|
|
self.events = []
|
|
|
|
async def on_initialized(self, session_id: str, capabilities: Any) -> None:
|
|
self.events.append(("initialized", session_id, capabilities))
|
|
|
|
async def on_operation_result(self, operation: str, data: dict) -> None:
|
|
self.events.append(("operation_result", operation, data))
|
|
|
|
async def on_operation_error(self, operation: str, error: str) -> None:
|
|
self.events.append(("operation_error", operation, error))
|
|
|
|
async def on_mcp_activity(self, activity_type: str, message: str, details: dict) -> None:
|
|
self.events.append(("mcp_activity", activity_type, message, details))
|
|
|
|
async def on_elicitation_request(self, request_id: str, message: str, requested_schema: Any) -> None:
|
|
self.events.append(("elicitation_request", request_id, message, requested_schema))
|
|
|
|
|
|
class TestMCPClient:
|
|
"""Test the MCPClient class"""
|
|
|
|
@pytest.fixture
|
|
def mock_session(self):
|
|
"""Create a mock MCP session"""
|
|
session = AsyncMock()
|
|
|
|
# Mock initialization
|
|
session.initialize.return_value = InitializeResult(
|
|
protocolVersion="2024-11-05",
|
|
capabilities=ServerCapabilities(),
|
|
serverInfo=Implementation(name="test-server", version="1.0.0")
|
|
)
|
|
|
|
# Mock tools
|
|
session.list_tools.return_value = ListToolsResult(
|
|
tools=[
|
|
Tool(
|
|
name="test_tool",
|
|
description="A test tool",
|
|
inputSchema={
|
|
"type": "object",
|
|
"properties": {"message": {"type": "string"}},
|
|
"required": ["message"]
|
|
}
|
|
)
|
|
]
|
|
)
|
|
|
|
# Mock tool call
|
|
session.call_tool.return_value = CallToolResult(
|
|
content=[TextContent(type="text", text="Tool executed successfully")],
|
|
isError=False
|
|
)
|
|
|
|
# Mock resources
|
|
from pydantic import HttpUrl
|
|
test_uri = HttpUrl("https://example.com/test.txt")
|
|
session.list_resources.return_value = ListResourcesResult(
|
|
resources=[
|
|
Resource(
|
|
uri=test_uri,
|
|
name="test.txt",
|
|
description="A test resource",
|
|
mimeType="text/plain"
|
|
)
|
|
]
|
|
)
|
|
|
|
session.read_resource.return_value = ReadResourceResult(
|
|
contents=[TextResourceContents(
|
|
uri=test_uri,
|
|
text="This is test content",
|
|
mimeType="text/plain"
|
|
)]
|
|
)
|
|
|
|
# Mock prompts
|
|
session.list_prompts.return_value = ListPromptsResult(
|
|
prompts=[
|
|
Prompt(
|
|
name="test_prompt",
|
|
description="A test prompt"
|
|
)
|
|
]
|
|
)
|
|
|
|
session.get_prompt.return_value = GetPromptResult(
|
|
description="Test prompt result",
|
|
messages=[
|
|
PromptMessage(
|
|
role="user",
|
|
content=TextContent(type="text", text="Test prompt content")
|
|
)
|
|
]
|
|
)
|
|
|
|
return session
|
|
|
|
@pytest.fixture
|
|
def mock_event_handler(self):
|
|
"""Create a mock event handler"""
|
|
return MockEventHandler()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_client_initialization(self, mock_session, mock_event_handler):
|
|
"""Test MCPClient initialization"""
|
|
client = MCPClient(mock_session, "test-session", mock_event_handler)
|
|
|
|
assert client.session == mock_session
|
|
assert client.session_id == "test-session"
|
|
assert client.event_handler == mock_event_handler
|
|
assert not client._initialized
|
|
assert client._capabilities is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_client_initialize(self, mock_session, mock_event_handler):
|
|
"""Test MCPClient.initialize()"""
|
|
client = MCPClient(mock_session, "test-session", mock_event_handler)
|
|
|
|
await client.initialize()
|
|
|
|
# Verify session.initialize was called
|
|
mock_session.initialize.assert_called_once()
|
|
|
|
# Verify client state
|
|
assert client._initialized
|
|
assert client.capabilities is not None
|
|
|
|
# Verify event handler was called
|
|
events = [e for e in mock_event_handler.events if e[0] == "initialized"]
|
|
assert len(events) == 1
|
|
assert events[0][1] == "test-session"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_list_tools_operation(self, mock_session, mock_event_handler):
|
|
"""Test list_tools operation"""
|
|
client = MCPClient(mock_session, "test-session", mock_event_handler)
|
|
await client.initialize()
|
|
|
|
# Test list_tools operation
|
|
operation = {"operation": "list_tools"}
|
|
await client.handle_operation(operation)
|
|
|
|
# Verify session method was called
|
|
mock_session.list_tools.assert_called_once()
|
|
|
|
# Verify result event was fired
|
|
result_events = [e for e in mock_event_handler.events if e[0] == "operation_result"]
|
|
assert len(result_events) == 1
|
|
assert result_events[0][1] == "list_tools"
|
|
assert "tools" in result_events[0][2]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_call_tool_operation(self, mock_session, mock_event_handler):
|
|
"""Test call_tool operation"""
|
|
client = MCPClient(mock_session, "test-session", mock_event_handler)
|
|
await client.initialize()
|
|
|
|
# Test call_tool operation
|
|
operation = {
|
|
"operation": "call_tool",
|
|
"tool_name": "test_tool",
|
|
"arguments": {"message": "test"}
|
|
}
|
|
await client.handle_operation(operation)
|
|
|
|
# Verify session method was called
|
|
mock_session.call_tool.assert_called_once_with("test_tool", {"message": "test"})
|
|
|
|
# Verify result event was fired
|
|
result_events = [e for e in mock_event_handler.events if e[0] == "operation_result"]
|
|
assert len(result_events) == 1
|
|
assert result_events[0][1] == "call_tool"
|
|
assert result_events[0][2]["tool_name"] == "test_tool"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_call_tool_missing_name(self, mock_session, mock_event_handler):
|
|
"""Test call_tool operation with missing tool name"""
|
|
client = MCPClient(mock_session, "test-session", mock_event_handler)
|
|
await client.initialize()
|
|
|
|
# Test call_tool operation without tool_name
|
|
operation = {
|
|
"operation": "call_tool",
|
|
"arguments": {"message": "test"}
|
|
}
|
|
await client.handle_operation(operation)
|
|
|
|
# Verify error event was fired
|
|
error_events = [e for e in mock_event_handler.events if e[0] == "operation_error"]
|
|
assert len(error_events) == 1
|
|
assert error_events[0][1] == "call_tool"
|
|
assert "Tool name is required" in error_events[0][2]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_unknown_operation(self, mock_session, mock_event_handler):
|
|
"""Test unknown operation handling"""
|
|
client = MCPClient(mock_session, "test-session", mock_event_handler)
|
|
await client.initialize()
|
|
|
|
# Test unknown operation
|
|
operation = {"operation": "unknown_op"}
|
|
await client.handle_operation(operation)
|
|
|
|
# Verify error event was fired
|
|
error_events = [e for e in mock_event_handler.events if e[0] == "operation_error"]
|
|
assert len(error_events) == 1
|
|
assert error_events[0][1] == "unknown_op"
|
|
assert "Unknown operation" in error_events[0][2]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_operation_exception_handling(self, mock_session, mock_event_handler):
|
|
"""Test operation exception handling"""
|
|
client = MCPClient(mock_session, "test-session", mock_event_handler)
|
|
await client.initialize()
|
|
|
|
# Mock session to raise exception
|
|
mock_session.list_tools.side_effect = Exception("Test error")
|
|
|
|
# Test list_tools operation
|
|
operation = {"operation": "list_tools"}
|
|
await client.handle_operation(operation)
|
|
|
|
# Verify error event was fired
|
|
error_events = [e for e in mock_event_handler.events if e[0] == "operation_error"]
|
|
assert len(error_events) == 1
|
|
assert error_events[0][1] == "list_tools"
|
|
assert "Test error" in error_events[0][2]
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__, "-v"])
|