Refactor agent chat to prepare for handoff/swarm (#3949)

Add handoff message type to chat message types
Add Swarm group chat that uses handoff message to select next speaker
Remove tool call and tool call result message types from chat message types
Remove BaseToolUseChatAgent, move tool call handling from group chat's chat agent container upward to the ToolUseAssistantAgent implementation, which subclasses BaseChatAgent directly.
Renaming for better clarity

---------

Co-authored-by: Victor Dibia <victordibia@microsoft.com>
This commit is contained in:
Eric Zhu 2024-10-25 10:57:04 -07:00 committed by GitHub
parent 0756ebd63d
commit f31ff66368
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 363 additions and 293 deletions

View File

@ -1,11 +1,10 @@
from ._base_chat_agent import BaseChatAgent, BaseToolUseChatAgent
from ._base_chat_agent import BaseChatAgent
from ._code_executor_agent import CodeExecutorAgent
from ._coding_assistant_agent import CodingAssistantAgent
from ._tool_use_assistant_agent import ToolUseAssistantAgent
__all__ = [
"BaseChatAgent",
"BaseToolUseChatAgent",
"CodeExecutorAgent",
"CodingAssistantAgent",
"ToolUseAssistantAgent",

View File

@ -1,10 +1,9 @@
from abc import ABC, abstractmethod
from typing import List, Sequence
from typing import Sequence
from autogen_core.base import CancellationToken
from autogen_core.components.tools import Tool
from ..base import ChatAgent, TaskResult, TerminationCondition, ToolUseChatAgent
from ..base import ChatAgent, TaskResult, TerminationCondition
from ..messages import ChatMessage
from ..teams import RoundRobinGroupChat
@ -51,21 +50,3 @@ class BaseChatAgent(ChatAgent, ABC):
termination_condition=termination_condition,
)
return result
class BaseToolUseChatAgent(BaseChatAgent, ToolUseChatAgent):
"""Base class for a chat agent that can use tools.
Subclass this base class to create an agent class that uses tools by returning
ToolCallMessage message from the :meth:`on_messages` method and receiving
ToolCallResultMessage message from the input to the :meth:`on_messages` method.
"""
def __init__(self, name: str, description: str, registered_tools: List[Tool]) -> None:
super().__init__(name, description)
self._registered_tools = registered_tools
@property
def registered_tools(self) -> List[Tool]:
"""The list of tools that the agent can use."""
return self._registered_tools

View File

@ -1,3 +1,6 @@
import asyncio
import json
import logging
from typing import Any, Awaitable, Callable, List, Sequence
from autogen_core.base import CancellationToken
@ -5,25 +8,45 @@ from autogen_core.components import FunctionCall
from autogen_core.components.models import (
AssistantMessage,
ChatCompletionClient,
FunctionExecutionResult,
FunctionExecutionResultMessage,
LLMMessage,
SystemMessage,
UserMessage,
)
from autogen_core.components.tools import FunctionTool, Tool
from pydantic import BaseModel, ConfigDict
from .. import EVENT_LOGGER_NAME
from ..messages import (
ChatMessage,
MultiModalMessage,
StopMessage,
TextMessage,
ToolCallMessage,
ToolCallResultMessage,
)
from ._base_chat_agent import BaseToolUseChatAgent
from ._base_chat_agent import BaseChatAgent
event_logger = logging.getLogger(EVENT_LOGGER_NAME)
class ToolUseAssistantAgent(BaseToolUseChatAgent):
class ToolCallEvent(BaseModel):
"""A tool call event."""
tool_calls: List[FunctionCall]
"""The tool call message."""
model_config = ConfigDict(arbitrary_types_allowed=True)
class ToolCallResultEvent(BaseModel):
"""A tool call result event."""
tool_call_results: List[FunctionExecutionResult]
"""The tool call result message."""
model_config = ConfigDict(arbitrary_types_allowed=True)
class ToolUseAssistantAgent(BaseChatAgent):
"""An agent that provides assistance with tool use.
It responds with a StopMessage when 'terminate' is detected in the response.
@ -45,46 +68,50 @@ class ToolUseAssistantAgent(BaseToolUseChatAgent):
description: str = "An agent that provides assistance with ability to use tools.",
system_message: str = "You are a helpful AI assistant. Solve tasks using your tools. Reply with 'TERMINATE' when the task has been completed.",
):
tools: List[Tool] = []
super().__init__(name=name, description=description)
self._model_client = model_client
self._system_messages = [SystemMessage(content=system_message)]
self._tools: List[Tool] = []
for tool in registered_tools:
if isinstance(tool, Tool):
tools.append(tool)
self._tools.append(tool)
elif callable(tool):
if hasattr(tool, "__doc__") and tool.__doc__ is not None:
description = tool.__doc__
else:
description = ""
tools.append(FunctionTool(tool, description=description))
self._tools.append(FunctionTool(tool, description=description))
else:
raise ValueError(f"Unsupported tool type: {type(tool)}")
super().__init__(name=name, description=description, registered_tools=tools)
self._model_client = model_client
self._system_messages = [SystemMessage(content=system_message)]
self._tool_schema = [tool.schema for tool in tools]
self._model_context: List[LLMMessage] = []
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage:
# Add messages to the model context.
for msg in messages:
if isinstance(msg, ToolCallResultMessage):
self._model_context.append(FunctionExecutionResultMessage(content=msg.content))
elif not isinstance(msg, TextMessage | MultiModalMessage | StopMessage):
raise ValueError(f"Unsupported message type: {type(msg)}")
else:
self._model_context.append(UserMessage(content=msg.content, source=msg.source))
# TODO: add special handling for handoff messages
self._model_context.append(UserMessage(content=msg.content, source=msg.source))
# Generate an inference result based on the current model context.
llm_messages = self._system_messages + self._model_context
result = await self._model_client.create(
llm_messages, tools=self._tool_schema, cancellation_token=cancellation_token
)
result = await self._model_client.create(llm_messages, tools=self._tools, cancellation_token=cancellation_token)
# Add the response to the model context.
self._model_context.append(AssistantMessage(content=result.content, source=self.name))
# Detect tool calls.
if isinstance(result.content, list) and all(isinstance(item, FunctionCall) for item in result.content):
return ToolCallMessage(content=result.content, source=self.name)
# Run tool calls until the model produces a string response.
while isinstance(result.content, list) and all(isinstance(item, FunctionCall) for item in result.content):
event_logger.debug(ToolCallEvent(tool_calls=result.content))
# Execute the tool calls.
results = await asyncio.gather(
*[self._execute_tool_call(call, cancellation_token) for call in result.content]
)
event_logger.debug(ToolCallResultEvent(tool_call_results=results))
self._model_context.append(FunctionExecutionResultMessage(content=results))
# Generate an inference result based on the current model context.
result = await self._model_client.create(
self._model_context, tools=self._tools, cancellation_token=cancellation_token
)
self._model_context.append(AssistantMessage(content=result.content, source=self.name))
assert isinstance(result.content, str)
# Detect stop request.
@ -93,3 +120,20 @@ class ToolUseAssistantAgent(BaseToolUseChatAgent):
return StopMessage(content=result.content, source=self.name)
return TextMessage(content=result.content, source=self.name)
async def _execute_tool_call(
self, tool_call: FunctionCall, cancellation_token: CancellationToken
) -> FunctionExecutionResult:
"""Execute a tool call and return the result."""
try:
if not self._tools:
raise ValueError("No tools are available.")
tool = next((t for t in self._tools if t.name == tool_call.name), None)
if tool is None:
raise ValueError(f"The tool '{tool_call.name}' is not available.")
arguments = json.loads(tool_call.arguments)
result = await tool.run_json(arguments, cancellation_token)
result_as_str = tool.return_value_as_string(result)
return FunctionExecutionResult(content=result_as_str, call_id=tool_call.id)
except Exception as e:
return FunctionExecutionResult(content=f"Error: {e}", call_id=tool_call.id)

View File

@ -1,11 +1,10 @@
from ._chat_agent import ChatAgent, ToolUseChatAgent
from ._chat_agent import ChatAgent
from ._task import TaskResult, TaskRunner
from ._team import Team
from ._termination import TerminatedException, TerminationCondition
__all__ = [
"ChatAgent",
"ToolUseChatAgent",
"Team",
"TerminatedException",
"TerminationCondition",

View File

@ -1,7 +1,6 @@
from typing import List, Protocol, Sequence, runtime_checkable
from typing import Protocol, Sequence, runtime_checkable
from autogen_core.base import CancellationToken
from autogen_core.components.tools import Tool
from ..messages import ChatMessage
from ._task import TaskResult, TaskRunner
@ -38,13 +37,3 @@ class ChatAgent(TaskRunner, Protocol):
) -> TaskResult:
"""Run the agent with the given task and return the result."""
...
@runtime_checkable
class ToolUseChatAgent(ChatAgent, Protocol):
"""Protocol for a chat agent that can use tools."""
@property
def registered_tools(self) -> List[Tool]:
"""The list of tools that the agent can use."""
...

View File

@ -3,13 +3,12 @@ import logging
import sys
from datetime import datetime
from ..agents._tool_use_assistant_agent import ToolCallEvent, ToolCallResultEvent
from ..messages import ChatMessage, StopMessage, TextMessage
from ..teams._events import (
ContentPublishEvent,
SelectSpeakerEvent,
GroupChatPublishEvent,
GroupChatSelectSpeakerEvent,
TerminationEvent,
ToolCallEvent,
ToolCallResultEvent,
)
@ -25,7 +24,7 @@ class ConsoleLogHandler(logging.Handler):
def emit(self, record: logging.LogRecord) -> None:
ts = datetime.fromtimestamp(record.created).isoformat()
if isinstance(record.msg, ContentPublishEvent):
if isinstance(record.msg, GroupChatPublishEvent):
if record.msg.source is None:
sys.stdout.write(
f"\n{'-'*75} \n"
@ -41,19 +40,15 @@ class ConsoleLogHandler(logging.Handler):
sys.stdout.flush()
elif isinstance(record.msg, ToolCallEvent):
sys.stdout.write(
f"\n{'-'*75} \n"
f"\033[91m[{ts}], Tool Call:\033[0m\n"
f"\n{self.serialize_chat_message(record.msg.agent_message)}"
f"\n{'-'*75} \n" f"\033[91m[{ts}], Tool Call:\033[0m\n" f"\n{str(record.msg.model_dump())}"
)
sys.stdout.flush()
elif isinstance(record.msg, ToolCallResultEvent):
sys.stdout.write(
f"\n{'-'*75} \n"
f"\033[91m[{ts}], Tool Call Result:\033[0m\n"
f"\n{self.serialize_chat_message(record.msg.agent_message)}"
f"\n{'-'*75} \n" f"\033[91m[{ts}], Tool Call Result:\033[0m\n" f"\n{str(record.msg.model_dump())}"
)
sys.stdout.flush()
elif isinstance(record.msg, SelectSpeakerEvent):
elif isinstance(record.msg, GroupChatSelectSpeakerEvent):
sys.stdout.write(
f"\n{'-'*75} \n" f"\033[91m[{ts}], Selected Next Speaker:\033[0m\n" f"\n{record.msg.selected_speaker}"
)

View File

@ -4,12 +4,11 @@ from dataclasses import asdict, is_dataclass
from datetime import datetime
from typing import Any
from ..agents._tool_use_assistant_agent import ToolCallEvent, ToolCallResultEvent
from ..teams._events import (
ContentPublishEvent,
SelectSpeakerEvent,
GroupChatPublishEvent,
GroupChatSelectSpeakerEvent,
TerminationEvent,
ToolCallEvent,
ToolCallResultEvent,
)
@ -21,7 +20,7 @@ class FileLogHandler(logging.Handler):
def emit(self, record: logging.LogRecord) -> None:
ts = datetime.fromtimestamp(record.created).isoformat()
if isinstance(record.msg, ContentPublishEvent | ToolCallEvent | ToolCallResultEvent | TerminationEvent):
if isinstance(record.msg, GroupChatPublishEvent | TerminationEvent):
log_entry = json.dumps(
{
"timestamp": ts,
@ -31,7 +30,7 @@ class FileLogHandler(logging.Handler):
},
default=self.json_serializer,
)
elif isinstance(record.msg, SelectSpeakerEvent):
elif isinstance(record.msg, GroupChatSelectSpeakerEvent):
log_entry = json.dumps(
{
"timestamp": ts,
@ -41,6 +40,24 @@ class FileLogHandler(logging.Handler):
},
default=self.json_serializer,
)
elif isinstance(record.msg, ToolCallEvent):
log_entry = json.dumps(
{
"timestamp": ts,
"tool_calls": record.msg.model_dump(),
"type": "ToolCallEvent",
},
default=self.json_serializer,
)
elif isinstance(record.msg, ToolCallResultEvent):
log_entry = json.dumps(
{
"timestamp": ts,
"tool_call_results": record.msg.model_dump(),
"type": "ToolCallResultEvent",
},
default=self.json_serializer,
)
else:
raise ValueError(f"Unexpected log record: {record.msg}")
file_record = logging.LogRecord(

View File

@ -1,7 +1,6 @@
from typing import List
from autogen_core.components import FunctionCall, Image
from autogen_core.components.models import FunctionExecutionResult
from autogen_core.components import Image
from pydantic import BaseModel
@ -26,20 +25,6 @@ class MultiModalMessage(BaseMessage):
"""The content of the message."""
class ToolCallMessage(BaseMessage):
"""A message containing a list of function calls."""
content: List[FunctionCall]
"""The list of function calls."""
class ToolCallResultMessage(BaseMessage):
"""A message containing the results of function calls."""
content: List[FunctionExecutionResult]
"""The list of function execution results."""
class StopMessage(BaseMessage):
"""A message requesting stop of a conversation."""
@ -47,7 +32,14 @@ class StopMessage(BaseMessage):
"""The content for the stop message."""
ChatMessage = TextMessage | MultiModalMessage | StopMessage | ToolCallMessage | ToolCallResultMessage
class HandoffMessage(BaseMessage):
"""A message requesting handoff of a conversation to another agent."""
content: str
"""The agent name to handoff the conversation to."""
ChatMessage = TextMessage | MultiModalMessage | StopMessage | HandoffMessage
"""A message used by agents in a team."""
@ -55,8 +47,7 @@ __all__ = [
"BaseMessage",
"TextMessage",
"MultiModalMessage",
"ToolCallMessage",
"ToolCallResultMessage",
"StopMessage",
"HandoffMessage",
"ChatMessage",
]

View File

@ -1,7 +1,9 @@
from ._group_chat._round_robin_group_chat import RoundRobinGroupChat
from ._group_chat._selector_group_chat import SelectorGroupChat
from ._group_chat._swarm_group_chat import Swarm
__all__ = [
"RoundRobinGroupChat",
"SelectorGroupChat",
"Swarm",
]

View File

@ -1,16 +1,16 @@
from autogen_core.base import AgentId
from pydantic import BaseModel, ConfigDict
from ..messages import MultiModalMessage, StopMessage, TextMessage, ToolCallMessage, ToolCallResultMessage
from ..messages import ChatMessage, StopMessage
class ContentPublishEvent(BaseModel):
"""An event for sharing some data. Agents receive this event should
class GroupChatPublishEvent(BaseModel):
"""An group chat event for sharing some data. Agents receive this event should
update their internal state (e.g., append to message history) with the
content of the event.
"""
agent_message: TextMessage | MultiModalMessage | StopMessage
agent_message: ChatMessage
"""The message published by the agent."""
source: AgentId | None = None
@ -19,39 +19,15 @@ class ContentPublishEvent(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
class ContentRequestEvent(BaseModel):
"""An event for requesting to publish a content event.
Upon receiving this event, the agent should publish a ContentPublishEvent.
class GroupChatRequestPublishEvent(BaseModel):
"""An event for requesting to publish a group chat publish event.
Upon receiving this event, the agent should publish a group chat publish event.
"""
...
class ToolCallEvent(BaseModel):
"""An event produced when requesting a tool call."""
agent_message: ToolCallMessage
"""The tool call message."""
source: AgentId
"""The sender of the tool call message."""
model_config = ConfigDict(arbitrary_types_allowed=True)
class ToolCallResultEvent(BaseModel):
"""An event produced when a tool call is completed."""
agent_message: ToolCallResultMessage
"""The tool call result message."""
source: AgentId
"""The sender of the tool call result message."""
model_config = ConfigDict(arbitrary_types_allowed=True)
class SelectSpeakerEvent(BaseModel):
class GroupChatSelectSpeakerEvent(BaseModel):
"""An event for selecting the next speaker in a group chat."""
selected_speaker: str

View File

@ -1,92 +0,0 @@
import asyncio
import logging
from typing import List
from autogen_core.base import AgentId, AgentType, MessageContext
from autogen_core.components import DefaultTopicId, event
from autogen_core.components.models import FunctionExecutionResult
from autogen_core.components.tool_agent import ToolException
from ... import EVENT_LOGGER_NAME
from ...base import ChatAgent
from ...messages import MultiModalMessage, StopMessage, TextMessage, ToolCallMessage, ToolCallResultMessage
from .._events import ContentPublishEvent, ContentRequestEvent, ToolCallEvent, ToolCallResultEvent
from ._sequential_routed_agent import SequentialRoutedAgent
event_logger = logging.getLogger(EVENT_LOGGER_NAME)
class BaseChatAgentContainer(SequentialRoutedAgent):
"""A core agent class that delegates message handling to an
:class:`autogen_agentchat.agents.BaseChatAgent` so that it can be used in a
group chat team.
Args:
parent_topic_type (str): The topic type of the parent orchestrator.
agent (BaseChatAgent): The agent to delegate message handling to.
tool_agent_type (AgentType, optional): The agent type of the tool agent. Defaults to None.
"""
def __init__(self, parent_topic_type: str, agent: ChatAgent, tool_agent_type: AgentType | None = None) -> None:
super().__init__(description=agent.description)
self._parent_topic_type = parent_topic_type
self._agent = agent
self._message_buffer: List[TextMessage | MultiModalMessage | StopMessage] = []
self._tool_agent_id = AgentId(type=tool_agent_type, key=self.id.key) if tool_agent_type else None
@event
async def handle_content_publish(self, message: ContentPublishEvent, ctx: MessageContext) -> None:
"""Handle a content publish event by appending the content to the buffer."""
if not isinstance(message.agent_message, TextMessage | MultiModalMessage | StopMessage):
raise ValueError(
f"Unexpected message type: {type(message.agent_message)}. "
"The message must be a text, multimodal, or stop message."
)
self._message_buffer.append(message.agent_message)
@event
async def handle_content_request(self, message: ContentRequestEvent, ctx: MessageContext) -> None:
"""Handle a content request event by passing the messages in the buffer
to the delegate agent and publish the response."""
response = await self._agent.on_messages(self._message_buffer, ctx.cancellation_token)
if self._tool_agent_id is not None:
# Handle tool calls.
while isinstance(response, ToolCallMessage):
# Log the tool call.
event_logger.debug(ToolCallEvent(agent_message=response, source=self.id))
results: List[FunctionExecutionResult | BaseException] = await asyncio.gather(
*[
self.send_message(
message=call,
recipient=self._tool_agent_id,
cancellation_token=ctx.cancellation_token,
)
for call in response.content
]
)
# Combine the results in to a single response and handle exceptions.
function_results: List[FunctionExecutionResult] = []
for result in results:
if isinstance(result, FunctionExecutionResult):
function_results.append(result)
elif isinstance(result, ToolException):
function_results.append(
FunctionExecutionResult(content=f"Error: {result}", call_id=result.call_id)
)
elif isinstance(result, BaseException):
raise result # Unexpected exception.
# Create a new tool call result message.
feedback = ToolCallResultMessage(content=function_results, source=self._tool_agent_id.type)
# Log the feedback.
event_logger.debug(ToolCallResultEvent(agent_message=feedback, source=self._tool_agent_id))
response = await self._agent.on_messages([feedback], ctx.cancellation_token)
# Publish the response.
assert isinstance(response, TextMessage | MultiModalMessage | StopMessage)
self._message_buffer.clear()
await self.publish_message(
ContentPublishEvent(agent_message=response, source=self.id),
topic_id=DefaultTopicId(type=self._parent_topic_type),
)

View File

@ -13,14 +13,12 @@ from autogen_core.base import (
TopicId,
)
from autogen_core.components import ClosureAgent, TypeSubscription
from autogen_core.components.tool_agent import ToolAgent
from autogen_core.components.tools import Tool
from ...base import ChatAgent, TaskResult, Team, TerminationCondition, ToolUseChatAgent
from ...base import ChatAgent, TaskResult, Team, TerminationCondition
from ...messages import ChatMessage, TextMessage
from .._events import ContentPublishEvent, ContentRequestEvent
from ._base_chat_agent_container import BaseChatAgentContainer
from .._events import GroupChatPublishEvent, GroupChatRequestPublishEvent
from ._base_group_chat_manager import BaseGroupChatManager
from ._chat_agent_container import ChatAgentContainer
class BaseGroupChat(Team, ABC):
@ -35,11 +33,6 @@ class BaseGroupChat(Team, ABC):
raise ValueError("At least one participant is required.")
if len(participants) != len(set(participant.name for participant in participants)):
raise ValueError("The participant names must be unique.")
for participant in participants:
if isinstance(participant, ToolUseChatAgent) and not participant.registered_tools:
raise ValueError(
f"Participant '{participant.name}' is a tool use agent so it must have registered tools."
)
self._participants = participants
self._team_id = str(uuid.uuid4())
self._base_group_chat_manager_class = group_chat_manager_class
@ -55,27 +48,19 @@ class BaseGroupChat(Team, ABC):
) -> Callable[[], BaseGroupChatManager]: ...
def _create_participant_factory(
self, parent_topic_type: str, agent: ChatAgent, tool_agent_type: AgentType | None
) -> Callable[[], BaseChatAgentContainer]:
def _factory() -> BaseChatAgentContainer:
self,
parent_topic_type: str,
agent: ChatAgent,
) -> Callable[[], ChatAgentContainer]:
def _factory() -> ChatAgentContainer:
id = AgentInstantiationContext.current_agent_id()
assert id == AgentId(type=agent.name, key=self._team_id)
container = BaseChatAgentContainer(parent_topic_type, agent, tool_agent_type)
container = ChatAgentContainer(parent_topic_type, agent)
assert container.id == id
return container
return _factory
def _create_tool_agent_factory(
self,
caller_name: str,
tools: List[Tool],
) -> Callable[[], ToolAgent]:
def _factory() -> ToolAgent:
return ToolAgent(f"Tool agent for {caller_name}", tools)
return _factory
async def run(
self,
task: str,
@ -99,27 +84,14 @@ class BaseGroupChat(Team, ABC):
participant_topic_types: List[str] = []
participant_descriptions: List[str] = []
for participant in self._participants:
if isinstance(participant, ToolUseChatAgent):
assert participant.registered_tools is not None and len(participant.registered_tools) > 0
# Register the tool agent.
tool_agent_type = await ToolAgent.register(
runtime,
f"tool_agent_for_{participant.name}",
self._create_tool_agent_factory(participant.name, participant.registered_tools),
)
# No subscriptions are needed for the tool agent, which will be called via direct messages.
else:
# No tool agent is needed.
tool_agent_type = None
# Use the participant name as the agent type and topic type.
agent_type = participant.name
topic_type = participant.name
# Register the participant factory.
await BaseChatAgentContainer.register(
await ChatAgentContainer.register(
runtime,
type=agent_type,
factory=self._create_participant_factory(group_topic_type, participant, tool_agent_type),
factory=self._create_participant_factory(group_topic_type, participant),
)
# Add subscriptions for the participant.
await runtime.add_subscription(TypeSubscription(topic_type=topic_type, agent_type=agent_type))
@ -154,7 +126,10 @@ class BaseGroupChat(Team, ABC):
group_chat_messages: List[ChatMessage] = []
async def collect_group_chat_messages(
_runtime: AgentRuntime, id: AgentId, message: ContentPublishEvent, ctx: MessageContext
_runtime: AgentRuntime,
id: AgentId,
message: GroupChatPublishEvent,
ctx: MessageContext,
) -> None:
group_chat_messages.append(message.agent_message)
@ -174,10 +149,10 @@ class BaseGroupChat(Team, ABC):
team_topic_id = TopicId(type=team_topic_type, source=self._team_id)
group_chat_manager_topic_id = TopicId(type=group_chat_manager_topic_type, source=self._team_id)
await runtime.publish_message(
ContentPublishEvent(agent_message=TextMessage(content=task, source="user")),
GroupChatPublishEvent(agent_message=TextMessage(content=task, source="user")),
topic_id=team_topic_id,
)
await runtime.publish_message(ContentRequestEvent(), topic_id=group_chat_manager_topic_id)
await runtime.publish_message(GroupChatRequestPublishEvent(), topic_id=group_chat_manager_topic_id)
# Wait for the runtime to stop.
await runtime.stop_when_idle()

View File

@ -1,13 +1,17 @@
import logging
from abc import ABC, abstractmethod
from typing import List
from typing import Any, List
from autogen_core.base import MessageContext, TopicId
from autogen_core.components import event
from autogen_core.base import MessageContext
from autogen_core.components import DefaultTopicId, event
from ... import EVENT_LOGGER_NAME
from ...base import TerminationCondition
from .._events import ContentPublishEvent, ContentRequestEvent, TerminationEvent
from .._events import (
GroupChatPublishEvent,
GroupChatRequestPublishEvent,
TerminationEvent,
)
from ._sequential_routed_agent import SequentialRoutedAgent
event_logger = logging.getLogger(EVENT_LOGGER_NAME)
@ -33,6 +37,10 @@ class BaseGroupChatManager(SequentialRoutedAgent, ABC):
Raises:
ValueError: If the number of participant topic types, agent types, and descriptions are not the same.
ValueError: If the participant topic types are not unique.
ValueError: If the group topic type is in the participant topic types.
ValueError: If the parent topic type is in the participant topic types.
ValueError: If the group topic type is the same as the parent topic type.
"""
def __init__(
@ -58,11 +66,11 @@ class BaseGroupChatManager(SequentialRoutedAgent, ABC):
raise ValueError("The group topic type must not be the same as the parent topic type.")
self._participant_topic_types = participant_topic_types
self._participant_descriptions = participant_descriptions
self._message_thread: List[ContentPublishEvent] = []
self._message_thread: List[GroupChatPublishEvent] = []
self._termination_condition = termination_condition
@event
async def handle_content_publish(self, message: ContentPublishEvent, ctx: MessageContext) -> None:
async def handle_content_publish(self, message: GroupChatPublishEvent, ctx: MessageContext) -> None:
"""Handle a content publish event.
If the event is from the parent topic, add the message to the thread.
@ -70,16 +78,25 @@ class BaseGroupChatManager(SequentialRoutedAgent, ABC):
If the event is from the group chat topic, add the message to the thread and select a speaker to continue the conversation.
If the event from the group chat session requests a pause, publish the last message to the parent topic."""
assert ctx.topic_id is not None
group_chat_topic_id = TopicId(type=self._group_topic_type, source=ctx.topic_id.source)
event_logger.info(message)
if self._termination_condition is not None and self._termination_condition.terminated:
# The group chat has been terminated.
return
# Process event from parent.
if ctx.topic_id.type == self._parent_topic_type:
self._message_thread.append(message)
await self.publish_message(
ContentPublishEvent(agent_message=message.agent_message, source=self.id), topic_id=group_chat_topic_id
GroupChatPublishEvent(agent_message=message.agent_message, source=self.id),
topic_id=DefaultTopicId(type=self._group_topic_type),
)
if self._termination_condition is not None:
stop_message = await self._termination_condition([message.agent_message])
if stop_message is not None:
event_logger.info(TerminationEvent(agent_message=stop_message, source=self.id))
# Stop the group chat.
return
# Process event from the group chat this agent manages.
@ -91,8 +108,6 @@ class BaseGroupChatManager(SequentialRoutedAgent, ABC):
stop_message = await self._termination_condition([message.agent_message])
if stop_message is not None:
event_logger.info(TerminationEvent(agent_message=stop_message, source=self.id))
# Reset the termination condition.
await self._termination_condition.reset()
# Stop the group chat.
# TODO: this should be different if the group chat is nested.
return
@ -100,24 +115,28 @@ class BaseGroupChatManager(SequentialRoutedAgent, ABC):
# Select a speaker to continue the conversation.
speaker_topic_type = await self.select_speaker(self._message_thread)
participant_topic_id = TopicId(type=speaker_topic_type, source=ctx.topic_id.source)
group_chat_topic_id = TopicId(type=self._group_topic_type, source=ctx.topic_id.source)
await self.publish_message(ContentRequestEvent(), topic_id=participant_topic_id)
await self.publish_message(GroupChatRequestPublishEvent(), topic_id=DefaultTopicId(type=speaker_topic_type))
@event
async def handle_content_request(self, message: ContentRequestEvent, ctx: MessageContext) -> None:
async def handle_content_request(self, message: GroupChatRequestPublishEvent, ctx: MessageContext) -> None:
"""Handle a content request by selecting a speaker to start the conversation."""
assert ctx.topic_id is not None
if ctx.topic_id.type == self._group_topic_type:
raise RuntimeError("Content request event from the group chat topic is not allowed.")
if self._termination_condition is not None and self._termination_condition.terminated:
# The group chat has been terminated.
return
speaker_topic_type = await self.select_speaker(self._message_thread)
participant_topic_id = TopicId(type=speaker_topic_type, source=ctx.topic_id.source)
await self.publish_message(ContentRequestEvent(), topic_id=participant_topic_id)
await self.publish_message(GroupChatRequestPublishEvent(), topic_id=DefaultTopicId(type=speaker_topic_type))
@abstractmethod
async def select_speaker(self, thread: List[ContentPublishEvent]) -> str:
async def select_speaker(self, thread: List[GroupChatPublishEvent]) -> str:
"""Select a speaker from the participants and return the
topic type of the selected speaker."""
...
async def on_unhandled_message(self, message: Any, ctx: MessageContext) -> None:
raise ValueError(f"Unhandled message in group chat manager: {type(message)}")

View File

@ -0,0 +1,48 @@
from typing import Any, List
from autogen_core.base import MessageContext
from autogen_core.components import DefaultTopicId, event
from ...base import ChatAgent
from ...messages import ChatMessage
from .._events import GroupChatPublishEvent, GroupChatRequestPublishEvent
from ._sequential_routed_agent import SequentialRoutedAgent
class ChatAgentContainer(SequentialRoutedAgent):
"""A core agent class that delegates message handling to an
:class:`autogen_agentchat.base.ChatAgent` so that it can be used in a
group chat team.
Args:
parent_topic_type (str): The topic type of the parent orchestrator.
agent (ChatAgent): The agent to delegate message handling to.
"""
def __init__(self, parent_topic_type: str, agent: ChatAgent) -> None:
super().__init__(description=agent.description)
self._parent_topic_type = parent_topic_type
self._agent = agent
self._message_buffer: List[ChatMessage] = []
@event
async def handle_message(self, message: GroupChatPublishEvent, ctx: MessageContext) -> None:
"""Handle an event by appending the content to the buffer."""
self._message_buffer.append(message.agent_message)
@event
async def handle_content_request(self, message: GroupChatRequestPublishEvent, ctx: MessageContext) -> None:
"""Handle a content request event by passing the messages in the buffer
to the delegate agent and publish the response."""
# Pass the messages in the buffer to the delegate agent.
response = await self._agent.on_messages(self._message_buffer, ctx.cancellation_token)
# Publish the response.
self._message_buffer.clear()
await self.publish_message(
GroupChatPublishEvent(agent_message=response, source=self.id),
topic_id=DefaultTopicId(type=self._parent_topic_type),
)
async def on_unhandled_message(self, message: Any, ctx: MessageContext) -> None:
raise ValueError(f"Unhandled message in agent container: {type(message)}")

View File

@ -1,10 +1,17 @@
import logging
from typing import Callable, List
from ... import EVENT_LOGGER_NAME
from ...base import ChatAgent, TerminationCondition
from .._events import ContentPublishEvent
from .._events import (
GroupChatPublishEvent,
GroupChatSelectSpeakerEvent,
)
from ._base_group_chat import BaseGroupChat
from ._base_group_chat_manager import BaseGroupChatManager
event_logger = logging.getLogger(EVENT_LOGGER_NAME)
class RoundRobinGroupChatManager(BaseGroupChatManager):
"""A group chat manager that selects the next speaker in a round-robin fashion."""
@ -26,11 +33,13 @@ class RoundRobinGroupChatManager(BaseGroupChatManager):
)
self._next_speaker_index = 0
async def select_speaker(self, thread: List[ContentPublishEvent]) -> str:
async def select_speaker(self, thread: List[GroupChatPublishEvent]) -> str:
"""Select a speaker from the participants in a round-robin fashion."""
current_speaker_index = self._next_speaker_index
self._next_speaker_index = (current_speaker_index + 1) % len(self._participant_topic_types)
return self._participant_topic_types[current_speaker_index]
current_speaker = self._participant_topic_types[current_speaker_index]
event_logger.debug(GroupChatSelectSpeakerEvent(selected_speaker=current_speaker, source=self.id))
return current_speaker
class RoundRobinGroupChat(BaseGroupChat):

View File

@ -7,7 +7,10 @@ from autogen_core.components.models import ChatCompletionClient, SystemMessage
from ... import EVENT_LOGGER_NAME, TRACE_LOGGER_NAME
from ...base import ChatAgent, TerminationCondition
from ...messages import MultiModalMessage, StopMessage, TextMessage
from .._events import ContentPublishEvent, SelectSpeakerEvent
from .._events import (
GroupChatPublishEvent,
GroupChatSelectSpeakerEvent,
)
from ._base_group_chat import BaseGroupChat
from ._base_group_chat_manager import BaseGroupChatManager
@ -42,7 +45,7 @@ class SelectorGroupChatManager(BaseGroupChatManager):
self._previous_speaker: str | None = None
self._allow_repeated_speaker = allow_repeated_speaker
async def select_speaker(self, thread: List[ContentPublishEvent]) -> str:
async def select_speaker(self, thread: List[GroupChatPublishEvent]) -> str:
"""Selects the next speaker in a group chat using a ChatCompletion client.
A key assumption is that the agent type is the same as the topic type, which we use as the agent name.
@ -107,7 +110,7 @@ class SelectorGroupChatManager(BaseGroupChatManager):
else:
agent_name = participants[0]
self._previous_speaker = agent_name
event_logger.debug(SelectSpeakerEvent(selected_speaker=agent_name, source=self.id))
event_logger.debug(GroupChatSelectSpeakerEvent(selected_speaker=agent_name, source=self.id))
return agent_name
def _mentioned_agents(self, message_content: str, agent_names: List[str]) -> Dict[str, int]:
@ -148,7 +151,7 @@ class SelectorGroupChat(BaseGroupChat):
to all, using a ChatCompletion model to select the next speaker after each message.
Args:
participants (List[BaseChatAgent]): The participants in the group chat,
participants (List[ChatAgent]): The participants in the group chat,
must have unique names and at least two participants.
model_client (ChatCompletionClient): The ChatCompletion model client used
to select the next speaker.

View File

@ -0,0 +1,72 @@
import logging
from typing import Callable, List
from ... import EVENT_LOGGER_NAME
from ...base import ChatAgent, TerminationCondition
from ...messages import HandoffMessage
from .._events import (
GroupChatPublishEvent,
GroupChatSelectSpeakerEvent,
)
from ._base_group_chat import BaseGroupChat
from ._base_group_chat_manager import BaseGroupChatManager
event_logger = logging.getLogger(EVENT_LOGGER_NAME)
class SwarmGroupChatManager(BaseGroupChatManager):
"""A group chat manager that selects the next speaker based on handoff message only."""
def __init__(
self,
parent_topic_type: str,
group_topic_type: str,
participant_topic_types: List[str],
participant_descriptions: List[str],
termination_condition: TerminationCondition | None,
) -> None:
super().__init__(
parent_topic_type,
group_topic_type,
participant_topic_types,
participant_descriptions,
termination_condition,
)
self._current_speaker = participant_topic_types[0]
async def select_speaker(self, thread: List[GroupChatPublishEvent]) -> str:
"""Select a speaker from the participants based on handoff message."""
if len(thread) > 0 and isinstance(thread[-1].agent_message, HandoffMessage):
self._current_speaker = thread[-1].agent_message.content
if self._current_speaker not in self._participant_topic_types:
raise ValueError("The selected speaker in the handoff message is not a participant.")
event_logger.debug(GroupChatSelectSpeakerEvent(selected_speaker=self._current_speaker, source=self.id))
return self._current_speaker
else:
return self._current_speaker
class Swarm(BaseGroupChat):
"""(Experimental) A group chat that selects the next speaker based on handoff message only."""
def __init__(self, participants: List[ChatAgent]):
super().__init__(participants, group_chat_manager_class=SwarmGroupChatManager)
def _create_group_chat_manager_factory(
self,
parent_topic_type: str,
group_topic_type: str,
participant_topic_types: List[str],
participant_descriptions: List[str],
termination_condition: TerminationCondition | None,
) -> Callable[[], SwarmGroupChatManager]:
def _factory() -> SwarmGroupChatManager:
return SwarmGroupChatManager(
parent_topic_type,
group_topic_type,
participant_topic_types,
participant_descriptions,
termination_condition,
)
return _factory

View File

@ -13,11 +13,17 @@ from autogen_agentchat.agents import (
ToolUseAssistantAgent,
)
from autogen_agentchat.logging import FileLogHandler
from autogen_agentchat.messages import ChatMessage, StopMessage, TextMessage
from autogen_agentchat.task import StopMessageTermination
from autogen_agentchat.messages import (
ChatMessage,
HandoffMessage,
StopMessage,
TextMessage,
)
from autogen_agentchat.task import MaxMessageTermination, StopMessageTermination
from autogen_agentchat.teams import (
RoundRobinGroupChat,
SelectorGroupChat,
Swarm,
)
from autogen_core.base import CancellationToken
from autogen_core.components import FunctionCall
@ -212,7 +218,16 @@ async def test_round_robin_group_chat_with_tools(monkeypatch: pytest.MonkeyPatch
)
echo_agent = _EchoAgent("echo_agent", description="echo agent")
team = RoundRobinGroupChat(participants=[tool_use_agent, echo_agent])
await team.run("Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination())
result = await team.run(
"Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination()
)
assert len(result.messages) == 4
assert isinstance(result.messages[0], TextMessage) # task
assert isinstance(result.messages[1], TextMessage) # tool use agent response
assert isinstance(result.messages[2], TextMessage) # echo agent response
assert isinstance(result.messages[3], StopMessage) # tool use agent response
context = tool_use_agent._model_context # pyright: ignore
assert context[0].content == "Write a program that prints 'Hello, world!'"
assert isinstance(context[1].content, list)
@ -393,3 +408,29 @@ async def test_selector_group_chat_two_speakers_allow_repeated(monkeypatch: pyte
assert result.messages[1].source == "agent2"
assert result.messages[2].source == "agent2"
assert result.messages[3].source == "agent1"
class _HandOffAgent(BaseChatAgent):
def __init__(self, name: str, description: str, next_agent: str) -> None:
super().__init__(name, description)
self._next_agent = next_agent
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage:
return HandoffMessage(content=self._next_agent, source=self.name)
@pytest.mark.asyncio
async def test_swarm() -> None:
first_agent = _HandOffAgent("first_agent", description="first agent", next_agent="second_agent")
second_agent = _HandOffAgent("second_agent", description="second agent", next_agent="third_agent")
third_agent = _HandOffAgent("third_agent", description="third agent", next_agent="first_agent")
team = Swarm([second_agent, first_agent, third_agent])
result = await team.run("task", termination_condition=MaxMessageTermination(6))
assert len(result.messages) == 6
assert result.messages[0].content == "task"
assert result.messages[1].content == "third_agent"
assert result.messages[2].content == "first_agent"
assert result.messages[3].content == "second_agent"
assert result.messages[4].content == "third_agent"
assert result.messages[5].content == "first_agent"

View File

@ -4,6 +4,7 @@ from typing import Any, AsyncGenerator, List
import pytest
from autogen_agentchat.agents import ToolUseAssistantAgent
from autogen_agentchat.messages import StopMessage, TextMessage
from autogen_core.components.models import OpenAIChatCompletionClient
from autogen_core.components.tools import FunctionTool
from openai.resources.chat.completions import AsyncCompletions
@ -103,5 +104,6 @@ async def test_round_robin_group_chat_with_tools(monkeypatch: pytest.MonkeyPatch
)
result = await tool_use_agent.run("task")
assert len(result.messages) == 3
# assert isinstance(result.messages[1], ToolCallMessage)
# assert isinstance(result.messages[2], TextMessage)
assert isinstance(result.messages[0], TextMessage)
assert isinstance(result.messages[1], TextMessage)
assert isinstance(result.messages[2], StopMessage)