mirror of
				https://github.com/microsoft/autogen.git
				synced 2025-10-25 23:10:01 +00:00 
			
		
		
		
	Add sample chat application with FastAPI (#5433)
Introduce a sample chat application using AgentChat and FastAPI, demonstrating single-agent and team chat functionalities, along with state persistence and conversation history management. Resolves #5423 --------- Co-authored-by: Victor Dibia <victor.dibia@gmail.com> Co-authored-by: Victor Dibia <victordibia@microsoft.com>
This commit is contained in:
		
							parent
							
								
									f20ba9127d
								
							
						
					
					
						commit
						abdc0da4f1
					
				
							
								
								
									
										5
									
								
								python/samples/agentchat_fastapi/.gitignore
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								python/samples/agentchat_fastapi/.gitignore
									
									
									
									
										vendored
									
									
										Normal file
									
								
							| @ -0,0 +1,5 @@ | ||||
| model_config.yaml | ||||
| agent_state.json | ||||
| agent_history.json | ||||
| team_state.json | ||||
| team_history.json | ||||
							
								
								
									
										70
									
								
								python/samples/agentchat_fastapi/README.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										70
									
								
								python/samples/agentchat_fastapi/README.md
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,70 @@ | ||||
| # AgentChat App with FastAPI | ||||
| 
 | ||||
| This sample demonstrates how to create a simple chat application using | ||||
| [AgentChat](https://microsoft.github.io/autogen/stable/user-guide/agentchat-user-guide/index.html) | ||||
| and [FastAPI](https://fastapi.tiangolo.com/). | ||||
| 
 | ||||
| You will be using the following features of AgentChat: | ||||
| 
 | ||||
| 1. Agent: | ||||
|    - `AssistantAgent` | ||||
|    - `UserProxyAgent` with a custom websocket input function | ||||
| 2. Team: `RoundRobinGroupChat` | ||||
| 3. State persistence: `save_state` and `load_state` methods of both agent and team. | ||||
| 
 | ||||
| ## Setup | ||||
| 
 | ||||
| Install the required packages with OpenAI support: | ||||
| 
 | ||||
| ```bash | ||||
| pip install -U "autogen-ext[openai]" "fastapi" "uvicorn" "PyYAML" | ||||
| ``` | ||||
| 
 | ||||
| To use models other than OpenAI, see the [Models](https://microsoft.github.io/autogen/stable/user-guide/agentchat-user-guide/tutorial/models.html) documentation. | ||||
| 
 | ||||
| Create a new file named `model_config.yaml` in the same directory as this README file to configure your model settings. | ||||
| See `model_config_template.yaml` for an example. | ||||
| 
 | ||||
| ## Chat with a single agent | ||||
| 
 | ||||
| To start the FastAPI server for single-agent chat, run: | ||||
| 
 | ||||
| ```bash | ||||
| python app_agent.py | ||||
| ``` | ||||
| 
 | ||||
| Visit http://localhost:8001 in your browser to start chatting. | ||||
| 
 | ||||
| ## Chat with a team of agents | ||||
| 
 | ||||
| To start the FastAPI server for team chat, run: | ||||
| 
 | ||||
| ```bash | ||||
| python app_team.py | ||||
| ``` | ||||
| 
 | ||||
| Visit http://localhost:8002 in your browser to start chatting. | ||||
| 
 | ||||
| The team also includes a `UserProxyAgent` agent with a custom websocket input function | ||||
| that allows the user to send messages to the team from the browser. | ||||
| 
 | ||||
| The team follows a round-robin strategy so each agent will take turns to respond. | ||||
| When it is the user's turn, the input box will be enabled. | ||||
| Once the user sends a message, the input box will be disabled and the agents | ||||
| will take turns to respond. | ||||
| 
 | ||||
| ## State persistence | ||||
| 
 | ||||
| The agents and team use the `load_state` and `save_state` methods to load and save | ||||
| their state from and to files on each turn. | ||||
| For the agent, the state is saved to and loaded from `agent_state.json`. | ||||
| For the team, the state is saved to and loaded from `team_state.json`. | ||||
| You can inspect the state files to see the state of the agents and team | ||||
| once you have chatted with them. | ||||
| 
 | ||||
| When the server restarts, the agents and team will load their state from the state files | ||||
| to maintain their state across restarts. | ||||
| 
 | ||||
| Additionally, the apps uses separate JSON files, | ||||
| `agent_history.json` and `team_history.json`, to store the conversation history | ||||
| for display in the browser. | ||||
							
								
								
									
										195
									
								
								python/samples/agentchat_fastapi/app_agent.html
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										195
									
								
								python/samples/agentchat_fastapi/app_agent.html
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,195 @@ | ||||
| <!DOCTYPE html> | ||||
| <html lang="en"> | ||||
| 
 | ||||
| <head> | ||||
|     <meta charset="UTF-8"> | ||||
|     <meta name="viewport" content="width=device-width, initial-scale=1.0"> | ||||
|     <title>AutoGen FastAPI Sample: Agent</title> | ||||
|     <style> | ||||
|         body { | ||||
|             font-family: Arial, sans-serif; | ||||
|             margin: 0; | ||||
|             padding: 0; | ||||
|             display: flex; | ||||
|             flex-direction: column; | ||||
|             align-items: center; | ||||
|             justify-content: center; | ||||
|             height: 100vh; | ||||
|             background-color: #f0f0f0; | ||||
|         } | ||||
| 
 | ||||
|         #chat-container { | ||||
|             width: 90%; | ||||
|             max-width: 600px; | ||||
|             background-color: #fff; | ||||
|             border-radius: 8px; | ||||
|             box-shadow: 0 0 10px rgba(0, 0, 0, 0.1); | ||||
|             padding: 20px; | ||||
|         } | ||||
| 
 | ||||
|         #messages { | ||||
|             height: 600px; | ||||
|             overflow-y: auto; | ||||
|             border-bottom: 1px solid #ddd; | ||||
|             margin-bottom: 20px; | ||||
|         } | ||||
| 
 | ||||
|         .message { | ||||
|             margin: 10px 0; | ||||
|         } | ||||
| 
 | ||||
|         .message.user { | ||||
|             text-align: right; | ||||
|         } | ||||
| 
 | ||||
|         .message.assistant { | ||||
|             text-align: left; | ||||
|         } | ||||
| 
 | ||||
|         .message.error { | ||||
|             color: #721c24; | ||||
|             background-color: #f8d7da; | ||||
|             border: 1px solid #f5c6cb; | ||||
|             padding: 10px; | ||||
|             border-radius: 4px; | ||||
|             margin: 10px 0; | ||||
|         } | ||||
| 
 | ||||
|         .message.system { | ||||
|             color: #0c5460; | ||||
|             background-color: #d1ecf1; | ||||
|             border: 1px solid #bee5eb; | ||||
|             padding: 10px; | ||||
|             border-radius: 4px; | ||||
|             margin: 10px 0; | ||||
|         } | ||||
| 
 | ||||
|         #input-container input:disabled, | ||||
|         #input-container button:disabled { | ||||
|             background-color: #e0e0e0; | ||||
|             cursor: not-allowed; | ||||
|         } | ||||
| 
 | ||||
|         #input-container { | ||||
|             display: flex; | ||||
|         } | ||||
| 
 | ||||
