mirror of
https://github.com/microsoft/autogen.git
synced 2025-08-02 22:02:21 +00:00

Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com>
216 lines
8.0 KiB
Python
216 lines
8.0 KiB
Python
import asyncio
|
|
import json
|
|
from datetime import datetime, timezone
|
|
from typing import Any, Dict, Optional
|
|
|
|
from fastapi import WebSocket
|
|
from loguru import logger
|
|
from mcp import ClientSession
|
|
from mcp.types import ElicitResult, ErrorData
|
|
|
|
from .client import MCPClient, MCPEventHandler
|
|
from .utils import extract_real_error, is_websocket_disconnect, serialize_for_json
|
|
|
|
|
|
class MCPWebSocketBridge(MCPEventHandler):
|
|
"""Bridges WebSocket connections to MCP operations"""
|
|
|
|
def __init__(self, websocket: WebSocket, session_id: str):
|
|
self.websocket = websocket
|
|
self.session_id = session_id
|
|
self.mcp_client: Optional[MCPClient] = None
|
|
self.pending_elicitations: Dict[str, asyncio.Future[ElicitResult | ErrorData]] = {}
|
|
self._running = True
|
|
|
|
async def send_message(self, message: Dict[str, Any]) -> None:
|
|
"""Send a message through the WebSocket"""
|
|
try:
|
|
from fastapi.websockets import WebSocketState
|
|
|
|
if self.websocket.client_state == WebSocketState.CONNECTED:
|
|
serialized_message = serialize_for_json(message)
|
|
await self.websocket.send_json(serialized_message)
|
|
except Exception as e:
|
|
real_error = extract_real_error(e)
|
|
logger.error(f"Error sending WebSocket message: {real_error}")
|
|
|
|
# Implement MCPEventHandler interface
|
|
async def on_initialized(self, session_id: str, capabilities: Any) -> None:
|
|
await self.send_message(
|
|
{
|
|
"type": "initialized",
|
|
"session_id": session_id,
|
|
"capabilities": capabilities,
|
|
"timestamp": datetime.now(timezone.utc).isoformat(),
|
|
}
|
|
)
|
|
|
|
async def on_operation_result(self, operation: str, data: Dict[str, Any]) -> None:
|
|
await self.send_message(
|
|
{
|
|
"type": "operation_result",
|
|
"operation": operation,
|
|
"data": data,
|
|
"timestamp": datetime.now(timezone.utc).isoformat(),
|
|
}
|
|
)
|
|
|
|
async def on_operation_error(self, operation: str, error: str) -> None:
|
|
await self.send_message(
|
|
{
|
|
"type": "operation_error",
|
|
"operation": operation,
|
|
"error": error,
|
|
"timestamp": datetime.now(timezone.utc).isoformat(),
|
|
}
|
|
)
|
|
|
|
async def on_mcp_activity(self, activity_type: str, message: str, details: Dict[str, Any]) -> None:
|
|
await self.send_message(
|
|
{
|
|
"type": "mcp_activity",
|
|
"activity_type": activity_type,
|
|
"message": message,
|
|
"details": details,
|
|
"session_id": self.session_id,
|
|
"timestamp": datetime.now(timezone.utc).isoformat(),
|
|
}
|
|
)
|
|
|
|
async def on_elicitation_request(self, request_id: str, message: str, requested_schema: Any) -> None:
|
|
await self.send_message(
|
|
{
|
|
"type": "elicitation_request",
|
|
"request_id": request_id,
|
|
"message": message,
|
|
"requestedSchema": requested_schema,
|
|
"session_id": self.session_id,
|
|
"timestamp": datetime.now(timezone.utc).isoformat(),
|
|
}
|
|
)
|
|
|
|
def set_mcp_client(self, mcp_client: MCPClient) -> None:
|
|
"""Set the MCP client after initialization"""
|
|
self.mcp_client = mcp_client
|
|
|
|
async def handle_websocket_message(self, message: Dict[str, Any]) -> None:
|
|
"""Handle incoming WebSocket messages"""
|
|
message_type = message.get("type")
|
|
|
|
# Update last activity for session tracking
|
|
from datetime import datetime, timezone
|
|
|
|
from ..web.routes.mcp import active_sessions
|
|
|
|
if self.session_id in active_sessions:
|
|
active_sessions[self.session_id]["last_activity"] = datetime.now(timezone.utc)
|
|
|
|
if message_type == "operation":
|
|
# CRITICAL: Run in background task to avoid blocking message loop
|
|
# This preserves the exact behavior from the original code
|
|
if self.mcp_client:
|
|
asyncio.create_task(self.mcp_client.handle_operation(message))
|
|
else:
|
|
await self.send_message(
|
|
{
|
|
"type": "error",
|
|
"error": "MCP client not initialized",
|
|
"timestamp": datetime.now(timezone.utc).isoformat(),
|
|
}
|
|
)
|
|
|
|
elif message_type == "ping":
|
|
await self.send_message({"type": "pong", "timestamp": datetime.now(timezone.utc).isoformat()})
|
|
|
|
elif message_type == "elicitation_response":
|
|
await self._handle_elicitation_response(message)
|
|
|
|
else:
|
|
await self.send_message(
|
|
{
|
|
"type": "error",
|
|
"error": f"Unknown message type: {message_type}",
|
|
"timestamp": datetime.now(timezone.utc).isoformat(),
|
|
}
|
|
)
|
|
|
|
async def _handle_elicitation_response(self, message: Dict[str, Any]) -> None:
|
|
"""Handle user response to elicitation request"""
|
|
request_id = message.get("request_id")
|
|
|
|
if not request_id:
|
|
await self.send_message(
|
|
{
|
|
"type": "error",
|
|
"error": "Missing request_id in elicitation response",
|
|
"timestamp": datetime.now(timezone.utc).isoformat(),
|
|
}
|
|
)
|
|
return
|
|
|
|
if request_id in self.pending_elicitations:
|
|
try:
|
|
action = message.get("action", "cancel")
|
|
data = message.get("data", {})
|
|
|
|
if action == "accept":
|
|
result = ElicitResult(action="accept", content=data)
|
|
elif action == "decline":
|
|
result = ElicitResult(action="decline")
|
|
else:
|
|
result = ElicitResult(action="cancel")
|
|
|
|
future = self.pending_elicitations[request_id]
|
|
if not future.done():
|
|
future.set_result(result)
|
|
else:
|
|
logger.warning(f"Future for elicitation request {request_id} was already done")
|
|
|
|
except Exception as e:
|
|
error_msg = extract_real_error(e)
|
|
logger.error(f"Error processing elicitation response: {error_msg}")
|
|
|
|
future = self.pending_elicitations.get(request_id)
|
|
if future and not future.done():
|
|
future.set_result(
|
|
ErrorData(code=-32603, message=f"Error processing elicitation response: {error_msg}")
|
|
)
|
|
else:
|
|
logger.warning(f"Unknown elicitation request_id: {request_id}")
|
|
await self.send_message(
|
|
{
|
|
"type": "operation_error",
|
|
"error": f"Unknown elicitation request_id: {request_id}",
|
|
"timestamp": datetime.now(timezone.utc).isoformat(),
|
|
}
|
|
)
|
|
|
|
async def run(self) -> None:
|
|
"""Main message loop"""
|
|
try:
|
|
while self._running:
|
|
try:
|
|
raw_message = await self.websocket.receive_text()
|
|
message = json.loads(raw_message)
|
|
await self.handle_websocket_message(message)
|
|
|
|
except json.JSONDecodeError:
|
|
logger.warning(f"Invalid JSON received from session {self.session_id}")
|
|
await self.send_message(
|
|
{
|
|
"type": "error",
|
|
"error": "Invalid message format",
|
|
"timestamp": datetime.now(timezone.utc).isoformat(),
|
|
}
|
|
)
|
|
|
|
except Exception as e:
|
|
if not is_websocket_disconnect(e):
|
|
real_error = extract_real_error(e)
|
|
logger.error(f"Error in message loop: {real_error}")
|
|
raise
|
|
|
|
def stop(self) -> None:
|
|
"""Stop the bridge"""
|
|
self._running = False
|