mirror of
https://github.com/microsoft/autogen.git
synced 2025-07-09 10:01:37 +00:00
88 lines
3.3 KiB
Python
88 lines
3.3 KiB
Python
![]() |
import jwt
|
||
|
from fastapi import WebSocket, WebSocketDisconnect, status
|
||
|
from loguru import logger
|
||
|
|
||
|
from .manager import AuthManager
|
||
|
from .models import User
|
||
|
|
||
|
|
||
|
class WebSocketAuthHandler:
|
||
|
"""
|
||
|
Helper class for authenticating WebSocket connections.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, auth_manager: AuthManager):
|
||
|
self.auth_manager = auth_manager
|
||
|
|
||
|
async def authenticate(self, websocket: WebSocket) -> tuple[bool, User | None]:
|
||
|
"""
|
||
|
Authenticate a WebSocket connection.
|
||
|
Returns (success, user) tuple.
|
||
|
"""
|
||
|
if self.auth_manager.config.type == "none":
|
||
|
# No authentication required
|
||
|
return True, User(id="guestuser@gmail.com", name="Default User", provider="none")
|
||
|
|
||
|
try:
|
||
|
# Extract token from query params or headers query_params)
|
||
|
token = None
|
||
|
if "token" in websocket.query_params:
|
||
|
token = websocket.query_params["token"]
|
||
|
elif "authorization" in websocket.headers:
|
||
|
auth_header = websocket.headers["authorization"]
|
||
|
if auth_header.startswith("Bearer "):
|
||
|
token = auth_header.replace("Bearer ", "")
|
||
|
|
||
|
if not token:
|
||
|
logger.warning("No token found for WebSocket connection")
|
||
|
return False, None
|
||
|
|
||
|
# Validate token
|
||
|
if not self.auth_manager.config.jwt_secret:
|
||
|
# Development mode with no JWT secret
|
||
|
return True, User(id="guestuser@gmail.com", name="Default User", provider="none")
|
||
|
|
||
|
try:
|
||
|
# Decode and validate JWT
|
||
|
if not self.auth_manager.config.jwt_secret:
|
||
|
logger.warning("Invalid token for WebSocket connection")
|
||
|
return False, None
|
||
|
payload = jwt.decode(token, self.auth_manager.config.jwt_secret, algorithms=["HS256"])
|
||
|
|
||
|
# Create User object from token payload
|
||
|
user = User(
|
||
|
id=payload.get("sub"),
|
||
|
name=payload.get("name", "Unknown User"),
|
||
|
email=payload.get("email"),
|
||
|
provider=payload.get("provider", "jwt"),
|
||
|
roles=payload.get("roles", ["user"]),
|
||
|
)
|
||
|
|
||
|
return True, user
|
||
|
|
||
|
except jwt.ExpiredSignatureError:
|
||
|
logger.warning("Expired token for WebSocket connection")
|
||
|
return False, None
|
||
|
except jwt.InvalidTokenError:
|
||
|
logger.warning("Invalid token for WebSocket connection")
|
||
|
return False, None
|
||
|
|
||
|
except Exception as e:
|
||
|
logger.error(f"WebSocket auth error: {str(e)}")
|
||
|
return False, None
|
||
|
|
||
|
async def on_connect(self, websocket: WebSocket) -> User | None:
|
||
|
"""
|
||
|
Handle WebSocket connection with authentication.
|
||
|
Returns authenticated user if successful, otherwise closes the connection.
|
||
|
"""
|
||
|
success, user = await self.authenticate(websocket)
|
||
|
|
||
|
if not success:
|
||
|
# Authentication failed, close the connection
|
||
|
await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason="Authentication failed")
|
||
|
raise WebSocketDisconnect(code=status.WS_1008_POLICY_VIOLATION)
|
||
|
|
||
|
# Authentication successful, return the user
|
||
|
return user
|