mirror of
https://github.com/microsoft/autogen.git
synced 2025-10-27 15:59:35 +00:00
Refactor agent chat to prepare for handoff/swarm (#3949)
Add handoff message type to chat message types Add Swarm group chat that uses handoff message to select next speaker Remove tool call and tool call result message types from chat message types Remove BaseToolUseChatAgent, move tool call handling from group chat's chat agent container upward to the ToolUseAssistantAgent implementation, which subclasses BaseChatAgent directly. Renaming for better clarity --------- Co-authored-by: Victor Dibia <victordibia@microsoft.com>
This commit is contained in:
parent
0756ebd63d
commit
f31ff66368
@ -1,11 +1,10 @@
|
||||
from ._base_chat_agent import BaseChatAgent, BaseToolUseChatAgent
|
||||
from ._base_chat_agent import BaseChatAgent
|
||||
from ._code_executor_agent import CodeExecutorAgent
|
||||
from ._coding_assistant_agent import CodingAssistantAgent
|
||||
from ._tool_use_assistant_agent import ToolUseAssistantAgent
|
||||
|
||||
__all__ = [
|
||||
"BaseChatAgent",
|
||||
"BaseToolUseChatAgent",
|
||||
"CodeExecutorAgent",
|
||||
"CodingAssistantAgent",
|
||||
"ToolUseAssistantAgent",
|
||||
|
||||
@ -1,10 +1,9 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Sequence
|
||||
from typing import Sequence
|
||||
|
||||
from autogen_core.base import CancellationToken
|
||||
from autogen_core.components.tools import Tool
|
||||
|
||||
from ..base import ChatAgent, TaskResult, TerminationCondition, ToolUseChatAgent
|
||||
from ..base import ChatAgent, TaskResult, TerminationCondition
|
||||
from ..messages import ChatMessage
|
||||
from ..teams import RoundRobinGroupChat
|
||||
|
||||
@ -51,21 +50,3 @@ class BaseChatAgent(ChatAgent, ABC):
|
||||
termination_condition=termination_condition,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
class BaseToolUseChatAgent(BaseChatAgent, ToolUseChatAgent):
|
||||
"""Base class for a chat agent that can use tools.
|
||||
|
||||
Subclass this base class to create an agent class that uses tools by returning
|
||||
ToolCallMessage message from the :meth:`on_messages` method and receiving
|
||||
ToolCallResultMessage message from the input to the :meth:`on_messages` method.
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, description: str, registered_tools: List[Tool]) -> None:
|
||||
super().__init__(name, description)
|
||||
self._registered_tools = registered_tools
|
||||
|
||||
@property
|
||||
def registered_tools(self) -> List[Tool]:
|
||||
"""The list of tools that the agent can use."""
|
||||
return self._registered_tools
|
||||
|
||||
@ -1,3 +1,6 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Awaitable, Callable, List, Sequence
|
||||
|
||||
from autogen_core.base import CancellationToken
|
||||
@ -5,25 +8,45 @@ from autogen_core.components import FunctionCall
|
||||
from autogen_core.components.models import (
|
||||
AssistantMessage,
|
||||
ChatCompletionClient,
|
||||
FunctionExecutionResult,
|
||||
FunctionExecutionResultMessage,
|
||||
LLMMessage,
|
||||
SystemMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from autogen_core.components.tools import FunctionTool, Tool
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from .. import EVENT_LOGGER_NAME
|
||||
from ..messages import (
|
||||
ChatMessage,
|
||||
MultiModalMessage,
|
||||
StopMessage,
|
||||
TextMessage,
|
||||
ToolCallMessage,
|
||||
ToolCallResultMessage,
|
||||
)
|
||||
from ._base_chat_agent import BaseToolUseChatAgent
|
||||
from ._base_chat_agent import BaseChatAgent
|
||||
|
||||
event_logger = logging.getLogger(EVENT_LOGGER_NAME)
|
||||
|
||||
|
||||
class ToolUseAssistantAgent(BaseToolUseChatAgent):
|
||||
class ToolCallEvent(BaseModel):
|
||||
"""A tool call event."""
|
||||
|
||||
tool_calls: List[FunctionCall]
|
||||
"""The tool call message."""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
class ToolCallResultEvent(BaseModel):
|
||||
"""A tool call result event."""
|
||||
|
||||
tool_call_results: List[FunctionExecutionResult]
|
||||
"""The tool call result message."""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
class ToolUseAssistantAgent(BaseChatAgent):
|
||||
"""An agent that provides assistance with tool use.
|
||||
|
||||
It responds with a StopMessage when 'terminate' is detected in the response.
|
||||
@ -45,46 +68,50 @@ class ToolUseAssistantAgent(BaseToolUseChatAgent):
|
||||
description: str = "An agent that provides assistance with ability to use tools.",
|
||||
system_message: str = "You are a helpful AI assistant. Solve tasks using your tools. Reply with 'TERMINATE' when the task has been completed.",
|
||||
):
|
||||
tools: List[Tool] = []
|
||||
super().__init__(name=name, description=description)
|
||||
self._model_client = model_client
|
||||
self._system_messages = [SystemMessage(content=system_message)]
|
||||
self._tools: List[Tool] = []
|
||||
for tool in registered_tools:
|
||||
if isinstance(tool, Tool):
|
||||
tools.append(tool)
|
||||
self._tools.append(tool)
|
||||
elif callable(tool):
|
||||
if hasattr(tool, "__doc__") and tool.__doc__ is not None:
|
||||
description = tool.__doc__
|
||||
else:
|
||||
description = ""
|
||||
tools.append(FunctionTool(tool, description=description))
|
||||
self._tools.append(FunctionTool(tool, description=description))
|
||||
else:
|
||||
raise ValueError(f"Unsupported tool type: {type(tool)}")
|
||||
super().__init__(name=name, description=description, registered_tools=tools)
|
||||
self._model_client = model_client
|
||||
self._system_messages = [SystemMessage(content=system_message)]
|
||||
self._tool_schema = [tool.schema for tool in tools]
|
||||
self._model_context: List[LLMMessage] = []
|
||||
|
||||
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage:
|
||||
# Add messages to the model context.
|
||||
for msg in messages:
|
||||
if isinstance(msg, ToolCallResultMessage):
|
||||
self._model_context.append(FunctionExecutionResultMessage(content=msg.content))
|
||||
elif not isinstance(msg, TextMessage | MultiModalMessage | StopMessage):
|
||||
raise ValueError(f"Unsupported message type: {type(msg)}")
|
||||
else:
|
||||
# TODO: add special handling for handoff messages
|
||||
self._model_context.append(UserMessage(content=msg.content, source=msg.source))
|
||||
|
||||
# Generate an inference result based on the current model context.
|
||||
llm_messages = self._system_messages + self._model_context
|
||||
result = await self._model_client.create(
|
||||
llm_messages, tools=self._tool_schema, cancellation_token=cancellation_token
|
||||
)
|
||||
result = await self._model_client.create(llm_messages, tools=self._tools, cancellation_token=cancellation_token)
|
||||
|
||||
# Add the response to the model context.
|
||||
self._model_context.append(AssistantMessage(content=result.content, source=self.name))
|
||||
|
||||
# Detect tool calls.
|
||||
if isinstance(result.content, list) and all(isinstance(item, FunctionCall) for item in result.content):
|
||||
return ToolCallMessage(content=result.content, source=self.name)
|
||||
# Run tool calls until the model produces a string response.
|
||||
while isinstance(result.content, list) and all(isinstance(item, FunctionCall) for item in result.content):
|
||||
event_logger.debug(ToolCallEvent(tool_calls=result.content))
|
||||
# Execute the tool calls.
|
||||
results = await asyncio.gather(
|
||||
*[self._execute_tool_call(call, cancellation_token) for call in result.content]
|
||||
)
|
||||
event_logger.debug(ToolCallResultEvent(tool_call_results=results))
|
||||
self._model_context.append(FunctionExecutionResultMessage(content=results))
|
||||
# Generate an inference result based on the current model context.
|
||||
result = await self._model_client.create(
|
||||
self._model_context, tools=self._tools, cancellation_token=cancellation_token
|
||||
)
|
||||
self._model_context.append(AssistantMessage(content=result.content, source=self.name))
|
||||
|
||||
assert isinstance(result.content, str)
|
||||
# Detect stop request.
|
||||
@ -93,3 +120,20 @@ class ToolUseAssistantAgent(BaseToolUseChatAgent):
|
||||
return StopMessage(content=result.content, source=self.name)
|
||||
|
||||
return TextMessage(content=result.content, source=self.name)
|
||||
|
||||
async def _execute_tool_call(
|
||||
self, tool_call: FunctionCall, cancellation_token: CancellationToken
|
||||
) -> FunctionExecutionResult:
|
||||
"""Execute a tool call and return the result."""
|
||||
try:
|
||||
if not self._tools:
|
||||
raise ValueError("No tools are available.")
|
||||
tool = next((t for t in self._tools if t.name == tool_call.name), None)
|
||||
if tool is None:
|
||||
raise ValueError(f"The tool '{tool_call.name}' is not available.")
|
||||
arguments = json.loads(tool_call.arguments)
|
||||
result = await tool.run_json(arguments, cancellation_token)
|
||||
result_as_str = tool.return_value_as_string(result)
|
||||
return FunctionExecutionResult(content=result_as_str, call_id=tool_call.id)
|
||||
except Exception as e:
|
||||
return FunctionExecutionResult(content=f"Error: {e}", call_id=tool_call.id)
|
||||
|
||||
@ -1,11 +1,10 @@
|
||||
from ._chat_agent import ChatAgent, ToolUseChatAgent
|
||||
from ._chat_agent import ChatAgent
|
||||
from ._task import TaskResult, TaskRunner
|
||||
from ._team import Team
|
||||
from ._termination import TerminatedException, TerminationCondition
|
||||
|
||||
__all__ = [
|
||||
"ChatAgent",
|
||||
"ToolUseChatAgent",
|
||||
"Team",
|
||||
"TerminatedException",
|
||||
"TerminationCondition",
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
from typing import List, Protocol, Sequence, runtime_checkable
|
||||
from typing import Protocol, Sequence, runtime_checkable
|
||||
|
||||
from autogen_core.base import CancellationToken
|
||||
from autogen_core.components.tools import Tool
|
||||
|
||||
from ..messages import ChatMessage
|
||||
from ._task import TaskResult, TaskRunner
|
||||
@ -38,13 +37,3 @@ class ChatAgent(TaskRunner, Protocol):
|
||||
) -> TaskResult:
|
||||
"""Run the agent with the given task and return the result."""
|
||||
...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class ToolUseChatAgent(ChatAgent, Protocol):
|
||||
"""Protocol for a chat agent that can use tools."""
|
||||
|
||||
@property
|
||||
def registered_tools(self) -> List[Tool]:
|
||||
"""The list of tools that the agent can use."""
|
||||
...
|
||||
|
||||
@ -3,13 +3,12 @@ import logging
|
||||
import sys
|
||||
from datetime import datetime
|
||||
|
||||
from ..agents._tool_use_assistant_agent import ToolCallEvent, ToolCallResultEvent
|
||||
from ..messages import ChatMessage, StopMessage, TextMessage
|
||||
from ..teams._events import (
|
||||
ContentPublishEvent,
|
||||
SelectSpeakerEvent,
|
||||
GroupChatPublishEvent,
|
||||
GroupChatSelectSpeakerEvent,
|
||||
TerminationEvent,
|
||||
ToolCallEvent,
|
||||
ToolCallResultEvent,
|
||||
)
|
||||
|
||||
|
||||
@ -25,7 +24,7 @@ class ConsoleLogHandler(logging.Handler):
|
||||
|
||||
def emit(self, record: logging.LogRecord) -> None:
|
||||
ts = datetime.fromtimestamp(record.created).isoformat()
|
||||
if isinstance(record.msg, ContentPublishEvent):
|
||||
if isinstance(record.msg, GroupChatPublishEvent):
|
||||
if record.msg.source is None:
|
||||
sys.stdout.write(
|
||||
f"\n{'-'*75} \n"
|
||||
@ -41,19 +40,15 @@ class ConsoleLogHandler(logging.Handler):
|
||||
sys.stdout.flush()
|
||||
elif isinstance(record.msg, ToolCallEvent):
|
||||
sys.stdout.write(
|
||||
f"\n{'-'*75} \n"
|
||||
f"\033[91m[{ts}], Tool Call:\033[0m\n"
|
||||
f"\n{self.serialize_chat_message(record.msg.agent_message)}"
|
||||
f"\n{'-'*75} \n" f"\033[91m[{ts}], Tool Call:\033[0m\n" f"\n{str(record.msg.model_dump())}"
|
||||
)
|
||||
sys.stdout.flush()
|
||||
elif isinstance(record.msg, ToolCallResultEvent):
|
||||
sys.stdout.write(
|
||||
f"\n{'-'*75} \n"
|
||||
f"\033[91m[{ts}], Tool Call Result:\033[0m\n"
|
||||
f"\n{self.serialize_chat_message(record.msg.agent_message)}"
|
||||
f"\n{'-'*75} \n" f"\033[91m[{ts}], Tool Call Result:\033[0m\n" f"\n{str(record.msg.model_dump())}"
|
||||
)
|
||||
sys.stdout.flush()
|
||||
elif isinstance(record.msg, SelectSpeakerEvent):
|
||||
elif isinstance(record.msg, GroupChatSelectSpeakerEvent):
|
||||
sys.stdout.write(
|
||||
f"\n{'-'*75} \n" f"\033[91m[{ts}], Selected Next Speaker:\033[0m\n" f"\n{record.msg.selected_speaker}"
|
||||
)
|
||||
|
||||
@ -4,12 +4,11 @@ from dataclasses import asdict, is_dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from ..agents._tool_use_assistant_agent import ToolCallEvent, ToolCallResultEvent
|
||||
from ..teams._events import (
|
||||
ContentPublishEvent,
|
||||
SelectSpeakerEvent,
|
||||
GroupChatPublishEvent,
|
||||
GroupChatSelectSpeakerEvent,
|
||||
TerminationEvent,
|
||||
ToolCallEvent,
|
||||
ToolCallResultEvent,
|
||||
)
|
||||
|
||||
|
||||
@ -21,7 +20,7 @@ class FileLogHandler(logging.Handler):
|
||||
|
||||
def emit(self, record: logging.LogRecord) -> None:
|
||||
ts = datetime.fromtimestamp(record.created).isoformat()
|
||||
if isinstance(record.msg, ContentPublishEvent | ToolCallEvent | ToolCallResultEvent | TerminationEvent):
|
||||
if isinstance(record.msg, GroupChatPublishEvent | TerminationEvent):
|
||||
log_entry = json.dumps(
|
||||
{
|
||||
"timestamp": ts,
|
||||
@ -31,7 +30,7 @@ class FileLogHandler(logging.Handler):
|
||||
},
|
||||
default=self.json_serializer,
|
||||
)
|
||||
elif isinstance(record.msg, SelectSpeakerEvent):
|
||||
elif isinstance(record.msg, GroupChatSelectSpeakerEvent):
|
||||
log_entry = json.dumps(
|
||||
{
|
||||
"timestamp": ts,
|
||||
@ -41,6 +40,24 @@ class FileLogHandler(logging.Handler):
|
||||
},
|
||||
default=self.json_serializer,
|
||||
)
|
||||
elif isinstance(record.msg, ToolCallEvent):
|
||||
log_entry = json.dumps(
|
||||
{
|
||||
"timestamp": ts,
|
||||
"tool_calls": record.msg.model_dump(),
|
||||
"type": "ToolCallEvent",
|
||||
},
|
||||
default=self.json_serializer,
|
||||
)
|
||||
elif isinstance(record.msg, ToolCallResultEvent):
|
||||
log_entry = json.dumps(
|
||||
{
|
||||
"timestamp": ts,
|
||||
"tool_call_results": record.msg.model_dump(),
|
||||
"type": "ToolCallResultEvent",
|
||||
},
|
||||
default=self.json_serializer,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unexpected log record: {record.msg}")
|
||||
file_record = logging.LogRecord(
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
from typing import List
|
||||
|
||||
from autogen_core.components import FunctionCall, Image
|
||||
from autogen_core.components.models import FunctionExecutionResult
|
||||
from autogen_core.components import Image
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
@ -26,20 +25,6 @@ class MultiModalMessage(BaseMessage):
|
||||
"""The content of the message."""
|
||||
|
||||
|
||||
class ToolCallMessage(BaseMessage):
|
||||
"""A message containing a list of function calls."""
|
||||
|
||||
content: List[FunctionCall]
|
||||
"""The list of function calls."""
|
||||
|
||||
|
||||
class ToolCallResultMessage(BaseMessage):
|
||||
"""A message containing the results of function calls."""
|
||||
|
||||
content: List[FunctionExecutionResult]
|
||||
"""The list of function execution results."""
|
||||
|
||||
|
||||
class StopMessage(BaseMessage):
|
||||
"""A message requesting stop of a conversation."""
|
||||
|
||||
@ -47,7 +32,14 @@ class StopMessage(BaseMessage):
|
||||
"""The content for the stop message."""
|
||||
|
||||
|
||||
ChatMessage = TextMessage | MultiModalMessage | StopMessage | ToolCallMessage | ToolCallResultMessage
|
||||
class HandoffMessage(BaseMessage):
|
||||
"""A message requesting handoff of a conversation to another agent."""
|
||||
|
||||
content: str
|
||||
"""The agent name to handoff the conversation to."""
|
||||
|
||||
|
||||
ChatMessage = TextMessage | MultiModalMessage | StopMessage | HandoffMessage
|
||||
"""A message used by agents in a team."""
|
||||
|
||||
|
||||
@ -55,8 +47,7 @@ __all__ = [
|
||||
"BaseMessage",
|
||||
"TextMessage",
|
||||
"MultiModalMessage",
|
||||
"ToolCallMessage",
|
||||
"ToolCallResultMessage",
|
||||
"StopMessage",
|
||||
"HandoffMessage",
|
||||
"ChatMessage",
|
||||
]
|
||||
|
||||
@ -1,7 +1,9 @@
|
||||
from ._group_chat._round_robin_group_chat import RoundRobinGroupChat
|
||||
from ._group_chat._selector_group_chat import SelectorGroupChat
|
||||
from ._group_chat._swarm_group_chat import Swarm
|
||||
|
||||
__all__ = [
|
||||
"RoundRobinGroupChat",
|
||||
"SelectorGroupChat",
|
||||
"Swarm",
|
||||
]
|
||||
|
||||
@ -1,16 +1,16 @@
|
||||
from autogen_core.base import AgentId
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from ..messages import MultiModalMessage, StopMessage, TextMessage, ToolCallMessage, ToolCallResultMessage
|
||||
from ..messages import ChatMessage, StopMessage
|
||||
|
||||
|
||||
class ContentPublishEvent(BaseModel):
|
||||
"""An event for sharing some data. Agents receive this event should
|
||||
class GroupChatPublishEvent(BaseModel):
|
||||
"""An group chat event for sharing some data. Agents receive this event should
|
||||
update their internal state (e.g., append to message history) with the
|
||||
content of the event.
|
||||
"""
|
||||
|
||||
agent_message: TextMessage | MultiModalMessage | StopMessage
|
||||
agent_message: ChatMessage
|
||||
"""The message published by the agent."""
|
||||
|
||||
source: AgentId | None = None
|
||||
@ -19,39 +19,15 @@ class ContentPublishEvent(BaseModel):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
class ContentRequestEvent(BaseModel):
|
||||
"""An event for requesting to publish a content event.
|
||||
Upon receiving this event, the agent should publish a ContentPublishEvent.
|
||||
class GroupChatRequestPublishEvent(BaseModel):
|
||||
"""An event for requesting to publish a group chat publish event.
|
||||
Upon receiving this event, the agent should publish a group chat publish event.
|
||||
"""
|
||||
|
||||
...
|
||||
|
||||
|
||||
class ToolCallEvent(BaseModel):
|
||||
"""An event produced when requesting a tool call."""
|
||||
|
||||
agent_message: ToolCallMessage
|
||||
"""The tool call message."""
|
||||
|
||||
source: AgentId
|
||||
"""The sender of the tool call message."""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
class ToolCallResultEvent(BaseModel):
|
||||
"""An event produced when a tool call is completed."""
|
||||
|
||||
agent_message: ToolCallResultMessage
|
||||
"""The tool call result message."""
|
||||
|
||||
source: AgentId
|
||||
"""The sender of the tool call result message."""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
class SelectSpeakerEvent(BaseModel):
|
||||
class GroupChatSelectSpeakerEvent(BaseModel):
|
||||
"""An event for selecting the next speaker in a group chat."""
|
||||
|
||||
selected_speaker: str
|
||||
|
||||
@ -1,92 +0,0 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from autogen_core.base import AgentId, AgentType, MessageContext
|
||||
from autogen_core.components import DefaultTopicId, event
|
||||
from autogen_core.components.models import FunctionExecutionResult
|
||||
from autogen_core.components.tool_agent import ToolException
|
||||
|
||||
from ... import EVENT_LOGGER_NAME
|
||||
from ...base import ChatAgent
|
||||
from ...messages import MultiModalMessage, StopMessage, TextMessage, ToolCallMessage, ToolCallResultMessage
|
||||
from .._events import ContentPublishEvent, ContentRequestEvent, ToolCallEvent, ToolCallResultEvent
|
||||
from ._sequential_routed_agent import SequentialRoutedAgent
|
||||
|
||||
event_logger = logging.getLogger(EVENT_LOGGER_NAME)
|
||||
|
||||
|
||||
class BaseChatAgentContainer(SequentialRoutedAgent):
|
||||
"""A core agent class that delegates message handling to an
|
||||
:class:`autogen_agentchat.agents.BaseChatAgent` so that it can be used in a
|
||||
group chat team.
|
||||
|
||||
Args:
|
||||
parent_topic_type (str): The topic type of the parent orchestrator.
|
||||
agent (BaseChatAgent): The agent to delegate message handling to.
|
||||
tool_agent_type (AgentType, optional): The agent type of the tool agent. Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self, parent_topic_type: str, agent: ChatAgent, tool_agent_type: AgentType | None = None) -> None:
|
||||
super().__init__(description=agent.description)
|
||||
self._parent_topic_type = parent_topic_type
|
||||
self._agent = agent
|
||||
self._message_buffer: List[TextMessage | MultiModalMessage | StopMessage] = []
|
||||
self._tool_agent_id = AgentId(type=tool_agent_type, key=self.id.key) if tool_agent_type else None
|
||||
|
||||
@event
|
||||
async def handle_content_publish(self, message: ContentPublishEvent, ctx: MessageContext) -> None:
|
||||
"""Handle a content publish event by appending the content to the buffer."""
|
||||
if not isinstance(message.agent_message, TextMessage | MultiModalMessage | StopMessage):
|
||||
raise ValueError(
|
||||
f"Unexpected message type: {type(message.agent_message)}. "
|
||||
"The message must be a text, multimodal, or stop message."
|
||||
)
|
||||
self._message_buffer.append(message.agent_message)
|
||||
|
||||
@event
|
||||
async def handle_content_request(self, message: ContentRequestEvent, ctx: MessageContext) -> None:
|
||||
"""Handle a content request event by passing the messages in the buffer
|
||||
to the delegate agent and publish the response."""
|
||||
response = await self._agent.on_messages(self._message_buffer, ctx.cancellation_token)
|
||||
|
||||
if self._tool_agent_id is not None:
|
||||
# Handle tool calls.
|
||||
while isinstance(response, ToolCallMessage):
|
||||
# Log the tool call.
|
||||
event_logger.debug(ToolCallEvent(agent_message=response, source=self.id))
|
||||
|
||||
results: List[FunctionExecutionResult | BaseException] = await asyncio.gather(
|
||||
*[
|
||||
self.send_message(
|
||||
message=call,
|
||||
recipient=self._tool_agent_id,
|
||||
cancellation_token=ctx.cancellation_token,
|
||||
)
|
||||
for call in response.content
|
||||
]
|
||||
)
|
||||
# Combine the results in to a single response and handle exceptions.
|
||||
function_results: List[FunctionExecutionResult] = []
|
||||
for result in results:
|
||||
if isinstance(result, FunctionExecutionResult):
|
||||
function_results.append(result)
|
||||
elif isinstance(result, ToolException):
|
||||
function_results.append(
|
||||
FunctionExecutionResult(content=f"Error: {result}", call_id=result.call_id)
|
||||
)
|
||||
elif isinstance(result, BaseException):
|
||||
raise result # Unexpected exception.
|
||||
# Create a new tool call result message.
|
||||
feedback = ToolCallResultMessage(content=function_results, source=self._tool_agent_id.type)
|
||||
# Log the feedback.
|
||||
event_logger.debug(ToolCallResultEvent(agent_message=feedback, source=self._tool_agent_id))
|
||||
response = await self._agent.on_messages([feedback], ctx.cancellation_token)
|
||||
|
||||
# Publish the response.
|
||||
assert isinstance(response, TextMessage | MultiModalMessage | StopMessage)
|
||||
self._message_buffer.clear()
|
||||
await self.publish_message(
|
||||
ContentPublishEvent(agent_message=response, source=self.id),
|
||||
topic_id=DefaultTopicId(type=self._parent_topic_type),
|
||||
)
|
||||
@ -13,14 +13,12 @@ from autogen_core.base import (
|
||||
TopicId,
|
||||
)
|
||||
from autogen_core.components import ClosureAgent, TypeSubscription
|
||||
from autogen_core.components.tool_agent import ToolAgent
|
||||
from autogen_core.components.tools import Tool
|
||||
|
||||
from ...base import ChatAgent, TaskResult, Team, TerminationCondition, ToolUseChatAgent
|
||||
from ...base import ChatAgent, TaskResult, Team, TerminationCondition
|
||||
from ...messages import ChatMessage, TextMessage
|
||||
from .._events import ContentPublishEvent, ContentRequestEvent
|
||||
from ._base_chat_agent_container import BaseChatAgentContainer
|
||||
from .._events import GroupChatPublishEvent, GroupChatRequestPublishEvent
|
||||
from ._base_group_chat_manager import BaseGroupChatManager
|
||||
from ._chat_agent_container import ChatAgentContainer
|
||||
|
||||
|
||||
class BaseGroupChat(Team, ABC):
|
||||
@ -35,11 +33,6 @@ class BaseGroupChat(Team, ABC):
|
||||
raise ValueError("At least one participant is required.")
|
||||
if len(participants) != len(set(participant.name for participant in participants)):
|
||||
raise ValueError("The participant names must be unique.")
|
||||
for participant in participants:
|
||||
if isinstance(participant, ToolUseChatAgent) and not participant.registered_tools:
|
||||
raise ValueError(
|
||||
f"Participant '{participant.name}' is a tool use agent so it must have registered tools."
|
||||
)
|
||||
self._participants = participants
|
||||
self._team_id = str(uuid.uuid4())
|
||||
self._base_group_chat_manager_class = group_chat_manager_class
|
||||
@ -55,27 +48,19 @@ class BaseGroupChat(Team, ABC):
|
||||
) -> Callable[[], BaseGroupChatManager]: ...
|
||||
|
||||
def _create_participant_factory(
|
||||
self, parent_topic_type: str, agent: ChatAgent, tool_agent_type: AgentType | None
|
||||
) -> Callable[[], BaseChatAgentContainer]:
|
||||
def _factory() -> BaseChatAgentContainer:
|
||||
self,
|
||||
parent_topic_type: str,
|
||||
agent: ChatAgent,
|
||||
) -> Callable[[], ChatAgentContainer]:
|
||||
def _factory() -> ChatAgentContainer:
|
||||
id = AgentInstantiationContext.current_agent_id()
|
||||
assert id == AgentId(type=agent.name, key=self._team_id)
|
||||
container = BaseChatAgentContainer(parent_topic_type, agent, tool_agent_type)
|
||||
container = ChatAgentContainer(parent_topic_type, agent)
|
||||
assert container.id == id
|
||||
return container
|
||||
|
||||
return _factory
|
||||
|
||||
def _create_tool_agent_factory(
|
||||
self,
|
||||
caller_name: str,
|
||||
tools: List[Tool],
|
||||
) -> Callable[[], ToolAgent]:
|
||||
def _factory() -> ToolAgent:
|
||||
return ToolAgent(f"Tool agent for {caller_name}", tools)
|
||||
|
||||
return _factory
|
||||
|
||||
async def run(
|
||||
self,
|
||||
task: str,
|
||||
@ -99,27 +84,14 @@ class BaseGroupChat(Team, ABC):
|
||||
participant_topic_types: List[str] = []
|
||||
participant_descriptions: List[str] = []
|
||||
for participant in self._participants:
|
||||
if isinstance(participant, ToolUseChatAgent):
|
||||
assert participant.registered_tools is not None and len(participant.registered_tools) > 0
|
||||
# Register the tool agent.
|
||||
tool_agent_type = await ToolAgent.register(
|
||||
runtime,
|
||||
f"tool_agent_for_{participant.name}",
|
||||
self._create_tool_agent_factory(participant.name, participant.registered_tools),
|
||||
)
|
||||
# No subscriptions are needed for the tool agent, which will be called via direct messages.
|
||||
else:
|
||||
# No tool agent is needed.
|
||||
tool_agent_type = None
|
||||
|
||||
# Use the participant name as the agent type and topic type.
|
||||
agent_type = participant.name
|
||||
topic_type = participant.name
|
||||
# Register the participant factory.
|
||||
await BaseChatAgentContainer.register(
|
||||
await ChatAgentContainer.register(
|
||||
runtime,
|
||||
type=agent_type,
|
||||
factory=self._create_participant_factory(group_topic_type, participant, tool_agent_type),
|
||||
factory=self._create_participant_factory(group_topic_type, participant),
|
||||
)
|
||||
# Add subscriptions for the participant.
|
||||
await runtime.add_subscription(TypeSubscription(topic_type=topic_type, agent_type=agent_type))
|
||||
@ -154,7 +126,10 @@ class BaseGroupChat(Team, ABC):
|
||||
group_chat_messages: List[ChatMessage] = []
|
||||
|
||||
async def collect_group_chat_messages(
|
||||
_runtime: AgentRuntime, id: AgentId, message: ContentPublishEvent, ctx: MessageContext
|
||||
_runtime: AgentRuntime,
|
||||
id: AgentId,
|
||||
message: GroupChatPublishEvent,
|
||||
ctx: MessageContext,
|
||||
) -> None:
|
||||
group_chat_messages.append(message.agent_message)
|
||||
|
||||
@ -174,10 +149,10 @@ class BaseGroupChat(Team, ABC):
|
||||
team_topic_id = TopicId(type=team_topic_type, source=self._team_id)
|
||||
group_chat_manager_topic_id = TopicId(type=group_chat_manager_topic_type, source=self._team_id)
|
||||
await runtime.publish_message(
|
||||
ContentPublishEvent(agent_message=TextMessage(content=task, source="user")),
|
||||
GroupChatPublishEvent(agent_message=TextMessage(content=task, source="user")),
|
||||
topic_id=team_topic_id,
|
||||
)
|
||||
await runtime.publish_message(ContentRequestEvent(), topic_id=group_chat_manager_topic_id)
|
||||
await runtime.publish_message(GroupChatRequestPublishEvent(), topic_id=group_chat_manager_topic_id)
|
||||
|
||||
# Wait for the runtime to stop.
|
||||
await runtime.stop_when_idle()
|
||||
|
||||
@ -1,13 +1,17 @@
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List
|
||||
from typing import Any, List
|
||||
|
||||
from autogen_core.base import MessageContext, TopicId
|
||||
from autogen_core.components import event
|
||||
from autogen_core.base import MessageContext
|
||||
from autogen_core.components import DefaultTopicId, event
|
||||
|
||||
from ... import EVENT_LOGGER_NAME
|
||||
from ...base import TerminationCondition
|
||||
from .._events import ContentPublishEvent, ContentRequestEvent, TerminationEvent
|
||||
from .._events import (
|
||||
GroupChatPublishEvent,
|
||||
GroupChatRequestPublishEvent,
|
||||
TerminationEvent,
|
||||
)
|
||||
from ._sequential_routed_agent import SequentialRoutedAgent
|
||||
|
||||
event_logger = logging.getLogger(EVENT_LOGGER_NAME)
|
||||
@ -33,6 +37,10 @@ class BaseGroupChatManager(SequentialRoutedAgent, ABC):
|
||||
|
||||
Raises:
|
||||
ValueError: If the number of participant topic types, agent types, and descriptions are not the same.
|
||||
ValueError: If the participant topic types are not unique.
|
||||
ValueError: If the group topic type is in the participant topic types.
|
||||
ValueError: If the parent topic type is in the participant topic types.
|
||||
ValueError: If the group topic type is the same as the parent topic type.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -58,11 +66,11 @@ class BaseGroupChatManager(SequentialRoutedAgent, ABC):
|
||||
raise ValueError("The group topic type must not be the same as the parent topic type.")
|
||||
self._participant_topic_types = participant_topic_types
|
||||
self._participant_descriptions = participant_descriptions
|
||||
self._message_thread: List[ContentPublishEvent] = []
|
||||
self._message_thread: List[GroupChatPublishEvent] = []
|
||||
self._termination_condition = termination_condition
|
||||
|
||||
@event
|
||||
async def handle_content_publish(self, message: ContentPublishEvent, ctx: MessageContext) -> None:
|
||||
async def handle_content_publish(self, message: GroupChatPublishEvent, ctx: MessageContext) -> None:
|
||||
"""Handle a content publish event.
|
||||
|
||||
If the event is from the parent topic, add the message to the thread.
|
||||
@ -70,16 +78,25 @@ class BaseGroupChatManager(SequentialRoutedAgent, ABC):
|
||||
If the event is from the group chat topic, add the message to the thread and select a speaker to continue the conversation.
|
||||
If the event from the group chat session requests a pause, publish the last message to the parent topic."""
|
||||
assert ctx.topic_id is not None
|
||||
group_chat_topic_id = TopicId(type=self._group_topic_type, source=ctx.topic_id.source)
|
||||
|
||||
event_logger.info(message)
|
||||
|
||||
if self._termination_condition is not None and self._termination_condition.terminated:
|
||||
# The group chat has been terminated.
|
||||
return
|
||||
|
||||
# Process event from parent.
|
||||
if ctx.topic_id.type == self._parent_topic_type:
|
||||
self._message_thread.append(message)
|
||||
await self.publish_message(
|
||||
ContentPublishEvent(agent_message=message.agent_message, source=self.id), topic_id=group_chat_topic_id
|
||||
GroupChatPublishEvent(agent_message=message.agent_message, source=self.id),
|
||||
topic_id=DefaultTopicId(type=self._group_topic_type),
|
||||
)
|
||||
if self._termination_condition is not None:
|
||||
stop_message = await self._termination_condition([message.agent_message])
|
||||
if stop_message is not None:
|
||||
event_logger.info(TerminationEvent(agent_message=stop_message, source=self.id))
|
||||
# Stop the group chat.
|
||||
return
|
||||
|
||||
# Process event from the group chat this agent manages.
|
||||
@ -91,8 +108,6 @@ class BaseGroupChatManager(SequentialRoutedAgent, ABC):
|
||||
stop_message = await self._termination_condition([message.agent_message])
|
||||
if stop_message is not None:
|
||||
event_logger.info(TerminationEvent(agent_message=stop_message, source=self.id))
|
||||
# Reset the termination condition.
|
||||
await self._termination_condition.reset()
|
||||
# Stop the group chat.
|
||||
# TODO: this should be different if the group chat is nested.
|
||||
return
|
||||
@ -100,24 +115,28 @@ class BaseGroupChatManager(SequentialRoutedAgent, ABC):
|
||||
# Select a speaker to continue the conversation.
|
||||
speaker_topic_type = await self.select_speaker(self._message_thread)
|
||||
|
||||
participant_topic_id = TopicId(type=speaker_topic_type, source=ctx.topic_id.source)
|
||||
group_chat_topic_id = TopicId(type=self._group_topic_type, source=ctx.topic_id.source)
|
||||
await self.publish_message(ContentRequestEvent(), topic_id=participant_topic_id)
|
||||
await self.publish_message(GroupChatRequestPublishEvent(), topic_id=DefaultTopicId(type=speaker_topic_type))
|
||||
|
||||
@event
|
||||
async def handle_content_request(self, message: ContentRequestEvent, ctx: MessageContext) -> None:
|
||||
async def handle_content_request(self, message: GroupChatRequestPublishEvent, ctx: MessageContext) -> None:
|
||||
"""Handle a content request by selecting a speaker to start the conversation."""
|
||||
assert ctx.topic_id is not None
|
||||
if ctx.topic_id.type == self._group_topic_type:
|
||||
raise RuntimeError("Content request event from the group chat topic is not allowed.")
|
||||
|
||||
if self._termination_condition is not None and self._termination_condition.terminated:
|
||||
# The group chat has been terminated.
|
||||
return
|
||||
|
||||
speaker_topic_type = await self.select_speaker(self._message_thread)
|
||||
|
||||
participant_topic_id = TopicId(type=speaker_topic_type, source=ctx.topic_id.source)
|
||||
await self.publish_message(ContentRequestEvent(), topic_id=participant_topic_id)
|
||||
await self.publish_message(GroupChatRequestPublishEvent(), topic_id=DefaultTopicId(type=speaker_topic_type))
|
||||
|
||||
@abstractmethod
|
||||
async def select_speaker(self, thread: List[ContentPublishEvent]) -> str:
|
||||
async def select_speaker(self, thread: List[GroupChatPublishEvent]) -> str:
|
||||
"""Select a speaker from the participants and return the
|
||||
topic type of the selected speaker."""
|
||||
...
|
||||
|
||||
async def on_unhandled_message(self, message: Any, ctx: MessageContext) -> None:
|
||||
raise ValueError(f"Unhandled message in group chat manager: {type(message)}")
|
||||
|
||||
@ -0,0 +1,48 @@
|
||||
from typing import Any, List
|
||||
|
||||
from autogen_core.base import MessageContext
|
||||
from autogen_core.components import DefaultTopicId, event
|
||||
|
||||
from ...base import ChatAgent
|
||||
from ...messages import ChatMessage
|
||||
from .._events import GroupChatPublishEvent, GroupChatRequestPublishEvent
|
||||
from ._sequential_routed_agent import SequentialRoutedAgent
|
||||
|
||||
|
||||
class ChatAgentContainer(SequentialRoutedAgent):
|
||||
"""A core agent class that delegates message handling to an
|
||||
:class:`autogen_agentchat.base.ChatAgent` so that it can be used in a
|
||||
group chat team.
|
||||
|
||||
Args:
|
||||
parent_topic_type (str): The topic type of the parent orchestrator.
|
||||
agent (ChatAgent): The agent to delegate message handling to.
|
||||
"""
|
||||
|
||||
def __init__(self, parent_topic_type: str, agent: ChatAgent) -> None:
|
||||
super().__init__(description=agent.description)
|
||||
self._parent_topic_type = parent_topic_type
|
||||
self._agent = agent
|
||||
self._message_buffer: List[ChatMessage] = []
|
||||
|
||||
@event
|
||||
async def handle_message(self, message: GroupChatPublishEvent, ctx: MessageContext) -> None:
|
||||
"""Handle an event by appending the content to the buffer."""
|
||||
self._message_buffer.append(message.agent_message)
|
||||
|
||||
@event
|
||||
async def handle_content_request(self, message: GroupChatRequestPublishEvent, ctx: MessageContext) -> None:
|
||||
"""Handle a content request event by passing the messages in the buffer
|
||||
to the delegate agent and publish the response."""
|
||||
# Pass the messages in the buffer to the delegate agent.
|
||||
response = await self._agent.on_messages(self._message_buffer, ctx.cancellation_token)
|
||||
|
||||
# Publish the response.
|
||||
self._message_buffer.clear()
|
||||
await self.publish_message(
|
||||
GroupChatPublishEvent(agent_message=response, source=self.id),
|
||||
topic_id=DefaultTopicId(type=self._parent_topic_type),
|
||||
)
|
||||
|
||||
async def on_unhandled_message(self, message: Any, ctx: MessageContext) -> None:
|
||||
raise ValueError(f"Unhandled message in agent container: {type(message)}")
|
||||
@ -1,10 +1,17 @@
|
||||
import logging
|
||||
from typing import Callable, List
|
||||
|
||||
from ... import EVENT_LOGGER_NAME
|
||||
from ...base import ChatAgent, TerminationCondition
|
||||
from .._events import ContentPublishEvent
|
||||
from .._events import (
|
||||
GroupChatPublishEvent,
|
||||
GroupChatSelectSpeakerEvent,
|
||||
)
|
||||
from ._base_group_chat import BaseGroupChat
|
||||
from ._base_group_chat_manager import BaseGroupChatManager
|
||||
|
||||
event_logger = logging.getLogger(EVENT_LOGGER_NAME)
|
||||
|
||||
|
||||
class RoundRobinGroupChatManager(BaseGroupChatManager):
|
||||
"""A group chat manager that selects the next speaker in a round-robin fashion."""
|
||||
@ -26,11 +33,13 @@ class RoundRobinGroupChatManager(BaseGroupChatManager):
|
||||
)
|
||||
self._next_speaker_index = 0
|
||||
|
||||
async def select_speaker(self, thread: List[ContentPublishEvent]) -> str:
|
||||
async def select_speaker(self, thread: List[GroupChatPublishEvent]) -> str:
|
||||
"""Select a speaker from the participants in a round-robin fashion."""
|
||||
current_speaker_index = self._next_speaker_index
|
||||
self._next_speaker_index = (current_speaker_index + 1) % len(self._participant_topic_types)
|
||||
return self._participant_topic_types[current_speaker_index]
|
||||
current_speaker = self._participant_topic_types[current_speaker_index]
|
||||
event_logger.debug(GroupChatSelectSpeakerEvent(selected_speaker=current_speaker, source=self.id))
|
||||
return current_speaker
|
||||
|
||||
|
||||
class RoundRobinGroupChat(BaseGroupChat):
|
||||
|
||||
@ -7,7 +7,10 @@ from autogen_core.components.models import ChatCompletionClient, SystemMessage
|
||||
from ... import EVENT_LOGGER_NAME, TRACE_LOGGER_NAME
|
||||
from ...base import ChatAgent, TerminationCondition
|
||||
from ...messages import MultiModalMessage, StopMessage, TextMessage
|
||||
from .._events import ContentPublishEvent, SelectSpeakerEvent
|
||||
from .._events import (
|
||||
GroupChatPublishEvent,
|
||||
GroupChatSelectSpeakerEvent,
|
||||
)
|
||||
from ._base_group_chat import BaseGroupChat
|
||||
from ._base_group_chat_manager import BaseGroupChatManager
|
||||
|
||||
@ -42,7 +45,7 @@ class SelectorGroupChatManager(BaseGroupChatManager):
|
||||
self._previous_speaker: str | None = None
|
||||
self._allow_repeated_speaker = allow_repeated_speaker
|
||||
|
||||
async def select_speaker(self, thread: List[ContentPublishEvent]) -> str:
|
||||
async def select_speaker(self, thread: List[GroupChatPublishEvent]) -> str:
|
||||
"""Selects the next speaker in a group chat using a ChatCompletion client.
|
||||
|
||||
A key assumption is that the agent type is the same as the topic type, which we use as the agent name.
|
||||
@ -107,7 +110,7 @@ class SelectorGroupChatManager(BaseGroupChatManager):
|
||||
else:
|
||||
agent_name = participants[0]
|
||||
self._previous_speaker = agent_name
|
||||
event_logger.debug(SelectSpeakerEvent(selected_speaker=agent_name, source=self.id))
|
||||
event_logger.debug(GroupChatSelectSpeakerEvent(selected_speaker=agent_name, source=self.id))
|
||||
return agent_name
|
||||
|
||||
def _mentioned_agents(self, message_content: str, agent_names: List[str]) -> Dict[str, int]:
|
||||
@ -148,7 +151,7 @@ class SelectorGroupChat(BaseGroupChat):
|
||||
to all, using a ChatCompletion model to select the next speaker after each message.
|
||||
|
||||
Args:
|
||||
participants (List[BaseChatAgent]): The participants in the group chat,
|
||||
participants (List[ChatAgent]): The participants in the group chat,
|
||||
must have unique names and at least two participants.
|
||||
model_client (ChatCompletionClient): The ChatCompletion model client used
|
||||
to select the next speaker.
|
||||
|
||||
@ -0,0 +1,72 @@
|
||||
import logging
|
||||
from typing import Callable, List
|
||||
|
||||
from ... import EVENT_LOGGER_NAME
|
||||
from ...base import ChatAgent, TerminationCondition
|
||||
from ...messages import HandoffMessage
|
||||
from .._events import (
|
||||
GroupChatPublishEvent,
|
||||
GroupChatSelectSpeakerEvent,
|
||||
)
|
||||
from ._base_group_chat import BaseGroupChat
|
||||
from ._base_group_chat_manager import BaseGroupChatManager
|
||||
|
||||
event_logger = logging.getLogger(EVENT_LOGGER_NAME)
|
||||
|
||||
|
||||
class SwarmGroupChatManager(BaseGroupChatManager):
|
||||
"""A group chat manager that selects the next speaker based on handoff message only."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
parent_topic_type: str,
|
||||
group_topic_type: str,
|
||||
participant_topic_types: List[str],
|
||||
participant_descriptions: List[str],
|
||||
termination_condition: TerminationCondition | None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
parent_topic_type,
|
||||
group_topic_type,
|
||||
participant_topic_types,
|
||||
participant_descriptions,
|
||||
termination_condition,
|
||||
)
|
||||
self._current_speaker = participant_topic_types[0]
|
||||
|
||||
async def select_speaker(self, thread: List[GroupChatPublishEvent]) -> str:
|
||||
"""Select a speaker from the participants based on handoff message."""
|
||||
if len(thread) > 0 and isinstance(thread[-1].agent_message, HandoffMessage):
|
||||
self._current_speaker = thread[-1].agent_message.content
|
||||
if self._current_speaker not in self._participant_topic_types:
|
||||
raise ValueError("The selected speaker in the handoff message is not a participant.")
|
||||
event_logger.debug(GroupChatSelectSpeakerEvent(selected_speaker=self._current_speaker, source=self.id))
|
||||
return self._current_speaker
|
||||
else:
|
||||
return self._current_speaker
|
||||
|
||||
|
||||
class Swarm(BaseGroupChat):
|
||||
"""(Experimental) A group chat that selects the next speaker based on handoff message only."""
|
||||
|
||||
def __init__(self, participants: List[ChatAgent]):
|
||||
super().__init__(participants, group_chat_manager_class=SwarmGroupChatManager)
|
||||
|
||||
def _create_group_chat_manager_factory(
|
||||
self,
|
||||
parent_topic_type: str,
|
||||
group_topic_type: str,
|
||||
participant_topic_types: List[str],
|
||||
participant_descriptions: List[str],
|
||||
termination_condition: TerminationCondition | None,
|
||||
) -> Callable[[], SwarmGroupChatManager]:
|
||||
def _factory() -> SwarmGroupChatManager:
|
||||
return SwarmGroupChatManager(
|
||||
parent_topic_type,
|
||||
group_topic_type,
|
||||
participant_topic_types,
|
||||
participant_descriptions,
|
||||
termination_condition,
|
||||
)
|
||||
|
||||
return _factory
|
||||
@ -13,11 +13,17 @@ from autogen_agentchat.agents import (
|
||||
ToolUseAssistantAgent,
|
||||
)
|
||||
from autogen_agentchat.logging import FileLogHandler
|
||||
from autogen_agentchat.messages import ChatMessage, StopMessage, TextMessage
|
||||
from autogen_agentchat.task import StopMessageTermination
|
||||
from autogen_agentchat.messages import (
|
||||
ChatMessage,
|
||||
HandoffMessage,
|
||||
StopMessage,
|
||||
TextMessage,
|
||||
)
|
||||
from autogen_agentchat.task import MaxMessageTermination, StopMessageTermination
|
||||
from autogen_agentchat.teams import (
|
||||
RoundRobinGroupChat,
|
||||
SelectorGroupChat,
|
||||
Swarm,
|
||||
)
|
||||
from autogen_core.base import CancellationToken
|
||||
from autogen_core.components import FunctionCall
|
||||
@ -212,7 +218,16 @@ async def test_round_robin_group_chat_with_tools(monkeypatch: pytest.MonkeyPatch
|
||||
)
|
||||
echo_agent = _EchoAgent("echo_agent", description="echo agent")
|
||||
team = RoundRobinGroupChat(participants=[tool_use_agent, echo_agent])
|
||||
await team.run("Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination())
|
||||
result = await team.run(
|
||||
"Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination()
|
||||
)
|
||||
|
||||
assert len(result.messages) == 4
|
||||
assert isinstance(result.messages[0], TextMessage) # task
|
||||
assert isinstance(result.messages[1], TextMessage) # tool use agent response
|
||||
assert isinstance(result.messages[2], TextMessage) # echo agent response
|
||||
assert isinstance(result.messages[3], StopMessage) # tool use agent response
|
||||
|
||||
context = tool_use_agent._model_context # pyright: ignore
|
||||
assert context[0].content == "Write a program that prints 'Hello, world!'"
|
||||
assert isinstance(context[1].content, list)
|
||||
@ -393,3 +408,29 @@ async def test_selector_group_chat_two_speakers_allow_repeated(monkeypatch: pyte
|
||||
assert result.messages[1].source == "agent2"
|
||||
assert result.messages[2].source == "agent2"
|
||||
assert result.messages[3].source == "agent1"
|
||||
|
||||
|
||||
class _HandOffAgent(BaseChatAgent):
|
||||
def __init__(self, name: str, description: str, next_agent: str) -> None:
|
||||
super().__init__(name, description)
|
||||
self._next_agent = next_agent
|
||||
|
||||
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage:
|
||||
return HandoffMessage(content=self._next_agent, source=self.name)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_swarm() -> None:
|
||||
first_agent = _HandOffAgent("first_agent", description="first agent", next_agent="second_agent")
|
||||
second_agent = _HandOffAgent("second_agent", description="second agent", next_agent="third_agent")
|
||||
third_agent = _HandOffAgent("third_agent", description="third agent", next_agent="first_agent")
|
||||
|
||||
team = Swarm([second_agent, first_agent, third_agent])
|
||||
result = await team.run("task", termination_condition=MaxMessageTermination(6))
|
||||
assert len(result.messages) == 6
|
||||
assert result.messages[0].content == "task"
|
||||
assert result.messages[1].content == "third_agent"
|
||||
assert result.messages[2].content == "first_agent"
|
||||
assert result.messages[3].content == "second_agent"
|
||||
assert result.messages[4].content == "third_agent"
|
||||
assert result.messages[5].content == "first_agent"
|
||||
|
||||
@ -4,6 +4,7 @@ from typing import Any, AsyncGenerator, List
|
||||
|
||||
import pytest
|
||||
from autogen_agentchat.agents import ToolUseAssistantAgent
|
||||
from autogen_agentchat.messages import StopMessage, TextMessage
|
||||
from autogen_core.components.models import OpenAIChatCompletionClient
|
||||
from autogen_core.components.tools import FunctionTool
|
||||
from openai.resources.chat.completions import AsyncCompletions
|
||||
@ -103,5 +104,6 @@ async def test_round_robin_group_chat_with_tools(monkeypatch: pytest.MonkeyPatch
|
||||
)
|
||||
result = await tool_use_agent.run("task")
|
||||
assert len(result.messages) == 3
|
||||
# assert isinstance(result.messages[1], ToolCallMessage)
|
||||
# assert isinstance(result.messages[2], TextMessage)
|
||||
assert isinstance(result.messages[0], TextMessage)
|
||||
assert isinstance(result.messages[1], TextMessage)
|
||||
assert isinstance(result.messages[2], StopMessage)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user