|         #input-container input { | ||||
|             flex: 1; | ||||
|             padding: 10px; | ||||
|             border: 1px solid #ddd; | ||||
|             border-radius: 4px; | ||||
|         } | ||||
| 
 | ||||
|         #input-container button { | ||||
|             padding: 10px 20px; | ||||
|             border: none; | ||||
|             background-color: #007bff; | ||||
|             color: #fff; | ||||
|             border-radius: 4px; | ||||
|             cursor: pointer; | ||||
|         } | ||||
|     </style> | ||||
| </head> | ||||
| 
 | ||||
| <body> | ||||
|     <div id="chat-container"> | ||||
|         <div id="messages"></div> | ||||
|         <div id="input-container"> | ||||
|             <input type="text" id="message-input" placeholder="Type a message..."> | ||||
|             <button onclick="sendMessage()">Send</button> | ||||
|         </div> | ||||
|     </div> | ||||
| 
 | ||||
|     <script> | ||||
|         document.getElementById('message-input').addEventListener('keydown', function (event) { | ||||
|             if (event.key === 'Enter') { | ||||
|                 sendMessage(); | ||||
|             } | ||||
|         }); | ||||
| 
 | ||||
|         async function sendMessage() { | ||||
|             const input = document.getElementById('message-input'); | ||||
|             const button = document.querySelector('#input-container button'); | ||||
|             const message = input.value; | ||||
|             if (!message) return; | ||||
| 
 | ||||
|             // Display user message | ||||
|             displayMessage(message, 'user'); | ||||
| 
 | ||||
|             // Clear input and disable controls | ||||
|             input.value = ''; | ||||
|             input.disabled = true; | ||||
|             button.disabled = true; | ||||
| 
 | ||||
|             try { | ||||
|                 const response = await fetch('http://localhost:8001/chat', { | ||||
|                     method: 'POST', | ||||
|                     headers: { | ||||
|                         'Content-Type': 'application/json' | ||||
|                     }, | ||||
|                     body: JSON.stringify({ content: message, source: 'user' }) | ||||
|                 }); | ||||
| 
 | ||||
|                 const data = await response.json(); | ||||
|                 if (!response.ok) { | ||||
|                     // Handle error response | ||||
|                     if (data.detail && data.detail.type === 'error') { | ||||
|                         displayMessage(data.detail.content, 'error'); | ||||
|                     } else { | ||||
|                         displayMessage('Error: ' + (data.detail || 'Unknown error'), 'error'); | ||||
|                     } | ||||
|                 } else { | ||||
|                     displayMessage(data.content, 'assistant'); | ||||
|                 } | ||||
|             } catch (error) { | ||||
|                 console.error('Error:', error); | ||||
|                 displayMessage('Error: Could not reach the server.', 'error'); | ||||
|             } finally { | ||||
|                 // Re-enable controls | ||||
|                 input.disabled = false; | ||||
|                 button.disabled = false; | ||||
|                 input.focus(); | ||||
|             } | ||||
|         } | ||||
| 
 | ||||
