| 
									
										
										
										
											2024-11-09 14:32:24 -08:00
										 |  |  | # api/ws.py | 
					
						
							| 
									
										
										
										
											2024-11-26 15:39:36 -08:00
										 |  |  | import asyncio | 
					
						
							| 
									
										
										
										
											2024-11-09 14:32:24 -08:00
										 |  |  | import json | 
					
						
							|  |  |  | from datetime import datetime | 
					
						
							| 
									
										
										
										
											2024-11-26 15:39:36 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  | from fastapi import APIRouter, Depends, WebSocket, WebSocketDisconnect | 
					
						
							|  |  |  | from loguru import logger | 
					
						
							| 
									
										
										
										
											2024-11-09 14:32:24 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  | from ...datamodel import Run, RunStatus | 
					
						
							| 
									
										
										
										
											2024-11-26 15:39:36 -08:00
										 |  |  | from ..deps import get_db, get_websocket_manager | 
					
						
							| 
									
										
										
										
											2024-11-15 14:51:57 -08:00
										 |  |  | from ..managers import WebSocketManager | 
					
						
							| 
									
										
										
										
											2024-11-09 14:32:24 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  | router = APIRouter() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | @router.websocket("/runs/{run_id}") | 
					
						
							|  |  |  | async def run_websocket( | 
					
						
							|  |  |  |     websocket: WebSocket, | 
					
						
							| 
									
										
										
										
											2025-03-06 10:52:42 -08:00
										 |  |  |     run_id: int, | 
					
						
							| 
									
										
										
										
											2024-11-15 14:51:57 -08:00
										 |  |  |     ws_manager: WebSocketManager = Depends(get_websocket_manager), | 
					
						
							| 
									
										
										
										
											2024-11-09 14:32:24 -08:00
										 |  |  |     db=Depends(get_db), | 
					
						
							|  |  |  | ): | 
					
						
							|  |  |  |     """WebSocket endpoint for run communication""" | 
					
						
							|  |  |  |     # Verify run exists and is in valid state | 
					
						
							|  |  |  |     run_response = db.get(Run, filters={"id": run_id}, return_json=False) | 
					
						
							|  |  |  |     if not run_response.status or not run_response.data: | 
					
						
							| 
									
										
										
										
											2024-11-26 15:39:36 -08:00
										 |  |  |         logger.warning(f"Run not found: {run_id}") | 
					
						
							| 
									
										
										
										
											2024-11-09 14:32:24 -08:00
										 |  |  |         await websocket.close(code=4004, reason="Run not found") | 
					
						
							|  |  |  |         return | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     run = run_response.data[0] | 
					
						
							|  |  |  |     if run.status not in [RunStatus.CREATED, RunStatus.ACTIVE]: | 
					
						
							|  |  |  |         await websocket.close(code=4003, reason="Run not in valid state") | 
					
						
							|  |  |  |         return | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     # Connect websocket | 
					
						
							|  |  |  |     connected = await ws_manager.connect(websocket, run_id) | 
					
						
							|  |  |  |     if not connected: | 
					
						
							|  |  |  |         await websocket.close(code=4002, reason="Failed to establish connection") | 
					
						
							|  |  |  |         return | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     try: | 
					
						
							|  |  |  |         logger.info(f"WebSocket connection established for run {run_id}") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         while True: | 
					
						
							|  |  |  |             try: | 
					
						
							|  |  |  |                 raw_message = await websocket.receive_text() | 
					
						
							|  |  |  |                 message = json.loads(raw_message) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-11-26 15:39:36 -08:00
										 |  |  |                 if message.get("type") == "start": | 
					
						
							|  |  |  |                     # Handle start message | 
					
						
							|  |  |  |                     logger.info(f"Received start request for run {run_id}") | 
					
						
							|  |  |  |                     task = message.get("task") | 
					
						
							|  |  |  |                     team_config = message.get("team_config") | 
					
						
							|  |  |  |                     if task and team_config: | 
					
						
							|  |  |  |                         # await ws_manager.start_stream(run_id, task, team_config) | 
					
						
							|  |  |  |                         asyncio.create_task(ws_manager.start_stream(run_id, task, team_config)) | 
					
						
							|  |  |  |                     else: | 
					
						
							|  |  |  |                         logger.warning(f"Invalid start message format for run {run_id}") | 
					
						
							|  |  |  |                         await websocket.send_json( | 
					
						
							|  |  |  |                             { | 
					
						
							|  |  |  |                                 "type": "error", | 
					
						
							|  |  |  |                                 "error": "Invalid start message format", | 
					
						
							|  |  |  |                                 "timestamp": datetime.utcnow().isoformat(), | 
					
						
							|  |  |  |                             } | 
					
						
							|  |  |  |                         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 elif message.get("type") == "stop": | 
					
						
							|  |  |  |                     logger.info(f"Received stop request for run {run_id}") | 
					
						
							|  |  |  |                     reason = message.get("reason") or "User requested stop/cancellation" | 
					
						
							| 
									
										
										
										
											2024-11-15 14:51:57 -08:00
										 |  |  |                     await ws_manager.stop_run(run_id, reason=reason) | 
					
						
							| 
									
										
										
										
											2024-11-09 14:32:24 -08:00
										 |  |  |                     break | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 elif message.get("type") == "ping": | 
					
						
							| 
									
										
										
										
											2024-11-26 15:39:36 -08:00
										 |  |  |                     await websocket.send_json({"type": "pong", "timestamp": datetime.utcnow().isoformat()}) | 
					
						
							| 
									
										
										
										
											2024-11-09 14:32:24 -08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-11-15 14:51:57 -08:00
										 |  |  |                 elif message.get("type") == "input_response": | 
					
						
							|  |  |  |                     # Handle input response from client | 
					
						
							|  |  |  |                     response = message.get("response") | 
					
						
							|  |  |  |                     if response is not None: | 
					
						
							|  |  |  |                         await ws_manager.handle_input_response(run_id, response) | 
					
						
							|  |  |  |                     else: | 
					
						
							| 
									
										
										
										
											2024-11-26 15:39:36 -08:00
										 |  |  |                         logger.warning(f"Invalid input response format for run {run_id}") | 
					
						
							| 
									
										
										
										
											2024-11-15 14:51:57 -08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-11-09 14:32:24 -08:00
										 |  |  |             except json.JSONDecodeError: | 
					
						
							|  |  |  |                 logger.warning(f"Invalid JSON received: {raw_message}") | 
					
						
							| 
									
										
										
										
											2024-11-26 15:39:36 -08:00
										 |  |  |                 await websocket.send_json( | 
					
						
							|  |  |  |                     {"type": "error", "error": "Invalid message format", "timestamp": datetime.utcnow().isoformat()} | 
					
						
							|  |  |  |                 ) | 
					
						
							| 
									
										
										
										
											2024-11-09 14:32:24 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     except WebSocketDisconnect: | 
					
						
							|  |  |  |         logger.info(f"WebSocket disconnected for run {run_id}") | 
					
						
							|  |  |  |     except Exception as e: | 
					
						
							|  |  |  |         logger.error(f"WebSocket error: {str(e)}") | 
					
						
							|  |  |  |     finally: | 
					
						
							|  |  |  |         await ws_manager.disconnect(run_id) |