2024-05-23 08:23:24 -07:00
|
|
|
import openai
|
|
|
|
|
2024-05-24 17:25:17 -04:00
|
|
|
from agnext.agent_components.type_routed_agent import TypeRoutedAgent, message_handler
|
2024-05-23 16:00:05 -04:00
|
|
|
from agnext.chat.agents.base import BaseChatAgent
|
2024-05-24 17:25:17 -04:00
|
|
|
from agnext.chat.types import Reset, RespondNow, TextMessage
|
2024-05-27 17:10:56 -04:00
|
|
|
from agnext.core import AgentRuntime, CancellationToken
|
2024-05-23 16:00:05 -04:00
|
|
|
|
2024-05-23 08:23:24 -07:00
|
|
|
|
2024-05-24 17:25:17 -04:00
|
|
|
class OpenAIAssistantAgent(BaseChatAgent, TypeRoutedAgent):
|
2024-05-23 08:23:24 -07:00
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
name: str,
|
|
|
|
description: str,
|
2024-05-23 16:00:05 -04:00
|
|
|
runtime: AgentRuntime,
|
2024-05-23 08:23:24 -07:00
|
|
|
client: openai.AsyncClient,
|
|
|
|
assistant_id: str,
|
|
|
|
thread_id: str,
|
|
|
|
) -> None:
|
|
|
|
super().__init__(name, description, runtime)
|
|
|
|
self._client = client
|
|
|
|
self._assistant_id = assistant_id
|
|
|
|
self._thread_id = thread_id
|
|
|
|
self._current_session_window_length = 0
|
|
|
|
|
2024-05-23 16:00:05 -04:00
|
|
|
# TODO: use require_response
|
2024-05-24 17:25:17 -04:00
|
|
|
@message_handler(TextMessage)
|
2024-05-23 16:00:05 -04:00
|
|
|
async def on_chat_message_with_cancellation(
|
2024-05-26 08:45:02 -04:00
|
|
|
self, message: TextMessage, cancellation_token: CancellationToken
|
2024-05-24 17:25:17 -04:00
|
|
|
) -> None:
|
2024-05-23 08:23:24 -07:00
|
|
|
print("---------------")
|
2024-05-24 17:25:17 -04:00
|
|
|
print(f"{self.name} received message from {message.source}: {message.content}")
|
2024-05-23 08:23:24 -07:00
|
|
|
print("---------------")
|
|
|
|
|
|
|
|
# Save the message to the thread.
|
|
|
|
_ = await self._client.beta.threads.messages.create(
|
|
|
|
thread_id=self._thread_id,
|
2024-05-24 17:25:17 -04:00
|
|
|
content=message.content,
|
2024-05-23 08:23:24 -07:00
|
|
|
role="user",
|
2024-05-24 17:25:17 -04:00
|
|
|
metadata={"sender": message.source},
|
2024-05-23 08:23:24 -07:00
|
|
|
)
|
|
|
|
self._current_session_window_length += 1
|
|
|
|
|
2024-05-24 17:25:17 -04:00
|
|
|
@message_handler(Reset)
|
2024-05-26 08:45:02 -04:00
|
|
|
async def on_reset(self, message: Reset, cancellation_token: CancellationToken) -> None:
|
2024-05-24 17:25:17 -04:00
|
|
|
# Reset the current session window.
|
|
|
|
self._current_session_window_length = 0
|
|
|
|
|
|
|
|
@message_handler(RespondNow)
|
2024-05-26 08:45:02 -04:00
|
|
|
async def on_respond_now(self, message: RespondNow, cancellation_token: CancellationToken) -> TextMessage:
|
2024-05-23 08:23:24 -07:00
|
|
|
# Create a run and wait until it finishes.
|
|
|
|
run = await self._client.beta.threads.runs.create_and_poll(
|
|
|
|
thread_id=self._thread_id,
|
|
|
|
assistant_id=self._assistant_id,
|
|
|
|
truncation_strategy={
|
|
|
|
"type": "last_messages",
|
|
|
|
"last_messages": self._current_session_window_length,
|
|
|
|
},
|
|
|
|
)
|
|
|
|
|
|
|
|
if run.status != "completed":
|
|
|
|
# TODO: handle other statuses.
|
|
|
|
raise ValueError(f"Run did not complete successfully: {run}")
|
|
|
|
|
|
|
|
# Get the last message from the run.
|
|
|
|
response = await self._client.beta.threads.messages.list(self._thread_id, run_id=run.id, order="desc", limit=1)
|
|
|
|
last_message_content = response.data[0].content
|
|
|
|
|
|
|
|
# TODO: handle array of content.
|
|
|
|
text_content = [content for content in last_message_content if content.type == "text"]
|
|
|
|
if not text_content:
|
|
|
|
raise ValueError(f"Expected text content in the last message: {last_message_content}")
|
|
|
|
|
|
|
|
# TODO: handle multiple text content.
|
2024-05-24 17:25:17 -04:00
|
|
|
return TextMessage(content=text_content[0].text.value, source=self.name)
|