autogen/python/packages/autogen-agentchat/tests/test_group_chat_graph.py
Eric Zhu b31b4e508d
Add callable condition for GraphFlow edges (#6623)
This PR adds callable as an option to specify conditional edges in
GraphFlow.

```python
import asyncio

from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.conditions import MaxMessageTermination
from autogen_agentchat.teams import DiGraphBuilder, GraphFlow
from autogen_ext.models.openai import OpenAIChatCompletionClient


async def main():
    # Initialize agents with OpenAI model clients.
    model_client = OpenAIChatCompletionClient(model="gpt-4.1-nano")
    agent_a = AssistantAgent(
        "A",
        model_client=model_client,
        system_message="Detect if the input is in Chinese. If it is, say 'yes', else say 'no', and nothing else.",
    )
    agent_b = AssistantAgent("B", model_client=model_client, system_message="Translate input to English.")
    agent_c = AssistantAgent("C", model_client=model_client, system_message="Translate input to Chinese.")

    # Create a directed graph with conditional branching flow A -> B ("yes"), A -> C (otherwise).
    builder = DiGraphBuilder()
    builder.add_node(agent_a).add_node(agent_b).add_node(agent_c)
    # Create conditions as callables that check the message content.
    builder.add_edge(agent_a, agent_b, condition=lambda msg: "yes" in msg.to_model_text())
    builder.add_edge(agent_a, agent_c, condition=lambda msg: "yes" not in msg.to_model_text())
    graph = builder.build()

    # Create a GraphFlow team with the directed graph.
    team = GraphFlow(
        participants=[agent_a, agent_b, agent_c],
        graph=graph,
        termination_condition=MaxMessageTermination(5),
    )

    # Run the team and print the events.
    async for event in team.run_stream(task="AutoGen is a framework for building AI agents."):
        print(event)


asyncio.run(main())
```

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: ekzhu <320302+ekzhu@users.noreply.github.com>
2025-06-04 22:43:26 +00:00

1359 lines
47 KiB
Python

import asyncio
from typing import AsyncGenerator, List, Sequence
from unittest.mock import patch
import pytest
import pytest_asyncio
from autogen_agentchat.agents import (
AssistantAgent,
BaseChatAgent,
MessageFilterAgent,
MessageFilterConfig,
PerSourceFilter,
)
from autogen_agentchat.base import Response, TaskResult
from autogen_agentchat.conditions import MaxMessageTermination, SourceMatchTermination
from autogen_agentchat.messages import BaseChatMessage, ChatMessage, MessageFactory, StopMessage, TextMessage
from autogen_agentchat.teams import (
DiGraphBuilder,
GraphFlow,
)
from autogen_agentchat.teams._group_chat._events import ( # type: ignore[attr-defined]
BaseAgentEvent,
GroupChatTermination,
)
from autogen_agentchat.teams._group_chat._graph._digraph_group_chat import (
_DIGRAPH_STOP_AGENT_NAME, # pyright: ignore[reportPrivateUsage]
DiGraph,
DiGraphEdge,
DiGraphNode,
GraphFlowManager,
)
from autogen_core import AgentRuntime, CancellationToken, Component, SingleThreadedAgentRuntime
from autogen_ext.models.replay import ReplayChatCompletionClient
from pydantic import BaseModel
from utils import compare_message_lists, compare_task_results
def test_create_digraph() -> None:
"""Test creating a simple directed graph."""
graph = DiGraph(
nodes={
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B")]),
"B": DiGraphNode(name="B", edges=[DiGraphEdge(target="C")]),
"C": DiGraphNode(name="C", edges=[]),
}
)
assert "A" in graph.nodes
assert "B" in graph.nodes
assert "C" in graph.nodes
assert len(graph.nodes["A"].edges) == 1
assert len(graph.nodes["B"].edges) == 1
assert len(graph.nodes["C"].edges) == 0
def test_get_parents() -> None:
"""Test computing parent relationships."""
graph = DiGraph(
nodes={
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B")]),
"B": DiGraphNode(name="B", edges=[DiGraphEdge(target="C")]),
"C": DiGraphNode(name="C", edges=[]),
}
)
parents = graph.get_parents()
assert parents["A"] == []
assert parents["B"] == ["A"]
assert parents["C"] == ["B"]
def test_get_start_nodes() -> None:
"""Test retrieving start nodes (nodes with no incoming edges)."""
graph = DiGraph(
nodes={
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B")]),
"B": DiGraphNode(name="B", edges=[DiGraphEdge(target="C")]),
"C": DiGraphNode(name="C", edges=[]),
}
)
start_nodes = graph.get_start_nodes()
assert start_nodes == set(["A"])
def test_get_leaf_nodes() -> None:
"""Test retrieving leaf nodes (nodes with no outgoing edges)."""
graph = DiGraph(
nodes={
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B")]),
"B": DiGraphNode(name="B", edges=[DiGraphEdge(target="C")]),
"C": DiGraphNode(name="C", edges=[]),
}
)
leaf_nodes = graph.get_leaf_nodes()
assert leaf_nodes == set(["C"])
def test_serialization() -> None:
"""Test serializing and deserializing the graph."""
# Use a string condition instead of a lambda
graph = DiGraph(
nodes={
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B", condition="trigger1")]),
"B": DiGraphNode(name="B", edges=[DiGraphEdge(target="C")]),
"C": DiGraphNode(name="C", edges=[]),
}
)
serialized = graph.model_dump_json()
deserialized_graph = DiGraph.model_validate_json(serialized)
assert deserialized_graph.nodes["A"].edges[0].target == "B"
assert deserialized_graph.nodes["A"].edges[0].condition == "trigger1"
assert deserialized_graph.nodes["B"].edges[0].target == "C"
# Test the original condition works
test_msg = TextMessage(content="this has trigger1 in it", source="test")
# Manually check if the string is in the message text
assert "trigger1" in test_msg.to_model_text()
def test_invalid_graph_no_start_node() -> None:
"""Test validation failure when there is no start node."""
graph = DiGraph(
nodes={
"B": DiGraphNode(name="B", edges=[DiGraphEdge(target="C")]),
"C": DiGraphNode(name="C", edges=[DiGraphEdge(target="B")]), # Forms a cycle
}
)
start_nodes = graph.get_start_nodes()
assert len(start_nodes) == 0 # Now it correctly fails when no start nodes exist
def test_invalid_graph_no_leaf_node() -> None:
"""Test validation failure when there is no leaf node."""
graph = DiGraph(
nodes={
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B")]),
"B": DiGraphNode(name="B", edges=[DiGraphEdge(target="C")]),
"C": DiGraphNode(name="C", edges=[DiGraphEdge(target="A")]), # Circular reference
}
)
leaf_nodes = graph.get_leaf_nodes()
assert len(leaf_nodes) == 0 # No true endpoint because of cycle
def test_condition_edge_execution() -> None:
"""Test conditional edge execution support."""
# Use string condition
graph = DiGraph(
nodes={
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B", condition="TRIGGER")]),
"B": DiGraphNode(name="B", edges=[DiGraphEdge(target="C")]),
"C": DiGraphNode(name="C", edges=[]),
}
)
# Check the condition manually
test_message = TextMessage(content="This has TRIGGER in it", source="test")
non_match_message = TextMessage(content="This doesn't match", source="test")
# Check if the string condition is in each message text
assert "TRIGGER" in test_message.to_model_text()
assert "TRIGGER" not in non_match_message.to_model_text()
# Check the condition itself
assert graph.nodes["A"].edges[0].condition == "TRIGGER"
assert graph.nodes["B"].edges[0].condition is None
def test_graph_with_multiple_paths() -> None:
"""Test a graph with multiple execution paths."""
graph = DiGraph(
nodes={
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B"), DiGraphEdge(target="C")]),
"B": DiGraphNode(name="B", edges=[DiGraphEdge(target="D")]),
"C": DiGraphNode(name="C", edges=[DiGraphEdge(target="D")]),
"D": DiGraphNode(name="D", edges=[]),
}
)
parents = graph.get_parents()
assert parents["B"] == ["A"]
assert parents["C"] == ["A"]
assert parents["D"] == ["B", "C"]
start_nodes = graph.get_start_nodes()
assert start_nodes == set(["A"])
leaf_nodes = graph.get_leaf_nodes()
assert leaf_nodes == set(["D"])
def test_cycle_detection_no_cycle() -> None:
"""Test that a valid acyclic graph returns False for cycle check."""
graph = DiGraph(
nodes={
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B")]),
"B": DiGraphNode(name="B", edges=[DiGraphEdge(target="C")]),
"C": DiGraphNode(name="C", edges=[]),
}
)
assert not graph.has_cycles_with_exit()
def test_cycle_detection_with_exit_condition() -> None:
"""Test a graph with cycle and conditional exit passes validation."""
# Use a string condition
graph = DiGraph(
nodes={
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B")]),
"B": DiGraphNode(name="B", edges=[DiGraphEdge(target="C")]),
"C": DiGraphNode(name="C", edges=[DiGraphEdge(target="A", condition="exit")]), # Cycle with condition
}
)
assert graph.has_cycles_with_exit()
def test_cycle_detection_without_exit_condition() -> None:
"""Test that cycle without exit condition raises an error."""
graph = DiGraph(
nodes={
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B")]),
"B": DiGraphNode(name="B", edges=[DiGraphEdge(target="C")]),
"C": DiGraphNode(name="C", edges=[DiGraphEdge(target="A")]), # Cycle without condition
"D": DiGraphNode(name="D", edges=[DiGraphEdge(target="E")]),
"E": DiGraphNode(name="E", edges=[]),
}
)
with pytest.raises(ValueError, match="Cycle detected without exit condition: A -> B -> C -> A"):
graph.has_cycles_with_exit()
def test_validate_graph_success() -> None:
"""Test successful validation of a valid graph."""
graph = DiGraph(
nodes={
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B")]),
"B": DiGraphNode(name="B", edges=[]),
}
)
# No error should be raised
graph.graph_validate()
assert not graph.get_has_cycles()
def test_validate_graph_missing_start_node() -> None:
"""Test validation failure when no start node exists."""
graph = DiGraph(
nodes={
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B")]),
"B": DiGraphNode(name="B", edges=[DiGraphEdge(target="A")]), # Cycle
}
)
with pytest.raises(ValueError, match="Graph must have at least one start node"):
graph.graph_validate()
def test_validate_graph_missing_leaf_node() -> None:
"""Test validation failure when no leaf node exists."""
graph = DiGraph(
nodes={
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B")]),
"B": DiGraphNode(name="B", edges=[DiGraphEdge(target="C")]),
"C": DiGraphNode(name="C", edges=[DiGraphEdge(target="B")]), # Cycle
}
)
with pytest.raises(ValueError, match="Graph must have at least one leaf node"):
graph.graph_validate()
def test_validate_graph_mixed_conditions() -> None:
"""Test validation failure when node has mixed conditional and unconditional edges."""
# Use string for condition
graph = DiGraph(
nodes={
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B", condition="cond"), DiGraphEdge(target="C")]),
"B": DiGraphNode(name="B", edges=[]),
"C": DiGraphNode(name="C", edges=[]),
}
)
with pytest.raises(ValueError, match="Node 'A' has a mix of conditional and unconditional edges"):
graph.graph_validate()
@pytest.mark.asyncio
async def test_invalid_digraph_manager_cycle_without_termination() -> None:
"""Test GraphManager raises error for cyclic graph without termination condition."""
# Create a cyclic graph A → B → A
graph = DiGraph(
nodes={
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B")]),
"B": DiGraphNode(name="B", edges=[DiGraphEdge(target="A")]),
}
)
output_queue: asyncio.Queue[BaseAgentEvent | BaseChatMessage | GroupChatTermination] = asyncio.Queue()
with patch(
"autogen_agentchat.teams._group_chat._base_group_chat_manager.BaseGroupChatManager.__init__",
return_value=None,
):
manager = GraphFlowManager.__new__(GraphFlowManager)
with pytest.raises(ValueError, match="Graph must have at least one start node"):
manager.__init__( # type: ignore[misc]
name="test_manager",
group_topic_type="topic",
output_topic_type="topic",
participant_topic_types=["topic1", "topic2"],
participant_names=["A", "B"],
participant_descriptions=["Agent A", "Agent B"],
output_message_queue=output_queue,
termination_condition=None,
max_turns=None,
message_factory=MessageFactory(),
graph=graph,
)
class _EchoAgent(BaseChatAgent):
def __init__(self, name: str, description: str) -> None:
super().__init__(name, description)
self._last_message: str | None = None
self._total_messages = 0
@property
def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:
return (TextMessage,)
@property
def total_messages(self) -> int:
return self._total_messages
async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response:
if len(messages) > 0:
assert isinstance(messages[0], TextMessage)
self._last_message = messages[0].content
self._total_messages += 1
return Response(chat_message=TextMessage(content=messages[0].content, source=self.name))
else:
assert self._last_message is not None
self._total_messages += 1
return Response(chat_message=TextMessage(content=self._last_message, source=self.name))
async def on_reset(self, cancellation_token: CancellationToken) -> None:
self._last_message = None
@pytest_asyncio.fixture(params=["single_threaded", "embedded"]) # type: ignore
async def runtime(request: pytest.FixtureRequest) -> AsyncGenerator[AgentRuntime | None, None]:
if request.param == "single_threaded":
runtime = SingleThreadedAgentRuntime()
runtime.start()
yield runtime
await runtime.stop()
elif request.param == "embedded":
yield None
TaskType = str | List[ChatMessage] | ChatMessage
@pytest.mark.asyncio
async def test_digraph_group_chat_sequential_execution(runtime: AgentRuntime | None) -> None:
# Create agents A → B → C
agent_a = _EchoAgent("A", description="Echo agent A")
agent_b = _EchoAgent("B", description="Echo agent B")
agent_c = _EchoAgent("C", description="Echo agent C")
# Define graph A → B → C
graph = DiGraph(
nodes={
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B")]),
"B": DiGraphNode(name="B", edges=[DiGraphEdge(target="C")]),
"C": DiGraphNode(name="C", edges=[]),
}
)
# Create team using Graph
team = GraphFlow(
participants=[agent_a, agent_b, agent_c],
graph=graph,
runtime=runtime,
termination_condition=MaxMessageTermination(5),
)
# Run the chat
result: TaskResult = await team.run(task="Hello from User")
assert len(result.messages) == 5
assert isinstance(result.messages[0], TextMessage)
assert result.messages[0].source == "user"
assert result.messages[1].source == "A"
assert result.messages[2].source == "B"
assert result.messages[3].source == "C"
assert result.messages[4].source == _DIGRAPH_STOP_AGENT_NAME
assert all(isinstance(m, TextMessage) for m in result.messages[:-1])
assert result.stop_reason is not None
@pytest.mark.asyncio
async def test_digraph_group_chat_parallel_fanout(runtime: AgentRuntime | None) -> None:
agent_a = _EchoAgent("A", description="Echo agent A")
agent_b = _EchoAgent("B", description="Echo agent B")
agent_c = _EchoAgent("C", description="Echo agent C")
graph = DiGraph(
nodes={
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B"), DiGraphEdge(target="C")]),
"B": DiGraphNode(name="B", edges=[]),
"C": DiGraphNode(name="C", edges=[]),
}
)
team = GraphFlow(
participants=[agent_a, agent_b, agent_c],
graph=graph,
runtime=runtime,
termination_condition=MaxMessageTermination(5),
)
result: TaskResult = await team.run(task="Start")
assert len(result.messages) == 5
assert result.messages[0].source == "user"
assert result.messages[1].source == "A"
assert set(m.source for m in result.messages[2:-1]) == {"B", "C"}
assert result.messages[-1].source == _DIGRAPH_STOP_AGENT_NAME
assert result.stop_reason is not None
@pytest.mark.asyncio
async def test_digraph_group_chat_parallel_join_all(runtime: AgentRuntime | None) -> None:
agent_a = _EchoAgent("A", description="Echo agent A")
agent_b = _EchoAgent("B", description="Echo agent B")
agent_c = _EchoAgent("C", description="Echo agent C")
graph = DiGraph(
nodes={
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="C")]),
"B": DiGraphNode(name="B", edges=[DiGraphEdge(target="C")]),
"C": DiGraphNode(name="C", edges=[], activation="all"),
}
)
team = GraphFlow(
participants=[agent_a, agent_b, agent_c],
graph=graph,
runtime=runtime,
termination_condition=MaxMessageTermination(5),
)
result: TaskResult = await team.run(task="Go")
assert len(result.messages) == 5
assert result.messages[0].source == "user"
assert set([result.messages[1].source, result.messages[2].source]) == {"A", "B"}
assert result.messages[3].source == "C"
assert result.messages[-1].source == _DIGRAPH_STOP_AGENT_NAME
assert result.stop_reason is not None
@pytest.mark.asyncio
async def test_digraph_group_chat_parallel_join_any(runtime: AgentRuntime | None) -> None:
agent_a = _EchoAgent("A", description="Echo agent A")
agent_b = _EchoAgent("B", description="Echo agent B")
agent_c = _EchoAgent("C", description="Echo agent C")
graph = DiGraph(
nodes={
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="C")]),
"B": DiGraphNode(name="B", edges=[DiGraphEdge(target="C")]),
"C": DiGraphNode(name="C", edges=[], activation="any"),
}
)
team = GraphFlow(
participants=[agent_a, agent_b, agent_c],
graph=graph,
runtime=runtime,
termination_condition=MaxMessageTermination(5),
)
result: TaskResult = await team.run(task="Start")
assert len(result.messages) == 5
assert result.messages[0].source == "user"
sources = [m.source for m in result.messages[1:]]
# C must be last
assert sources[-2] == "C"
# A and B must both execute
assert {"A", "B"}.issubset(set(sources))
# One of A or B must execute before C
index_a = sources.index("A")
index_b = sources.index("B")
index_c = sources.index("C")
assert index_c > min(index_a, index_b)
assert result.messages[-1].source == _DIGRAPH_STOP_AGENT_NAME
assert result.stop_reason is not None
@pytest.mark.asyncio
async def test_digraph_group_chat_multiple_start_nodes(runtime: AgentRuntime | None) -> None:
agent_a = _EchoAgent("A", description="Echo agent A")
agent_b = _EchoAgent("B", description="Echo agent B")
graph = DiGraph(
nodes={
"A": DiGraphNode(name="A", edges=[]),
"B": DiGraphNode(name="B", edges=[]),
}
)
team = GraphFlow(
participants=[agent_a, agent_b],
graph=graph,
runtime=runtime,
termination_condition=MaxMessageTermination(5),
)
result: TaskResult = await team.run(task="Start")
assert len(result.messages) == 4
assert result.messages[0].source == "user"
assert set(m.source for m in result.messages[1:-1]) == {"A", "B"}
assert result.messages[-1].source == _DIGRAPH_STOP_AGENT_NAME
assert result.stop_reason is not None
@pytest.mark.asyncio
async def test_digraph_group_chat_disconnected_graph(runtime: AgentRuntime | None) -> None:
agent_a = _EchoAgent("A", description="Echo agent A")
agent_b = _EchoAgent("B", description="Echo agent B")
agent_c = _EchoAgent("C", description="Echo agent C")
agent_d = _EchoAgent("D", description="Echo agent D")
graph = DiGraph(
nodes={
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B")]),
"B": DiGraphNode(name="B", edges=[]),
"C": DiGraphNode(name="C", edges=[DiGraphEdge(target="D")]),
"D": DiGraphNode(name="D", edges=[]),
}
)
team = GraphFlow(
participants=[agent_a, agent_b, agent_c, agent_d],
graph=graph,
runtime=runtime,
termination_condition=MaxMessageTermination(10),
)
result: TaskResult = await team.run(task="Go")
assert len(result.messages) == 6
assert result.messages[0].source == "user"
assert {"A", "C"} == set([result.messages[1].source, result.messages[2].source])
assert {"B", "D"} == set([result.messages[3].source, result.messages[4].source])
assert result.messages[-1].source == _DIGRAPH_STOP_AGENT_NAME
assert result.stop_reason is not None
@pytest.mark.asyncio
async def test_digraph_group_chat_conditional_branch(runtime: AgentRuntime | None) -> None:
agent_a = _EchoAgent("A", description="Echo agent A")
agent_b = _EchoAgent("B", description="Echo agent B")
agent_c = _EchoAgent("C", description="Echo agent C")
# Use string conditions
graph = DiGraph(
nodes={
"A": DiGraphNode(
name="A", edges=[DiGraphEdge(target="B", condition="yes"), DiGraphEdge(target="C", condition="no")]
),
"B": DiGraphNode(name="B", edges=[], activation="any"),
"C": DiGraphNode(name="C", edges=[], activation="any"),
}
)
team = GraphFlow(
participants=[agent_a, agent_b, agent_c],
graph=graph,
runtime=runtime,
termination_condition=MaxMessageTermination(5),
)
result = await team.run(task="Trigger yes")
assert result.messages[2].source == "B"
@pytest.mark.asyncio
async def test_digraph_group_chat_loop_with_exit_condition(runtime: AgentRuntime | None) -> None:
# Agents A and C: Echo Agents
agent_a = _EchoAgent("A", description="Echo agent A")
agent_c = _EchoAgent("C", description="Echo agent C")
# Replay model client for agent B
model_client = ReplayChatCompletionClient(
chat_completions=[
"loop", # First time B will ask to loop
"loop", # Second time B will ask to loop
"exit", # Third time B will say exit
]
)
# Agent B: Assistant Agent using Replay Client
agent_b = AssistantAgent("B", description="Decision agent B", model_client=model_client)
# DiGraph: A → B → C (conditional back to A or terminate)
graph = DiGraph(
nodes={
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B")]),
"B": DiGraphNode(
name="B", edges=[DiGraphEdge(target="C", condition="exit"), DiGraphEdge(target="A", condition="loop")]
),
"C": DiGraphNode(name="C", edges=[]),
},
default_start_node="A",
)
team = GraphFlow(
participants=[agent_a, agent_b, agent_c],
graph=graph,
runtime=runtime,
termination_condition=MaxMessageTermination(20),
)
# Run
result = await team.run(task="Start")
# Assert message order
expected_sources = [
"user",
"A",
"B", # 1st loop
"A",
"B", # 2nd loop
"A",
"B",
"C",
_DIGRAPH_STOP_AGENT_NAME,
]
actual_sources = [m.source for m in result.messages]
assert actual_sources == expected_sources
assert result.stop_reason is not None
assert result.messages[-2].source == "C"
assert any(m.content == "exit" for m in result.messages[:-1]) # type: ignore[attr-defined,union-attr]
assert result.messages[-1].source == _DIGRAPH_STOP_AGENT_NAME
@pytest.mark.asyncio
async def test_digraph_group_chat_parallel_join_any_1(runtime: AgentRuntime | None) -> None:
agent_a = _EchoAgent("A", description="Echo agent A")
agent_b = _EchoAgent("B", description="Echo agent B")
agent_c = _EchoAgent("C", description="Echo agent C")
agent_d = _EchoAgent("D", description="Echo agent D")
graph = DiGraph(
nodes={
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B"), DiGraphEdge(target="C")]),
"B": DiGraphNode(name="B", edges=[DiGraphEdge(target="D")]),
"C": DiGraphNode(name="C", edges=[DiGraphEdge(target="D")]),
"D": DiGraphNode(name="D", edges=[], activation="any"),
}
)
team = GraphFlow(
participants=[agent_a, agent_b, agent_c, agent_d],
graph=graph,
runtime=runtime,
termination_condition=MaxMessageTermination(10),
)
result = await team.run(task="Run parallel join")
sequence = [msg.source for msg in result.messages if isinstance(msg, TextMessage)]
assert sequence[0] == "user"
# B and C should both run
assert "B" in sequence
assert "C" in sequence
# D should trigger twice → once after B and once after C (order depends on runtime)
d_indices = [i for i, s in enumerate(sequence) if s == "D"]
assert len(d_indices) == 1
# Each D trigger must be after corresponding B or C
b_index = sequence.index("B")
c_index = sequence.index("C")
assert any(d > b_index for d in d_indices)
assert any(d > c_index for d in d_indices)
assert result.stop_reason is not None
@pytest.mark.asyncio
async def test_digraph_group_chat_chained_parallel_join_any(runtime: AgentRuntime | None) -> None:
agent_a = _EchoAgent("A", description="Echo agent A")
agent_b = _EchoAgent("B", description="Echo agent B")
agent_c = _EchoAgent("C", description="Echo agent C")
agent_d = _EchoAgent("D", description="Echo agent D")
agent_e = _EchoAgent("E", description="Echo agent E")
graph = DiGraph(
nodes={
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B"), DiGraphEdge(target="C")]),
"B": DiGraphNode(name="B", edges=[DiGraphEdge(target="D")]),
"C": DiGraphNode(name="C", edges=[DiGraphEdge(target="D")]),
"D": DiGraphNode(name="D", edges=[DiGraphEdge(target="E")], activation="any"),
"E": DiGraphNode(name="E", edges=[], activation="any"),
}
)
team = GraphFlow(
participants=[agent_a, agent_b, agent_c, agent_d, agent_e],
graph=graph,
runtime=runtime,
termination_condition=MaxMessageTermination(20),
)
result = await team.run(task="Run chained parallel join-any")
sequence = [msg.source for msg in result.messages if isinstance(msg, TextMessage)]
# D should trigger twice
d_indices = [i for i, s in enumerate(sequence) if s == "D"]
assert len(d_indices) == 1
# Each D trigger must be after corresponding B or C
b_index = sequence.index("B")
c_index = sequence.index("C")
assert any(d > b_index for d in d_indices)
assert any(d > c_index for d in d_indices)
# E should also trigger twice → once after each D
e_indices = [i for i, s in enumerate(sequence) if s == "E"]
assert len(e_indices) == 1
assert e_indices[0] > d_indices[0]
assert result.stop_reason is not None
@pytest.mark.asyncio
async def test_digraph_group_chat_multiple_conditional(runtime: AgentRuntime | None) -> None:
agent_a = _EchoAgent("A", description="Echo agent A")
agent_b = _EchoAgent("B", description="Echo agent B")
agent_c = _EchoAgent("C", description="Echo agent C")
agent_d = _EchoAgent("D", description="Echo agent D")
# Use string conditions
graph = DiGraph(
nodes={
"A": DiGraphNode(
name="A",
edges=[
DiGraphEdge(target="B", condition="apple"),
DiGraphEdge(target="C", condition="banana"),
DiGraphEdge(target="D", condition="cherry"),
],
),
"B": DiGraphNode(name="B", edges=[]),
"C": DiGraphNode(name="C", edges=[]),
"D": DiGraphNode(name="D", edges=[]),
}
)
team = GraphFlow(
participants=[agent_a, agent_b, agent_c, agent_d],
graph=graph,
runtime=runtime,
termination_condition=MaxMessageTermination(5),
)
# Test banana branch
result = await team.run(task="banana")
assert result.messages[2].source == "C"
class _TestMessageFilterAgentConfig(BaseModel):
name: str
description: str = "Echo test agent"
class _TestMessageFilterAgent(BaseChatAgent, Component[_TestMessageFilterAgentConfig]):
component_config_schema = _TestMessageFilterAgentConfig
component_provider_override = "test_group_chat_graph._TestMessageFilterAgent"
def __init__(self, name: str, description: str = "Echo test agent") -> None:
super().__init__(name=name, description=description)
self.received_messages: list[BaseChatMessage] = []
@property
def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:
return (TextMessage,)
async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response:
self.received_messages.extend(messages)
return Response(chat_message=TextMessage(content="ACK", source=self.name))
async def on_reset(self, cancellation_token: CancellationToken) -> None:
self.received_messages.clear()
def _to_config(self) -> _TestMessageFilterAgentConfig:
return _TestMessageFilterAgentConfig(name=self.name, description=self.description)
@classmethod
def _from_config(cls, config: _TestMessageFilterAgentConfig) -> "_TestMessageFilterAgent":
return cls(name=config.name, description=config.description)
@pytest.mark.asyncio
async def test_message_filter_agent_empty_filter_blocks_all() -> None:
inner_agent = _TestMessageFilterAgent("inner")
wrapper = MessageFilterAgent(
name="wrapper",
wrapped_agent=inner_agent,
filter=MessageFilterConfig(per_source=[]),
)
messages = [
TextMessage(source="user", content="Hello"),
TextMessage(source="system", content="System msg"),
]
await wrapper.on_messages(messages, CancellationToken())
assert len(inner_agent.received_messages) == 0
@pytest.mark.asyncio
async def test_message_filter_agent_with_position_none_gets_all() -> None:
inner_agent = _TestMessageFilterAgent("inner")
wrapper = MessageFilterAgent(
name="wrapper",
wrapped_agent=inner_agent,
filter=MessageFilterConfig(per_source=[PerSourceFilter(source="user", position=None, count=None)]),
)
messages = [
TextMessage(source="user", content="A"),
TextMessage(source="user", content="B"),
TextMessage(source="system", content="Ignore this"),
]
await wrapper.on_messages(messages, CancellationToken())
assert len(inner_agent.received_messages) == 2
assert {m.content for m in inner_agent.received_messages} == {"A", "B"} # type: ignore[attr-defined]
@pytest.mark.asyncio
async def test_digraph_group_chat() -> None:
inner_agent = _TestMessageFilterAgent("agent")
wrapper = MessageFilterAgent(
name="agent",
wrapped_agent=inner_agent,
filter=MessageFilterConfig(
per_source=[
PerSourceFilter(source="user", position="last", count=2),
PerSourceFilter(source="system", position="first", count=1),
]
),
)
config = wrapper.dump_component()
loaded = MessageFilterAgent.load_component(config)
assert loaded.name == "agent"
assert loaded._filter == wrapper._filter # pyright: ignore[reportPrivateUsage]
assert loaded._wrapped_agent.name == wrapper._wrapped_agent.name # pyright: ignore[reportPrivateUsage]
# Run on_messages and validate filtering still works
messages = [
TextMessage(source="user", content="u1"),
TextMessage(source="user", content="u2"),
TextMessage(source="user", content="u3"),
TextMessage(source="system", content="s1"),
TextMessage(source="system", content="s2"),
]
await loaded.on_messages(messages, CancellationToken())
received = loaded._wrapped_agent.received_messages # type: ignore[attr-defined]
assert {m.content for m in received} == {"u2", "u3", "s1"} # pyright: ignore[reportUnknownVariableType,reportUnknownMemberType]
@pytest.mark.asyncio
async def test_message_filter_agent_in_digraph_group_chat(runtime: AgentRuntime | None) -> None:
inner_agent = _TestMessageFilterAgent("filtered")
filtered = MessageFilterAgent(
name="filtered",
wrapped_agent=inner_agent,
filter=MessageFilterConfig(
per_source=[
PerSourceFilter(source="user", position="last", count=1),
]
),
)
graph = DiGraph(
nodes={
"filtered": DiGraphNode(name="filtered", edges=[]),
}
)
team = GraphFlow(
participants=[filtered],
graph=graph,
runtime=runtime,
termination_condition=MaxMessageTermination(3),
)
result = await team.run(task="only last user message matters")
assert result.stop_reason is not None
assert any(msg.source == "filtered" for msg in result.messages)
assert any(msg.content == "ACK" for msg in result.messages if msg.source == "filtered") # type: ignore[attr-defined,union-attr]
@pytest.mark.asyncio
async def test_message_filter_agent_loop_graph_visibility(runtime: AgentRuntime | None) -> None:
agent_a_inner = _TestMessageFilterAgent("A")
agent_a = MessageFilterAgent(
name="A",
wrapped_agent=agent_a_inner,
filter=MessageFilterConfig(
per_source=[
PerSourceFilter(source="user", position="first", count=1),
PerSourceFilter(source="B", position="last", count=1),
]
),
)
from autogen_agentchat.agents import AssistantAgent
from autogen_ext.models.replay import ReplayChatCompletionClient
model_client = ReplayChatCompletionClient(["loop", "loop", "exit"])
agent_b_inner = AssistantAgent("B", model_client=model_client)
agent_b = MessageFilterAgent(
name="B",
wrapped_agent=agent_b_inner,
filter=MessageFilterConfig(
per_source=[
PerSourceFilter(source="user", position="first", count=1),
PerSourceFilter(source="A", position="last", count=1),
PerSourceFilter(source="B", position="last", count=10),
]
),
)
agent_c_inner = _TestMessageFilterAgent("C")
agent_c = MessageFilterAgent(
name="C",
wrapped_agent=agent_c_inner,
filter=MessageFilterConfig(
per_source=[
PerSourceFilter(source="user", position="first", count=1),
PerSourceFilter(source="B", position="last", count=1),
]
),
)
graph = DiGraph(
nodes={
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B")]),
"B": DiGraphNode(
name="B",
edges=[
DiGraphEdge(target="C", condition="exit"),
DiGraphEdge(target="A", condition="loop"),
],
),
"C": DiGraphNode(name="C", edges=[]),
},
default_start_node="A",
)
team = GraphFlow(
participants=[agent_a, agent_b, agent_c],
graph=graph,
runtime=runtime,
termination_condition=MaxMessageTermination(20),
)
result = await team.run(task="Start")
assert result.stop_reason is not None
# Check A received: 1 user + 2 from B
assert [m.source for m in agent_a_inner.received_messages].count("user") == 1
assert [m.source for m in agent_a_inner.received_messages].count("B") == 2
# Check C received: 1 user + 1 from B
assert [m.source for m in agent_c_inner.received_messages].count("user") == 1
assert [m.source for m in agent_c_inner.received_messages].count("B") == 1
# Check B received: 1 user + multiple from A + own messages
model_msgs = await agent_b_inner.model_context.get_messages()
sources = [m.source for m in model_msgs] # type: ignore[union-attr]
assert sources.count("user") == 1 # pyright: ignore[reportUnknownMemberType]
assert sources.count("A") >= 3 # pyright: ignore[reportUnknownMemberType]
assert sources.count("B") >= 2 # pyright: ignore[reportUnknownMemberType]
# Test Graph Builder
def test_add_node() -> None:
client = ReplayChatCompletionClient(["response"])
agent = AssistantAgent("A", model_client=client)
builder = DiGraphBuilder()
builder.add_node(agent)
assert "A" in builder.nodes
assert "A" in builder.agents
assert builder.nodes["A"].activation == "all"
def test_add_edge() -> None:
client = ReplayChatCompletionClient(["1", "2"])
a = AssistantAgent("A", model_client=client)
b = AssistantAgent("B", model_client=client)
builder = DiGraphBuilder()
builder.add_node(a).add_node(b)
builder.add_edge(a, b)
assert builder.nodes["A"].edges[0].target == "B"
assert builder.nodes["A"].edges[0].condition is None
def test_add_conditional_edges() -> None:
client = ReplayChatCompletionClient(["1", "2"])
a = AssistantAgent("A", model_client=client)
b = AssistantAgent("B", model_client=client)
c = AssistantAgent("C", model_client=client)
builder = DiGraphBuilder()
builder.add_node(a).add_node(b).add_node(c)
builder.add_conditional_edges(a, {"yes": b, "no": c})
edges = builder.nodes["A"].edges
assert len(edges) == 2
# Extract the condition strings to compare them
conditions = [e.condition for e in edges]
assert "yes" in conditions
assert "no" in conditions
# Match edge targets with conditions
yes_edge = next(e for e in edges if e.condition == "yes")
no_edge = next(e for e in edges if e.condition == "no")
assert yes_edge.target == "B"
assert no_edge.target == "C"
def test_set_entry_point() -> None:
client = ReplayChatCompletionClient(["ok"])
a = AssistantAgent("A", model_client=client)
builder = DiGraphBuilder().add_node(a).set_entry_point(a)
graph = builder.build()
assert graph.default_start_node == "A"
def test_build_graph_validation() -> None:
client = ReplayChatCompletionClient(["1", "2", "3"])
a = AssistantAgent("A", model_client=client)
b = AssistantAgent("B", model_client=client)
c = AssistantAgent("C", model_client=client)
builder = DiGraphBuilder()
builder.add_node(a).add_node(b).add_node(c)
builder.add_edge("A", "B").add_edge("B", "C")
builder.set_entry_point("A")
graph = builder.build()
assert isinstance(graph, DiGraph)
assert set(graph.nodes.keys()) == {"A", "B", "C"}
assert graph.get_start_nodes() == {"A"}
assert graph.get_leaf_nodes() == {"C"}
def test_build_fan_out() -> None:
client = ReplayChatCompletionClient(["hi"] * 3)
a = AssistantAgent("A", model_client=client)
b = AssistantAgent("B", model_client=client)
c = AssistantAgent("C", model_client=client)
builder = DiGraphBuilder()
builder.add_node(a).add_node(b).add_node(c)
builder.add_edge(a, b).add_edge(a, c)
builder.set_entry_point(a)
graph = builder.build()
assert graph.get_start_nodes() == {"A"}
assert graph.get_leaf_nodes() == {"B", "C"}
def test_build_parallel_join() -> None:
client = ReplayChatCompletionClient(["go"] * 3)
a = AssistantAgent("A", model_client=client)
b = AssistantAgent("B", model_client=client)
c = AssistantAgent("C", model_client=client)
builder = DiGraphBuilder()
builder.add_node(a).add_node(b).add_node(c, activation="all")
builder.add_edge(a, c).add_edge(b, c)
builder.set_entry_point(a)
builder.add_edge(b, c)
builder.nodes["B"] = DiGraphNode(name="B", edges=[DiGraphEdge(target="C")])
graph = builder.build()
assert graph.nodes["C"].activation == "all"
assert graph.get_leaf_nodes() == {"C"}
def test_build_conditional_loop() -> None:
client = ReplayChatCompletionClient(["loop", "loop", "exit"])
a = AssistantAgent("A", model_client=client)
b = AssistantAgent("B", model_client=client)
c = AssistantAgent("C", model_client=client)
builder = DiGraphBuilder()
builder.add_node(a).add_node(b).add_node(c)
builder.add_edge(a, b)
builder.add_conditional_edges(b, {"loop": a, "exit": c})
builder.set_entry_point(a)
graph = builder.build()
# Check that edges have the right conditions and targets
edges = graph.nodes["B"].edges
assert len(edges) == 2
# Find edges by their conditions
loop_edge = next(e for e in edges if e.condition == "loop")
exit_edge = next(e for e in edges if e.condition == "exit")
assert loop_edge.target == "A"
assert exit_edge.target == "C"
assert graph.has_cycles_with_exit()
@pytest.mark.asyncio
async def test_graph_builder_sequential_execution(runtime: AgentRuntime | None) -> None:
a = _EchoAgent("A", description="Echo A")
b = _EchoAgent("B", description="Echo B")
c = _EchoAgent("C", description="Echo C")
builder = DiGraphBuilder()
builder.add_node(a).add_node(b).add_node(c)
builder.add_edge(a, b).add_edge(b, c)
team = GraphFlow(
participants=builder.get_participants(),
graph=builder.build(),
runtime=runtime,
termination_condition=MaxMessageTermination(5),
)
result = await team.run(task="Start")
assert [m.source for m in result.messages[1:-1]] == ["A", "B", "C"]
assert result.stop_reason is not None
@pytest.mark.asyncio
async def test_graph_builder_fan_out(runtime: AgentRuntime | None) -> None:
a = _EchoAgent("A", description="Echo A")
b = _EchoAgent("B", description="Echo B")
c = _EchoAgent("C", description="Echo C")
builder = DiGraphBuilder()
builder.add_node(a).add_node(b).add_node(c)
builder.add_edge(a, b).add_edge(a, c)
team = GraphFlow(
participants=builder.get_participants(),
graph=builder.build(),
runtime=runtime,
termination_condition=MaxMessageTermination(5),
)
result = await team.run(task="Start")
sources = [m.source for m in result.messages if isinstance(m, TextMessage)]
assert set(sources[1:]) == {"A", "B", "C"}
assert result.stop_reason is not None
@pytest.mark.asyncio
async def test_graph_builder_conditional_execution(runtime: AgentRuntime | None) -> None:
a = _EchoAgent("A", description="Echo A")
b = _EchoAgent("B", description="Echo B")
c = _EchoAgent("C", description="Echo C")
builder = DiGraphBuilder()
builder.add_node(a).add_node(b).add_node(c)
builder.add_conditional_edges(a, {"yes": b, "no": c})
team = GraphFlow(
participants=builder.get_participants(),
graph=builder.build(),
runtime=runtime,
termination_condition=MaxMessageTermination(5),
)
# Input "no" should trigger the edge to C
result = await team.run(task="no")
sources = [m.source for m in result.messages]
assert "C" in sources
assert result.stop_reason is not None
@pytest.mark.asyncio
async def test_digraph_group_chat_callable_condition(runtime: AgentRuntime | None) -> None:
"""Test that string conditions work correctly in edge transitions."""
agent_a = _EchoAgent("A", description="Echo agent A")
agent_b = _EchoAgent("B", description="Echo agent B")
agent_c = _EchoAgent("C", description="Echo agent C")
graph = DiGraph(
nodes={
"A": DiGraphNode(
name="A",
edges=[
# Will go to B if "long" is in message
DiGraphEdge(target="B", condition="long"),
# Will go to C if "short" is in message
DiGraphEdge(target="C", condition="short"),
],
),
"B": DiGraphNode(name="B", edges=[]),
"C": DiGraphNode(name="C", edges=[]),
}
)
team = GraphFlow(
participants=[agent_a, agent_b, agent_c],
graph=graph,
runtime=runtime,
termination_condition=MaxMessageTermination(5),
)
# Test with a message containing "long" - should go to B
result = await team.run(task="This is a long message")
assert result.messages[2].source == "B"
# Reset for next test
await team.reset()
# Test with a message containing "short" - should go to C
result = await team.run(task="This is a short message")
assert result.messages[2].source == "C"
@pytest.mark.asyncio
async def test_graph_flow_serialize_deserialize() -> None:
client_a = ReplayChatCompletionClient(list(map(str, range(10))))
client_b = ReplayChatCompletionClient(list(map(str, range(10))))
a = AssistantAgent("A", model_client=client_a)
b = AssistantAgent("B", model_client=client_b)
builder = DiGraphBuilder()
builder.add_node(a).add_node(b)
builder.add_edge(a, b)
builder.set_entry_point(a)
team = GraphFlow(
participants=builder.get_participants(),
graph=builder.build(),
runtime=None,
)
serialized = team.dump_component()
deserialized_team = GraphFlow.load_component(serialized)
serialized_deserialized = deserialized_team.dump_component()
results = await team.run(task="Start")
de_results = await deserialized_team.run(task="Start")
assert serialized == serialized_deserialized
assert compare_task_results(results, de_results)
assert results.stop_reason is not None
assert results.stop_reason == de_results.stop_reason
assert compare_message_lists(results.messages, de_results.messages)
assert isinstance(results.messages[0], TextMessage)
assert results.messages[0].source == "user"
assert results.messages[0].content == "Start"
assert isinstance(results.messages[1], TextMessage)
assert results.messages[1].source == "A"
assert results.messages[1].content == "0"
assert isinstance(results.messages[2], TextMessage)
assert results.messages[2].source == "B"
assert results.messages[2].content == "0"
assert isinstance(results.messages[-1], StopMessage)
assert results.messages[-1].source == _DIGRAPH_STOP_AGENT_NAME
assert results.messages[-1].content == "Digraph execution is complete"
@pytest.mark.asyncio
async def test_graph_flow_stateful_pause_and_resume_with_termination() -> None:
client_a = ReplayChatCompletionClient(["A1", "A2"])
client_b = ReplayChatCompletionClient(["B1"])
a = AssistantAgent("A", model_client=client_a)
b = AssistantAgent("B", model_client=client_b)
builder = DiGraphBuilder()
builder.add_node(a).add_node(b)
builder.add_edge(a, b)
builder.set_entry_point(a)
team = GraphFlow(
participants=builder.get_participants(),
graph=builder.build(),
runtime=None,
termination_condition=SourceMatchTermination(sources=["A"]),
)
result = await team.run(task="Start")
assert len(result.messages) == 2
assert result.messages[0].source == "user"
assert result.messages[1].source == "A"
assert result.stop_reason is not None and result.stop_reason == "'A' answered"
# Export state.
state = await team.save_state()
# Load state into a new team.
new_team = GraphFlow(
participants=builder.get_participants(),
graph=builder.build(),
runtime=None,
)
await new_team.load_state(state)
# Resume.
result = await new_team.run()
assert len(result.messages) == 2
assert result.messages[0].source == "B"
assert result.messages[1].source == _DIGRAPH_STOP_AGENT_NAME
@pytest.mark.asyncio
async def test_builder_with_lambda_condition(runtime: AgentRuntime | None) -> None:
"""Test that DiGraphBuilder supports string conditions."""
agent_a = _EchoAgent("A", description="Echo agent A")
agent_b = _EchoAgent("B", description="Echo agent B")
agent_c = _EchoAgent("C", description="Echo agent C")
builder = DiGraphBuilder()
builder.add_node(agent_a).add_node(agent_b).add_node(agent_c)
# Using callable conditions
builder.add_edge(agent_a, agent_b, lambda msg: "even" in msg.to_model_text())
builder.add_edge(agent_a, agent_c, lambda msg: "odd" in msg.to_model_text())
team = GraphFlow(
participants=builder.get_participants(),
graph=builder.build(),
runtime=runtime,
termination_condition=MaxMessageTermination(5),
)
# Test with "even" in message - should go to B
result = await team.run(task="even length")
assert result.messages[2].source == "B"
# Reset for next test
await team.reset()
# Test with "odd" in message - should go to C
result = await team.run(task="odd message")
assert result.messages[2].source == "C"