autogen/python/packages/autogen-agentchat/tests/test_group_chat_graph.py
Zen 9b8dc8d707
add activation group for workflow with multiple cycles (#6711)
## Why are these changes needed?
1. problem
When the GraphFlowManager encounters cycles, it tracks remaining
indegree counts for the node's activation. However, this tracking
mechanism has a flaw when dealing with cycles. When a node first enters
a cycle, the GraphFlowManager evaluates all remaining incoming edges,
including those that loop back to the origin node. If the activation
prerequisites are not satisfied at that moment, the workflow will
eventually finish because the _remaining counter never reaches zero,
preventing the select_speaker() method from selecting any agents for
execution.
2. solution
change activation map to 2 layer for ditinguish remaining inside
different cycle and outside the cycle.
add a activation group and policy property for edge, compute the
remaining map when GraphFlowManager is init and check the remaining map
with activation group to avoid checking the loop back edges
<!-- Please give a short summary of the change and the problem this
solves. -->

## Related issue number

#6710

## Checks

- [x] I've included any doc changes needed for
<https://microsoft.github.io/autogen/>. See
<https://github.com/microsoft/autogen/blob/main/CONTRIBUTING.md> to
build and test documentation locally.
- [x] I've added tests (if relevant) corresponding to the changes
introduced in this PR.
- [x] I've made sure all auto checks have passed.
2025-06-25 12:20:04 +08:00

1629 lines
57 KiB
Python

import asyncio
import re
from typing import AsyncGenerator, List, Sequence
from unittest.mock import patch
import pytest
import pytest_asyncio
from autogen_agentchat.agents import (
AssistantAgent,
BaseChatAgent,
MessageFilterAgent,
MessageFilterConfig,
PerSourceFilter,
)
from autogen_agentchat.base import Response, TaskResult
from autogen_agentchat.conditions import MaxMessageTermination, SourceMatchTermination
from autogen_agentchat.messages import BaseChatMessage, ChatMessage, MessageFactory, StopMessage, TextMessage
from autogen_agentchat.teams import (
DiGraphBuilder,
GraphFlow,
)
from autogen_agentchat.teams._group_chat._events import ( # type: ignore[attr-defined]
BaseAgentEvent,
GroupChatTermination,
)
from autogen_agentchat.teams._group_chat._graph._digraph_group_chat import (
_DIGRAPH_STOP_AGENT_NAME, # pyright: ignore[reportPrivateUsage]
DiGraph,
DiGraphEdge,
DiGraphNode,
GraphFlowManager,
)
from autogen_core import AgentRuntime, CancellationToken, Component, SingleThreadedAgentRuntime
from autogen_ext.models.replay import ReplayChatCompletionClient
from pydantic import BaseModel
from utils import compare_message_lists, compare_task_results
def test_create_digraph() -> None:
"""Test creating a simple directed graph."""
graph = DiGraph(
nodes={
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B")]),
"B": DiGraphNode(name="B", edges=[DiGraphEdge(target="C")]),
"C": DiGraphNode(name="C", edges=[]),
}
)
assert "A" in graph.nodes
assert "B" in graph.nodes
assert "C" in graph.nodes
assert len(graph.nodes["A"].edges) == 1
assert len(graph.nodes["B"].edges) == 1
assert len(graph.nodes["C"].edges) == 0
def test_get_parents() -> None:
"""Test computing parent relationships."""
graph = DiGraph(
nodes={
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B")]),
"B": DiGraphNode(name="B", edges=[DiGraphEdge(target="C")]),
"C": DiGraphNode(name="C", edges=[]),
}
)
parents = graph.get_parents()
assert parents["A"] == []
assert parents["B"] == ["A"]
assert parents["C"] == ["B"]
def test_get_start_nodes() -> None:
"""Test retrieving start nodes (nodes with no incoming edges)."""
graph = DiGraph(
nodes={
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B")]),
"B": DiGraphNode(name="B", edges=[DiGraphEdge(target="C")]),
"C": DiGraphNode(name="C", edges=[]),
}
)
start_nodes = graph.get_start_nodes()
assert start_nodes == set(["A"])
def test_get_leaf_nodes() -> None:
"""Test retrieving leaf nodes (nodes with no outgoing edges)."""
graph = DiGraph(
nodes={
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B")]),
"B": DiGraphNode(name="B", edges=[DiGraphEdge(target="C")]),
"C": DiGraphNode(name="C", edges=[]),
}
)
leaf_nodes = graph.get_leaf_nodes()
assert leaf_nodes == set(["C"])
def test_serialization() -> None:
"""Test serializing and deserializing the graph."""
# Use a string condition instead of a lambda
graph = DiGraph(
nodes={
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B", condition="trigger1")]),
"B": DiGraphNode(name="B", edges=[DiGraphEdge(target="C")]),
"C": DiGraphNode(name="C", edges=[]),
}
)
serialized = graph.model_dump_json()
deserialized_graph = DiGraph.model_validate_json(serialized)
assert deserialized_graph.nodes["A"].edges[0].target == "B"
assert deserialized_graph.nodes["A"].edges[0].condition == "trigger1"
assert deserialized_graph.nodes["B"].edges[0].target == "C"
# Test the original condition works
test_msg = TextMessage(content="this has trigger1 in it", source="test")
# Manually check if the string is in the message text
assert "trigger1" in test_msg.to_model_text()
def test_invalid_graph_no_start_node() -> None:
"""Test validation failure when there is no start node."""
graph = DiGraph(
nodes={
"B": DiGraphNode(name="B", edges=[DiGraphEdge(target="C")]),
"C": DiGraphNode(name="C", edges=[DiGraphEdge(target="B")]), # Forms a cycle
}
)
start_nodes = graph.get_start_nodes()
assert len(start_nodes) == 0 # Now it correctly fails when no start nodes exist
def test_invalid_graph_no_leaf_node() -> None:
"""Test validation failure when there is no leaf node."""
graph = DiGraph(
nodes={
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B")]),
"B": DiGraphNode(name="B", edges=[DiGraphEdge(target="C")]),
"C": DiGraphNode(name="C", edges=[DiGraphEdge(target="A")]), # Circular reference
}
)
leaf_nodes = graph.get_leaf_nodes()
assert len(leaf_nodes) == 0 # No true endpoint because of cycle
def test_condition_edge_execution() -> None:
"""Test conditional edge execution support."""
# Use string condition
graph = DiGraph(
nodes={
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B", condition="TRIGGER")]),
"B": DiGraphNode(name="B", edges=[DiGraphEdge(target="C")]),
"C": DiGraphNode(name="C", edges=[]),
}
)
# Check the condition manually
test_message = TextMessage(content="This has TRIGGER in it", source="test")
non_match_message = TextMessage(content="This doesn't match", source="test")
# Check if the string condition is in each message text
assert "TRIGGER" in test_message.to_model_text()
assert "TRIGGER" not in non_match_message.to_model_text()
# Check the condition itself
assert graph.nodes["A"].edges[0].condition == "TRIGGER"
assert graph.nodes["B"].edges[0].condition is None
def test_graph_with_multiple_paths() -> None:
"""Test a graph with multiple execution paths."""
graph = DiGraph(
nodes={
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B"), DiGraphEdge(target="C")]),
"B": DiGraphNode(name="B", edges=[DiGraphEdge(target="D")]),
"C": DiGraphNode(name="C", edges=[DiGraphEdge(target="D")]),
"D": DiGraphNode(name="D", edges=[]),
}
)
parents = graph.get_parents()
assert parents["B"] == ["A"]
assert parents["C"] == ["A"]
assert parents["D"] == ["B", "C"]
start_nodes = graph.get_start_nodes()
assert start_nodes == set(["A"])
leaf_nodes = graph.get_leaf_nodes()
assert leaf_nodes == set(["D"])
def test_cycle_detection_no_cycle() -> None:
"""Test that a valid acyclic graph returns False for cycle check."""
graph = DiGraph(
nodes={
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B")]),
"B": DiGraphNode(name="B", edges=[DiGraphEdge(target="C")]),
"C": DiGraphNode(name="C", edges=[]),
}
)
assert not graph.has_cycles_with_exit()
def test_cycle_detection_with_exit_condition() -> None:
"""Test a graph with cycle and conditional exit passes validation."""
# Use a string condition
graph = DiGraph(
nodes={
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B")]),
"B": DiGraphNode(name="B", edges=[DiGraphEdge(target="C")]),
"C": DiGraphNode(name="C", edges=[DiGraphEdge(target="A", condition="exit")]), # Cycle with condition
}
)
assert graph.has_cycles_with_exit()
# Use a lambda condition
graph_with_lambda = DiGraph(
nodes={
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B")]),
"B": DiGraphNode(name="B", edges=[DiGraphEdge(target="C")]),
"C": DiGraphNode(
name="C", edges=[DiGraphEdge(target="A", condition=lambda msg: "test" in msg.to_model_text())]
), # Cycle with lambda
}
)
assert graph_with_lambda.has_cycles_with_exit()
def test_cycle_detection_without_exit_condition() -> None:
"""Test that cycle without exit condition raises an error."""
graph = DiGraph(
nodes={
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B")]),
"B": DiGraphNode(name="B", edges=[DiGraphEdge(target="C")]),
"C": DiGraphNode(name="C", edges=[DiGraphEdge(target="A")]), # Cycle without condition
"D": DiGraphNode(name="D", edges=[DiGraphEdge(target="E")]),
"E": DiGraphNode(name="E", edges=[]),
}
)
with pytest.raises(ValueError, match="Cycle detected without exit condition: A -> B -> C -> A"):
graph.has_cycles_with_exit()
def test_different_activation_groups_detection() -> None:
"""Test different activation groups."""
graph = DiGraph(
nodes={
"A": DiGraphNode(
name="A",
edges=[
DiGraphEdge(target="B"),
DiGraphEdge(target="C"),
],
),
"B": DiGraphNode(name="B", edges=[DiGraphEdge(target="D", activation_condition="all")]),
"C": DiGraphNode(name="C", edges=[DiGraphEdge(target="D", activation_condition="any")]),
"D": DiGraphNode(name="D", edges=[]),
}
)
with pytest.raises(
ValueError,
match=re.escape(
"Conflicting activation conditions for target 'D' group 'D': "
"'all' (from node 'B') and 'any' (from node 'C')"
),
):
graph.graph_validate()
def test_validate_graph_success() -> None:
"""Test successful validation of a valid graph."""
graph = DiGraph(
nodes={
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B")]),
"B": DiGraphNode(name="B", edges=[]),
}
)
# No error should be raised
graph.graph_validate()
assert not graph.get_has_cycles()
# Use a lambda condition
graph_with_lambda = DiGraph(
nodes={
"A": DiGraphNode(
name="A", edges=[DiGraphEdge(target="B", condition=lambda msg: "test" in msg.to_model_text())]
),
"B": DiGraphNode(name="B", edges=[]),
}
)
# No error should be raised
graph_with_lambda.graph_validate()
assert not graph_with_lambda.get_has_cycles()
def test_validate_graph_missing_start_node() -> None:
"""Test validation failure when no start node exists."""
graph = DiGraph(
nodes={
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B")]),
"B": DiGraphNode(name="B", edges=[DiGraphEdge(target="A")]), # Cycle
}
)
with pytest.raises(ValueError, match="Graph must have at least one start node"):
graph.graph_validate()
def test_validate_graph_missing_leaf_node() -> None:
"""Test validation failure when no leaf node exists."""
graph = DiGraph(
nodes={
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B")]),
"B": DiGraphNode(name="B", edges=[DiGraphEdge(target="C")]),
"C": DiGraphNode(name="C", edges=[DiGraphEdge(target="B")]), # Cycle
}
)
with pytest.raises(ValueError, match="Graph must have at least one leaf node"):
graph.graph_validate()
def test_validate_graph_mixed_conditions() -> None:
"""Test validation failure when node has mixed conditional and unconditional edges."""
# Use string for condition
graph = DiGraph(
nodes={
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B", condition="cond"), DiGraphEdge(target="C")]),
"B": DiGraphNode(name="B", edges=[]),
"C": DiGraphNode(name="C", edges=[]),
}
)
with pytest.raises(ValueError, match="Node 'A' has a mix of conditional and unconditional edges"):
graph.graph_validate()
# Use lambda for condition
graph_with_lambda = DiGraph(
nodes={
"A": DiGraphNode(
name="A",
edges=[
DiGraphEdge(target="B", condition=lambda msg: "test" in msg.to_model_text()),
DiGraphEdge(target="C"),
],
),
"B": DiGraphNode(name="B", edges=[]),
"C": DiGraphNode(name="C", edges=[]),
}
)
with pytest.raises(ValueError, match="Node 'A' has a mix of conditional and unconditional edges"):
graph_with_lambda.graph_validate()
@pytest.mark.asyncio
async def test_invalid_digraph_manager_cycle_without_termination() -> None:
"""Test GraphManager raises error for cyclic graph without termination condition."""
# Create a cyclic graph A → B → A
graph = DiGraph(
nodes={
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B")]),
"B": DiGraphNode(name="B", edges=[DiGraphEdge(target="A")]),
}
)
output_queue: asyncio.Queue[BaseAgentEvent | BaseChatMessage | GroupChatTermination] = asyncio.Queue()
with patch(
"autogen_agentchat.teams._group_chat._base_group_chat_manager.BaseGroupChatManager.__init__",
return_value=None,
):
manager = GraphFlowManager.__new__(GraphFlowManager)
with pytest.raises(ValueError, match="Graph must have at least one start node"):
manager.__init__( # type: ignore[misc]
name="test_manager",
group_topic_type="topic",
output_topic_type="topic",
participant_topic_types=["topic1", "topic2"],
participant_names=["A", "B"],
participant_descriptions=["Agent A", "Agent B"],
output_message_queue=output_queue,
termination_condition=None,
max_turns=None,
message_factory=MessageFactory(),
graph=graph,
)
class _EchoAgent(BaseChatAgent):
def __init__(self, name: str, description: str) -> None:
super().__init__(name, description)
self._last_message: str | None = None
self._total_messages = 0
@property
def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:
return (TextMessage,)
@property
def total_messages(self) -> int:
return self._total_messages
async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response:
if len(messages) > 0:
assert isinstance(messages[0], TextMessage)
self._last_message = messages[0].content
self._total_messages += 1
return Response(chat_message=TextMessage(content=messages[0].content, source=self.name))
else:
assert self._last_message is not None
self._total_messages += 1
return Response(chat_message=TextMessage(content=self._last_message, source=self.name))
async def on_reset(self, cancellation_token: CancellationToken) -> None:
self._last_message = None
@pytest_asyncio.fixture(params=["single_threaded", "embedded"]) # type: ignore
async def runtime(request: pytest.FixtureRequest) -> AsyncGenerator[AgentRuntime | None, None]:
if request.param == "single_threaded":
runtime = SingleThreadedAgentRuntime()
runtime.start()
yield runtime
await runtime.stop()
elif request.param == "embedded":
yield None
TaskType = str | List[ChatMessage] | ChatMessage
@pytest.mark.asyncio
async def test_digraph_group_chat_sequential_execution(runtime: AgentRuntime | None) -> None:
# Create agents A → B → C
agent_a = _EchoAgent("A", description="Echo agent A")
agent_b = _EchoAgent("B", description="Echo agent B")
agent_c = _EchoAgent("C", description="Echo agent C")
# Define graph A → B → C
graph = DiGraph(
nodes={
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B")]),
"B": DiGraphNode(name="B", edges=[DiGraphEdge(target="C")]),
"C": DiGraphNode(name="C", edges=[]),
}
)
# Create team using Graph
team = GraphFlow(
participants=[agent_a, agent_b, agent_c],
graph=graph,
runtime=runtime,
termination_condition=MaxMessageTermination(5),
)
# Run the chat
result: TaskResult = await team.run(task="Hello from User")
assert len(result.messages) == 5
assert isinstance(result.messages[0], TextMessage)
assert result.messages[0].source == "user"
assert result.messages[1].source == "A"
assert result.messages[2].source == "B"
assert result.messages[3].source == "C"
assert result.messages[4].source == _DIGRAPH_STOP_AGENT_NAME
assert all(isinstance(m, TextMessage) for m in result.messages[:-1])
assert result.stop_reason is not None
@pytest.mark.asyncio
async def test_digraph_group_chat_parallel_fanout(runtime: AgentRuntime | None) -> None:
agent_a = _EchoAgent("A", description="Echo agent A")
agent_b = _EchoAgent("B", description="Echo agent B")
agent_c = _EchoAgent("C", description="Echo agent C")
graph = DiGraph(
nodes={
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B"), DiGraphEdge(target="C")]),
"B": DiGraphNode(name="B", edges=[]),
"C": DiGraphNode(name="C", edges=[]),
}
)
team = GraphFlow(
participants=[agent_a, agent_b, agent_c],
graph=graph,
runtime=runtime,
termination_condition=MaxMessageTermination(5),
)
result: TaskResult = await team.run(task="Start")
assert len(result.messages) == 5
assert result.messages[0].source == "user"
assert result.messages[1].source == "A"
assert set(m.source for m in result.messages[2:-1]) == {"B", "C"}
assert result.messages[-1].source == _DIGRAPH_STOP_AGENT_NAME
assert result.stop_reason is not None
@pytest.mark.asyncio
async def test_digraph_group_chat_parallel_join_all(runtime: AgentRuntime | None) -> None:
agent_a = _EchoAgent("A", description="Echo agent A")
agent_b = _EchoAgent("B", description="Echo agent B")
agent_c = _EchoAgent("C", description="Echo agent C")
graph = DiGraph(
nodes={
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="C")]),
"B": DiGraphNode(name="B", edges=[DiGraphEdge(target="C")]),
"C": DiGraphNode(name="C", edges=[], activation="all"),
}
)
team = GraphFlow(
participants=[agent_a, agent_b, agent_c],
graph=graph,
runtime=runtime,
termination_condition=MaxMessageTermination(5),
)
result: TaskResult = await team.run(task="Go")
assert len(result.messages) == 5
assert result.messages[0].source == "user"
assert set([result.messages[1].source, result.messages[2].source]) == {"A", "B"}
assert result.messages[3].source == "C"
assert result.messages[-1].source == _DIGRAPH_STOP_AGENT_NAME
assert result.stop_reason is not None
@pytest.mark.asyncio
async def test_digraph_group_chat_parallel_join_any(runtime: AgentRuntime | None) -> None:
agent_a = _EchoAgent("A", description="Echo agent A")
agent_b = _EchoAgent("B", description="Echo agent B")
agent_c = _EchoAgent("C", description="Echo agent C")
graph = DiGraph(
nodes={
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="C")]),
"B": DiGraphNode(name="B", edges=[DiGraphEdge(target="C")]),
"C": DiGraphNode(name="C", edges=[], activation="any"),
}
)
team = GraphFlow(
participants=[agent_a, agent_b, agent_c],
graph=graph,
runtime=runtime,
termination_condition=MaxMessageTermination(5),
)
result: TaskResult = await team.run(task="Start")
assert len(result.messages) == 5
assert result.messages[0].source == "user"
sources = [m.source for m in result.messages[1:]]
# C must be last
assert sources[-2] == "C"
# A and B must both execute
assert {"A", "B"}.issubset(set(sources))
# One of A or B must execute before C
index_a = sources.index("A")
index_b = sources.index("B")
index_c = sources.index("C")
assert index_c > min(index_a, index_b)
assert result.messages[-1].source == _DIGRAPH_STOP_AGENT_NAME
assert result.stop_reason is not None
@pytest.mark.asyncio
async def test_digraph_group_chat_multiple_start_nodes(runtime: AgentRuntime | None) -> None:
agent_a = _EchoAgent("A", description="Echo agent A")
agent_b = _EchoAgent("B", description="Echo agent B")
graph = DiGraph(
nodes={
"A": DiGraphNode(name="A", edges=[]),
"B": DiGraphNode(name="B", edges=[]),
}
)
team = GraphFlow(
participants=[agent_a, agent_b],
graph=graph,
runtime=runtime,
termination_condition=MaxMessageTermination(5),
)
result: TaskResult = await team.run(task="Start")
assert len(result.messages) == 4
assert result.messages[0].source == "user"
assert set(m.source for m in result.messages[1:-1]) == {"A", "B"}
assert result.messages[-1].source == _DIGRAPH_STOP_AGENT_NAME
assert result.stop_reason is not None
@pytest.mark.asyncio
async def test_digraph_group_chat_disconnected_graph(runtime: AgentRuntime | None) -> None:
agent_a = _EchoAgent("A", description="Echo agent A")
agent_b = _EchoAgent("B", description="Echo agent B")
agent_c = _EchoAgent("C", description="Echo agent C")
agent_d = _EchoAgent("D", description="Echo agent D")
graph = DiGraph(
nodes={
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B")]),
"B": DiGraphNode(name="B", edges=[]),
"C": DiGraphNode(name="C", edges=[DiGraphEdge(target="D")]),
"D": DiGraphNode(name="D", edges=[]),
}
)
team = GraphFlow(
participants=[agent_a, agent_b, agent_c, agent_d],
graph=graph,
runtime=runtime,
termination_condition=MaxMessageTermination(10),
)
result: TaskResult = await team.run(task="Go")
assert len(result.messages) == 6
assert result.messages[0].source == "user"
assert {"A", "C"} == set([result.messages[1].source, result.messages[2].source])
assert {"B", "D"} == set([result.messages[3].source, result.messages[4].source])
assert result.messages[-1].source == _DIGRAPH_STOP_AGENT_NAME
assert result.stop_reason is not None
@pytest.mark.asyncio
async def test_digraph_group_chat_conditional_branch(runtime: AgentRuntime | None) -> None:
agent_a = _EchoAgent("A", description="Echo agent A")
agent_b = _EchoAgent("B", description="Echo agent B")
agent_c = _EchoAgent("C", description="Echo agent C")
# Use string conditions
graph = DiGraph(
nodes={
"A": DiGraphNode(
name="A", edges=[DiGraphEdge(target="B", condition="yes"), DiGraphEdge(target="C", condition="no")]
),
"B": DiGraphNode(name="B", edges=[], activation="any"),
"C": DiGraphNode(name="C", edges=[], activation="any"),
}
)
team = GraphFlow(
participants=[agent_a, agent_b, agent_c],
graph=graph,
runtime=runtime,
termination_condition=MaxMessageTermination(5),
)
result = await team.run(task="Trigger yes")
assert result.messages[2].source == "B"
# Use lambda conditions
graph_with_lambda = DiGraph(
nodes={
"A": DiGraphNode(
name="A",
edges=[
DiGraphEdge(target="B", condition=lambda msg: "yes" in msg.to_model_text()),
DiGraphEdge(target="C", condition=lambda msg: "no" in msg.to_model_text()),
],
),
"B": DiGraphNode(name="B", edges=[], activation="any"),
"C": DiGraphNode(name="C", edges=[], activation="any"),
}
)
team_with_lambda = GraphFlow(
participants=[agent_a, agent_b, agent_c],
graph=graph_with_lambda,
runtime=runtime,
termination_condition=MaxMessageTermination(5),
)
result_with_lambda = await team_with_lambda.run(task="Trigger no")
assert result_with_lambda.messages[2].source == "C"
@pytest.mark.asyncio
async def test_digraph_group_chat_loop_with_exit_condition(runtime: AgentRuntime | None) -> None:
# Agents A and C: Echo Agents
agent_a = _EchoAgent("A", description="Echo agent A")
agent_c = _EchoAgent("C", description="Echo agent C")
# Replay model client for agent B
model_client = ReplayChatCompletionClient(
chat_completions=[
"loop", # First time B will ask to loop
"loop", # Second time B will ask to loop
"exit", # Third time B will say exit
]
)
# Agent B: Assistant Agent using Replay Client
agent_b = AssistantAgent("B", description="Decision agent B", model_client=model_client)
# DiGraph: A → B → C (conditional back to A or terminate)
graph = DiGraph(
nodes={
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B")]),
"B": DiGraphNode(
name="B", edges=[DiGraphEdge(target="C", condition="exit"), DiGraphEdge(target="A", condition="loop")]
),
"C": DiGraphNode(name="C", edges=[]),
},
default_start_node="A",
)
team = GraphFlow(
participants=[agent_a, agent_b, agent_c],
graph=graph,
runtime=runtime,
termination_condition=MaxMessageTermination(20),
)
# Run
result = await team.run(task="Start")
# Assert message order
expected_sources = [
"user",
"A",
"B", # 1st loop
"A",
"B", # 2nd loop
"A",
"B",
"C",
_DIGRAPH_STOP_AGENT_NAME,
]
actual_sources = [m.source for m in result.messages]
assert actual_sources == expected_sources
assert result.stop_reason is not None
assert result.messages[-2].source == "C"
assert any(m.content == "exit" for m in result.messages[:-1]) # type: ignore[attr-defined,union-attr]
assert result.messages[-1].source == _DIGRAPH_STOP_AGENT_NAME
@pytest.mark.asyncio
async def test_digraph_group_chat_loop_with_self_cycle(runtime: AgentRuntime | None) -> None:
# Agents A and C: Echo Agents
agent_a = _EchoAgent("A", description="Echo agent A")
agent_c = _EchoAgent("C", description="Echo agent C")
# Replay model client for agent B
model_client = ReplayChatCompletionClient(
chat_completions=[
"loop", # First time B will ask to loop
"loop", # Second time B will ask to loop
"exit", # Third time B will say exit
]
)
# Agent B: Assistant Agent using Replay Client
agent_b = AssistantAgent("B", description="Decision agent B", model_client=model_client)
# DiGraph: A → B(self loop) → C (conditional back to A or terminate)
graph = DiGraph(
nodes={
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B")]),
"B": DiGraphNode(
name="B",
edges=[
DiGraphEdge(target="C", condition="exit"),
DiGraphEdge(target="B", condition="loop", activation_group="B_loop"),
],
),
"C": DiGraphNode(name="C", edges=[]),
},
default_start_node="A",
)
team = GraphFlow(
participants=[agent_a, agent_b, agent_c],
graph=graph,
runtime=runtime,
termination_condition=MaxMessageTermination(20),
)
# Run
result = await team.run(task="Start")
# Assert message order
expected_sources = [
"user",
"A",
"B", # 1st loop
"B", # 2nd loop
"B",
"C",
_DIGRAPH_STOP_AGENT_NAME,
]
actual_sources = [m.source for m in result.messages]
assert actual_sources == expected_sources
assert result.stop_reason is not None
assert result.messages[-2].source == "C"
assert any(m.content == "exit" for m in result.messages[:-1]) # type: ignore[attr-defined,union-attr]
assert result.messages[-1].source == _DIGRAPH_STOP_AGENT_NAME
@pytest.mark.asyncio
async def test_digraph_group_chat_loop_with_two_cycles(runtime: AgentRuntime | None) -> None:
# Agents A and C: Echo Agents
agent_a = _EchoAgent("A", description="Echo agent A")
agent_b = _EchoAgent("B", description="Echo agent B")
agent_c = _EchoAgent("C", description="Echo agent C")
agent_e = _EchoAgent("E", description="Echo agent E")
# Replay model client for agent B
model_client = ReplayChatCompletionClient(
chat_completions=[
"to_x", # First time O will branch to B
"to_o", # X will go back to O
"to_y", # Second time O will branch to C
"to_o", # Y will go back to O
"exit", # Third time O will say exit
]
)
# Agent o, b, c: Assistant Agent using Replay Client
agent_o = AssistantAgent("O", description="Decision agent o", model_client=model_client)
agent_x = AssistantAgent("X", description="Decision agent x", model_client=model_client)
agent_y = AssistantAgent("Y", description="Decision agent y", model_client=model_client)
# DiGraph:
#
# A
# / \
# B C
# \ |
# X = O = Y (bidirectional)
# |
# E(exit)
graph = DiGraph(
nodes={
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B"), DiGraphEdge(target="C")]),
"B": DiGraphNode(
name="B", edges=[DiGraphEdge(target="O")]
), # default activation group name is same as target node name "O"
"C": DiGraphNode(
name="C", edges=[DiGraphEdge(target="O")]
), # default activation group name is same as target node name "O"
"O": DiGraphNode(
name="O",
edges=[
DiGraphEdge(target="X", condition="to_x"),
DiGraphEdge(target="Y", condition="to_y"),
DiGraphEdge(target="E", condition="exit"),
],
),
"X": DiGraphNode(name="X", edges=[DiGraphEdge(target="O", condition="to_o", activation_group="x_o_loop")]),
"Y": DiGraphNode(name="Y", edges=[DiGraphEdge(target="O", condition="to_o", activation_group="y_o_loop")]),
"E": DiGraphNode(name="E", edges=[]),
},
default_start_node="A",
)
team = GraphFlow(
participants=[agent_a, agent_o, agent_b, agent_c, agent_x, agent_y, agent_e],
graph=graph,
runtime=runtime,
termination_condition=MaxMessageTermination(20),
)
# Run
result = await team.run(task="Start")
# Assert message order
expected_sources = [
"user",
"A",
"B",
"C",
"O",
"X", # O -> X
"O", # X -> O
"Y", # O -> Y
"O", # Y -> O
"E", # O -> E
_DIGRAPH_STOP_AGENT_NAME,
]
actual_sources = [m.source for m in result.messages]
assert actual_sources == expected_sources
assert result.stop_reason is not None
assert result.messages[-2].source == "E"
assert any(m.content == "exit" for m in result.messages[:-1]) # type: ignore[attr-defined,union-attr]
assert result.messages[-1].source == _DIGRAPH_STOP_AGENT_NAME
@pytest.mark.asyncio
async def test_digraph_group_chat_parallel_join_any_1(runtime: AgentRuntime | None) -> None:
agent_a = _EchoAgent("A", description="Echo agent A")
agent_b = _EchoAgent("B", description="Echo agent B")
agent_c = _EchoAgent("C", description="Echo agent C")
agent_d = _EchoAgent("D", description="Echo agent D")
graph = DiGraph(
nodes={
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B"), DiGraphEdge(target="C")]),
"B": DiGraphNode(name="B", edges=[DiGraphEdge(target="D", activation_group="any")]),
"C": DiGraphNode(name="C", edges=[DiGraphEdge(target="D", activation_group="any")]),
"D": DiGraphNode(name="D", edges=[]),
}
)
team = GraphFlow(
participants=[agent_a, agent_b, agent_c, agent_d],
graph=graph,
runtime=runtime,
termination_condition=MaxMessageTermination(10),
)
result = await team.run(task="Run parallel join")
sequence = [msg.source for msg in result.messages if isinstance(msg, TextMessage)]
assert sequence[0] == "user"
# B and C should both run
assert "B" in sequence
assert "C" in sequence
# D should trigger twice → once after B and once after C (order depends on runtime)
d_indices = [i for i, s in enumerate(sequence) if s == "D"]
assert len(d_indices) == 1
# Each D trigger must be after corresponding B or C
b_index = sequence.index("B")
c_index = sequence.index("C")
assert any(d > b_index for d in d_indices)
assert any(d > c_index for d in d_indices)
assert result.stop_reason is not None
@pytest.mark.asyncio
async def test_digraph_group_chat_chained_parallel_join_any(runtime: AgentRuntime | None) -> None:
agent_a = _EchoAgent("A", description="Echo agent A")
agent_b = _EchoAgent("B", description="Echo agent B")
agent_c = _EchoAgent("C", description="Echo agent C")
agent_d = _EchoAgent("D", description="Echo agent D")
agent_e = _EchoAgent("E", description="Echo agent E")
graph = DiGraph(
nodes={
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B"), DiGraphEdge(target="C")]),
"B": DiGraphNode(name="B", edges=[DiGraphEdge(target="D")]),
"C": DiGraphNode(name="C", edges=[DiGraphEdge(target="D")]),
"D": DiGraphNode(name="D", edges=[DiGraphEdge(target="E")], activation="any"),
"E": DiGraphNode(name="E", edges=[], activation="any"),
}
)
team = GraphFlow(
participants=[agent_a, agent_b, agent_c, agent_d, agent_e],
graph=graph,
runtime=runtime,
termination_condition=MaxMessageTermination(20),
)
result = await team.run(task="Run chained parallel join-any")
sequence = [msg.source for msg in result.messages if isinstance(msg, TextMessage)]
# D should trigger twice
d_indices = [i for i, s in enumerate(sequence) if s == "D"]
assert len(d_indices) == 1
# Each D trigger must be after corresponding B or C
b_index = sequence.index("B")
c_index = sequence.index("C")
assert any(d > b_index for d in d_indices)
assert any(d > c_index for d in d_indices)
# E should also trigger twice → once after each D
e_indices = [i for i, s in enumerate(sequence) if s == "E"]
assert len(e_indices) == 1
assert e_indices[0] > d_indices[0]
assert result.stop_reason is not None
@pytest.mark.asyncio
async def test_digraph_group_chat_multiple_conditional(runtime: AgentRuntime | None) -> None:
agent_a = _EchoAgent("A", description="Echo agent A")
agent_b = _EchoAgent("B", description="Echo agent B")
agent_c = _EchoAgent("C", description="Echo agent C")
agent_d = _EchoAgent("D", description="Echo agent D")
# Use string conditions
graph = DiGraph(
nodes={
"A": DiGraphNode(
name="A",
edges=[
DiGraphEdge(target="B", condition="apple"),
DiGraphEdge(target="C", condition="banana"),
DiGraphEdge(target="D", condition="cherry"),
],
),
"B": DiGraphNode(name="B", edges=[]),
"C": DiGraphNode(name="C", edges=[]),
"D": DiGraphNode(name="D", edges=[]),
}
)
team = GraphFlow(
participants=[agent_a, agent_b, agent_c, agent_d],
graph=graph,
runtime=runtime,
termination_condition=MaxMessageTermination(5),
)
# Test banana branch
result = await team.run(task="banana")
assert result.messages[2].source == "C"
# Use lambda conditions
graph_with_lambda = DiGraph(
nodes={
"A": DiGraphNode(
name="A",
edges=[
DiGraphEdge(target="B", condition=lambda msg: "apple" in msg.to_model_text()),
DiGraphEdge(target="C", condition=lambda msg: "banana" in msg.to_model_text()),
DiGraphEdge(target="D", condition=lambda msg: "cherry" in msg.to_model_text()),
],
),
"B": DiGraphNode(name="B", edges=[]),
"C": DiGraphNode(name="C", edges=[]),
"D": DiGraphNode(name="D", edges=[]),
}
)
team_with_lambda = GraphFlow(
participants=[agent_a, agent_b, agent_c, agent_d],
graph=graph_with_lambda,
runtime=runtime,
termination_condition=MaxMessageTermination(5),
)
result_with_lambda = await team_with_lambda.run(task="cherry")
assert result_with_lambda.messages[2].source == "D"
class _TestMessageFilterAgentConfig(BaseModel):
name: str
description: str = "Echo test agent"
class _TestMessageFilterAgent(BaseChatAgent, Component[_TestMessageFilterAgentConfig]):
component_config_schema = _TestMessageFilterAgentConfig
component_provider_override = "test_group_chat_graph._TestMessageFilterAgent"
def __init__(self, name: str, description: str = "Echo test agent") -> None:
super().__init__(name=name, description=description)
self.received_messages: list[BaseChatMessage] = []
@property
def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:
return (TextMessage,)
async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response:
self.received_messages.extend(messages)
return Response(chat_message=TextMessage(content="ACK", source=self.name))
async def on_reset(self, cancellation_token: CancellationToken) -> None:
self.received_messages.clear()
def _to_config(self) -> _TestMessageFilterAgentConfig:
return _TestMessageFilterAgentConfig(name=self.name, description=self.description)
@classmethod
def _from_config(cls, config: _TestMessageFilterAgentConfig) -> "_TestMessageFilterAgent":
return cls(name=config.name, description=config.description)
@pytest.mark.asyncio
async def test_message_filter_agent_empty_filter_blocks_all() -> None:
inner_agent = _TestMessageFilterAgent("inner")
wrapper = MessageFilterAgent(
name="wrapper",
wrapped_agent=inner_agent,
filter=MessageFilterConfig(per_source=[]),
)
messages = [
TextMessage(source="user", content="Hello"),
TextMessage(source="system", content="System msg"),
]
await wrapper.on_messages(messages, CancellationToken())
assert len(inner_agent.received_messages) == 0
@pytest.mark.asyncio
async def test_message_filter_agent_with_position_none_gets_all() -> None:
inner_agent = _TestMessageFilterAgent("inner")
wrapper = MessageFilterAgent(
name="wrapper",
wrapped_agent=inner_agent,
filter=MessageFilterConfig(per_source=[PerSourceFilter(source="user", position=None, count=None)]),
)
messages = [
TextMessage(source="user", content="A"),
TextMessage(source="user", content="B"),
TextMessage(source="system", content="Ignore this"),
]
await wrapper.on_messages(messages, CancellationToken())
assert len(inner_agent.received_messages) == 2
assert {m.content for m in inner_agent.received_messages} == {"A", "B"} # type: ignore[attr-defined]
@pytest.mark.asyncio
async def test_digraph_group_chat() -> None:
inner_agent = _TestMessageFilterAgent("agent")
wrapper = MessageFilterAgent(
name="agent",
wrapped_agent=inner_agent,
filter=MessageFilterConfig(
per_source=[
PerSourceFilter(source="user", position="last", count=2),
PerSourceFilter(source="system", position="first", count=1),
]
),
)
config = wrapper.dump_component()
loaded = MessageFilterAgent.load_component(config)
assert loaded.name == "agent"
assert loaded._filter == wrapper._filter # pyright: ignore[reportPrivateUsage]
assert loaded._wrapped_agent.name == wrapper._wrapped_agent.name # pyright: ignore[reportPrivateUsage]
# Run on_messages and validate filtering still works
messages = [
TextMessage(source="user", content="u1"),
TextMessage(source="user", content="u2"),
TextMessage(source="user", content="u3"),
TextMessage(source="system", content="s1"),
TextMessage(source="system", content="s2"),
]
await loaded.on_messages(messages, CancellationToken())
received = loaded._wrapped_agent.received_messages # type: ignore[attr-defined]
assert {m.content for m in received} == {"u2", "u3", "s1"} # pyright: ignore[reportUnknownVariableType,reportUnknownMemberType]
@pytest.mark.asyncio
async def test_message_filter_agent_in_digraph_group_chat(runtime: AgentRuntime | None) -> None:
inner_agent = _TestMessageFilterAgent("filtered")
filtered = MessageFilterAgent(
name="filtered",
wrapped_agent=inner_agent,
filter=MessageFilterConfig(
per_source=[
PerSourceFilter(source="user", position="last", count=1),
]
),
)
graph = DiGraph(
nodes={
"filtered": DiGraphNode(name="filtered", edges=[]),
}
)
team = GraphFlow(
participants=[filtered],
graph=graph,
runtime=runtime,
termination_condition=MaxMessageTermination(3),
)
result = await team.run(task="only last user message matters")
assert result.stop_reason is not None
assert any(msg.source == "filtered" for msg in result.messages)
assert any(msg.content == "ACK" for msg in result.messages if msg.source == "filtered") # type: ignore[attr-defined,union-attr]
@pytest.mark.asyncio
async def test_message_filter_agent_loop_graph_visibility(runtime: AgentRuntime | None) -> None:
agent_a_inner = _TestMessageFilterAgent("A")
agent_a = MessageFilterAgent(
name="A",
wrapped_agent=agent_a_inner,
filter=MessageFilterConfig(
per_source=[
PerSourceFilter(source="user", position="first", count=1),
PerSourceFilter(source="B", position="last", count=1),
]
),
)
from autogen_agentchat.agents import AssistantAgent
from autogen_ext.models.replay import ReplayChatCompletionClient
model_client = ReplayChatCompletionClient(["loop", "loop", "exit"])
agent_b_inner = AssistantAgent("B", model_client=model_client)
agent_b = MessageFilterAgent(
name="B",
wrapped_agent=agent_b_inner,
filter=MessageFilterConfig(
per_source=[
PerSourceFilter(source="user", position="first", count=1),
PerSourceFilter(source="A", position="last", count=1),
PerSourceFilter(source="B", position="last", count=10),
]
),
)
agent_c_inner = _TestMessageFilterAgent("C")
agent_c = MessageFilterAgent(
name="C",
wrapped_agent=agent_c_inner,
filter=MessageFilterConfig(
per_source=[
PerSourceFilter(source="user", position="first", count=1),
PerSourceFilter(source="B", position="last", count=1),
]
),
)
graph = DiGraph(
nodes={
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B")]),
"B": DiGraphNode(
name="B",
edges=[
DiGraphEdge(target="C", condition="exit"),
DiGraphEdge(target="A", condition="loop"),
],
),
"C": DiGraphNode(name="C", edges=[]),
},
default_start_node="A",
)
team = GraphFlow(
participants=[agent_a, agent_b, agent_c],
graph=graph,
runtime=runtime,
termination_condition=MaxMessageTermination(20),
)
result = await team.run(task="Start")
assert result.stop_reason is not None
# Check A received: 1 user + 2 from B
assert [m.source for m in agent_a_inner.received_messages].count("user") == 1
assert [m.source for m in agent_a_inner.received_messages].count("B") == 2
# Check C received: 1 user + 1 from B
assert [m.source for m in agent_c_inner.received_messages].count("user") == 1
assert [m.source for m in agent_c_inner.received_messages].count("B") == 1
# Check B received: 1 user + multiple from A + own messages
model_msgs = await agent_b_inner.model_context.get_messages()
sources = [m.source for m in model_msgs] # type: ignore[union-attr]
assert sources.count("user") == 1 # pyright: ignore[reportUnknownMemberType]
assert sources.count("A") >= 3 # pyright: ignore[reportUnknownMemberType]
assert sources.count("B") >= 2 # pyright: ignore[reportUnknownMemberType]
# Test Graph Builder
def test_add_node() -> None:
client = ReplayChatCompletionClient(["response"])
agent = AssistantAgent("A", model_client=client)
builder = DiGraphBuilder()
builder.add_node(agent)
assert "A" in builder.nodes
assert "A" in builder.agents
assert builder.nodes["A"].activation == "all"
def test_add_edge() -> None:
client = ReplayChatCompletionClient(["1", "2"])
a = AssistantAgent("A", model_client=client)
b = AssistantAgent("B", model_client=client)
builder = DiGraphBuilder()
builder.add_node(a).add_node(b)
builder.add_edge(a, b)
assert builder.nodes["A"].edges[0].target == "B"
assert builder.nodes["A"].edges[0].condition is None
def test_add_conditional_edges() -> None:
client = ReplayChatCompletionClient(["1", "2"])
a = AssistantAgent("A", model_client=client)
b = AssistantAgent("B", model_client=client)
c = AssistantAgent("C", model_client=client)
builder = DiGraphBuilder()
builder.add_node(a).add_node(b).add_node(c)
builder.add_conditional_edges(a, {"yes": b, "no": c})
edges = builder.nodes["A"].edges
assert len(edges) == 2
# Extract the condition strings to compare them
conditions = [e.condition for e in edges]
assert "yes" in conditions
assert "no" in conditions
# Match edge targets with conditions
yes_edge = next(e for e in edges if e.condition == "yes")
no_edge = next(e for e in edges if e.condition == "no")
assert yes_edge.target == "B"
assert no_edge.target == "C"
def test_set_entry_point() -> None:
client = ReplayChatCompletionClient(["ok"])
a = AssistantAgent("A", model_client=client)
builder = DiGraphBuilder().add_node(a).set_entry_point(a)
graph = builder.build()
assert graph.default_start_node == "A"
def test_build_graph_validation() -> None:
client = ReplayChatCompletionClient(["1", "2", "3"])
a = AssistantAgent("A", model_client=client)
b = AssistantAgent("B", model_client=client)
c = AssistantAgent("C", model_client=client)
builder = DiGraphBuilder()
builder.add_node(a).add_node(b).add_node(c)
builder.add_edge("A", "B").add_edge("B", "C")
builder.set_entry_point("A")
graph = builder.build()
assert isinstance(graph, DiGraph)
assert set(graph.nodes.keys()) == {"A", "B", "C"}
assert graph.get_start_nodes() == {"A"}
assert graph.get_leaf_nodes() == {"C"}
def test_build_fan_out() -> None:
client = ReplayChatCompletionClient(["hi"] * 3)
a = AssistantAgent("A", model_client=client)
b = AssistantAgent("B", model_client=client)
c = AssistantAgent("C", model_client=client)
builder = DiGraphBuilder()
builder.add_node(a).add_node(b).add_node(c)
builder.add_edge(a, b).add_edge(a, c)
builder.set_entry_point(a)
graph = builder.build()
assert graph.get_start_nodes() == {"A"}
assert graph.get_leaf_nodes() == {"B", "C"}
def test_build_parallel_join() -> None:
client = ReplayChatCompletionClient(["go"] * 3)
a = AssistantAgent("A", model_client=client)
b = AssistantAgent("B", model_client=client)
c = AssistantAgent("C", model_client=client)
builder = DiGraphBuilder()
builder.add_node(a).add_node(b).add_node(c, activation="all")
builder.add_edge(a, c).add_edge(b, c)
builder.set_entry_point(a)
builder.add_edge(b, c)
builder.nodes["B"] = DiGraphNode(name="B", edges=[DiGraphEdge(target="C")])
graph = builder.build()
assert graph.nodes["C"].activation == "all"
assert graph.get_leaf_nodes() == {"C"}
def test_build_conditional_loop() -> None:
client = ReplayChatCompletionClient(["loop", "loop", "exit"])
a = AssistantAgent("A", model_client=client)
b = AssistantAgent("B", model_client=client)
c = AssistantAgent("C", model_client=client)
builder = DiGraphBuilder()
builder.add_node(a).add_node(b).add_node(c)
builder.add_edge(a, b)
builder.add_conditional_edges(b, {"loop": a, "exit": c})
builder.set_entry_point(a)
graph = builder.build()
# Check that edges have the right conditions and targets
edges = graph.nodes["B"].edges
assert len(edges) == 2
# Find edges by their conditions
loop_edge = next(e for e in edges if e.condition == "loop")
exit_edge = next(e for e in edges if e.condition == "exit")
assert loop_edge.target == "A"
assert exit_edge.target == "C"
assert graph.has_cycles_with_exit()
@pytest.mark.asyncio
async def test_graph_builder_sequential_execution(runtime: AgentRuntime | None) -> None:
a = _EchoAgent("A", description="Echo A")
b = _EchoAgent("B", description="Echo B")
c = _EchoAgent("C", description="Echo C")
builder = DiGraphBuilder()
builder.add_node(a).add_node(b).add_node(c)
builder.add_edge(a, b).add_edge(b, c)
team = GraphFlow(
participants=builder.get_participants(),
graph=builder.build(),
runtime=runtime,
termination_condition=MaxMessageTermination(5),
)
result = await team.run(task="Start")
assert [m.source for m in result.messages[1:-1]] == ["A", "B", "C"]
assert result.stop_reason is not None
@pytest.mark.asyncio
async def test_graph_builder_fan_out(runtime: AgentRuntime | None) -> None:
a = _EchoAgent("A", description="Echo A")
b = _EchoAgent("B", description="Echo B")
c = _EchoAgent("C", description="Echo C")
builder = DiGraphBuilder()
builder.add_node(a).add_node(b).add_node(c)
builder.add_edge(a, b).add_edge(a, c)
team = GraphFlow(
participants=builder.get_participants(),
graph=builder.build(),
runtime=runtime,
termination_condition=MaxMessageTermination(5),
)
result = await team.run(task="Start")
sources = [m.source for m in result.messages if isinstance(m, TextMessage)]
assert set(sources[1:]) == {"A", "B", "C"}
assert result.stop_reason is not None
@pytest.mark.asyncio
async def test_graph_builder_conditional_execution(runtime: AgentRuntime | None) -> None:
a = _EchoAgent("A", description="Echo A")
b = _EchoAgent("B", description="Echo B")
c = _EchoAgent("C", description="Echo C")
builder = DiGraphBuilder()
builder.add_node(a).add_node(b).add_node(c)
builder.add_conditional_edges(a, {"yes": b, "no": c})
team = GraphFlow(
participants=builder.get_participants(),
graph=builder.build(),
runtime=runtime,
termination_condition=MaxMessageTermination(5),
)
# Input "no" should trigger the edge to C
result = await team.run(task="no")
sources = [m.source for m in result.messages]
assert "C" in sources
assert result.stop_reason is not None
@pytest.mark.asyncio
async def test_digraph_group_chat_callable_condition(runtime: AgentRuntime | None) -> None:
"""Test that string conditions work correctly in edge transitions."""
agent_a = _EchoAgent("A", description="Echo agent A")
agent_b = _EchoAgent("B", description="Echo agent B")
agent_c = _EchoAgent("C", description="Echo agent C")
graph = DiGraph(
nodes={
"A": DiGraphNode(
name="A",
edges=[
# Will go to B if "long" is in message
DiGraphEdge(target="B", condition="long"),
# Will go to C if "short" is in message
DiGraphEdge(target="C", condition="short"),
],
),
"B": DiGraphNode(name="B", edges=[]),
"C": DiGraphNode(name="C", edges=[]),
}
)
team = GraphFlow(
participants=[agent_a, agent_b, agent_c],
graph=graph,
runtime=runtime,
termination_condition=MaxMessageTermination(5),
)
# Test with a message containing "long" - should go to B
result = await team.run(task="This is a long message")
assert result.messages[2].source == "B"
# Reset for next test
await team.reset()
# Test with a message containing "short" - should go to C
result = await team.run(task="This is a short message")
assert result.messages[2].source == "C"
@pytest.mark.asyncio
async def test_graph_flow_serialize_deserialize() -> None:
client_a = ReplayChatCompletionClient(list(map(str, range(10))))
client_b = ReplayChatCompletionClient(list(map(str, range(10))))
a = AssistantAgent("A", model_client=client_a)
b = AssistantAgent("B", model_client=client_b)
builder = DiGraphBuilder()
builder.add_node(a).add_node(b)
builder.add_edge(a, b)
builder.set_entry_point(a)
team = GraphFlow(
participants=builder.get_participants(),
graph=builder.build(),
runtime=None,
)
serialized = team.dump_component()
deserialized_team = GraphFlow.load_component(serialized)
serialized_deserialized = deserialized_team.dump_component()
results = await team.run(task="Start")
de_results = await deserialized_team.run(task="Start")
assert serialized == serialized_deserialized
assert compare_task_results(results, de_results)
assert results.stop_reason is not None
assert results.stop_reason == de_results.stop_reason
assert compare_message_lists(results.messages, de_results.messages)
assert isinstance(results.messages[0], TextMessage)
assert results.messages[0].source == "user"
assert results.messages[0].content == "Start"
assert isinstance(results.messages[1], TextMessage)
assert results.messages[1].source == "A"
assert results.messages[1].content == "0"
assert isinstance(results.messages[2], TextMessage)
assert results.messages[2].source == "B"
assert results.messages[2].content == "0"
assert isinstance(results.messages[-1], StopMessage)
assert results.messages[-1].source == _DIGRAPH_STOP_AGENT_NAME
assert results.messages[-1].content == "Digraph execution is complete"
@pytest.mark.asyncio
async def test_graph_flow_stateful_pause_and_resume_with_termination() -> None:
client_a = ReplayChatCompletionClient(["A1", "A2"])
client_b = ReplayChatCompletionClient(["B1"])
a = AssistantAgent("A", model_client=client_a)
b = AssistantAgent("B", model_client=client_b)
builder = DiGraphBuilder()
builder.add_node(a).add_node(b)
builder.add_edge(a, b)
builder.set_entry_point(a)
team = GraphFlow(
participants=builder.get_participants(),
graph=builder.build(),
runtime=None,
termination_condition=SourceMatchTermination(sources=["A"]),
)
result = await team.run(task="Start")
assert len(result.messages) == 2
assert result.messages[0].source == "user"
assert result.messages[1].source == "A"
assert result.stop_reason is not None and result.stop_reason == "'A' answered"
# Export state.
state = await team.save_state()
# Load state into a new team.
new_team = GraphFlow(
participants=builder.get_participants(),
graph=builder.build(),
runtime=None,
)
await new_team.load_state(state)
# Resume.
result = await new_team.run()
assert len(result.messages) == 2
assert result.messages[0].source == "B"
assert result.messages[1].source == _DIGRAPH_STOP_AGENT_NAME
@pytest.mark.asyncio
async def test_builder_with_lambda_condition(runtime: AgentRuntime | None) -> None:
"""Test that DiGraphBuilder supports string conditions."""
agent_a = _EchoAgent("A", description="Echo agent A")
agent_b = _EchoAgent("B", description="Echo agent B")
agent_c = _EchoAgent("C", description="Echo agent C")
builder = DiGraphBuilder()
builder.add_node(agent_a).add_node(agent_b).add_node(agent_c)
# Using callable conditions
builder.add_edge(agent_a, agent_b, lambda msg: "even" in msg.to_model_text())
builder.add_edge(agent_a, agent_c, lambda msg: "odd" in msg.to_model_text())
team = GraphFlow(
participants=builder.get_participants(),
graph=builder.build(),
runtime=runtime,
termination_condition=MaxMessageTermination(5),
)
# Test with "even" in message - should go to B
result = await team.run(task="even length")
assert result.messages[2].source == "B"
# Reset for next test
await team.reset()
# Test with "odd" in message - should go to C
result = await team.run(task="odd message")
assert result.messages[2].source == "C"