From c99aa7416d44e16b9cd5bafcd709f6e2484ffe31 Mon Sep 17 00:00:00 2001 From: Eric Zhu Date: Wed, 4 Jun 2025 22:05:16 -0700 Subject: [PATCH] Fix graph validation logic and add tests (#6630) Follow up to #6629 --- .../_group_chat/_graph/_digraph_group_chat.py | 6 +- .../tests/test_group_chat_graph.py | 78 +++++++++++++++++++ 2 files changed, 82 insertions(+), 2 deletions(-) diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_graph/_digraph_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_graph/_digraph_group_chat.py index ff4a01f48..7a8ee1c1b 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_graph/_digraph_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_graph/_digraph_group_chat.py @@ -198,8 +198,10 @@ class DiGraph(BaseModel): # Outgoing edge condition validation (per node) for node in self.nodes.values(): # Check that if a node has an outgoing conditional edge, then all outgoing edges are conditional - has_condition = any(edge.condition is not None for edge in node.edges) - has_unconditioned = any(edge.condition is None for edge in node.edges) + has_condition = any( + edge.condition is not None or edge.condition_function is not None for edge in node.edges + ) + has_unconditioned = any(edge.condition is None and edge.condition_function is None for edge in node.edges) if has_condition and has_unconditioned: raise ValueError(f"Node '{node.name}' has a mix of conditional and unconditional edges.") diff --git a/python/packages/autogen-agentchat/tests/test_group_chat_graph.py b/python/packages/autogen-agentchat/tests/test_group_chat_graph.py index a18c0ae61..2fe2b4dc5 100644 --- a/python/packages/autogen-agentchat/tests/test_group_chat_graph.py +++ b/python/packages/autogen-agentchat/tests/test_group_chat_graph.py @@ -259,6 +259,19 @@ def test_validate_graph_success() -> None: graph.graph_validate() assert not graph.get_has_cycles() + # Use a lambda condition + graph_with_lambda = DiGraph( + nodes={ + "A": DiGraphNode( + name="A", edges=[DiGraphEdge(target="B", condition=lambda msg: "test" in msg.to_model_text())] + ), + "B": DiGraphNode(name="B", edges=[]), + } + ) + # No error should be raised + graph_with_lambda.graph_validate() + assert not graph_with_lambda.get_has_cycles() + def test_validate_graph_missing_start_node() -> None: """Test validation failure when no start node exists.""" @@ -298,6 +311,23 @@ def test_validate_graph_mixed_conditions() -> None: with pytest.raises(ValueError, match="Node 'A' has a mix of conditional and unconditional edges"): graph.graph_validate() + # Use lambda for condition + graph_with_lambda = DiGraph( + nodes={ + "A": DiGraphNode( + name="A", + edges=[ + DiGraphEdge(target="B", condition=lambda msg: "test" in msg.to_model_text()), + DiGraphEdge(target="C"), + ], + ), + "B": DiGraphNode(name="B", edges=[]), + "C": DiGraphNode(name="C", edges=[]), + } + ) + with pytest.raises(ValueError, match="Node 'A' has a mix of conditional and unconditional edges"): + graph_with_lambda.graph_validate() + @pytest.mark.asyncio async def test_invalid_digraph_manager_cycle_without_termination() -> None: @@ -603,6 +633,29 @@ async def test_digraph_group_chat_conditional_branch(runtime: AgentRuntime | Non result = await team.run(task="Trigger yes") assert result.messages[2].source == "B" + # Use lambda conditions + graph_with_lambda = DiGraph( + nodes={ + "A": DiGraphNode( + name="A", + edges=[ + DiGraphEdge(target="B", condition=lambda msg: "yes" in msg.to_model_text()), + DiGraphEdge(target="C", condition=lambda msg: "no" in msg.to_model_text()), + ], + ), + "B": DiGraphNode(name="B", edges=[], activation="any"), + "C": DiGraphNode(name="C", edges=[], activation="any"), + } + ) + team_with_lambda = GraphFlow( + participants=[agent_a, agent_b, agent_c], + graph=graph_with_lambda, + runtime=runtime, + termination_condition=MaxMessageTermination(5), + ) + result_with_lambda = await team_with_lambda.run(task="Trigger no") + assert result_with_lambda.messages[2].source == "C" + @pytest.mark.asyncio async def test_digraph_group_chat_loop_with_exit_condition(runtime: AgentRuntime | None) -> None: @@ -785,6 +838,31 @@ async def test_digraph_group_chat_multiple_conditional(runtime: AgentRuntime | N result = await team.run(task="banana") assert result.messages[2].source == "C" + # Use lambda conditions + graph_with_lambda = DiGraph( + nodes={ + "A": DiGraphNode( + name="A", + edges=[ + DiGraphEdge(target="B", condition=lambda msg: "apple" in msg.to_model_text()), + DiGraphEdge(target="C", condition=lambda msg: "banana" in msg.to_model_text()), + DiGraphEdge(target="D", condition=lambda msg: "cherry" in msg.to_model_text()), + ], + ), + "B": DiGraphNode(name="B", edges=[]), + "C": DiGraphNode(name="C", edges=[]), + "D": DiGraphNode(name="D", edges=[]), + } + ) + team_with_lambda = GraphFlow( + participants=[agent_a, agent_b, agent_c, agent_d], + graph=graph_with_lambda, + runtime=runtime, + termination_condition=MaxMessageTermination(5), + ) + result_with_lambda = await team_with_lambda.run(task="cherry") + assert result_with_lambda.messages[2].source == "D" + class _TestMessageFilterAgentConfig(BaseModel): name: str