| 
									
										
										
										
											2024-06-10 19:51:51 -07:00
										 |  |  | """This is an example of a terminal-based ChatGPT clone
 | 
					
						
							|  |  |  | using an OpenAIAssistantAgent and event-based orchestration."""
 | 
					
						
							| 
									
										
										
										
											2024-06-08 01:27:27 -07:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-06-10 19:51:51 -07:00
										 |  |  | import argparse | 
					
						
							|  |  |  | import asyncio | 
					
						
							|  |  |  | import logging | 
					
						
							| 
									
										
										
										
											2024-06-08 01:27:27 -07:00
										 |  |  | import os | 
					
						
							|  |  |  | import re | 
					
						
							| 
									
										
										
										
											2024-06-25 13:23:29 -07:00
										 |  |  | import sys | 
					
						
							| 
									
										
										
										
											2024-06-10 19:51:51 -07:00
										 |  |  | from typing import List | 
					
						
							| 
									
										
										
										
											2024-06-08 01:27:27 -07:00
										 |  |  | 
 | 
					
						
							|  |  |  | import aiofiles | 
					
						
							|  |  |  | import openai | 
					
						
							|  |  |  | from agnext.application import SingleThreadedAgentRuntime | 
					
						
							|  |  |  | from agnext.components import TypeRoutedAgent, message_handler | 
					
						
							| 
									
										
										
										
											2024-06-18 14:53:18 -04:00
										 |  |  | from agnext.core import AgentId, AgentRuntime, CancellationToken | 
					
						
							| 
									
										
										
										
											2024-06-08 01:27:27 -07:00
										 |  |  | from openai import AsyncAssistantEventHandler | 
					
						
							|  |  |  | from openai.types.beta.thread import ToolResources | 
					
						
							|  |  |  | from openai.types.beta.threads import Message, Text, TextDelta | 
					
						
							|  |  |  | from openai.types.beta.threads.runs import RunStep, RunStepDelta | 
					
						
							|  |  |  | from typing_extensions import override | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-06-25 13:23:29 -07:00
										 |  |  | sys.path.append(os.path.join(os.path.dirname(__file__), "..")) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | from common.agents import OpenAIAssistantAgent | 
					
						
							|  |  |  | from common.memory import BufferedChatMemory | 
					
						
							|  |  |  | from common.patterns._group_chat_manager import GroupChatManager | 
					
						
							|  |  |  | from common.types import PublishNow, TextMessage | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-06-08 01:27:27 -07:00
										 |  |  | sep = "-" * 50 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-06-09 12:11:36 -07:00
										 |  |  | class UserProxyAgent(TypeRoutedAgent):  # type: ignore | 
					
						
							| 
									
										
										
										
											2024-06-10 19:51:51 -07:00
										 |  |  |     def __init__(  # type: ignore | 
					
						
							| 
									
										
										
										
											2024-06-08 01:27:27 -07:00
										 |  |  |         self, | 
					
						
							| 
									
										
										
										
											2024-06-10 19:51:51 -07:00
										 |  |  |         client: openai.AsyncClient,  # type: ignore | 
					
						
							| 
									
										
										
										
											2024-06-08 01:27:27 -07:00
										 |  |  |         assistant_id: str, | 
					
						
							|  |  |  |         thread_id: str, | 
					
						
							|  |  |  |         vector_store_id: str, | 
					
						
							|  |  |  |     ) -> None:  # type: ignore | 
					
						
							|  |  |  |         super().__init__( | 
					
						
							|  |  |  |             description="A human user", | 
					
						
							| 
									
										
										
										
											2024-06-09 12:11:36 -07:00
										 |  |  |         )  # type: ignore | 
					
						
							| 
									
										
										
										
											2024-06-08 01:27:27 -07:00
										 |  |  |         self._client = client | 
					
						
							|  |  |  |         self._assistant_id = assistant_id | 
					
						
							|  |  |  |         self._thread_id = thread_id | 
					
						
							|  |  |  |         self._vector_store_id = vector_store_id | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @message_handler()  # type: ignore | 
					
						
							|  |  |  |     async def on_text_message(self, message: TextMessage, cancellation_token: CancellationToken) -> None:  # type: ignore | 
					
						
							|  |  |  |         # TODO: render image if message has image. | 
					
						
							|  |  |  |         # print(f"{message.source}: {message.content}") | 
					
						
							|  |  |  |         pass | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-06-10 19:51:51 -07:00
										 |  |  |     async def _get_user_input(self, prompt: str) -> str: | 
					
						
							|  |  |  |         loop = asyncio.get_event_loop() | 
					
						
							|  |  |  |         return await loop.run_in_executor(None, input, prompt) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-06-08 01:27:27 -07:00
										 |  |  |     @message_handler()  # type: ignore | 
					
						
							| 
									
										
										
										
											2024-06-10 19:51:51 -07:00
										 |  |  |     async def on_publish_now(self, message: PublishNow, cancellation_token: CancellationToken) -> None:  # type: ignore | 
					
						
							| 
									
										
										
										
											2024-06-08 01:27:27 -07:00
										 |  |  |         while True: | 
					
						
							| 
									
										
										
										
											2024-06-10 19:51:51 -07:00
										 |  |  |             user_input = await self._get_user_input(f"\n{sep}\nYou: ") | 
					
						
							| 
									
										
										
										
											2024-06-08 01:27:27 -07:00
										 |  |  |             # Parse upload file command '[upload code_interpreter | file_search filename]'. | 
					
						
							|  |  |  |             match = re.search(r"\[upload\s+(code_interpreter|file_search)\s+(.+)\]", user_input) | 
					
						
							|  |  |  |             if match: | 
					
						
							|  |  |  |                 # Purpose of the file. | 
					
						
							|  |  |  |                 purpose = match.group(1) | 
					
						
							|  |  |  |                 # Extract file path. | 
					
						
							|  |  |  |                 file_path = match.group(2) | 
					
						
							|  |  |  |                 if not os.path.exists(file_path): | 
					
						
							|  |  |  |                     print(f"File not found: {file_path}") | 
					
						
							|  |  |  |                     continue | 
					
						
							|  |  |  |                 # Filename. | 
					
						
							|  |  |  |                 file_name = os.path.basename(file_path) | 
					
						
							|  |  |  |                 # Read file content. | 
					
						
							|  |  |  |                 async with aiofiles.open(file_path, "rb") as f: | 
					
						
							|  |  |  |                     file_content = await f.read() | 
					
						
							|  |  |  |                 if purpose == "code_interpreter": | 
					
						
							|  |  |  |                     # Upload file. | 
					
						
							|  |  |  |                     file = await self._client.files.create(file=(file_name, file_content), purpose="assistants") | 
					
						
							|  |  |  |                     # Get existing file ids from tool resources. | 
					
						
							|  |  |  |                     thread = await self._client.beta.threads.retrieve(thread_id=self._thread_id) | 
					
						
							|  |  |  |                     tool_resources: ToolResources = thread.tool_resources if thread.tool_resources else ToolResources() | 
					
						
							|  |  |  |                     assert tool_resources.code_interpreter is not None | 
					
						
							|  |  |  |                     if tool_resources.code_interpreter.file_ids: | 
					
						
							|  |  |  |                         file_ids = tool_resources.code_interpreter.file_ids | 
					
						
							|  |  |  |                     else: | 
					
						
							|  |  |  |                         file_ids = [file.id] | 
					
						
							|  |  |  |                     # Update thread with new file. | 
					
						
							|  |  |  |                     await self._client.beta.threads.update( | 
					
						
							|  |  |  |                         thread_id=self._thread_id, | 
					
						
							|  |  |  |                         tool_resources={"code_interpreter": {"file_ids": file_ids}}, | 
					
						
							|  |  |  |                     ) | 
					
						
							|  |  |  |                 elif purpose == "file_search": | 
					
						
							|  |  |  |                     # Upload file to vector store. | 
					
						
							|  |  |  |                     file_batch = await self._client.beta.vector_stores.file_batches.upload_and_poll( | 
					
						
							|  |  |  |                         vector_store_id=self._vector_store_id, | 
					
						
							|  |  |  |                         files=[(file_name, file_content)], | 
					
						
							|  |  |  |                     ) | 
					
						
							|  |  |  |                     assert file_batch.status == "completed" | 
					
						
							|  |  |  |                 print(f"Uploaded file: {file_name}") | 
					
						
							|  |  |  |                 continue | 
					
						
							|  |  |  |             elif user_input.startswith("[upload"): | 
					
						
							|  |  |  |                 print("Invalid upload command. Please use '[upload code_interpreter | file_search filename]'.") | 
					
						
							|  |  |  |                 continue | 
					
						
							| 
									
										
										
										
											2024-06-10 19:51:51 -07:00
										 |  |  |             elif user_input.strip().lower() == "exit": | 
					
						
							|  |  |  |                 # Exit handler. | 
					
						
							|  |  |  |                 return | 
					
						
							| 
									
										
										
										
											2024-06-08 01:27:27 -07:00
										 |  |  |             else: | 
					
						
							| 
									
										
										
										
											2024-06-10 19:51:51 -07:00
										 |  |  |                 # Publish user input and exit handler. | 
					
						
							| 
									
										
										
										
											2024-06-18 15:51:02 -04:00
										 |  |  |                 await self.publish_message(TextMessage(content=user_input, source=self.metadata["name"])) | 
					
						
							| 
									
										
										
										
											2024-06-10 19:51:51 -07:00
										 |  |  |                 return | 
					
						
							| 
									
										
										
										
											2024-06-08 01:27:27 -07:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class EventHandler(AsyncAssistantEventHandler): | 
					
						
							|  |  |  |     @override | 
					
						
							|  |  |  |     async def on_text_delta(self, delta: TextDelta, snapshot: Text) -> None: | 
					
						
							|  |  |  |         print(delta.value, end="", flush=True) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @override | 
					
						
							|  |  |  |     async def on_run_step_created(self, run_step: RunStep) -> None: | 
					
						
							|  |  |  |         details = run_step.step_details | 
					
						
							|  |  |  |         if details.type == "tool_calls": | 
					
						
							|  |  |  |             for tool in details.tool_calls: | 
					
						
							|  |  |  |                 if tool.type == "code_interpreter": | 
					
						
							|  |  |  |                     print("\nGenerating code to interpret:\n\n```python") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @override | 
					
						
							|  |  |  |     async def on_run_step_done(self, run_step: RunStep) -> None: | 
					
						
							|  |  |  |         details = run_step.step_details | 
					
						
							|  |  |  |         if details.type == "tool_calls": | 
					
						
							|  |  |  |             for tool in details.tool_calls: | 
					
						
							|  |  |  |                 if tool.type == "code_interpreter": | 
					
						
							|  |  |  |                     print("\n```\nExecuting code...") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @override | 
					
						
							|  |  |  |     async def on_run_step_delta(self, delta: RunStepDelta, snapshot: RunStep) -> None: | 
					
						
							|  |  |  |         details = delta.step_details | 
					
						
							|  |  |  |         if details is not None and details.type == "tool_calls": | 
					
						
							|  |  |  |             for tool in details.tool_calls or []: | 
					
						
							|  |  |  |                 if tool.type == "code_interpreter" and tool.code_interpreter and tool.code_interpreter.input: | 
					
						
							|  |  |  |                     print(tool.code_interpreter.input, end="", flush=True) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @override | 
					
						
							|  |  |  |     async def on_message_created(self, message: Message) -> None: | 
					
						
							|  |  |  |         print(f"{sep}\nAssistant:\n") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     @override | 
					
						
							|  |  |  |     async def on_message_done(self, message: Message) -> None: | 
					
						
							|  |  |  |         # print a citation to the file searched | 
					
						
							|  |  |  |         if not message.content: | 
					
						
							|  |  |  |             return | 
					
						
							|  |  |  |         content = message.content[0] | 
					
						
							|  |  |  |         if not content.type == "text": | 
					
						
							|  |  |  |             return | 
					
						
							|  |  |  |         text_content = content.text | 
					
						
							|  |  |  |         annotations = text_content.annotations | 
					
						
							|  |  |  |         citations: List[str] = [] | 
					
						
							|  |  |  |         for index, annotation in enumerate(annotations): | 
					
						
							|  |  |  |             text_content.value = text_content.value.replace(annotation.text, f"[{index}]") | 
					
						
							|  |  |  |             if file_citation := getattr(annotation, "file_citation", None): | 
					
						
							|  |  |  |                 client = openai.AsyncClient() | 
					
						
							|  |  |  |                 cited_file = await client.files.retrieve(file_citation.file_id) | 
					
						
							|  |  |  |                 citations.append(f"[{index}] {cited_file.filename}") | 
					
						
							|  |  |  |         if citations: | 
					
						
							|  |  |  |             print("\n".join(citations)) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-07-23 11:49:38 -07:00
										 |  |  | async def assistant_chat(runtime: AgentRuntime) -> AgentId: | 
					
						
							| 
									
										
										
										
											2024-06-08 01:27:27 -07:00
										 |  |  |     oai_assistant = openai.beta.assistants.create( | 
					
						
							|  |  |  |         model="gpt-4-turbo", | 
					
						
							|  |  |  |         description="An AI assistant that helps with everyday tasks.", | 
					
						
							|  |  |  |         instructions="Help the user with their task.", | 
					
						
							|  |  |  |         tools=[{"type": "code_interpreter"}, {"type": "file_search"}], | 
					
						
							|  |  |  |     ) | 
					
						
							|  |  |  |     vector_store = openai.beta.vector_stores.create() | 
					
						
							|  |  |  |     thread = openai.beta.threads.create( | 
					
						
							|  |  |  |         tool_resources={"file_search": {"vector_store_ids": [vector_store.id]}}, | 
					
						
							|  |  |  |     ) | 
					
						
							| 
									
										
										
										
											2024-07-23 11:49:38 -07:00
										 |  |  |     assistant = await runtime.register_and_get( | 
					
						
							| 
									
										
										
										
											2024-06-18 14:53:18 -04:00
										 |  |  |         "Assistant", | 
					
						
							|  |  |  |         lambda: OpenAIAssistantAgent( | 
					
						
							|  |  |  |             description="An AI assistant that helps with everyday tasks.", | 
					
						
							|  |  |  |             client=openai.AsyncClient(), | 
					
						
							|  |  |  |             assistant_id=oai_assistant.id, | 
					
						
							|  |  |  |             thread_id=thread.id, | 
					
						
							|  |  |  |             assistant_event_handler_factory=lambda: EventHandler(), | 
					
						
							|  |  |  |         ), | 
					
						
							| 
									
										
										
										
											2024-06-08 01:27:27 -07:00
										 |  |  |     ) | 
					
						
							| 
									
										
										
										
											2024-06-18 14:53:18 -04:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-07-23 11:49:38 -07:00
										 |  |  |     user = await runtime.register_and_get( | 
					
						
							| 
									
										
										
										
											2024-06-18 14:53:18 -04:00
										 |  |  |         "User", | 
					
						
							|  |  |  |         lambda: UserProxyAgent( | 
					
						
							|  |  |  |             client=openai.AsyncClient(), | 
					
						
							|  |  |  |             assistant_id=oai_assistant.id, | 
					
						
							|  |  |  |             thread_id=thread.id, | 
					
						
							|  |  |  |             vector_store_id=vector_store.id, | 
					
						
							|  |  |  |         ), | 
					
						
							| 
									
										
										
										
											2024-06-08 01:27:27 -07:00
										 |  |  |     ) | 
					
						
							| 
									
										
										
										
											2024-06-10 19:51:51 -07:00
										 |  |  |     # Create a group chat manager to facilitate a turn-based conversation. | 
					
						
							| 
									
										
										
										
											2024-07-23 11:49:38 -07:00
										 |  |  |     await runtime.register( | 
					
						
							| 
									
										
										
										
											2024-06-18 14:53:18 -04:00
										 |  |  |         "GroupChatManager", | 
					
						
							|  |  |  |         lambda: GroupChatManager( | 
					
						
							|  |  |  |             description="A group chat manager.", | 
					
						
							|  |  |  |             memory=BufferedChatMemory(buffer_size=10), | 
					
						
							|  |  |  |             participants=[assistant, user], | 
					
						
							|  |  |  |         ), | 
					
						
							| 
									
										
										
										
											2024-06-08 01:27:27 -07:00
										 |  |  |     ) | 
					
						
							| 
									
										
										
										
											2024-06-10 19:51:51 -07:00
										 |  |  |     return user | 
					
						
							| 
									
										
										
										
											2024-06-08 01:27:27 -07:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | async def main() -> None: | 
					
						
							|  |  |  |     usage = """Chat with an AI assistant backed by OpenAI Assistant API.
 | 
					
						
							|  |  |  | You can upload files to the assistant using the command: | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | [upload code_interpreter | file_search filename] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | where 'code_interpreter' or 'file_search' is the purpose of the file and | 
					
						
							|  |  |  | 'filename' is the path to the file. For example: | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | [upload code_interpreter data.csv] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | This will upload data.csv to the assistant for use with the code interpreter tool. | 
					
						
							| 
									
										
										
										
											2024-06-10 19:51:51 -07:00
										 |  |  | 
 | 
					
						
							|  |  |  | Type "exit" to exit the chat. | 
					
						
							| 
									
										
										
										
											2024-06-08 01:27:27 -07:00
										 |  |  | """
 | 
					
						
							|  |  |  |     runtime = SingleThreadedAgentRuntime() | 
					
						
							| 
									
										
										
										
											2024-07-23 11:49:38 -07:00
										 |  |  |     user = await assistant_chat(runtime) | 
					
						
							| 
									
										
										
										
											2024-07-01 11:53:45 -04:00
										 |  |  |     _run_context = runtime.start() | 
					
						
							| 
									
										
										
										
											2024-06-08 01:27:27 -07:00
										 |  |  |     print(usage) | 
					
						
							| 
									
										
										
										
											2024-06-10 19:51:51 -07:00
										 |  |  |     # Request the user to start the conversation. | 
					
						
							| 
									
										
										
										
											2024-06-27 13:40:12 -04:00
										 |  |  |     await runtime.send_message(PublishNow(), user) | 
					
						
							| 
									
										
										
										
											2024-07-01 11:53:45 -04:00
										 |  |  | 
 | 
					
						
							|  |  |  |     # TODO: have a way to exit the loop. | 
					
						
							| 
									
										
										
										
											2024-06-08 01:27:27 -07:00
										 |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | if __name__ == "__main__": | 
					
						
							| 
									
										
										
										
											2024-06-10 19:51:51 -07:00
										 |  |  |     parser = argparse.ArgumentParser(description="Chat with an AI assistant.") | 
					
						
							|  |  |  |     parser.add_argument("--verbose", action="store_true", help="Enable verbose logging.") | 
					
						
							|  |  |  |     args = parser.parse_args() | 
					
						
							|  |  |  |     if args.verbose: | 
					
						
							|  |  |  |         logging.basicConfig(level=logging.WARNING) | 
					
						
							|  |  |  |         logging.getLogger("agnext").setLevel(logging.DEBUG) | 
					
						
							| 
									
										
										
										
											2024-06-17 17:54:27 -07:00
										 |  |  |         handler = logging.FileHandler("assistant.log") | 
					
						
							|  |  |  |         logging.getLogger("agnext").addHandler(handler) | 
					
						
							| 
									
										
										
										
											2024-06-08 01:27:27 -07:00
										 |  |  |     asyncio.run(main()) |