| 
									
										
										
										
											2024-11-15 14:51:57 -08:00
										 |  |  | import asyncio | 
					
						
							| 
									
										
										
										
											2024-11-09 14:32:24 -08:00
										 |  |  | import logging | 
					
						
							|  |  |  | from datetime import datetime, timezone | 
					
						
							| 
									
										
										
										
											2024-11-26 15:39:36 -08:00
										 |  |  | from typing import Any, Callable, Dict, Optional, Union | 
					
						
							|  |  |  | from uuid import UUID | 
					
						
							| 
									
										
										
										
											2024-11-09 14:32:24 -08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-11-26 15:39:36 -08:00
										 |  |  | from autogen_agentchat.base._task import TaskResult | 
					
						
							| 
									
										
										
										
											2024-12-14 15:33:14 -08:00
										 |  |  | from autogen_agentchat.messages import ( | 
					
						
							| 
									
										
										
										
											2024-12-18 14:09:19 -08:00
										 |  |  |     AgentEvent, | 
					
						
							| 
									
										
										
										
											2024-12-14 15:33:14 -08:00
										 |  |  |     ChatMessage, | 
					
						
							|  |  |  |     HandoffMessage, | 
					
						
							|  |  |  |     MultiModalMessage, | 
					
						
							|  |  |  |     StopMessage, | 
					
						
							|  |  |  |     TextMessage, | 
					
						
							| 
									
										
										
										
											2024-12-18 14:09:19 -08:00
										 |  |  |     ToolCallExecutionEvent, | 
					
						
							| 
									
										
										
										
											2024-12-18 18:57:11 -08:00
										 |  |  |     ToolCallRequestEvent, | 
					
						
							| 
									
										
										
										
											2024-12-14 15:33:14 -08:00
										 |  |  | ) | 
					
						
							| 
									
										
										
										
											2024-12-03 17:00:44 -08:00
										 |  |  | from autogen_core import CancellationToken | 
					
						
							|  |  |  | from autogen_core import Image as AGImage | 
					
						
							| 
									
										
										
										
											2024-11-26 15:39:36 -08:00
										 |  |  | from fastapi import WebSocket, WebSocketDisconnect | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | from ...database import DatabaseManager | 
					
						
							|  |  |  | from ...datamodel import Message, MessageConfig, Run, RunStatus, TeamResult | 
					
						
							|  |  |  | from ...teammanager import TeamManager | 
					
						
							| 
									
										
										
										
											2024-11-09 14:32:24 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  | logger = logging.getLogger(__name__) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class WebSocketManager: | 
					
						
							|  |  |  |     """Manages WebSocket connections and message streaming for team task execution""" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def __init__(self, db_manager: DatabaseManager): | 
					
						
							|  |  |  |         self.db_manager = db_manager | 
					
						
							|  |  |  |         self._connections: Dict[UUID, WebSocket] = {} | 
					
						
							|  |  |  |         self._cancellation_tokens: Dict[UUID, CancellationToken] = {} | 
					
						
							| 
									
										
										
										
											2024-11-12 20:29:06 -08:00
										 |  |  |         # Track explicitly closed connections | 
					
						
							|  |  |  |         self._closed_connections: set[UUID] = set() | 
					
						
							| 
									
										
										
										
											2024-11-15 14:51:57 -08:00
										 |  |  |         self._input_responses: Dict[UUID, asyncio.Queue] = {} | 
					
						
							| 
									
										
										
										
											2024-11-12 20:29:06 -08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-11-26 15:39:36 -08:00
										 |  |  |         self._cancel_message = TeamResult( | 
					
						
							|  |  |  |             task_result=TaskResult( | 
					
						
							|  |  |  |                 messages=[TextMessage(source="user", content="Run cancelled by user")], stop_reason="cancelled by user" | 
					
						
							|  |  |  |             ), | 
					
						
							|  |  |  |             usage="", | 
					
						
							|  |  |  |             duration=0, | 
					
						
							|  |  |  |         ).model_dump() | 
					
						
							| 
									
										
										
										
											2024-11-09 14:32:24 -08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-11-15 14:51:57 -08:00
										 |  |  |     def _get_stop_message(self, reason: str) -> dict: | 
					
						
							| 
									
										
										
										
											2024-11-26 15:39:36 -08:00
										 |  |  |         return TeamResult( | 
					
						
							|  |  |  |             task_result=TaskResult(messages=[TextMessage(source="user", content=reason)], stop_reason=reason), | 
					
						
							|  |  |  |             usage="", | 
					
						
							|  |  |  |             duration=0, | 
					
						
							|  |  |  |         ).model_dump() | 
					
						
							| 
									
										
										
										
											2024-11-09 14:32:24 -08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-11-15 14:51:57 -08:00
										 |  |  |     async def connect(self, websocket: WebSocket, run_id: UUID) -> bool: | 
					
						
							| 
									
										
										
										
											2024-11-09 14:32:24 -08:00
										 |  |  |         try: | 
					
						
							|  |  |  |             await websocket.accept() | 
					
						
							|  |  |  |             self._connections[run_id] = websocket | 
					
						
							| 
									
										
										
										
											2024-11-12 20:29:06 -08:00
										 |  |  |             self._closed_connections.discard(run_id) | 
					
						
							| 
									
										
										
										
											2024-11-15 14:51:57 -08:00
										 |  |  |             # Initialize input queue for this connection | 
					
						
							|  |  |  |             self._input_responses[run_id] = asyncio.Queue() | 
					
						
							| 
									
										
										
										
											2024-11-09 14:32:24 -08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-11-26 15:39:36 -08:00
										 |  |  |             await self._send_message( | 
					
						
							|  |  |  |                 run_id, {"type": "system", "status": "connected", "timestamp": datetime.now(timezone.utc).isoformat()} | 
					
						
							|  |  |  |             ) | 
					
						
							| 
									
										
										
										
											2024-11-09 14:32:24 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  |             return True | 
					
						
							|  |  |  |         except Exception as e: | 
					
						
							|  |  |  |             logger.error(f"Connection error for run {run_id}: {e}") | 
					
						
							|  |  |  |             return False | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-11-26 15:39:36 -08:00
										 |  |  |     async def start_stream(self, run_id: UUID, task: str, team_config: dict) -> None: | 
					
						
							|  |  |  |         """Start streaming task execution with proper run management""" | 
					
						
							| 
									
										
										
										
											2024-11-12 20:29:06 -08:00
										 |  |  |         if run_id not in self._connections or run_id in self._closed_connections: | 
					
						
							| 
									
										
										
										
											2024-11-09 14:32:24 -08:00
										 |  |  |             raise ValueError(f"No active connection for run {run_id}") | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-11-26 15:39:36 -08:00
										 |  |  |         team_manager = TeamManager() | 
					
						
							| 
									
										
										
										
											2024-11-09 14:32:24 -08:00
										 |  |  |         cancellation_token = CancellationToken() | 
					
						
							|  |  |  |         self._cancellation_tokens[run_id] = cancellation_token | 
					
						
							| 
									
										
										
										
											2024-11-26 15:39:36 -08:00
										 |  |  |         final_result = None | 
					
						
							| 
									
										
										
										
											2024-11-09 14:32:24 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         try: | 
					
						
							| 
									
										
										
										
											2024-11-26 15:39:36 -08:00
										 |  |  |             # Update run with task and status | 
					
						
							|  |  |  |             run = await self._get_run(run_id) | 
					
						
							|  |  |  |             if run: | 
					
						
							|  |  |  |                 run.task = MessageConfig(content=task, source="user").model_dump() | 
					
						
							|  |  |  |                 run.status = RunStatus.ACTIVE | 
					
						
							|  |  |  |                 self.db_manager.upsert(run) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-11-15 14:51:57 -08:00
										 |  |  |             input_func = self.create_input_func(run_id) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-11-09 14:32:24 -08:00
										 |  |  |             async for message in team_manager.run_stream( | 
					
						
							| 
									
										
										
										
											2024-11-26 15:39:36 -08:00
										 |  |  |                 task=task, team_config=team_config, input_func=input_func, cancellation_token=cancellation_token | 
					
						
							| 
									
										
										
										
											2024-11-09 14:32:24 -08:00
										 |  |  |             ): | 
					
						
							| 
									
										
										
										
											2024-11-12 20:29:06 -08:00
										 |  |  |                 if cancellation_token.is_cancelled() or run_id in self._closed_connections: | 
					
						
							| 
									
										
										
										
											2024-11-26 15:39:36 -08:00
										 |  |  |                     logger.info(f"Stream cancelled or connection closed for run {run_id}") | 
					
						
							| 
									
										
										
										
											2024-11-09 14:32:24 -08:00
										 |  |  |                     break | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 formatted_message = self._format_message(message) | 
					
						
							|  |  |  |                 if formatted_message: | 
					
						
							|  |  |  |                     await self._send_message(run_id, formatted_message) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-14 15:33:14 -08:00
										 |  |  |                     # Save messages by concrete type | 
					
						
							|  |  |  |                     if isinstance( | 
					
						
							|  |  |  |                         message, | 
					
						
							|  |  |  |                         ( | 
					
						
							|  |  |  |                             TextMessage, | 
					
						
							|  |  |  |                             MultiModalMessage, | 
					
						
							|  |  |  |                             StopMessage, | 
					
						
							|  |  |  |                             HandoffMessage, | 
					
						
							| 
									
										
										
										
											2024-12-18 14:09:19 -08:00
										 |  |  |                             ToolCallRequestEvent, | 
					
						
							|  |  |  |                             ToolCallExecutionEvent, | 
					
						
							| 
									
										
										
										
											2024-12-14 15:33:14 -08:00
										 |  |  |                         ), | 
					
						
							|  |  |  |                     ): | 
					
						
							| 
									
										
										
										
											2024-11-26 15:39:36 -08:00
										 |  |  |                         await self._save_message(run_id, message) | 
					
						
							|  |  |  |                     # Capture final result if it's a TeamResult | 
					
						
							|  |  |  |                     elif isinstance(message, TeamResult): | 
					
						
							|  |  |  |                         final_result = message.model_dump() | 
					
						
							| 
									
										
										
										
											2024-11-12 20:29:06 -08:00
										 |  |  |             if not cancellation_token.is_cancelled() and run_id not in self._closed_connections: | 
					
						
							| 
									
										
										
										
											2024-11-26 15:39:36 -08:00
										 |  |  |                 if final_result: | 
					
						
							|  |  |  |                     await self._update_run(run_id, RunStatus.COMPLETE, team_result=final_result) | 
					
						
							|  |  |  |                 else: | 
					
						
							|  |  |  |                     logger.warning(f"No final result captured for completed run {run_id}") | 
					
						
							|  |  |  |                     await self._update_run_status(run_id, RunStatus.COMPLETE) | 
					
						
							| 
									
										
										
										
											2024-11-09 14:32:24 -08:00
										 |  |  |             else: | 
					
						
							| 
									
										
										
										
											2024-11-26 15:39:36 -08:00
										 |  |  |                 await self._send_message( | 
					
						
							|  |  |  |                     run_id, | 
					
						
							|  |  |  |                     { | 
					
						
							|  |  |  |                         "type": "completion", | 
					
						
							|  |  |  |                         "status": "cancelled", | 
					
						
							|  |  |  |                         "data": self._cancel_message, | 
					
						
							|  |  |  |                         "timestamp": datetime.now(timezone.utc).isoformat(), | 
					
						
							|  |  |  |                     }, | 
					
						
							|  |  |  |                 ) | 
					
						
							|  |  |  |                 # Update run with cancellation result | 
					
						
							|  |  |  |                 await self._update_run(run_id, RunStatus.STOPPED, team_result=self._cancel_message) | 
					
						
							| 
									
										
										
										
											2024-11-09 14:32:24 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         except Exception as e: | 
					
						
							|  |  |  |             logger.error(f"Stream error for run {run_id}: {e}") | 
					
						
							|  |  |  |             await self._handle_stream_error(run_id, e) | 
					
						
							|  |  |  |         finally: | 
					
						
							|  |  |  |             self._cancellation_tokens.pop(run_id, None) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-12-18 14:09:19 -08:00
										 |  |  |     async def _save_message(self, run_id: UUID, message: Union[AgentEvent | ChatMessage, ChatMessage]) -> None: | 
					
						
							| 
									
										
										
										
											2024-11-26 15:39:36 -08:00
										 |  |  |         """Save a message to the database""" | 
					
						
							|  |  |  |         run = await self._get_run(run_id) | 
					
						
							|  |  |  |         if run: | 
					
						
							|  |  |  |             db_message = Message( | 
					
						
							|  |  |  |                 session_id=run.session_id, | 
					
						
							|  |  |  |                 run_id=run_id, | 
					
						
							|  |  |  |                 config=message.model_dump(), | 
					
						
							|  |  |  |                 user_id=None,  # You might want to pass this from somewhere | 
					
						
							|  |  |  |             ) | 
					
						
							|  |  |  |             self.db_manager.upsert(db_message) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     async def _update_run( | 
					
						
							|  |  |  |         self, run_id: UUID, status: RunStatus, team_result: Optional[dict] = None, error: Optional[str] = None | 
					
						
							|  |  |  |     ) -> None: | 
					
						
							|  |  |  |         """Update run status and result""" | 
					
						
							|  |  |  |         run = await self._get_run(run_id) | 
					
						
							|  |  |  |         if run: | 
					
						
							|  |  |  |             run.status = status | 
					
						
							|  |  |  |             if team_result: | 
					
						
							|  |  |  |                 run.team_result = team_result | 
					
						
							|  |  |  |             if error: | 
					
						
							|  |  |  |                 run.error_message = error | 
					
						
							|  |  |  |             self.db_manager.upsert(run) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-11-15 14:51:57 -08:00
										 |  |  |     def create_input_func(self, run_id: UUID) -> Callable: | 
					
						
							|  |  |  |         """Creates an input function for a specific run""" | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-11-26 15:39:36 -08:00
										 |  |  |         async def input_handler(prompt: str = "", cancellation_token: Optional[CancellationToken] = None) -> str: | 
					
						
							|  |  |  |             try: | 
					
						
							| 
									
										
										
										
											2024-11-15 14:51:57 -08:00
										 |  |  |                 # Send input request to client | 
					
						
							| 
									
										
										
										
											2024-11-26 15:39:36 -08:00
										 |  |  |                 await self._send_message( | 
					
						
							|  |  |  |                     run_id, | 
					
						
							|  |  |  |                     { | 
					
						
							|  |  |  |                         "type": "input_request", | 
					
						
							|  |  |  |                         "prompt": prompt, | 
					
						
							|  |  |  |                         "data": {"source": "system", "content": prompt}, | 
					
						
							|  |  |  |                         "timestamp": datetime.now(timezone.utc).isoformat(), | 
					
						
							| 
									
										
										
										
											2024-11-15 14:51:57 -08:00
										 |  |  |                     }, | 
					
						
							| 
									
										
										
										
											2024-11-26 15:39:36 -08:00
										 |  |  |                 ) | 
					
						
							| 
									
										
										
										
											2024-11-15 14:51:57 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  |                 # Wait for response | 
					
						
							|  |  |  |                 if run_id in self._input_responses: | 
					
						
							|  |  |  |                     response = await self._input_responses[run_id].get() | 
					
						
							|  |  |  |                     return response | 
					
						
							|  |  |  |                 else: | 
					
						
							|  |  |  |                     raise ValueError(f"No input queue for run {run_id}") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             except Exception as e: | 
					
						
							|  |  |  |                 logger.error(f"Error handling input for run {run_id}: {e}") | 
					
						
							|  |  |  |                 raise | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return input_handler | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     async def handle_input_response(self, run_id: UUID, response: str) -> None: | 
					
						
							|  |  |  |         """Handle input response from client""" | 
					
						
							|  |  |  |         if run_id in self._input_responses: | 
					
						
							|  |  |  |             await self._input_responses[run_id].put(response) | 
					
						
							|  |  |  |         else: | 
					
						
							| 
									
										
										
										
											2024-11-26 15:39:36 -08:00
										 |  |  |             logger.warning(f"Received input response for inactive run {run_id}") | 
					
						
							| 
									
										
										
										
											2024-11-15 14:51:57 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     async def stop_run(self, run_id: UUID, reason: str) -> None: | 
					
						
							| 
									
										
										
										
											2024-11-09 14:32:24 -08:00
										 |  |  |         if run_id in self._cancellation_tokens: | 
					
						
							|  |  |  |             logger.info(f"Stopping run {run_id}") | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-11-26 15:39:36 -08:00
										 |  |  |             stop_message = self._get_stop_message(reason) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             try: | 
					
						
							|  |  |  |                 # Update run record first | 
					
						
							|  |  |  |                 await self._update_run(run_id, status=RunStatus.STOPPED, team_result=stop_message) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 # Then handle websocket communication if connection is active | 
					
						
							|  |  |  |                 if run_id in self._connections and run_id not in self._closed_connections: | 
					
						
							|  |  |  |                     await self._send_message( | 
					
						
							|  |  |  |                         run_id, | 
					
						
							|  |  |  |                         { | 
					
						
							|  |  |  |                             "type": "completion", | 
					
						
							|  |  |  |                             "status": "cancelled", | 
					
						
							|  |  |  |                             "data": stop_message, | 
					
						
							|  |  |  |                             "timestamp": datetime.now(timezone.utc).isoformat(), | 
					
						
							|  |  |  |                         }, | 
					
						
							|  |  |  |                     ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 # Finally cancel the token | 
					
						
							|  |  |  |                 self._cancellation_tokens[run_id].cancel() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             except Exception as e: | 
					
						
							|  |  |  |                 logger.error(f"Error stopping run {run_id}: {e}") | 
					
						
							|  |  |  |                 # We might want to force disconnect here if db update failed | 
					
						
							|  |  |  |                 # await self.disconnect(run_id)  # Optional | 
					
						
							| 
									
										
										
										
											2024-11-09 14:32:24 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     async def disconnect(self, run_id: UUID) -> None: | 
					
						
							|  |  |  |         """Clean up connection and associated resources""" | 
					
						
							|  |  |  |         logger.info(f"Disconnecting run {run_id}") | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-11-12 20:29:06 -08:00
										 |  |  |         # Mark as closed before cleanup to prevent any new messages | 
					
						
							|  |  |  |         self._closed_connections.add(run_id) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Cancel any running tasks | 
					
						
							| 
									
										
										
										
											2024-11-15 14:51:57 -08:00
										 |  |  |         await self.stop_run(run_id, "Connection closed") | 
					
						
							| 
									
										
										
										
											2024-11-09 14:32:24 -08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-11-12 20:29:06 -08:00
										 |  |  |         # Clean up resources | 
					
						
							|  |  |  |         self._connections.pop(run_id, None) | 
					
						
							|  |  |  |         self._cancellation_tokens.pop(run_id, None) | 
					
						
							| 
									
										
										
										
											2024-11-15 14:51:57 -08:00
										 |  |  |         self._input_responses.pop(run_id, None) | 
					
						
							| 
									
										
										
										
											2024-11-09 14:32:24 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     async def _send_message(self, run_id: UUID, message: dict) -> None: | 
					
						
							| 
									
										
										
										
											2024-11-12 20:29:06 -08:00
										 |  |  |         """Send a message through the WebSocket with connection state checking
 | 
					
						
							| 
									
										
										
										
											2024-11-09 14:32:24 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         Args: | 
					
						
							|  |  |  |             run_id: UUID of the run | 
					
						
							|  |  |  |             message: Message dictionary to send | 
					
						
							|  |  |  |         """
 | 
					
						
							| 
									
										
										
										
											2024-11-12 20:29:06 -08:00
										 |  |  |         if run_id in self._closed_connections: | 
					
						
							| 
									
										
										
										
											2024-11-26 15:39:36 -08:00
										 |  |  |             logger.warning(f"Attempted to send message to closed connection for run {run_id}") | 
					
						
							| 
									
										
										
										
											2024-11-12 20:29:06 -08:00
										 |  |  |             return | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-11-09 14:32:24 -08:00
										 |  |  |         try: | 
					
						
							|  |  |  |             if run_id in self._connections: | 
					
						
							| 
									
										
										
										
											2024-11-12 20:29:06 -08:00
										 |  |  |                 websocket = self._connections[run_id] | 
					
						
							|  |  |  |                 await websocket.send_json(message) | 
					
						
							| 
									
										
										
										
											2024-11-09 14:32:24 -08:00
										 |  |  |         except WebSocketDisconnect: | 
					
						
							| 
									
										
										
										
											2024-11-26 15:39:36 -08:00
										 |  |  |             logger.warning(f"WebSocket disconnected while sending message for run {run_id}") | 
					
						
							| 
									
										
										
										
											2024-11-09 14:32:24 -08:00
										 |  |  |             await self.disconnect(run_id) | 
					
						
							|  |  |  |         except Exception as e: | 
					
						
							| 
									
										
										
										
											2024-11-26 15:39:36 -08:00
										 |  |  |             logger.error(f"Error sending message for run {run_id}: {e}, {message}") | 
					
						
							| 
									
										
										
										
											2024-11-12 20:29:06 -08:00
										 |  |  |             # Don't try to send error message here to avoid potential recursive loop | 
					
						
							|  |  |  |             await self._update_run_status(run_id, RunStatus.ERROR, str(e)) | 
					
						
							|  |  |  |             await self.disconnect(run_id) | 
					
						
							| 
									
										
										
										
											2024-11-09 14:32:24 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     async def _handle_stream_error(self, run_id: UUID, error: Exception) -> None: | 
					
						
							| 
									
										
										
										
											2024-11-26 15:39:36 -08:00
										 |  |  |         """Handle stream errors with proper run updates""" | 
					
						
							| 
									
										
										
										
											2024-11-12 20:29:06 -08:00
										 |  |  |         if run_id not in self._closed_connections: | 
					
						
							| 
									
										
										
										
											2024-11-26 15:39:36 -08:00
										 |  |  |             error_result = TeamResult( | 
					
						
							|  |  |  |                 task_result=TaskResult( | 
					
						
							| 
									
										
										
										
											2024-12-16 13:17:42 -08:00
										 |  |  |                     messages=[TextMessage(source="system", content=str(error))], | 
					
						
							|  |  |  |                     stop_reason="An error occurred while processing this run", | 
					
						
							| 
									
										
										
										
											2024-11-26 15:39:36 -08:00
										 |  |  |                 ), | 
					
						
							|  |  |  |                 usage="", | 
					
						
							|  |  |  |                 duration=0, | 
					
						
							|  |  |  |             ).model_dump() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             await self._send_message( | 
					
						
							|  |  |  |                 run_id, | 
					
						
							|  |  |  |                 { | 
					
						
							| 
									
										
										
										
											2024-11-12 20:29:06 -08:00
										 |  |  |                     "type": "completion", | 
					
						
							|  |  |  |                     "status": "error", | 
					
						
							| 
									
										
										
										
											2024-11-26 15:39:36 -08:00
										 |  |  |                     "data": error_result, | 
					
						
							|  |  |  |                     "timestamp": datetime.now(timezone.utc).isoformat(), | 
					
						
							|  |  |  |                 }, | 
					
						
							|  |  |  |             ) | 
					
						
							| 
									
										
										
										
											2024-11-09 14:32:24 -08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-11-26 15:39:36 -08:00
										 |  |  |             await self._update_run(run_id, RunStatus.ERROR, team_result=error_result, error=str(error)) | 
					
						
							| 
									
										
										
										
											2024-11-09 14:32:24 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def _format_message(self, message: Any) -> Optional[dict]: | 
					
						
							|  |  |  |         """Format message for WebSocket transmission
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         Args: | 
					
						
							|  |  |  |             message: Message to format | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         Returns: | 
					
						
							|  |  |  |             Optional[dict]: Formatted message or None if formatting fails | 
					
						
							|  |  |  |         """
 | 
					
						
							| 
									
										
										
										
											2024-12-08 21:44:16 -08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-11-09 14:32:24 -08:00
										 |  |  |         try: | 
					
						
							| 
									
										
										
										
											2024-11-26 21:24:45 -08:00
										 |  |  |             if isinstance(message, MultiModalMessage): | 
					
						
							|  |  |  |                 message_dump = message.model_dump() | 
					
						
							|  |  |  |                 message_dump["content"] = [ | 
					
						
							|  |  |  |                     message_dump["content"][0], | 
					
						
							|  |  |  |                     { | 
					
						
							|  |  |  |                         "url": f"data:image/png;base64,{message_dump['content'][1]['data']}", | 
					
						
							|  |  |  |                         "alt": "WebSurfer Screenshot", | 
					
						
							|  |  |  |                     }, | 
					
						
							|  |  |  |                 ] | 
					
						
							|  |  |  |                 return {"type": "message", "data": message_dump} | 
					
						
							| 
									
										
										
										
											2024-12-08 21:44:16 -08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-11-09 14:32:24 -08:00
										 |  |  |             elif isinstance(message, TeamResult): | 
					
						
							|  |  |  |                 return { | 
					
						
							|  |  |  |                     "type": "result", | 
					
						
							|  |  |  |                     "data": message.model_dump(), | 
					
						
							|  |  |  |                     "status": "complete", | 
					
						
							|  |  |  |                 } | 
					
						
							| 
									
										
										
										
											2024-12-14 15:33:14 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  |             elif isinstance( | 
					
						
							| 
									
										
										
										
											2024-12-18 14:09:19 -08:00
										 |  |  |                 message, (TextMessage, StopMessage, HandoffMessage, ToolCallRequestEvent, ToolCallExecutionEvent) | 
					
						
							| 
									
										
										
										
											2024-12-14 15:33:14 -08:00
										 |  |  |             ): | 
					
						
							| 
									
										
										
										
											2024-12-08 21:44:16 -08:00
										 |  |  |                 return {"type": "message", "data": message.model_dump()} | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-11-09 14:32:24 -08:00
										 |  |  |             return None | 
					
						
							| 
									
										
										
										
											2024-12-14 15:33:14 -08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-11-09 14:32:24 -08:00
										 |  |  |         except Exception as e: | 
					
						
							|  |  |  |             logger.error(f"Message formatting error: {e}") | 
					
						
							|  |  |  |             return None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     async def _get_run(self, run_id: UUID) -> Optional[Run]: | 
					
						
							|  |  |  |         """Get run from database
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         Args: | 
					
						
							|  |  |  |             run_id: UUID of the run to retrieve | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         Returns: | 
					
						
							|  |  |  |             Optional[Run]: Run object if found, None otherwise | 
					
						
							|  |  |  |         """
 | 
					
						
							| 
									
										
										
										
											2024-11-26 15:39:36 -08:00
										 |  |  |         response = self.db_manager.get(Run, filters={"id": run_id}, return_json=False) | 
					
						
							| 
									
										
										
										
											2024-11-09 14:32:24 -08:00
										 |  |  |         return response.data[0] if response.status and response.data else None | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-11-26 15:39:36 -08:00
										 |  |  |     async def _update_run_status(self, run_id: UUID, status: RunStatus, error: Optional[str] = None) -> None: | 
					
						
							| 
									
										
										
										
											2024-11-09 14:32:24 -08:00
										 |  |  |         """Update run status in database
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         Args: | 
					
						
							|  |  |  |             run_id: UUID of the run to update | 
					
						
							|  |  |  |             status: New status to set | 
					
						
							|  |  |  |             error: Optional error message | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         run = await self._get_run(run_id) | 
					
						
							|  |  |  |         if run: | 
					
						
							|  |  |  |             run.status = status | 
					
						
							|  |  |  |             run.error_message = error | 
					
						
							|  |  |  |             self.db_manager.upsert(run) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-11-15 14:51:57 -08:00
										 |  |  |     async def cleanup(self) -> None: | 
					
						
							|  |  |  |         """Clean up all active connections and resources when server is shutting down""" | 
					
						
							| 
									
										
										
										
											2024-11-26 15:39:36 -08:00
										 |  |  |         logger.info(f"Cleaning up {len(self.active_connections)} active connections") | 
					
						
							| 
									
										
										
										
											2024-11-15 14:51:57 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         try: | 
					
						
							|  |  |  |             # First cancel all running tasks | 
					
						
							|  |  |  |             for run_id in self.active_runs.copy(): | 
					
						
							|  |  |  |                 if run_id in self._cancellation_tokens: | 
					
						
							|  |  |  |                     self._cancellation_tokens[run_id].cancel() | 
					
						
							| 
									
										
										
										
											2024-11-26 15:39:36 -08:00
										 |  |  |                 run = await self._get_run(run_id) | 
					
						
							|  |  |  |                 if run and run.status == RunStatus.ACTIVE: | 
					
						
							|  |  |  |                     interrupted_result = TeamResult( | 
					
						
							|  |  |  |                         task_result=TaskResult( | 
					
						
							|  |  |  |                             messages=[TextMessage(source="system", content="Run interrupted by server shutdown")], | 
					
						
							|  |  |  |                             stop_reason="server_shutdown", | 
					
						
							|  |  |  |                         ), | 
					
						
							|  |  |  |                         usage="", | 
					
						
							|  |  |  |                         duration=0, | 
					
						
							|  |  |  |                     ).model_dump() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                     run.status = RunStatus.STOPPED | 
					
						
							|  |  |  |                     run.team_result = interrupted_result | 
					
						
							|  |  |  |                     self.db_manager.upsert(run) | 
					
						
							| 
									
										
										
										
											2024-11-15 14:51:57 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  |             # Then disconnect all websockets with timeout | 
					
						
							|  |  |  |             # 10 second timeout for entire cleanup | 
					
						
							|  |  |  |             async with asyncio.timeout(10): | 
					
						
							|  |  |  |                 for run_id in self.active_connections.copy(): | 
					
						
							|  |  |  |                     try: | 
					
						
							|  |  |  |                         # Give each disconnect operation 2 seconds | 
					
						
							|  |  |  |                         async with asyncio.timeout(2): | 
					
						
							|  |  |  |                             await self.disconnect(run_id) | 
					
						
							|  |  |  |                     except asyncio.TimeoutError: | 
					
						
							|  |  |  |                         logger.warning(f"Timeout disconnecting run {run_id}") | 
					
						
							|  |  |  |                     except Exception as e: | 
					
						
							|  |  |  |                         logger.error(f"Error disconnecting run {run_id}: {e}") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         except asyncio.TimeoutError: | 
					
						
							|  |  |  |             logger.warning("WebSocketManager cleanup timed out") | 
					
						
							|  |  |  |         except Exception as e: | 
					
						
							|  |  |  |             logger.error(f"Error during WebSocketManager cleanup: {e}") | 
					
						
							|  |  |  |         finally: | 
					
						
							|  |  |  |             # Always clear internal state, even if cleanup had errors | 
					
						
							|  |  |  |             self._connections.clear() | 
					
						
							|  |  |  |             self._cancellation_tokens.clear() | 
					
						
							|  |  |  |             self._closed_connections.clear() | 
					
						
							|  |  |  |             self._input_responses.clear() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-11-09 14:32:24 -08:00
										 |  |  |     @property | 
					
						
							|  |  |  |     def active_connections(self) -> set[UUID]: | 
					
						
							|  |  |  |         """Get set of active run IDs""" | 
					
						
							| 
									
										
										
										
											2024-11-12 20:29:06 -08:00
										 |  |  |         return set(self._connections.keys()) - self._closed_connections | 
					
						
							| 
									
										
										
										
											2024-11-09 14:32:24 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     @property | 
					
						
							|  |  |  |     def active_runs(self) -> set[UUID]: | 
					
						
							|  |  |  |         """Get set of runs with active cancellation tokens""" | 
					
						
							|  |  |  |         return set(self._cancellation_tokens.keys()) |