# api/ws.py import asyncio import json from datetime import datetime from fastapi import APIRouter, Depends, WebSocket, WebSocketDisconnect from loguru import logger from ...datamodel import Run, RunStatus from ..deps import get_db, get_websocket_manager from ..managers import WebSocketManager router = APIRouter() @router.websocket("/runs/{run_id}") async def run_websocket( websocket: WebSocket, run_id: int, ws_manager: WebSocketManager = Depends(get_websocket_manager), db=Depends(get_db), ): """WebSocket endpoint for run communication""" # Verify run exists and is in valid state run_response = db.get(Run, filters={"id": run_id}, return_json=False) if not run_response.status or not run_response.data: logger.warning(f"Run not found: {run_id}") await websocket.close(code=4004, reason="Run not found") return run = run_response.data[0] if run.status not in [RunStatus.CREATED, RunStatus.ACTIVE]: await websocket.close(code=4003, reason="Run not in valid state") return # Connect websocket connected = await ws_manager.connect(websocket, run_id) if not connected: await websocket.close(code=4002, reason="Failed to establish connection") return try: logger.info(f"WebSocket connection established for run {run_id}") while True: try: raw_message = await websocket.receive_text() message = json.loads(raw_message) if message.get("type") == "start": # Handle start message logger.info(f"Received start request for run {run_id}") task = message.get("task") team_config = message.get("team_config") if task and team_config: # await ws_manager.start_stream(run_id, task, team_config) asyncio.create_task(ws_manager.start_stream(run_id, task, team_config)) else: logger.warning(f"Invalid start message format for run {run_id}") await websocket.send_json( { "type": "error", "error": "Invalid start message format", "timestamp": datetime.utcnow().isoformat(), } ) elif message.get("type") == "stop": logger.info(f"Received stop request for run {run_id}") reason = message.get("reason") or "User requested stop/cancellation" await ws_manager.stop_run(run_id, reason=reason) break elif message.get("type") == "ping": await websocket.send_json({"type": "pong", "timestamp": datetime.utcnow().isoformat()}) elif message.get("type") == "input_response": # Handle input response from client response = message.get("response") if response is not None: await ws_manager.handle_input_response(run_id, response) else: logger.warning(f"Invalid input response format for run {run_id}") except json.JSONDecodeError: logger.warning(f"Invalid JSON received: {raw_message}") await websocket.send_json( {"type": "error", "error": "Invalid message format", "timestamp": datetime.utcnow().isoformat()} ) except WebSocketDisconnect: logger.info(f"WebSocket disconnected for run {run_id}") except Exception as e: logger.error(f"WebSocket error: {str(e)}") finally: await ws_manager.disconnect(run_id)