|         function displayMessage(content, source) { | ||||
|             const messagesContainer = document.getElementById('messages'); | ||||
|             const messageElement = document.createElement('div'); | ||||
|             messageElement.className = `message ${source}`; | ||||
| 
 | ||||
|             const labelElement = document.createElement('span'); | ||||
|             labelElement.className = 'label'; | ||||
|             labelElement.textContent = source; | ||||
| 
 | ||||
|             const contentElement = document.createElement('div'); | ||||
|             contentElement.className = 'content'; | ||||
|             contentElement.textContent = content; | ||||
| 
 | ||||
|             messageElement.appendChild(labelElement); | ||||
|             messageElement.appendChild(contentElement); | ||||
|             messagesContainer.appendChild(messageElement); | ||||
|             messagesContainer.scrollTop = messagesContainer.scrollHeight; | ||||
|         } | ||||
| 
 | ||||
|         async function loadHistory() { | ||||
|             try { | ||||
|                 const response = await fetch('http://localhost:8001/history'); | ||||
|                 if (!response.ok) { | ||||
|                     throw new Error('Network response was not ok'); | ||||
|                 } | ||||
|                 const history = await response.json(); | ||||
|                 history.forEach(message => { | ||||
|                     displayMessage(message.content, message.source); | ||||
|                 }); | ||||
|             } catch (error) { | ||||
|                 console.error('Error loading history:', error); | ||||
|             } | ||||
|         } | ||||
| 
 | ||||
|         // Load chat history when the page loads | ||||
|         window.onload = loadHistory; | ||||
|     </script> | ||||
| </body> | ||||
| 
 | ||||
| </html> | ||||
							
								
								
									
										111
									
								
								python/samples/agentchat_fastapi/app_agent.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										111
									
								
								python/samples/agentchat_fastapi/app_agent.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,111 @@ | ||||
| import json | ||||
| import os | ||||
| from typing import Any | ||||
| 
 | ||||
| import aiofiles | ||||
| import yaml | ||||
| from autogen_agentchat.agents import AssistantAgent | ||||
| from autogen_agentchat.messages import TextMessage | ||||
| from autogen_core import CancellationToken | ||||
| from autogen_core.models import ChatCompletionClient | ||||
| from fastapi import FastAPI, HTTPException | ||||
| from fastapi.middleware.cors import CORSMiddleware | ||||
| from fastapi.responses import FileResponse | ||||
| from fastapi.staticfiles import StaticFiles | ||||
| 
 | ||||
