From 7fade2d5e7108e32847173f457b8cc101193714e Mon Sep 17 00:00:00 2001 From: Leonardo Pinheiro Date: Tue, 1 Oct 2024 10:03:20 +1000 Subject: [PATCH] Return message history in agentchat (#661) * update TeamRunResult * fix line ending in test * lint * update team result to list[chatmessage] --------- Co-authored-by: Leonardo Pinheiro Co-authored-by: Eric Zhu --- .../src/autogen_agentchat/teams/_base_team.py | 6 ++-- .../group_chat/_round_robin_group_chat.py | 29 +++++++------------ .../tests/test_group_chat.py | 15 +++++++++- 3 files changed, 28 insertions(+), 22 deletions(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_base_team.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_base_team.py index 8653372d1..4ef19eb3d 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_base_team.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_base_team.py @@ -1,10 +1,12 @@ from dataclasses import dataclass -from typing import Protocol +from typing import List, Protocol + +from autogen_agentchat.agents._base_chat_agent import ChatMessage @dataclass class TeamRunResult: - result: str + messages: List[ChatMessage] class BaseTeam(Protocol): diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/group_chat/_round_robin_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/group_chat/_round_robin_group_chat.py index e1d50a6cb..d2f21d8ef 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/group_chat/_round_robin_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/group_chat/_round_robin_group_chat.py @@ -1,7 +1,7 @@ -import asyncio import uuid from typing import Callable, List +from autogen_agentchat.agents._base_chat_agent import ChatMessage from autogen_core.application import SingleThreadedAgentRuntime from autogen_core.base import AgentId, AgentInstantiationContext, AgentRuntime, AgentType, MessageContext, TopicId from autogen_core.components import ClosureAgent, TypeSubscription @@ -132,19 +132,20 @@ class RoundRobinGroupChat(BaseTeam): TypeSubscription(topic_type=team_topic_type, agent_type=group_chat_manager_agent_type.type) ) - # Create a closure agent to recieve the final result. - team_messages = asyncio.Queue[ContentPublishEvent]() + group_chat_messages: List[ChatMessage] = [] - async def output_result( + async def collect_group_chat_messages( _runtime: AgentRuntime, id: AgentId, message: ContentPublishEvent, ctx: MessageContext ) -> None: - await team_messages.put(message) + group_chat_messages.append(message.agent_message) await ClosureAgent.register( runtime, - type="output_result", - closure=output_result, - subscriptions=lambda: [TypeSubscription(topic_type=team_topic_type, agent_type="output_result")], + type="collect_group_chat_messages", + closure=collect_group_chat_messages, + subscriptions=lambda: [ + TypeSubscription(topic_type=group_topic_type, agent_type="collect_group_chat_messages") + ], ) # Start the runtime. @@ -162,14 +163,4 @@ class RoundRobinGroupChat(BaseTeam): # Wait for the runtime to stop. await runtime.stop_when_idle() - # Get the last message from the team. - last_message = None - while not team_messages.empty(): - last_message = await team_messages.get() - - assert ( - last_message is not None - and isinstance(last_message.agent_message, TextMessage) - and isinstance(last_message.agent_message.content, str) - ) - return TeamRunResult(last_message.agent_message.content) + return TeamRunResult(messages=group_chat_messages) diff --git a/python/packages/autogen-agentchat/tests/test_group_chat.py b/python/packages/autogen-agentchat/tests/test_group_chat.py index 2008f2c26..f728d5736 100644 --- a/python/packages/autogen-agentchat/tests/test_group_chat.py +++ b/python/packages/autogen-agentchat/tests/test_group_chat.py @@ -95,7 +95,20 @@ async def test_round_robin_group_chat(monkeypatch: pytest.MonkeyPatch) -> None: ) team = RoundRobinGroupChat(participants=[coding_assistant_agent, code_executor_agent]) result = await team.run("Write a program that prints 'Hello, world!'") - assert result.result == "TERMINATE" + expected_messages = [ + "Write a program that prints 'Hello, world!'", + 'Here is the program\n ```python\nprint("Hello, world!")\n```', + "Hello, world!", + "TERMINATE", + ] + # Normalize the messages to remove \r\n and any leading/trailing whitespace. + normalized_messages = [ + msg.content.replace("\r\n", "\n").rstrip("\n") if isinstance(msg.content, str) else msg.content + for msg in result.messages + ] + + # Assert that all expected messages are in the collected messages + assert normalized_messages == expected_messages @pytest.mark.asyncio