mirror of
https://github.com/microsoft/autogen.git
synced 2025-08-01 21:32:38 +00:00

Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com>
187 lines
6.8 KiB
Python
187 lines
6.8 KiB
Python
import asyncio
|
|
import uuid
|
|
from datetime import datetime, timezone
|
|
from typing import Any, Dict, Tuple
|
|
|
|
from loguru import logger
|
|
from mcp.shared.context import RequestContext
|
|
from mcp.shared.session import RequestResponder
|
|
from mcp.types import (
|
|
ClientResult,
|
|
CreateMessageRequestParams,
|
|
CreateMessageResult,
|
|
ElicitRequestParams,
|
|
ElicitResult,
|
|
ErrorData,
|
|
ServerNotification,
|
|
ServerRequest,
|
|
TextContent,
|
|
)
|
|
|
|
from .utils import extract_real_error, serialize_for_json
|
|
from .wsbridge import MCPWebSocketBridge
|
|
|
|
|
|
def create_message_handler(bridge: MCPWebSocketBridge):
|
|
"""Create a message handler callback that streams MCP protocol messages to the UI"""
|
|
|
|
async def message_handler(
|
|
message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception,
|
|
) -> None:
|
|
try:
|
|
if isinstance(message, Exception):
|
|
await bridge.on_mcp_activity(
|
|
"error", f"Protocol error: {str(message)}", {"details": extract_real_error(message)}
|
|
)
|
|
|
|
elif hasattr(message, "method"):
|
|
method = getattr(message, "method", "unknown")
|
|
params = getattr(message, "params", None)
|
|
|
|
await bridge.on_mcp_activity(
|
|
"protocol",
|
|
f"MCP {method}",
|
|
{
|
|
"method": method,
|
|
"params": serialize_for_json(params) if params else None,
|
|
"message_type": type(message).__name__,
|
|
},
|
|
)
|
|
|
|
else:
|
|
await bridge.on_mcp_activity(
|
|
"protocol",
|
|
f"{type(message).__name__}",
|
|
{
|
|
"message_type": type(message).__name__,
|
|
"content": serialize_for_json(message) if hasattr(message, "model_dump") else str(message),
|
|
},
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in message handler: {extract_real_error(e)}")
|
|
|
|
return message_handler
|
|
|
|
|
|
def create_sampling_callback(bridge: MCPWebSocketBridge):
|
|
"""Create a sampling callback that handles AI sampling requests from tools"""
|
|
|
|
async def sampling_callback(
|
|
context: RequestContext[Any, Any, Any],
|
|
params: CreateMessageRequestParams,
|
|
) -> CreateMessageResult | ErrorData:
|
|
try:
|
|
request_id = str(uuid.uuid4())
|
|
|
|
await bridge.on_mcp_activity(
|
|
"sampling",
|
|
f"Tool requested AI sampling for {len(params.messages)} message(s)",
|
|
{
|
|
"request_id": request_id,
|
|
"params": serialize_for_json(params.model_dump()),
|
|
"context": "Tool is requesting AI to generate a response",
|
|
},
|
|
)
|
|
|
|
dummy_response = CreateMessageResult(
|
|
role="assistant",
|
|
content=TextContent(
|
|
type="text",
|
|
text="[AutoGen Studio Default Sampling Response - This is a placeholder response for AI sampling requests. In a production setup, this would be handled by your configured LLM.]",
|
|
),
|
|
model="autogen-studio-default",
|
|
)
|
|
|
|
await bridge.on_mcp_activity(
|
|
"sampling",
|
|
"Provided default sampling response to tool",
|
|
{
|
|
"request_id": request_id,
|
|
"response": serialize_for_json(dummy_response.model_dump()),
|
|
"note": "This is a placeholder response - configure an LLM for real sampling",
|
|
},
|
|
)
|
|
|
|
logger.info("Handled sampling request with default response")
|
|
return dummy_response
|
|
|
|
except Exception as e:
|
|
error_msg = extract_real_error(e)
|
|
logger.error(f"Error in sampling callback: {error_msg}")
|
|
|
|
await bridge.on_mcp_activity("error", f"Sampling callback error: {error_msg}", {"error": error_msg})
|
|
|
|
return ErrorData(code=-32603, message=f"Sampling failed: {error_msg}")
|
|
|
|
return sampling_callback
|
|
|
|
|
|
def create_elicitation_callback(
|
|
bridge: MCPWebSocketBridge,
|
|
) -> Tuple[Any, Dict[str, asyncio.Future[ElicitResult | ErrorData]]]:
|
|
"""Create an elicitation callback that handles user input requests from tools"""
|
|
|
|
async def elicitation_callback(
|
|
context: RequestContext[Any, Any, Any],
|
|
params: ElicitRequestParams,
|
|
) -> ElicitResult | ErrorData:
|
|
try:
|
|
request_id = str(uuid.uuid4())
|
|
|
|
await bridge.on_mcp_activity(
|
|
"elicitation",
|
|
f"Tool requesting user input: {params.message}",
|
|
{
|
|
"request_id": request_id,
|
|
"message": params.message,
|
|
"requestedSchema": serialize_for_json(params.requestedSchema) if params.requestedSchema else None,
|
|
"context": "Tool is requesting additional information from user",
|
|
},
|
|
)
|
|
|
|
await bridge.on_elicitation_request(
|
|
request_id,
|
|
params.message,
|
|
serialize_for_json(params.requestedSchema) if params.requestedSchema else None,
|
|
)
|
|
|
|
response_future: asyncio.Future[ElicitResult | ErrorData] = asyncio.Future()
|
|
bridge.pending_elicitations[request_id] = response_future
|
|
|
|
try:
|
|
user_response = await asyncio.wait_for(response_future, timeout=60.0)
|
|
|
|
await bridge.on_mcp_activity(
|
|
"elicitation",
|
|
"User responded to elicitation request",
|
|
{"request_id": request_id, "response": serialize_for_json(user_response), "status": "completed"},
|
|
)
|
|
|
|
return user_response
|
|
|
|
except asyncio.TimeoutError:
|
|
logger.warning(f"User did not respond to elicitation request {request_id} within 60 seconds")
|
|
error_msg = "User did not respond to elicitation request within 60 seconds"
|
|
|
|
await bridge.on_mcp_activity(
|
|
"error",
|
|
f"Elicitation timeout: {error_msg}",
|
|
{"request_id": request_id, "error": error_msg, "timeout": 60},
|
|
)
|
|
|
|
return ErrorData(code=-32603, message=error_msg)
|
|
|
|
finally:
|
|
bridge.pending_elicitations.pop(request_id, None)
|
|
|
|
except Exception as e:
|
|
error_msg = extract_real_error(e)
|
|
logger.error(f"Error in elicitation callback: {error_msg}")
|
|
|
|
await bridge.on_mcp_activity("error", f"Elicitation callback error: {error_msg}", {"error": error_msg})
|
|
|
|
return ErrorData(code=-32603, message=f"Elicitation failed: {error_msg}")
|
|
|
|
return elicitation_callback, bridge.pending_elicitations
|