import asyncio import os import queue import threading import traceback from contextlib import asynccontextmanager from typing import Any, Union from fastapi import FastAPI, WebSocket, WebSocketDisconnect from fastapi.middleware.cors import CORSMiddleware from fastapi.staticfiles import StaticFiles from loguru import logger from openai import OpenAIError from ..chatmanager import AutoGenChatManager from ..database import workflow_from_id from ..database.dbmanager import DBManager from ..datamodel import Agent, Message, Model, Response, Session, Skill, Workflow from ..profiler import Profiler from ..utils import check_and_cast_datetime_fields, init_app_folders, md5_hash, test_model from ..version import VERSION from ..websocket_connection_manager import WebSocketConnectionManager profiler = Profiler() managers = {"chat": None} # manage calls to autogen # Create thread-safe queue for messages between api thread and autogen threads message_queue = queue.Queue() active_connections = [] active_connections_lock = asyncio.Lock() websocket_manager = WebSocketConnectionManager( active_connections=active_connections, active_connections_lock=active_connections_lock, ) def message_handler(): while True: message = message_queue.get() logger.info( "** Processing Agent Message on Queue: Active Connections: " + str([client_id for _, client_id in websocket_manager.active_connections]) + " **" ) for connection, socket_client_id in websocket_manager.active_connections: if message["connection_id"] == socket_client_id: logger.info( f"Sending message to connection_id: {message['connection_id']}. Connection ID: {socket_client_id}" ) asyncio.run(websocket_manager.send_message(message, connection)) else: logger.info( f"Skipping message for connection_id: {message['connection_id']}. Connection ID: {socket_client_id}" ) message_queue.task_done() message_handler_thread = threading.Thread(target=message_handler, daemon=True) message_handler_thread.start() app_file_path = os.path.dirname(os.path.abspath(__file__)) folders = init_app_folders(app_file_path) ui_folder_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "ui") database_engine_uri = folders["database_engine_uri"] dbmanager = DBManager(engine_uri=database_engine_uri) HUMAN_INPUT_TIMEOUT_SECONDS = 180 @asynccontextmanager async def lifespan(app: FastAPI): print("***** App started *****") managers["chat"] = AutoGenChatManager( message_queue=message_queue, websocket_manager=websocket_manager, human_input_timeout=HUMAN_INPUT_TIMEOUT_SECONDS, ) dbmanager.create_db_and_tables() yield # Close all active connections await websocket_manager.disconnect_all() print("***** App stopped *****") app = FastAPI(lifespan=lifespan) # allow cross origin requests for testing on localhost:800* ports only app.add_middleware( CORSMiddleware, allow_origins=[ "http://localhost:8000", "http://127.0.0.1:8000", "http://localhost:8001", "http://localhost:8081", ], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) show_docs = os.environ.get("AUTOGENSTUDIO_API_DOCS", "False").lower() == "true" docs_url = "/docs" if show_docs else None api = FastAPI( root_path="/api", title="AutoGen Studio API", version=VERSION, docs_url=docs_url, description="AutoGen Studio is a low-code tool for building and testing multi-agent workflows using AutoGen.", ) # mount an api route such that the main route serves the ui and the /api app.mount("/api", api) app.mount("/", StaticFiles(directory=ui_folder_path, html=True), name="ui") api.mount( "/files", StaticFiles(directory=folders["files_static_root"], html=True), name="files", ) # manage websocket connections def create_entity(model: Any, model_class: Any, filters: dict = None): """Create a new entity""" model = check_and_cast_datetime_fields(model) try: response: Response = dbmanager.upsert(model) return response.model_dump(mode="json") except Exception as ex_error: print(ex_error) return { "status": False, "message": f"Error occurred while creating {model_class.__name__}: " + str(ex_error), } def list_entity( model_class: Any, filters: dict = None, return_json: bool = True, order: str = "desc", ): """List all entities for a user""" return dbmanager.get(model_class, filters=filters, return_json=return_json, order=order) def delete_entity(model_class: Any, filters: dict = None): """Delete an entity""" return dbmanager.delete(filters=filters, model_class=model_class) @api.get("/skills") async def list_skills(user_id: str): """List all skills for a user""" filters = {"user_id": user_id} return list_entity(Skill, filters=filters) @api.post("/skills") async def create_skill(skill: Skill): """Create a new skill""" filters = {"user_id": skill.user_id} return create_entity(skill, Skill, filters=filters) @api.delete("/skills/delete") async def delete_skill(skill_id: int, user_id: str): """Delete a skill""" filters = {"id": skill_id, "user_id": user_id} return delete_entity(Skill, filters=filters) @api.get("/models") async def list_models(user_id: str): """List all models for a user""" filters = {"user_id": user_id} return list_entity(Model, filters=filters) @api.post("/models") async def create_model(model: Model): """Create a new model""" return create_entity(model, Model) @api.post("/models/test") async def test_model_endpoint(model: Model): """Test a model""" try: response = test_model(model) return { "status": True, "message": "Model tested successfully", "data": response, } except (OpenAIError, Exception) as ex_error: return { "status": False, "message": "Error occurred while testing model: " + str(ex_error), } @api.delete("/models/delete") async def delete_model(model_id: int, user_id: str): """Delete a model""" filters = {"id": model_id, "user_id": user_id} return delete_entity(Model, filters=filters) @api.get("/agents") async def list_agents(user_id: str): """List all agents for a user""" filters = {"user_id": user_id} return list_entity(Agent, filters=filters) @api.post("/agents") async def create_agent(agent: Agent): """Create a new agent""" return create_entity(agent, Agent) @api.delete("/agents/delete") async def delete_agent(agent_id: int, user_id: str): """Delete an agent""" filters = {"id": agent_id, "user_id": user_id} return delete_entity(Agent, filters=filters) @api.post("/agents/link/model/{agent_id}/{model_id}") async def link_agent_model(agent_id: int, model_id: int): """Link a model to an agent""" return dbmanager.link(link_type="agent_model", primary_id=agent_id, secondary_id=model_id) @api.delete("/agents/link/model/{agent_id}/{model_id}") async def unlink_agent_model(agent_id: int, model_id: int): """Unlink a model from an agent""" return dbmanager.unlink(link_type="agent_model", primary_id=agent_id, secondary_id=model_id) @api.get("/agents/link/model/{agent_id}") async def get_agent_models(agent_id: int): """Get all models linked to an agent""" return dbmanager.get_linked_entities("agent_model", agent_id, return_json=True) @api.post("/agents/link/skill/{agent_id}/{skill_id}") async def link_agent_skill(agent_id: int, skill_id: int): """Link an a skill to an agent""" return dbmanager.link(link_type="agent_skill", primary_id=agent_id, secondary_id=skill_id) @api.delete("/agents/link/skill/{agent_id}/{skill_id}") async def unlink_agent_skill(agent_id: int, skill_id: int): """Unlink an a skill from an agent""" return dbmanager.unlink(link_type="agent_skill", primary_id=agent_id, secondary_id=skill_id) @api.get("/agents/link/skill/{agent_id}") async def get_agent_skills(agent_id: int): """Get all skills linked to an agent""" return dbmanager.get_linked_entities("agent_skill", agent_id, return_json=True) @api.post("/agents/link/agent/{primary_agent_id}/{secondary_agent_id}") async def link_agent_agent(primary_agent_id: int, secondary_agent_id: int): """Link an agent to another agent""" return dbmanager.link( link_type="agent_agent", primary_id=primary_agent_id, secondary_id=secondary_agent_id, ) @api.delete("/agents/link/agent/{primary_agent_id}/{secondary_agent_id}") async def unlink_agent_agent(primary_agent_id: int, secondary_agent_id: int): """Unlink an agent from another agent""" return dbmanager.unlink( link_type="agent_agent", primary_id=primary_agent_id, secondary_id=secondary_agent_id, ) @api.get("/agents/link/agent/{agent_id}") async def get_linked_agents(agent_id: int): """Get all agents linked to an agent""" return dbmanager.get_linked_entities("agent_agent", agent_id, return_json=True) @api.get("/workflows") async def list_workflows(user_id: str): """List all workflows for a user""" filters = {"user_id": user_id} return list_entity(Workflow, filters=filters) @api.get("/workflows/{workflow_id}") async def get_workflow(workflow_id: int, user_id: str): """Get a workflow""" filters = {"id": workflow_id, "user_id": user_id} return list_entity(Workflow, filters=filters) @api.get("/workflows/export/{workflow_id}") async def export_workflow(workflow_id: int, user_id: str): """Export a user workflow""" response = Response(message="Workflow exported successfully", status=True, data=None) try: workflow_details = workflow_from_id(workflow_id, dbmanager=dbmanager) response.data = workflow_details except Exception as ex_error: response.message = "Error occurred while exporting workflow: " + str(ex_error) response.status = False return response.model_dump(mode="json") @api.post("/workflows") async def create_workflow(workflow: Workflow): """Create a new workflow""" return create_entity(workflow, Workflow) @api.delete("/workflows/delete") async def delete_workflow(workflow_id: int, user_id: str): """Delete a workflow""" filters = {"id": workflow_id, "user_id": user_id} return delete_entity(Workflow, filters=filters) @api.post("/workflows/link/agent/{workflow_id}/{agent_id}/{agent_type}") async def link_workflow_agent(workflow_id: int, agent_id: int, agent_type: str): """Link an agent to a workflow""" return dbmanager.link( link_type="workflow_agent", primary_id=workflow_id, secondary_id=agent_id, agent_type=agent_type, ) @api.post("/workflows/link/agent/{workflow_id}/{agent_id}/{agent_type}/{sequence_id}") async def link_workflow_agent_sequence(workflow_id: int, agent_id: int, agent_type: str, sequence_id: int): """Link an agent to a workflow""" print("Sequence ID: ", sequence_id) return dbmanager.link( link_type="workflow_agent", primary_id=workflow_id, secondary_id=agent_id, agent_type=agent_type, sequence_id=sequence_id, ) @api.delete("/workflows/link/agent/{workflow_id}/{agent_id}/{agent_type}") async def unlink_workflow_agent(workflow_id: int, agent_id: int, agent_type: str): """Unlink an agent from a workflow""" return dbmanager.unlink( link_type="workflow_agent", primary_id=workflow_id, secondary_id=agent_id, agent_type=agent_type, ) @api.delete("/workflows/link/agent/{workflow_id}/{agent_id}/{agent_type}/{sequence_id}") async def unlink_workflow_agent_sequence(workflow_id: int, agent_id: int, agent_type: str, sequence_id: int): """Unlink an agent from a workflow sequence""" return dbmanager.unlink( link_type="workflow_agent", primary_id=workflow_id, secondary_id=agent_id, agent_type=agent_type, sequence_id=sequence_id, ) @api.get("/workflows/link/agent/{workflow_id}") async def get_linked_workflow_agents(workflow_id: int): """Get all agents linked to a workflow""" return dbmanager.get_linked_entities( link_type="workflow_agent", primary_id=workflow_id, return_json=True, ) @api.get("/profiler/{message_id}") async def profile_agent_task_run(message_id: int): """Profile an agent task run""" try: agent_message = dbmanager.get(Message, filters={"id": message_id}).data[0] profile = profiler.profile(agent_message) return { "status": True, "message": "Agent task run profiled successfully", "data": profile, } except Exception as ex_error: return { "status": False, "message": "Error occurred while profiling agent task run: " + str(ex_error), } @api.get("/sessions") async def list_sessions(user_id: str): """List all sessions for a user""" filters = {"user_id": user_id} return list_entity(Session, filters=filters) @api.post("/sessions") async def create_session(session: Session): """Create a new session""" return create_entity(session, Session) @api.delete("/sessions/delete") async def delete_session(session_id: int, user_id: str): """Delete a session""" filters = {"id": session_id, "user_id": user_id} return delete_entity(Session, filters=filters) @api.get("/sessions/{session_id}/messages") async def list_messages(user_id: str, session_id: int): """List all messages for a use session""" filters = {"user_id": user_id, "session_id": session_id} return list_entity(Message, filters=filters, order="asc", return_json=True) @api.post("/sessions/{session_id}/workflow/{workflow_id}/run") async def run_session_workflow(message: Message, session_id: int, workflow_id: int): """Runs a workflow on provided message""" try: user_message_history = ( dbmanager.get( Message, filters={"user_id": message.user_id, "session_id": message.session_id}, return_json=True, ).data if session_id is not None else [] ) # save incoming message dbmanager.upsert(message) user_dir = os.path.join(folders["files_static_root"], "user", md5_hash(message.user_id)) os.makedirs(user_dir, exist_ok=True) workflow = workflow_from_id(workflow_id, dbmanager=dbmanager) agent_response: Message = await managers["chat"].a_chat( message=message, history=user_message_history, user_dir=user_dir, workflow=workflow, connection_id=message.connection_id, ) response: Response = dbmanager.upsert(agent_response) return response.model_dump(mode="json") except Exception as ex_error: return { "status": False, "message": "Error occurred while processing message: " + str(ex_error), } @api.get("/version") async def get_version(): return { "status": True, "message": "Version retrieved successfully", "data": {"version": VERSION}, } # websockets async def process_socket_message(data: dict, websocket: WebSocket, client_id: str): print(f"Client says: {data['type']}") if data["type"] == "user_message": user_message = Message(**data["data"]) session_id = data["data"].get("session_id", None) workflow_id = data["data"].get("workflow_id", None) response = await run_session_workflow(message=user_message, session_id=session_id, workflow_id=workflow_id) response_socket_message = { "type": "agent_response", "data": response, "connection_id": client_id, } await websocket_manager.send_message(response_socket_message, websocket) @api.websocket("/ws/{client_id}") async def websocket_endpoint(websocket: WebSocket, client_id: str): await websocket_manager.connect(websocket, client_id) try: while True: data = await websocket.receive_json() await process_socket_message(data, websocket, client_id) except WebSocketDisconnect: print(f"Client #{client_id} is disconnected") await websocket_manager.disconnect(websocket)