Add callable condition for GraphFlow edges (#6623)

This PR adds callable as an option to specify conditional edges in
GraphFlow.

```python
import asyncio

from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.conditions import MaxMessageTermination
from autogen_agentchat.teams import DiGraphBuilder, GraphFlow
from autogen_ext.models.openai import OpenAIChatCompletionClient


async def main():
    # Initialize agents with OpenAI model clients.
    model_client = OpenAIChatCompletionClient(model="gpt-4.1-nano")
    agent_a = AssistantAgent(
        "A",
        model_client=model_client,
        system_message="Detect if the input is in Chinese. If it is, say 'yes', else say 'no', and nothing else.",
    )
    agent_b = AssistantAgent("B", model_client=model_client, system_message="Translate input to English.")
    agent_c = AssistantAgent("C", model_client=model_client, system_message="Translate input to Chinese.")

    # Create a directed graph with conditional branching flow A -> B ("yes"), A -> C (otherwise).
    builder = DiGraphBuilder()
    builder.add_node(agent_a).add_node(agent_b).add_node(agent_c)
    # Create conditions as callables that check the message content.
    builder.add_edge(agent_a, agent_b, condition=lambda msg: "yes" in msg.to_model_text())
    builder.add_edge(agent_a, agent_c, condition=lambda msg: "yes" not in msg.to_model_text())
    graph = builder.build()

    # Create a GraphFlow team with the directed graph.
    team = GraphFlow(
        participants=[agent_a, agent_b, agent_c],
        graph=graph,
        termination_condition=MaxMessageTermination(5),
    )

    # Run the team and print the events.
    async for event in team.run_stream(task="AutoGen is a framework for building AI agents."):
        print(event)


asyncio.run(main())
```

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: ekzhu <320302+ekzhu@users.noreply.github.com>
This commit is contained in:
Eric Zhu 2025-06-04 15:43:26 -07:00 committed by GitHub
parent 9065c6f37b
commit b31b4e508d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 244 additions and 54 deletions

View File

@ -1,9 +1,9 @@
import asyncio
from collections import Counter, deque
from typing import Any, Callable, Deque, Dict, List, Literal, Mapping, Sequence, Set
from typing import Any, Callable, Deque, Dict, List, Literal, Mapping, Sequence, Set, Union
from autogen_core import AgentRuntime, CancellationToken, Component, ComponentModel
from pydantic import BaseModel
from pydantic import BaseModel, Field, model_validator
from typing_extensions import Self
from autogen_agentchat.agents import BaseChatAgent
@ -34,16 +34,50 @@ class DiGraphEdge(BaseModel):
This is an experimental feature, and the API will change in the future releases.
.. warning::
If the condition is a callable, it will not be serialized in the model.
"""
target: str # Target node name
condition: str | None = None # Optional execution condition (trigger-based)
condition: Union[str, Callable[[BaseChatMessage], bool], None] = Field(default=None)
"""(Experimental) Condition to execute this edge.
If None, the edge is unconditional. If a string, the edge is conditional on the presence of that string in the last agent chat message.
NOTE: This is an experimental feature WILL change in the future releases to allow for better spcification of branching conditions
similar to the `TerminationCondition` class.
If None, the edge is unconditional.
If a string, the edge is conditional on the presence of that string in the last agent chat message.
If a callable, the edge is conditional on the callable returning True when given the last message.
"""
# Using Field to exclude the condition in serialization if it's a callable
condition_function: Callable[[BaseChatMessage], bool] | None = Field(default=None, exclude=True)
@model_validator(mode="after")
def _validate_condition(self) -> "DiGraphEdge":
# Store callable in a separate field and set condition to None for serialization
if callable(self.condition):
self.condition_function = self.condition
# For serialization purposes, we'll set the condition to None
# when storing as a pydantic model/dict
object.__setattr__(self, "condition", None)
return self
def check_condition(self, message: BaseChatMessage) -> bool:
"""Check if the edge condition is satisfied for the given message.
Args:
message: The message to check the condition against.
Returns:
True if condition is satisfied (None condition always returns True),
False otherwise.
"""
if self.condition_function is not None:
return self.condition_function(message)
elif isinstance(self.condition, str):
# If it's a string, check if the string is in the message content
return self.condition in message.to_model_text()
return True # None condition is always satisfied
class DiGraphNode(BaseModel):
"""Represents a node (agent) in a :class:`DiGraph`, with its outgoing edges and activation type.
@ -125,7 +159,7 @@ class DiGraph(BaseModel):
cycle_edges: List[DiGraphEdge] = []
for n in cycle_nodes:
cycle_edges.extend(self.nodes[n].edges)
if not any(edge.condition for edge in cycle_edges):
if not any(edge.condition is not None for edge in cycle_edges):
raise ValueError(
f"Cycle detected without exit condition: {' -> '.join(cycle_nodes + cycle_nodes[:1])}"
)
@ -164,7 +198,7 @@ class DiGraph(BaseModel):
# Outgoing edge condition validation (per node)
for node in self.nodes.values():
# Check that if a node has an outgoing conditional edge, then all outgoing edges are conditional
has_condition = any(edge.condition for edge in node.edges)
has_condition = any(edge.condition is not None for edge in node.edges)
has_unconditioned = any(edge.condition is None for edge in node.edges)
if has_condition and has_unconditioned:
raise ValueError(f"Node '{node.name}' has a mix of conditional and unconditional edges.")
@ -239,11 +273,11 @@ class GraphFlowManager(BaseGroupChatManager):
return
assert isinstance(message, BaseChatMessage)
source = message.source
content = message.to_model_text()
# Propagate the update to the children of the node.
for edge in self._edges[source]:
if edge.condition and edge.condition not in content:
# Use the new check_condition method that handles both string and callable conditions
if not edge.check_condition(message):
continue
if self._activation[edge.target] == "all":
self._remaining[edge.target] -= 1
@ -360,6 +394,11 @@ class GraphFlow(BaseGroupChat, Component[GraphFlowConfig]):
See the :class:`DiGraphBuilder` documentation for more details.
The :class:`GraphFlow` class is designed to be used with the :class:`DiGraphBuilder` for creating complex workflows.
.. warning::
When using callable conditions in edges, they will not be serialized
when calling :meth:`dump_component`. This will be addressed in future releases.
Args:
participants (List[ChatAgent]): The participants in the group chat.
@ -450,7 +489,7 @@ class GraphFlow(BaseGroupChat, Component[GraphFlowConfig]):
asyncio.run(main())
**Conditional Branching: A B (if 'yes') or C (if 'no')**
**Conditional Branching: A B (if 'yes') or C (otherwise)**
.. code-block:: python
@ -473,11 +512,12 @@ class GraphFlow(BaseGroupChat, Component[GraphFlowConfig]):
agent_b = AssistantAgent("B", model_client=model_client, system_message="Translate input to English.")
agent_c = AssistantAgent("C", model_client=model_client, system_message="Translate input to Chinese.")
# Create a directed graph with conditional branching flow A -> B ("yes"), A -> C ("no").
# Create a directed graph with conditional branching flow A -> B ("yes"), A -> C (otherwise).
builder = DiGraphBuilder()
builder.add_node(agent_a).add_node(agent_b).add_node(agent_c)
builder.add_edge(agent_a, agent_b, condition="yes")
builder.add_edge(agent_a, agent_c, condition="no")
# Create conditions as callables that check the message content.
builder.add_edge(agent_a, agent_b, condition=lambda msg: "yes" in msg.to_model_text())
builder.add_edge(agent_a, agent_c, condition=lambda msg: "yes" not in msg.to_model_text())
graph = builder.build()
# Create a GraphFlow team with the directed graph.
@ -494,7 +534,7 @@ class GraphFlow(BaseGroupChat, Component[GraphFlowConfig]):
asyncio.run(main())
**Loop with exit condition: A B C (if 'APPROVE') or A (if 'REJECT')**
**Loop with exit condition: A B C (if 'APPROVE') or A (otherwise)**
.. code-block:: python
@ -518,17 +558,21 @@ class GraphFlow(BaseGroupChat, Component[GraphFlowConfig]):
"B",
model_client=model_client,
system_message="Provide feedback on the input, if your feedback has been addressed, "
"say 'APPROVE', else say 'REJECT' and provide a reason.",
"say 'APPROVE', otherwise provide a reason for rejection.",
)
agent_c = AssistantAgent(
"C", model_client=model_client, system_message="Translate the final product to Korean."
)
# Create a loop graph with conditional exit: A -> B -> C ("APPROVE"), B -> A ("REJECT").
# Create a loop graph with conditional exit: A -> B -> C ("APPROVE"), B -> A (otherwise).
builder = DiGraphBuilder()
builder.add_node(agent_a).add_node(agent_b).add_node(agent_c)
builder.add_edge(agent_a, agent_b)
builder.add_conditional_edges(agent_b, {"APPROVE": agent_c, "REJECT": agent_a})
# Create conditional edges using strings
builder.add_edge(agent_b, agent_c, condition=lambda msg: "APPROVE" in msg.to_model_text())
builder.add_edge(agent_b, agent_a, condition=lambda msg: "APPROVE" not in msg.to_model_text())
builder.set_entry_point(agent_a)
graph = builder.build()