| app = FastAPI() | ||||
| 
 | ||||
| # Add CORS middleware | ||||
| app.add_middleware( | ||||
|     CORSMiddleware, | ||||
|     allow_origins=["*"],  # Allows all origins | ||||
|     allow_credentials=True, | ||||
|     allow_methods=["*"],  # Allows all methods | ||||
|     allow_headers=["*"],  # Allows all headers | ||||
| ) | ||||
| 
 | ||||
| # Serve static files | ||||
| app.mount("/static", StaticFiles(directory="."), name="static") | ||||
| 
 | ||||
| @app.get("/") | ||||
| async def root(): | ||||
|     """Serve the chat interface HTML file.""" | ||||
|     return FileResponse("app_agent.html") | ||||
| 
 | ||||
| model_config_path = "model_config.yaml" | ||||
| state_path = "agent_state.json" | ||||
| history_path = "agent_history.json" | ||||
| 
 | ||||
| 
 | ||||
| async def get_agent() -> AssistantAgent: | ||||
|     """Get the assistant agent, load state from file.""" | ||||
|     # Get model client from config. | ||||
|     async with aiofiles.open(model_config_path, "r") as file: | ||||
|         model_config = yaml.safe_load(await file.read()) | ||||
|     model_client = ChatCompletionClient.load_component(model_config) | ||||
|     # Create the assistant agent. | ||||
|     agent = AssistantAgent( | ||||
|         name="assistant", | ||||
|         model_client=model_client, | ||||
|         system_message="You are a helpful assistant.", | ||||
|     ) | ||||
|     # Load state from file. | ||||
|     if not os.path.exists(state_path): | ||||
|         return agent  # Return agent without loading state. | ||||
|     async with aiofiles.open(state_path, "r") as file: | ||||
|         state = json.loads(await file.read()) | ||||
|     await agent.load_state(state) | ||||
|     return agent | ||||
| 
 | ||||
| 
 | ||||
| async def get_history() -> list[dict[str, Any]]: | ||||
|     """Get chat history from file.""" | ||||
|     if not os.path.exists(history_path): | ||||
|         return [] | ||||
|     async with aiofiles.open(history_path, "r") as file: | ||||
|         return json.loads(await file.read()) | ||||
| 
 | ||||
| 
 | ||||
| @app.get("/history") | ||||
| async def history() -> list[dict[str, Any]]: | ||||
|     try: | ||||
|         return await get_history() | ||||
|     except Exception as e: | ||||
|         raise HTTPException(status_code=500, detail=str(e)) from e | ||||
| 
 | ||||
| 
 | ||||
| @app.post("/chat", response_model=TextMessage) | ||||
| async def chat(request: TextMessage) -> TextMessage: | ||||
|     try: | ||||
|         # Get the agent and respond to the message. | ||||
|         agent = await get_agent() | ||||
|         response = await agent.on_messages(messages=[request], cancellation_token=CancellationToken()) | ||||
| 
 | ||||
|         # Save agent state to file. | ||||
|         state = await agent.save_state() | ||||
|         async with aiofiles.open(state_path, "w") as file: | ||||
|             await file.write(json.dumps(state)) | ||||
| 
 | ||||
|         # Save chat history to file. | ||||
|         history = await get_history() | ||||
|         history.append(request.model_dump()) | ||||
|         history.append(response.chat_message.model_dump()) | ||||
|         async with aiofiles.open(history_path, "w") as file: | ||||
|             await file.write(json.dumps(history)) | ||||
| 
 | ||||
|         assert isinstance(response.chat_message, TextMessage) | ||||
|         return response.chat_message | ||||
|     except Exception as e: | ||||
|         error_message = { | ||||
|             "type": "error", | ||||
|             "content": f"Error: {str(e)}", | ||||
|             "source": "system" | ||||
|         } | ||||
|         raise HTTPException(status_code=500, detail=error_message) from e | ||||
| 
 | ||||
| 
 | ||||
| # Example usage | ||||
| if __name__ == "__main__": | ||||
|     import uvicorn | ||||
| 
 | ||||
|     uvicorn.run(app, host="0.0.0.0", port=8001) | ||||
							
								
								
									
										217
									
								
								python/samples/agentchat_fastapi/app_team.html
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										217
									
								
								python/samples/agentchat_fastapi/app_team.html
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,217 @@ | ||||
| <!DOCTYPE html> | ||||
| <html lang="en"> | ||||
| 
 | ||||
