diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/__init__.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/__init__.py index 1c3078d01..19cdec548 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/__init__.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/__init__.py @@ -1,11 +1,10 @@ -from ._base_chat_agent import BaseChatAgent, BaseToolUseChatAgent +from ._base_chat_agent import BaseChatAgent from ._code_executor_agent import CodeExecutorAgent from ._coding_assistant_agent import CodingAssistantAgent from ._tool_use_assistant_agent import ToolUseAssistantAgent __all__ = [ "BaseChatAgent", - "BaseToolUseChatAgent", "CodeExecutorAgent", "CodingAssistantAgent", "ToolUseAssistantAgent", diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_base_chat_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_base_chat_agent.py index 62bac59d5..77bf4c02c 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_base_chat_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_base_chat_agent.py @@ -1,10 +1,9 @@ from abc import ABC, abstractmethod -from typing import List, Sequence +from typing import Sequence from autogen_core.base import CancellationToken -from autogen_core.components.tools import Tool -from ..base import ChatAgent, TaskResult, TerminationCondition, ToolUseChatAgent +from ..base import ChatAgent, TaskResult, TerminationCondition from ..messages import ChatMessage from ..teams import RoundRobinGroupChat @@ -51,21 +50,3 @@ class BaseChatAgent(ChatAgent, ABC): termination_condition=termination_condition, ) return result - - -class BaseToolUseChatAgent(BaseChatAgent, ToolUseChatAgent): - """Base class for a chat agent that can use tools. - - Subclass this base class to create an agent class that uses tools by returning - ToolCallMessage message from the :meth:`on_messages` method and receiving - ToolCallResultMessage message from the input to the :meth:`on_messages` method. - """ - - def __init__(self, name: str, description: str, registered_tools: List[Tool]) -> None: - super().__init__(name, description) - self._registered_tools = registered_tools - - @property - def registered_tools(self) -> List[Tool]: - """The list of tools that the agent can use.""" - return self._registered_tools diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_tool_use_assistant_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_tool_use_assistant_agent.py index 37022acd0..fde1a3e49 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_tool_use_assistant_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/agents/_tool_use_assistant_agent.py @@ -1,3 +1,6 @@ +import asyncio +import json +import logging from typing import Any, Awaitable, Callable, List, Sequence from autogen_core.base import CancellationToken @@ -5,25 +8,45 @@ from autogen_core.components import FunctionCall from autogen_core.components.models import ( AssistantMessage, ChatCompletionClient, + FunctionExecutionResult, FunctionExecutionResultMessage, LLMMessage, SystemMessage, UserMessage, ) from autogen_core.components.tools import FunctionTool, Tool +from pydantic import BaseModel, ConfigDict +from .. import EVENT_LOGGER_NAME from ..messages import ( ChatMessage, - MultiModalMessage, StopMessage, TextMessage, - ToolCallMessage, - ToolCallResultMessage, ) -from ._base_chat_agent import BaseToolUseChatAgent +from ._base_chat_agent import BaseChatAgent + +event_logger = logging.getLogger(EVENT_LOGGER_NAME) -class ToolUseAssistantAgent(BaseToolUseChatAgent): +class ToolCallEvent(BaseModel): + """A tool call event.""" + + tool_calls: List[FunctionCall] + """The tool call message.""" + + model_config = ConfigDict(arbitrary_types_allowed=True) + + +class ToolCallResultEvent(BaseModel): + """A tool call result event.""" + + tool_call_results: List[FunctionExecutionResult] + """The tool call result message.""" + + model_config = ConfigDict(arbitrary_types_allowed=True) + + +class ToolUseAssistantAgent(BaseChatAgent): """An agent that provides assistance with tool use. It responds with a StopMessage when 'terminate' is detected in the response. @@ -45,46 +68,50 @@ class ToolUseAssistantAgent(BaseToolUseChatAgent): description: str = "An agent that provides assistance with ability to use tools.", system_message: str = "You are a helpful AI assistant. Solve tasks using your tools. Reply with 'TERMINATE' when the task has been completed.", ): - tools: List[Tool] = [] + super().__init__(name=name, description=description) + self._model_client = model_client + self._system_messages = [SystemMessage(content=system_message)] + self._tools: List[Tool] = [] for tool in registered_tools: if isinstance(tool, Tool): - tools.append(tool) + self._tools.append(tool) elif callable(tool): if hasattr(tool, "__doc__") and tool.__doc__ is not None: description = tool.__doc__ else: description = "" - tools.append(FunctionTool(tool, description=description)) + self._tools.append(FunctionTool(tool, description=description)) else: raise ValueError(f"Unsupported tool type: {type(tool)}") - super().__init__(name=name, description=description, registered_tools=tools) - self._model_client = model_client - self._system_messages = [SystemMessage(content=system_message)] - self._tool_schema = [tool.schema for tool in tools] self._model_context: List[LLMMessage] = [] async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage: # Add messages to the model context. for msg in messages: - if isinstance(msg, ToolCallResultMessage): - self._model_context.append(FunctionExecutionResultMessage(content=msg.content)) - elif not isinstance(msg, TextMessage | MultiModalMessage | StopMessage): - raise ValueError(f"Unsupported message type: {type(msg)}") - else: - self._model_context.append(UserMessage(content=msg.content, source=msg.source)) + # TODO: add special handling for handoff messages + self._model_context.append(UserMessage(content=msg.content, source=msg.source)) # Generate an inference result based on the current model context. llm_messages = self._system_messages + self._model_context - result = await self._model_client.create( - llm_messages, tools=self._tool_schema, cancellation_token=cancellation_token - ) + result = await self._model_client.create(llm_messages, tools=self._tools, cancellation_token=cancellation_token) # Add the response to the model context. self._model_context.append(AssistantMessage(content=result.content, source=self.name)) - # Detect tool calls. - if isinstance(result.content, list) and all(isinstance(item, FunctionCall) for item in result.content): - return ToolCallMessage(content=result.content, source=self.name) + # Run tool calls until the model produces a string response. + while isinstance(result.content, list) and all(isinstance(item, FunctionCall) for item in result.content): + event_logger.debug(ToolCallEvent(tool_calls=result.content)) + # Execute the tool calls. + results = await asyncio.gather( + *[self._execute_tool_call(call, cancellation_token) for call in result.content] + ) + event_logger.debug(ToolCallResultEvent(tool_call_results=results)) + self._model_context.append(FunctionExecutionResultMessage(content=results)) + # Generate an inference result based on the current model context. + result = await self._model_client.create( + self._model_context, tools=self._tools, cancellation_token=cancellation_token + ) + self._model_context.append(AssistantMessage(content=result.content, source=self.name)) assert isinstance(result.content, str) # Detect stop request. @@ -93,3 +120,20 @@ class ToolUseAssistantAgent(BaseToolUseChatAgent): return StopMessage(content=result.content, source=self.name) return TextMessage(content=result.content, source=self.name) + + async def _execute_tool_call( + self, tool_call: FunctionCall, cancellation_token: CancellationToken + ) -> FunctionExecutionResult: + """Execute a tool call and return the result.""" + try: + if not self._tools: + raise ValueError("No tools are available.") + tool = next((t for t in self._tools if t.name == tool_call.name), None) + if tool is None: + raise ValueError(f"The tool '{tool_call.name}' is not available.") + arguments = json.loads(tool_call.arguments) + result = await tool.run_json(arguments, cancellation_token) + result_as_str = tool.return_value_as_string(result) + return FunctionExecutionResult(content=result_as_str, call_id=tool_call.id) + except Exception as e: + return FunctionExecutionResult(content=f"Error: {e}", call_id=tool_call.id) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/base/__init__.py b/python/packages/autogen-agentchat/src/autogen_agentchat/base/__init__.py index 36845eb82..436d69fb0 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/base/__init__.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/base/__init__.py @@ -1,11 +1,10 @@ -from ._chat_agent import ChatAgent, ToolUseChatAgent +from ._chat_agent import ChatAgent from ._task import TaskResult, TaskRunner from ._team import Team from ._termination import TerminatedException, TerminationCondition __all__ = [ "ChatAgent", - "ToolUseChatAgent", "Team", "TerminatedException", "TerminationCondition", diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/base/_chat_agent.py b/python/packages/autogen-agentchat/src/autogen_agentchat/base/_chat_agent.py index 6200050d4..d82539540 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/base/_chat_agent.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/base/_chat_agent.py @@ -1,7 +1,6 @@ -from typing import List, Protocol, Sequence, runtime_checkable +from typing import Protocol, Sequence, runtime_checkable from autogen_core.base import CancellationToken -from autogen_core.components.tools import Tool from ..messages import ChatMessage from ._task import TaskResult, TaskRunner @@ -38,13 +37,3 @@ class ChatAgent(TaskRunner, Protocol): ) -> TaskResult: """Run the agent with the given task and return the result.""" ... - - -@runtime_checkable -class ToolUseChatAgent(ChatAgent, Protocol): - """Protocol for a chat agent that can use tools.""" - - @property - def registered_tools(self) -> List[Tool]: - """The list of tools that the agent can use.""" - ... diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/logging/_console_log_handler.py b/python/packages/autogen-agentchat/src/autogen_agentchat/logging/_console_log_handler.py index d0fb4ab08..95200e284 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/logging/_console_log_handler.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/logging/_console_log_handler.py @@ -3,13 +3,12 @@ import logging import sys from datetime import datetime +from ..agents._tool_use_assistant_agent import ToolCallEvent, ToolCallResultEvent from ..messages import ChatMessage, StopMessage, TextMessage from ..teams._events import ( - ContentPublishEvent, - SelectSpeakerEvent, + GroupChatPublishEvent, + GroupChatSelectSpeakerEvent, TerminationEvent, - ToolCallEvent, - ToolCallResultEvent, ) @@ -25,7 +24,7 @@ class ConsoleLogHandler(logging.Handler): def emit(self, record: logging.LogRecord) -> None: ts = datetime.fromtimestamp(record.created).isoformat() - if isinstance(record.msg, ContentPublishEvent): + if isinstance(record.msg, GroupChatPublishEvent): if record.msg.source is None: sys.stdout.write( f"\n{'-'*75} \n" @@ -41,19 +40,15 @@ class ConsoleLogHandler(logging.Handler): sys.stdout.flush() elif isinstance(record.msg, ToolCallEvent): sys.stdout.write( - f"\n{'-'*75} \n" - f"\033[91m[{ts}], Tool Call:\033[0m\n" - f"\n{self.serialize_chat_message(record.msg.agent_message)}" + f"\n{'-'*75} \n" f"\033[91m[{ts}], Tool Call:\033[0m\n" f"\n{str(record.msg.model_dump())}" ) sys.stdout.flush() elif isinstance(record.msg, ToolCallResultEvent): sys.stdout.write( - f"\n{'-'*75} \n" - f"\033[91m[{ts}], Tool Call Result:\033[0m\n" - f"\n{self.serialize_chat_message(record.msg.agent_message)}" + f"\n{'-'*75} \n" f"\033[91m[{ts}], Tool Call Result:\033[0m\n" f"\n{str(record.msg.model_dump())}" ) sys.stdout.flush() - elif isinstance(record.msg, SelectSpeakerEvent): + elif isinstance(record.msg, GroupChatSelectSpeakerEvent): sys.stdout.write( f"\n{'-'*75} \n" f"\033[91m[{ts}], Selected Next Speaker:\033[0m\n" f"\n{record.msg.selected_speaker}" ) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/logging/_file_log_handler.py b/python/packages/autogen-agentchat/src/autogen_agentchat/logging/_file_log_handler.py index ca64b0d68..24fd09418 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/logging/_file_log_handler.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/logging/_file_log_handler.py @@ -4,12 +4,11 @@ from dataclasses import asdict, is_dataclass from datetime import datetime from typing import Any +from ..agents._tool_use_assistant_agent import ToolCallEvent, ToolCallResultEvent from ..teams._events import ( - ContentPublishEvent, - SelectSpeakerEvent, + GroupChatPublishEvent, + GroupChatSelectSpeakerEvent, TerminationEvent, - ToolCallEvent, - ToolCallResultEvent, ) @@ -21,7 +20,7 @@ class FileLogHandler(logging.Handler): def emit(self, record: logging.LogRecord) -> None: ts = datetime.fromtimestamp(record.created).isoformat() - if isinstance(record.msg, ContentPublishEvent | ToolCallEvent | ToolCallResultEvent | TerminationEvent): + if isinstance(record.msg, GroupChatPublishEvent | TerminationEvent): log_entry = json.dumps( { "timestamp": ts, @@ -31,7 +30,7 @@ class FileLogHandler(logging.Handler): }, default=self.json_serializer, ) - elif isinstance(record.msg, SelectSpeakerEvent): + elif isinstance(record.msg, GroupChatSelectSpeakerEvent): log_entry = json.dumps( { "timestamp": ts, @@ -41,6 +40,24 @@ class FileLogHandler(logging.Handler): }, default=self.json_serializer, ) + elif isinstance(record.msg, ToolCallEvent): + log_entry = json.dumps( + { + "timestamp": ts, + "tool_calls": record.msg.model_dump(), + "type": "ToolCallEvent", + }, + default=self.json_serializer, + ) + elif isinstance(record.msg, ToolCallResultEvent): + log_entry = json.dumps( + { + "timestamp": ts, + "tool_call_results": record.msg.model_dump(), + "type": "ToolCallResultEvent", + }, + default=self.json_serializer, + ) else: raise ValueError(f"Unexpected log record: {record.msg}") file_record = logging.LogRecord( diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py b/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py index 6aac22248..99bd0c888 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py @@ -1,7 +1,6 @@ from typing import List -from autogen_core.components import FunctionCall, Image -from autogen_core.components.models import FunctionExecutionResult +from autogen_core.components import Image from pydantic import BaseModel @@ -26,20 +25,6 @@ class MultiModalMessage(BaseMessage): """The content of the message.""" -class ToolCallMessage(BaseMessage): - """A message containing a list of function calls.""" - - content: List[FunctionCall] - """The list of function calls.""" - - -class ToolCallResultMessage(BaseMessage): - """A message containing the results of function calls.""" - - content: List[FunctionExecutionResult] - """The list of function execution results.""" - - class StopMessage(BaseMessage): """A message requesting stop of a conversation.""" @@ -47,7 +32,14 @@ class StopMessage(BaseMessage): """The content for the stop message.""" -ChatMessage = TextMessage | MultiModalMessage | StopMessage | ToolCallMessage | ToolCallResultMessage +class HandoffMessage(BaseMessage): + """A message requesting handoff of a conversation to another agent.""" + + content: str + """The agent name to handoff the conversation to.""" + + +ChatMessage = TextMessage | MultiModalMessage | StopMessage | HandoffMessage """A message used by agents in a team.""" @@ -55,8 +47,7 @@ __all__ = [ "BaseMessage", "TextMessage", "MultiModalMessage", - "ToolCallMessage", - "ToolCallResultMessage", "StopMessage", + "HandoffMessage", "ChatMessage", ] diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/__init__.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/__init__.py index 836f15012..a2dd74d61 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/__init__.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/__init__.py @@ -1,7 +1,9 @@ from ._group_chat._round_robin_group_chat import RoundRobinGroupChat from ._group_chat._selector_group_chat import SelectorGroupChat +from ._group_chat._swarm_group_chat import Swarm __all__ = [ "RoundRobinGroupChat", "SelectorGroupChat", + "Swarm", ] diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_events.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_events.py index 5bfeff417..3442b35ce 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_events.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_events.py @@ -1,16 +1,16 @@ from autogen_core.base import AgentId from pydantic import BaseModel, ConfigDict -from ..messages import MultiModalMessage, StopMessage, TextMessage, ToolCallMessage, ToolCallResultMessage +from ..messages import ChatMessage, StopMessage -class ContentPublishEvent(BaseModel): - """An event for sharing some data. Agents receive this event should +class GroupChatPublishEvent(BaseModel): + """An group chat event for sharing some data. Agents receive this event should update their internal state (e.g., append to message history) with the content of the event. """ - agent_message: TextMessage | MultiModalMessage | StopMessage + agent_message: ChatMessage """The message published by the agent.""" source: AgentId | None = None @@ -19,39 +19,15 @@ class ContentPublishEvent(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) -class ContentRequestEvent(BaseModel): - """An event for requesting to publish a content event. - Upon receiving this event, the agent should publish a ContentPublishEvent. +class GroupChatRequestPublishEvent(BaseModel): + """An event for requesting to publish a group chat publish event. + Upon receiving this event, the agent should publish a group chat publish event. """ ... -class ToolCallEvent(BaseModel): - """An event produced when requesting a tool call.""" - - agent_message: ToolCallMessage - """The tool call message.""" - - source: AgentId - """The sender of the tool call message.""" - - model_config = ConfigDict(arbitrary_types_allowed=True) - - -class ToolCallResultEvent(BaseModel): - """An event produced when a tool call is completed.""" - - agent_message: ToolCallResultMessage - """The tool call result message.""" - - source: AgentId - """The sender of the tool call result message.""" - - model_config = ConfigDict(arbitrary_types_allowed=True) - - -class SelectSpeakerEvent(BaseModel): +class GroupChatSelectSpeakerEvent(BaseModel): """An event for selecting the next speaker in a group chat.""" selected_speaker: str diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_chat_agent_container.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_chat_agent_container.py deleted file mode 100644 index 4f3e902af..000000000 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_chat_agent_container.py +++ /dev/null @@ -1,92 +0,0 @@ -import asyncio -import logging -from typing import List - -from autogen_core.base import AgentId, AgentType, MessageContext -from autogen_core.components import DefaultTopicId, event -from autogen_core.components.models import FunctionExecutionResult -from autogen_core.components.tool_agent import ToolException - -from ... import EVENT_LOGGER_NAME -from ...base import ChatAgent -from ...messages import MultiModalMessage, StopMessage, TextMessage, ToolCallMessage, ToolCallResultMessage -from .._events import ContentPublishEvent, ContentRequestEvent, ToolCallEvent, ToolCallResultEvent -from ._sequential_routed_agent import SequentialRoutedAgent - -event_logger = logging.getLogger(EVENT_LOGGER_NAME) - - -class BaseChatAgentContainer(SequentialRoutedAgent): - """A core agent class that delegates message handling to an - :class:`autogen_agentchat.agents.BaseChatAgent` so that it can be used in a - group chat team. - - Args: - parent_topic_type (str): The topic type of the parent orchestrator. - agent (BaseChatAgent): The agent to delegate message handling to. - tool_agent_type (AgentType, optional): The agent type of the tool agent. Defaults to None. - """ - - def __init__(self, parent_topic_type: str, agent: ChatAgent, tool_agent_type: AgentType | None = None) -> None: - super().__init__(description=agent.description) - self._parent_topic_type = parent_topic_type - self._agent = agent - self._message_buffer: List[TextMessage | MultiModalMessage | StopMessage] = [] - self._tool_agent_id = AgentId(type=tool_agent_type, key=self.id.key) if tool_agent_type else None - - @event - async def handle_content_publish(self, message: ContentPublishEvent, ctx: MessageContext) -> None: - """Handle a content publish event by appending the content to the buffer.""" - if not isinstance(message.agent_message, TextMessage | MultiModalMessage | StopMessage): - raise ValueError( - f"Unexpected message type: {type(message.agent_message)}. " - "The message must be a text, multimodal, or stop message." - ) - self._message_buffer.append(message.agent_message) - - @event - async def handle_content_request(self, message: ContentRequestEvent, ctx: MessageContext) -> None: - """Handle a content request event by passing the messages in the buffer - to the delegate agent and publish the response.""" - response = await self._agent.on_messages(self._message_buffer, ctx.cancellation_token) - - if self._tool_agent_id is not None: - # Handle tool calls. - while isinstance(response, ToolCallMessage): - # Log the tool call. - event_logger.debug(ToolCallEvent(agent_message=response, source=self.id)) - - results: List[FunctionExecutionResult | BaseException] = await asyncio.gather( - *[ - self.send_message( - message=call, - recipient=self._tool_agent_id, - cancellation_token=ctx.cancellation_token, - ) - for call in response.content - ] - ) - # Combine the results in to a single response and handle exceptions. - function_results: List[FunctionExecutionResult] = [] - for result in results: - if isinstance(result, FunctionExecutionResult): - function_results.append(result) - elif isinstance(result, ToolException): - function_results.append( - FunctionExecutionResult(content=f"Error: {result}", call_id=result.call_id) - ) - elif isinstance(result, BaseException): - raise result # Unexpected exception. - # Create a new tool call result message. - feedback = ToolCallResultMessage(content=function_results, source=self._tool_agent_id.type) - # Log the feedback. - event_logger.debug(ToolCallResultEvent(agent_message=feedback, source=self._tool_agent_id)) - response = await self._agent.on_messages([feedback], ctx.cancellation_token) - - # Publish the response. - assert isinstance(response, TextMessage | MultiModalMessage | StopMessage) - self._message_buffer.clear() - await self.publish_message( - ContentPublishEvent(agent_message=response, source=self.id), - topic_id=DefaultTopicId(type=self._parent_topic_type), - ) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py index c599d269b..6ee79d4de 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py @@ -13,14 +13,12 @@ from autogen_core.base import ( TopicId, ) from autogen_core.components import ClosureAgent, TypeSubscription -from autogen_core.components.tool_agent import ToolAgent -from autogen_core.components.tools import Tool -from ...base import ChatAgent, TaskResult, Team, TerminationCondition, ToolUseChatAgent +from ...base import ChatAgent, TaskResult, Team, TerminationCondition from ...messages import ChatMessage, TextMessage -from .._events import ContentPublishEvent, ContentRequestEvent -from ._base_chat_agent_container import BaseChatAgentContainer +from .._events import GroupChatPublishEvent, GroupChatRequestPublishEvent from ._base_group_chat_manager import BaseGroupChatManager +from ._chat_agent_container import ChatAgentContainer class BaseGroupChat(Team, ABC): @@ -35,11 +33,6 @@ class BaseGroupChat(Team, ABC): raise ValueError("At least one participant is required.") if len(participants) != len(set(participant.name for participant in participants)): raise ValueError("The participant names must be unique.") - for participant in participants: - if isinstance(participant, ToolUseChatAgent) and not participant.registered_tools: - raise ValueError( - f"Participant '{participant.name}' is a tool use agent so it must have registered tools." - ) self._participants = participants self._team_id = str(uuid.uuid4()) self._base_group_chat_manager_class = group_chat_manager_class @@ -55,27 +48,19 @@ class BaseGroupChat(Team, ABC): ) -> Callable[[], BaseGroupChatManager]: ... def _create_participant_factory( - self, parent_topic_type: str, agent: ChatAgent, tool_agent_type: AgentType | None - ) -> Callable[[], BaseChatAgentContainer]: - def _factory() -> BaseChatAgentContainer: + self, + parent_topic_type: str, + agent: ChatAgent, + ) -> Callable[[], ChatAgentContainer]: + def _factory() -> ChatAgentContainer: id = AgentInstantiationContext.current_agent_id() assert id == AgentId(type=agent.name, key=self._team_id) - container = BaseChatAgentContainer(parent_topic_type, agent, tool_agent_type) + container = ChatAgentContainer(parent_topic_type, agent) assert container.id == id return container return _factory - def _create_tool_agent_factory( - self, - caller_name: str, - tools: List[Tool], - ) -> Callable[[], ToolAgent]: - def _factory() -> ToolAgent: - return ToolAgent(f"Tool agent for {caller_name}", tools) - - return _factory - async def run( self, task: str, @@ -99,27 +84,14 @@ class BaseGroupChat(Team, ABC): participant_topic_types: List[str] = [] participant_descriptions: List[str] = [] for participant in self._participants: - if isinstance(participant, ToolUseChatAgent): - assert participant.registered_tools is not None and len(participant.registered_tools) > 0 - # Register the tool agent. - tool_agent_type = await ToolAgent.register( - runtime, - f"tool_agent_for_{participant.name}", - self._create_tool_agent_factory(participant.name, participant.registered_tools), - ) - # No subscriptions are needed for the tool agent, which will be called via direct messages. - else: - # No tool agent is needed. - tool_agent_type = None - # Use the participant name as the agent type and topic type. agent_type = participant.name topic_type = participant.name # Register the participant factory. - await BaseChatAgentContainer.register( + await ChatAgentContainer.register( runtime, type=agent_type, - factory=self._create_participant_factory(group_topic_type, participant, tool_agent_type), + factory=self._create_participant_factory(group_topic_type, participant), ) # Add subscriptions for the participant. await runtime.add_subscription(TypeSubscription(topic_type=topic_type, agent_type=agent_type)) @@ -154,7 +126,10 @@ class BaseGroupChat(Team, ABC): group_chat_messages: List[ChatMessage] = [] async def collect_group_chat_messages( - _runtime: AgentRuntime, id: AgentId, message: ContentPublishEvent, ctx: MessageContext + _runtime: AgentRuntime, + id: AgentId, + message: GroupChatPublishEvent, + ctx: MessageContext, ) -> None: group_chat_messages.append(message.agent_message) @@ -174,10 +149,10 @@ class BaseGroupChat(Team, ABC): team_topic_id = TopicId(type=team_topic_type, source=self._team_id) group_chat_manager_topic_id = TopicId(type=group_chat_manager_topic_type, source=self._team_id) await runtime.publish_message( - ContentPublishEvent(agent_message=TextMessage(content=task, source="user")), + GroupChatPublishEvent(agent_message=TextMessage(content=task, source="user")), topic_id=team_topic_id, ) - await runtime.publish_message(ContentRequestEvent(), topic_id=group_chat_manager_topic_id) + await runtime.publish_message(GroupChatRequestPublishEvent(), topic_id=group_chat_manager_topic_id) # Wait for the runtime to stop. await runtime.stop_when_idle() diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat_manager.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat_manager.py index 5f59e9e63..68eb76c06 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat_manager.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat_manager.py @@ -1,13 +1,17 @@ import logging from abc import ABC, abstractmethod -from typing import List +from typing import Any, List -from autogen_core.base import MessageContext, TopicId -from autogen_core.components import event +from autogen_core.base import MessageContext +from autogen_core.components import DefaultTopicId, event from ... import EVENT_LOGGER_NAME from ...base import TerminationCondition -from .._events import ContentPublishEvent, ContentRequestEvent, TerminationEvent +from .._events import ( + GroupChatPublishEvent, + GroupChatRequestPublishEvent, + TerminationEvent, +) from ._sequential_routed_agent import SequentialRoutedAgent event_logger = logging.getLogger(EVENT_LOGGER_NAME) @@ -33,6 +37,10 @@ class BaseGroupChatManager(SequentialRoutedAgent, ABC): Raises: ValueError: If the number of participant topic types, agent types, and descriptions are not the same. + ValueError: If the participant topic types are not unique. + ValueError: If the group topic type is in the participant topic types. + ValueError: If the parent topic type is in the participant topic types. + ValueError: If the group topic type is the same as the parent topic type. """ def __init__( @@ -58,11 +66,11 @@ class BaseGroupChatManager(SequentialRoutedAgent, ABC): raise ValueError("The group topic type must not be the same as the parent topic type.") self._participant_topic_types = participant_topic_types self._participant_descriptions = participant_descriptions - self._message_thread: List[ContentPublishEvent] = [] + self._message_thread: List[GroupChatPublishEvent] = [] self._termination_condition = termination_condition @event - async def handle_content_publish(self, message: ContentPublishEvent, ctx: MessageContext) -> None: + async def handle_content_publish(self, message: GroupChatPublishEvent, ctx: MessageContext) -> None: """Handle a content publish event. If the event is from the parent topic, add the message to the thread. @@ -70,16 +78,25 @@ class BaseGroupChatManager(SequentialRoutedAgent, ABC): If the event is from the group chat topic, add the message to the thread and select a speaker to continue the conversation. If the event from the group chat session requests a pause, publish the last message to the parent topic.""" assert ctx.topic_id is not None - group_chat_topic_id = TopicId(type=self._group_topic_type, source=ctx.topic_id.source) event_logger.info(message) + if self._termination_condition is not None and self._termination_condition.terminated: + # The group chat has been terminated. + return + # Process event from parent. if ctx.topic_id.type == self._parent_topic_type: self._message_thread.append(message) await self.publish_message( - ContentPublishEvent(agent_message=message.agent_message, source=self.id), topic_id=group_chat_topic_id + GroupChatPublishEvent(agent_message=message.agent_message, source=self.id), + topic_id=DefaultTopicId(type=self._group_topic_type), ) + if self._termination_condition is not None: + stop_message = await self._termination_condition([message.agent_message]) + if stop_message is not None: + event_logger.info(TerminationEvent(agent_message=stop_message, source=self.id)) + # Stop the group chat. return # Process event from the group chat this agent manages. @@ -91,8 +108,6 @@ class BaseGroupChatManager(SequentialRoutedAgent, ABC): stop_message = await self._termination_condition([message.agent_message]) if stop_message is not None: event_logger.info(TerminationEvent(agent_message=stop_message, source=self.id)) - # Reset the termination condition. - await self._termination_condition.reset() # Stop the group chat. # TODO: this should be different if the group chat is nested. return @@ -100,24 +115,28 @@ class BaseGroupChatManager(SequentialRoutedAgent, ABC): # Select a speaker to continue the conversation. speaker_topic_type = await self.select_speaker(self._message_thread) - participant_topic_id = TopicId(type=speaker_topic_type, source=ctx.topic_id.source) - group_chat_topic_id = TopicId(type=self._group_topic_type, source=ctx.topic_id.source) - await self.publish_message(ContentRequestEvent(), topic_id=participant_topic_id) + await self.publish_message(GroupChatRequestPublishEvent(), topic_id=DefaultTopicId(type=speaker_topic_type)) @event - async def handle_content_request(self, message: ContentRequestEvent, ctx: MessageContext) -> None: + async def handle_content_request(self, message: GroupChatRequestPublishEvent, ctx: MessageContext) -> None: """Handle a content request by selecting a speaker to start the conversation.""" assert ctx.topic_id is not None if ctx.topic_id.type == self._group_topic_type: raise RuntimeError("Content request event from the group chat topic is not allowed.") + if self._termination_condition is not None and self._termination_condition.terminated: + # The group chat has been terminated. + return + speaker_topic_type = await self.select_speaker(self._message_thread) - participant_topic_id = TopicId(type=speaker_topic_type, source=ctx.topic_id.source) - await self.publish_message(ContentRequestEvent(), topic_id=participant_topic_id) + await self.publish_message(GroupChatRequestPublishEvent(), topic_id=DefaultTopicId(type=speaker_topic_type)) @abstractmethod - async def select_speaker(self, thread: List[ContentPublishEvent]) -> str: + async def select_speaker(self, thread: List[GroupChatPublishEvent]) -> str: """Select a speaker from the participants and return the topic type of the selected speaker.""" ... + + async def on_unhandled_message(self, message: Any, ctx: MessageContext) -> None: + raise ValueError(f"Unhandled message in group chat manager: {type(message)}") diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_chat_agent_container.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_chat_agent_container.py new file mode 100644 index 000000000..acf5e9d2f --- /dev/null +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_chat_agent_container.py @@ -0,0 +1,48 @@ +from typing import Any, List + +from autogen_core.base import MessageContext +from autogen_core.components import DefaultTopicId, event + +from ...base import ChatAgent +from ...messages import ChatMessage +from .._events import GroupChatPublishEvent, GroupChatRequestPublishEvent +from ._sequential_routed_agent import SequentialRoutedAgent + + +class ChatAgentContainer(SequentialRoutedAgent): + """A core agent class that delegates message handling to an + :class:`autogen_agentchat.base.ChatAgent` so that it can be used in a + group chat team. + + Args: + parent_topic_type (str): The topic type of the parent orchestrator. + agent (ChatAgent): The agent to delegate message handling to. + """ + + def __init__(self, parent_topic_type: str, agent: ChatAgent) -> None: + super().__init__(description=agent.description) + self._parent_topic_type = parent_topic_type + self._agent = agent + self._message_buffer: List[ChatMessage] = [] + + @event + async def handle_message(self, message: GroupChatPublishEvent, ctx: MessageContext) -> None: + """Handle an event by appending the content to the buffer.""" + self._message_buffer.append(message.agent_message) + + @event + async def handle_content_request(self, message: GroupChatRequestPublishEvent, ctx: MessageContext) -> None: + """Handle a content request event by passing the messages in the buffer + to the delegate agent and publish the response.""" + # Pass the messages in the buffer to the delegate agent. + response = await self._agent.on_messages(self._message_buffer, ctx.cancellation_token) + + # Publish the response. + self._message_buffer.clear() + await self.publish_message( + GroupChatPublishEvent(agent_message=response, source=self.id), + topic_id=DefaultTopicId(type=self._parent_topic_type), + ) + + async def on_unhandled_message(self, message: Any, ctx: MessageContext) -> None: + raise ValueError(f"Unhandled message in agent container: {type(message)}") diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_round_robin_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_round_robin_group_chat.py index 529314fec..fff872dd8 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_round_robin_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_round_robin_group_chat.py @@ -1,10 +1,17 @@ +import logging from typing import Callable, List +from ... import EVENT_LOGGER_NAME from ...base import ChatAgent, TerminationCondition -from .._events import ContentPublishEvent +from .._events import ( + GroupChatPublishEvent, + GroupChatSelectSpeakerEvent, +) from ._base_group_chat import BaseGroupChat from ._base_group_chat_manager import BaseGroupChatManager +event_logger = logging.getLogger(EVENT_LOGGER_NAME) + class RoundRobinGroupChatManager(BaseGroupChatManager): """A group chat manager that selects the next speaker in a round-robin fashion.""" @@ -26,11 +33,13 @@ class RoundRobinGroupChatManager(BaseGroupChatManager): ) self._next_speaker_index = 0 - async def select_speaker(self, thread: List[ContentPublishEvent]) -> str: + async def select_speaker(self, thread: List[GroupChatPublishEvent]) -> str: """Select a speaker from the participants in a round-robin fashion.""" current_speaker_index = self._next_speaker_index self._next_speaker_index = (current_speaker_index + 1) % len(self._participant_topic_types) - return self._participant_topic_types[current_speaker_index] + current_speaker = self._participant_topic_types[current_speaker_index] + event_logger.debug(GroupChatSelectSpeakerEvent(selected_speaker=current_speaker, source=self.id)) + return current_speaker class RoundRobinGroupChat(BaseGroupChat): diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py index 1eaf0ddd5..79f8b60de 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_selector_group_chat.py @@ -7,7 +7,10 @@ from autogen_core.components.models import ChatCompletionClient, SystemMessage from ... import EVENT_LOGGER_NAME, TRACE_LOGGER_NAME from ...base import ChatAgent, TerminationCondition from ...messages import MultiModalMessage, StopMessage, TextMessage -from .._events import ContentPublishEvent, SelectSpeakerEvent +from .._events import ( + GroupChatPublishEvent, + GroupChatSelectSpeakerEvent, +) from ._base_group_chat import BaseGroupChat from ._base_group_chat_manager import BaseGroupChatManager @@ -42,7 +45,7 @@ class SelectorGroupChatManager(BaseGroupChatManager): self._previous_speaker: str | None = None self._allow_repeated_speaker = allow_repeated_speaker - async def select_speaker(self, thread: List[ContentPublishEvent]) -> str: + async def select_speaker(self, thread: List[GroupChatPublishEvent]) -> str: """Selects the next speaker in a group chat using a ChatCompletion client. A key assumption is that the agent type is the same as the topic type, which we use as the agent name. @@ -107,7 +110,7 @@ class SelectorGroupChatManager(BaseGroupChatManager): else: agent_name = participants[0] self._previous_speaker = agent_name - event_logger.debug(SelectSpeakerEvent(selected_speaker=agent_name, source=self.id)) + event_logger.debug(GroupChatSelectSpeakerEvent(selected_speaker=agent_name, source=self.id)) return agent_name def _mentioned_agents(self, message_content: str, agent_names: List[str]) -> Dict[str, int]: @@ -148,7 +151,7 @@ class SelectorGroupChat(BaseGroupChat): to all, using a ChatCompletion model to select the next speaker after each message. Args: - participants (List[BaseChatAgent]): The participants in the group chat, + participants (List[ChatAgent]): The participants in the group chat, must have unique names and at least two participants. model_client (ChatCompletionClient): The ChatCompletion model client used to select the next speaker. diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_swarm_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_swarm_group_chat.py new file mode 100644 index 000000000..4f2d08afc --- /dev/null +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_swarm_group_chat.py @@ -0,0 +1,72 @@ +import logging +from typing import Callable, List + +from ... import EVENT_LOGGER_NAME +from ...base import ChatAgent, TerminationCondition +from ...messages import HandoffMessage +from .._events import ( + GroupChatPublishEvent, + GroupChatSelectSpeakerEvent, +) +from ._base_group_chat import BaseGroupChat +from ._base_group_chat_manager import BaseGroupChatManager + +event_logger = logging.getLogger(EVENT_LOGGER_NAME) + + +class SwarmGroupChatManager(BaseGroupChatManager): + """A group chat manager that selects the next speaker based on handoff message only.""" + + def __init__( + self, + parent_topic_type: str, + group_topic_type: str, + participant_topic_types: List[str], + participant_descriptions: List[str], + termination_condition: TerminationCondition | None, + ) -> None: + super().__init__( + parent_topic_type, + group_topic_type, + participant_topic_types, + participant_descriptions, + termination_condition, + ) + self._current_speaker = participant_topic_types[0] + + async def select_speaker(self, thread: List[GroupChatPublishEvent]) -> str: + """Select a speaker from the participants based on handoff message.""" + if len(thread) > 0 and isinstance(thread[-1].agent_message, HandoffMessage): + self._current_speaker = thread[-1].agent_message.content + if self._current_speaker not in self._participant_topic_types: + raise ValueError("The selected speaker in the handoff message is not a participant.") + event_logger.debug(GroupChatSelectSpeakerEvent(selected_speaker=self._current_speaker, source=self.id)) + return self._current_speaker + else: + return self._current_speaker + + +class Swarm(BaseGroupChat): + """(Experimental) A group chat that selects the next speaker based on handoff message only.""" + + def __init__(self, participants: List[ChatAgent]): + super().__init__(participants, group_chat_manager_class=SwarmGroupChatManager) + + def _create_group_chat_manager_factory( + self, + parent_topic_type: str, + group_topic_type: str, + participant_topic_types: List[str], + participant_descriptions: List[str], + termination_condition: TerminationCondition | None, + ) -> Callable[[], SwarmGroupChatManager]: + def _factory() -> SwarmGroupChatManager: + return SwarmGroupChatManager( + parent_topic_type, + group_topic_type, + participant_topic_types, + participant_descriptions, + termination_condition, + ) + + return _factory diff --git a/python/packages/autogen-agentchat/tests/test_group_chat.py b/python/packages/autogen-agentchat/tests/test_group_chat.py index 97cf65c2d..9f740eb64 100644 --- a/python/packages/autogen-agentchat/tests/test_group_chat.py +++ b/python/packages/autogen-agentchat/tests/test_group_chat.py @@ -13,11 +13,17 @@ from autogen_agentchat.agents import ( ToolUseAssistantAgent, ) from autogen_agentchat.logging import FileLogHandler -from autogen_agentchat.messages import ChatMessage, StopMessage, TextMessage -from autogen_agentchat.task import StopMessageTermination +from autogen_agentchat.messages import ( + ChatMessage, + HandoffMessage, + StopMessage, + TextMessage, +) +from autogen_agentchat.task import MaxMessageTermination, StopMessageTermination from autogen_agentchat.teams import ( RoundRobinGroupChat, SelectorGroupChat, + Swarm, ) from autogen_core.base import CancellationToken from autogen_core.components import FunctionCall @@ -212,7 +218,16 @@ async def test_round_robin_group_chat_with_tools(monkeypatch: pytest.MonkeyPatch ) echo_agent = _EchoAgent("echo_agent", description="echo agent") team = RoundRobinGroupChat(participants=[tool_use_agent, echo_agent]) - await team.run("Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination()) + result = await team.run( + "Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination() + ) + + assert len(result.messages) == 4 + assert isinstance(result.messages[0], TextMessage) # task + assert isinstance(result.messages[1], TextMessage) # tool use agent response + assert isinstance(result.messages[2], TextMessage) # echo agent response + assert isinstance(result.messages[3], StopMessage) # tool use agent response + context = tool_use_agent._model_context # pyright: ignore assert context[0].content == "Write a program that prints 'Hello, world!'" assert isinstance(context[1].content, list) @@ -393,3 +408,29 @@ async def test_selector_group_chat_two_speakers_allow_repeated(monkeypatch: pyte assert result.messages[1].source == "agent2" assert result.messages[2].source == "agent2" assert result.messages[3].source == "agent1" + + +class _HandOffAgent(BaseChatAgent): + def __init__(self, name: str, description: str, next_agent: str) -> None: + super().__init__(name, description) + self._next_agent = next_agent + + async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage: + return HandoffMessage(content=self._next_agent, source=self.name) + + +@pytest.mark.asyncio +async def test_swarm() -> None: + first_agent = _HandOffAgent("first_agent", description="first agent", next_agent="second_agent") + second_agent = _HandOffAgent("second_agent", description="second agent", next_agent="third_agent") + third_agent = _HandOffAgent("third_agent", description="third agent", next_agent="first_agent") + + team = Swarm([second_agent, first_agent, third_agent]) + result = await team.run("task", termination_condition=MaxMessageTermination(6)) + assert len(result.messages) == 6 + assert result.messages[0].content == "task" + assert result.messages[1].content == "third_agent" + assert result.messages[2].content == "first_agent" + assert result.messages[3].content == "second_agent" + assert result.messages[4].content == "third_agent" + assert result.messages[5].content == "first_agent" diff --git a/python/packages/autogen-agentchat/tests/test_tool_use_assistant_agent.py b/python/packages/autogen-agentchat/tests/test_tool_use_assistant_agent.py index 3a1734b3f..d5ec31a12 100644 --- a/python/packages/autogen-agentchat/tests/test_tool_use_assistant_agent.py +++ b/python/packages/autogen-agentchat/tests/test_tool_use_assistant_agent.py @@ -4,6 +4,7 @@ from typing import Any, AsyncGenerator, List import pytest from autogen_agentchat.agents import ToolUseAssistantAgent +from autogen_agentchat.messages import StopMessage, TextMessage from autogen_core.components.models import OpenAIChatCompletionClient from autogen_core.components.tools import FunctionTool from openai.resources.chat.completions import AsyncCompletions @@ -103,5 +104,6 @@ async def test_round_robin_group_chat_with_tools(monkeypatch: pytest.MonkeyPatch ) result = await tool_use_agent.run("task") assert len(result.messages) == 3 - # assert isinstance(result.messages[1], ToolCallMessage) - # assert isinstance(result.messages[2], TextMessage) + assert isinstance(result.messages[0], TextMessage) + assert isinstance(result.messages[1], TextMessage) + assert isinstance(result.messages[2], StopMessage)