Victor Dibia b2cef7f47c
Update AGS (Support Workbenches ++) (#6736)
Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com>
2025-07-16 10:03:02 -07:00

241 lines
8.9 KiB
Python

import base64
import json
import uuid
from datetime import datetime, timezone
from typing import Any, Dict
from autogen_ext.tools.mcp._config import (
McpServerParams,
SseServerParams,
StdioServerParams,
StreamableHttpServerParams,
)
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from loguru import logger
from mcp import ClientSession, StdioServerParameters
from mcp.client.sse import sse_client
from mcp.client.stdio import stdio_client
from mcp.client.streamable_http import streamablehttp_client
from pydantic import BaseModel
from ...mcp.callbacks import (
create_elicitation_callback,
create_message_handler,
create_sampling_callback,
)
from ...mcp.client import MCPClient
from ...mcp.utils import extract_real_error, is_websocket_disconnect, serialize_for_json
from ...mcp.wsbridge import MCPWebSocketBridge
router = APIRouter()
# Global session tracking for status endpoint
active_sessions: Dict[str, Dict[str, Any]] = {}
class CreateWebSocketConnectionRequest(BaseModel):
server_params: McpServerParams
async def create_mcp_session(bridge: MCPWebSocketBridge, server_params: McpServerParams, session_id: str):
"""Create MCP session based on server parameters"""
# Create callbacks using the bridge
message_handler = create_message_handler(bridge)
sampling_callback = create_sampling_callback(bridge)
elicitation_callback, _ = create_elicitation_callback(bridge)
if isinstance(server_params, StdioServerParams):
stdio_params = StdioServerParameters(
command=server_params.command, args=server_params.args, env=server_params.env
)
async with stdio_client(stdio_params) as (read, write):
async with ClientSession(
read,
write,
message_handler=message_handler,
sampling_callback=sampling_callback,
elicitation_callback=elicitation_callback,
) as session:
mcp_client = MCPClient(session, session_id, bridge)
bridge.set_mcp_client(mcp_client)
# Initialize and run
await mcp_client.initialize()
# Store session info
active_sessions[session_id] = {
"created_at": datetime.now(timezone.utc),
"last_activity": datetime.now(timezone.utc),
"capabilities": serialize_for_json(mcp_client.capabilities.model_dump())
if mcp_client.capabilities
else None,
}
# Run the bridge message loop
await bridge.run()
elif isinstance(server_params, SseServerParams):
async with sse_client(server_params.url) as (read, write):
async with ClientSession(
read,
write,
message_handler=message_handler,
sampling_callback=sampling_callback,
elicitation_callback=elicitation_callback,
) as session:
mcp_client = MCPClient(session, session_id, bridge)
bridge.set_mcp_client(mcp_client)
await mcp_client.initialize()
active_sessions[session_id] = {
"created_at": datetime.now(timezone.utc),
"last_activity": datetime.now(timezone.utc),
"capabilities": serialize_for_json(mcp_client.capabilities.model_dump())
if mcp_client.capabilities
else None,
}
await bridge.run()
elif isinstance(server_params, StreamableHttpServerParams):
async with streamablehttp_client(server_params.url) as (read, write, _):
async with ClientSession(
read,
write,
message_handler=message_handler,
sampling_callback=sampling_callback,
elicitation_callback=elicitation_callback,
) as session:
mcp_client = MCPClient(session, session_id, bridge)
bridge.set_mcp_client(mcp_client)
await mcp_client.initialize()
active_sessions[session_id] = {
"created_at": datetime.now(timezone.utc),
"last_activity": datetime.now(timezone.utc),
"capabilities": serialize_for_json(mcp_client.capabilities.model_dump())
if mcp_client.capabilities
else None,
}
await bridge.run()
else:
raise ValueError(f"Unsupported server params type: {type(server_params)}")
@router.websocket("/ws/{session_id}")
async def mcp_websocket(websocket: WebSocket, session_id: str):
"""Main WebSocket endpoint - now a thin layer"""
await websocket.accept()
logger.info(f"MCP WebSocket connection established for session {session_id}")
bridge = None
try:
# Parse server parameters
query_params = dict(websocket.query_params)
server_params_encoded = query_params.get("server_params")
if not server_params_encoded:
await websocket.close(code=4000, reason="Missing server_params")
return
decoded_params = base64.b64decode(server_params_encoded).decode("utf-8")
server_params_dict = json.loads(decoded_params)
# Create appropriate server params object
if server_params_dict.get("type") == "StdioServerParams":
server_params = StdioServerParams(**server_params_dict)
elif server_params_dict.get("type") == "SseServerParams":
server_params = SseServerParams(**server_params_dict)
elif server_params_dict.get("type") == "StreamableHttpServerParams":
server_params = StreamableHttpServerParams(**server_params_dict)
else:
await websocket.close(code=4000, reason="Invalid server parameters")
return
# Create bridge and run MCP session
bridge = MCPWebSocketBridge(websocket, session_id)
await create_mcp_session(bridge, server_params, session_id)
except WebSocketDisconnect:
logger.info(f"MCP WebSocket session {session_id} disconnected normally")
except Exception as e:
real_error = extract_real_error(e)
if is_websocket_disconnect(e):
logger.info(f"MCP WebSocket session {session_id} disconnected (wrapped)")
else:
logger.error(f"MCP WebSocket error for session {session_id}: {real_error}")
if bridge and not is_websocket_disconnect(e):
try:
await bridge.send_message(
{
"type": "error",
"error": f"Connection error: {real_error}",
"timestamp": datetime.now(timezone.utc).isoformat(),
}
)
except Exception:
pass
finally:
# Cleanup
if session_id in active_sessions:
session_info = active_sessions.pop(session_id, None)
if session_info:
duration = datetime.now(timezone.utc) - session_info["created_at"]
logger.info(f"MCP session {session_id} ended after {duration.total_seconds():.2f} seconds")
if bridge:
bridge.stop()
@router.post("/ws/connect")
async def create_mcp_websocket_connection(request: CreateWebSocketConnectionRequest):
"""Create WebSocket connection URL"""
try:
session_id = str(uuid.uuid4())
server_params_json = json.dumps(serialize_for_json(request.server_params.model_dump()))
server_params_encoded = base64.b64encode(server_params_json.encode("utf-8")).decode("utf-8")
return {
"status": True,
"message": "WebSocket connection URL created",
"session_id": session_id,
"websocket_url": f"/api/mcp/ws/{session_id}?server_params={server_params_encoded}",
"timestamp": datetime.now(timezone.utc).isoformat(),
}
except Exception as e:
real_error = extract_real_error(e)
logger.error(f"Error creating WebSocket connection: {real_error}")
return {"status": False, "message": "An internal error occurred while creating the WebSocket connection."}
@router.get("/ws/status/{session_id}")
async def get_mcp_session_status(session_id: str):
"""Get MCP session status"""
session_info = active_sessions.get(session_id)
if not session_info:
return {"status": False, "message": "Session not found", "session_id": session_id}
# Update last activity
active_sessions[session_id]["last_activity"] = datetime.now(timezone.utc)
return {
"status": True,
"message": "Session active",
"session_id": session_id,
"connected": True,
"capabilities": session_info.get("capabilities"),
"created_at": session_info["created_at"].isoformat(),
"last_activity": session_info["last_activity"].isoformat(),
}