| <head> | ||||
|     <meta charset="UTF-8"> | ||||
|     <meta name="viewport" content="width=device-width, initial-scale=1.0"> | ||||
|     <title>AutoGen FastAPI Sample: Team</title> | ||||
|     <style> | ||||
|         body { | ||||
|             font-family: Arial, sans-serif; | ||||
|             margin: 0; | ||||
|             padding: 0; | ||||
|             display: flex; | ||||
|             flex-direction: column; | ||||
|             align-items: center; | ||||
|             justify-content: center; | ||||
|             height: 100vh; | ||||
|             background-color: #f0f0f0; | ||||
|         } | ||||
| 
 | ||||
|         #chat-container { | ||||
|             width: 90%; | ||||
|             max-width: 600px; | ||||
|             background-color: #fff; | ||||
|             border-radius: 8px; | ||||
|             box-shadow: 0 0 10px rgba(0, 0, 0, 0.1); | ||||
|             padding: 20px; | ||||
|         } | ||||
| 
 | ||||
|         #messages { | ||||
|             height: 600px; | ||||
|             overflow-y: auto; | ||||
|             border-bottom: 1px solid #ddd; | ||||
|             margin-bottom: 20px; | ||||
|         } | ||||
| 
 | ||||
|         .message { | ||||
|             margin: 10px 0; | ||||
|         } | ||||
| 
 | ||||
|         .message.user { | ||||
|             text-align: right; | ||||
|         } | ||||
| 
 | ||||
|         .message.assistant { | ||||
|             text-align: left; | ||||
|         } | ||||
| 
 | ||||
|         .label { | ||||
|             font-weight: bold; | ||||
|             display: block; | ||||
|         } | ||||
| 
 | ||||
|         .content { | ||||
|             margin-top: 5px; | ||||
|         } | ||||
| 
 | ||||
|         #input-container { | ||||
|             display: flex; | ||||
|         } | ||||
| 
 | ||||
|         #input-container input { | ||||
|             flex: 1; | ||||
|             padding: 10px; | ||||
|             border: 1px solid #ddd; | ||||
|             border-radius: 4px; | ||||
|         } | ||||
| 
 | ||||
|         #input-container button { | ||||
|             padding: 10px 20px; | ||||
|             border: none; | ||||
|             background-color: #007bff; | ||||
|             color: #fff; | ||||
|             border-radius: 4px; | ||||
|             cursor: pointer; | ||||
|         } | ||||
| 
 | ||||
|         #input-container input:disabled, | ||||
|         #input-container button:disabled { | ||||
|             background-color: #e0e0e0; | ||||
|             cursor: not-allowed; | ||||
|         } | ||||
| 
 | ||||
|         .message.error { | ||||
|             color: #721c24; | ||||
|             background-color: #f8d7da; | ||||
|             border: 1px solid #f5c6cb; | ||||
|             padding: 10px; | ||||
|             border-radius: 4px; | ||||
|             margin: 10px 0; | ||||
|         } | ||||
| 
 | ||||
|         .message.system { | ||||
|             color: #0c5460; | ||||
|             background-color: #d1ecf1; | ||||
|             border: 1px solid #bee5eb; | ||||
|             padding: 10px; | ||||
|             border-radius: 4px; | ||||
|             margin: 10px 0; | ||||
|         } | ||||
|     </style> | ||||
| </head> | ||||
| 
 | ||||
| <body> | ||||
|     <div id="chat-container"> | ||||
|         <div id="messages"></div> | ||||
|         <div id="input-container"> | ||||
|             <input type="text" id="message-input" placeholder="Type a message..."> | ||||
|             <button id="send-button" onclick="sendMessage()">Send</button> | ||||
|         </div> | ||||
|     </div> | ||||
| 
 | ||||
|     <script> | ||||
|         const ws = new WebSocket('ws://localhost:8002/ws/chat'); | ||||
| 
 | ||||
