diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_graph/_digraph_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_graph/_digraph_group_chat.py index 9917a2a8b..87b083b3d 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_graph/_digraph_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_graph/_digraph_group_chat.py @@ -646,6 +646,9 @@ class GraphFlow(BaseGroupChat, Component[GraphFlowConfig]): runtime: AgentRuntime | None = None, custom_message_types: List[type[BaseAgentEvent | BaseChatMessage]] | None = None, ) -> None: + self._input_participants = participants + self._input_termination_condition = termination_condition + stop_agent = _StopAgent() stop_agent_termination = StopMessageTermination() termination_condition = ( @@ -700,8 +703,10 @@ class GraphFlow(BaseGroupChat, Component[GraphFlowConfig]): def _to_config(self) -> GraphFlowConfig: """Converts the instance into a configuration object.""" - participants = [participant.dump_component() for participant in self._participants] - termination_condition = self._termination_condition.dump_component() if self._termination_condition else None + participants = [participant.dump_component() for participant in self._input_participants] + termination_condition = ( + self._input_termination_condition.dump_component() if self._input_termination_condition else None + ) return GraphFlowConfig( participants=participants, termination_condition=termination_condition, diff --git a/python/packages/autogen-agentchat/tests/test_group_chat_graph.py b/python/packages/autogen-agentchat/tests/test_group_chat_graph.py index c6aa5362d..7c8baf4b1 100644 --- a/python/packages/autogen-agentchat/tests/test_group_chat_graph.py +++ b/python/packages/autogen-agentchat/tests/test_group_chat_graph.py @@ -13,7 +13,7 @@ from autogen_agentchat.agents import ( ) from autogen_agentchat.base import Response, TaskResult from autogen_agentchat.conditions import MaxMessageTermination -from autogen_agentchat.messages import BaseChatMessage, ChatMessage, MessageFactory, TextMessage +from autogen_agentchat.messages import BaseChatMessage, ChatMessage, MessageFactory, StopMessage, TextMessage from autogen_agentchat.messages import BaseTextChatMessage as TextChatMessage from autogen_agentchat.teams import ( DiGraphBuilder, @@ -1399,3 +1399,51 @@ async def test_graph_builder_with_filter_agent(runtime: AgentRuntime | None) -> result = await team.run(task="Hello") assert any(m.source == "X" and m.content == "Hello" for m in result.messages) # type: ignore[union-attr] assert result.stop_reason is not None + + +@pytest.mark.asyncio +async def test_graph_flow_serialize_deserialize() -> None: + client_a = ReplayChatCompletionClient(list(map(str, range(10)))) + client_b = ReplayChatCompletionClient(list(map(str, range(10)))) + a = AssistantAgent("A", model_client=client_a) + b = AssistantAgent("B", model_client=client_b) + + builder = DiGraphBuilder() + builder.add_node(a).add_node(b) + builder.add_edge(a, b) + builder.set_entry_point(a) + + team = GraphFlow( + participants=builder.get_participants(), + graph=builder.build(), + runtime=None, + termination_condition=MaxMessageTermination(5), + ) + + serialized = team.dump_component() + deserialized_team = GraphFlow.load_component(serialized) + serialized_deserialized = deserialized_team.dump_component() + + results = await team.run(task="Start") + de_results = await deserialized_team.run(task="Start") + + assert serialized == serialized_deserialized + assert results == de_results + assert results.stop_reason is not None + assert results.stop_reason == de_results.stop_reason + assert results.messages == de_results.messages + assert isinstance(results.messages[0], TextMessage) + assert results.messages[0].source == "user" + assert results.messages[0].content == "Start" + assert isinstance(results.messages[1], TextMessage) + assert results.messages[1].source == "A" + assert results.messages[1].content == "0" + assert isinstance(results.messages[2], TextMessage) + assert results.messages[2].source == "A" + assert results.messages[2].content == "1" + assert isinstance(results.messages[3], TextMessage) + assert results.messages[3].source == "B" + assert results.messages[3].content == "0" + assert isinstance(results.messages[-1], StopMessage) + assert results.messages[-1].source == _DIGRAPH_STOP_AGENT_NAME + assert results.messages[-1].content == "Digraph execution is complete"