Victor Dibia b2cef7f47c
Update AGS (Support Workbenches ++) (#6736)
Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com>
2025-07-16 10:03:02 -07:00

187 lines
6.8 KiB
Python

import asyncio
import uuid
from datetime import datetime, timezone
from typing import Any, Dict, Tuple
from loguru import logger
from mcp.shared.context import RequestContext
from mcp.shared.session import RequestResponder
from mcp.types import (
ClientResult,
CreateMessageRequestParams,
CreateMessageResult,
ElicitRequestParams,
ElicitResult,
ErrorData,
ServerNotification,
ServerRequest,
TextContent,
)
from .utils import extract_real_error, serialize_for_json
from .wsbridge import MCPWebSocketBridge
def create_message_handler(bridge: MCPWebSocketBridge):
"""Create a message handler callback that streams MCP protocol messages to the UI"""
async def message_handler(
message: RequestResponder[ServerRequest, ClientResult] | ServerNotification | Exception,
) -> None:
try:
if isinstance(message, Exception):
await bridge.on_mcp_activity(
"error", f"Protocol error: {str(message)}", {"details": extract_real_error(message)}
)
elif hasattr(message, "method"):
method = getattr(message, "method", "unknown")
params = getattr(message, "params", None)
await bridge.on_mcp_activity(
"protocol",
f"MCP {method}",
{
"method": method,
"params": serialize_for_json(params) if params else None,
"message_type": type(message).__name__,
},
)
else:
await bridge.on_mcp_activity(
"protocol",
f"{type(message).__name__}",
{
"message_type": type(message).__name__,
"content": serialize_for_json(message) if hasattr(message, "model_dump") else str(message),
},
)
except Exception as e:
logger.error(f"Error in message handler: {extract_real_error(e)}")
return message_handler
def create_sampling_callback(bridge: MCPWebSocketBridge):
"""Create a sampling callback that handles AI sampling requests from tools"""
async def sampling_callback(
context: RequestContext[Any, Any, Any],
params: CreateMessageRequestParams,
) -> CreateMessageResult | ErrorData:
try:
request_id = str(uuid.uuid4())
await bridge.on_mcp_activity(
"sampling",
f"Tool requested AI sampling for {len(params.messages)} message(s)",
{
"request_id": request_id,
"params": serialize_for_json(params.model_dump()),
"context": "Tool is requesting AI to generate a response",
},
)
dummy_response = CreateMessageResult(
role="assistant",
content=TextContent(
type="text",
text="[AutoGen Studio Default Sampling Response - This is a placeholder response for AI sampling requests. In a production setup, this would be handled by your configured LLM.]",
),
model="autogen-studio-default",
)
await bridge.on_mcp_activity(
"sampling",
"Provided default sampling response to tool",
{
"request_id": request_id,
"response": serialize_for_json(dummy_response.model_dump()),
"note": "This is a placeholder response - configure an LLM for real sampling",
},
)
logger.info("Handled sampling request with default response")
return dummy_response
except Exception as e:
error_msg = extract_real_error(e)
logger.error(f"Error in sampling callback: {error_msg}")
await bridge.on_mcp_activity("error", f"Sampling callback error: {error_msg}", {"error": error_msg})
return ErrorData(code=-32603, message=f"Sampling failed: {error_msg}")
return sampling_callback
def create_elicitation_callback(
bridge: MCPWebSocketBridge,
) -> Tuple[Any, Dict[str, asyncio.Future[ElicitResult | ErrorData]]]:
"""Create an elicitation callback that handles user input requests from tools"""
async def elicitation_callback(
context: RequestContext[Any, Any, Any],
params: ElicitRequestParams,
) -> ElicitResult | ErrorData:
try:
request_id = str(uuid.uuid4())
await bridge.on_mcp_activity(
"elicitation",
f"Tool requesting user input: {params.message}",
{
"request_id": request_id,
"message": params.message,
"requestedSchema": serialize_for_json(params.requestedSchema) if params.requestedSchema else None,
"context": "Tool is requesting additional information from user",
},
)
await bridge.on_elicitation_request(
request_id,
params.message,
serialize_for_json(params.requestedSchema) if params.requestedSchema else None,
)
response_future: asyncio.Future[ElicitResult | ErrorData] = asyncio.Future()
bridge.pending_elicitations[request_id] = response_future
try:
user_response = await asyncio.wait_for(response_future, timeout=60.0)
await bridge.on_mcp_activity(
"elicitation",
"User responded to elicitation request",
{"request_id": request_id, "response": serialize_for_json(user_response), "status": "completed"},
)
return user_response
except asyncio.TimeoutError:
logger.warning(f"User did not respond to elicitation request {request_id} within 60 seconds")
error_msg = "User did not respond to elicitation request within 60 seconds"
await bridge.on_mcp_activity(
"error",
f"Elicitation timeout: {error_msg}",
{"request_id": request_id, "error": error_msg, "timeout": 60},
)
return ErrorData(code=-32603, message=error_msg)
finally:
bridge.pending_elicitations.pop(request_id, None)
except Exception as e:
error_msg = extract_real_error(e)
logger.error(f"Error in elicitation callback: {error_msg}")
await bridge.on_mcp_activity("error", f"Elicitation callback error: {error_msg}", {"error": error_msg})
return ErrorData(code=-32603, message=f"Elicitation failed: {error_msg}")
return elicitation_callback, bridge.pending_elicitations