|         ws.onmessage = function (event) { | ||||
|             const message = JSON.parse(event.data); | ||||
| 
 | ||||
|             if (message.type === 'UserInputRequestedEvent') { | ||||
|                 // Re-enable input and send button if UserInputRequestedEvent is received | ||||
|                 enableInput(); | ||||
|             } | ||||
|             else if (message.type === 'error') { | ||||
|                 // Display error message | ||||
|                 displayMessage(message.content, 'error'); | ||||
|                 enableInput(); | ||||
|             } | ||||
|             else { | ||||
|                 // Display regular message | ||||
|                 displayMessage(message.content, message.source); | ||||
|             } | ||||
|         }; | ||||
| 
 | ||||
|         ws.onerror = function(error) { | ||||
|             displayMessage("WebSocket error occurred. Please refresh the page.", 'error'); | ||||
|             enableInput(); | ||||
|         }; | ||||
| 
 | ||||
|         ws.onclose = function() { | ||||
|             displayMessage("Connection closed. Please refresh the page.", 'system'); | ||||
|             disableInput(); | ||||
|         }; | ||||
| 
 | ||||
|         document.getElementById('message-input').addEventListener('keydown', function (event) { | ||||
|             if (event.key === 'Enter' && !event.target.disabled) { | ||||
|                 sendMessage(); | ||||
|             } | ||||
|         }); | ||||
| 
 | ||||
|         async function sendMessage() { | ||||
|             const input = document.getElementById('message-input'); | ||||
|             const button = document.getElementById('send-button'); | ||||
|             const message = input.value; | ||||
|             if (!message) return; | ||||
| 
 | ||||
|             // Clear input and disable input and send button | ||||
|             input.value = ''; | ||||
|             disableInput(); | ||||
| 
 | ||||
|             // Send message to WebSocket | ||||
|             ws.send(JSON.stringify({ content: message, source: 'user' })); | ||||
|         } | ||||
| 
 | ||||
|         function displayMessage(content, source) { | ||||
|             const messagesContainer = document.getElementById('messages'); | ||||
|             const messageElement = document.createElement('div'); | ||||
|             messageElement.className = `message ${source}`; | ||||
| 
 | ||||
|             const labelElement = document.createElement('span'); | ||||
|             labelElement.className = 'label'; | ||||
|             labelElement.textContent = source; | ||||
| 
 | ||||
|             const contentElement = document.createElement('div'); | ||||
|             contentElement.className = 'content'; | ||||
|             contentElement.textContent = content; | ||||
| 
 | ||||
|             messageElement.appendChild(labelElement); | ||||
|             messageElement.appendChild(contentElement); | ||||
|             messagesContainer.appendChild(messageElement); | ||||
|             messagesContainer.scrollTop = messagesContainer.scrollHeight; | ||||
|         } | ||||
| 
 | ||||
|         function disableInput() { | ||||
|             const input = document.getElementById('message-input'); | ||||
|             const button = document.getElementById('send-button'); | ||||
|             input.disabled = true; | ||||
|             button.disabled = true; | ||||
|         } | ||||
| 
 | ||||
|         function enableInput() { | ||||
|             const input = document.getElementById('message-input'); | ||||
|             const button = document.getElementById('send-button'); | ||||
|             input.disabled = false; | ||||
|             button.disabled = false; | ||||
|         } | ||||
| 
 | ||||
|         async function loadHistory() { | ||||
|             try { | ||||
|                 const response = await fetch('http://localhost:8002/history'); | ||||
|                 if (!response.ok) { | ||||
|                     throw new Error('Network response was not ok'); | ||||
|                 } | ||||
|                 const history = await response.json(); | ||||
|                 history.forEach(message => { | ||||
|                     displayMessage(message.content, message.source); | ||||
|                 }); | ||||
|             } catch (error) { | ||||
|                 console.error('Error loading history:', error); | ||||
|             } | ||||
|         } | ||||
| 
 | ||||
|         // Load chat history when the page loads | ||||
|         window.onload = loadHistory; | ||||
|     </script> | ||||
| </body> | ||||
| 
 | ||||
| </html> | ||||
							
								
								
									
										166
									
								
								python/samples/agentchat_fastapi/app_team.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										166
									
								
								python/samples/agentchat_fastapi/app_team.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,166 @@ | ||||
| import json | ||||
| import logging | ||||
| import os | ||||
| from typing import Any, Awaitable, Callable, Optional | ||||
| 
 | ||||
