Add created_at to BaseChatMessage and BaseAgentEvent (#6557)

## Why are these changes needed?

I added `created_at` to both BaseChatMessage and BaseAgentEvent classes
that store the time these Pydantic model instances are generated. And
then users will be able to use `created_at` to build up a customized
external persisting state management layer for their case.

## Related issue number


https://github.com/microsoft/autogen/discussions/6169#discussioncomment-13151540

## Checks

- [x] I've included any doc changes needed for
<https://microsoft.github.io/autogen/>. See
<https://github.com/microsoft/autogen/blob/main/CONTRIBUTING.md> to
build and test documentation locally.
- [x] I've added tests (if relevant) corresponding to the changes
introduced in this PR.
- [x] I've made sure all auto checks have passed.

---------

Co-authored-by: Jack Gerrits <jackgerrits@users.noreply.github.com>
Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
This commit is contained in:
Sungjun.Kim 2025-05-23 14:29:24 +09:00 committed by GitHub
parent 726e0be110
commit db125fbd2d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 113 additions and 64 deletions

View File

@ -5,6 +5,7 @@ class and includes specific fields relevant to the type of message being sent.
"""
from abc import ABC, abstractmethod
from datetime import datetime, timezone
from typing import Any, Dict, Generic, List, Literal, Mapping, Optional, Type, TypeVar
from autogen_core import Component, ComponentBase, FunctionCall, Image
@ -85,6 +86,9 @@ class BaseChatMessage(BaseMessage, ABC):
metadata: Dict[str, str] = {}
"""Additional metadata about the message."""
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
"""The time when the message was created."""
@abstractmethod
def to_model_text(self) -> str:
"""Convert the content of the message to text-only representation.
@ -154,6 +158,9 @@ class BaseAgentEvent(BaseMessage, ABC):
metadata: Dict[str, str] = {}
"""Additional metadata about the message."""
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
"""The time when the message was created."""
StructuredContentType = TypeVar("StructuredContentType", bound=BaseModel, covariant=True)
"""Type variable for structured content types."""

View File

@ -41,7 +41,7 @@ from autogen_ext.tools.mcp import (
SseServerParams,
)
from pydantic import BaseModel, ValidationError
from utils import FileLogHandler
from utils import FileLogHandler, compare_messages, compare_task_results
logger = logging.getLogger(EVENT_LOGGER_NAME)
logger.setLevel(logging.DEBUG)
@ -180,9 +180,9 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
index = 0
async for message in agent.run_stream(task="task"):
if isinstance(message, TaskResult):
assert message == result
assert compare_task_results(message, result)
else:
assert message == result.messages[index]
assert compare_messages(message, result.messages[index])
index += 1
# Test state saving and loading.
@ -273,9 +273,9 @@ async def test_run_with_tools_and_reflection() -> None:
index = 0
async for message in agent.run_stream(task="task"):
if isinstance(message, TaskResult):
assert message == result
assert compare_task_results(message, result)
else:
assert message == result.messages[index]
assert compare_messages(message, result.messages[index])
index += 1
# Test state saving and loading.
@ -363,9 +363,9 @@ async def test_run_with_parallel_tools() -> None:
index = 0
async for message in agent.run_stream(task="task"):
if isinstance(message, TaskResult):
assert message == result
assert compare_task_results(message, result)
else:
assert message == result.messages[index]
assert compare_messages(message, result.messages[index])
index += 1
# Test state saving and loading.
@ -446,9 +446,9 @@ async def test_run_with_parallel_tools_with_empty_call_ids() -> None:
index = 0
async for message in agent.run_stream(task="task"):
if isinstance(message, TaskResult):
assert message == result
assert compare_task_results(message, result)
else:
assert message == result.messages[index]
assert compare_messages(message, result.messages[index])
index += 1
# Test state saving and loading.
@ -560,9 +560,9 @@ async def test_run_with_workbench() -> None:
index = 0
async for message in agent.run_stream(task="task"):
if isinstance(message, TaskResult):
assert message == result
assert compare_task_results(message, result)
else:
assert message == result.messages[index]
assert compare_messages(message, result.messages[index])
index += 1
# Test state saving and loading.
@ -779,9 +779,9 @@ async def test_handoffs() -> None:
index = 0
async for message in tool_use_agent.run_stream(task="task"):
if isinstance(message, TaskResult):
assert message == result
assert compare_task_results(message, result)
else:
assert message == result.messages[index]
assert compare_messages(message, result.messages[index])
index += 1
@ -852,9 +852,9 @@ async def test_handoff_with_tool_call_context() -> None:
index = 0
async for message in tool_use_agent.run_stream(task="task"):
if isinstance(message, TaskResult):
assert message == result
assert compare_task_results(message, result)
else:
assert message == result.messages[index]
assert compare_messages(message, result.messages[index])
index += 1
@ -927,9 +927,9 @@ async def test_custom_handoffs() -> None:
index = 0
async for message in tool_use_agent.run_stream(task="task"):
if isinstance(message, TaskResult):
assert message == result
assert compare_task_results(message, result)
else:
assert message == result.messages[index]
assert compare_messages(message, result.messages[index])
index += 1
@ -1004,9 +1004,9 @@ async def test_custom_object_handoffs() -> None:
index = 0
async for message in tool_use_agent.run_stream(task="task"):
if isinstance(message, TaskResult):
assert message == result
assert compare_task_results(message, result)
else:
assert message == result.messages[index]
assert compare_messages(message, result.messages[index])
index += 1
@ -1161,9 +1161,9 @@ async def test_list_chat_messages(monkeypatch: pytest.MonkeyPatch) -> None:
index = 0
async for message in agent.run_stream(task=messages):
if isinstance(message, TaskResult):
assert message == result
assert compare_task_results(message, result)
else:
assert message == result.messages[index]
assert compare_messages(message, result.messages[index])
index += 1

View File

@ -55,7 +55,7 @@ from autogen_ext.code_executors.local import LocalCommandLineCodeExecutor
from autogen_ext.models.openai import OpenAIChatCompletionClient
from autogen_ext.models.replay import ReplayChatCompletionClient
from pydantic import BaseModel
from utils import FileLogHandler
from utils import FileLogHandler, compare_messages, compare_task_results
logger = logging.getLogger(EVENT_LOGGER_NAME)
logger.setLevel(logging.DEBUG)
@ -269,9 +269,9 @@ async def test_round_robin_group_chat(runtime: AgentRuntime | None) -> None:
task="Write a program that prints 'Hello, world!'",
):
if isinstance(message, TaskResult):
assert message == result
assert compare_task_results(message, result)
else:
assert message == result.messages[index]
assert compare_messages(message, result.messages[index])
index += 1
# Test message input.
@ -282,7 +282,7 @@ async def test_round_robin_group_chat(runtime: AgentRuntime | None) -> None:
result_2 = await team.run(
task=TextMessage(content="Write a program that prints 'Hello, world!'", source="user")
)
assert result == result_2
assert compare_task_results(result, result_2)
# Test multi-modal message.
model_client.reset()
@ -293,7 +293,9 @@ async def test_round_robin_group_chat(runtime: AgentRuntime | None) -> None:
assert isinstance(result.messages[0], TextMessage)
assert isinstance(result_2.messages[0], MultiModalMessage)
assert result.messages[0].content == task.content[0]
assert result.messages[1:] == result_2.messages[1:]
assert len(result.messages[1:]) == len(result_2.messages[1:])
for i in range(1, len(result.messages)):
assert compare_messages(result.messages[i], result_2.messages[i])
@pytest.mark.asyncio
@ -339,9 +341,9 @@ async def test_round_robin_group_chat_with_team_event(runtime: AgentRuntime | No
task="Write a program that prints 'Hello, world!'",
):
if isinstance(message, TaskResult):
assert message == result
assert compare_task_results(message, result)
else:
assert message == result.messages[index]
assert compare_messages(message, result.messages[index])
index += 1
@ -496,9 +498,9 @@ async def test_round_robin_group_chat_with_tools(runtime: AgentRuntime | None) -
task="Write a program that prints 'Hello, world!'",
):
if isinstance(message, TaskResult):
assert message == result
assert compare_task_results(message, result)
else:
assert message == result.messages[index]
assert compare_messages(message, result.messages[index])
index += 1
# Test Console.
@ -507,7 +509,7 @@ async def test_round_robin_group_chat_with_tools(runtime: AgentRuntime | None) -
index = 0
await team.reset()
result2 = await Console(team.run_stream(task="Write a program that prints 'Hello, world!'"))
assert result2 == result
assert compare_task_results(result2, result)
@pytest.mark.asyncio
@ -685,9 +687,9 @@ async def test_selector_group_chat(runtime: AgentRuntime | None) -> None:
task="Write a program that prints 'Hello, world!'",
):
if isinstance(message, TaskResult):
assert message == result
assert compare_task_results(message, result)
else:
assert message == result.messages[index]
assert compare_messages(message, result.messages[index])
index += 1
# Test Console.
@ -696,7 +698,7 @@ async def test_selector_group_chat(runtime: AgentRuntime | None) -> None:
index = 0
await team.reset()
result2 = await Console(team.run_stream(task="Write a program that prints 'Hello, world!'"))
assert result2 == result
assert compare_task_results(result2, result)
@pytest.mark.asyncio
@ -806,9 +808,9 @@ async def test_selector_group_chat_with_team_event(runtime: AgentRuntime | None)
task="Write a program that prints 'Hello, world!'",
):
if isinstance(message, TaskResult):
assert message == result
assert compare_task_results(message, result)
else:
assert message == result.messages[index]
assert compare_messages(message, result.messages[index])
index += 1
@ -910,9 +912,9 @@ async def test_selector_group_chat_two_speakers(runtime: AgentRuntime | None) ->
await team.reset()
async for message in team.run_stream(task="Write a program that prints 'Hello, world!'"):
if isinstance(message, TaskResult):
assert message == result
assert compare_task_results(message, result)
else:
assert message == result.messages[index]
assert compare_messages(message, result.messages[index])
index += 1
# Test Console.
@ -921,7 +923,7 @@ async def test_selector_group_chat_two_speakers(runtime: AgentRuntime | None) ->
index = 0
await team.reset()
result2 = await Console(team.run_stream(task="Write a program that prints 'Hello, world!'"))
assert result2 == result
assert compare_task_results(result2, result)
@pytest.mark.asyncio
@ -958,9 +960,9 @@ async def test_selector_group_chat_two_speakers_allow_repeated(runtime: AgentRun
await team.reset()
async for message in team.run_stream(task="Write a program that prints 'Hello, world!'"):
if isinstance(message, TaskResult):
assert message == result
assert compare_task_results(message, result)
else:
assert message == result.messages[index]
assert compare_messages(message, result.messages[index])
index += 1
# Test Console.
@ -968,7 +970,7 @@ async def test_selector_group_chat_two_speakers_allow_repeated(runtime: AgentRun
index = 0
await team.reset()
result2 = await Console(team.run_stream(task="Write a program that prints 'Hello, world!'"))
assert result2 == result
assert compare_task_results(result2, result)
@pytest.mark.asyncio
@ -1174,9 +1176,9 @@ async def test_swarm_handoff(runtime: AgentRuntime | None) -> None:
stream = team.run_stream(task="task")
async for message in stream:
if isinstance(message, TaskResult):
assert message == result
assert compare_task_results(message, result)
else:
assert message == result.messages[index]
assert compare_messages(message, result.messages[index])
index += 1
# Test save and load.
@ -1248,9 +1250,9 @@ async def test_swarm_handoff_with_team_events(runtime: AgentRuntime | None) -> N
stream = team.run_stream(task="task")
async for message in stream:
if isinstance(message, TaskResult):
assert message == result
assert compare_task_results(message, result)
else:
assert message == result.messages[index]
assert compare_messages(message, result.messages[index])
index += 1
@ -1366,9 +1368,9 @@ async def test_swarm_handoff_using_tool_calls(runtime: AgentRuntime | None) -> N
stream = team.run_stream(task="task")
async for message in stream:
if isinstance(message, TaskResult):
assert message == result
assert compare_task_results(message, result)
else:
assert message == result.messages[index]
assert compare_messages(message, result.messages[index])
index += 1
# Test Console
@ -1377,7 +1379,7 @@ async def test_swarm_handoff_using_tool_calls(runtime: AgentRuntime | None) -> N
index = 0
await team.reset()
result2 = await Console(team.run_stream(task="task"))
assert result2 == result
assert compare_task_results(result2, result)
@pytest.mark.asyncio
@ -1471,14 +1473,17 @@ async def test_swarm_with_parallel_tool_calls(runtime: AgentRuntime | None) -> N
team = Swarm([agent1, agent2], termination_condition=termination, runtime=runtime)
result = await team.run(task="task")
assert len(result.messages) == 6
assert result.messages[0] == TextMessage(content="task", source="user")
assert compare_messages(result.messages[0], TextMessage(content="task", source="user"))
assert isinstance(result.messages[1], ToolCallRequestEvent)
assert isinstance(result.messages[2], ToolCallExecutionEvent)
assert result.messages[3] == HandoffMessage(
content="handoff to agent2",
target="agent2",
source="agent1",
context=expected_handoff_context,
assert compare_messages(
result.messages[3],
HandoffMessage(
content="handoff to agent2",
target="agent2",
source="agent1",
context=expected_handoff_context,
),
)
assert isinstance(result.messages[4], TextMessage)
assert result.messages[4].content == "Hello"
@ -1598,9 +1603,9 @@ async def test_round_robin_group_chat_with_message_list(runtime: AgentRuntime |
index = 0
async for message in team.run_stream(task=messages):
if isinstance(message, TaskResult):
assert message == result
assert compare_task_results(message, result)
else:
assert message == result.messages[index]
assert compare_messages(message, result.messages[index])
index += 1
# Test with invalid message list
@ -1803,7 +1808,7 @@ async def test_selector_group_chat_streaming(runtime: AgentRuntime | None) -> No
streaming: List[str] = []
async for message in team.run_stream(task="Write a program that prints 'Hello, world!'"):
if isinstance(message, TaskResult):
assert message == result
assert compare_task_results(message, result)
elif isinstance(message, ModelClientStreamingChunkEvent):
streaming.append(message.content)
else:
@ -1811,5 +1816,5 @@ async def test_selector_group_chat_streaming(runtime: AgentRuntime | None) -> No
assert isinstance(message, SelectorEvent)
assert message.content == "".join([chunk for chunk in streaming])
streaming = []
assert message == result.messages[index]
assert compare_messages(message, result.messages[index])
index += 1

View File

@ -32,6 +32,7 @@ from autogen_agentchat.teams._group_chat._graph._digraph_group_chat import (
from autogen_core import AgentRuntime, CancellationToken, Component, SingleThreadedAgentRuntime
from autogen_ext.models.replay import ReplayChatCompletionClient
from pydantic import BaseModel
from utils import compare_message_lists, compare_task_results
def test_create_digraph() -> None:
@ -1207,10 +1208,10 @@ async def test_graph_flow_serialize_deserialize() -> None:
de_results = await deserialized_team.run(task="Start")
assert serialized == serialized_deserialized
assert results == de_results
assert compare_task_results(results, de_results)
assert results.stop_reason is not None
assert results.stop_reason == de_results.stop_reason
assert results.messages == de_results.messages
assert compare_message_lists(results.messages, de_results.messages)
assert isinstance(results.messages[0], TextMessage)
assert results.messages[0].source == "user"
assert results.messages[0].content == "Start"

View File

@ -2,7 +2,10 @@ import json
import logging
import sys
from datetime import datetime
from typing import Sequence
from autogen_agentchat.base._task import TaskResult
from autogen_agentchat.messages import BaseAgentEvent, BaseChatMessage, BaseTextChatMessage
from pydantic import BaseModel
@ -18,7 +21,7 @@ class FileLogHandler(logging.Handler):
record.msg = json.dumps(
{
"timestamp": ts,
"message": record.msg.model_dump(),
"message": record.msg.model_dump_json(indent=2),
"type": record.msg.__class__.__name__,
},
)
@ -37,3 +40,36 @@ class ConsoleLogHandler(logging.Handler):
},
)
sys.stdout.write(f"{record.msg}\n")
def compare_messages(
msg1: BaseAgentEvent | BaseChatMessage | BaseTextChatMessage,
msg2: BaseAgentEvent | BaseChatMessage | BaseTextChatMessage,
) -> bool:
if isinstance(msg1, BaseTextChatMessage) and isinstance(msg2, BaseTextChatMessage):
if msg1.content != msg2.content:
return False
return (
(msg1.source == msg2.source) and (msg1.models_usage == msg2.models_usage) and (msg1.metadata == msg2.metadata)
)
def compare_message_lists(
msgs1: Sequence[BaseAgentEvent | BaseChatMessage],
msgs2: Sequence[BaseAgentEvent | BaseChatMessage],
) -> bool:
if len(msgs1) != len(msgs2):
return False
for i in range(len(msgs1)):
if not compare_messages(msgs1[i], msgs2[i]):
return False
return True
def compare_task_results(
res1: TaskResult,
res2: TaskResult,
) -> bool:
if res1.stop_reason != res2.stop_reason:
return False
return compare_message_lists(res1.messages, res2.messages)

View File

@ -32,7 +32,7 @@ class FileLogHandler(logging.Handler):
record.msg = json.dumps(
{
"timestamp": ts,
"message": record.msg.model_dump(),
"message": record.msg.model_dump_json(indent=2),
"type": record.msg.__class__.__name__,
},
)

View File

@ -33,7 +33,7 @@ class FileLogHandler(logging.Handler):
record.msg = json.dumps(
{
"timestamp": ts,
"message": record.msg.model_dump(),
"message": record.msg.model_dump_json(indent=2),
"type": record.msg.__class__.__name__,
},
)