From f31ff663685a37f7960c4911b1837d36f1f32a13 Mon Sep 17 00:00:00 2001 From: Eric Zhu Date: Fri, 25 Oct 2024 10:57:04 -0700 Subject: [PATCH] Refactor agent chat to prepare for handoff/swarm (#3949) Add handoff message type to chat message types Add Swarm group chat that uses handoff message to select next speaker Remove tool call and tool call result message types from chat message types Remove BaseToolUseChatAgent, move tool call handling from group chat's chat agent container upward to the ToolUseAssistantAgent implementation, which subclasses BaseChatAgent directly. Renaming for better clarity --------- Co-authored-by: Victor Dibia --- .../src/autogen_agentchat/agents/__init__.py | 3 +- .../agents/_base_chat_agent.py | 23 +---- .../agents/_tool_use_assistant_agent.py | 92 ++++++++++++++----- .../src/autogen_agentchat/base/__init__.py | 3 +- .../src/autogen_agentchat/base/_chat_agent.py | 13 +-- .../logging/_console_log_handler.py | 19 ++-- .../logging/_file_log_handler.py | 29 ++++-- .../src/autogen_agentchat/messages.py | 29 ++---- .../src/autogen_agentchat/teams/__init__.py | 2 + .../src/autogen_agentchat/teams/_events.py | 40 ++------ .../_group_chat/_base_chat_agent_container.py | 92 ------------------- .../teams/_group_chat/_base_group_chat.py | 59 ++++-------- .../_group_chat/_base_group_chat_manager.py | 53 +++++++---- .../_group_chat/_chat_agent_container.py | 48 ++++++++++ .../_group_chat/_round_robin_group_chat.py | 15 ++- .../teams/_group_chat/_selector_group_chat.py | 11 ++- .../teams/_group_chat/_swarm_group_chat.py | 72 +++++++++++++++ .../tests/test_group_chat.py | 47 +++++++++- .../tests/test_tool_use_assistant_agent.py | 6 +- 19 files changed, 363 insertions(+), 293 deletions(-) delete mode 100644 python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_chat_agent_container.py create mode 100644 python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_chat_agent_container.py create mode 100644 python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_swarm_group_chat.py diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/__init__.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/__init__.py index 1c3078d01..19cdec548 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/__init__.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/__init__.py @@ -1,11 +1,10 @@ -from ._base_chat_agent import BaseChatAgent, BaseToolUseChatAgent +from ._base_chat_agent import BaseChatAgent from ._code_executor_agent import CodeExecutorAgent from ._coding_assistant_agent import CodingAssistantAgent from ._tool_use_assistant_agent import ToolUseAssistantAgent __all__ = [ "BaseChatAgent", - "BaseToolUseChatAgent", "CodeExecutorAgent", "CodingAssistantAgent", "ToolUseAssistantAgent", diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_base_chat_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_base_chat_agent.py index 62bac59d5..77bf4c02c 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_base_chat_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_base_chat_agent.py @@ -1,10 +1,9 @@ from abc import ABC, abstractmethod -from typing import List, Sequence +from typing import Sequence from autogen_core.base import CancellationToken -from autogen_core.components.tools import Tool -from ..base import ChatAgent, TaskResult, TerminationCondition, ToolUseChatAgent +from ..base import ChatAgent, TaskResult, TerminationCondition from ..messages import ChatMessage from ..teams import RoundRobinGroupChat @@ -51,21 +50,3 @@ class BaseChatAgent(ChatAgent, ABC): termination_condition=termination_condition, ) return result - - -class BaseToolUseChatAgent(BaseChatAgent, ToolUseChatAgent): - """Base class for a chat agent that can use tools. - - Subclass this base class to create an agent class that uses tools by returning - ToolCallMessage message from the :meth:`on_messages` method and receiving - ToolCallResultMessage message from the input to the :meth:`on_messages` method. - """ - - def __init__(self, name: str, description: str, registered_tools: List[Tool]) -> None: - super().__init__(name, description) - self._registered_tools = registered_tools - - @property - def registered_tools(self) -> List[Tool]: - """The list of tools that the agent can use.""" - return self._registered_tools diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_tool_use_assistant_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_tool_use_assistant_agent.py index 37022acd0..fde1a3e49 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_tool_use_assistant_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_tool_use_assistant_agent.py @@ -1,3 +1,6 @@ +import asyncio +import json +import logging from typing import Any, Awaitable, Callable, List, Sequence from autogen_core.base import CancellationToken @@ -5,25 +8,45 @@ from autogen_core.components import FunctionCall from autogen_core.components.models import ( AssistantMessage, ChatCompletionClient, + FunctionExecutionResult, FunctionExecutionResultMessage, LLMMessage, SystemMessage, UserMessage, ) from autogen_core.components.tools import FunctionTool, Tool +from pydantic import BaseModel, ConfigDict +from .. import EVENT_LOGGER_NAME from ..messages import ( ChatMessage, - MultiModalMessage, StopMessage, TextMessage, - ToolCallMessage, - ToolCallResultMessage, ) -from ._base_chat_agent import BaseToolUseChatAgent +from ._base_chat_agent import BaseChatAgent + +event_logger = logging.getLogger(EVENT_LOGGER_NAME) -class ToolUseAssistantAgent(BaseToolUseChatAgent): +class ToolCallEvent(BaseModel): + """A tool call event.""" + + tool_calls: List[FunctionCall] + """The tool call message.""" + + model_config = ConfigDict(arbitrary_types_allowed=True) + + +class ToolCallResultEvent(BaseModel): + """A tool call result event.""" + + tool_call_results: List[FunctionExecutionResult] + """The tool call result message.""" + + model_config = ConfigDict(arbitrary_types_allowed=True) + + +class ToolUseAssistantAgent(BaseChatAgent): """An agent that provides assistance with tool use. It responds with a StopMessage when 'terminate' is detected in the response. @@ -45,46 +68,50 @@ class ToolUseAssistantAgent(BaseToolUseChatAgent): description: str = "An agent that provides assistance with ability to use tools.", system_message: str = "You are a helpful AI assistant. Solve tasks using your tools. Reply with 'TERMINATE' when the task has been completed.", ): - tools: List[Tool] = [] + super().__init__(name=name, description=description) + self._model_client = model_client + self._system_messages = [SystemMessage(content=system_message)] + self._tools: List[Tool] = [] for tool in registered_tools: if isinstance(tool, Tool): - tools.append(tool) + self._tools.append(tool) elif callable(tool): if hasattr(tool, "__doc__") and tool.__doc__ is not None: description = tool.__doc__ else: description = "" - tools.append(FunctionTool(tool, description=description)) + self._tools.append(FunctionTool(tool, description=description)) else: raise ValueError(f"Unsupported tool type: {type(tool)}") - super().__init__(name=name, description=description, registered_tools=tools) - self._model_client = model_client - self._system_messages = [SystemMessage(content=system_message)] - self._tool_schema = [tool.schema for tool in tools] self._model_context: List[LLMMessage] = [] async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage: # Add messages to the model context. for msg in messages: - if isinstance(msg, ToolCallResultMessage): - self._model_context.append(FunctionExecutionResultMessage(content=msg.content)) - elif not isinstance(msg, TextMessage | MultiModalMessage | StopMessage): - raise ValueError(f"Unsupported message type: {type(msg)}") - else: - self._model_context.append(UserMessage(content=msg.content, source=msg.source)) + # TODO: add special handling for handoff messages + self._model_context.append(UserMessage(content=msg.content, source=msg.source)) # Generate an inference result based on the current model context. llm_messages = self._system_messages + self._model_context - result = await self._model_client.create( - llm_messages, tools=self._tool_schema, cancellation_token=cancellation_token - ) + result = await self._model_client.create(llm_messages, tools=self._tools, cancellation_token=cancellation_token) # Add the response to the model context. self._model_context.append(AssistantMessage(content=result.content, source=self.name)) - # Detect tool calls. - if isinstance(result.content, list) and all(isinstance(item, FunctionCall) for item in result.content): - return ToolCallMessage(content=result.content, source=self.name) + # Run tool calls until the model produces a string response. + while isinstance(result.content, list) and all(isinstance(item, FunctionCall) for item in result.content): + event_logger.debug(ToolCallEvent(tool_calls=result.content)) + # Execute the tool calls. + results = await asyncio.gather( + *[self._execute_tool_call(call, cancellation_token) for call in result.content] + ) + event_logger.debug(ToolCallResultEvent(tool_call_results=results)) + self._model_context.append(FunctionExecutionResultMessage(content=results)) + # Generate an inference result based on the current model context. + result = await self._model_client.create( + self._model_context, tools=self._tools, cancellation_token=cancellation_token + ) + self._model_context.append(AssistantMessage(content=result.content, source=self.name)) assert isinstance(result.content, str) # Detect stop request. @@ -93,3 +120,20 @@ class ToolUseAssistantAgent(BaseToolUseChatAgent): return StopMessage(content=result.content, source=self.name) return TextMessage(content=result.content, source=self.name) + + async def _execute_tool_call( + self, tool_call: FunctionCall, cancellation_token: CancellationToken + ) -> FunctionExecutionResult: + """Execute a tool call and return the result.""" + try: + if not self._tools: + raise ValueError("No tools are available.") + tool = next((t for t in self._tools if t.name == tool_call.name), None) + if tool is None: + raise ValueError(f"The tool '{tool_call.name}' is not available.") + arguments = json.loads(tool_call.arguments) + result = await tool.run_json(arguments, cancellation_token) + result_as_str = tool.return_value_as_string(result) + return FunctionExecutionResult(content=result_as_str, call_id=tool_call.id) + except Exception as e: + return FunctionExecutionResult(content=f"Error: {e}", call_id=tool_call.id) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/base/__init__.py b/python/packages/autogen-agentchat/src/autogen_agentchat/base/__init__.py index 36845eb82..436d69fb0 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/base/__init__.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/base/__init__.py @@ -1,11 +1,10 @@ -from ._chat_agent import ChatAgent, ToolUseChatAgent +from ._chat_agent import ChatAgent from ._task import TaskResult, TaskRunner from ._team import Team from ._termination import TerminatedException, TerminationCondition __all__ = [ "ChatAgent", - "ToolUseChatAgent", "Team", "TerminatedException", "TerminationCondition", diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/base/_chat_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/base/_chat_agent.py index 6200050d4..d82539540 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/base/_chat_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/base/_chat_agent.py @@ -1,7 +1,6 @@ -from typing import List, Protocol, Sequence, runtime_checkable +from typing import Protocol, Sequence, runtime_checkable from autogen_core.base import CancellationToken -from autogen_core.components.tools import Tool from ..messages import ChatMessage from ._task import TaskResult, TaskRunner @@ -38,13 +37,3 @@ class ChatAgent(TaskRunner, Protocol): ) -> TaskResult: """Run the agent with the given task and return the result.""" ... - - -@runtime_checkable -class ToolUseChatAgent(ChatAgent, Protocol): - """Protocol for a chat agent that can use tools.""" - - @property - def registered_tools(self) -> List[Tool]: - """The list of tools that the agent can use.""" - ... diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/logging/_console_log_handler.py b/python/packages/autogen-agentchat/src/autogen_agentchat/logging/_console_log_handler.py index d0fb4ab08..95200e284 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/logging/_console_log_handler.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/logging/_console_log_handler.py @@ -3,13 +3,12 @@ import logging import sys from datetime import datetime +from ..agents._tool_use_assistant_agent import ToolCallEvent, ToolCallResultEvent from ..messages import ChatMessage, StopMessage, TextMessage from ..teams._events import ( - ContentPublishEvent, - SelectSpeakerEvent, + GroupChatPublishEvent, + GroupChatSelectSpeakerEvent, TerminationEvent, - ToolCallEvent, - ToolCallResultEvent, ) @@ -25,7 +24,7 @@ class ConsoleLogHandler(logging.Handler): def emit(self, record: logging.LogRecord) -> None: ts = datetime.fromtimestamp(record.created).isoformat() - if isinstance(record.msg, ContentPublishEvent): + if isinstance(record.msg, GroupChatPublishEvent): if record.msg.source is None: sys.stdout.write( f"\n{'-'*75} \n" @@ -41,19 +40,15 @@ class ConsoleLogHandler(logging.Handler): sys.stdout.flush() elif isinstance(record.msg, ToolCallEvent): sys.stdout.write( - f"\n{'-'*75} \n" - f"\033[91m[{ts}], Tool Call:\033[0m\n" - f"\n{self.serialize_chat_message(record.msg.agent_message)}" + f"\n{'-'*75} \n" f"\033[91m[{ts}], Tool Call:\033[0m\n" f"\n{str(record.msg.model_dump())}" ) sys.stdout.flush() elif isinstance(record.msg, ToolCallResultEvent): sys.stdout.write( - f"\n{'-'*75} \n" - f"\033[91m[{ts}], Tool Call Result:\033[0m\n" - f"\n{self.serialize_chat_message(record.msg.agent_message)}" + f"\n{'-'*75} \n" f"\033[91m[{ts}], Tool Call Result:\033[0m\n" f"\n{str(record.msg.model_dump())}" ) sys.stdout.flush() - elif isinstance(record.msg, SelectSpeakerEvent): + elif isinstance(record.msg, GroupChatSelectSpeakerEvent): sys.stdout.write( f"\n{'-'*75} \n" f"\033[91m[{ts}], Selected Next Speaker:\033[0m\n" f"\n{record.msg.selected_speaker}" ) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/logging/_file_log_handler.py b/python/packages/autogen-agentchat/src/autogen_agentchat/logging/_file_log_handler.py index ca64b0d68..24fd09418 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/logging/_file_log_handler.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/logging/_file_log_handler.py @@ -4,12 +4,11 @@ from dataclasses import asdict, is_dataclass from datetime import datetime from typing import Any +from ..agents._tool_use_assistant_agent import ToolCallEvent, ToolCallResultEvent from ..teams._events import ( - ContentPublishEvent, - SelectSpeakerEvent, + GroupChatPublishEvent, + GroupChatSelectSpeakerEvent, TerminationEvent, - ToolCallEvent, - ToolCallResultEvent, ) @@ -21,7 +20,7 @@ class FileLogHandler(logging.Handler): def emit(self, record: logging.LogRecord) -> None: ts = datetime.fromtimestamp(record.created).isoformat() - if isinstance(record.msg, ContentPublishEvent | ToolCallEvent | ToolCallResultEvent | TerminationEvent): + if isinstance(record.msg, GroupChatPublishEvent | TerminationEvent): log_entry = json.dumps( { "timestamp": ts, @@ -31,7 +30,7 @@ class FileLogHandler(logging.Handler): }, default=self.json_serializer, ) - elif isinstance(record.msg, SelectSpeakerEvent): + elif isinstance(record.msg, GroupChatSelectSpeakerEvent): log_entry = json.dumps( { "timestamp": ts, @@ -41,6 +40,24 @@ class FileLogHandler(logging.Handler): }, default=self.json_serializer, ) + elif isinstance(record.msg, ToolCallEvent): + log_entry = json.dumps( + { + "timestamp": ts, + "tool_calls": record.msg.model_dump(), + "type": "ToolCallEvent", + }, + default=self.json_serializer, + ) + elif isinstance(record.msg, ToolCallResultEvent): + log_entry = json.dumps( + { + "timestamp": ts, + "tool_call_results": record.msg.model_dump(), + "type": "ToolCallResultEvent", + }, + default=self.json_serializer, + ) else: raise ValueError(f"Unexpected log record: {record.msg}") file_record = logging.LogRecord( diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py b/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py index 6aac22248..99bd0c888 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py @@ -1,7 +1,6 @@ from typing import List -from autogen_core.components import FunctionCall, Image -from autogen_core.components.models import FunctionExecutionResult +from autogen_core.components import Image from pydantic import BaseModel @@ -26,20 +25,6 @@ class MultiModalMessage(BaseMessage): """The content of the message.""" -class ToolCallMessage(BaseMessage): - """A message containing a list of function calls.""" - - content: List[FunctionCall] - """The list of function calls.""" - - -class ToolCallResultMessage(BaseMessage): - """A message containing the results of function calls.""" - - content: List[FunctionExecutionResult] - """The list of function execution results.""" - - class StopMessage(BaseMessage): """A message requesting stop of a conversation.""" @@ -47,7 +32,14 @@ class StopMessage(BaseMessage): """The content for the stop message.""" -ChatMessage = TextMessage | MultiModalMessage | StopMessage | ToolCallMessage | ToolCallResultMessage +class HandoffMessage(BaseMessage): + """A message requesting handoff of a conversation to another agent.""" + + content: str + """The agent name to handoff the conversation to.""" + + +ChatMessage = TextMessage | MultiModalMessage | StopMessage | HandoffMessage """A message used by agents in a team.""" @@ -55,8 +47,7 @@ __all__ = [ "BaseMessage", "TextMessage", "MultiModalMessage", - "ToolCallMessage", - "ToolCallResultMessage", "StopMessage", + "HandoffMessage", "ChatMessage", ] diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/__init__.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/__init__.py index 836f15012..a2dd74d61 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/__init__.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/__init__.py @@ -1,7 +1,9 @@ from ._group_chat._round_robin_group_chat import RoundRobinGroupChat from ._group_chat._selector_group_chat import SelectorGroupChat +from ._group_chat._swarm_group_chat import Swarm __all__ = [ "RoundRobinGroupChat", "SelectorGroupChat", + "Swarm", ] diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_events.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_events.py index 5bfeff417..3442b35ce 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_events.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_events.py @@ -1,16 +1,16 @@ from autogen_core.base import AgentId from pydantic import BaseModel, ConfigDict -from ..messages import MultiModalMessage, StopMessage, TextMessage, ToolCallMessage, ToolCallResultMessage +from ..messages import ChatMessage, StopMessage -class ContentPublishEvent(BaseModel): - """An event for sharing some data. Agents receive this event should +class GroupChatPublishEvent(BaseModel): + """An group chat event for sharing some data. Agents receive this event should update their internal state (e.g., append to message history) with the content of the event. """ - agent_message: TextMessage | MultiModalMessage | StopMessage + agent_message: ChatMessage """The message published by the agent.""" source: AgentId | None = None @@ -19,39 +19,15 @@ class ContentPublishEvent(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) -class ContentRequestEvent(BaseModel): - """An event for requesting to publish a content event. - Upon receiving this event, the agent should publish a ContentPublishEvent. +class GroupChatRequestPublishEvent(BaseModel): + """An event for requesting to publish a group chat publish event. + Upon receiving this event, the agent should publish a group chat publish event. """ ... -class ToolCallEvent(BaseModel): - """An event produced when requesting a tool call.""" - - agent_message: ToolCallMessage - """The tool call message.""" - - source: AgentId - """The sender of the tool call message.""" - - model_config = ConfigDict(arbitrary_types_allowed=True) - - -class ToolCallResultEvent(BaseModel): - """An event produced when a tool call is completed.""" - - agent_message: ToolCallResultMessage - """The tool call result message.""" - - source: AgentId - """The sender of the tool call result message.""" - - model_config = ConfigDict(arbitrary_types_allowed=True) - - -class SelectSpeakerEvent(BaseModel): +class GroupChatSelectSpeakerEvent(BaseModel): """An event for selecting the next speaker in a group chat.""" selected_speaker: str diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_chat_agent_container.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_chat_agent_container.py deleted file mode 100644 index 4f3e902af..000000000 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_chat_agent_container.py +++ /dev/null @@ -1,92 +0,0 @@ -import asyncio -import logging -from typing import List - -from autogen_core.base import AgentId, AgentType, MessageContext -from autogen_core.components import DefaultTopicId, event -from autogen_core.components.models import FunctionExecutionResult -from autogen_core.components.tool_agent import ToolException - -from ... import EVENT_LOGGER_NAME -from ...base import ChatAgent -from ...messages import MultiModalMessage, StopMessage, TextMessage, ToolCallMessage, ToolCallResultMessage -from .._events import ContentPublishEvent, ContentRequestEvent, ToolCallEvent, ToolCallResultEvent -from ._sequential_routed_agent import SequentialRoutedAgent - -event_logger = logging.getLogger(EVENT_LOGGER_NAME) - - -class BaseChatAgentContainer(SequentialRoutedAgent): - """A core agent class that delegates message handling to an - :class:`autogen_agentchat.agents.BaseChatAgent` so that it can be used in a - group chat team. - - Args: - parent_topic_type (str): The topic type of the parent orchestrator. - agent (BaseChatAgent): The agent to delegate message handling to. - tool_agent_type (AgentType, optional): The agent type of the tool agent. Defaults to None. - """ - - def __init__(self, parent_topic_type: str, agent: ChatAgent, tool_agent_type: AgentType | None = None) -> None: - super().__init__(description=agent.description) - self._parent_topic_type = parent_topic_type - self._agent = agent - self._message_buffer: List[TextMessage | MultiModalMessage | StopMessage] = [] - self._tool_agent_id = AgentId(type=tool_agent_type, key=self.id.key) if tool_agent_type else None - - @event - async def handle_content_publish(self, message: ContentPublishEvent, ctx: MessageContext) -> None: - """Handle a content publish event by appending the content to the buffer.""" - if not isinstance(message.agent_message, TextMessage | MultiModalMessage | StopMessage): - raise ValueError( - f"Unexpected message type: {type(message.agent_message)}. " - "The message must be a text, multimodal, or stop message." - ) - self._message_buffer.append(message.agent_message) - - @event - async def handle_content_request(self, message: ContentRequestEvent, ctx: MessageContext) -> None: - """Handle a content request event by passing the messages in the buffer - to the delegate agent and publish the response.""" - response = await self._agent.on_messages(self._message_buffer, ctx.cancellation_token) - - if self._tool_agent_id is not None: - # Handle tool calls. - while isinstance(response, ToolCallMessage): - # Log the tool call. - event_logger.debug(ToolCallEvent(agent_message=response, source=self.id)) - - results: List[FunctionExecutionResult | BaseException] = await asyncio.gather( - *[ - self.send_message( - message=call, - recipient=self._tool_agent_id, - cancellation_token=ctx.cancellation_token, - ) - for call in response.content - ] - ) - # Combine the results in to a single response and handle exceptions. - function_results: List[FunctionExecutionResult] = [] - for result in results: - if isinstance(result, FunctionExecutionResult): - function_results.append(result) - elif isinstance(result, ToolException): - function_results.append( - FunctionExecutionResult(content=f"Error: {result}", call_id=result.call_id) - ) - elif isinstance(result, BaseException): - raise result # Unexpected exception. - # Create a new tool call result message. - feedback = ToolCallResultMessage(content=function_results, source=self._tool_agent_id.type) - # Log the feedback. - event_logger.debug(ToolCallResultEvent(agent_message=feedback, source=self._tool_agent_id)) - response = await self._agent.on_messages([feedback], ctx.cancellation_token) - - # Publish the response. - assert isinstance(response, TextMessage | MultiModalMessage | StopMessage) - self._message_buffer.clear() - await self.publish_message( - ContentPublishEvent(agent_message=response, source=self.id), - topic_id=DefaultTopicId(type=self._parent_topic_type), - ) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py index c599d269b..6ee79d4de 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py @@ -13,14 +13,12 @@ from autogen_core.base import ( TopicId, ) from autogen_core.components import ClosureAgent, TypeSubscription -from autogen_core.components.tool_agent import ToolAgent -from autogen_core.components.tools import Tool -from ...base import ChatAgent, TaskResult, Team, TerminationCondition, ToolUseChatAgent +from ...base import ChatAgent, TaskResult, Team, TerminationCondition from ...messages import ChatMessage, TextMessage -from .._events import ContentPublishEvent, ContentRequestEvent -from ._base_chat_agent_container import BaseChatAgentContainer +from .._events import GroupChatPublishEvent, GroupChatRequestPublishEvent from ._base_group_chat_manager import BaseGroupChatManager +from ._chat_agent_container import ChatAgentContainer class BaseGroupChat(Team, ABC): @@ -35,11 +33,6 @@ class BaseGroupChat(Team, ABC): raise ValueError("At least one participant is required.") if len(participants) != len(set(participant.name for participant in participants)): raise ValueError("The participant names must be unique.") - for participant in participants: - if isinstance(participant, ToolUseChatAgent) and not participant.registered_tools: - raise ValueError( - f"Participant '{participant.name}' is a tool use agent so it must have registered tools." - ) self._participants = participants self._team_id = str(uuid.uuid4()) self._base_group_chat_manager_class = group_chat_manager_class @@ -55,27 +48,19 @@ class BaseGroupChat(Team, ABC): ) -> Callable[[], BaseGroupChatManager]: ... def _create_participant_factory( - self, parent_topic_type: str, agent: ChatAgent, tool_agent_type: AgentType | None - ) -> Callable[[], BaseChatAgentContainer]: - def _factory() -> BaseChatAgentContainer: + self, + parent_topic_type: str, + agent: ChatAgent, + ) -> Callable[[], ChatAgentContainer]: + def _factory() -> ChatAgentContainer: id = AgentInstantiationContext.current_agent_id() assert id == AgentId(type=agent.name, key=self._team_id) - container = BaseChatAgentContainer(parent_topic_type, agent, tool_agent_type) + container = ChatAgentContainer(parent_topic_type, agent) assert container.id == id return container return _factory - def _create_tool_agent_factory( - self, - caller_name: str, - tools: List[Tool], - ) -> Callable[[], ToolAgent]: - def _factory() -> ToolAgent: - return ToolAgent(f"Tool agent for {caller_name}", tools) - - return _factory - async def run( self, task: str, @@ -99,27 +84,14 @@ class BaseGroupChat(Team, ABC): participant_topic_types: List[str] = [] participant_descriptions: List[str] = [] for participant in self._participants: - if isinstance(participant, ToolUseChatAgent): - assert participant.registered_tools is not None and len(participant.registered_tools) > 0 - # Register the tool agent. - tool_agent_type = await ToolAgent.register( - runtime, - f"tool_agent_for_{participant.name}", - self._create_tool_agent_factory(participant.name, participant.registered_tools), - ) - # No subscriptions are needed for the tool agent, which will be called via direct messages. - else: - # No tool agent is needed. - tool_agent_type = None - # Use the participant name as the agent type and topic type. agent_type = participant.name topic_type = participant.name # Register the participant factory. - await BaseChatAgentContainer.register( + await ChatAgentContainer.register( runtime, type=agent_type, - factory=self._create_participant_factory(group_topic_type, participant, tool_agent_type), + factory=self._create_participant_factory(group_topic_type, participant), ) # Add subscriptions for the participant. await runtime.add_subscription(TypeSubscription(topic_type=topic_type, agent_type=agent_type)) @@ -154,7 +126,10 @@ class BaseGroupChat(Team, ABC): group_chat_messages: List[ChatMessage] = [] async def collect_group_chat_messages( - _runtime: AgentRuntime, id: AgentId, message: ContentPublishEvent, ctx: MessageContext + _runtime: AgentRuntime, + id: AgentId, + message: GroupChatPublishEvent, + ctx: MessageContext, ) -> None: group_chat_messages.append(message.agent_message) @@ -174,10 +149,10 @@ class BaseGroupChat(Team, ABC): team_topic_id = TopicId(type=team_topic_type, source=self._team_id) group_chat_manager_topic_id = TopicId(type=group_chat_manager_topic_type, source=self._team_id) await runtime.publish_message( - ContentPublishEvent(agent_message=TextMessage(content=task, source="user")), + GroupChatPublishEvent(agent_message=TextMessage(content=task, source="user")), topic_id=team_topic_id, ) - await runtime.publish_message(ContentRequestEvent(), topic_id=group_chat_manager_topic_id) + await runtime.publish_message(GroupChatRequestPublishEvent(), topic_id=group_chat_manager_topic_id) # Wait for the runtime to stop. await runtime.stop_when_idle() diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat_manager.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat_manager.py index 5f59e9e63..68eb76c06 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat_manager.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat_manager.py @@ -1,13 +1,17 @@ import logging from abc import ABC, abstractmethod -from typing import List +from typing import Any, List -from autogen_core.base import MessageContext, TopicId -from autogen_core.components import event +from autogen_core.base import MessageContext +from autogen_core.components import DefaultTopicId, event from ... import EVENT_LOGGER_NAME from ...base import TerminationCondition -from .._events import ContentPublishEvent, ContentRequestEvent, TerminationEvent +from .._events import ( + GroupChatPublishEvent, + GroupChatRequestPublishEvent, + TerminationEvent, +) from ._sequential_routed_agent import SequentialRoutedAgent event_logger = logging.getLogger(EVENT_LOGGER_NAME) @@ -33,6 +37,10 @@ class BaseGroupChatManager(SequentialRoutedAgent, ABC): Raises: ValueError: If the number of participant topic types, agent types, and descriptions are not the same. + ValueError: If the participant topic types are not unique. + ValueError: If the group topic type is in the participant topic types. + ValueError: If the parent topic type is in the participant topic types. + ValueError: If the group topic type is the same as the parent topic type. """ def __init__( @@ -58,11 +66,11 @@ class BaseGroupChatManager(SequentialRoutedAgent, ABC): raise ValueError("The group topic type must not be the same as the parent topic type.") self._participant_topic_types = participant_topic_types self._participant_descriptions = participant_descriptions - self._message_thread: List[ContentPublishEvent] = [] + self._message_thread: List[GroupChatPublishEvent] = [] self._termination_condition = termination_condition @event - async def handle_content_publish(self, message: ContentPublishEvent, ctx: MessageContext) -> None: + async def handle_content_publish(self, message: GroupChatPublishEvent, ctx: MessageContext) -> None: """Handle a content publish event. If the event is from the parent topic, add the message to the thread. @@ -70,16 +78,25 @@ class BaseGroupChatManager(SequentialRoutedAgent, ABC): If the event is from the group chat topic, add the message to the thread and select a speaker to continue the conversation. If the event from the group chat session requests a pause, publish the last message to the parent topic.""" assert ctx.topic_id is not None - group_chat_topic_id = TopicId(type=self._group_topic_type, source=ctx.topic_id.source) event_logger.info(message) + if self._termination_condition is not None and self._termination_condition.terminated: + # The group chat has been terminated. + return + # Process event from parent. if ctx.topic_id.type == self._parent_topic_type: self._message_thread.append(message) await self.publish_message( - ContentPublishEvent(agent_message=message.agent_message, source=self.id), topic_id=group_chat_topic_id + GroupChatPublishEvent(agent_message=message.agent_message, source=self.id), + topic_id=DefaultTopicId(type=self._group_topic_type), ) + if self._termination_condition is not None: + stop_message = await self._termination_condition([message.agent_message]) + if stop_message is not None: + event_logger.info(TerminationEvent(agent_message=stop_message, source=self.id)) + # Stop the group chat. return # Process event from the group chat this agent manages. @@ -91,8 +108,6 @@ class BaseGroupChatManager(SequentialRoutedAgent, ABC): stop_message = await self._termination_condition([message.agent_message]) if stop_message is not None: event_logger.info(TerminationEvent(agent_message=stop_message, source=self.id)) - # Reset the termination condition. - await self._termination_condition.reset() # Stop the group chat. # TODO: this should be different if the group chat is nested. return @@ -100,24 +115,28 @@ class BaseGroupChatManager(SequentialRoutedAgent, ABC): # Select a speaker to continue the conversation. speaker_topic_type = await self.select_speaker(self._message_thread) - participant_topic_id = TopicId(type=speaker_topic_type, source=ctx.topic_id.source) - group_chat_topic_id = TopicId(type=self._group_topic_type, source=ctx.topic_id.source) - await self.publish_message(ContentRequestEvent(), topic_id=participant_topic_id) + await self.publish_message(GroupChatRequestPublishEvent(), topic_id=DefaultTopicId(type=speaker_topic_type)) @event - async def handle_content_request(self, message: ContentRequestEvent, ctx: MessageContext) -> None: + async def handle_content_request(self, message: GroupChatRequestPublishEvent, ctx: MessageContext) -> None: """Handle a content request by selecting a speaker to start the conversation.""" assert ctx.topic_id is not None if ctx.topic_id.type == self._group_topic_type: raise RuntimeError("Content request event from the group chat topic is not allowed.") + if self._termination_condition is not None and self._termination_condition.terminated: + # The group chat has been terminated. + return + speaker_topic_type = await self.select_speaker(self._message_thread) - participant_topic_id = TopicId(type=speaker_topic_type, source=ctx.topic_id.source) - await self.publish_message(ContentRequestEvent(), topic_id=participant_topic_id) + await self.publish_message(GroupChatRequestPublishEvent(), topic_id=DefaultTopicId(type=speaker_topic_type)) @abstractmethod - async def select_speaker(self, thread: List[ContentPublishEvent]) -> str: + async def select_speaker(self, thread: List[GroupChatPublishEvent]) -> str: """Select a speaker from the participants and return the topic type of the selected speaker.""" ... + + async def on_unhandled_message(self, message: Any, ctx: MessageContext) -> None: + raise ValueError(f"Unhandled message in group chat manager: {type(message)}") diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_chat_agent_container.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_chat_agent_container.py new file mode 100644 index 000000000..acf5e9d2f --- /dev/null +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_chat_agent_container.py @@ -0,0 +1,48 @@ +from typing import Any, List + +from autogen_core.base import MessageContext +from autogen_core.components import DefaultTopicId, event + +from ...base import ChatAgent +from ...messages import ChatMessage +from .._events import GroupChatPublishEvent, GroupChatRequestPublishEvent +from ._sequential_routed_agent import SequentialRoutedAgent + + +class ChatAgentContainer(SequentialRoutedAgent): + """A core agent class that delegates message handling to an + :class:`autogen_agentchat.base.ChatAgent` so that it can be used in a + group chat team. + + Args: + parent_topic_type (str): The topic type of the parent orchestrator. + agent (ChatAgent): The agent to delegate message handling to. + """ + + def __init__(self, parent_topic_type: str, agent: ChatAgent) -> None: + super().__init__(description=agent.description) + self._parent_topic_type = parent_topic_type + self._agent = agent + self._message_buffer: List[ChatMessage] = [] + + @event + async def handle_message(self, message: GroupChatPublishEvent, ctx: MessageContext) -> None: + """Handle an event by appending the content to the buffer.""" + self._message_buffer.append(message.agent_message) + + @event + async def handle_content_request(self, message: GroupChatRequestPublishEvent, ctx: MessageContext) -> None: + """Handle a content request event by passing the messages in the buffer + to the delegate agent and publish the response.""" + # Pass the messages in the buffer to the delegate agent. + response = await self._agent.on_messages(self._message_buffer, ctx.cancellation_token) + + # Publish the response. + self._message_buffer.clear() + await self.publish_message( + GroupChatPublishEvent(agent_message=response, source=self.id), + topic_id=DefaultTopicId(type=self._parent_topic_type), + ) + + async def on_unhandled_message(self, message: Any, ctx: MessageContext) -> None: + raise ValueError(f"Unhandled message in agent container: {type(message)}") diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_round_robin_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_round_robin_group_chat.py index 529314fec..fff872dd8 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_round_robin_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_round_robin_group_chat.py @@ -1,10 +1,17 @@ +import logging from typing import Callable, List +from ... import EVENT_LOGGER_NAME from ...base import ChatAgent, TerminationCondition -from .._events import ContentPublishEvent +from .._events import ( + GroupChatPublishEvent, + GroupChatSelectSpeakerEvent, +) from ._base_group_chat import BaseGroupChat from ._base_group_chat_manager import BaseGroupChatManager +event_logger = logging.getLogger(EVENT_LOGGER_NAME) + class RoundRobinGroupChatManager(BaseGroupChatManager): """A group chat manager that selects the next speaker in a round-robin fashion.""" @@ -26,11 +33,13 @@ class RoundRobinGroupChatManager(BaseGroupChatManager): ) self._next_speaker_index = 0 - async def select_speaker(self, thread: List[ContentPublishEvent]) -> str: + async def select_speaker(self, thread: List[GroupChatPublishEvent]) -> str: """Select a speaker from the participants in a round-robin fashion.""" current_speaker_index = self._next_speaker_index self._next_speaker_index = (current_speaker_index + 1) % len(self._participant_topic_types) - return self._participant_topic_types[current_speaker_index] + current_speaker = self._participant_topic_types[current_speaker_index] + event_logger.debug(GroupChatSelectSpeakerEvent(selected_speaker=current_speaker, source=self.id)) + return current_speaker class RoundRobinGroupChat(BaseGroupChat): diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py index 1eaf0ddd5..79f8b60de 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py @@ -7,7 +7,10 @@ from autogen_core.components.models import ChatCompletionClient, SystemMessage from ... import EVENT_LOGGER_NAME, TRACE_LOGGER_NAME from ...base import ChatAgent, TerminationCondition from ...messages import MultiModalMessage, StopMessage, TextMessage -from .._events import ContentPublishEvent, SelectSpeakerEvent +from .._events import ( + GroupChatPublishEvent, + GroupChatSelectSpeakerEvent, +) from ._base_group_chat import BaseGroupChat from ._base_group_chat_manager import BaseGroupChatManager @@ -42,7 +45,7 @@ class SelectorGroupChatManager(BaseGroupChatManager): self._previous_speaker: str | None = None self._allow_repeated_speaker = allow_repeated_speaker - async def select_speaker(self, thread: List[ContentPublishEvent]) -> str: + async def select_speaker(self, thread: List[GroupChatPublishEvent]) -> str: """Selects the next speaker in a group chat using a ChatCompletion client. A key assumption is that the agent type is the same as the topic type, which we use as the agent name. @@ -107,7 +110,7 @@ class SelectorGroupChatManager(BaseGroupChatManager): else: agent_name = participants[0] self._previous_speaker = agent_name - event_logger.debug(SelectSpeakerEvent(selected_speaker=agent_name, source=self.id)) + event_logger.debug(GroupChatSelectSpeakerEvent(selected_speaker=agent_name, source=self.id)) return agent_name def _mentioned_agents(self, message_content: str, agent_names: List[str]) -> Dict[str, int]: @@ -148,7 +151,7 @@ class SelectorGroupChat(BaseGroupChat): to all, using a ChatCompletion model to select the next speaker after each message. Args: - participants (List[BaseChatAgent]): The participants in the group chat, + participants (List[ChatAgent]): The participants in the group chat, must have unique names and at least two participants. model_client (ChatCompletionClient): The ChatCompletion model client used to select the next speaker. diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_swarm_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_swarm_group_chat.py new file mode 100644 index 000000000..4f2d08afc --- /dev/null +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_swarm_group_chat.py @@ -0,0 +1,72 @@ +import logging +from typing import Callable, List + +from ... import EVENT_LOGGER_NAME +from ...base import ChatAgent, TerminationCondition +from ...messages import HandoffMessage +from .._events import ( + GroupChatPublishEvent, + GroupChatSelectSpeakerEvent, +) +from ._base_group_chat import BaseGroupChat +from ._base_group_chat_manager import BaseGroupChatManager + +event_logger = logging.getLogger(EVENT_LOGGER_NAME) + + +class SwarmGroupChatManager(BaseGroupChatManager): + """A group chat manager that selects the next speaker based on handoff message only.""" + + def __init__( + self, + parent_topic_type: str, + group_topic_type: str, + participant_topic_types: List[str], + participant_descriptions: List[str], + termination_condition: TerminationCondition | None, + ) -> None: + super().__init__( + parent_topic_type, + group_topic_type, + participant_topic_types, + participant_descriptions, + termination_condition, + ) + self._current_speaker = participant_topic_types[0] + + async def select_speaker(self, thread: List[GroupChatPublishEvent]) -> str: + """Select a speaker from the participants based on handoff message.""" + if len(thread) > 0 and isinstance(thread[-1].agent_message, HandoffMessage): + self._current_speaker = thread[-1].agent_message.content + if self._current_speaker not in self._participant_topic_types: + raise ValueError("The selected speaker in the handoff message is not a participant.") + event_logger.debug(GroupChatSelectSpeakerEvent(selected_speaker=self._current_speaker, source=self.id)) + return self._current_speaker + else: + return self._current_speaker + + +class Swarm(BaseGroupChat): + """(Experimental) A group chat that selects the next speaker based on handoff message only.""" + + def __init__(self, participants: List[ChatAgent]): + super().__init__(participants, group_chat_manager_class=SwarmGroupChatManager) + + def _create_group_chat_manager_factory( + self, + parent_topic_type: str, + group_topic_type: str, + participant_topic_types: List[str], + participant_descriptions: List[str], + termination_condition: TerminationCondition | None, + ) -> Callable[[], SwarmGroupChatManager]: + def _factory() -> SwarmGroupChatManager: + return SwarmGroupChatManager( + parent_topic_type, + group_topic_type, + participant_topic_types, + participant_descriptions, + termination_condition, + ) + + return _factory diff --git a/python/packages/autogen-agentchat/tests/test_group_chat.py b/python/packages/autogen-agentchat/tests/test_group_chat.py index 97cf65c2d..9f740eb64 100644 --- a/python/packages/autogen-agentchat/tests/test_group_chat.py +++ b/python/packages/autogen-agentchat/tests/test_group_chat.py @@ -13,11 +13,17 @@ from autogen_agentchat.agents import ( ToolUseAssistantAgent, ) from autogen_agentchat.logging import FileLogHandler -from autogen_agentchat.messages import ChatMessage, StopMessage, TextMessage -from autogen_agentchat.task import StopMessageTermination +from autogen_agentchat.messages import ( + ChatMessage, + HandoffMessage, + StopMessage, + TextMessage, +) +from autogen_agentchat.task import MaxMessageTermination, StopMessageTermination from autogen_agentchat.teams import ( RoundRobinGroupChat, SelectorGroupChat, + Swarm, ) from autogen_core.base import CancellationToken from autogen_core.components import FunctionCall @@ -212,7 +218,16 @@ async def test_round_robin_group_chat_with_tools(monkeypatch: pytest.MonkeyPatch ) echo_agent = _EchoAgent("echo_agent", description="echo agent") team = RoundRobinGroupChat(participants=[tool_use_agent, echo_agent]) - await team.run("Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination()) + result = await team.run( + "Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination() + ) + + assert len(result.messages) == 4 + assert isinstance(result.messages[0], TextMessage) # task + assert isinstance(result.messages[1], TextMessage) # tool use agent response + assert isinstance(result.messages[2], TextMessage) # echo agent response + assert isinstance(result.messages[3], StopMessage) # tool use agent response + context = tool_use_agent._model_context # pyright: ignore assert context[0].content == "Write a program that prints 'Hello, world!'" assert isinstance(context[1].content, list) @@ -393,3 +408,29 @@ async def test_selector_group_chat_two_speakers_allow_repeated(monkeypatch: pyte assert result.messages[1].source == "agent2" assert result.messages[2].source == "agent2" assert result.messages[3].source == "agent1" + + +class _HandOffAgent(BaseChatAgent): + def __init__(self, name: str, description: str, next_agent: str) -> None: + super().__init__(name, description) + self._next_agent = next_agent + + async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage: + return HandoffMessage(content=self._next_agent, source=self.name) + + +@pytest.mark.asyncio +async def test_swarm() -> None: + first_agent = _HandOffAgent("first_agent", description="first agent", next_agent="second_agent") + second_agent = _HandOffAgent("second_agent", description="second agent", next_agent="third_agent") + third_agent = _HandOffAgent("third_agent", description="third agent", next_agent="first_agent") + + team = Swarm([second_agent, first_agent, third_agent]) + result = await team.run("task", termination_condition=MaxMessageTermination(6)) + assert len(result.messages) == 6 + assert result.messages[0].content == "task" + assert result.messages[1].content == "third_agent" + assert result.messages[2].content == "first_agent" + assert result.messages[3].content == "second_agent" + assert result.messages[4].content == "third_agent" + assert result.messages[5].content == "first_agent" diff --git a/python/packages/autogen-agentchat/tests/test_tool_use_assistant_agent.py b/python/packages/autogen-agentchat/tests/test_tool_use_assistant_agent.py index 3a1734b3f..d5ec31a12 100644 --- a/python/packages/autogen-agentchat/tests/test_tool_use_assistant_agent.py +++ b/python/packages/autogen-agentchat/tests/test_tool_use_assistant_agent.py @@ -4,6 +4,7 @@ from typing import Any, AsyncGenerator, List import pytest from autogen_agentchat.agents import ToolUseAssistantAgent +from autogen_agentchat.messages import StopMessage, TextMessage from autogen_core.components.models import OpenAIChatCompletionClient from autogen_core.components.tools import FunctionTool from openai.resources.chat.completions import AsyncCompletions @@ -103,5 +104,6 @@ async def test_round_robin_group_chat_with_tools(monkeypatch: pytest.MonkeyPatch ) result = await tool_use_agent.run("task") assert len(result.messages) == 3 - # assert isinstance(result.messages[1], ToolCallMessage) - # assert isinstance(result.messages[2], TextMessage) + assert isinstance(result.messages[0], TextMessage) + assert isinstance(result.messages[1], TextMessage) + assert isinstance(result.messages[2], StopMessage)