mirror of
https://github.com/microsoft/autogen.git
synced 2025-08-05 07:12:46 +00:00

Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com>
271 lines
10 KiB
Python
271 lines
10 KiB
Python
#!/usr/bin/env python3
|
|
"""Test the refactored MCP callback functions"""
|
|
|
|
import asyncio
|
|
import pytest
|
|
import uuid
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
from datetime import datetime, timezone
|
|
|
|
from typing import Any
|
|
|
|
from mcp.types import (
|
|
CreateMessageRequestParams,
|
|
CreateMessageResult,
|
|
ElicitRequestParams,
|
|
ElicitResult,
|
|
ErrorData,
|
|
TextContent
|
|
)
|
|
from mcp.shared.context import RequestContext
|
|
|
|
from autogenstudio.mcp.callbacks import (
|
|
create_message_handler,
|
|
create_sampling_callback,
|
|
create_elicitation_callback
|
|
)
|
|
from autogenstudio.mcp.wsbridge import MCPWebSocketBridge
|
|
|
|
|
|
class MockBridge(MCPWebSocketBridge):
|
|
"""Mock bridge for testing callbacks"""
|
|
|
|
def __init__(self):
|
|
# Don't call parent __init__ to avoid WebSocket dependency
|
|
self.session_id = "test-session"
|
|
self.pending_elicitations = {}
|
|
self.events = []
|
|
|
|
async def on_mcp_activity(self, activity_type: str, message: str, details: dict) -> None:
|
|
self.events.append(("mcp_activity", activity_type, message, details))
|
|
|
|
async def on_elicitation_request(self, request_id: str, message: str, requested_schema: Any) -> None:
|
|
self.events.append(("elicitation_request", request_id, message, requested_schema))
|
|
|
|
|
|
class TestMCPCallbacks:
|
|
"""Test MCP callback functions"""
|
|
|
|
@pytest.fixture
|
|
def mock_bridge(self):
|
|
"""Create a mock bridge"""
|
|
return MockBridge()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_message_handler_with_exception(self, mock_bridge):
|
|
"""Test message handler with exception"""
|
|
handler = create_message_handler(mock_bridge)
|
|
|
|
test_exception = Exception("Test protocol error")
|
|
await handler(test_exception)
|
|
|
|
# Verify activity was logged
|
|
assert len(mock_bridge.events) == 1
|
|
event = mock_bridge.events[0]
|
|
assert event[0] == "mcp_activity"
|
|
assert event[1] == "error"
|
|
assert "Protocol error" in event[2]
|
|
assert "Test protocol error" in event[3]["details"]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_message_handler_with_method_message(self, mock_bridge):
|
|
"""Test message handler with method-based message"""
|
|
handler = create_message_handler(mock_bridge)
|
|
|
|
# Create a mock message with method attribute
|
|
mock_message = MagicMock()
|
|
mock_message.method = "notifications/initialized"
|
|
mock_message.params = {"capabilities": {"tools": True}}
|
|
|
|
await handler(mock_message)
|
|
|
|
# Verify activity was logged
|
|
assert len(mock_bridge.events) == 1
|
|
event = mock_bridge.events[0]
|
|
assert event[0] == "mcp_activity"
|
|
assert event[1] == "protocol"
|
|
assert "notifications/initialized" in event[2]
|
|
assert event[3]["method"] == "notifications/initialized"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_message_handler_with_other_message(self, mock_bridge):
|
|
"""Test message handler with other message types"""
|
|
handler = create_message_handler(mock_bridge)
|
|
|
|
# Create a simple mock message without method, avoiding recursion issues
|
|
class SimpleMockMessage:
|
|
def model_dump(self):
|
|
return {"type": "response", "data": "test"}
|
|
|
|
mock_message = SimpleMockMessage()
|
|
|
|
# Type ignore for test purposes - we're testing edge case handling
|
|
await handler(mock_message) # type: ignore
|
|
|
|
# Verify activity was logged
|
|
assert len(mock_bridge.events) == 1
|
|
event = mock_bridge.events[0]
|
|
assert event[0] == "mcp_activity"
|
|
assert event[1] == "protocol"
|
|
assert "SimpleMockMessage" in event[2] # Type name
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_sampling_callback_success(self, mock_bridge):
|
|
"""Test sampling callback success case"""
|
|
callback = create_sampling_callback(mock_bridge)
|
|
|
|
# Create mock context and params
|
|
mock_context = AsyncMock(spec=RequestContext)
|
|
mock_params = CreateMessageRequestParams(
|
|
messages=[], # Empty messages array for test
|
|
maxTokens=100
|
|
)
|
|
|
|
result = await callback(mock_context, mock_params)
|
|
|
|
# Verify result is CreateMessageResult
|
|
assert isinstance(result, CreateMessageResult)
|
|
assert result.role == "assistant"
|
|
assert result.model == "autogen-studio-default"
|
|
assert isinstance(result.content, TextContent)
|
|
assert "AutoGen Studio Default Sampling Response" in result.content.text
|
|
|
|
# Verify activities were logged
|
|
assert len(mock_bridge.events) == 2
|
|
# First event: sampling request
|
|
assert mock_bridge.events[0][1] == "sampling"
|
|
assert "Tool requested AI sampling" in mock_bridge.events[0][2]
|
|
# Second event: sampling response
|
|
assert mock_bridge.events[1][1] == "sampling"
|
|
assert "Provided default sampling response" in mock_bridge.events[1][2]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_sampling_callback_exception(self, mock_bridge):
|
|
"""Test sampling callback with exception"""
|
|
callback = create_sampling_callback(mock_bridge)
|
|
|
|
# Create mock context that raises exception
|
|
mock_context = AsyncMock(spec=RequestContext)
|
|
|
|
# Create params that will cause an exception when accessing
|
|
mock_params = MagicMock()
|
|
mock_params.messages = None # This should cause an error
|
|
|
|
# Mock the model_dump to raise exception
|
|
mock_params.model_dump.side_effect = Exception("Test sampling error")
|
|
|
|
result = await callback(mock_context, mock_params)
|
|
|
|
# Verify result is ErrorData
|
|
assert isinstance(result, ErrorData)
|
|
assert result.code == -32603
|
|
assert "Sampling failed" in result.message
|
|
|
|
# Verify error was logged
|
|
error_events = [e for e in mock_bridge.events if e[1] == "error"]
|
|
assert len(error_events) == 1
|
|
assert "Sampling callback error" in error_events[0][2]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_elicitation_callback_success(self, mock_bridge):
|
|
"""Test elicitation callback success case"""
|
|
callback, pending_dict = create_elicitation_callback(mock_bridge)
|
|
|
|
# Verify that pending_dict is the same as bridge's pending_elicitations
|
|
assert pending_dict is mock_bridge.pending_elicitations
|
|
|
|
# Create mock context and params
|
|
mock_context = AsyncMock(spec=RequestContext)
|
|
mock_params = ElicitRequestParams(
|
|
message="Please provide your name",
|
|
requestedSchema={"type": "string"}
|
|
)
|
|
|
|
# Create a task to simulate user response
|
|
async def simulate_user_response():
|
|
await asyncio.sleep(0.1) # Let elicitation setup
|
|
|
|
# Find the request ID from events
|
|
elicit_events = [e for e in mock_bridge.events if e[0] == "elicitation_request"]
|
|
assert len(elicit_events) == 1
|
|
request_id = elicit_events[0][1]
|
|
|
|
# Simulate user accepting
|
|
if request_id in mock_bridge.pending_elicitations:
|
|
future = mock_bridge.pending_elicitations[request_id]
|
|
result = ElicitResult(action="accept", content={"name": "John Doe"})
|
|
future.set_result(result)
|
|
|
|
# Run both the callback and the response simulation
|
|
callback_task = asyncio.create_task(callback(mock_context, mock_params))
|
|
response_task = asyncio.create_task(simulate_user_response())
|
|
|
|
result, _ = await asyncio.gather(callback_task, response_task)
|
|
|
|
# Verify result
|
|
assert isinstance(result, ElicitResult)
|
|
assert result.action == "accept"
|
|
assert result.content == {"name": "John Doe"}
|
|
|
|
# Verify events were logged
|
|
activity_events = [e for e in mock_bridge.events if e[0] == "mcp_activity"]
|
|
elicit_events = [e for e in mock_bridge.events if e[0] == "elicitation_request"]
|
|
|
|
assert len(elicit_events) == 1
|
|
assert len(activity_events) >= 2 # Request and response activities
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_elicitation_callback_timeout(self, mock_bridge):
|
|
"""Test elicitation callback timeout"""
|
|
callback, _ = create_elicitation_callback(mock_bridge)
|
|
|
|
# Create mock context and params
|
|
mock_context = AsyncMock(spec=RequestContext)
|
|
mock_params = ElicitRequestParams(
|
|
message="Please provide input",
|
|
requestedSchema={"type": "string"}
|
|
)
|
|
|
|
# Mock asyncio.wait_for to raise TimeoutError
|
|
with patch('asyncio.wait_for', side_effect=asyncio.TimeoutError):
|
|
result = await callback(mock_context, mock_params)
|
|
|
|
# Verify result is ErrorData
|
|
assert isinstance(result, ErrorData)
|
|
assert result.code == -32603
|
|
assert "60 seconds" in result.message
|
|
|
|
# Verify timeout was logged
|
|
error_events = [e for e in mock_bridge.events if e[1] == "error"]
|
|
assert len(error_events) == 1
|
|
assert "Elicitation timeout" in error_events[0][2]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_elicitation_callback_exception(self, mock_bridge):
|
|
"""Test elicitation callback with exception"""
|
|
callback, _ = create_elicitation_callback(mock_bridge)
|
|
|
|
# Create mock context and params that will cause exception
|
|
mock_context = AsyncMock(spec=RequestContext)
|
|
mock_params = MagicMock()
|
|
mock_params.message = "Test message"
|
|
mock_params.requestedSchema = None
|
|
|
|
# Mock uuid.uuid4 to raise exception
|
|
with patch('uuid.uuid4', side_effect=Exception("UUID generation failed")):
|
|
result = await callback(mock_context, mock_params)
|
|
|
|
# Verify result is ErrorData
|
|
assert isinstance(result, ErrorData)
|
|
assert result.code == -32603
|
|
assert "Elicitation failed" in result.message
|
|
|
|
# Verify error was logged
|
|
error_events = [e for e in mock_bridge.events if e[1] == "error"]
|
|
assert len(error_events) == 1
|
|
assert "Elicitation callback error" in error_events[0][2]
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__, "-v"])
|