View File

@ -1,6 +1,8 @@
from typing import Dict, Literal, Optional, Union
import warnings
from typing import Callable, Dict, Literal, Optional, Union
from autogen_agentchat.base import ChatAgent
from autogen_agentchat.messages import BaseChatMessage
from ._digraph_group_chat import DiGraph, DiGraphEdge, DiGraphNode
@ -22,7 +24,7 @@ class DiGraphBuilder:
- Cyclic loops with safe exits
Each node in the graph represents an agent. Edges define execution paths between agents,
and can optionally be conditioned on message content.
and can optionally be conditioned on message content using callable functions.
The builder is compatible with the `Graph` runner and supports both standard and filtered agents.
@ -49,16 +51,29 @@ class DiGraphBuilder:
>>> builder.add_node(agent_a).add_node(agent_b).add_node(agent_c)
>>> builder.add_edge(agent_a, agent_b).add_edge(agent_a, agent_c)
Example Conditional Branching A B ("yes"), A C ("no"):
Example Conditional Branching A B or A C:
>>> builder = GraphBuilder()
>>> builder.add_node(agent_a).add_node(agent_b).add_node(agent_c)
>>> builder.add_conditional_edges(agent_a, {"yes": agent_b, "no": agent_c})
>>> # Add conditional edges using keyword check
>>> builder.add_edge(agent_a, agent_b, condition="keyword1")
>>> builder.add_edge(agent_a, agent_c, condition="keyword2")
Example Loop: A B A ("loop"), B C ("exit"):
Example Using Custom String Conditions:
>>> builder = GraphBuilder()
>>> builder.add_node(agent_a).add_node(agent_b).add_node(agent_c)
>>> # Add condition strings to check in messages
>>> builder.add_edge(agent_a, agent_b, condition="big")
>>> builder.add_edge(agent_a, agent_c, condition="small")
Example Loop: A B A or B C:
>>> builder = GraphBuilder()
>>> builder.add_node(agent_a).add_node(agent_b).add_node(agent_c)
>>> builder.add_edge(agent_a, agent_b)
>>> builder.add_conditional_edges(agent_b, {"loop": agent_a, "exit": agent_c})
>> # Add a loop back to agent A
>>> builder.add_edge(agent_b, agent_a, condition=lambda msg: "loop" in msg.to_model_text())
>>> # Add exit condition to break the loop
>>> builder.add_edge(agent_b, agent_c, condition=lambda msg: "loop" not in msg.to_model_text())
"""
def __init__(self) -> None:
@ -78,9 +93,26 @@ class DiGraphBuilder:
return self
def add_edge(
self, source: Union[str, ChatAgent], target: Union[str, ChatAgent], condition: Optional[str] = None
self,
source: Union[str, ChatAgent],
target: Union[str, ChatAgent],
condition: Optional[Union[str, Callable[[BaseChatMessage], bool]]] = None,
) -> "DiGraphBuilder":
"""Add a directed edge from source to target, optionally with a condition."""
"""Add a directed edge from source to target, optionally with a condition.
Args:
source: Source node (agent name or agent object)
target: Target node (agent name or agent object)
condition: Optional condition for edge activation.
If string, activates when substring is found in message.
If callable, activates when function returns True for the message.
Returns:
Self for method chaining
Raises:
ValueError: If source or target node doesn't exist in the builder
"""
source_name = self._get_name(source)
target_name = self._get_name(target)
@ -95,9 +127,35 @@ class DiGraphBuilder:
def add_conditional_edges(
self, source: Union[str, ChatAgent], condition_to_target: Dict[str, Union[str, ChatAgent]]
) -> "DiGraphBuilder":
"""Add multiple conditional edges from a source node based on condition strings."""
for condition, target in condition_to_target.items():
self.add_edge(source, target, condition)
"""Add multiple conditional edges from a source node based on keyword checks.
.. warning::
This method interface will be changed in the future to support callable conditions.
Please use `add_edge` if you need to specify custom conditions.
Args:
source: Source node (agent name or agent object)
condition_to_target: Mapping from condition strings to target nodes
Each key is a keyword that will be checked in the message content
Each value is the target node to activate when condition is met
For each key (keyword), a lambda will be created that checks
if the keyword is in the message text.
Returns:
Self for method chaining
"""
warnings.warn(
"add_conditional_edges will be changed in the future to support callable conditions. "
"For now, please use add_edge if you need to specify custom conditions.",
DeprecationWarning,
stacklevel=2,
)
for condition_keyword, target in condition_to_target.items():
self.add_edge(source, target, condition=condition_keyword)
return self
def set_entry_point(self, name: Union[str, ChatAgent]) -> "DiGraphBuilder":

View File

@ -99,6 +99,7 @@ def test_get_leaf_nodes() -> None:
def test_serialization() -> None:
"""Test serializing and deserializing the graph."""
# Use a string condition instead of a lambda
graph = DiGraph(
nodes={
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B", condition="trigger1")]),
@ -114,6 +115,11 @@ def test_serialization() -> None:
assert deserialized_graph.nodes["A"].edges[0].condition == "trigger1"
assert deserialized_graph.nodes["B"].edges[0].target == "C"
# Test the original condition works
test_msg = TextMessage(content="this has trigger1 in it", source="test")
# Manually check if the string is in the message text
assert "trigger1" in test_msg.to_model_text()
def test_invalid_graph_no_start_node() -> None:
"""Test validation failure when there is no start node."""
@ -144,6 +150,7 @@ def test_invalid_graph_no_leaf_node() -> None:
def test_condition_edge_execution() -> None:
"""Test conditional edge execution support."""
# Use string condition
graph = DiGraph(
nodes={
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B", condition="TRIGGER")]),
@ -152,6 +159,15 @@ def test_condition_edge_execution() -> None:
}
)
# Check the condition manually
test_message = TextMessage(content="This has TRIGGER in it", source="test")
non_match_message = TextMessage(content="This doesn't match", source="test")
# Check if the string condition is in each message text
assert "TRIGGER" in test_message.to_model_text()
assert "TRIGGER" not in non_match_message.to_model_text()
# Check the condition itself
assert graph.nodes["A"].edges[0].condition == "TRIGGER"
assert graph.nodes["B"].edges[0].condition is None
@ -193,6 +209,7 @@ def test_cycle_detection_no_cycle() -> None:
def test_cycle_detection_with_exit_condition() -> None:
"""Test a graph with cycle and conditional exit passes validation."""
# Use a string condition
graph = DiGraph(
nodes={
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B")]),
@ -258,6 +275,7 @@ def test_validate_graph_missing_leaf_node() -> None:
def test_validate_graph_mixed_conditions() -> None:
"""Test validation failure when node has mixed conditional and unconditional edges."""
# Use string for condition
graph = DiGraph(
nodes={
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B", condition="cond"), DiGraphEdge(target="C")]),
@ -552,6 +570,7 @@ async def test_digraph_group_chat_conditional_branch(runtime: AgentRuntime | Non
agent_b = _EchoAgent("B", description="Echo agent B")
agent_c = _EchoAgent("C", description="Echo agent C")
# Use string conditions
graph = DiGraph(
nodes={
"A": DiGraphNode(
@ -726,6 +745,7 @@ async def test_digraph_group_chat_multiple_conditional(runtime: AgentRuntime | N
agent_c = _EchoAgent("C", description="Echo agent C")
agent_d = _EchoAgent("D", description="Echo agent D")
# Use string conditions
graph = DiGraph(
nodes={
"A": DiGraphNode(
@ -1005,10 +1025,18 @@ def test_add_conditional_edges() -> None:
edges = builder.nodes["A"].edges
assert len(edges) == 2
conditions = {e.condition for e in edges}
targets = {e.target for e in edges}
assert conditions == {"yes", "no"}
assert targets == {"B", "C"}
# Extract the condition strings to compare them
conditions = [e.condition for e in edges]
assert "yes" in conditions
assert "no" in conditions
# Match edge targets with conditions
yes_edge = next(e for e in edges if e.condition == "yes")
no_edge = next(e for e in edges if e.condition == "no")
assert yes_edge.target == "B"
assert no_edge.target == "C"
def test_set_entry_point() -> None:
@ -1085,8 +1113,16 @@ def test_build_conditional_loop() -> None:
builder.set_entry_point(a)
graph = builder.build()
assert graph.nodes["B"].edges[0].condition == "loop"
assert graph.nodes["B"].edges[1].condition == "exit"
# Check that edges have the right conditions and targets
edges = graph.nodes["B"].edges
assert len(edges) == 2
# Find edges by their conditions
loop_edge = next(e for e in edges if e.condition == "loop")
exit_edge = next(e for e in edges if e.condition == "exit")
assert loop_edge.target == "A"
assert exit_edge.target == "C"
assert graph.has_cycles_with_exit()
@ -1152,6 +1188,7 @@ async def test_graph_builder_conditional_execution(runtime: AgentRuntime | None)
termination_condition=MaxMessageTermination(5),
)
# Input "no" should trigger the edge to C
result = await team.run(task="no")
sources = [m.source for m in result.messages]
assert "C" in sources
@ -1159,27 +1196,45 @@ async def test_graph_builder_conditional_execution(runtime: AgentRuntime | None)
@pytest.mark.asyncio
async def test_graph_builder_with_filter_agent(runtime: AgentRuntime | None) -> None:
inner = _EchoAgent("X", description="Echo X")
filter_agent = MessageFilterAgent(
name="X",
wrapped_agent=inner,
filter=MessageFilterConfig(per_source=[PerSourceFilter(source="user", position="last", count=1)]),
)
async def test_digraph_group_chat_callable_condition(runtime: AgentRuntime | None) -> None:
"""Test that string conditions work correctly in edge transitions."""
agent_a = _EchoAgent("A", description="Echo agent A")
agent_b = _EchoAgent("B", description="Echo agent B")
agent_c = _EchoAgent("C", description="Echo agent C")
builder = DiGraphBuilder()
builder.add_node(filter_agent)
graph = DiGraph(
nodes={
"A": DiGraphNode(
name="A",
edges=[
# Will go to B if "long" is in message
DiGraphEdge(target="B", condition="long"),
# Will go to C if "short" is in message
DiGraphEdge(target="C", condition="short"),
],
),
"B": DiGraphNode(name="B", edges=[]),
"C": DiGraphNode(name="C", edges=[]),
}
)
team = GraphFlow(
participants=builder.get_participants(),
graph=builder.build(),
participants=[agent_a, agent_b, agent_c],
graph=graph,
runtime=runtime,
termination_condition=MaxMessageTermination(3),
termination_condition=MaxMessageTermination(5),
)
result = await team.run(task="Hello")
assert any(m.source == "X" and m.content == "Hello" for m in result.messages) # type: ignore[union-attr]
assert result.stop_reason is not None
# Test with a message containing "long" - should go to B
result = await team.run(task="This is a long message")
assert result.messages[2].source == "B"
# Reset for next test
await team.reset()
# Test with a message containing "short" - should go to C
result = await team.run(task="This is a short message")
assert result.messages[2].source == "C"
@pytest.mark.asyncio
@ -1268,3 +1323,36 @@ async def test_graph_flow_stateful_pause_and_resume_with_termination() -> None:
assert len(result.messages) == 2
assert result.messages[0].source == "B"
assert result.messages[1].source == _DIGRAPH_STOP_AGENT_NAME
@pytest.mark.asyncio
async def test_builder_with_lambda_condition(runtime: AgentRuntime | None) -> None:
"""Test that DiGraphBuilder supports string conditions."""
agent_a = _EchoAgent("A", description="Echo agent A")
agent_b = _EchoAgent("B", description="Echo agent B")
agent_c = _EchoAgent("C", description="Echo agent C")
builder = DiGraphBuilder()
builder.add_node(agent_a).add_node(agent_b).add_node(agent_c)
# Using callable conditions
builder.add_edge(agent_a, agent_b, lambda msg: "even" in msg.to_model_text())
builder.add_edge(agent_a, agent_c, lambda msg: "odd" in msg.to_model_text())
team = GraphFlow(
participants=builder.get_participants(),
graph=builder.build(),
runtime=runtime,
termination_condition=MaxMessageTermination(5),
)
# Test with "even" in message - should go to B
result = await team.run(task="even length")
assert result.messages[2].source == "B"
# Reset for next test
await team.reset()
# Test with "odd" in message - should go to C
result = await team.run(task="odd message")
assert result.messages[2].source == "C"

View File

@ -414,7 +414,7 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": null,
"id": "af297db2",
"metadata": {},
"outputs": [
@ -591,7 +591,7 @@
"reviewer = AssistantAgent(\n",
" \"reviewer\",\n",
" model_client=model_client,\n",
" system_message=\"Review ideas and say 'REVISE' and provide feedbacks, or 'APPROVE' for final approval.\",\n",
" system_message=\"Review ideas and provide feedbacks, or just 'APPROVE' for final approval.\",\n",
")\n",
"summarizer_core = AssistantAgent(\n",
" \"summary\", model_client=model_client, system_message=\"Summarize the user request and the final feedback.\"\n",
@ -613,8 +613,8 @@
"builder = DiGraphBuilder()\n",
"builder.add_node(generator).add_node(reviewer).add_node(filtered_summarizer)\n",
"builder.add_edge(generator, reviewer)\n",
"builder.add_edge(reviewer, generator, condition=\"REVISE\")\n",
"builder.add_edge(reviewer, filtered_summarizer, condition=\"APPROVE\")\n",
"builder.add_edge(reviewer, filtered_summarizer, condition=lambda msg: \"APPROVE\" in msg.to_model_text())\n",
"builder.add_edge(reviewer, generator, condition=lambda msg: \"APPROVE\" not in msg.to_model_text())\n",
"builder.set_entry_point(generator) # Set entry point to generator. Required if there are no source nodes.\n",
"graph = builder.build()\n",
"\n",