oai assistant and pattern fixes (#30)

This commit is contained in:
Eric Zhu 2024-05-28 15:49:30 -07:00 committed by GitHub
parent f4a5835772
commit ecbc3b7806
4 changed files with 103 additions and 111 deletions

View File

@ -1,6 +1,6 @@
import argparse
import asyncio
from typing import Any
import logging
import openai
from agnext.agent_components.model_client import OpenAI
@ -8,13 +8,18 @@ from agnext.application_components import (
SingleThreadedAgentRuntime,
)
from agnext.chat.agents.oai_assistant import OpenAIAssistantAgent
from agnext.chat.messages import ChatMessage
from agnext.chat.patterns.group_chat import GroupChat, Output
from agnext.chat.patterns.group_chat import GroupChat, GroupChatOutput
from agnext.chat.patterns.orchestrator import Orchestrator
from agnext.chat.types import TextMessage
from agnext.core._agent import Agent
from agnext.core.intervention import DefaultInterventionHandler, DropMessage
from typing_extensions import Any, override
logging.basicConfig(level=logging.WARNING)
logging.getLogger("agnext").setLevel(logging.DEBUG)
class ConcatOutput(Output):
class ConcatOutput(GroupChatOutput):
def __init__(self) -> None:
self._output = ""
@ -32,8 +37,26 @@ class ConcatOutput(Output):
self._output = ""
class LoggingHandler(DefaultInterventionHandler):
@override
async def on_send(self, message: Any, *, sender: Agent | None, recipient: Agent) -> Any | type[DropMessage]:
if sender is None:
print(f"Sending message to {recipient.name}: {message}")
else:
print(f"Sending message from {sender.name} to {recipient.name}: {message}")
return message
@override
async def on_response(self, message: Any, *, sender: Agent, recipient: Agent | None) -> Any | type[DropMessage]:
if recipient is None:
print(f"Received response from {sender.name}: {message}")
else:
print(f"Received response from {sender.name} to {recipient.name}: {message}")
return message
async def group_chat(message: str) -> None:
runtime = SingleThreadedAgentRuntime()
runtime = SingleThreadedAgentRuntime(before_send=LoggingHandler())
joe_oai_assistant = openai.beta.assistants.create(
model="gpt-3.5-turbo",
@ -67,16 +90,16 @@ async def group_chat(message: str) -> None:
chat = GroupChat("Host", "A round-robin chat room.", runtime, [joe, cathy], num_rounds=5, output=ConcatOutput())
response = runtime.send_message(ChatMessage(body=message, sender="host"), chat)
response = runtime.send_message(TextMessage(content=message, source="host"), chat)
while not response.done():
await runtime.process_next()
print((await response).body) # type: ignore
await response
async def orchestrator(message: str) -> None:
runtime = SingleThreadedAgentRuntime()
runtime = SingleThreadedAgentRuntime(before_send=LoggingHandler())
developer_oai_assistant = openai.beta.assistants.create(
model="gpt-3.5-turbo",
@ -117,17 +140,14 @@ async def orchestrator(message: str) -> None:
)
response = runtime.send_message(
ChatMessage(
body=message,
sender="customer",
),
TextMessage(content=message, source="customer"),
chat,
)
while not response.done():
await runtime.process_next()
print((await response).body) # type: ignore
print((await response).content) # type: ignore
if __name__ == "__main__":

View File

@ -1,3 +1,5 @@
from typing import Callable, Dict
import openai
from agnext.agent_components.type_routed_agent import TypeRoutedAgent, message_handler
@ -15,22 +17,18 @@ class OpenAIAssistantAgent(BaseChatAgent, TypeRoutedAgent):
client: openai.AsyncClient,
assistant_id: str,
thread_id: str,
tools: Dict[str, Callable[..., str]] | None = None,
) -> None:
super().__init__(name, description, runtime)
self._client = client
self._assistant_id = assistant_id
self._thread_id = thread_id
self._current_session_window_length = 0
# TODO: investigate why this is 1, as setting this to 0 causes the earlest message in the window to be ignored.
self._current_session_window_length = 1
self._tools = tools or {}
# TODO: use require_response
@message_handler(TextMessage)
async def on_chat_message_with_cancellation(
self, message: TextMessage, cancellation_token: CancellationToken
) -> None:
print("---------------")
print(f"{self.name} received message from {message.source}: {message.content}")
print("---------------")
async def on_text_message(self, message: TextMessage, cancellation_token: CancellationToken) -> None:
# Save the message to the thread.
_ = await self._client.beta.threads.messages.create(
thread_id=self._thread_id,
@ -43,7 +41,7 @@ class OpenAIAssistantAgent(BaseChatAgent, TypeRoutedAgent):
@message_handler(Reset)
async def on_reset(self, message: Reset, cancellation_token: CancellationToken) -> None:
# Reset the current session window.
self._current_session_window_length = 0
self._current_session_window_length = 1
@message_handler(RespondNow)
async def on_respond_now(self, message: RespondNow, cancellation_token: CancellationToken) -> TextMessage:
@ -61,6 +59,9 @@ class OpenAIAssistantAgent(BaseChatAgent, TypeRoutedAgent):
# TODO: handle other statuses.
raise ValueError(f"Run did not complete successfully: {run}")
# Increment the current session window length.
self._current_session_window_length += 1
# Get the last message from the run.
response = await self._client.beta.threads.messages.list(self._thread_id, run_id=run.id, order="desc", limit=1)
last_message_content = response.data[0].content

View File

@ -1,12 +1,12 @@
from typing import Any, List, Protocol, Sequence
from agnext.chat.types import Reset, RespondNow
from ...agent_components.type_routed_agent import TypeRoutedAgent, message_handler
from ...core import AgentRuntime, CancellationToken
from ..agents.base import BaseChatAgent
from ..types import Reset, RespondNow, TextMessage
class Output(Protocol):
class GroupChatOutput(Protocol):
def on_message_received(self, message: Any) -> None: ...
def get_output(self) -> Any: ...
@ -14,7 +14,7 @@ class Output(Protocol):
def reset(self) -> None: ...
class GroupChat(BaseChatAgent):
class GroupChat(BaseChatAgent, TypeRoutedAgent):
def __init__(
self,
name: str,
@ -22,42 +22,43 @@ class GroupChat(BaseChatAgent):
runtime: AgentRuntime,
agents: Sequence[BaseChatAgent],
num_rounds: int,
output: Output,
output: GroupChatOutput,
) -> None:
super().__init__(name, description, runtime)
self._agents = agents
self._num_rounds = num_rounds
self._history: List[Any] = []
self._output = output
super().__init__(name, description, runtime)
@property
def subscriptions(self) -> Sequence[type]:
agent_sublists = [agent.subscriptions for agent in self._agents]
return [Reset, RespondNow] + [item for sublist in agent_sublists for item in sublist]
async def on_message(self, message: Any, cancellation_token: CancellationToken) -> Any | None:
if isinstance(message, Reset):
# Reset the history.
self._history = []
# TODO: reset sub-agents?
@message_handler(Reset)
async def on_reset(self, message: Reset, cancellation_token: CancellationToken) -> None:
self._history.clear()
if isinstance(message, RespondNow):
# TODO reset...
return self._output.get_output()
@message_handler(RespondNow)
async def on_respond_now(self, message: RespondNow, cancellation_token: CancellationToken) -> Any:
return self._output.get_output()
@message_handler(TextMessage)
async def on_text_message(self, message: Any, cancellation_token: CancellationToken) -> Any:
# TODO: how should we handle the group chat receiving a message while in the middle of a conversation?
# Should this class disallow it?
self._history.append(message)
round = 0
prev_speaker = None
while round < self._num_rounds:
# TODO: add support for advanced speaker selection.
# Select speaker (round-robin for now).
speaker = self._agents[round % len(self._agents)]
# Send the last message to all agents.
for agent in [agent for agent in self._agents]:
# Send the last message to all agents except the previous speaker.
for agent in [agent for agent in self._agents if agent is not prev_speaker]:
# TODO gather and await
_ = await self._send_message(
self._history[-1],
@ -66,6 +67,7 @@ class GroupChat(BaseChatAgent):
)
# TODO handle if response is not None
# Request the speaker to speak.
response = await self._send_message(
RespondNow(),
speaker,
@ -73,12 +75,13 @@ class GroupChat(BaseChatAgent):
)
if response is not None:
# 4. Append the response to the history.
# Append the response to the history.
self._history.append(response)
self._output.on_message_received(response)
# 6. Increment the round.
# Increment the round.
round += 1
prev_speaker = speaker
output = self._output.get_output()
self._output.reset()

View File

@ -6,7 +6,7 @@ from ...agent_components.type_routed_agent import TypeRoutedAgent, message_handl
from ...agent_components.types import AssistantMessage, LLMMessage, UserMessage
from ...core import AgentRuntime, CancellationToken
from ..agents.base import BaseChatAgent
from ..messages import ChatMessage
from ..types import RespondNow, TextMessage
class Orchestrator(BaseChatAgent, TypeRoutedAgent):
@ -27,26 +27,19 @@ class Orchestrator(BaseChatAgent, TypeRoutedAgent):
self._max_turns = max_turns
self._max_stalled_turns_before_retry = max_stalled_turns_before_retry
self._max_retry_attempts_before_educated_guess = max_retry_attempts
self._history: List[ChatMessage] = []
self._history: List[TextMessage] = []
@message_handler(ChatMessage)
async def on_chat_message(
@message_handler(TextMessage)
async def on_text_message(
self,
message: ChatMessage,
message: TextMessage,
cancellation_token: CancellationToken,
) -> ChatMessage | None:
) -> TextMessage | None:
# A task is received.
task = message.body
if message.reset:
# Reset the history.
self._history = []
if message.save_message_only:
# TODO: what should we do with save_message_only messages for this pattern?
return ChatMessage(body="OK", sender=self.name)
task = message.content
# Prepare the task.
team, names, facts, plan = await self._prepare_task(task, message.sender)
team, names, facts, plan = await self._prepare_task(task, message.source)
# Main loop.
total_turns = 0
@ -74,11 +67,9 @@ Some additional points to consider:
# Send the task specs to the team and signal a reset.
for agent in self._agents:
self._send_message(
ChatMessage(
body=task_specs,
sender=self.name,
save_message_only=True,
reset=True,
TextMessage(
content=task_specs,
source=self.name,
),
agent,
)
@ -96,18 +87,13 @@ Some additional points to consider:
stalled_turns = 0
while total_turns < self._max_turns:
# Reflect on the task.
data = await self._reflect_on_task(task, team, names, ledger, message.sender)
data = await self._reflect_on_task(task, team, names, ledger, message.source)
# Check if the request is satisfied.
if data["is_request_satisfied"]["answer"]:
return ChatMessage(
body="The task has been successfully addressed.",
sender=self.name,
payload={
"ledgers": ledgers,
"status": "success",
"reason": data["is_request_satisfied"]["reason"],
},
return TextMessage(
content=f"The task has been successfully addressed. {data['is_request_satisfied']['reason']}",
source=self.name,
)
# Update stalled turns.
@ -121,7 +107,7 @@ Some additional points to consider:
# In a retry, we need to rewrite the facts and the plan.
# Rewrite the facts.
facts = await self._rewrite_facts(facts, ledger, message.sender)
facts = await self._rewrite_facts(facts, ledger, message.source)
# Increment the retry attempts.
retry_attempts += 1
@ -129,20 +115,15 @@ Some additional points to consider:
# Check if we should just guess.
if retry_attempts > self._max_retry_attempts_before_educated_guess:
# Make an educated guess.
educated_guess = await self._educated_guess(facts, ledger, message.sender)
educated_guess = await self._educated_guess(facts, ledger, message.source)
if educated_guess["has_educated_guesses"]["answer"]:
return ChatMessage(
body="The task is addressed with an educated guess.",
sender=self.name,
payload={
"ledgers": ledgers,
"status": "educated_guess",
"reason": educated_guess["has_educated_guesses"]["reason"],
},
return TextMessage(
content=f"The task is addressed with an educated guess. {educated_guess['has_educated_guesses']['reason']}",
source=self.name,
)
# Come up with a new plan.
plan = await self._rewrite_plan(team, ledger, message.sender)
plan = await self._rewrite_plan(team, ledger, message.source)
# Exit the inner loop.
break
@ -152,28 +133,21 @@ Some additional points to consider:
if subtask is None:
subtask = ""
# Update agents.
for agent in [agent for agent in self._agents]:
_ = await self._send_message(
TextMessage(content=subtask, source=self.name),
agent,
)
# Find the speaker.
try:
speaker = next(agent for agent in self._agents if agent.name == data["next_speaker"]["answer"])
except StopIteration as e:
raise ValueError(f"Invalid next speaker: {data['next_speaker']['answer']}") from e
# Update all other agents.
for agent in [agent for agent in self._agents if agent != speaker]:
_ = await self._send_message(
ChatMessage(
body=subtask,
sender=self.name,
save_message_only=True,
),
agent,
)
# Update the speaker and ask to speak.
speaker_response = await self._send_message(
ChatMessage(body=subtask, sender=self.name),
speaker,
)
# As speaker to speak.
speaker_response = await self._send_message(RespondNow(), speaker)
assert speaker_response is not None
@ -188,10 +162,9 @@ Some additional points to consider:
# Update all other agents with the speaker's response.
for agent in [agent for agent in self._agents if agent != speaker]:
_ = await self._send_message(
ChatMessage(
body=speaker_response.body,
sender=speaker_response.sender,
save_message_only=True,
TextMessage(
content=speaker_response.content,
source=speaker_response.source,
),
agent,
)
@ -199,22 +172,17 @@ Some additional points to consider:
# Update the ledger.
ledger.append(
UserMessage(
content=speaker_response.body,
source=speaker_response.sender,
content=speaker_response.content,
source=speaker_response.source,
)
)
# Increment the total turns.
total_turns += 1
return ChatMessage(
body="The task was not addressed",
sender=self.name,
payload={
"ledgers": ledgers,
"status": "failure",
"reason": "The maximum number of turns was reached.",
},
return TextMessage(
content="The task was not addressed. The maximum number of turns was reached.",
source=self.name,
)
async def _prepare_task(self, task: str, sender: str) -> Tuple[str, str, str, str]: