mirror of
https://github.com/microsoft/autogen.git
synced 2025-07-05 16:10:50 +00:00
75 lines
2.5 KiB
Python
75 lines
2.5 KiB
Python
![]() |
# api/ws.py
|
||
|
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Depends, HTTPException
|
||
|
from typing import Dict
|
||
|
from uuid import UUID
|
||
|
import logging
|
||
|
import json
|
||
|
from datetime import datetime
|
||
|
|
||
|
from ..deps import get_websocket_manager, get_db, get_team_manager
|
||
|
from ...datamodel import Run, RunStatus
|
||
|
|
||
|
router = APIRouter()
|
||
|
logger = logging.getLogger(__name__)
|
||
|
|
||
|
|
||
|
@router.websocket("/runs/{run_id}")
|
||
|
async def run_websocket(
|
||
|
websocket: WebSocket,
|
||
|
run_id: UUID,
|
||
|
ws_manager=Depends(get_websocket_manager),
|
||
|
db=Depends(get_db),
|
||
|
team_manager=Depends(get_team_manager)
|
||
|
):
|
||
|
"""WebSocket endpoint for run communication"""
|
||
|
# Verify run exists and is in valid state
|
||
|
run_response = db.get(Run, filters={"id": run_id}, return_json=False)
|
||
|
if not run_response.status or not run_response.data:
|
||
|
await websocket.close(code=4004, reason="Run not found")
|
||
|
return
|
||
|
|
||
|
run = run_response.data[0]
|
||
|
if run.status not in [RunStatus.CREATED, RunStatus.ACTIVE]:
|
||
|
await websocket.close(code=4003, reason="Run not in valid state")
|
||
|
return
|
||
|
|
||
|
# Connect websocket
|
||
|
connected = await ws_manager.connect(websocket, run_id)
|
||
|
if not connected:
|
||
|
await websocket.close(code=4002, reason="Failed to establish connection")
|
||
|
return
|
||
|
|
||
|
try:
|
||
|
logger.info(f"WebSocket connection established for run {run_id}")
|
||
|
|
||
|
while True:
|
||
|
try:
|
||
|
raw_message = await websocket.receive_text()
|
||
|
message = json.loads(raw_message)
|
||
|
|
||
|
if message.get("type") == "stop":
|
||
|
logger.info(f"Received stop request for run {run_id}")
|
||
|
await ws_manager.stop_run(run_id)
|
||
|
break
|
||
|
|
||
|
elif message.get("type") == "ping":
|
||
|
await websocket.send_json({
|
||
|
"type": "pong",
|
||
|
"timestamp": datetime.utcnow().isoformat()
|
||
|
})
|
||
|
|
||
|
except json.JSONDecodeError:
|
||
|
logger.warning(f"Invalid JSON received: {raw_message}")
|
||
|
await websocket.send_json({
|
||
|
"type": "error",
|
||
|
"error": "Invalid message format",
|
||
|
"timestamp": datetime.utcnow().isoformat()
|
||
|
})
|
||
|
|
||
|
except WebSocketDisconnect:
|
||
|
logger.info(f"WebSocket disconnected for run {run_id}")
|
||
|
except Exception as e:
|
||
|
logger.error(f"WebSocket error: {str(e)}")
|
||
|
finally:
|
||
|
await ws_manager.disconnect(run_id)
|