mirror of
https://github.com/microsoft/autogen.git
synced 2025-07-10 18:41:30 +00:00
120 lines
4.1 KiB
Python
120 lines
4.1 KiB
Python
![]() |
import json
|
||
|
import re
|
||
|
|
||
|
from fastapi import Request, Response, WebSocket
|
||
|
from loguru import logger
|
||
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||
|
from starlette.status import HTTP_401_UNAUTHORIZED
|
||
|
|
||
|
from .exceptions import AuthException
|
||
|
from .manager import AuthManager
|
||
|
|
||
|
|
||
|
class AuthMiddleware(BaseHTTPMiddleware):
|
||
|
"""
|
||
|
Middleware for handling authentication for all routes.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, app, auth_manager: AuthManager):
|
||
|
super().__init__(app)
|
||
|
self.auth_manager = auth_manager
|
||
|
|
||
|
async def dispatch(self, request: Request, call_next):
|
||
|
"""Process each request, authenticating as needed."""
|
||
|
# Skip auth for OPTIONS requests (CORS preflight)
|
||
|
if request.method == "OPTIONS":
|
||
|
return await call_next(request)
|
||
|
|
||
|
path = request.url.path
|
||
|
|
||
|
if (
|
||
|
path == "/"
|
||
|
or path == "/login"
|
||
|
or path == "/callback"
|
||
|
or path == "/images"
|
||
|
or path.startswith("/page-data/")
|
||
|
or path in self.auth_manager.config.exclude_paths
|
||
|
or re.match(r"/[^/]+\.(js|css|png|ico|svg|jpg|webmanifest|json)$", path)
|
||
|
or re.match(r".*\.(js\.map|svg)$", path)
|
||
|
):
|
||
|
return await call_next(request)
|
||
|
|
||
|
# Skip auth if disabled
|
||
|
if self.auth_manager.config.type == "none":
|
||
|
request.state.user = await self.auth_manager.authenticate_request(request)
|
||
|
return await call_next(request)
|
||
|
|
||
|
# WebSocket handling (special case)
|
||
|
if request.url.path.startswith("/api/ws"):
|
||
|
# For WebSockets, we'll add auth in the WebSocket accept handler
|
||
|
# Just pass through here
|
||
|
return await call_next(request)
|
||
|
|
||
|
# Handle authentication for all other requests
|
||
|
try:
|
||
|
user = await self.auth_manager.authenticate_request(request)
|
||
|
# Add user to request state for use in route handlers
|
||
|
request.state.user = user
|
||
|
return await call_next(request)
|
||
|
|
||
|
except AuthException as e:
|
||
|
# Handle authentication errors
|
||
|
return Response(
|
||
|
status_code=HTTP_401_UNAUTHORIZED,
|
||
|
content=json.dumps({"status": False, "detail": e.detail}),
|
||
|
media_type="application/json",
|
||
|
headers=e.headers or {},
|
||
|
)
|
||
|
except Exception as e:
|
||
|
# Log unexpected errors
|
||
|
logger.error(f"Unexpected error in auth middleware: {str(e)}")
|
||
|
return Response(
|
||
|
status_code=HTTP_401_UNAUTHORIZED,
|
||
|
content=json.dumps({"status": False, "detail": "Authentication failed"}),
|
||
|
media_type="application/json",
|
||
|
)
|
||
|
|
||
|
|
||
|
class WebSocketAuthMiddleware:
|
||
|
"""
|
||
|
Helper for authenticating WebSocket connections.
|
||
|
Not a middleware in the traditional sense - used in WebSocket endpoint.
|
||
|
"""
|
||
|
|
||
|
def __init__(self, auth_manager: AuthManager):
|
||
|
self.auth_manager = auth_manager
|
||
|
|
||
|
async def authenticate(self, websocket: WebSocket) -> bool:
|
||
|
"""
|
||
|
Authenticate a WebSocket connection.
|
||
|
Returns True if authenticated, False otherwise.
|
||
|
"""
|
||
|
if self.auth_manager.config.type == "none":
|
||
|
return True
|
||
|
|
||
|
try:
|
||
|
# Extract token from query params or cookies
|
||
|
token = None
|
||
|
if "token" in websocket.query_params:
|
||
|
token = websocket.query_params["token"]
|
||
|
elif "authorization" in websocket.headers:
|
||
|
auth_header = websocket.headers["authorization"]
|
||
|
if auth_header.startswith("Bearer "):
|
||
|
token = auth_header.replace("Bearer ", "")
|
||
|
|
||
|
if not token:
|
||
|
logger.warning("No token found for WebSocket connection")
|
||
|
return False
|
||
|
|
||
|
# Validate token
|
||
|
valid = self.auth_manager.is_valid_token(token)
|
||
|
if not valid:
|
||
|
logger.warning("Invalid token for WebSocket connection")
|
||
|
return False
|
||
|
|
||
|
return True
|
||
|
|
||
|
except Exception as e:
|
||
|
logger.error(f"WebSocket auth error: {str(e)}")
|
||
|
return False
|