From c7757de59eaa6cf3a7cd039ccddcdb7ceadd724f Mon Sep 17 00:00:00 2001 From: EeS Date: Thu, 1 May 2025 03:25:20 +0900 Subject: [PATCH] FIX: GraphFlow serialize/deserialize and adding test (#6434) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Why are these changes needed? ❗ Before Previously, GraphFlow.__init__() modified the inner_chats and termination_condition for internal execution logic (e.g., constructing _StopAgent or composing OrTerminationCondition). However, these modified values were also used during dump_component(), meaning the serialized config no longer matched the original inputs. As a result: 1. dump_component() → load_component() → dump_component() produced non-idempotent configs. 2. Internal-only constructs like _StopAgent were mistakenly serialized, even though they should only exist in runtime. ⸻ ✅ After This patch changes the behavior to: • Store original inner_chats and termination_condition as-is at initialization. • During to_config(), serialize only the original unmodified versions. • Avoid serializing _StopAgent or other dynamically built agents. • Ensure deserialization (from_config) produces a logically equivalent object without additional nesting or duplication. This ensures that: • GraphFlow.dump_component() → load_component() round-trip produces consistent, minimal configs. • Internal execution logic and serialized component structure are properly separated. ## Related issue number Closes #6431 ## Checks - [ ] I've included any doc changes needed for . See to build and test documentation locally. - [x] I've added tests (if relevant) corresponding to the changes introduced in this PR. - [x] I've made sure all auto checks have passed. --- .../_group_chat/_graph/_digraph_group_chat.py | 9 +++- .../tests/test_group_chat_graph.py | 50 ++++++++++++++++++- 2 files changed, 56 insertions(+), 3 deletions(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_graph/_digraph_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_graph/_digraph_group_chat.py index 9917a2a8b..87b083b3d 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_graph/_digraph_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_graph/_digraph_group_chat.py @@ -646,6 +646,9 @@ class GraphFlow(BaseGroupChat, Component[GraphFlowConfig]): runtime: AgentRuntime | None = None, custom_message_types: List[type[BaseAgentEvent | BaseChatMessage]] | None = None, ) -> None: + self._input_participants = participants + self._input_termination_condition = termination_condition + stop_agent = _StopAgent() stop_agent_termination = StopMessageTermination() termination_condition = ( @@ -700,8 +703,10 @@ class GraphFlow(BaseGroupChat, Component[GraphFlowConfig]): def _to_config(self) -> GraphFlowConfig: """Converts the instance into a configuration object.""" - participants = [participant.dump_component() for participant in self._participants] - termination_condition = self._termination_condition.dump_component() if self._termination_condition else None + participants = [participant.dump_component() for participant in self._input_participants] + termination_condition = ( + self._input_termination_condition.dump_component() if self._input_termination_condition else None + ) return GraphFlowConfig( participants=participants, termination_condition=termination_condition, diff --git a/python/packages/autogen-agentchat/tests/test_group_chat_graph.py b/python/packages/autogen-agentchat/tests/test_group_chat_graph.py index c6aa5362d..7c8baf4b1 100644 --- a/python/packages/autogen-agentchat/tests/test_group_chat_graph.py +++ b/python/packages/autogen-agentchat/tests/test_group_chat_graph.py @@ -13,7 +13,7 @@ from autogen_agentchat.agents import ( ) from autogen_agentchat.base import Response, TaskResult from autogen_agentchat.conditions import MaxMessageTermination -from autogen_agentchat.messages import BaseChatMessage, ChatMessage, MessageFactory, TextMessage +from autogen_agentchat.messages import BaseChatMessage, ChatMessage, MessageFactory, StopMessage, TextMessage from autogen_agentchat.messages import BaseTextChatMessage as TextChatMessage from autogen_agentchat.teams import ( DiGraphBuilder, @@ -1399,3 +1399,51 @@ async def test_graph_builder_with_filter_agent(runtime: AgentRuntime | None) -> result = await team.run(task="Hello") assert any(m.source == "X" and m.content == "Hello" for m in result.messages) # type: ignore[union-attr] assert result.stop_reason is not None + + +@pytest.mark.asyncio +async def test_graph_flow_serialize_deserialize() -> None: + client_a = ReplayChatCompletionClient(list(map(str, range(10)))) + client_b = ReplayChatCompletionClient(list(map(str, range(10)))) + a = AssistantAgent("A", model_client=client_a) + b = AssistantAgent("B", model_client=client_b) + + builder = DiGraphBuilder() + builder.add_node(a).add_node(b) + builder.add_edge(a, b) + builder.set_entry_point(a) + + team = GraphFlow( + participants=builder.get_participants(), + graph=builder.build(), + runtime=None, + termination_condition=MaxMessageTermination(5), + ) + + serialized = team.dump_component() + deserialized_team = GraphFlow.load_component(serialized) + serialized_deserialized = deserialized_team.dump_component() + + results = await team.run(task="Start") + de_results = await deserialized_team.run(task="Start") + + assert serialized == serialized_deserialized + assert results == de_results + assert results.stop_reason is not None + assert results.stop_reason == de_results.stop_reason + assert results.messages == de_results.messages + assert isinstance(results.messages[0], TextMessage) + assert results.messages[0].source == "user" + assert results.messages[0].content == "Start" + assert isinstance(results.messages[1], TextMessage) + assert results.messages[1].source == "A" + assert results.messages[1].content == "0" + assert isinstance(results.messages[2], TextMessage) + assert results.messages[2].source == "A" + assert results.messages[2].content == "1" + assert isinstance(results.messages[3], TextMessage) + assert results.messages[3].source == "B" + assert results.messages[3].content == "0" + assert isinstance(results.messages[-1], StopMessage) + assert results.messages[-1].source == _DIGRAPH_STOP_AGENT_NAME + assert results.messages[-1].content == "Digraph execution is complete"