from typing import Callable, Dict import openai from agnext.agent_components.type_routed_agent import TypeRoutedAgent, message_handler from agnext.chat.agents.base import BaseChatAgent from agnext.chat.types import Reset, RespondNow, TextMessage from agnext.core import AgentRuntime, CancellationToken class OpenAIAssistantAgent(BaseChatAgent, TypeRoutedAgent): def __init__( self, name: str, description: str, runtime: AgentRuntime, client: openai.AsyncClient, assistant_id: str, thread_id: str, tools: Dict[str, Callable[..., str]] | None = None, ) -> None: super().__init__(name, description, runtime) self._client = client self._assistant_id = assistant_id self._thread_id = thread_id # TODO: investigate why this is 1, as setting this to 0 causes the earlest message in the window to be ignored. self._current_session_window_length = 1 self._tools = tools or {} @message_handler(TextMessage) async def on_text_message(self, message: TextMessage, cancellation_token: CancellationToken) -> None: # Save the message to the thread. _ = await self._client.beta.threads.messages.create( thread_id=self._thread_id, content=message.content, role="user", metadata={"sender": message.source}, ) self._current_session_window_length += 1 @message_handler(Reset) async def on_reset(self, message: Reset, cancellation_token: CancellationToken) -> None: # Reset the current session window. self._current_session_window_length = 1 @message_handler(RespondNow) async def on_respond_now(self, message: RespondNow, cancellation_token: CancellationToken) -> TextMessage: # Create a run and wait until it finishes. run = await self._client.beta.threads.runs.create_and_poll( thread_id=self._thread_id, assistant_id=self._assistant_id, truncation_strategy={ "type": "last_messages", "last_messages": self._current_session_window_length, }, ) if run.status != "completed": # TODO: handle other statuses. raise ValueError(f"Run did not complete successfully: {run}") # Increment the current session window length. self._current_session_window_length += 1 # Get the last message from the run. response = await self._client.beta.threads.messages.list(self._thread_id, run_id=run.id, order="desc", limit=1) last_message_content = response.data[0].content # TODO: handle array of content. text_content = [content for content in last_message_content if content.type == "text"] if not text_content: raise ValueError(f"Expected text content in the last message: {last_message_content}") # TODO: handle multiple text content. return TextMessage(content=text_content[0].text.value, source=self.name)