| import aiofiles | ||||
| import yaml | ||||
| from autogen_agentchat.agents import AssistantAgent, UserProxyAgent | ||||
| from autogen_agentchat.base import TaskResult | ||||
| from autogen_agentchat.messages import TextMessage, UserInputRequestedEvent | ||||
| from autogen_agentchat.teams import RoundRobinGroupChat | ||||
| from autogen_core import CancellationToken | ||||
| from autogen_core.models import ChatCompletionClient | ||||
| from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect | ||||
| from fastapi.middleware.cors import CORSMiddleware | ||||
| from fastapi.responses import FileResponse | ||||
| from fastapi.staticfiles import StaticFiles | ||||
| 
 | ||||
| logger = logging.getLogger(__name__) | ||||
| 
 | ||||
| app = FastAPI() | ||||
| 
 | ||||
| # Add CORS middleware | ||||
| app.add_middleware( | ||||
|     CORSMiddleware, | ||||
|     allow_origins=["*"],  # Allows all origins | ||||
|     allow_credentials=True, | ||||
|     allow_methods=["*"],  # Allows all methods | ||||
|     allow_headers=["*"],  # Allows all headers | ||||
| ) | ||||
| 
 | ||||
| model_config_path = "model_config.yaml" | ||||
| state_path = "team_state.json" | ||||
| history_path = "team_history.json" | ||||
| 
 | ||||
| # Serve static files | ||||
| app.mount("/static", StaticFiles(directory="."), name="static") | ||||
| 
 | ||||
| @app.get("/") | ||||
| async def root(): | ||||
|     """Serve the chat interface HTML file.""" | ||||
|     return FileResponse("app_team.html") | ||||
| 
 | ||||
| 
 | ||||
| async def get_team( | ||||
|     user_input_func: Callable[[str, Optional[CancellationToken]], Awaitable[str]], | ||||
| ) -> RoundRobinGroupChat: | ||||
|     # Get model client from config. | ||||
|     async with aiofiles.open(model_config_path, "r") as file: | ||||
|         model_config = yaml.safe_load(await file.read()) | ||||
|     model_client = ChatCompletionClient.load_component(model_config) | ||||
|     # Create the team. | ||||
|     agent = AssistantAgent( | ||||
|         name="assistant", | ||||
|         model_client=model_client, | ||||
|         system_message="You are a helpful assistant.", | ||||
|     ) | ||||
|     yoda = AssistantAgent( | ||||
|         name="yoda", | ||||
|         model_client=model_client, | ||||
|         system_message="Repeat the same message in the tone of Yoda.", | ||||
|     ) | ||||
|     user_proxy = UserProxyAgent( | ||||
|         name="user", | ||||
|         input_func=user_input_func,  # Use the user input function. | ||||
|     ) | ||||
|     team = RoundRobinGroupChat( | ||||
|         [agent, yoda, user_proxy], | ||||
|     ) | ||||
|     # Load state from file. | ||||
|     if not os.path.exists(state_path): | ||||
|         return team | ||||
|     async with aiofiles.open(state_path, "r") as file: | ||||
|         state = json.loads(await file.read()) | ||||
|     await team.load_state(state) | ||||
|     return team | ||||
| 
 | ||||
| 
 | ||||
| async def get_history() -> list[dict[str, Any]]: | ||||
|     """Get chat history from file.""" | ||||
|     if not os.path.exists(history_path): | ||||
|         return [] | ||||
|     async with aiofiles.open(history_path, "r") as file: | ||||
|         return json.loads(await file.read()) | ||||
| 
 | ||||
| 
 | ||||
| @app.get("/history") | ||||
| async def history() -> list[dict[str, Any]]: | ||||
|     try: | ||||
|         return await get_history() | ||||
|     except Exception as e: | ||||
|         raise HTTPException(status_code=500, detail=str(e)) from e | ||||
| 
 | ||||
| 
 | ||||
| @app.websocket("/ws/chat") | ||||
| async def chat(websocket: WebSocket): | ||||
|     await websocket.accept() | ||||
| 
 | ||||
|     # User input function used by the team. | ||||
|     async def _user_input(prompt: str, cancellation_token: CancellationToken | None) -> str: | ||||
|         data = await websocket.receive_json() | ||||
|         message = TextMessage.model_validate(data) | ||||
|         return message.content | ||||
| 
 | ||||
