Return message history in agentchat (#661)

* update TeamRunResult

* fix line ending in test

* lint

* update team result to list[chatmessage]

---------

Co-authored-by: Leonardo Pinheiro <lpinheiro@microsoft.com>
Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
This commit is contained in:
Leonardo Pinheiro 2024-10-01 10:03:20 +10:00 committed by Jack Gerrits
parent e7342d558c
commit 7fade2d5e7
3 changed files with 28 additions and 22 deletions

View File

@ -1,10 +1,12 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Protocol from typing import List, Protocol
from autogen_agentchat.agents._base_chat_agent import ChatMessage
@dataclass @dataclass
class TeamRunResult: class TeamRunResult:
result: str messages: List[ChatMessage]
class BaseTeam(Protocol): class BaseTeam(Protocol):

View File

@ -1,7 +1,7 @@
import asyncio
import uuid import uuid
from typing import Callable, List from typing import Callable, List
from autogen_agentchat.agents._base_chat_agent import ChatMessage
from autogen_core.application import SingleThreadedAgentRuntime from autogen_core.application import SingleThreadedAgentRuntime
from autogen_core.base import AgentId, AgentInstantiationContext, AgentRuntime, AgentType, MessageContext, TopicId from autogen_core.base import AgentId, AgentInstantiationContext, AgentRuntime, AgentType, MessageContext, TopicId
from autogen_core.components import ClosureAgent, TypeSubscription from autogen_core.components import ClosureAgent, TypeSubscription
@ -132,19 +132,20 @@ class RoundRobinGroupChat(BaseTeam):
TypeSubscription(topic_type=team_topic_type, agent_type=group_chat_manager_agent_type.type) TypeSubscription(topic_type=team_topic_type, agent_type=group_chat_manager_agent_type.type)
) )
# Create a closure agent to recieve the final result. group_chat_messages: List[ChatMessage] = []
team_messages = asyncio.Queue[ContentPublishEvent]()
async def output_result( async def collect_group_chat_messages(
_runtime: AgentRuntime, id: AgentId, message: ContentPublishEvent, ctx: MessageContext _runtime: AgentRuntime, id: AgentId, message: ContentPublishEvent, ctx: MessageContext
) -> None: ) -> None:
await team_messages.put(message) group_chat_messages.append(message.agent_message)
await ClosureAgent.register( await ClosureAgent.register(
runtime, runtime,
type="output_result", type="collect_group_chat_messages",
closure=output_result, closure=collect_group_chat_messages,
subscriptions=lambda: [TypeSubscription(topic_type=team_topic_type, agent_type="output_result")], subscriptions=lambda: [
TypeSubscription(topic_type=group_topic_type, agent_type="collect_group_chat_messages")
],
) )
# Start the runtime. # Start the runtime.
@ -162,14 +163,4 @@ class RoundRobinGroupChat(BaseTeam):
# Wait for the runtime to stop. # Wait for the runtime to stop.
await runtime.stop_when_idle() await runtime.stop_when_idle()
# Get the last message from the team. return TeamRunResult(messages=group_chat_messages)
last_message = None
while not team_messages.empty():
last_message = await team_messages.get()
assert (
last_message is not None
and isinstance(last_message.agent_message, TextMessage)
and isinstance(last_message.agent_message.content, str)
)
return TeamRunResult(last_message.agent_message.content)

View File

@ -95,7 +95,20 @@ async def test_round_robin_group_chat(monkeypatch: pytest.MonkeyPatch) -> None:
) )
team = RoundRobinGroupChat(participants=[coding_assistant_agent, code_executor_agent]) team = RoundRobinGroupChat(participants=[coding_assistant_agent, code_executor_agent])
result = await team.run("Write a program that prints 'Hello, world!'") result = await team.run("Write a program that prints 'Hello, world!'")
assert result.result == "TERMINATE" expected_messages = [
"Write a program that prints 'Hello, world!'",
'Here is the program\n ```python\nprint("Hello, world!")\n```',
"Hello, world!",
"TERMINATE",
]
# Normalize the messages to remove \r\n and any leading/trailing whitespace.
normalized_messages = [
msg.content.replace("\r\n", "\n").rstrip("\n") if isinstance(msg.content, str) else msg.content
for msg in result.messages
]
# Assert that all expected messages are in the collected messages
assert normalized_messages == expected_messages
@pytest.mark.asyncio @pytest.mark.asyncio