From ecbc3b7806a931a1b5bb4e185ef50c41e3b9bc5e Mon Sep 17 00:00:00 2001 From: Eric Zhu Date: Tue, 28 May 2024 15:49:30 -0700 Subject: [PATCH] oai assistant and pattern fixes (#30) --- examples/patterns.py | 46 +++++++--- src/agnext/chat/agents/oai_assistant.py | 21 ++--- src/agnext/chat/patterns/group_chat.py | 39 ++++---- src/agnext/chat/patterns/orchestrator.py | 108 ++++++++--------------- 4 files changed, 103 insertions(+), 111 deletions(-) diff --git a/examples/patterns.py b/examples/patterns.py index f12adca79..bfcf0cae5 100644 --- a/examples/patterns.py +++ b/examples/patterns.py @@ -1,6 +1,6 @@ import argparse import asyncio -from typing import Any +import logging import openai from agnext.agent_components.model_client import OpenAI @@ -8,13 +8,18 @@ from agnext.application_components import ( SingleThreadedAgentRuntime, ) from agnext.chat.agents.oai_assistant import OpenAIAssistantAgent -from agnext.chat.messages import ChatMessage -from agnext.chat.patterns.group_chat import GroupChat, Output +from agnext.chat.patterns.group_chat import GroupChat, GroupChatOutput from agnext.chat.patterns.orchestrator import Orchestrator from agnext.chat.types import TextMessage +from agnext.core._agent import Agent +from agnext.core.intervention import DefaultInterventionHandler, DropMessage +from typing_extensions import Any, override + +logging.basicConfig(level=logging.WARNING) +logging.getLogger("agnext").setLevel(logging.DEBUG) -class ConcatOutput(Output): +class ConcatOutput(GroupChatOutput): def __init__(self) -> None: self._output = "" @@ -32,8 +37,26 @@ class ConcatOutput(Output): self._output = "" +class LoggingHandler(DefaultInterventionHandler): + @override + async def on_send(self, message: Any, *, sender: Agent | None, recipient: Agent) -> Any | type[DropMessage]: + if sender is None: + print(f"Sending message to {recipient.name}: {message}") + else: + print(f"Sending message from {sender.name} to {recipient.name}: {message}") + return message + + @override + async def on_response(self, message: Any, *, sender: Agent, recipient: Agent | None) -> Any | type[DropMessage]: + if recipient is None: + print(f"Received response from {sender.name}: {message}") + else: + print(f"Received response from {sender.name} to {recipient.name}: {message}") + return message + + async def group_chat(message: str) -> None: - runtime = SingleThreadedAgentRuntime() + runtime = SingleThreadedAgentRuntime(before_send=LoggingHandler()) joe_oai_assistant = openai.beta.assistants.create( model="gpt-3.5-turbo", @@ -67,16 +90,16 @@ async def group_chat(message: str) -> None: chat = GroupChat("Host", "A round-robin chat room.", runtime, [joe, cathy], num_rounds=5, output=ConcatOutput()) - response = runtime.send_message(ChatMessage(body=message, sender="host"), chat) + response = runtime.send_message(TextMessage(content=message, source="host"), chat) while not response.done(): await runtime.process_next() - print((await response).body) # type: ignore + await response async def orchestrator(message: str) -> None: - runtime = SingleThreadedAgentRuntime() + runtime = SingleThreadedAgentRuntime(before_send=LoggingHandler()) developer_oai_assistant = openai.beta.assistants.create( model="gpt-3.5-turbo", @@ -117,17 +140,14 @@ async def orchestrator(message: str) -> None: ) response = runtime.send_message( - ChatMessage( - body=message, - sender="customer", - ), + TextMessage(content=message, source="customer"), chat, ) while not response.done(): await runtime.process_next() - print((await response).body) # type: ignore + print((await response).content) # type: ignore if __name__ == "__main__": diff --git a/src/agnext/chat/agents/oai_assistant.py b/src/agnext/chat/agents/oai_assistant.py index 7108f454a..8ece9ac84 100644 --- a/src/agnext/chat/agents/oai_assistant.py +++ b/src/agnext/chat/agents/oai_assistant.py @@ -1,3 +1,5 @@ +from typing import Callable, Dict + import openai from agnext.agent_components.type_routed_agent import TypeRoutedAgent, message_handler @@ -15,22 +17,18 @@ class OpenAIAssistantAgent(BaseChatAgent, TypeRoutedAgent): client: openai.AsyncClient, assistant_id: str, thread_id: str, + tools: Dict[str, Callable[..., str]] | None = None, ) -> None: super().__init__(name, description, runtime) self._client = client self._assistant_id = assistant_id self._thread_id = thread_id - self._current_session_window_length = 0 + # TODO: investigate why this is 1, as setting this to 0 causes the earlest message in the window to be ignored. + self._current_session_window_length = 1 + self._tools = tools or {} - # TODO: use require_response @message_handler(TextMessage) - async def on_chat_message_with_cancellation( - self, message: TextMessage, cancellation_token: CancellationToken - ) -> None: - print("---------------") - print(f"{self.name} received message from {message.source}: {message.content}") - print("---------------") - + async def on_text_message(self, message: TextMessage, cancellation_token: CancellationToken) -> None: # Save the message to the thread. _ = await self._client.beta.threads.messages.create( thread_id=self._thread_id, @@ -43,7 +41,7 @@ class OpenAIAssistantAgent(BaseChatAgent, TypeRoutedAgent): @message_handler(Reset) async def on_reset(self, message: Reset, cancellation_token: CancellationToken) -> None: # Reset the current session window. - self._current_session_window_length = 0 + self._current_session_window_length = 1 @message_handler(RespondNow) async def on_respond_now(self, message: RespondNow, cancellation_token: CancellationToken) -> TextMessage: @@ -61,6 +59,9 @@ class OpenAIAssistantAgent(BaseChatAgent, TypeRoutedAgent): # TODO: handle other statuses. raise ValueError(f"Run did not complete successfully: {run}") + # Increment the current session window length. + self._current_session_window_length += 1 + # Get the last message from the run. response = await self._client.beta.threads.messages.list(self._thread_id, run_id=run.id, order="desc", limit=1) last_message_content = response.data[0].content diff --git a/src/agnext/chat/patterns/group_chat.py b/src/agnext/chat/patterns/group_chat.py index 83d9f1ec8..098d6934b 100644 --- a/src/agnext/chat/patterns/group_chat.py +++ b/src/agnext/chat/patterns/group_chat.py @@ -1,12 +1,12 @@ from typing import Any, List, Protocol, Sequence -from agnext.chat.types import Reset, RespondNow - +from ...agent_components.type_routed_agent import TypeRoutedAgent, message_handler from ...core import AgentRuntime, CancellationToken from ..agents.base import BaseChatAgent +from ..types import Reset, RespondNow, TextMessage -class Output(Protocol): +class GroupChatOutput(Protocol): def on_message_received(self, message: Any) -> None: ... def get_output(self) -> Any: ... @@ -14,7 +14,7 @@ class Output(Protocol): def reset(self) -> None: ... -class GroupChat(BaseChatAgent): +class GroupChat(BaseChatAgent, TypeRoutedAgent): def __init__( self, name: str, @@ -22,42 +22,43 @@ class GroupChat(BaseChatAgent): runtime: AgentRuntime, agents: Sequence[BaseChatAgent], num_rounds: int, - output: Output, + output: GroupChatOutput, ) -> None: - super().__init__(name, description, runtime) self._agents = agents self._num_rounds = num_rounds self._history: List[Any] = [] self._output = output + super().__init__(name, description, runtime) @property def subscriptions(self) -> Sequence[type]: agent_sublists = [agent.subscriptions for agent in self._agents] return [Reset, RespondNow] + [item for sublist in agent_sublists for item in sublist] - async def on_message(self, message: Any, cancellation_token: CancellationToken) -> Any | None: - if isinstance(message, Reset): - # Reset the history. - self._history = [] - # TODO: reset sub-agents? + @message_handler(Reset) + async def on_reset(self, message: Reset, cancellation_token: CancellationToken) -> None: + self._history.clear() - if isinstance(message, RespondNow): - # TODO reset... - return self._output.get_output() + @message_handler(RespondNow) + async def on_respond_now(self, message: RespondNow, cancellation_token: CancellationToken) -> Any: + return self._output.get_output() + @message_handler(TextMessage) + async def on_text_message(self, message: Any, cancellation_token: CancellationToken) -> Any: # TODO: how should we handle the group chat receiving a message while in the middle of a conversation? # Should this class disallow it? self._history.append(message) round = 0 + prev_speaker = None while round < self._num_rounds: # TODO: add support for advanced speaker selection. # Select speaker (round-robin for now). speaker = self._agents[round % len(self._agents)] - # Send the last message to all agents. - for agent in [agent for agent in self._agents]: + # Send the last message to all agents except the previous speaker. + for agent in [agent for agent in self._agents if agent is not prev_speaker]: # TODO gather and await _ = await self._send_message( self._history[-1], @@ -66,6 +67,7 @@ class GroupChat(BaseChatAgent): ) # TODO handle if response is not None + # Request the speaker to speak. response = await self._send_message( RespondNow(), speaker, @@ -73,12 +75,13 @@ class GroupChat(BaseChatAgent): ) if response is not None: - # 4. Append the response to the history. + # Append the response to the history. self._history.append(response) self._output.on_message_received(response) - # 6. Increment the round. + # Increment the round. round += 1 + prev_speaker = speaker output = self._output.get_output() self._output.reset() diff --git a/src/agnext/chat/patterns/orchestrator.py b/src/agnext/chat/patterns/orchestrator.py index 21eb03f49..f3dc3b6d3 100644 --- a/src/agnext/chat/patterns/orchestrator.py +++ b/src/agnext/chat/patterns/orchestrator.py @@ -6,7 +6,7 @@ from ...agent_components.type_routed_agent import TypeRoutedAgent, message_handl from ...agent_components.types import AssistantMessage, LLMMessage, UserMessage from ...core import AgentRuntime, CancellationToken from ..agents.base import BaseChatAgent -from ..messages import ChatMessage +from ..types import RespondNow, TextMessage class Orchestrator(BaseChatAgent, TypeRoutedAgent): @@ -27,26 +27,19 @@ class Orchestrator(BaseChatAgent, TypeRoutedAgent): self._max_turns = max_turns self._max_stalled_turns_before_retry = max_stalled_turns_before_retry self._max_retry_attempts_before_educated_guess = max_retry_attempts - self._history: List[ChatMessage] = [] + self._history: List[TextMessage] = [] - @message_handler(ChatMessage) - async def on_chat_message( + @message_handler(TextMessage) + async def on_text_message( self, - message: ChatMessage, + message: TextMessage, cancellation_token: CancellationToken, - ) -> ChatMessage | None: + ) -> TextMessage | None: # A task is received. - task = message.body - - if message.reset: - # Reset the history. - self._history = [] - if message.save_message_only: - # TODO: what should we do with save_message_only messages for this pattern? - return ChatMessage(body="OK", sender=self.name) + task = message.content # Prepare the task. - team, names, facts, plan = await self._prepare_task(task, message.sender) + team, names, facts, plan = await self._prepare_task(task, message.source) # Main loop. total_turns = 0 @@ -74,11 +67,9 @@ Some additional points to consider: # Send the task specs to the team and signal a reset. for agent in self._agents: self._send_message( - ChatMessage( - body=task_specs, - sender=self.name, - save_message_only=True, - reset=True, + TextMessage( + content=task_specs, + source=self.name, ), agent, ) @@ -96,18 +87,13 @@ Some additional points to consider: stalled_turns = 0 while total_turns < self._max_turns: # Reflect on the task. - data = await self._reflect_on_task(task, team, names, ledger, message.sender) + data = await self._reflect_on_task(task, team, names, ledger, message.source) # Check if the request is satisfied. if data["is_request_satisfied"]["answer"]: - return ChatMessage( - body="The task has been successfully addressed.", - sender=self.name, - payload={ - "ledgers": ledgers, - "status": "success", - "reason": data["is_request_satisfied"]["reason"], - }, + return TextMessage( + content=f"The task has been successfully addressed. {data['is_request_satisfied']['reason']}", + source=self.name, ) # Update stalled turns. @@ -121,7 +107,7 @@ Some additional points to consider: # In a retry, we need to rewrite the facts and the plan. # Rewrite the facts. - facts = await self._rewrite_facts(facts, ledger, message.sender) + facts = await self._rewrite_facts(facts, ledger, message.source) # Increment the retry attempts. retry_attempts += 1 @@ -129,20 +115,15 @@ Some additional points to consider: # Check if we should just guess. if retry_attempts > self._max_retry_attempts_before_educated_guess: # Make an educated guess. - educated_guess = await self._educated_guess(facts, ledger, message.sender) + educated_guess = await self._educated_guess(facts, ledger, message.source) if educated_guess["has_educated_guesses"]["answer"]: - return ChatMessage( - body="The task is addressed with an educated guess.", - sender=self.name, - payload={ - "ledgers": ledgers, - "status": "educated_guess", - "reason": educated_guess["has_educated_guesses"]["reason"], - }, + return TextMessage( + content=f"The task is addressed with an educated guess. {educated_guess['has_educated_guesses']['reason']}", + source=self.name, ) # Come up with a new plan. - plan = await self._rewrite_plan(team, ledger, message.sender) + plan = await self._rewrite_plan(team, ledger, message.source) # Exit the inner loop. break @@ -152,28 +133,21 @@ Some additional points to consider: if subtask is None: subtask = "" + # Update agents. + for agent in [agent for agent in self._agents]: + _ = await self._send_message( + TextMessage(content=subtask, source=self.name), + agent, + ) + # Find the speaker. try: speaker = next(agent for agent in self._agents if agent.name == data["next_speaker"]["answer"]) except StopIteration as e: raise ValueError(f"Invalid next speaker: {data['next_speaker']['answer']}") from e - # Update all other agents. - for agent in [agent for agent in self._agents if agent != speaker]: - _ = await self._send_message( - ChatMessage( - body=subtask, - sender=self.name, - save_message_only=True, - ), - agent, - ) - - # Update the speaker and ask to speak. - speaker_response = await self._send_message( - ChatMessage(body=subtask, sender=self.name), - speaker, - ) + # As speaker to speak. + speaker_response = await self._send_message(RespondNow(), speaker) assert speaker_response is not None @@ -188,10 +162,9 @@ Some additional points to consider: # Update all other agents with the speaker's response. for agent in [agent for agent in self._agents if agent != speaker]: _ = await self._send_message( - ChatMessage( - body=speaker_response.body, - sender=speaker_response.sender, - save_message_only=True, + TextMessage( + content=speaker_response.content, + source=speaker_response.source, ), agent, ) @@ -199,22 +172,17 @@ Some additional points to consider: # Update the ledger. ledger.append( UserMessage( - content=speaker_response.body, - source=speaker_response.sender, + content=speaker_response.content, + source=speaker_response.source, ) ) # Increment the total turns. total_turns += 1 - return ChatMessage( - body="The task was not addressed", - sender=self.name, - payload={ - "ledgers": ledgers, - "status": "failure", - "reason": "The maximum number of turns was reached.", - }, + return TextMessage( + content="The task was not addressed. The maximum number of turns was reached.", + source=self.name, ) async def _prepare_task(self, task: str, sender: str) -> Tuple[str, str, str, str]: