autogen/src/agnext/chat/agents/oai_assistant.py

76 lines
3.0 KiB
Python
Raw Normal View History

2024-05-28 15:49:30 -07:00
from typing import Callable, Dict
2024-05-23 08:23:24 -07:00
import openai
from agnext.agent_components.type_routed_agent import TypeRoutedAgent, message_handler
from agnext.chat.agents.base import BaseChatAgent
from agnext.chat.types import Reset, RespondNow, TextMessage
2024-05-27 17:10:56 -04:00
from agnext.core import AgentRuntime, CancellationToken
2024-05-23 08:23:24 -07:00
class OpenAIAssistantAgent(BaseChatAgent, TypeRoutedAgent):
2024-05-23 08:23:24 -07:00
def __init__(
self,
name: str,
description: str,
runtime: AgentRuntime,
2024-05-23 08:23:24 -07:00
client: openai.AsyncClient,
assistant_id: str,
thread_id: str,
2024-05-28 15:49:30 -07:00
tools: Dict[str, Callable[..., str]] | None = None,
2024-05-23 08:23:24 -07:00
) -> None:
super().__init__(name, description, runtime)
self._client = client
self._assistant_id = assistant_id
self._thread_id = thread_id
2024-05-28 15:49:30 -07:00
# TODO: investigate why this is 1, as setting this to 0 causes the earlest message in the window to be ignored.
self._current_session_window_length = 1
self._tools = tools or {}
2024-05-23 08:23:24 -07:00
@message_handler(TextMessage)
2024-05-28 15:49:30 -07:00
async def on_text_message(self, message: TextMessage, cancellation_token: CancellationToken) -> None:
2024-05-23 08:23:24 -07:00
# Save the message to the thread.
_ = await self._client.beta.threads.messages.create(
thread_id=self._thread_id,
content=message.content,
2024-05-23 08:23:24 -07:00
role="user",
metadata={"sender": message.source},
2024-05-23 08:23:24 -07:00
)
self._current_session_window_length += 1
@message_handler(Reset)
async def on_reset(self, message: Reset, cancellation_token: CancellationToken) -> None:
# Reset the current session window.
2024-05-28 15:49:30 -07:00
self._current_session_window_length = 1
@message_handler(RespondNow)
async def on_respond_now(self, message: RespondNow, cancellation_token: CancellationToken) -> TextMessage:
2024-05-23 08:23:24 -07:00
# Create a run and wait until it finishes.
run = await self._client.beta.threads.runs.create_and_poll(
thread_id=self._thread_id,
assistant_id=self._assistant_id,
truncation_strategy={
"type": "last_messages",
"last_messages": self._current_session_window_length,
},
)
if run.status != "completed":
# TODO: handle other statuses.
raise ValueError(f"Run did not complete successfully: {run}")
2024-05-28 15:49:30 -07:00
# Increment the current session window length.
self._current_session_window_length += 1
2024-05-23 08:23:24 -07:00
# Get the last message from the run.
response = await self._client.beta.threads.messages.list(self._thread_id, run_id=run.id, order="desc", limit=1)
last_message_content = response.data[0].content
# TODO: handle array of content.
text_content = [content for content in last_message_content if content.type == "text"]
if not text_content:
raise ValueError(f"Expected text content in the last message: {last_message_content}")
# TODO: handle multiple text content.
return TextMessage(content=text_content[0].text.value, source=self.name)