diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py b/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py index 398be978e..fd78be50d 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/messages.py @@ -5,6 +5,7 @@ class and includes specific fields relevant to the type of message being sent. """ from abc import ABC, abstractmethod +from datetime import datetime, timezone from typing import Any, Dict, Generic, List, Literal, Mapping, Optional, Type, TypeVar from autogen_core import Component, ComponentBase, FunctionCall, Image @@ -85,6 +86,9 @@ class BaseChatMessage(BaseMessage, ABC): metadata: Dict[str, str] = {} """Additional metadata about the message.""" + created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + """The time when the message was created.""" + @abstractmethod def to_model_text(self) -> str: """Convert the content of the message to text-only representation. @@ -154,6 +158,9 @@ class BaseAgentEvent(BaseMessage, ABC): metadata: Dict[str, str] = {} """Additional metadata about the message.""" + created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc)) + """The time when the message was created.""" + StructuredContentType = TypeVar("StructuredContentType", bound=BaseModel, covariant=True) """Type variable for structured content types.""" diff --git a/python/packages/autogen-agentchat/tests/test_assistant_agent.py b/python/packages/autogen-agentchat/tests/test_assistant_agent.py index 988707d95..4565dd28d 100644 --- a/python/packages/autogen-agentchat/tests/test_assistant_agent.py +++ b/python/packages/autogen-agentchat/tests/test_assistant_agent.py @@ -41,7 +41,7 @@ from autogen_ext.tools.mcp import ( SseServerParams, ) from pydantic import BaseModel, ValidationError -from utils import FileLogHandler +from utils import FileLogHandler, compare_messages, compare_task_results logger = logging.getLogger(EVENT_LOGGER_NAME) logger.setLevel(logging.DEBUG) @@ -180,9 +180,9 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None: index = 0 async for message in agent.run_stream(task="task"): if isinstance(message, TaskResult): - assert message == result + assert compare_task_results(message, result) else: - assert message == result.messages[index] + assert compare_messages(message, result.messages[index]) index += 1 # Test state saving and loading. @@ -273,9 +273,9 @@ async def test_run_with_tools_and_reflection() -> None: index = 0 async for message in agent.run_stream(task="task"): if isinstance(message, TaskResult): - assert message == result + assert compare_task_results(message, result) else: - assert message == result.messages[index] + assert compare_messages(message, result.messages[index]) index += 1 # Test state saving and loading. @@ -363,9 +363,9 @@ async def test_run_with_parallel_tools() -> None: index = 0 async for message in agent.run_stream(task="task"): if isinstance(message, TaskResult): - assert message == result + assert compare_task_results(message, result) else: - assert message == result.messages[index] + assert compare_messages(message, result.messages[index]) index += 1 # Test state saving and loading. @@ -446,9 +446,9 @@ async def test_run_with_parallel_tools_with_empty_call_ids() -> None: index = 0 async for message in agent.run_stream(task="task"): if isinstance(message, TaskResult): - assert message == result + assert compare_task_results(message, result) else: - assert message == result.messages[index] + assert compare_messages(message, result.messages[index]) index += 1 # Test state saving and loading. @@ -560,9 +560,9 @@ async def test_run_with_workbench() -> None: index = 0 async for message in agent.run_stream(task="task"): if isinstance(message, TaskResult): - assert message == result + assert compare_task_results(message, result) else: - assert message == result.messages[index] + assert compare_messages(message, result.messages[index]) index += 1 # Test state saving and loading. @@ -779,9 +779,9 @@ async def test_handoffs() -> None: index = 0 async for message in tool_use_agent.run_stream(task="task"): if isinstance(message, TaskResult): - assert message == result + assert compare_task_results(message, result) else: - assert message == result.messages[index] + assert compare_messages(message, result.messages[index]) index += 1 @@ -852,9 +852,9 @@ async def test_handoff_with_tool_call_context() -> None: index = 0 async for message in tool_use_agent.run_stream(task="task"): if isinstance(message, TaskResult): - assert message == result + assert compare_task_results(message, result) else: - assert message == result.messages[index] + assert compare_messages(message, result.messages[index]) index += 1 @@ -927,9 +927,9 @@ async def test_custom_handoffs() -> None: index = 0 async for message in tool_use_agent.run_stream(task="task"): if isinstance(message, TaskResult): - assert message == result + assert compare_task_results(message, result) else: - assert message == result.messages[index] + assert compare_messages(message, result.messages[index]) index += 1 @@ -1004,9 +1004,9 @@ async def test_custom_object_handoffs() -> None: index = 0 async for message in tool_use_agent.run_stream(task="task"): if isinstance(message, TaskResult): - assert message == result + assert compare_task_results(message, result) else: - assert message == result.messages[index] + assert compare_messages(message, result.messages[index]) index += 1 @@ -1161,9 +1161,9 @@ async def test_list_chat_messages(monkeypatch: pytest.MonkeyPatch) -> None: index = 0 async for message in agent.run_stream(task=messages): if isinstance(message, TaskResult): - assert message == result + assert compare_task_results(message, result) else: - assert message == result.messages[index] + assert compare_messages(message, result.messages[index]) index += 1 diff --git a/python/packages/autogen-agentchat/tests/test_group_chat.py b/python/packages/autogen-agentchat/tests/test_group_chat.py index 889b67d87..9e15eb789 100644 --- a/python/packages/autogen-agentchat/tests/test_group_chat.py +++ b/python/packages/autogen-agentchat/tests/test_group_chat.py @@ -55,7 +55,7 @@ from autogen_ext.code_executors.local import LocalCommandLineCodeExecutor from autogen_ext.models.openai import OpenAIChatCompletionClient from autogen_ext.models.replay import ReplayChatCompletionClient from pydantic import BaseModel -from utils import FileLogHandler +from utils import FileLogHandler, compare_messages, compare_task_results logger = logging.getLogger(EVENT_LOGGER_NAME) logger.setLevel(logging.DEBUG) @@ -269,9 +269,9 @@ async def test_round_robin_group_chat(runtime: AgentRuntime | None) -> None: task="Write a program that prints 'Hello, world!'", ): if isinstance(message, TaskResult): - assert message == result + assert compare_task_results(message, result) else: - assert message == result.messages[index] + assert compare_messages(message, result.messages[index]) index += 1 # Test message input. @@ -282,7 +282,7 @@ async def test_round_robin_group_chat(runtime: AgentRuntime | None) -> None: result_2 = await team.run( task=TextMessage(content="Write a program that prints 'Hello, world!'", source="user") ) - assert result == result_2 + assert compare_task_results(result, result_2) # Test multi-modal message. model_client.reset() @@ -293,7 +293,9 @@ async def test_round_robin_group_chat(runtime: AgentRuntime | None) -> None: assert isinstance(result.messages[0], TextMessage) assert isinstance(result_2.messages[0], MultiModalMessage) assert result.messages[0].content == task.content[0] - assert result.messages[1:] == result_2.messages[1:] + assert len(result.messages[1:]) == len(result_2.messages[1:]) + for i in range(1, len(result.messages)): + assert compare_messages(result.messages[i], result_2.messages[i]) @pytest.mark.asyncio @@ -339,9 +341,9 @@ async def test_round_robin_group_chat_with_team_event(runtime: AgentRuntime | No task="Write a program that prints 'Hello, world!'", ): if isinstance(message, TaskResult): - assert message == result + assert compare_task_results(message, result) else: - assert message == result.messages[index] + assert compare_messages(message, result.messages[index]) index += 1 @@ -496,9 +498,9 @@ async def test_round_robin_group_chat_with_tools(runtime: AgentRuntime | None) - task="Write a program that prints 'Hello, world!'", ): if isinstance(message, TaskResult): - assert message == result + assert compare_task_results(message, result) else: - assert message == result.messages[index] + assert compare_messages(message, result.messages[index]) index += 1 # Test Console. @@ -507,7 +509,7 @@ async def test_round_robin_group_chat_with_tools(runtime: AgentRuntime | None) - index = 0 await team.reset() result2 = await Console(team.run_stream(task="Write a program that prints 'Hello, world!'")) - assert result2 == result + assert compare_task_results(result2, result) @pytest.mark.asyncio @@ -685,9 +687,9 @@ async def test_selector_group_chat(runtime: AgentRuntime | None) -> None: task="Write a program that prints 'Hello, world!'", ): if isinstance(message, TaskResult): - assert message == result + assert compare_task_results(message, result) else: - assert message == result.messages[index] + assert compare_messages(message, result.messages[index]) index += 1 # Test Console. @@ -696,7 +698,7 @@ async def test_selector_group_chat(runtime: AgentRuntime | None) -> None: index = 0 await team.reset() result2 = await Console(team.run_stream(task="Write a program that prints 'Hello, world!'")) - assert result2 == result + assert compare_task_results(result2, result) @pytest.mark.asyncio @@ -806,9 +808,9 @@ async def test_selector_group_chat_with_team_event(runtime: AgentRuntime | None) task="Write a program that prints 'Hello, world!'", ): if isinstance(message, TaskResult): - assert message == result + assert compare_task_results(message, result) else: - assert message == result.messages[index] + assert compare_messages(message, result.messages[index]) index += 1 @@ -910,9 +912,9 @@ async def test_selector_group_chat_two_speakers(runtime: AgentRuntime | None) -> await team.reset() async for message in team.run_stream(task="Write a program that prints 'Hello, world!'"): if isinstance(message, TaskResult): - assert message == result + assert compare_task_results(message, result) else: - assert message == result.messages[index] + assert compare_messages(message, result.messages[index]) index += 1 # Test Console. @@ -921,7 +923,7 @@ async def test_selector_group_chat_two_speakers(runtime: AgentRuntime | None) -> index = 0 await team.reset() result2 = await Console(team.run_stream(task="Write a program that prints 'Hello, world!'")) - assert result2 == result + assert compare_task_results(result2, result) @pytest.mark.asyncio @@ -958,9 +960,9 @@ async def test_selector_group_chat_two_speakers_allow_repeated(runtime: AgentRun await team.reset() async for message in team.run_stream(task="Write a program that prints 'Hello, world!'"): if isinstance(message, TaskResult): - assert message == result + assert compare_task_results(message, result) else: - assert message == result.messages[index] + assert compare_messages(message, result.messages[index]) index += 1 # Test Console. @@ -968,7 +970,7 @@ async def test_selector_group_chat_two_speakers_allow_repeated(runtime: AgentRun index = 0 await team.reset() result2 = await Console(team.run_stream(task="Write a program that prints 'Hello, world!'")) - assert result2 == result + assert compare_task_results(result2, result) @pytest.mark.asyncio @@ -1174,9 +1176,9 @@ async def test_swarm_handoff(runtime: AgentRuntime | None) -> None: stream = team.run_stream(task="task") async for message in stream: if isinstance(message, TaskResult): - assert message == result + assert compare_task_results(message, result) else: - assert message == result.messages[index] + assert compare_messages(message, result.messages[index]) index += 1 # Test save and load. @@ -1248,9 +1250,9 @@ async def test_swarm_handoff_with_team_events(runtime: AgentRuntime | None) -> N stream = team.run_stream(task="task") async for message in stream: if isinstance(message, TaskResult): - assert message == result + assert compare_task_results(message, result) else: - assert message == result.messages[index] + assert compare_messages(message, result.messages[index]) index += 1 @@ -1366,9 +1368,9 @@ async def test_swarm_handoff_using_tool_calls(runtime: AgentRuntime | None) -> N stream = team.run_stream(task="task") async for message in stream: if isinstance(message, TaskResult): - assert message == result + assert compare_task_results(message, result) else: - assert message == result.messages[index] + assert compare_messages(message, result.messages[index]) index += 1 # Test Console @@ -1377,7 +1379,7 @@ async def test_swarm_handoff_using_tool_calls(runtime: AgentRuntime | None) -> N index = 0 await team.reset() result2 = await Console(team.run_stream(task="task")) - assert result2 == result + assert compare_task_results(result2, result) @pytest.mark.asyncio @@ -1471,14 +1473,17 @@ async def test_swarm_with_parallel_tool_calls(runtime: AgentRuntime | None) -> N team = Swarm([agent1, agent2], termination_condition=termination, runtime=runtime) result = await team.run(task="task") assert len(result.messages) == 6 - assert result.messages[0] == TextMessage(content="task", source="user") + assert compare_messages(result.messages[0], TextMessage(content="task", source="user")) assert isinstance(result.messages[1], ToolCallRequestEvent) assert isinstance(result.messages[2], ToolCallExecutionEvent) - assert result.messages[3] == HandoffMessage( - content="handoff to agent2", - target="agent2", - source="agent1", - context=expected_handoff_context, + assert compare_messages( + result.messages[3], + HandoffMessage( + content="handoff to agent2", + target="agent2", + source="agent1", + context=expected_handoff_context, + ), ) assert isinstance(result.messages[4], TextMessage) assert result.messages[4].content == "Hello" @@ -1598,9 +1603,9 @@ async def test_round_robin_group_chat_with_message_list(runtime: AgentRuntime | index = 0 async for message in team.run_stream(task=messages): if isinstance(message, TaskResult): - assert message == result + assert compare_task_results(message, result) else: - assert message == result.messages[index] + assert compare_messages(message, result.messages[index]) index += 1 # Test with invalid message list @@ -1803,7 +1808,7 @@ async def test_selector_group_chat_streaming(runtime: AgentRuntime | None) -> No streaming: List[str] = [] async for message in team.run_stream(task="Write a program that prints 'Hello, world!'"): if isinstance(message, TaskResult): - assert message == result + assert compare_task_results(message, result) elif isinstance(message, ModelClientStreamingChunkEvent): streaming.append(message.content) else: @@ -1811,5 +1816,5 @@ async def test_selector_group_chat_streaming(runtime: AgentRuntime | None) -> No assert isinstance(message, SelectorEvent) assert message.content == "".join([chunk for chunk in streaming]) streaming = [] - assert message == result.messages[index] + assert compare_messages(message, result.messages[index]) index += 1 diff --git a/python/packages/autogen-agentchat/tests/test_group_chat_graph.py b/python/packages/autogen-agentchat/tests/test_group_chat_graph.py index 2f381528c..216256bcf 100644 --- a/python/packages/autogen-agentchat/tests/test_group_chat_graph.py +++ b/python/packages/autogen-agentchat/tests/test_group_chat_graph.py @@ -32,6 +32,7 @@ from autogen_agentchat.teams._group_chat._graph._digraph_group_chat import ( from autogen_core import AgentRuntime, CancellationToken, Component, SingleThreadedAgentRuntime from autogen_ext.models.replay import ReplayChatCompletionClient from pydantic import BaseModel +from utils import compare_message_lists, compare_task_results def test_create_digraph() -> None: @@ -1207,10 +1208,10 @@ async def test_graph_flow_serialize_deserialize() -> None: de_results = await deserialized_team.run(task="Start") assert serialized == serialized_deserialized - assert results == de_results + assert compare_task_results(results, de_results) assert results.stop_reason is not None assert results.stop_reason == de_results.stop_reason - assert results.messages == de_results.messages + assert compare_message_lists(results.messages, de_results.messages) assert isinstance(results.messages[0], TextMessage) assert results.messages[0].source == "user" assert results.messages[0].content == "Start" diff --git a/python/packages/autogen-agentchat/tests/utils.py b/python/packages/autogen-agentchat/tests/utils.py index 90d66f0ba..85d0d41f1 100644 --- a/python/packages/autogen-agentchat/tests/utils.py +++ b/python/packages/autogen-agentchat/tests/utils.py @@ -2,7 +2,10 @@ import json import logging import sys from datetime import datetime +from typing import Sequence +from autogen_agentchat.base._task import TaskResult +from autogen_agentchat.messages import BaseAgentEvent, BaseChatMessage, BaseTextChatMessage from pydantic import BaseModel @@ -18,7 +21,7 @@ class FileLogHandler(logging.Handler): record.msg = json.dumps( { "timestamp": ts, - "message": record.msg.model_dump(), + "message": record.msg.model_dump_json(indent=2), "type": record.msg.__class__.__name__, }, ) @@ -37,3 +40,36 @@ class ConsoleLogHandler(logging.Handler): }, ) sys.stdout.write(f"{record.msg}\n") + + +def compare_messages( + msg1: BaseAgentEvent | BaseChatMessage | BaseTextChatMessage, + msg2: BaseAgentEvent | BaseChatMessage | BaseTextChatMessage, +) -> bool: + if isinstance(msg1, BaseTextChatMessage) and isinstance(msg2, BaseTextChatMessage): + if msg1.content != msg2.content: + return False + return ( + (msg1.source == msg2.source) and (msg1.models_usage == msg2.models_usage) and (msg1.metadata == msg2.metadata) + ) + + +def compare_message_lists( + msgs1: Sequence[BaseAgentEvent | BaseChatMessage], + msgs2: Sequence[BaseAgentEvent | BaseChatMessage], +) -> bool: + if len(msgs1) != len(msgs2): + return False + for i in range(len(msgs1)): + if not compare_messages(msgs1[i], msgs2[i]): + return False + return True + + +def compare_task_results( + res1: TaskResult, + res2: TaskResult, +) -> bool: + if res1.stop_reason != res2.stop_reason: + return False + return compare_message_lists(res1.messages, res2.messages) diff --git a/python/packages/autogen-ext/tests/test_filesurfer_agent.py b/python/packages/autogen-ext/tests/test_filesurfer_agent.py index ef7d58e71..de2bbfec8 100644 --- a/python/packages/autogen-ext/tests/test_filesurfer_agent.py +++ b/python/packages/autogen-ext/tests/test_filesurfer_agent.py @@ -32,7 +32,7 @@ class FileLogHandler(logging.Handler): record.msg = json.dumps( { "timestamp": ts, - "message": record.msg.model_dump(), + "message": record.msg.model_dump_json(indent=2), "type": record.msg.__class__.__name__, }, ) diff --git a/python/packages/autogen-ext/tests/test_websurfer_agent.py b/python/packages/autogen-ext/tests/test_websurfer_agent.py index e9fb222d7..371a8833b 100644 --- a/python/packages/autogen-ext/tests/test_websurfer_agent.py +++ b/python/packages/autogen-ext/tests/test_websurfer_agent.py @@ -33,7 +33,7 @@ class FileLogHandler(logging.Handler): record.msg = json.dumps( { "timestamp": ts, - "message": record.msg.model_dump(), + "message": record.msg.model_dump_json(indent=2), "type": record.msg.__class__.__name__, }, )