autogen/python/packages/autogen-studio/tests/mcp/test_mcp_websocket.py

313 lines
12 KiB
Python
Raw Permalink Normal View History

"""
Updated tests for MCP WebSocket functionality using the new refactored architecture.
These tests replace the failing legacy tests in test_mcp_websocket.py.
"""
import json
import base64
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi import WebSocket
# Import the new architecture components
from autogenstudio.mcp.client import MCPClient
from autogenstudio.mcp.wsbridge import MCPWebSocketBridge
# Import MCP types for mocking
from mcp.types import (
Tool, Resource, Prompt, PromptArgument,
ListToolsResult, CallToolResult, ListResourcesResult,
ReadResourceResult, ListPromptsResult, GetPromptResult,
TextContent, TextResourceContents, PromptMessage,
ServerCapabilities, ToolsCapability, ResourcesCapability, PromptsCapability
)
from autogen_ext.tools.mcp._config import StdioServerParams
class TestMCPWebSocketUpdated:
"""Updated tests for MCP WebSocket functionality"""
@pytest.fixture
def mock_server_params(self):
"""Create mock server parameters"""
return StdioServerParams(
command="node",
args=["server.js"],
env={"NODE_ENV": "test"}
)
@pytest.fixture
def mock_client_session(self):
"""Create a mock MCP client session with all necessary methods"""
mock_session = AsyncMock()
# Mock initialization result
mock_init_result = MagicMock()
mock_init_result.capabilities = ServerCapabilities(
tools=ToolsCapability(listChanged=False),
resources=ResourcesCapability(subscribe=False, listChanged=False),
prompts=PromptsCapability(listChanged=False)
)
mock_session.initialize.return_value = mock_init_result
# Mock tools
mock_tools = [
Tool(
name="test_tool",
description="A test tool",
inputSchema={
"type": "object",
"properties": {"message": {"type": "string"}},
"required": ["message"]
}
)
]
mock_session.list_tools.return_value = ListToolsResult(tools=mock_tools)
# Mock call tool result
mock_call_result = CallToolResult(
content=[TextContent(type="text", text="Tool executed successfully")],
isError=False
)
mock_session.call_tool.return_value = mock_call_result
# Mock resources
from pydantic import HttpUrl
test_uri = HttpUrl("https://example.com/test.txt")
mock_resources = [
Resource(
uri=test_uri,
name="test.txt",
description="A test resource",
mimeType="text/plain"
)
]
mock_session.list_resources.return_value = ListResourcesResult(resources=mock_resources)
# Mock resource content
mock_resource_content = ReadResourceResult(
contents=[TextResourceContents(
uri=test_uri,
text="This is test content",
mimeType="text/plain"
)]
)
mock_session.read_resource.return_value = mock_resource_content
# Mock prompts
mock_prompts = [
Prompt(
name="test_prompt",
description="A test prompt",
arguments=[
PromptArgument(
name="input",
description="Input text",
required=True
)
]
)
]
mock_session.list_prompts.return_value = ListPromptsResult(prompts=mock_prompts)
# Mock prompt result
mock_prompt_result = GetPromptResult(
description="Test prompt result",
messages=[
PromptMessage(
role="user",
content=TextContent(type="text", text="Test message")
)
]
)
mock_session.get_prompt.return_value = mock_prompt_result
return mock_session
@pytest.fixture
def mock_websocket(self):
"""Create a mock WebSocket"""
mock_ws = AsyncMock(spec=WebSocket)
from fastapi.websockets import WebSocketState
mock_ws.client_state = WebSocketState.CONNECTED
return mock_ws
@pytest.mark.asyncio
async def test_websocket_bridge_send_message(self, mock_websocket):
"""Test WebSocket message sending via MCPWebSocketBridge"""
bridge = MCPWebSocketBridge(mock_websocket, "test_session")
test_message = {"type": "test", "data": "hello"}
await bridge.send_message(test_message)
mock_websocket.send_json.assert_called_once_with(test_message)
@pytest.mark.asyncio
async def test_mcp_client_list_tools_operation(self, mock_websocket, mock_client_session):
"""Test handling list_tools operation via MCPClient"""
bridge = MCPWebSocketBridge(mock_websocket, "test_session")
client = MCPClient(mock_client_session, "test_session", bridge)
operation = {"operation": "list_tools"}
await client.handle_operation(operation)
# Verify the session method was called
mock_client_session.list_tools.assert_called_once()
# Verify WebSocket response was sent
mock_websocket.send_json.assert_called_once()
sent_message = mock_websocket.send_json.call_args[0][0]
assert sent_message["type"] == "operation_result"
assert sent_message["operation"] == "list_tools"
assert "data" in sent_message
assert "tools" in sent_message["data"]
@pytest.mark.asyncio
async def test_mcp_client_call_tool_operation(self, mock_websocket, mock_client_session):
"""Test handling call_tool operation via MCPClient"""
bridge = MCPWebSocketBridge(mock_websocket, "test_session")
client = MCPClient(mock_client_session, "test_session", bridge)
operation = {
"operation": "call_tool",
"tool_name": "test_tool",
"arguments": {"message": "hello"}
}
await client.handle_operation(operation)
# Verify the session method was called with correct arguments
mock_client_session.call_tool.assert_called_once_with("test_tool", {"message": "hello"})
# Verify WebSocket response was sent
mock_websocket.send_json.assert_called_once()
sent_message = mock_websocket.send_json.call_args[0][0]
assert sent_message["type"] == "operation_result"
assert sent_message["operation"] == "call_tool"
@pytest.mark.asyncio
async def test_mcp_client_list_resources_operation(self, mock_websocket, mock_client_session):
"""Test handling list_resources operation via MCPClient"""
bridge = MCPWebSocketBridge(mock_websocket, "test_session")
client = MCPClient(mock_client_session, "test_session", bridge)
operation = {"operation": "list_resources"}
await client.handle_operation(operation)
mock_client_session.list_resources.assert_called_once()
mock_websocket.send_json.assert_called_once()
sent_message = mock_websocket.send_json.call_args[0][0]
assert sent_message["type"] == "operation_result"
assert sent_message["operation"] == "list_resources"
@pytest.mark.asyncio
async def test_mcp_client_read_resource_operation(self, mock_websocket, mock_client_session):
"""Test handling read_resource operation via MCPClient"""
bridge = MCPWebSocketBridge(mock_websocket, "test_session")
client = MCPClient(mock_client_session, "test_session", bridge)
operation = {
"operation": "read_resource",
"uri": "https://example.com/test.txt"
}
await client.handle_operation(operation)
mock_client_session.read_resource.assert_called_once_with("https://example.com/test.txt")
mock_websocket.send_json.assert_called_once()
sent_message = mock_websocket.send_json.call_args[0][0]
assert sent_message["type"] == "operation_result"
assert sent_message["operation"] == "read_resource"
@pytest.mark.asyncio
async def test_mcp_client_error_handling(self, mock_websocket, mock_client_session):
"""Test error handling in MCP operations via MCPClient"""
bridge = MCPWebSocketBridge(mock_websocket, "test_session")
client = MCPClient(mock_client_session, "test_session", bridge)
# Make the session raise an exception
mock_client_session.list_tools.side_effect = Exception("Test error")
operation = {"operation": "list_tools"}
await client.handle_operation(operation)
# Verify operation error response was sent
mock_websocket.send_json.assert_called_once()
sent_message = mock_websocket.send_json.call_args[0][0]
assert sent_message["type"] == "operation_error"
assert sent_message["operation"] == "list_tools"
assert "Test error" in sent_message["error"]
def test_websocket_connection_url_generation(self, mock_server_params):
"""Test WebSocket connection URL generation (preserved from original tests)"""
session_id = "test-session-123"
# Test the URL generation logic
server_params_json = json.dumps(mock_server_params.model_dump())
encoded_params = base64.b64encode(server_params_json.encode()).decode()
expected_url = f"ws://localhost:8000/ws/mcp?session_id={session_id}&server_params={encoded_params}"
# This is a functional test - just verify the encoding/decoding works
decoded_params = base64.b64decode(encoded_params.encode()).decode()
decoded_obj = json.loads(decoded_params)
assert decoded_obj["command"] == "node"
assert decoded_obj["args"] == ["server.js"]
assert decoded_obj["env"]["NODE_ENV"] == "test"
def test_active_sessions_structure(self):
"""Test active sessions data structure (preserved from original tests)"""
from autogenstudio.web.routes.mcp import active_sessions
# Test that active_sessions is a dictionary
assert isinstance(active_sessions, dict)
# Test adding a session
session_id = "test-session"
session_data = {
"session_id": session_id,
"server_params": {"command": "node", "args": ["server.js"]},
"last_activity": "2023-01-01T00:00:00Z",
"status": "active"
}
active_sessions[session_id] = session_data
assert session_id in active_sessions
assert active_sessions[session_id] == session_data
# Clean up
del active_sessions[session_id]
class TestMCPRouteIntegrationUpdated:
"""Updated integration tests for MCP routes"""
def test_router_exists(self):
"""Test that the MCP router exists and is properly configured"""
from autogenstudio.web.routes.mcp import router
from fastapi import APIRouter
assert isinstance(router, APIRouter)
def test_create_websocket_connection_request_model(self):
"""Test the request model for creating WebSocket connections"""
from autogenstudio.web.routes.mcp import CreateWebSocketConnectionRequest
from autogen_ext.tools.mcp._config import StdioServerParams
# Test creating a request with valid server params
server_params = StdioServerParams(
command="node",
args=["server.js"],
env={"NODE_ENV": "test"}
)
request = CreateWebSocketConnectionRequest(server_params=server_params)
assert request.server_params == server_params
# Type-check that server_params is StdioServerParams
assert isinstance(request.server_params, StdioServerParams)
assert request.server_params.command == "node"
assert request.server_params.args == ["server.js"]