98 lines
3.8 KiB
Python
Raw Normal View History

# api/ws.py
import asyncio
import json
from datetime import datetime
from uuid import UUID
from fastapi import APIRouter, Depends, WebSocket, WebSocketDisconnect
from loguru import logger
from ...datamodel import Run, RunStatus
from ..deps import get_db, get_websocket_manager
from ..managers import WebSocketManager
router = APIRouter()
@router.websocket("/runs/{run_id}")
async def run_websocket(
websocket: WebSocket,
run_id: UUID,
ws_manager: WebSocketManager = Depends(get_websocket_manager),
db=Depends(get_db),
):
"""WebSocket endpoint for run communication"""
# Verify run exists and is in valid state
run_response = db.get(Run, filters={"id": run_id}, return_json=False)
if not run_response.status or not run_response.data:
logger.warning(f"Run not found: {run_id}")
await websocket.close(code=4004, reason="Run not found")
return
run = run_response.data[0]
if run.status not in [RunStatus.CREATED, RunStatus.ACTIVE]:
await websocket.close(code=4003, reason="Run not in valid state")
return
# Connect websocket
connected = await ws_manager.connect(websocket, run_id)
if not connected:
await websocket.close(code=4002, reason="Failed to establish connection")
return
try:
logger.info(f"WebSocket connection established for run {run_id}")
while True:
try:
raw_message = await websocket.receive_text()
message = json.loads(raw_message)
if message.get("type") == "start":
# Handle start message
logger.info(f"Received start request for run {run_id}")
task = message.get("task")
team_config = message.get("team_config")
if task and team_config:
# await ws_manager.start_stream(run_id, task, team_config)
asyncio.create_task(ws_manager.start_stream(run_id, task, team_config))
else:
logger.warning(f"Invalid start message format for run {run_id}")
await websocket.send_json(
{
"type": "error",
"error": "Invalid start message format",
"timestamp": datetime.utcnow().isoformat(),
}
)
elif message.get("type") == "stop":
logger.info(f"Received stop request for run {run_id}")
reason = message.get("reason") or "User requested stop/cancellation"
await ws_manager.stop_run(run_id, reason=reason)
break
elif message.get("type") == "ping":
await websocket.send_json({"type": "pong", "timestamp": datetime.utcnow().isoformat()})
elif message.get("type") == "input_response":
# Handle input response from client
response = message.get("response")
if response is not None:
await ws_manager.handle_input_response(run_id, response)
else:
logger.warning(f"Invalid input response format for run {run_id}")
except json.JSONDecodeError:
logger.warning(f"Invalid JSON received: {raw_message}")
await websocket.send_json(
{"type": "error", "error": "Invalid message format", "timestamp": datetime.utcnow().isoformat()}
)
except WebSocketDisconnect:
logger.info(f"WebSocket disconnected for run {run_id}")
except Exception as e:
logger.error(f"WebSocket error: {str(e)}")
finally:
await ws_manager.disconnect(run_id)