autogen/src/agnext/chat/agents/oai_assistant.py

109 lines
4.4 KiB
Python
Raw Normal View History

from typing import Any, Callable, List, Mapping
2024-05-28 15:49:30 -07:00
2024-05-23 08:23:24 -07:00
import openai
from openai import AsyncAssistantEventHandler
from openai.types.beta import AssistantResponseFormatParam
2024-05-23 08:23:24 -07:00
2024-06-09 12:11:36 -07:00
from ...components import TypeRoutedAgent, message_handler
from ...core import AgentRuntime, CancellationToken
from ..types import Reset, RespondNow, ResponseFormat, TextMessage
2024-05-23 08:23:24 -07:00
2024-06-09 12:11:36 -07:00
class OpenAIAssistantAgent(TypeRoutedAgent):
2024-05-23 08:23:24 -07:00
def __init__(
self,
name: str,
description: str,
runtime: AgentRuntime,
2024-05-23 08:23:24 -07:00
client: openai.AsyncClient,
assistant_id: str,
thread_id: str,
assistant_event_handler_factory: Callable[[], AsyncAssistantEventHandler] | None = None,
2024-05-23 08:23:24 -07:00
) -> None:
super().__init__(name, description, runtime)
self._client = client
self._assistant_id = assistant_id
self._thread_id = thread_id
self._assistant_event_handler_factory = assistant_event_handler_factory
2024-05-23 08:23:24 -07:00
@message_handler()
2024-05-28 15:49:30 -07:00
async def on_text_message(self, message: TextMessage, cancellation_token: CancellationToken) -> None:
2024-05-23 08:23:24 -07:00
# Save the message to the thread.
_ = await self._client.beta.threads.messages.create(
thread_id=self._thread_id,
content=message.content,
2024-05-23 08:23:24 -07:00
role="user",
metadata={"sender": message.source},
2024-05-23 08:23:24 -07:00
)
@message_handler()
async def on_reset(self, message: Reset, cancellation_token: CancellationToken) -> None:
# Get all messages in this thread.
all_msgs: List[str] = []
while True:
if not all_msgs:
msgs = await self._client.beta.threads.messages.list(self._thread_id)
else:
msgs = await self._client.beta.threads.messages.list(self._thread_id, after=all_msgs[-1])
for msg in msgs.data:
all_msgs.append(msg.id)
if not msgs.has_next_page():
break
# Delete all the messages.
for msg_id in all_msgs:
status = await self._client.beta.threads.messages.delete(message_id=msg_id, thread_id=self._thread_id)
assert status.deleted is True
@message_handler()
async def on_respond_now(self, message: RespondNow, cancellation_token: CancellationToken) -> TextMessage:
# Handle response format.
if message.response_format == ResponseFormat.json_object:
response_format = AssistantResponseFormatParam(type="json_object")
else:
response_format = AssistantResponseFormatParam(type="text")
if self._assistant_event_handler_factory is not None:
# Use event handler and streaming mode if available.
async with self._client.beta.threads.runs.stream(
thread_id=self._thread_id,
assistant_id=self._assistant_id,
event_handler=self._assistant_event_handler_factory(),
response_format=response_format,
) as stream:
run = await stream.get_final_run()
else:
# Use blocking mode.
run = await self._client.beta.threads.runs.create(
thread_id=self._thread_id,
assistant_id=self._assistant_id,
response_format=response_format,
)
2024-05-23 08:23:24 -07:00
if run.status != "completed":
# TODO: handle other statuses.
raise ValueError(f"Run did not complete successfully: {run}")
# Get the last message from the run.
response = await self._client.beta.threads.messages.list(self._thread_id, run_id=run.id, order="desc", limit=1)
last_message_content = response.data[0].content
# TODO: handle array of content.
text_content = [content for content in last_message_content if content.type == "text"]
if not text_content:
raise ValueError(f"Expected text content in the last message: {last_message_content}")
# TODO: handle multiple text content.
return TextMessage(content=text_content[0].text.value, source=self.name)
def save_state(self) -> Mapping[str, Any]:
return {
"description": self.description,
"assistant_id": self._assistant_id,
"thread_id": self._thread_id,
}
def load_state(self, state: Mapping[str, Any]) -> None:
self._description = state["description"]
self._assistant_id = state["assistant_id"]
self._thread_id = state["thread_id"]