Victor Dibia fe1feb3906
Enable Auth in AGS (#5928)
<!-- Thank you for your contribution! Please review
https://microsoft.github.io/autogen/docs/Contribute before opening a
pull request. -->

<!-- Please add a reviewer to the assignee section when you create a PR.
If you don't have the access to it, we will shortly find a reviewer and
assign them to your PR. -->

## Why are these changes needed?


https://github.com/user-attachments/assets/b649053b-c377-40c7-aa51-ee64af766fc2

<img width="100%" alt="image"
src="https://github.com/user-attachments/assets/03ba1df5-c9a2-4734-b6a2-0eb97ec0b0e0"
/>


## Authentication

This PR implements an experimental authentication feature to enable
personalized experiences (multiple users). Currently, only GitHub
authentication is supported. You can extend the base authentication
class to add support for other authentication methods.

By default authenticatio is disabled and only enabled when you pass in
the `--auth-config` argument when running the application.

### Enable GitHub Authentication

To enable GitHub authentication, create a `auth.yaml` file in your app
directory:

```yaml
type: github
jwt_secret: "your-secret-key"
token_expiry_minutes: 60
github:
  client_id: "your-github-client-id"
  client_secret: "your-github-client-secret"
  callback_url: "http://localhost:8081/api/auth/callback"
  scopes: ["user:email"]
```

Please see the documentation on [GitHub
OAuth](https://docs.github.com/en/apps/oauth-apps/building-oauth-apps/authenticating-to-the-rest-api-with-an-oauth-app)
for more details on obtaining the `client_id` and `client_secret`.

To pass in this configuration you can use the `--auth-config` argument
when running the application:

```bash
autogenstudio ui --auth-config /path/to/auth.yaml
```

Or set the environment variable:

```bash
export AUTOGENSTUDIO_AUTH_CONFIG="/path/to/auth.yaml"
```

```{note}
- Authentication is currently experimental and may change in future releases
- User data is stored in your configured database
- When enabled, all API endpoints require authentication except for the authentication endpoints
- WebSocket connections require the token to be passed as a query parameter (`?token=your-jwt-token`)

```

## Related issue number

<!-- For example: "Closes #1234" -->
Closes #4350  

## Checks

- [ ] I've included any doc changes needed for
<https://microsoft.github.io/autogen/>. See
<https://github.com/microsoft/autogen/blob/main/CONTRIBUTING.md> to
build and test documentation locally.
- [ ] I've added tests (if relevant) corresponding to the changes
introduced in this PR.
- [ ] I've made sure all auto checks have passed.

---------

Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com>
2025-03-14 15:02:05 -07:00

145 lines
5.9 KiB
Python

# api/ws.py
import asyncio
import json
from datetime import datetime
from fastapi import APIRouter, Depends, WebSocket, WebSocketDisconnect
from fastapi.websockets import WebSocketState
from loguru import logger
from ...datamodel import Run, RunStatus
from ..auth.dependencies import get_ws_auth_manager
from ..auth.wsauth import WebSocketAuthHandler
from ..deps import get_db, get_websocket_manager
from ..managers import WebSocketManager
router = APIRouter()
@router.websocket("/runs/{run_id}")
async def run_websocket(
websocket: WebSocket,
run_id: int,
ws_manager: WebSocketManager = Depends(get_websocket_manager),
db=Depends(get_db),
auth_manager=Depends(get_ws_auth_manager),
):
"""WebSocket endpoint for run communication"""
async def start_stream_wrapper(run_id, task, team_config):
try:
await ws_manager.start_stream(run_id, task, team_config)
except Exception as e:
logger.error(f"Error in start_stream for run {run_id}: {str(e)}")
# Optionally notify the client about the error
if websocket.client_state == WebSocketState.CONNECTED:
await websocket.send_json(
{
"type": "error",
"error": f"Stream processing error: {str(e)}",
"timestamp": datetime.utcnow().isoformat(),
}
)
try:
# Verify run exists before connecting
run_response = db.get(Run, filters={"id": run_id}, return_json=False)
if not run_response.status or not run_response.data:
await websocket.close(code=4004, reason="Run not found")
return
run = run_response.data[0]
if run.status not in [RunStatus.CREATED, RunStatus.ACTIVE]:
await websocket.close(code=4003, reason="Run not in valid state")
return
# Connect websocket (this handles acceptance internally)
connected = await ws_manager.connect(websocket, run_id)
if not connected:
return # No need to close here as connect() failure would have closed it
# Handle authentication if enabled
if auth_manager is not None:
ws_auth = WebSocketAuthHandler(auth_manager)
success, user = await ws_auth.authenticate(websocket)
if not success:
logger.warning(f"Authentication failed for WebSocket connection to run {run_id}")
await websocket.send_json(
{
"type": "error",
"error": "Authentication failed",
"timestamp": datetime.utcnow().isoformat(),
}
)
# Close the connection with a specific code
# await websocket.close(code=4001, reason="Authentication failed")
return
if user and run.user_id != user.id and "admin" not in (user.roles or []):
await websocket.send_json(
{
"type": "error",
"error": "Authentication failed",
"timestamp": datetime.utcnow().isoformat(),
}
)
logger.warning(f"User {user.id} not authorized to access run {run_id}")
# await websocket.close(code=4003, reason="Not authorized to access this run")
return
logger.info(f"WebSocket connection established for run {run_id}")
while True:
try:
raw_message = await websocket.receive_text()
message = json.loads(raw_message)
if message.get("type") == "start":
# Handle start message
logger.info(f"Received start request for run {run_id}")
task = message.get("task")
team_config = message.get("team_config")
if task and team_config:
# Start the stream in a separate task
asyncio.create_task(start_stream_wrapper(run_id, task, team_config))
else:
logger.warning(f"Invalid start message format for run {run_id}")
await websocket.send_json(
{
"type": "error",
"error": "Invalid start message format",
"timestamp": datetime.utcnow().isoformat(),
}
)
elif message.get("type") == "stop":
logger.info(f"Received stop request for run {run_id}")
reason = message.get("reason") or "User requested stop/cancellation"
await ws_manager.stop_run(run_id, reason=reason)
break
elif message.get("type") == "ping":
await websocket.send_json({"type": "pong", "timestamp": datetime.utcnow().isoformat()})
elif message.get("type") == "input_response":
# Handle input response from client
response = message.get("response")
if response is not None:
await ws_manager.handle_input_response(run_id, response)
else:
logger.warning(f"Invalid input response format for run {run_id}")
except json.JSONDecodeError:
logger.warning(f"Invalid JSON received: {raw_message}")
await websocket.send_json(
{"type": "error", "error": "Invalid message format", "timestamp": datetime.utcnow().isoformat()}
)
except WebSocketDisconnect:
logger.info(f"WebSocket disconnected for run {run_id}")
except Exception as e:
logger.error(f"WebSocket error: {str(e)}")
finally:
await ws_manager.disconnect(run_id)