mirror of
https://github.com/microsoft/autogen.git
synced 2025-08-17 13:11:30 +00:00
Return message history in agentchat (#661)
* update TeamRunResult * fix line ending in test * lint * update team result to list[chatmessage] --------- Co-authored-by: Leonardo Pinheiro <lpinheiro@microsoft.com> Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
This commit is contained in:
parent
e7342d558c
commit
7fade2d5e7
@ -1,10 +1,12 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Protocol
|
from typing import List, Protocol
|
||||||
|
|
||||||
|
from autogen_agentchat.agents._base_chat_agent import ChatMessage
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TeamRunResult:
|
class TeamRunResult:
|
||||||
result: str
|
messages: List[ChatMessage]
|
||||||
|
|
||||||
|
|
||||||
class BaseTeam(Protocol):
|
class BaseTeam(Protocol):
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import asyncio
|
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Callable, List
|
from typing import Callable, List
|
||||||
|
|
||||||
|
from autogen_agentchat.agents._base_chat_agent import ChatMessage
|
||||||
from autogen_core.application import SingleThreadedAgentRuntime
|
from autogen_core.application import SingleThreadedAgentRuntime
|
||||||
from autogen_core.base import AgentId, AgentInstantiationContext, AgentRuntime, AgentType, MessageContext, TopicId
|
from autogen_core.base import AgentId, AgentInstantiationContext, AgentRuntime, AgentType, MessageContext, TopicId
|
||||||
from autogen_core.components import ClosureAgent, TypeSubscription
|
from autogen_core.components import ClosureAgent, TypeSubscription
|
||||||
@ -132,19 +132,20 @@ class RoundRobinGroupChat(BaseTeam):
|
|||||||
TypeSubscription(topic_type=team_topic_type, agent_type=group_chat_manager_agent_type.type)
|
TypeSubscription(topic_type=team_topic_type, agent_type=group_chat_manager_agent_type.type)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create a closure agent to recieve the final result.
|
group_chat_messages: List[ChatMessage] = []
|
||||||
team_messages = asyncio.Queue[ContentPublishEvent]()
|
|
||||||
|
|
||||||
async def output_result(
|
async def collect_group_chat_messages(
|
||||||
_runtime: AgentRuntime, id: AgentId, message: ContentPublishEvent, ctx: MessageContext
|
_runtime: AgentRuntime, id: AgentId, message: ContentPublishEvent, ctx: MessageContext
|
||||||
) -> None:
|
) -> None:
|
||||||
await team_messages.put(message)
|
group_chat_messages.append(message.agent_message)
|
||||||
|
|
||||||
await ClosureAgent.register(
|
await ClosureAgent.register(
|
||||||
runtime,
|
runtime,
|
||||||
type="output_result",
|
type="collect_group_chat_messages",
|
||||||
closure=output_result,
|
closure=collect_group_chat_messages,
|
||||||
subscriptions=lambda: [TypeSubscription(topic_type=team_topic_type, agent_type="output_result")],
|
subscriptions=lambda: [
|
||||||
|
TypeSubscription(topic_type=group_topic_type, agent_type="collect_group_chat_messages")
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Start the runtime.
|
# Start the runtime.
|
||||||
@ -162,14 +163,4 @@ class RoundRobinGroupChat(BaseTeam):
|
|||||||
# Wait for the runtime to stop.
|
# Wait for the runtime to stop.
|
||||||
await runtime.stop_when_idle()
|
await runtime.stop_when_idle()
|
||||||
|
|
||||||
# Get the last message from the team.
|
return TeamRunResult(messages=group_chat_messages)
|
||||||
last_message = None
|
|
||||||
while not team_messages.empty():
|
|
||||||
last_message = await team_messages.get()
|
|
||||||
|
|
||||||
assert (
|
|
||||||
last_message is not None
|
|
||||||
and isinstance(last_message.agent_message, TextMessage)
|
|
||||||
and isinstance(last_message.agent_message.content, str)
|
|
||||||
)
|
|
||||||
return TeamRunResult(last_message.agent_message.content)
|
|
||||||
|
@ -95,7 +95,20 @@ async def test_round_robin_group_chat(monkeypatch: pytest.MonkeyPatch) -> None:
|
|||||||
)
|
)
|
||||||
team = RoundRobinGroupChat(participants=[coding_assistant_agent, code_executor_agent])
|
team = RoundRobinGroupChat(participants=[coding_assistant_agent, code_executor_agent])
|
||||||
result = await team.run("Write a program that prints 'Hello, world!'")
|
result = await team.run("Write a program that prints 'Hello, world!'")
|
||||||
assert result.result == "TERMINATE"
|
expected_messages = [
|
||||||
|
"Write a program that prints 'Hello, world!'",
|
||||||
|
'Here is the program\n ```python\nprint("Hello, world!")\n```',
|
||||||
|
"Hello, world!",
|
||||||
|
"TERMINATE",
|
||||||
|
]
|
||||||
|
# Normalize the messages to remove \r\n and any leading/trailing whitespace.
|
||||||
|
normalized_messages = [
|
||||||
|
msg.content.replace("\r\n", "\n").rstrip("\n") if isinstance(msg.content, str) else msg.content
|
||||||
|
for msg in result.messages
|
||||||
|
]
|
||||||
|
|
||||||
|
# Assert that all expected messages are in the collected messages
|
||||||
|
assert normalized_messages == expected_messages
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
Loading…
x
Reference in New Issue
Block a user