|     try: | ||||
|         while True: | ||||
|             # Get user message. | ||||
|             data = await websocket.receive_json() | ||||
|             request = TextMessage.model_validate(data) | ||||
| 
 | ||||
|             try: | ||||
|                 # Get the team and respond to the message. | ||||
|                 team = await get_team(_user_input) | ||||
|                 history = await get_history() | ||||
|                 stream = team.run_stream(task=request) | ||||
|                 async for message in stream: | ||||
|                     if isinstance(message, TaskResult): | ||||
|                         continue | ||||
|                     await websocket.send_json(message.model_dump()) | ||||
|                     if not isinstance(message, UserInputRequestedEvent): | ||||
|                         # Don't save user input events to history. | ||||
|                         history.append(message.model_dump()) | ||||
| 
 | ||||
|                 # Save team state to file. | ||||
|                 async with aiofiles.open(state_path, "w") as file: | ||||
|                     state = await team.save_state() | ||||
|                     await file.write(json.dumps(state)) | ||||
| 
 | ||||
|                 # Save chat history to file. | ||||
|                 async with aiofiles.open(history_path, "w") as file: | ||||
|                     await file.write(json.dumps(history)) | ||||
|                      | ||||
|             except Exception as e: | ||||
|                 # Send error message to client | ||||
|                 error_message = { | ||||
|                     "type": "error", | ||||
|                     "content": f"Error: {str(e)}", | ||||
|                     "source": "system" | ||||
|                 } | ||||
|                 await websocket.send_json(error_message) | ||||
|                 # Re-enable input after error | ||||
|                 await websocket.send_json({ | ||||
|                     "type": "UserInputRequestedEvent", | ||||
|                     "content": "An error occurred. Please try again.", | ||||
|                     "source": "system" | ||||
|                 }) | ||||
|                  | ||||
|     except WebSocketDisconnect: | ||||
|         logger.info("Client disconnected") | ||||
|     except Exception as e: | ||||
|         logger.error(f"Unexpected error: {str(e)}") | ||||
|         try: | ||||
|             await websocket.send_json({ | ||||
|                 "type": "error", | ||||
|                 "content": f"Unexpected error: {str(e)}", | ||||
|                 "source": "system" | ||||
|             }) | ||||
|         except: | ||||
|             pass | ||||
| 
 | ||||
| 
 | ||||
| # Example usage | ||||
| if __name__ == "__main__": | ||||
|     import uvicorn | ||||
| 
 | ||||
|     uvicorn.run(app, host="0.0.0.0", port=8002) | ||||
							
								
								
									
										26
									
								
								python/samples/agentchat_fastapi/model_config_template.yaml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										26
									
								
								python/samples/agentchat_fastapi/model_config_template.yaml
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,26 @@ | ||||
| # Use Open AI with key | ||||
| provider: autogen_ext.models.openai.OpenAIChatCompletionClient | ||||
| config: | ||||
|   model: gpt-4o | ||||
|   api_key: REPLACE_WITH_YOUR_API_KEY | ||||
| # Use Azure Open AI with key | ||||
| # provider: autogen_ext.models.openai.AzureOpenAIChatCompletionClient | ||||
| # config: | ||||
| #   model: gpt-4o | ||||
| #   azure_endpoint: https://{your-custom-endpoint}.openai.azure.com/ | ||||
| #   azure_deployment: {your-azure-deployment} | ||||
| #   api_version: {your-api-version} | ||||
| #   api_key: REPLACE_WITH_YOUR_API_KEY | ||||
| # Use Azure OpenAI with AD token provider. | ||||
| # provider: autogen_ext.models.openai.AzureOpenAIChatCompletionClient | ||||
| # config: | ||||
| #   model: gpt-4o | ||||
| #   azure_endpoint: https://{your-custom-endpoint}.openai.azure.com/ | ||||
| #   azure_deployment: {your-azure-deployment} | ||||
| #   api_version: {your-api-version} | ||||
| #   azure_ad_token_provider: | ||||
| #     provider: autogen_ext.auth.azure.AzureTokenProvider | ||||
| #     config: | ||||
| #       provider_kind: DefaultAzureCredential | ||||
| #       scopes: | ||||
| #         - https://cognitiveservices.azure.com/.default | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Eric Zhu
						Eric Zhu