mirror of
https://github.com/microsoft/autogen.git
synced 2025-08-02 05:42:13 +00:00
187 lines
6.8 KiB
Python
187 lines
6.8 KiB
Python
![]() |
import asyncio
|
||
|
import uuid
|
||
|
from datetime import datetime, timezone
|
||
|
from typing import Any, Dict, Tuple
|
||
|
|
||
|
from loguru import logger
|
||
|
from mcp.shared.context import RequestContext
|
||
|
from mcp.shared.session import RequestResponder
|
||
|
from mcp.types import (
|
||
|
ClientResult,
|
||
|
CreateMessageRequestParams,
|
||
|
CreateMessageResult,
|
||
|
ElicitRequestParams,
|
||
|
ElicitResult,
|
||
|
ErrorData,
|
||
|
ServerNotification,
|
||
|
ServerRequest,
|
||
|
TextContent,
|
||
|
)
|
||
|
|
||
|
from .utils import extract_real_error, serialize_for_json
|
||
|
from .wsbridge import MCPWebSocketBridge
|
||
|
|
||
|
|
||
|
def create_message_handler(bridge: MCPWebSocketBridge):
|
||
|
"""Create a message handler callback that streams MCP protocol messages to the UI"""
|
||
|
|
||
|
async def message_handler(
|
||
|
message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception,
|
||
|
) -> None:
|
||
|
try:
|
||
|
if isinstance(message, Exception):
|
||
|
await bridge.on_mcp_activity(
|
||
|
"error", f"Protocol error: {str(message)}", {"details": extract_real_error(message)}
|
||
|
)
|
||
|
|
||
|
elif hasattr(message, "method"):
|
||
|
method = getattr(message, "method", "unknown")
|
||
|
params = getattr(message, "params", None)
|
||
|
|
||
|
await bridge.on_mcp_activity(
|
||
|
"protocol",
|
||
|
f"MCP {method}",
|
||
|
{
|
||
|
"method": method,
|
||
|
"params": serialize_for_json(params) if params else None,
|
||
|
"message_type": type(message).__name__,
|
||
|
},
|
||
|
)
|
||
|
|
||
|
else:
|
||
|
await bridge.on_mcp_activity(
|
||
|
"protocol",
|
||
|
f"{type(message).__name__}",
|
||
|
{
|
||
|
"message_type": type(message).__name__,
|
||
|
"content": serialize_for_json(message) if hasattr(message, "model_dump") else str(message),
|
||
|
},
|
||
|
)
|
||
|
|
||
|
except Exception as e:
|
||
|
logger.error(f"Error in message handler: {extract_real_error(e)}")
|
||
|
|
||
|
return message_handler
|
||
|
|
||
|
|
||
|
def create_sampling_callback(bridge: MCPWebSocketBridge):
|
||
|
"""Create a sampling callback that handles AI sampling requests from tools"""
|
||
|
|
||
|
async def sampling_callback(
|
||
|
context: RequestContext[Any, Any, Any],
|
||
|
params: CreateMessageRequestParams,
|
||
|
) -> CreateMessageResult | ErrorData:
|
||
|
try:
|
||
|
request_id = str(uuid.uuid4())
|
||
|
|
||
|
await bridge.on_mcp_activity(
|
||
|
"sampling",
|
||
|
f"Tool requested AI sampling for {len(params.messages)} message(s)",
|
||
|
{
|
||
|
"request_id": request_id,
|
||
|
"params": serialize_for_json(params.model_dump()),
|
||
|
"context": "Tool is requesting AI to generate a response",
|
||
|
},
|
||
|
)
|
||
|
|
||
|
dummy_response = CreateMessageResult(
|
||
|
role="assistant",
|
||
|
content=TextContent(
|
||
|
type="text",
|
||
|
text="[AutoGen Studio Default Sampling Response - This is a placeholder response for AI sampling requests. In a production setup, this would be handled by your configured LLM.]",
|
||
|
),
|
||
|
model="autogen-studio-default",
|
||
|
)
|
||
|
|
||
|
await bridge.on_mcp_activity(
|
||
|
"sampling",
|
||
|
"Provided default sampling response to tool",
|
||
|
{
|
||
|
"request_id": request_id,
|
||
|
"response": serialize_for_json(dummy_response.model_dump()),
|
||
|
"note": "This is a placeholder response - configure an LLM for real sampling",
|
||
|
},
|
||
|
)
|
||
|
|
||
|
logger.info("Handled sampling request with default response")
|
||
|
return dummy_response
|
||
|
|
||
|
except Exception as e:
|
||
|
error_msg = extract_real_error(e)
|
||
|
logger.error(f"Error in sampling callback: {error_msg}")
|
||
|
|
||
|
await bridge.on_mcp_activity("error", f"Sampling callback error: {error_msg}", {"error": error_msg})
|
||
|
|
||
|
return ErrorData(code=-32603, message=f"Sampling failed: {error_msg}")
|
||
|
|
||
|
return sampling_callback
|
||
|
|
||
|
|
||
|
def create_elicitation_callback(
|
||
|
bridge: MCPWebSocketBridge,
|
||
|
) -> Tuple[Any, Dict[str, asyncio.Future[ElicitResult | ErrorData]]]:
|
||
|
"""Create an elicitation callback that handles user input requests from tools"""
|
||
|
|
||
|
async def elicitation_callback(
|
||
|
context: RequestContext[Any, Any, Any],
|
||
|
params: ElicitRequestParams,
|
||
|
) -> ElicitResult | ErrorData:
|
||
|
try:
|
||
|
request_id = str(uuid.uuid4())
|
||
|
|
||
|
await bridge.on_mcp_activity(
|
||
|
"elicitation",
|
||
|
f"Tool requesting user input: {params.message}",
|
||
|
{
|
||
|
"request_id": request_id,
|
||
|
"message": params.message,
|
||
|
"requestedSchema": serialize_for_json(params.requestedSchema) if params.requestedSchema else None,
|
||
|
"context": "Tool is requesting additional information from user",
|
||
|
},
|
||
|
)
|
||
|
|
||
|
await bridge.on_elicitation_request(
|
||
|
request_id,
|
||
|
params.message,
|
||
|
serialize_for_json(params.requestedSchema) if params.requestedSchema else None,
|
||
|
)
|
||
|
|
||
|
response_future: asyncio.Future[ElicitResult | ErrorData] = asyncio.Future()
|
||
|
bridge.pending_elicitations[request_id] = response_future
|
||
|
|
||
|
try:
|
||
|
user_response = await asyncio.wait_for(response_future, timeout=60.0)
|
||
|
|
||
|
await bridge.on_mcp_activity(
|
||
|
"elicitation",
|
||
|
"User responded to elicitation request",
|
||
|
{"request_id": request_id, "response": serialize_for_json(user_response), "status": "completed"},
|
||
|
)
|
||
|
|
||
|
return user_response
|
||
|
|
||
|
except asyncio.TimeoutError:
|
||
|
logger.warning(f"User did not respond to elicitation request {request_id} within 60 seconds")
|
||
|
error_msg = "User did not respond to elicitation request within 60 seconds"
|
||
|
|
||
|
await bridge.on_mcp_activity(
|
||
|
"error",
|
||
|
f"Elicitation timeout: {error_msg}",
|
||
|
{"request_id": request_id, "error": error_msg, "timeout": 60},
|
||
|
)
|
||
|
|
||
|
return ErrorData(code=-32603, message=error_msg)
|
||
|
|
||
|
finally:
|
||
|
bridge.pending_elicitations.pop(request_id, None)
|
||
|
|
||
|
except Exception as e:
|
||
|
error_msg = extract_real_error(e)
|
||
|
logger.error(f"Error in elicitation callback: {error_msg}")
|
||
|
|
||
|
await bridge.on_mcp_activity("error", f"Elicitation callback error: {error_msg}", {"error": error_msg})
|
||
|
|
||
|
return ErrorData(code=-32603, message=f"Elicitation failed: {error_msg}")
|
||
|
|
||
|
return elicitation_callback, bridge.pending_elicitations
|