Formalize ChatAgent response as a dataclass with inner messages (#3990)

This commit is contained in:
Eric Zhu 2024-10-30 10:27:57 -07:00 committed by GitHub
parent e63fd17ed5
commit 3d51ab76ae
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 157 additions and 72 deletions

View File

@ -18,12 +18,16 @@ from autogen_core.components.tools import FunctionTool, Tool
from pydantic import BaseModel, ConfigDict, Field, model_validator from pydantic import BaseModel, ConfigDict, Field, model_validator
from .. import EVENT_LOGGER_NAME from .. import EVENT_LOGGER_NAME
from ..base import Response
from ..messages import ( from ..messages import (
ChatMessage, ChatMessage,
HandoffMessage, HandoffMessage,
InnerMessage,
ResetMessage, ResetMessage,
StopMessage, StopMessage,
TextMessage, TextMessage,
ToolCallMessage,
ToolCallResultMessages,
) )
from ._base_chat_agent import BaseChatAgent from ._base_chat_agent import BaseChatAgent
@ -214,7 +218,7 @@ class AssistantAgent(BaseChatAgent):
return [TextMessage, HandoffMessage, StopMessage] return [TextMessage, HandoffMessage, StopMessage]
return [TextMessage, StopMessage] return [TextMessage, StopMessage]
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage: async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
# Add messages to the model context. # Add messages to the model context.
for msg in messages: for msg in messages:
if isinstance(msg, ResetMessage): if isinstance(msg, ResetMessage):
@ -222,6 +226,9 @@ class AssistantAgent(BaseChatAgent):
else: else:
self._model_context.append(UserMessage(content=msg.content, source=msg.source)) self._model_context.append(UserMessage(content=msg.content, source=msg.source))
# Inner messages.
inner_messages: List[InnerMessage] = []
# Generate an inference result based on the current model context. # Generate an inference result based on the current model context.
llm_messages = self._system_messages + self._model_context llm_messages = self._system_messages + self._model_context
result = await self._model_client.create( result = await self._model_client.create(
@ -234,12 +241,16 @@ class AssistantAgent(BaseChatAgent):
# Run tool calls until the model produces a string response. # Run tool calls until the model produces a string response.
while isinstance(result.content, list) and all(isinstance(item, FunctionCall) for item in result.content): while isinstance(result.content, list) and all(isinstance(item, FunctionCall) for item in result.content):
event_logger.debug(ToolCallEvent(tool_calls=result.content, source=self.name)) event_logger.debug(ToolCallEvent(tool_calls=result.content, source=self.name))
# Add the tool call message to the output.
inner_messages.append(ToolCallMessage(content=result.content, source=self.name))
# Execute the tool calls. # Execute the tool calls.
results = await asyncio.gather( results = await asyncio.gather(
*[self._execute_tool_call(call, cancellation_token) for call in result.content] *[self._execute_tool_call(call, cancellation_token) for call in result.content]
) )
event_logger.debug(ToolCallResultEvent(tool_call_results=results, source=self.name)) event_logger.debug(ToolCallResultEvent(tool_call_results=results, source=self.name))
self._model_context.append(FunctionExecutionResultMessage(content=results)) self._model_context.append(FunctionExecutionResultMessage(content=results))
inner_messages.append(ToolCallResultMessages(content=results, source=self.name))
# Detect handoff requests. # Detect handoff requests.
handoffs: List[Handoff] = [] handoffs: List[Handoff] = []
@ -249,8 +260,13 @@ class AssistantAgent(BaseChatAgent):
if len(handoffs) > 0: if len(handoffs) > 0:
if len(handoffs) > 1: if len(handoffs) > 1:
raise ValueError(f"Multiple handoffs detected: {[handoff.name for handoff in handoffs]}") raise ValueError(f"Multiple handoffs detected: {[handoff.name for handoff in handoffs]}")
# Respond with a handoff message. # Return the output messages to signal the handoff.
return HandoffMessage(content=handoffs[0].message, target=handoffs[0].target, source=self.name) return Response(
chat_message=HandoffMessage(
content=handoffs[0].message, target=handoffs[0].target, source=self.name
),
inner_messages=inner_messages,
)
# Generate an inference result based on the current model context. # Generate an inference result based on the current model context.
result = await self._model_client.create( result = await self._model_client.create(
@ -262,9 +278,13 @@ class AssistantAgent(BaseChatAgent):
# Detect stop request. # Detect stop request.
request_stop = "terminate" in result.content.strip().lower() request_stop = "terminate" in result.content.strip().lower()
if request_stop: if request_stop:
return StopMessage(content=result.content, source=self.name) return Response(
chat_message=StopMessage(content=result.content, source=self.name), inner_messages=inner_messages
)
return TextMessage(content=result.content, source=self.name) return Response(
chat_message=TextMessage(content=result.content, source=self.name), inner_messages=inner_messages
)
async def _execute_tool_call( async def _execute_tool_call(
self, tool_call: FunctionCall, cancellation_token: CancellationToken self, tool_call: FunctionCall, cancellation_token: CancellationToken

View File

@ -3,9 +3,8 @@ from typing import List, Sequence
from autogen_core.base import CancellationToken from autogen_core.base import CancellationToken
from ..base import ChatAgent, TaskResult, TerminationCondition from ..base import ChatAgent, Response, TaskResult, TerminationCondition
from ..messages import ChatMessage from ..messages import ChatMessage, InnerMessage, TextMessage
from ..teams import RoundRobinGroupChat
class BaseChatAgent(ChatAgent, ABC): class BaseChatAgent(ChatAgent, ABC):
@ -37,8 +36,8 @@ class BaseChatAgent(ChatAgent, ABC):
... ...
@abstractmethod @abstractmethod
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage: async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
"""Handle incoming messages and return a response message.""" """Handles incoming messages and returns a response."""
... ...
async def run( async def run(
@ -49,10 +48,12 @@ class BaseChatAgent(ChatAgent, ABC):
termination_condition: TerminationCondition | None = None, termination_condition: TerminationCondition | None = None,
) -> TaskResult: ) -> TaskResult:
"""Run the agent with the given task and return the result.""" """Run the agent with the given task and return the result."""
group_chat = RoundRobinGroupChat(participants=[self]) if cancellation_token is None:
result = await group_chat.run( cancellation_token = CancellationToken()
task=task, first_message = TextMessage(content=task, source="user")
cancellation_token=cancellation_token, response = await self.on_messages([first_message], cancellation_token)
termination_condition=termination_condition, messages: List[InnerMessage | ChatMessage] = [first_message]
) if response.inner_messages is not None:
return result messages += response.inner_messages
messages.append(response.chat_message)
return TaskResult(messages=messages)

View File

@ -3,6 +3,7 @@ from typing import List, Sequence
from autogen_core.base import CancellationToken from autogen_core.base import CancellationToken
from autogen_core.components.code_executor import CodeBlock, CodeExecutor, extract_markdown_code_blocks from autogen_core.components.code_executor import CodeBlock, CodeExecutor, extract_markdown_code_blocks
from ..base import Response
from ..messages import ChatMessage, TextMessage from ..messages import ChatMessage, TextMessage
from ._base_chat_agent import BaseChatAgent from ._base_chat_agent import BaseChatAgent
@ -25,7 +26,7 @@ class CodeExecutorAgent(BaseChatAgent):
"""The types of messages that the code executor agent produces.""" """The types of messages that the code executor agent produces."""
return [TextMessage] return [TextMessage]
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage: async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
# Extract code blocks from the messages. # Extract code blocks from the messages.
code_blocks: List[CodeBlock] = [] code_blocks: List[CodeBlock] = []
for msg in messages: for msg in messages:
@ -34,6 +35,6 @@ class CodeExecutorAgent(BaseChatAgent):
if code_blocks: if code_blocks:
# Execute the code blocks. # Execute the code blocks.
result = await self._code_executor.execute_code_blocks(code_blocks, cancellation_token=cancellation_token) result = await self._code_executor.execute_code_blocks(code_blocks, cancellation_token=cancellation_token)
return TextMessage(content=result.output, source=self.name) return Response(chat_message=TextMessage(content=result.output, source=self.name))
else: else:
return TextMessage(content="No code blocks found in the thread.", source=self.name) return Response(chat_message=TextMessage(content="No code blocks found in the thread.", source=self.name))

View File

@ -1,10 +1,11 @@
from ._chat_agent import ChatAgent from ._chat_agent import ChatAgent, Response
from ._task import TaskResult, TaskRunner from ._task import TaskResult, TaskRunner
from ._team import Team from ._team import Team
from ._termination import TerminatedException, TerminationCondition from ._termination import TerminatedException, TerminationCondition
__all__ = [ __all__ = [
"ChatAgent", "ChatAgent",
"Response",
"Team", "Team",
"TerminatedException", "TerminatedException",
"TerminationCondition", "TerminationCondition",

View File

@ -1,12 +1,24 @@
from dataclasses import dataclass
from typing import List, Protocol, Sequence, runtime_checkable from typing import List, Protocol, Sequence, runtime_checkable
from autogen_core.base import CancellationToken from autogen_core.base import CancellationToken
from ..messages import ChatMessage from ..messages import ChatMessage, InnerMessage
from ._task import TaskResult, TaskRunner from ._task import TaskResult, TaskRunner
from ._termination import TerminationCondition from ._termination import TerminationCondition
@dataclass(kw_only=True)
class Response:
"""A response from calling :meth:`ChatAgent.on_messages`."""
chat_message: ChatMessage
"""A chat message produced by the agent as the response."""
inner_messages: List[InnerMessage] | None = None
"""Inner messages produced by the agent."""
@runtime_checkable @runtime_checkable
class ChatAgent(TaskRunner, Protocol): class ChatAgent(TaskRunner, Protocol):
"""Protocol for a chat agent.""" """Protocol for a chat agent."""
@ -29,8 +41,8 @@ class ChatAgent(TaskRunner, Protocol):
"""The types of messages that the agent produces.""" """The types of messages that the agent produces."""
... ...
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage: async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
"""Handle incoming messages and return a response message.""" """Handles incoming messages and returns a response."""
... ...
async def run( async def run(

View File

@ -3,7 +3,7 @@ from typing import Protocol, Sequence
from autogen_core.base import CancellationToken from autogen_core.base import CancellationToken
from ..messages import ChatMessage from ..messages import ChatMessage, InnerMessage
from ._termination import TerminationCondition from ._termination import TerminationCondition
@ -11,7 +11,7 @@ from ._termination import TerminationCondition
class TaskResult: class TaskResult:
"""Result of running a task.""" """Result of running a task."""
messages: Sequence[ChatMessage] messages: Sequence[InnerMessage | ChatMessage]
"""Messages produced by the task.""" """Messages produced by the task."""

View File

@ -1,6 +1,7 @@
from typing import List from typing import List
from autogen_core.components import Image from autogen_core.components import FunctionCall, Image
from autogen_core.components.models import FunctionExecutionResult
from pydantic import BaseModel from pydantic import BaseModel
@ -49,8 +50,26 @@ class ResetMessage(BaseMessage):
"""The content for the reset message.""" """The content for the reset message."""
class ToolCallMessage(BaseMessage):
"""A message signaling the use of tools."""
content: List[FunctionCall]
"""The tool calls."""
class ToolCallResultMessages(BaseMessage):
"""A message signaling the results of tool calls."""
content: List[FunctionExecutionResult]
"""The tool call results."""
InnerMessage = ToolCallMessage | ToolCallResultMessages
"""Messages for intra-agent monologues."""
ChatMessage = TextMessage | MultiModalMessage | StopMessage | HandoffMessage | ResetMessage ChatMessage = TextMessage | MultiModalMessage | StopMessage | HandoffMessage | ResetMessage
"""A message used by agents in a team.""" """Messages for agent-to-agent communication."""
__all__ = [ __all__ = [
@ -60,5 +79,7 @@ __all__ = [
"StopMessage", "StopMessage",
"HandoffMessage", "HandoffMessage",
"ResetMessage", "ResetMessage",
"ToolCallMessage",
"ToolCallResultMessages",
"ChatMessage", "ChatMessage",
] ]

View File

@ -15,7 +15,7 @@ from autogen_core.base import (
from autogen_core.components import ClosureAgent, TypeSubscription from autogen_core.components import ClosureAgent, TypeSubscription
from ...base import ChatAgent, TaskResult, Team, TerminationCondition from ...base import ChatAgent, TaskResult, Team, TerminationCondition
from ...messages import ChatMessage, TextMessage from ...messages import ChatMessage, InnerMessage, TextMessage
from .._events import GroupChatPublishEvent, GroupChatRequestPublishEvent from .._events import GroupChatPublishEvent, GroupChatRequestPublishEvent
from ._base_group_chat_manager import BaseGroupChatManager from ._base_group_chat_manager import BaseGroupChatManager
from ._chat_agent_container import ChatAgentContainer from ._chat_agent_container import ChatAgentContainer
@ -56,12 +56,13 @@ class BaseGroupChat(Team, ABC):
def _create_participant_factory( def _create_participant_factory(
self, self,
parent_topic_type: str, parent_topic_type: str,
output_topic_type: str,
agent: ChatAgent, agent: ChatAgent,
) -> Callable[[], ChatAgentContainer]: ) -> Callable[[], ChatAgentContainer]:
def _factory() -> ChatAgentContainer: def _factory() -> ChatAgentContainer:
id = AgentInstantiationContext.current_agent_id() id = AgentInstantiationContext.current_agent_id()
assert id == AgentId(type=agent.name, key=self._team_id) assert id == AgentId(type=agent.name, key=self._team_id)
container = ChatAgentContainer(parent_topic_type, agent) container = ChatAgentContainer(parent_topic_type, output_topic_type, agent)
assert container.id == id assert container.id == id
return container return container
@ -85,6 +86,7 @@ class BaseGroupChat(Team, ABC):
group_chat_manager_topic_type = group_chat_manager_agent_type.type group_chat_manager_topic_type = group_chat_manager_agent_type.type
group_topic_type = "round_robin_group_topic" group_topic_type = "round_robin_group_topic"
team_topic_type = "team_topic" team_topic_type = "team_topic"
output_topic_type = "output_topic"
# Register participants. # Register participants.
participant_topic_types: List[str] = [] participant_topic_types: List[str] = []
@ -97,7 +99,7 @@ class BaseGroupChat(Team, ABC):
await ChatAgentContainer.register( await ChatAgentContainer.register(
runtime, runtime,
type=agent_type, type=agent_type,
factory=self._create_participant_factory(group_topic_type, participant), factory=self._create_participant_factory(group_topic_type, output_topic_type, participant),
) )
# Add subscriptions for the participant. # Add subscriptions for the participant.
await runtime.add_subscription(TypeSubscription(topic_type=topic_type, agent_type=agent_type)) await runtime.add_subscription(TypeSubscription(topic_type=topic_type, agent_type=agent_type))
@ -129,22 +131,22 @@ class BaseGroupChat(Team, ABC):
TypeSubscription(topic_type=team_topic_type, agent_type=group_chat_manager_agent_type.type) TypeSubscription(topic_type=team_topic_type, agent_type=group_chat_manager_agent_type.type)
) )
group_chat_messages: List[ChatMessage] = [] output_messages: List[InnerMessage | ChatMessage] = []
async def collect_group_chat_messages( async def collect_output_messages(
_runtime: AgentRuntime, _runtime: AgentRuntime,
id: AgentId, id: AgentId,
message: GroupChatPublishEvent, message: InnerMessage | ChatMessage,
ctx: MessageContext, ctx: MessageContext,
) -> None: ) -> None:
group_chat_messages.append(message.agent_message) output_messages.append(message)
await ClosureAgent.register( await ClosureAgent.register(
runtime, runtime,
type="collect_group_chat_messages", type="collect_output_messages",
closure=collect_group_chat_messages, closure=collect_output_messages,
subscriptions=lambda: [ subscriptions=lambda: [
TypeSubscription(topic_type=group_topic_type, agent_type="collect_group_chat_messages"), TypeSubscription(topic_type=output_topic_type, agent_type="collect_output_messages"),
], ],
) )
@ -154,8 +156,10 @@ class BaseGroupChat(Team, ABC):
# Run the team by publishing the task to the team topic and then requesting the result. # Run the team by publishing the task to the team topic and then requesting the result.
team_topic_id = TopicId(type=team_topic_type, source=self._team_id) team_topic_id = TopicId(type=team_topic_type, source=self._team_id)
group_chat_manager_topic_id = TopicId(type=group_chat_manager_topic_type, source=self._team_id) group_chat_manager_topic_id = TopicId(type=group_chat_manager_topic_type, source=self._team_id)
first_chat_message = TextMessage(content=task, source="user")
output_messages.append(first_chat_message)
await runtime.publish_message( await runtime.publish_message(
GroupChatPublishEvent(agent_message=TextMessage(content=task, source="user")), GroupChatPublishEvent(agent_message=first_chat_message),
topic_id=team_topic_id, topic_id=team_topic_id,
) )
await runtime.publish_message(GroupChatRequestPublishEvent(), topic_id=group_chat_manager_topic_id) await runtime.publish_message(GroupChatRequestPublishEvent(), topic_id=group_chat_manager_topic_id)
@ -164,4 +168,4 @@ class BaseGroupChat(Team, ABC):
await runtime.stop_when_idle() await runtime.stop_when_idle()
# Return the result. # Return the result.
return TaskResult(messages=group_chat_messages) return TaskResult(messages=output_messages)

View File

@ -16,12 +16,14 @@ class ChatAgentContainer(SequentialRoutedAgent):
Args: Args:
parent_topic_type (str): The topic type of the parent orchestrator. parent_topic_type (str): The topic type of the parent orchestrator.
output_topic_type (str): The topic type for the output.
agent (ChatAgent): The agent to delegate message handling to. agent (ChatAgent): The agent to delegate message handling to.
""" """
def __init__(self, parent_topic_type: str, agent: ChatAgent) -> None: def __init__(self, parent_topic_type: str, output_topic_type: str, agent: ChatAgent) -> None:
super().__init__(description=agent.description) super().__init__(description=agent.description)
self._parent_topic_type = parent_topic_type self._parent_topic_type = parent_topic_type
self._output_topic_type = output_topic_type
self._agent = agent self._agent = agent
self._message_buffer: List[ChatMessage] = [] self._message_buffer: List[ChatMessage] = []
@ -36,18 +38,27 @@ class ChatAgentContainer(SequentialRoutedAgent):
to the delegate agent and publish the response.""" to the delegate agent and publish the response."""
# Pass the messages in the buffer to the delegate agent. # Pass the messages in the buffer to the delegate agent.
response = await self._agent.on_messages(self._message_buffer, ctx.cancellation_token) response = await self._agent.on_messages(self._message_buffer, ctx.cancellation_token)
if not any(isinstance(response, msg_type) for msg_type in self._agent.produced_message_types): if not any(isinstance(response.chat_message, msg_type) for msg_type in self._agent.produced_message_types):
raise ValueError( raise ValueError(
f"The agent {self._agent.name} produced an unexpected message type: {type(response)}. " f"The agent {self._agent.name} produced an unexpected message type: {type(response)}. "
f"Expected one of: {self._agent.produced_message_types}" f"Expected one of: {self._agent.produced_message_types}. "
f"Check the agent's produced_message_types property."
) )
# Publish inner messages to the output topic.
if response.inner_messages is not None:
for inner_message in response.inner_messages:
await self.publish_message(inner_message, topic_id=DefaultTopicId(type=self._output_topic_type))
# Publish the response. # Publish the response.
self._message_buffer.clear() self._message_buffer.clear()
await self.publish_message( await self.publish_message(
GroupChatPublishEvent(agent_message=response, source=self.id), GroupChatPublishEvent(agent_message=response.chat_message, source=self.id),
topic_id=DefaultTopicId(type=self._parent_topic_type), topic_id=DefaultTopicId(type=self._parent_topic_type),
) )
# Publish the response to the output topic.
await self.publish_message(response.chat_message, topic_id=DefaultTopicId(type=self._output_topic_type))
async def on_unhandled_message(self, message: Any, ctx: MessageContext) -> None: async def on_unhandled_message(self, message: Any, ctx: MessageContext) -> None:
raise ValueError(f"Unhandled message in agent container: {type(message)}") raise ValueError(f"Unhandled message in agent container: {type(message)}")

View File

@ -7,7 +7,7 @@ import pytest
from autogen_agentchat import EVENT_LOGGER_NAME from autogen_agentchat import EVENT_LOGGER_NAME
from autogen_agentchat.agents import AssistantAgent, Handoff from autogen_agentchat.agents import AssistantAgent, Handoff
from autogen_agentchat.logging import FileLogHandler from autogen_agentchat.logging import FileLogHandler
from autogen_agentchat.messages import HandoffMessage, StopMessage, TextMessage from autogen_agentchat.messages import HandoffMessage, TextMessage, ToolCallMessage, ToolCallResultMessages
from autogen_core.base import CancellationToken from autogen_core.base import CancellationToken
from autogen_core.components.tools import FunctionTool from autogen_core.components.tools import FunctionTool
from autogen_ext.models import OpenAIChatCompletionClient from autogen_ext.models import OpenAIChatCompletionClient
@ -111,10 +111,11 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
tools=[_pass_function, _fail_function, FunctionTool(_echo_function, description="Echo")], tools=[_pass_function, _fail_function, FunctionTool(_echo_function, description="Echo")],
) )
result = await tool_use_agent.run("task") result = await tool_use_agent.run("task")
assert len(result.messages) == 3 assert len(result.messages) == 4
assert isinstance(result.messages[0], TextMessage) assert isinstance(result.messages[0], TextMessage)
assert isinstance(result.messages[1], TextMessage) assert isinstance(result.messages[1], ToolCallMessage)
assert isinstance(result.messages[2], StopMessage) assert isinstance(result.messages[2], ToolCallResultMessages)
assert isinstance(result.messages[3], TextMessage)
@pytest.mark.asyncio @pytest.mark.asyncio
@ -162,5 +163,5 @@ async def test_handoffs(monkeypatch: pytest.MonkeyPatch) -> None:
response = await tool_use_agent.on_messages( response = await tool_use_agent.on_messages(
[TextMessage(content="task", source="user")], cancellation_token=CancellationToken() [TextMessage(content="task", source="user")], cancellation_token=CancellationToken()
) )
assert isinstance(response, HandoffMessage) assert isinstance(response.chat_message, HandoffMessage)
assert response.target == "agent2" assert response.chat_message.target == "agent2"

View File

@ -12,12 +12,15 @@ from autogen_agentchat.agents import (
CodeExecutorAgent, CodeExecutorAgent,
Handoff, Handoff,
) )
from autogen_agentchat.base import Response
from autogen_agentchat.logging import FileLogHandler from autogen_agentchat.logging import FileLogHandler
from autogen_agentchat.messages import ( from autogen_agentchat.messages import (
ChatMessage, ChatMessage,
HandoffMessage, HandoffMessage,
StopMessage, StopMessage,
TextMessage, TextMessage,
ToolCallMessage,
ToolCallResultMessages,
) )
from autogen_agentchat.task import MaxMessageTermination, StopMessageTermination from autogen_agentchat.task import MaxMessageTermination, StopMessageTermination
from autogen_agentchat.teams import ( from autogen_agentchat.teams import (
@ -66,14 +69,14 @@ class _EchoAgent(BaseChatAgent):
def produced_message_types(self) -> List[type[ChatMessage]]: def produced_message_types(self) -> List[type[ChatMessage]]:
return [TextMessage] return [TextMessage]
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage: async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
if len(messages) > 0: if len(messages) > 0:
assert isinstance(messages[0], TextMessage) assert isinstance(messages[0], TextMessage)
self._last_message = messages[0].content self._last_message = messages[0].content
return TextMessage(content=messages[0].content, source=self.name) return Response(chat_message=TextMessage(content=messages[0].content, source=self.name))
else: else:
assert self._last_message is not None assert self._last_message is not None
return TextMessage(content=self._last_message, source=self.name) return Response(chat_message=TextMessage(content=self._last_message, source=self.name))
class _StopAgent(_EchoAgent): class _StopAgent(_EchoAgent):
@ -86,11 +89,11 @@ class _StopAgent(_EchoAgent):
def produced_message_types(self) -> List[type[ChatMessage]]: def produced_message_types(self) -> List[type[ChatMessage]]:
return [TextMessage, StopMessage] return [TextMessage, StopMessage]
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage: async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
self._count += 1 self._count += 1
if self._count < self._stop_at: if self._count < self._stop_at:
return await super().on_messages(messages, cancellation_token) return await super().on_messages(messages, cancellation_token)
return StopMessage(content="TERMINATE", source=self.name) return Response(chat_message=StopMessage(content="TERMINATE", source=self.name))
def _pass_function(input: str) -> str: def _pass_function(input: str) -> str:
@ -230,11 +233,13 @@ async def test_round_robin_group_chat_with_tools(monkeypatch: pytest.MonkeyPatch
"Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination() "Write a program that prints 'Hello, world!'", termination_condition=StopMessageTermination()
) )
assert len(result.messages) == 4 assert len(result.messages) == 6
assert isinstance(result.messages[0], TextMessage) # task assert isinstance(result.messages[0], TextMessage) # task
assert isinstance(result.messages[1], TextMessage) # tool use agent response assert isinstance(result.messages[1], ToolCallMessage) # tool call
assert isinstance(result.messages[2], TextMessage) # echo agent response assert isinstance(result.messages[2], ToolCallResultMessages) # tool call result
assert isinstance(result.messages[3], StopMessage) # tool use agent response assert isinstance(result.messages[3], TextMessage) # tool use agent response
assert isinstance(result.messages[4], TextMessage) # echo agent response
assert isinstance(result.messages[5], StopMessage) # tool use agent response
context = tool_use_agent._model_context # pyright: ignore context = tool_use_agent._model_context # pyright: ignore
assert context[0].content == "Write a program that prints 'Hello, world!'" assert context[0].content == "Write a program that prints 'Hello, world!'"
@ -427,8 +432,12 @@ class _HandOffAgent(BaseChatAgent):
def produced_message_types(self) -> List[type[ChatMessage]]: def produced_message_types(self) -> List[type[ChatMessage]]:
return [HandoffMessage] return [HandoffMessage]
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage: async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
return HandoffMessage(content=f"Transferred to {self._next_agent}.", target=self._next_agent, source=self.name) return Response(
chat_message=HandoffMessage(
content=f"Transferred to {self._next_agent}.", target=self._next_agent, source=self.name
)
)
@pytest.mark.asyncio @pytest.mark.asyncio
@ -513,9 +522,11 @@ async def test_swarm_handoff_using_tool_calls(monkeypatch: pytest.MonkeyPatch) -
agent2 = _HandOffAgent("agent2", description="agent 2", next_agent="agent1") agent2 = _HandOffAgent("agent2", description="agent 2", next_agent="agent1")
team = Swarm([agnet1, agent2]) team = Swarm([agnet1, agent2])
result = await team.run("task", termination_condition=StopMessageTermination()) result = await team.run("task", termination_condition=StopMessageTermination())
assert len(result.messages) == 5 assert len(result.messages) == 7
assert result.messages[0].content == "task" assert result.messages[0].content == "task"
assert result.messages[1].content == "handoff to agent2" assert isinstance(result.messages[1], ToolCallMessage)
assert result.messages[2].content == "Transferred to agent1." assert isinstance(result.messages[2], ToolCallResultMessages)
assert result.messages[3].content == "Hello" assert result.messages[3].content == "handoff to agent2"
assert result.messages[4].content == "TERMINATE" assert result.messages[4].content == "Transferred to agent1."
assert result.messages[5].content == "Hello"
assert result.messages[6].content == "TERMINATE"

View File

@ -251,6 +251,7 @@
"from typing import List, Sequence\n", "from typing import List, Sequence\n",
"\n", "\n",
"from autogen_agentchat.agents import BaseChatAgent\n", "from autogen_agentchat.agents import BaseChatAgent\n",
"from autogen_agentchat.base import Response\n",
"from autogen_agentchat.messages import (\n", "from autogen_agentchat.messages import (\n",
" ChatMessage,\n", " ChatMessage,\n",
" StopMessage,\n", " StopMessage,\n",
@ -266,11 +267,11 @@
" def produced_message_types(self) -> List[type[ChatMessage]]:\n", " def produced_message_types(self) -> List[type[ChatMessage]]:\n",
" return [TextMessage, StopMessage]\n", " return [TextMessage, StopMessage]\n",
"\n", "\n",
" async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage:\n", " async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:\n",
" user_input = await asyncio.get_event_loop().run_in_executor(None, input, \"Enter your response: \")\n", " user_input = await asyncio.get_event_loop().run_in_executor(None, input, \"Enter your response: \")\n",
" if \"TERMINATE\" in user_input:\n", " if \"TERMINATE\" in user_input:\n",
" return StopMessage(content=\"User has terminated the conversation.\", source=self.name)\n", " return Response(chat_message=StopMessage(content=\"User has terminated the conversation.\", source=self.name))\n",
" return TextMessage(content=user_input, source=self.name)\n", " return Response(chat_message=TextMessage(content=user_input, source=self.name))\n",
"\n", "\n",
"\n", "\n",
"user_proxy_agent = UserProxyAgent(name=\"user_proxy_agent\")\n", "user_proxy_agent = UserProxyAgent(name=\"user_proxy_agent\")\n",

View File

@ -45,6 +45,7 @@
" CodingAssistantAgent,\n", " CodingAssistantAgent,\n",
" ToolUseAssistantAgent,\n", " ToolUseAssistantAgent,\n",
")\n", ")\n",
"from autogen_agentchat.base import Response\n",
"from autogen_agentchat.messages import ChatMessage, StopMessage, TextMessage\n", "from autogen_agentchat.messages import ChatMessage, StopMessage, TextMessage\n",
"from autogen_agentchat.task import StopMessageTermination\n", "from autogen_agentchat.task import StopMessageTermination\n",
"from autogen_agentchat.teams import SelectorGroupChat\n", "from autogen_agentchat.teams import SelectorGroupChat\n",
@ -75,11 +76,11 @@
" def produced_message_types(self) -> List[type[ChatMessage]]:\n", " def produced_message_types(self) -> List[type[ChatMessage]]:\n",
" return [TextMessage, StopMessage]\n", " return [TextMessage, StopMessage]\n",
"\n", "\n",
" async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> ChatMessage:\n", " async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:\n",
" user_input = await asyncio.get_event_loop().run_in_executor(None, input, \"Enter your response: \")\n", " user_input = await asyncio.get_event_loop().run_in_executor(None, input, \"Enter your response: \")\n",
" if \"TERMINATE\" in user_input:\n", " if \"TERMINATE\" in user_input:\n",
" return StopMessage(content=\"User has terminated the conversation.\", source=self.name)\n", " return Response(chat_message=StopMessage(content=\"User has terminated the conversation.\", source=self.name))\n",
" return TextMessage(content=user_input, source=self.name)" " return Response(chat_message=TextMessage(content=user_input, source=self.name))"
] ]
}, },
{ {