mirror of
https://github.com/microsoft/autogen.git
synced 2025-09-01 20:37:33 +00:00
Add created_at to BaseChatMessage and BaseAgentEvent (#6557)
## Why are these changes needed? I added `created_at` to both BaseChatMessage and BaseAgentEvent classes that store the time these Pydantic model instances are generated. And then users will be able to use `created_at` to build up a customized external persisting state management layer for their case. ## Related issue number https://github.com/microsoft/autogen/discussions/6169#discussioncomment-13151540 ## Checks - [x] I've included any doc changes needed for <https://microsoft.github.io/autogen/>. See <https://github.com/microsoft/autogen/blob/main/CONTRIBUTING.md> to build and test documentation locally. - [x] I've added tests (if relevant) corresponding to the changes introduced in this PR. - [x] I've made sure all auto checks have passed. --------- Co-authored-by: Jack Gerrits <jackgerrits@users.noreply.github.com> Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
This commit is contained in:
parent
726e0be110
commit
db125fbd2d
@ -5,6 +5,7 @@ class and includes specific fields relevant to the type of message being sent.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, Generic, List, Literal, Mapping, Optional, Type, TypeVar
|
||||
|
||||
from autogen_core import Component, ComponentBase, FunctionCall, Image
|
||||
@ -85,6 +86,9 @@ class BaseChatMessage(BaseMessage, ABC):
|
||||
metadata: Dict[str, str] = {}
|
||||
"""Additional metadata about the message."""
|
||||
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
"""The time when the message was created."""
|
||||
|
||||
@abstractmethod
|
||||
def to_model_text(self) -> str:
|
||||
"""Convert the content of the message to text-only representation.
|
||||
@ -154,6 +158,9 @@ class BaseAgentEvent(BaseMessage, ABC):
|
||||
metadata: Dict[str, str] = {}
|
||||
"""Additional metadata about the message."""
|
||||
|
||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
"""The time when the message was created."""
|
||||
|
||||
|
||||
StructuredContentType = TypeVar("StructuredContentType", bound=BaseModel, covariant=True)
|
||||
"""Type variable for structured content types."""
|
||||
|
@ -41,7 +41,7 @@ from autogen_ext.tools.mcp import (
|
||||
SseServerParams,
|
||||
)
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from utils import FileLogHandler
|
||||
from utils import FileLogHandler, compare_messages, compare_task_results
|
||||
|
||||
logger = logging.getLogger(EVENT_LOGGER_NAME)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
@ -180,9 +180,9 @@ async def test_run_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
index = 0
|
||||
async for message in agent.run_stream(task="task"):
|
||||
if isinstance(message, TaskResult):
|
||||
assert message == result
|
||||
assert compare_task_results(message, result)
|
||||
else:
|
||||
assert message == result.messages[index]
|
||||
assert compare_messages(message, result.messages[index])
|
||||
index += 1
|
||||
|
||||
# Test state saving and loading.
|
||||
@ -273,9 +273,9 @@ async def test_run_with_tools_and_reflection() -> None:
|
||||
index = 0
|
||||
async for message in agent.run_stream(task="task"):
|
||||
if isinstance(message, TaskResult):
|
||||
assert message == result
|
||||
assert compare_task_results(message, result)
|
||||
else:
|
||||
assert message == result.messages[index]
|
||||
assert compare_messages(message, result.messages[index])
|
||||
index += 1
|
||||
|
||||
# Test state saving and loading.
|
||||
@ -363,9 +363,9 @@ async def test_run_with_parallel_tools() -> None:
|
||||
index = 0
|
||||
async for message in agent.run_stream(task="task"):
|
||||
if isinstance(message, TaskResult):
|
||||
assert message == result
|
||||
assert compare_task_results(message, result)
|
||||
else:
|
||||
assert message == result.messages[index]
|
||||
assert compare_messages(message, result.messages[index])
|
||||
index += 1
|
||||
|
||||
# Test state saving and loading.
|
||||
@ -446,9 +446,9 @@ async def test_run_with_parallel_tools_with_empty_call_ids() -> None:
|
||||
index = 0
|
||||
async for message in agent.run_stream(task="task"):
|
||||
if isinstance(message, TaskResult):
|
||||
assert message == result
|
||||
assert compare_task_results(message, result)
|
||||
else:
|
||||
assert message == result.messages[index]
|
||||
assert compare_messages(message, result.messages[index])
|
||||
index += 1
|
||||
|
||||
# Test state saving and loading.
|
||||
@ -560,9 +560,9 @@ async def test_run_with_workbench() -> None:
|
||||
index = 0
|
||||
async for message in agent.run_stream(task="task"):
|
||||
if isinstance(message, TaskResult):
|
||||
assert message == result
|
||||
assert compare_task_results(message, result)
|
||||
else:
|
||||
assert message == result.messages[index]
|
||||
assert compare_messages(message, result.messages[index])
|
||||
index += 1
|
||||
|
||||
# Test state saving and loading.
|
||||
@ -779,9 +779,9 @@ async def test_handoffs() -> None:
|
||||
index = 0
|
||||
async for message in tool_use_agent.run_stream(task="task"):
|
||||
if isinstance(message, TaskResult):
|
||||
assert message == result
|
||||
assert compare_task_results(message, result)
|
||||
else:
|
||||
assert message == result.messages[index]
|
||||
assert compare_messages(message, result.messages[index])
|
||||
index += 1
|
||||
|
||||
|
||||
@ -852,9 +852,9 @@ async def test_handoff_with_tool_call_context() -> None:
|
||||
index = 0
|
||||
async for message in tool_use_agent.run_stream(task="task"):
|
||||
if isinstance(message, TaskResult):
|
||||
assert message == result
|
||||
assert compare_task_results(message, result)
|
||||
else:
|
||||
assert message == result.messages[index]
|
||||
assert compare_messages(message, result.messages[index])
|
||||
index += 1
|
||||
|
||||
|
||||
@ -927,9 +927,9 @@ async def test_custom_handoffs() -> None:
|
||||
index = 0
|
||||
async for message in tool_use_agent.run_stream(task="task"):
|
||||
if isinstance(message, TaskResult):
|
||||
assert message == result
|
||||
assert compare_task_results(message, result)
|
||||
else:
|
||||
assert message == result.messages[index]
|
||||
assert compare_messages(message, result.messages[index])
|
||||
index += 1
|
||||
|
||||
|
||||
@ -1004,9 +1004,9 @@ async def test_custom_object_handoffs() -> None:
|
||||
index = 0
|
||||
async for message in tool_use_agent.run_stream(task="task"):
|
||||
if isinstance(message, TaskResult):
|
||||
assert message == result
|
||||
assert compare_task_results(message, result)
|
||||
else:
|
||||
assert message == result.messages[index]
|
||||
assert compare_messages(message, result.messages[index])
|
||||
index += 1
|
||||
|
||||
|
||||
@ -1161,9 +1161,9 @@ async def test_list_chat_messages(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
index = 0
|
||||
async for message in agent.run_stream(task=messages):
|
||||
if isinstance(message, TaskResult):
|
||||
assert message == result
|
||||
assert compare_task_results(message, result)
|
||||
else:
|
||||
assert message == result.messages[index]
|
||||
assert compare_messages(message, result.messages[index])
|
||||
index += 1
|
||||
|
||||
|
||||
|
@ -55,7 +55,7 @@ from autogen_ext.code_executors.local import LocalCommandLineCodeExecutor
|
||||
from autogen_ext.models.openai import OpenAIChatCompletionClient
|
||||
from autogen_ext.models.replay import ReplayChatCompletionClient
|
||||
from pydantic import BaseModel
|
||||
from utils import FileLogHandler
|
||||
from utils import FileLogHandler, compare_messages, compare_task_results
|
||||
|
||||
logger = logging.getLogger(EVENT_LOGGER_NAME)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
@ -269,9 +269,9 @@ async def test_round_robin_group_chat(runtime: AgentRuntime | None) -> None:
|
||||
task="Write a program that prints 'Hello, world!'",
|
||||
):
|
||||
if isinstance(message, TaskResult):
|
||||
assert message == result
|
||||
assert compare_task_results(message, result)
|
||||
else:
|
||||
assert message == result.messages[index]
|
||||
assert compare_messages(message, result.messages[index])
|
||||
index += 1
|
||||
|
||||
# Test message input.
|
||||
@ -282,7 +282,7 @@ async def test_round_robin_group_chat(runtime: AgentRuntime | None) -> None:
|
||||
result_2 = await team.run(
|
||||
task=TextMessage(content="Write a program that prints 'Hello, world!'", source="user")
|
||||
)
|
||||
assert result == result_2
|
||||
assert compare_task_results(result, result_2)
|
||||
|
||||
# Test multi-modal message.
|
||||
model_client.reset()
|
||||
@ -293,7 +293,9 @@ async def test_round_robin_group_chat(runtime: AgentRuntime | None) -> None:
|
||||
assert isinstance(result.messages[0], TextMessage)
|
||||
assert isinstance(result_2.messages[0], MultiModalMessage)
|
||||
assert result.messages[0].content == task.content[0]
|
||||
assert result.messages[1:] == result_2.messages[1:]
|
||||
assert len(result.messages[1:]) == len(result_2.messages[1:])
|
||||
for i in range(1, len(result.messages)):
|
||||
assert compare_messages(result.messages[i], result_2.messages[i])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -339,9 +341,9 @@ async def test_round_robin_group_chat_with_team_event(runtime: AgentRuntime | No
|
||||
task="Write a program that prints 'Hello, world!'",
|
||||
):
|
||||
if isinstance(message, TaskResult):
|
||||
assert message == result
|
||||
assert compare_task_results(message, result)
|
||||
else:
|
||||
assert message == result.messages[index]
|
||||
assert compare_messages(message, result.messages[index])
|
||||
index += 1
|
||||
|
||||
|
||||
@ -496,9 +498,9 @@ async def test_round_robin_group_chat_with_tools(runtime: AgentRuntime | None) -
|
||||
task="Write a program that prints 'Hello, world!'",
|
||||
):
|
||||
if isinstance(message, TaskResult):
|
||||
assert message == result
|
||||
assert compare_task_results(message, result)
|
||||
else:
|
||||
assert message == result.messages[index]
|
||||
assert compare_messages(message, result.messages[index])
|
||||
index += 1
|
||||
|
||||
# Test Console.
|
||||
@ -507,7 +509,7 @@ async def test_round_robin_group_chat_with_tools(runtime: AgentRuntime | None) -
|
||||
index = 0
|
||||
await team.reset()
|
||||
result2 = await Console(team.run_stream(task="Write a program that prints 'Hello, world!'"))
|
||||
assert result2 == result
|
||||
assert compare_task_results(result2, result)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -685,9 +687,9 @@ async def test_selector_group_chat(runtime: AgentRuntime | None) -> None:
|
||||
task="Write a program that prints 'Hello, world!'",
|
||||
):
|
||||
if isinstance(message, TaskResult):
|
||||
assert message == result
|
||||
assert compare_task_results(message, result)
|
||||
else:
|
||||
assert message == result.messages[index]
|
||||
assert compare_messages(message, result.messages[index])
|
||||
index += 1
|
||||
|
||||
# Test Console.
|
||||
@ -696,7 +698,7 @@ async def test_selector_group_chat(runtime: AgentRuntime | None) -> None:
|
||||
index = 0
|
||||
await team.reset()
|
||||
result2 = await Console(team.run_stream(task="Write a program that prints 'Hello, world!'"))
|
||||
assert result2 == result
|
||||
assert compare_task_results(result2, result)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -806,9 +808,9 @@ async def test_selector_group_chat_with_team_event(runtime: AgentRuntime | None)
|
||||
task="Write a program that prints 'Hello, world!'",
|
||||
):
|
||||
if isinstance(message, TaskResult):
|
||||
assert message == result
|
||||
assert compare_task_results(message, result)
|
||||
else:
|
||||
assert message == result.messages[index]
|
||||
assert compare_messages(message, result.messages[index])
|
||||
index += 1
|
||||
|
||||
|
||||
@ -910,9 +912,9 @@ async def test_selector_group_chat_two_speakers(runtime: AgentRuntime | None) ->
|
||||
await team.reset()
|
||||
async for message in team.run_stream(task="Write a program that prints 'Hello, world!'"):
|
||||
if isinstance(message, TaskResult):
|
||||
assert message == result
|
||||
assert compare_task_results(message, result)
|
||||
else:
|
||||
assert message == result.messages[index]
|
||||
assert compare_messages(message, result.messages[index])
|
||||
index += 1
|
||||
|
||||
# Test Console.
|
||||
@ -921,7 +923,7 @@ async def test_selector_group_chat_two_speakers(runtime: AgentRuntime | None) ->
|
||||
index = 0
|
||||
await team.reset()
|
||||
result2 = await Console(team.run_stream(task="Write a program that prints 'Hello, world!'"))
|
||||
assert result2 == result
|
||||
assert compare_task_results(result2, result)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -958,9 +960,9 @@ async def test_selector_group_chat_two_speakers_allow_repeated(runtime: AgentRun
|
||||
await team.reset()
|
||||
async for message in team.run_stream(task="Write a program that prints 'Hello, world!'"):
|
||||
if isinstance(message, TaskResult):
|
||||
assert message == result
|
||||
assert compare_task_results(message, result)
|
||||
else:
|
||||
assert message == result.messages[index]
|
||||
assert compare_messages(message, result.messages[index])
|
||||
index += 1
|
||||
|
||||
# Test Console.
|
||||
@ -968,7 +970,7 @@ async def test_selector_group_chat_two_speakers_allow_repeated(runtime: AgentRun
|
||||
index = 0
|
||||
await team.reset()
|
||||
result2 = await Console(team.run_stream(task="Write a program that prints 'Hello, world!'"))
|
||||
assert result2 == result
|
||||
assert compare_task_results(result2, result)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -1174,9 +1176,9 @@ async def test_swarm_handoff(runtime: AgentRuntime | None) -> None:
|
||||
stream = team.run_stream(task="task")
|
||||
async for message in stream:
|
||||
if isinstance(message, TaskResult):
|
||||
assert message == result
|
||||
assert compare_task_results(message, result)
|
||||
else:
|
||||
assert message == result.messages[index]
|
||||
assert compare_messages(message, result.messages[index])
|
||||
index += 1
|
||||
|
||||
# Test save and load.
|
||||
@ -1248,9 +1250,9 @@ async def test_swarm_handoff_with_team_events(runtime: AgentRuntime | None) -> N
|
||||
stream = team.run_stream(task="task")
|
||||
async for message in stream:
|
||||
if isinstance(message, TaskResult):
|
||||
assert message == result
|
||||
assert compare_task_results(message, result)
|
||||
else:
|
||||
assert message == result.messages[index]
|
||||
assert compare_messages(message, result.messages[index])
|
||||
index += 1
|
||||
|
||||
|
||||
@ -1366,9 +1368,9 @@ async def test_swarm_handoff_using_tool_calls(runtime: AgentRuntime | None) -> N
|
||||
stream = team.run_stream(task="task")
|
||||
async for message in stream:
|
||||
if isinstance(message, TaskResult):
|
||||
assert message == result
|
||||
assert compare_task_results(message, result)
|
||||
else:
|
||||
assert message == result.messages[index]
|
||||
assert compare_messages(message, result.messages[index])
|
||||
index += 1
|
||||
|
||||
# Test Console
|
||||
@ -1377,7 +1379,7 @@ async def test_swarm_handoff_using_tool_calls(runtime: AgentRuntime | None) -> N
|
||||
index = 0
|
||||
await team.reset()
|
||||
result2 = await Console(team.run_stream(task="task"))
|
||||
assert result2 == result
|
||||
assert compare_task_results(result2, result)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -1471,14 +1473,17 @@ async def test_swarm_with_parallel_tool_calls(runtime: AgentRuntime | None) -> N
|
||||
team = Swarm([agent1, agent2], termination_condition=termination, runtime=runtime)
|
||||
result = await team.run(task="task")
|
||||
assert len(result.messages) == 6
|
||||
assert result.messages[0] == TextMessage(content="task", source="user")
|
||||
assert compare_messages(result.messages[0], TextMessage(content="task", source="user"))
|
||||
assert isinstance(result.messages[1], ToolCallRequestEvent)
|
||||
assert isinstance(result.messages[2], ToolCallExecutionEvent)
|
||||
assert result.messages[3] == HandoffMessage(
|
||||
content="handoff to agent2",
|
||||
target="agent2",
|
||||
source="agent1",
|
||||
context=expected_handoff_context,
|
||||
assert compare_messages(
|
||||
result.messages[3],
|
||||
HandoffMessage(
|
||||
content="handoff to agent2",
|
||||
target="agent2",
|
||||
source="agent1",
|
||||
context=expected_handoff_context,
|
||||
),
|
||||
)
|
||||
assert isinstance(result.messages[4], TextMessage)
|
||||
assert result.messages[4].content == "Hello"
|
||||
@ -1598,9 +1603,9 @@ async def test_round_robin_group_chat_with_message_list(runtime: AgentRuntime |
|
||||
index = 0
|
||||
async for message in team.run_stream(task=messages):
|
||||
if isinstance(message, TaskResult):
|
||||
assert message == result
|
||||
assert compare_task_results(message, result)
|
||||
else:
|
||||
assert message == result.messages[index]
|
||||
assert compare_messages(message, result.messages[index])
|
||||
index += 1
|
||||
|
||||
# Test with invalid message list
|
||||
@ -1803,7 +1808,7 @@ async def test_selector_group_chat_streaming(runtime: AgentRuntime | None) -> No
|
||||
streaming: List[str] = []
|
||||
async for message in team.run_stream(task="Write a program that prints 'Hello, world!'"):
|
||||
if isinstance(message, TaskResult):
|
||||
assert message == result
|
||||
assert compare_task_results(message, result)
|
||||
elif isinstance(message, ModelClientStreamingChunkEvent):
|
||||
streaming.append(message.content)
|
||||
else:
|
||||
@ -1811,5 +1816,5 @@ async def test_selector_group_chat_streaming(runtime: AgentRuntime | None) -> No
|
||||
assert isinstance(message, SelectorEvent)
|
||||
assert message.content == "".join([chunk for chunk in streaming])
|
||||
streaming = []
|
||||
assert message == result.messages[index]
|
||||
assert compare_messages(message, result.messages[index])
|
||||
index += 1
|
||||
|
@ -32,6 +32,7 @@ from autogen_agentchat.teams._group_chat._graph._digraph_group_chat import (
|
||||
from autogen_core import AgentRuntime, CancellationToken, Component, SingleThreadedAgentRuntime
|
||||
from autogen_ext.models.replay import ReplayChatCompletionClient
|
||||
from pydantic import BaseModel
|
||||
from utils import compare_message_lists, compare_task_results
|
||||
|
||||
|
||||
def test_create_digraph() -> None:
|
||||
@ -1207,10 +1208,10 @@ async def test_graph_flow_serialize_deserialize() -> None:
|
||||
de_results = await deserialized_team.run(task="Start")
|
||||
|
||||
assert serialized == serialized_deserialized
|
||||
assert results == de_results
|
||||
assert compare_task_results(results, de_results)
|
||||
assert results.stop_reason is not None
|
||||
assert results.stop_reason == de_results.stop_reason
|
||||
assert results.messages == de_results.messages
|
||||
assert compare_message_lists(results.messages, de_results.messages)
|
||||
assert isinstance(results.messages[0], TextMessage)
|
||||
assert results.messages[0].source == "user"
|
||||
assert results.messages[0].content == "Start"
|
||||
|
@ -2,7 +2,10 @@ import json
|
||||
import logging
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from typing import Sequence
|
||||
|
||||
from autogen_agentchat.base._task import TaskResult
|
||||
from autogen_agentchat.messages import BaseAgentEvent, BaseChatMessage, BaseTextChatMessage
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
@ -18,7 +21,7 @@ class FileLogHandler(logging.Handler):
|
||||
record.msg = json.dumps(
|
||||
{
|
||||
"timestamp": ts,
|
||||
"message": record.msg.model_dump(),
|
||||
"message": record.msg.model_dump_json(indent=2),
|
||||
"type": record.msg.__class__.__name__,
|
||||
},
|
||||
)
|
||||
@ -37,3 +40,36 @@ class ConsoleLogHandler(logging.Handler):
|
||||
},
|
||||
)
|
||||
sys.stdout.write(f"{record.msg}\n")
|
||||
|
||||
|
||||
def compare_messages(
|
||||
msg1: BaseAgentEvent | BaseChatMessage | BaseTextChatMessage,
|
||||
msg2: BaseAgentEvent | BaseChatMessage | BaseTextChatMessage,
|
||||
) -> bool:
|
||||
if isinstance(msg1, BaseTextChatMessage) and isinstance(msg2, BaseTextChatMessage):
|
||||
if msg1.content != msg2.content:
|
||||
return False
|
||||
return (
|
||||
(msg1.source == msg2.source) and (msg1.models_usage == msg2.models_usage) and (msg1.metadata == msg2.metadata)
|
||||
)
|
||||
|
||||
|
||||
def compare_message_lists(
|
||||
msgs1: Sequence[BaseAgentEvent | BaseChatMessage],
|
||||
msgs2: Sequence[BaseAgentEvent | BaseChatMessage],
|
||||
) -> bool:
|
||||
if len(msgs1) != len(msgs2):
|
||||
return False
|
||||
for i in range(len(msgs1)):
|
||||
if not compare_messages(msgs1[i], msgs2[i]):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def compare_task_results(
|
||||
res1: TaskResult,
|
||||
res2: TaskResult,
|
||||
) -> bool:
|
||||
if res1.stop_reason != res2.stop_reason:
|
||||
return False
|
||||
return compare_message_lists(res1.messages, res2.messages)
|
||||
|
@ -32,7 +32,7 @@ class FileLogHandler(logging.Handler):
|
||||
record.msg = json.dumps(
|
||||
{
|
||||
"timestamp": ts,
|
||||
"message": record.msg.model_dump(),
|
||||
"message": record.msg.model_dump_json(indent=2),
|
||||
"type": record.msg.__class__.__name__,
|
||||
},
|
||||
)
|
||||
|
@ -33,7 +33,7 @@ class FileLogHandler(logging.Handler):
|
||||
record.msg = json.dumps(
|
||||
{
|
||||
"timestamp": ts,
|
||||
"message": record.msg.model_dump(),
|
||||
"message": record.msg.model_dump_json(indent=2),
|
||||
"type": record.msg.__class__.__name__,
|
||||
},
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user