2025-04-29 12:06:27 +10:00
|
|
|
import asyncio
|
2025-06-25 12:20:04 +08:00
|
|
|
import re
|
Enable concurrent execution of agents in GraphFlow (#6545)
Support concurrent execution in `GraphFlow`:
- Updated `BaseGroupChatManager.select_speaker` to return a union of a
single string or a list of speaker name strings and added logics to
check for currently activated speakers and only proceed to select next
speakers when all activated speakers have finished.
- Updated existing teams (e.g., `SelectorGroupChat`) with the new
signature, while still returning a single speaker in their
implementations.
- Updated `GraphFlow` to support multiple speakers selected.
- Refactored `GraphFlow` for less dictionary gymnastic by using a queue
and update using `update_message_thread`.
Example: a fan out graph:
```python
import asyncio
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.teams import DiGraphBuilder, GraphFlow
from autogen_ext.models.openai import OpenAIChatCompletionClient
async def main():
# Initialize agents with OpenAI model clients.
model_client = OpenAIChatCompletionClient(model="gpt-4.1-nano")
agent_a = AssistantAgent("A", model_client=model_client, system_message="You are a helpful assistant.")
agent_b = AssistantAgent("B", model_client=model_client, system_message="Translate input to Chinese.")
agent_c = AssistantAgent("C", model_client=model_client, system_message="Translate input to Japanese.")
# Create a directed graph with fan-out flow A -> (B, C).
builder = DiGraphBuilder()
builder.add_node(agent_a).add_node(agent_b).add_node(agent_c)
builder.add_edge(agent_a, agent_b).add_edge(agent_a, agent_c)
graph = builder.build()
# Create a GraphFlow team with the directed graph.
team = GraphFlow(
participants=[agent_a, agent_b, agent_c],
graph=graph,
)
# Run the team and print the events.
async for event in team.run_stream(task="Write a short story about a cat."):
print(event)
asyncio.run(main())
```
Resolves:
#6541
#6533
2025-05-19 14:47:55 -07:00
|
|
|
from typing import AsyncGenerator, List, Sequence
|
|
|
|
from unittest.mock import patch
|
2025-04-29 12:06:27 +10:00
|
|
|
|
|
|
|
import pytest
|
|
|
|
import pytest_asyncio
|
|
|
|
from autogen_agentchat.agents import (
|
|
|
|
AssistantAgent,
|
|
|
|
BaseChatAgent,
|
|
|
|
MessageFilterAgent,
|
|
|
|
MessageFilterConfig,
|
|
|
|
PerSourceFilter,
|
|
|
|
)
|
|
|
|
from autogen_agentchat.base import Response, TaskResult
|
Enable concurrent execution of agents in GraphFlow (#6545)
Support concurrent execution in `GraphFlow`:
- Updated `BaseGroupChatManager.select_speaker` to return a union of a
single string or a list of speaker name strings and added logics to
check for currently activated speakers and only proceed to select next
speakers when all activated speakers have finished.
- Updated existing teams (e.g., `SelectorGroupChat`) with the new
signature, while still returning a single speaker in their
implementations.
- Updated `GraphFlow` to support multiple speakers selected.
- Refactored `GraphFlow` for less dictionary gymnastic by using a queue
and update using `update_message_thread`.
Example: a fan out graph:
```python
import asyncio
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.teams import DiGraphBuilder, GraphFlow
from autogen_ext.models.openai import OpenAIChatCompletionClient
async def main():
# Initialize agents with OpenAI model clients.
model_client = OpenAIChatCompletionClient(model="gpt-4.1-nano")
agent_a = AssistantAgent("A", model_client=model_client, system_message="You are a helpful assistant.")
agent_b = AssistantAgent("B", model_client=model_client, system_message="Translate input to Chinese.")
agent_c = AssistantAgent("C", model_client=model_client, system_message="Translate input to Japanese.")
# Create a directed graph with fan-out flow A -> (B, C).
builder = DiGraphBuilder()
builder.add_node(agent_a).add_node(agent_b).add_node(agent_c)
builder.add_edge(agent_a, agent_b).add_edge(agent_a, agent_c)
graph = builder.build()
# Create a GraphFlow team with the directed graph.
team = GraphFlow(
participants=[agent_a, agent_b, agent_c],
graph=graph,
)
# Run the team and print the events.
async for event in team.run_stream(task="Write a short story about a cat."):
print(event)
asyncio.run(main())
```
Resolves:
#6541
#6533
2025-05-19 14:47:55 -07:00
|
|
|
from autogen_agentchat.conditions import MaxMessageTermination, SourceMatchTermination
|
2025-05-01 03:25:20 +09:00
|
|
|
from autogen_agentchat.messages import BaseChatMessage, ChatMessage, MessageFactory, StopMessage, TextMessage
|
2025-04-29 12:06:27 +10:00
|
|
|
from autogen_agentchat.teams import (
|
|
|
|
DiGraphBuilder,
|
|
|
|
GraphFlow,
|
|
|
|
)
|
|
|
|
from autogen_agentchat.teams._group_chat._events import ( # type: ignore[attr-defined]
|
|
|
|
BaseAgentEvent,
|
|
|
|
GroupChatTermination,
|
|
|
|
)
|
|
|
|
from autogen_agentchat.teams._group_chat._graph._digraph_group_chat import (
|
|
|
|
_DIGRAPH_STOP_AGENT_NAME, # pyright: ignore[reportPrivateUsage]
|
|
|
|
DiGraph,
|
|
|
|
DiGraphEdge,
|
|
|
|
DiGraphNode,
|
|
|
|
GraphFlowManager,
|
|
|
|
)
|
|
|
|
from autogen_core import AgentRuntime, CancellationToken, Component, SingleThreadedAgentRuntime
|
|
|
|
from autogen_ext.models.replay import ReplayChatCompletionClient
|
|
|
|
from pydantic import BaseModel
|
2025-05-23 14:29:24 +09:00
|
|
|
from utils import compare_message_lists, compare_task_results
|
2025-04-29 12:06:27 +10:00
|
|
|
|
|
|
|
|
|
|
|
def test_create_digraph() -> None:
|
|
|
|
"""Test creating a simple directed graph."""
|
|
|
|
graph = DiGraph(
|
|
|
|
nodes={
|
|
|
|
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B")]),
|
|
|
|
"B": DiGraphNode(name="B", edges=[DiGraphEdge(target="C")]),
|
|
|
|
"C": DiGraphNode(name="C", edges=[]),
|
|
|
|
}
|
|
|
|
)
|
|
|
|
|
|
|
|
assert "A" in graph.nodes
|
|
|
|
assert "B" in graph.nodes
|
|
|
|
assert "C" in graph.nodes
|
|
|
|
assert len(graph.nodes["A"].edges) == 1
|
|
|
|
assert len(graph.nodes["B"].edges) == 1
|
|
|
|
assert len(graph.nodes["C"].edges) == 0
|
|
|
|
|
|
|
|
|
|
|
|
def test_get_parents() -> None:
|
|
|
|
"""Test computing parent relationships."""
|
|
|
|
graph = DiGraph(
|
|
|
|
nodes={
|
|
|
|
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B")]),
|
|
|
|
"B": DiGraphNode(name="B", edges=[DiGraphEdge(target="C")]),
|
|
|
|
"C": DiGraphNode(name="C", edges=[]),
|
|
|
|
}
|
|
|
|
)
|
|
|
|
|
|
|
|
parents = graph.get_parents()
|
|
|
|
assert parents["A"] == []
|
|
|
|
assert parents["B"] == ["A"]
|
|
|
|
assert parents["C"] == ["B"]
|
|
|
|
|
|
|
|
|
|
|
|
def test_get_start_nodes() -> None:
|
|
|
|
"""Test retrieving start nodes (nodes with no incoming edges)."""
|
|
|
|
graph = DiGraph(
|
|
|
|
nodes={
|
|
|
|
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B")]),
|
|
|
|
"B": DiGraphNode(name="B", edges=[DiGraphEdge(target="C")]),
|
|
|
|
"C": DiGraphNode(name="C", edges=[]),
|
|
|
|
}
|
|
|
|
)
|
|
|
|
|
|
|
|
start_nodes = graph.get_start_nodes()
|
|
|
|
assert start_nodes == set(["A"])
|
|
|
|
|
|
|
|
|
|
|
|
def test_get_leaf_nodes() -> None:
|
|
|
|
"""Test retrieving leaf nodes (nodes with no outgoing edges)."""
|
|
|
|
graph = DiGraph(
|
|
|
|
nodes={
|
|
|
|
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B")]),
|
|
|
|
"B": DiGraphNode(name="B", edges=[DiGraphEdge(target="C")]),
|
|
|
|
"C": DiGraphNode(name="C", edges=[]),
|
|
|
|
}
|
|
|
|
)
|
|
|
|
|
|
|
|
leaf_nodes = graph.get_leaf_nodes()
|
|
|
|
assert leaf_nodes == set(["C"])
|
|
|
|
|
|
|
|
|
|
|
|
def test_serialization() -> None:
|
|
|
|
"""Test serializing and deserializing the graph."""
|
Add callable condition for GraphFlow edges (#6623)
This PR adds callable as an option to specify conditional edges in
GraphFlow.
```python
import asyncio
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.conditions import MaxMessageTermination
from autogen_agentchat.teams import DiGraphBuilder, GraphFlow
from autogen_ext.models.openai import OpenAIChatCompletionClient
async def main():
# Initialize agents with OpenAI model clients.
model_client = OpenAIChatCompletionClient(model="gpt-4.1-nano")
agent_a = AssistantAgent(
"A",
model_client=model_client,
system_message="Detect if the input is in Chinese. If it is, say 'yes', else say 'no', and nothing else.",
)
agent_b = AssistantAgent("B", model_client=model_client, system_message="Translate input to English.")
agent_c = AssistantAgent("C", model_client=model_client, system_message="Translate input to Chinese.")
# Create a directed graph with conditional branching flow A -> B ("yes"), A -> C (otherwise).
builder = DiGraphBuilder()
builder.add_node(agent_a).add_node(agent_b).add_node(agent_c)
# Create conditions as callables that check the message content.
builder.add_edge(agent_a, agent_b, condition=lambda msg: "yes" in msg.to_model_text())
builder.add_edge(agent_a, agent_c, condition=lambda msg: "yes" not in msg.to_model_text())
graph = builder.build()
# Create a GraphFlow team with the directed graph.
team = GraphFlow(
participants=[agent_a, agent_b, agent_c],
graph=graph,
termination_condition=MaxMessageTermination(5),
)
# Run the team and print the events.
async for event in team.run_stream(task="AutoGen is a framework for building AI agents."):
print(event)
asyncio.run(main())
```
---------
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: ekzhu <320302+ekzhu@users.noreply.github.com>
2025-06-04 15:43:26 -07:00
|
|
|
# Use a string condition instead of a lambda
|
2025-04-29 12:06:27 +10:00
|
|
|
graph = DiGraph(
|
|
|
|
nodes={
|
|
|
|
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B", condition="trigger1")]),
|
|
|
|
"B": DiGraphNode(name="B", edges=[DiGraphEdge(target="C")]),
|
|
|
|
"C": DiGraphNode(name="C", edges=[]),
|
|
|
|
}
|
|
|
|
)
|
|
|
|
|
|
|
|
serialized = graph.model_dump_json()
|
|
|
|
deserialized_graph = DiGraph.model_validate_json(serialized)
|
|
|
|
|
|
|
|
assert deserialized_graph.nodes["A"].edges[0].target == "B"
|
|
|
|
assert deserialized_graph.nodes["A"].edges[0].condition == "trigger1"
|
|
|
|
assert deserialized_graph.nodes["B"].edges[0].target == "C"
|
|
|
|
|
Add callable condition for GraphFlow edges (#6623)
This PR adds callable as an option to specify conditional edges in
GraphFlow.
```python
import asyncio
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.conditions import MaxMessageTermination
from autogen_agentchat.teams import DiGraphBuilder, GraphFlow
from autogen_ext.models.openai import OpenAIChatCompletionClient
async def main():
# Initialize agents with OpenAI model clients.
model_client = OpenAIChatCompletionClient(model="gpt-4.1-nano")
agent_a = AssistantAgent(
"A",
model_client=model_client,
system_message="Detect if the input is in Chinese. If it is, say 'yes', else say 'no', and nothing else.",
)
agent_b = AssistantAgent("B", model_client=model_client, system_message="Translate input to English.")
agent_c = AssistantAgent("C", model_client=model_client, system_message="Translate input to Chinese.")
# Create a directed graph with conditional branching flow A -> B ("yes"), A -> C (otherwise).
builder = DiGraphBuilder()
builder.add_node(agent_a).add_node(agent_b).add_node(agent_c)
# Create conditions as callables that check the message content.
builder.add_edge(agent_a, agent_b, condition=lambda msg: "yes" in msg.to_model_text())
builder.add_edge(agent_a, agent_c, condition=lambda msg: "yes" not in msg.to_model_text())
graph = builder.build()
# Create a GraphFlow team with the directed graph.
team = GraphFlow(
participants=[agent_a, agent_b, agent_c],
graph=graph,
termination_condition=MaxMessageTermination(5),
)
# Run the team and print the events.
async for event in team.run_stream(task="AutoGen is a framework for building AI agents."):
print(event)
asyncio.run(main())
```
---------
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: ekzhu <320302+ekzhu@users.noreply.github.com>
2025-06-04 15:43:26 -07:00
|
|
|
# Test the original condition works
|
|
|
|
test_msg = TextMessage(content="this has trigger1 in it", source="test")
|
|
|
|
# Manually check if the string is in the message text
|
|
|
|
assert "trigger1" in test_msg.to_model_text()
|
|
|
|
|
2025-04-29 12:06:27 +10:00
|
|
|
|
|
|
|
def test_invalid_graph_no_start_node() -> None:
|
|
|
|
"""Test validation failure when there is no start node."""
|
|
|
|
graph = DiGraph(
|
|
|
|
nodes={
|
|
|
|
"B": DiGraphNode(name="B", edges=[DiGraphEdge(target="C")]),
|
|
|
|
"C": DiGraphNode(name="C", edges=[DiGraphEdge(target="B")]), # Forms a cycle
|
|
|
|
}
|
|
|
|
)
|
|
|
|
|
|
|
|
start_nodes = graph.get_start_nodes()
|
|
|
|
assert len(start_nodes) == 0 # Now it correctly fails when no start nodes exist
|
|
|
|
|
|
|
|
|
|
|
|
def test_invalid_graph_no_leaf_node() -> None:
|
|
|
|
"""Test validation failure when there is no leaf node."""
|
|
|
|
graph = DiGraph(
|
|
|
|
nodes={
|
|
|
|
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B")]),
|
|
|
|
"B": DiGraphNode(name="B", edges=[DiGraphEdge(target="C")]),
|
|
|
|
"C": DiGraphNode(name="C", edges=[DiGraphEdge(target="A")]), # Circular reference
|
|
|
|
}
|
|
|
|
)
|
|
|
|
|
|
|
|
leaf_nodes = graph.get_leaf_nodes()
|
|
|
|
assert len(leaf_nodes) == 0 # No true endpoint because of cycle
|
|
|
|
|
|
|
|
|
|
|
|
def test_condition_edge_execution() -> None:
|
|
|
|
"""Test conditional edge execution support."""
|
Add callable condition for GraphFlow edges (#6623)
This PR adds callable as an option to specify conditional edges in
GraphFlow.
```python
import asyncio
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.conditions import MaxMessageTermination
from autogen_agentchat.teams import DiGraphBuilder, GraphFlow
from autogen_ext.models.openai import OpenAIChatCompletionClient
async def main():
# Initialize agents with OpenAI model clients.
model_client = OpenAIChatCompletionClient(model="gpt-4.1-nano")
agent_a = AssistantAgent(
"A",
model_client=model_client,
system_message="Detect if the input is in Chinese. If it is, say 'yes', else say 'no', and nothing else.",
)
agent_b = AssistantAgent("B", model_client=model_client, system_message="Translate input to English.")
agent_c = AssistantAgent("C", model_client=model_client, system_message="Translate input to Chinese.")
# Create a directed graph with conditional branching flow A -> B ("yes"), A -> C (otherwise).
builder = DiGraphBuilder()
builder.add_node(agent_a).add_node(agent_b).add_node(agent_c)
# Create conditions as callables that check the message content.
builder.add_edge(agent_a, agent_b, condition=lambda msg: "yes" in msg.to_model_text())
builder.add_edge(agent_a, agent_c, condition=lambda msg: "yes" not in msg.to_model_text())
graph = builder.build()
# Create a GraphFlow team with the directed graph.
team = GraphFlow(
participants=[agent_a, agent_b, agent_c],
graph=graph,
termination_condition=MaxMessageTermination(5),
)
# Run the team and print the events.
async for event in team.run_stream(task="AutoGen is a framework for building AI agents."):
print(event)
asyncio.run(main())
```
---------
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: ekzhu <320302+ekzhu@users.noreply.github.com>
2025-06-04 15:43:26 -07:00
|
|
|
# Use string condition
|
2025-04-29 12:06:27 +10:00
|
|
|
graph = DiGraph(
|
|
|
|
nodes={
|
|
|
|
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B", condition="TRIGGER")]),
|
|
|
|
"B": DiGraphNode(name="B", edges=[DiGraphEdge(target="C")]),
|
|
|
|
"C": DiGraphNode(name="C", edges=[]),
|
|
|
|
}
|
|
|
|
)
|
|
|
|
|
Add callable condition for GraphFlow edges (#6623)
This PR adds callable as an option to specify conditional edges in
GraphFlow.
```python
import asyncio
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.conditions import MaxMessageTermination
from autogen_agentchat.teams import DiGraphBuilder, GraphFlow
from autogen_ext.models.openai import OpenAIChatCompletionClient
async def main():
# Initialize agents with OpenAI model clients.
model_client = OpenAIChatCompletionClient(model="gpt-4.1-nano")
agent_a = AssistantAgent(
"A",
model_client=model_client,
system_message="Detect if the input is in Chinese. If it is, say 'yes', else say 'no', and nothing else.",
)
agent_b = AssistantAgent("B", model_client=model_client, system_message="Translate input to English.")
agent_c = AssistantAgent("C", model_client=model_client, system_message="Translate input to Chinese.")
# Create a directed graph with conditional branching flow A -> B ("yes"), A -> C (otherwise).
builder = DiGraphBuilder()
builder.add_node(agent_a).add_node(agent_b).add_node(agent_c)
# Create conditions as callables that check the message content.
builder.add_edge(agent_a, agent_b, condition=lambda msg: "yes" in msg.to_model_text())
builder.add_edge(agent_a, agent_c, condition=lambda msg: "yes" not in msg.to_model_text())
graph = builder.build()
# Create a GraphFlow team with the directed graph.
team = GraphFlow(
participants=[agent_a, agent_b, agent_c],
graph=graph,
termination_condition=MaxMessageTermination(5),
)
# Run the team and print the events.
async for event in team.run_stream(task="AutoGen is a framework for building AI agents."):
print(event)
asyncio.run(main())
```
---------
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: ekzhu <320302+ekzhu@users.noreply.github.com>
2025-06-04 15:43:26 -07:00
|
|
|
# Check the condition manually
|
|
|
|
test_message = TextMessage(content="This has TRIGGER in it", source="test")
|
|
|
|
non_match_message = TextMessage(content="This doesn't match", source="test")
|
|
|
|
|
|
|
|
# Check if the string condition is in each message text
|
|
|
|
assert "TRIGGER" in test_message.to_model_text()
|
|
|
|
assert "TRIGGER" not in non_match_message.to_model_text()
|
|
|
|
|
|
|
|
# Check the condition itself
|
2025-04-29 12:06:27 +10:00
|
|
|
assert graph.nodes["A"].edges[0].condition == "TRIGGER"
|
|
|
|
assert graph.nodes["B"].edges[0].condition is None
|
|
|
|
|
|
|
|
|
|
|
|
def test_graph_with_multiple_paths() -> None:
|
|
|
|
"""Test a graph with multiple execution paths."""
|
|
|
|
graph = DiGraph(
|
|
|
|
nodes={
|
|
|
|
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B"), DiGraphEdge(target="C")]),
|
|
|
|
"B": DiGraphNode(name="B", edges=[DiGraphEdge(target="D")]),
|
|
|
|
"C": DiGraphNode(name="C", edges=[DiGraphEdge(target="D")]),
|
|
|
|
"D": DiGraphNode(name="D", edges=[]),
|
|
|
|
}
|
|
|
|
)
|
|
|
|
|
|
|
|
parents = graph.get_parents()
|
|
|
|
assert parents["B"] == ["A"]
|
|
|
|
assert parents["C"] == ["A"]
|
|
|
|
assert parents["D"] == ["B", "C"]
|
|
|
|
|
|
|
|
start_nodes = graph.get_start_nodes()
|
|
|
|
assert start_nodes == set(["A"])
|
|
|
|
|
|
|
|
leaf_nodes = graph.get_leaf_nodes()
|
|
|
|
assert leaf_nodes == set(["D"])
|
|
|
|
|
|
|
|
|
|
|
|
def test_cycle_detection_no_cycle() -> None:
|
|
|
|
"""Test that a valid acyclic graph returns False for cycle check."""
|
|
|
|
graph = DiGraph(
|
|
|
|
nodes={
|
|
|
|
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B")]),
|
|
|
|
"B": DiGraphNode(name="B", edges=[DiGraphEdge(target="C")]),
|
|
|
|
"C": DiGraphNode(name="C", edges=[]),
|
|
|
|
}
|
|
|
|
)
|
|
|
|
assert not graph.has_cycles_with_exit()
|
|
|
|
|
|
|
|
|
|
|
|
def test_cycle_detection_with_exit_condition() -> None:
|
|
|
|
"""Test a graph with cycle and conditional exit passes validation."""
|
Add callable condition for GraphFlow edges (#6623)
This PR adds callable as an option to specify conditional edges in
GraphFlow.
```python
import asyncio
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.conditions import MaxMessageTermination
from autogen_agentchat.teams import DiGraphBuilder, GraphFlow
from autogen_ext.models.openai import OpenAIChatCompletionClient
async def main():
# Initialize agents with OpenAI model clients.
model_client = OpenAIChatCompletionClient(model="gpt-4.1-nano")
agent_a = AssistantAgent(
"A",
model_client=model_client,
system_message="Detect if the input is in Chinese. If it is, say 'yes', else say 'no', and nothing else.",
)
agent_b = AssistantAgent("B", model_client=model_client, system_message="Translate input to English.")
agent_c = AssistantAgent("C", model_client=model_client, system_message="Translate input to Chinese.")
# Create a directed graph with conditional branching flow A -> B ("yes"), A -> C (otherwise).
builder = DiGraphBuilder()
builder.add_node(agent_a).add_node(agent_b).add_node(agent_c)
# Create conditions as callables that check the message content.
builder.add_edge(agent_a, agent_b, condition=lambda msg: "yes" in msg.to_model_text())
builder.add_edge(agent_a, agent_c, condition=lambda msg: "yes" not in msg.to_model_text())
graph = builder.build()
# Create a GraphFlow team with the directed graph.
team = GraphFlow(
participants=[agent_a, agent_b, agent_c],
graph=graph,
termination_condition=MaxMessageTermination(5),
)
# Run the team and print the events.
async for event in team.run_stream(task="AutoGen is a framework for building AI agents."):
print(event)
asyncio.run(main())
```
---------
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: ekzhu <320302+ekzhu@users.noreply.github.com>
2025-06-04 15:43:26 -07:00
|
|
|
# Use a string condition
|
2025-04-29 12:06:27 +10:00
|
|
|
graph = DiGraph(
|
|
|
|
nodes={
|
|
|
|
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B")]),
|
|
|
|
"B": DiGraphNode(name="B", edges=[DiGraphEdge(target="C")]),
|
|
|
|
"C": DiGraphNode(name="C", edges=[DiGraphEdge(target="A", condition="exit")]), # Cycle with condition
|
|
|
|
}
|
|
|
|
)
|
|
|
|
assert graph.has_cycles_with_exit()
|
|
|
|
|
2025-06-04 21:35:27 -07:00
|
|
|
# Use a lambda condition
|
|
|
|
graph_with_lambda = DiGraph(
|
|
|
|
nodes={
|
|
|
|
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B")]),
|
|
|
|
"B": DiGraphNode(name="B", edges=[DiGraphEdge(target="C")]),
|
|
|
|
"C": DiGraphNode(
|
|
|
|
name="C", edges=[DiGraphEdge(target="A", condition=lambda msg: "test" in msg.to_model_text())]
|
|
|
|
), # Cycle with lambda
|
|
|
|
}
|
|
|
|
)
|
|
|
|
assert graph_with_lambda.has_cycles_with_exit()
|
|
|
|
|
2025-04-29 12:06:27 +10:00
|
|
|
|
|
|
|
def test_cycle_detection_without_exit_condition() -> None:
|
|
|
|
"""Test that cycle without exit condition raises an error."""
|
|
|
|
graph = DiGraph(
|
|
|
|
nodes={
|
|
|
|
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B")]),
|
|
|
|
"B": DiGraphNode(name="B", edges=[DiGraphEdge(target="C")]),
|
|
|
|
"C": DiGraphNode(name="C", edges=[DiGraphEdge(target="A")]), # Cycle without condition
|
|
|
|
"D": DiGraphNode(name="D", edges=[DiGraphEdge(target="E")]),
|
|
|
|
"E": DiGraphNode(name="E", edges=[]),
|
|
|
|
}
|
|
|
|
)
|
|
|
|
with pytest.raises(ValueError, match="Cycle detected without exit condition: A -> B -> C -> A"):
|
|
|
|
graph.has_cycles_with_exit()
|
|
|
|
|
|
|
|
|
2025-06-25 12:20:04 +08:00
|
|
|
def test_different_activation_groups_detection() -> None:
|
|
|
|
"""Test different activation groups."""
|
|
|
|
graph = DiGraph(
|
|
|
|
nodes={
|
|
|
|
"A": DiGraphNode(
|
|
|
|
name="A",
|
|
|
|
edges=[
|
|
|
|
DiGraphEdge(target="B"),
|
|
|
|
DiGraphEdge(target="C"),
|
|
|
|
],
|
|
|
|
),
|
|
|
|
"B": DiGraphNode(name="B", edges=[DiGraphEdge(target="D", activation_condition="all")]),
|
|
|
|
"C": DiGraphNode(name="C", edges=[DiGraphEdge(target="D", activation_condition="any")]),
|
|
|
|
"D": DiGraphNode(name="D", edges=[]),
|
|
|
|
}
|
|
|
|
)
|
|
|
|
with pytest.raises(
|
|
|
|
ValueError,
|
|
|
|
match=re.escape(
|
|
|
|
"Conflicting activation conditions for target 'D' group 'D': "
|
|
|
|
"'all' (from node 'B') and 'any' (from node 'C')"
|
|
|
|
),
|
|
|
|
):
|
|
|
|
graph.graph_validate()
|
|
|
|
|
|
|
|
|
2025-04-29 12:06:27 +10:00
|
|
|
def test_validate_graph_success() -> None:
|
|
|
|
"""Test successful validation of a valid graph."""
|
|
|
|
graph = DiGraph(
|
|
|
|
nodes={
|
|
|
|
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B")]),
|
|
|
|
"B": DiGraphNode(name="B", edges=[]),
|
|
|
|
}
|
|
|
|
)
|
|
|
|
# No error should be raised
|
|
|
|
graph.graph_validate()
|
|
|
|
assert not graph.get_has_cycles()
|
|
|
|
|
2025-06-04 22:05:16 -07:00
|
|
|
# Use a lambda condition
|
|
|
|
graph_with_lambda = DiGraph(
|
|
|
|
nodes={
|
|
|
|
"A": DiGraphNode(
|
|
|
|
name="A", edges=[DiGraphEdge(target="B", condition=lambda msg: "test" in msg.to_model_text())]
|
|
|
|
),
|
|
|
|
"B": DiGraphNode(name="B", edges=[]),
|
|
|
|
}
|
|
|
|
)
|
|
|
|
# No error should be raised
|
|
|
|
graph_with_lambda.graph_validate()
|
|
|
|
assert not graph_with_lambda.get_has_cycles()
|
|
|
|
|
2025-04-29 12:06:27 +10:00
|
|
|
|
|
|
|
def test_validate_graph_missing_start_node() -> None:
|
|
|
|
"""Test validation failure when no start node exists."""
|
|
|
|
graph = DiGraph(
|
|
|
|
nodes={
|
|
|
|
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B")]),
|
|
|
|
"B": DiGraphNode(name="B", edges=[DiGraphEdge(target="A")]), # Cycle
|
|
|
|
}
|
|
|
|
)
|
|
|
|
with pytest.raises(ValueError, match="Graph must have at least one start node"):
|
|
|
|
graph.graph_validate()
|
|
|
|
|
|
|
|
|
|
|
|
def test_validate_graph_missing_leaf_node() -> None:
|
|
|
|
"""Test validation failure when no leaf node exists."""
|
|
|
|
graph = DiGraph(
|
|
|
|
nodes={
|
|
|
|
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B")]),
|
|
|
|
"B": DiGraphNode(name="B", edges=[DiGraphEdge(target="C")]),
|
|
|
|
"C": DiGraphNode(name="C", edges=[DiGraphEdge(target="B")]), # Cycle
|
|
|
|
}
|
|
|
|
)
|
|
|
|
with pytest.raises(ValueError, match="Graph must have at least one leaf node"):
|
|
|
|
graph.graph_validate()
|
|
|
|
|
|
|
|
|
|
|
|
def test_validate_graph_mixed_conditions() -> None:
|
|
|
|
"""Test validation failure when node has mixed conditional and unconditional edges."""
|
Add callable condition for GraphFlow edges (#6623)
This PR adds callable as an option to specify conditional edges in
GraphFlow.
```python
import asyncio
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.conditions import MaxMessageTermination
from autogen_agentchat.teams import DiGraphBuilder, GraphFlow
from autogen_ext.models.openai import OpenAIChatCompletionClient
async def main():
# Initialize agents with OpenAI model clients.
model_client = OpenAIChatCompletionClient(model="gpt-4.1-nano")
agent_a = AssistantAgent(
"A",
model_client=model_client,
system_message="Detect if the input is in Chinese. If it is, say 'yes', else say 'no', and nothing else.",
)
agent_b = AssistantAgent("B", model_client=model_client, system_message="Translate input to English.")
agent_c = AssistantAgent("C", model_client=model_client, system_message="Translate input to Chinese.")
# Create a directed graph with conditional branching flow A -> B ("yes"), A -> C (otherwise).
builder = DiGraphBuilder()
builder.add_node(agent_a).add_node(agent_b).add_node(agent_c)
# Create conditions as callables that check the message content.
builder.add_edge(agent_a, agent_b, condition=lambda msg: "yes" in msg.to_model_text())
builder.add_edge(agent_a, agent_c, condition=lambda msg: "yes" not in msg.to_model_text())
graph = builder.build()
# Create a GraphFlow team with the directed graph.
team = GraphFlow(
participants=[agent_a, agent_b, agent_c],
graph=graph,
termination_condition=MaxMessageTermination(5),
)
# Run the team and print the events.
async for event in team.run_stream(task="AutoGen is a framework for building AI agents."):
print(event)
asyncio.run(main())
```
---------
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: ekzhu <320302+ekzhu@users.noreply.github.com>
2025-06-04 15:43:26 -07:00
|
|
|
# Use string for condition
|
2025-04-29 12:06:27 +10:00
|
|
|
graph = DiGraph(
|
|
|
|
nodes={
|
|
|
|
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B", condition="cond"), DiGraphEdge(target="C")]),
|
|
|
|
"B": DiGraphNode(name="B", edges=[]),
|
|
|
|
"C": DiGraphNode(name="C", edges=[]),
|
|
|
|
}
|
|
|
|
)
|
|
|
|
with pytest.raises(ValueError, match="Node 'A' has a mix of conditional and unconditional edges"):
|
|
|
|
graph.graph_validate()
|
|
|
|
|
2025-06-04 22:05:16 -07:00
|
|
|
# Use lambda for condition
|
|
|
|
graph_with_lambda = DiGraph(
|
|
|
|
nodes={
|
|
|
|
"A": DiGraphNode(
|
|
|
|
name="A",
|
|
|
|
edges=[
|
|
|
|
DiGraphEdge(target="B", condition=lambda msg: "test" in msg.to_model_text()),
|
|
|
|
DiGraphEdge(target="C"),
|
|
|
|
],
|
|
|
|
),
|
|
|
|
"B": DiGraphNode(name="B", edges=[]),
|
|
|
|
"C": DiGraphNode(name="C", edges=[]),
|
|
|
|
}
|
|
|
|
)
|
|
|
|
with pytest.raises(ValueError, match="Node 'A' has a mix of conditional and unconditional edges"):
|
|
|
|
graph_with_lambda.graph_validate()
|
|
|
|
|
2025-04-29 12:06:27 +10:00
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
async def test_invalid_digraph_manager_cycle_without_termination() -> None:
|
|
|
|
"""Test GraphManager raises error for cyclic graph without termination condition."""
|
|
|
|
# Create a cyclic graph A → B → A
|
|
|
|
graph = DiGraph(
|
|
|
|
nodes={
|
|
|
|
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B")]),
|
|
|
|
"B": DiGraphNode(name="B", edges=[DiGraphEdge(target="A")]),
|
|
|
|
}
|
|
|
|
)
|
|
|
|
|
|
|
|
output_queue: asyncio.Queue[BaseAgentEvent | BaseChatMessage | GroupChatTermination] = asyncio.Queue()
|
|
|
|
|
|
|
|
with patch(
|
|
|
|
"autogen_agentchat.teams._group_chat._base_group_chat_manager.BaseGroupChatManager.__init__",
|
|
|
|
return_value=None,
|
|
|
|
):
|
|
|
|
manager = GraphFlowManager.__new__(GraphFlowManager)
|
|
|
|
|
|
|
|
with pytest.raises(ValueError, match="Graph must have at least one start node"):
|
|
|
|
manager.__init__( # type: ignore[misc]
|
|
|
|
name="test_manager",
|
|
|
|
group_topic_type="topic",
|
|
|
|
output_topic_type="topic",
|
|
|
|
participant_topic_types=["topic1", "topic2"],
|
|
|
|
participant_names=["A", "B"],
|
|
|
|
participant_descriptions=["Agent A", "Agent B"],
|
|
|
|
output_message_queue=output_queue,
|
|
|
|
termination_condition=None,
|
|
|
|
max_turns=None,
|
|
|
|
message_factory=MessageFactory(),
|
|
|
|
graph=graph,
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
class _EchoAgent(BaseChatAgent):
|
|
|
|
def __init__(self, name: str, description: str) -> None:
|
|
|
|
super().__init__(name, description)
|
|
|
|
self._last_message: str | None = None
|
|
|
|
self._total_messages = 0
|
|
|
|
|
|
|
|
@property
|
|
|
|
def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:
|
|
|
|
return (TextMessage,)
|
|
|
|
|
|
|
|
@property
|
|
|
|
def total_messages(self) -> int:
|
|
|
|
return self._total_messages
|
|
|
|
|
|
|
|
async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response:
|
|
|
|
if len(messages) > 0:
|
Fix GraphFlow to support multiple task execution without explicit reset (#6747)
## Problem
When using GraphFlow with a termination condition, the second task
execution would immediately terminate without running any agents. The
first task would run successfully, but subsequent tasks would skip all
agents and go directly to the stop agent.
This was demonstrated by the following issue:
```python
# First task runs correctly
result1 = await team.run(task="First task") # ✅ Works fine
# Second task fails immediately
result2 = await team.run(task="Second task") # ❌ Only user + stop messages
```
## Root Cause
The `GraphFlowManager` was not resetting its execution state when
termination occurred. After the first task completed:
1. The `_ready` queue was empty (all nodes had been processed)
2. The `_remaining` and `_enqueued_any` tracking structures remained in
"completed" state
3. The `_message_thread` retained history from the previous task
This left the graph in a "completed" state, causing subsequent tasks to
immediately trigger the stop agent instead of executing the workflow.
## Solution
Added an override of the `_apply_termination_condition` method in
`GraphFlowManager` to automatically reset the graph execution state when
termination occurs:
```python
async def _apply_termination_condition(
self, delta: Sequence[BaseAgentEvent | BaseChatMessage], increment_turn_count: bool = False
) -> bool:
# Call the base implementation first
terminated = await super()._apply_termination_condition(delta, increment_turn_count)
# If terminated, reset the graph execution state and message thread for the next task
if terminated:
self._remaining = {target: Counter(groups) for target, groups in self._graph.get_remaining_map().items()}
self._enqueued_any = {n: {g: False for g in self._enqueued_any[n]} for n in self._enqueued_any}
self._ready = deque([n for n in self._graph.get_start_nodes()])
# Clear the message thread to start fresh for the next task
self._message_thread.clear()
return terminated
```
This ensures that when a task completes (termination condition is met),
the graph is automatically reset to its initial state ready for the next
task.
## Testing
Added a comprehensive test case
`test_digraph_group_chat_multiple_task_execution` that validates:
- Multiple tasks can be run sequentially without explicit reset calls
- All agents are executed the expected number of times
- Both tasks produce the correct number of messages
- The fix works with various termination conditions
(MaxMessageTermination, TextMentionTermination)
## Result
GraphFlow now works like SelectorGroupChat where multiple tasks can be
run sequentially without explicit resets between them:
```python
# Both tasks now work correctly
result1 = await team.run(task="First task") # ✅ 5 messages, all agents called
result2 = await team.run(task="Second task") # ✅ 5 messages, all agents called again
```
Fixes #6746.
> [!WARNING]
>
> <details>
> <summary>Firewall rules blocked me from connecting to one or more
addresses</summary>
>
> #### I tried to connect to the following addresses, but was blocked by
firewall rules:
>
> - `esm.ubuntu.com`
> - Triggering command: `/usr/lib/apt/methods/https` (dns block)
>
> If you need me to access, download, or install something from one of
these locations, you can either:
>
> - Configure [Actions setup
steps](https://gh.io/copilot/actions-setup-steps) to set up my
environment, which run before the firewall is enabled
> - Add the appropriate URLs or hosts to my [firewall allow
list](https://gh.io/copilot/firewall-config)
>
> </details>
<!-- START COPILOT CODING AGENT TIPS -->
---
💬 Share your feedback on Copilot coding agent for the chance to win a
$200 gift card! Click
[here](https://survey.alchemer.com/s3/8343779/Copilot-Coding-agent) to
start the survey.
---------
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: ekzhu <320302+ekzhu@users.noreply.github.com>
Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
2025-07-05 23:32:50 -07:00
|
|
|
assert isinstance(messages[0], TextMessage) or isinstance(messages[0], StopMessage)
|
2025-04-29 12:06:27 +10:00
|
|
|
self._last_message = messages[0].content
|
|
|
|
self._total_messages += 1
|
|
|
|
return Response(chat_message=TextMessage(content=messages[0].content, source=self.name))
|
|
|
|
else:
|
|
|
|
assert self._last_message is not None
|
|
|
|
self._total_messages += 1
|
|
|
|
return Response(chat_message=TextMessage(content=self._last_message, source=self.name))
|
|
|
|
|
|
|
|
async def on_reset(self, cancellation_token: CancellationToken) -> None:
|
|
|
|
self._last_message = None
|
|
|
|
|
|
|
|
|
|
|
|
@pytest_asyncio.fixture(params=["single_threaded", "embedded"]) # type: ignore
|
|
|
|
async def runtime(request: pytest.FixtureRequest) -> AsyncGenerator[AgentRuntime | None, None]:
|
|
|
|
if request.param == "single_threaded":
|
|
|
|
runtime = SingleThreadedAgentRuntime()
|
|
|
|
runtime.start()
|
|
|
|
yield runtime
|
|
|
|
await runtime.stop()
|
|
|
|
elif request.param == "embedded":
|
|
|
|
yield None
|
|
|
|
|
|
|
|
|
|
|
|
TaskType = str | List[ChatMessage] | ChatMessage
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
async def test_digraph_group_chat_sequential_execution(runtime: AgentRuntime | None) -> None:
|
|
|
|
# Create agents A → B → C
|
|
|
|
agent_a = _EchoAgent("A", description="Echo agent A")
|
|
|
|
agent_b = _EchoAgent("B", description="Echo agent B")
|
|
|
|
agent_c = _EchoAgent("C", description="Echo agent C")
|
|
|
|
|
|
|
|
# Define graph A → B → C
|
|
|
|
graph = DiGraph(
|
|
|
|
nodes={
|
|
|
|
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B")]),
|
|
|
|
"B": DiGraphNode(name="B", edges=[DiGraphEdge(target="C")]),
|
|
|
|
"C": DiGraphNode(name="C", edges=[]),
|
|
|
|
}
|
|
|
|
)
|
|
|
|
|
|
|
|
# Create team using Graph
|
|
|
|
team = GraphFlow(
|
|
|
|
participants=[agent_a, agent_b, agent_c],
|
|
|
|
graph=graph,
|
|
|
|
runtime=runtime,
|
|
|
|
termination_condition=MaxMessageTermination(5),
|
|
|
|
)
|
|
|
|
|
|
|
|
# Run the chat
|
|
|
|
result: TaskResult = await team.run(task="Hello from User")
|
|
|
|
|
|
|
|
assert len(result.messages) == 5
|
|
|
|
assert isinstance(result.messages[0], TextMessage)
|
|
|
|
assert result.messages[0].source == "user"
|
|
|
|
assert result.messages[1].source == "A"
|
|
|
|
assert result.messages[2].source == "B"
|
|
|
|
assert result.messages[3].source == "C"
|
|
|
|
assert result.messages[4].source == _DIGRAPH_STOP_AGENT_NAME
|
|
|
|
assert all(isinstance(m, TextMessage) for m in result.messages[:-1])
|
|
|
|
assert result.stop_reason is not None
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
async def test_digraph_group_chat_parallel_fanout(runtime: AgentRuntime | None) -> None:
|
|
|
|
agent_a = _EchoAgent("A", description="Echo agent A")
|
|
|
|
agent_b = _EchoAgent("B", description="Echo agent B")
|
|
|
|
agent_c = _EchoAgent("C", description="Echo agent C")
|
|
|
|
|
|
|
|
graph = DiGraph(
|
|
|
|
nodes={
|
|
|
|
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B"), DiGraphEdge(target="C")]),
|
|
|
|
"B": DiGraphNode(name="B", edges=[]),
|
|
|
|
"C": DiGraphNode(name="C", edges=[]),
|
|
|
|
}
|
|
|
|
)
|
|
|
|
|
|
|
|
team = GraphFlow(
|
|
|
|
participants=[agent_a, agent_b, agent_c],
|
|
|
|
graph=graph,
|
|
|
|
runtime=runtime,
|
|
|
|
termination_condition=MaxMessageTermination(5),
|
|
|
|
)
|
|
|
|
|
|
|
|
result: TaskResult = await team.run(task="Start")
|
|
|
|
assert len(result.messages) == 5
|
|
|
|
assert result.messages[0].source == "user"
|
|
|
|
assert result.messages[1].source == "A"
|
|
|
|
assert set(m.source for m in result.messages[2:-1]) == {"B", "C"}
|
|
|
|
assert result.messages[-1].source == _DIGRAPH_STOP_AGENT_NAME
|
|
|
|
assert result.stop_reason is not None
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
async def test_digraph_group_chat_parallel_join_all(runtime: AgentRuntime | None) -> None:
|
|
|
|
agent_a = _EchoAgent("A", description="Echo agent A")
|
|
|
|
agent_b = _EchoAgent("B", description="Echo agent B")
|
|
|
|
agent_c = _EchoAgent("C", description="Echo agent C")
|
|
|
|
|
|
|
|
graph = DiGraph(
|
|
|
|
nodes={
|
|
|
|
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="C")]),
|
|
|
|
"B": DiGraphNode(name="B", edges=[DiGraphEdge(target="C")]),
|
|
|
|
"C": DiGraphNode(name="C", edges=[], activation="all"),
|
|
|
|
}
|
|
|
|
)
|
|
|
|
|
|
|
|
team = GraphFlow(
|
|
|
|
participants=[agent_a, agent_b, agent_c],
|
|
|
|
graph=graph,
|
|
|
|
runtime=runtime,
|
|
|
|
termination_condition=MaxMessageTermination(5),
|
|
|
|
)
|
|
|
|
|
|
|
|
result: TaskResult = await team.run(task="Go")
|
|
|
|
assert len(result.messages) == 5
|
|
|
|
assert result.messages[0].source == "user"
|
|
|
|
assert set([result.messages[1].source, result.messages[2].source]) == {"A", "B"}
|
|
|
|
assert result.messages[3].source == "C"
|
|
|
|
assert result.messages[-1].source == _DIGRAPH_STOP_AGENT_NAME
|
|
|
|
assert result.stop_reason is not None
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
async def test_digraph_group_chat_parallel_join_any(runtime: AgentRuntime | None) -> None:
|
|
|
|
agent_a = _EchoAgent("A", description="Echo agent A")
|
|
|
|
agent_b = _EchoAgent("B", description="Echo agent B")
|
|
|
|
agent_c = _EchoAgent("C", description="Echo agent C")
|
|
|
|
|
|
|
|
graph = DiGraph(
|
|
|
|
nodes={
|
|
|
|
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="C")]),
|
|
|
|
"B": DiGraphNode(name="B", edges=[DiGraphEdge(target="C")]),
|
|
|
|
"C": DiGraphNode(name="C", edges=[], activation="any"),
|
|
|
|
}
|
|
|
|
)
|
|
|
|
|
|
|
|
team = GraphFlow(
|
|
|
|
participants=[agent_a, agent_b, agent_c],
|
|
|
|
graph=graph,
|
|
|
|
runtime=runtime,
|
|
|
|
termination_condition=MaxMessageTermination(5),
|
|
|
|
)
|
|
|
|
|
|
|
|
result: TaskResult = await team.run(task="Start")
|
|
|
|
|
|
|
|
assert len(result.messages) == 5
|
|
|
|
assert result.messages[0].source == "user"
|
|
|
|
sources = [m.source for m in result.messages[1:]]
|
|
|
|
|
|
|
|
# C must be last
|
|
|
|
assert sources[-2] == "C"
|
|
|
|
|
|
|
|
# A and B must both execute
|
|
|
|
assert {"A", "B"}.issubset(set(sources))
|
|
|
|
|
|
|
|
# One of A or B must execute before C
|
|
|
|
index_a = sources.index("A")
|
|
|
|
index_b = sources.index("B")
|
|
|
|
index_c = sources.index("C")
|
|
|
|
assert index_c > min(index_a, index_b)
|
|
|
|
assert result.messages[-1].source == _DIGRAPH_STOP_AGENT_NAME
|
|
|
|
assert result.stop_reason is not None
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
async def test_digraph_group_chat_multiple_start_nodes(runtime: AgentRuntime | None) -> None:
|
|
|
|
agent_a = _EchoAgent("A", description="Echo agent A")
|
|
|
|
agent_b = _EchoAgent("B", description="Echo agent B")
|
|
|
|
|
|
|
|
graph = DiGraph(
|
|
|
|
nodes={
|
|
|
|
"A": DiGraphNode(name="A", edges=[]),
|
|
|
|
"B": DiGraphNode(name="B", edges=[]),
|
|
|
|
}
|
|
|
|
)
|
|
|
|
|
|
|
|
team = GraphFlow(
|
|
|
|
participants=[agent_a, agent_b],
|
|
|
|
graph=graph,
|
|
|
|
runtime=runtime,
|
|
|
|
termination_condition=MaxMessageTermination(5),
|
|
|
|
)
|
|
|
|
|
|
|
|
result: TaskResult = await team.run(task="Start")
|
|
|
|
assert len(result.messages) == 4
|
|
|
|
assert result.messages[0].source == "user"
|
|
|
|
assert set(m.source for m in result.messages[1:-1]) == {"A", "B"}
|
|
|
|
assert result.messages[-1].source == _DIGRAPH_STOP_AGENT_NAME
|
|
|
|
assert result.stop_reason is not None
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
async def test_digraph_group_chat_disconnected_graph(runtime: AgentRuntime | None) -> None:
|
|
|
|
agent_a = _EchoAgent("A", description="Echo agent A")
|
|
|
|
agent_b = _EchoAgent("B", description="Echo agent B")
|
|
|
|
agent_c = _EchoAgent("C", description="Echo agent C")
|
|
|
|
agent_d = _EchoAgent("D", description="Echo agent D")
|
|
|
|
|
|
|
|
graph = DiGraph(
|
|
|
|
nodes={
|
|
|
|
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B")]),
|
|
|
|
"B": DiGraphNode(name="B", edges=[]),
|
|
|
|
"C": DiGraphNode(name="C", edges=[DiGraphEdge(target="D")]),
|
|
|
|
"D": DiGraphNode(name="D", edges=[]),
|
|
|
|
}
|
|
|
|
)
|
|
|
|
|
|
|
|
team = GraphFlow(
|
|
|
|
participants=[agent_a, agent_b, agent_c, agent_d],
|
|
|
|
graph=graph,
|
|
|
|
runtime=runtime,
|
|
|
|
termination_condition=MaxMessageTermination(10),
|
|
|
|
)
|
|
|
|
|
|
|
|
result: TaskResult = await team.run(task="Go")
|
|
|
|
assert len(result.messages) == 6
|
|
|
|
assert result.messages[0].source == "user"
|
|
|
|
assert {"A", "C"} == set([result.messages[1].source, result.messages[2].source])
|
|
|
|
assert {"B", "D"} == set([result.messages[3].source, result.messages[4].source])
|
|
|
|
assert result.messages[-1].source == _DIGRAPH_STOP_AGENT_NAME
|
|
|
|
assert result.stop_reason is not None
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
async def test_digraph_group_chat_conditional_branch(runtime: AgentRuntime | None) -> None:
|
|
|
|
agent_a = _EchoAgent("A", description="Echo agent A")
|
|
|
|
agent_b = _EchoAgent("B", description="Echo agent B")
|
|
|
|
agent_c = _EchoAgent("C", description="Echo agent C")
|
|
|
|
|
Add callable condition for GraphFlow edges (#6623)
This PR adds callable as an option to specify conditional edges in
GraphFlow.
```python
import asyncio
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.conditions import MaxMessageTermination
from autogen_agentchat.teams import DiGraphBuilder, GraphFlow
from autogen_ext.models.openai import OpenAIChatCompletionClient
async def main():
# Initialize agents with OpenAI model clients.
model_client = OpenAIChatCompletionClient(model="gpt-4.1-nano")
agent_a = AssistantAgent(
"A",
model_client=model_client,
system_message="Detect if the input is in Chinese. If it is, say 'yes', else say 'no', and nothing else.",
)
agent_b = AssistantAgent("B", model_client=model_client, system_message="Translate input to English.")
agent_c = AssistantAgent("C", model_client=model_client, system_message="Translate input to Chinese.")
# Create a directed graph with conditional branching flow A -> B ("yes"), A -> C (otherwise).
builder = DiGraphBuilder()
builder.add_node(agent_a).add_node(agent_b).add_node(agent_c)
# Create conditions as callables that check the message content.
builder.add_edge(agent_a, agent_b, condition=lambda msg: "yes" in msg.to_model_text())
builder.add_edge(agent_a, agent_c, condition=lambda msg: "yes" not in msg.to_model_text())
graph = builder.build()
# Create a GraphFlow team with the directed graph.
team = GraphFlow(
participants=[agent_a, agent_b, agent_c],
graph=graph,
termination_condition=MaxMessageTermination(5),
)
# Run the team and print the events.
async for event in team.run_stream(task="AutoGen is a framework for building AI agents."):
print(event)
asyncio.run(main())
```
---------
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: ekzhu <320302+ekzhu@users.noreply.github.com>
2025-06-04 15:43:26 -07:00
|
|
|
# Use string conditions
|
2025-04-29 12:06:27 +10:00
|
|
|
graph = DiGraph(
|
|
|
|
nodes={
|
|
|
|
"A": DiGraphNode(
|
|
|
|
name="A", edges=[DiGraphEdge(target="B", condition="yes"), DiGraphEdge(target="C", condition="no")]
|
|
|
|
),
|
|
|
|
"B": DiGraphNode(name="B", edges=[], activation="any"),
|
|
|
|
"C": DiGraphNode(name="C", edges=[], activation="any"),
|
|
|
|
}
|
|
|
|
)
|
|
|
|
|
|
|
|
team = GraphFlow(
|
|
|
|
participants=[agent_a, agent_b, agent_c],
|
|
|
|
graph=graph,
|
|
|
|
runtime=runtime,
|
|
|
|
termination_condition=MaxMessageTermination(5),
|
|
|
|
)
|
|
|
|
|
|
|
|
result = await team.run(task="Trigger yes")
|
|
|
|
assert result.messages[2].source == "B"
|
|
|
|
|
2025-06-04 22:05:16 -07:00
|
|
|
# Use lambda conditions
|
|
|
|
graph_with_lambda = DiGraph(
|
|
|
|
nodes={
|
|
|
|
"A": DiGraphNode(
|
|
|
|
name="A",
|
|
|
|
edges=[
|
|
|
|
DiGraphEdge(target="B", condition=lambda msg: "yes" in msg.to_model_text()),
|
|
|
|
DiGraphEdge(target="C", condition=lambda msg: "no" in msg.to_model_text()),
|
|
|
|
],
|
|
|
|
),
|
|
|
|
"B": DiGraphNode(name="B", edges=[], activation="any"),
|
|
|
|
"C": DiGraphNode(name="C", edges=[], activation="any"),
|
|
|
|
}
|
|
|
|
)
|
|
|
|
team_with_lambda = GraphFlow(
|
|
|
|
participants=[agent_a, agent_b, agent_c],
|
|
|
|
graph=graph_with_lambda,
|
|
|
|
runtime=runtime,
|
|
|
|
termination_condition=MaxMessageTermination(5),
|
|
|
|
)
|
|
|
|
result_with_lambda = await team_with_lambda.run(task="Trigger no")
|
|
|
|
assert result_with_lambda.messages[2].source == "C"
|
|
|
|
|
2025-04-29 12:06:27 +10:00
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
async def test_digraph_group_chat_loop_with_exit_condition(runtime: AgentRuntime | None) -> None:
|
|
|
|
# Agents A and C: Echo Agents
|
|
|
|
agent_a = _EchoAgent("A", description="Echo agent A")
|
|
|
|
agent_c = _EchoAgent("C", description="Echo agent C")
|
|
|
|
|
|
|
|
# Replay model client for agent B
|
|
|
|
model_client = ReplayChatCompletionClient(
|
|
|
|
chat_completions=[
|
|
|
|
"loop", # First time B will ask to loop
|
|
|
|
"loop", # Second time B will ask to loop
|
|
|
|
"exit", # Third time B will say exit
|
|
|
|
]
|
|
|
|
)
|
|
|
|
# Agent B: Assistant Agent using Replay Client
|
|
|
|
agent_b = AssistantAgent("B", description="Decision agent B", model_client=model_client)
|
|
|
|
|
|
|
|
# DiGraph: A → B → C (conditional back to A or terminate)
|
|
|
|
graph = DiGraph(
|
|
|
|
nodes={
|
|
|
|
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B")]),
|
|
|
|
"B": DiGraphNode(
|
|
|
|
name="B", edges=[DiGraphEdge(target="C", condition="exit"), DiGraphEdge(target="A", condition="loop")]
|
|
|
|
),
|
|
|
|
"C": DiGraphNode(name="C", edges=[]),
|
|
|
|
},
|
|
|
|
default_start_node="A",
|
|
|
|
)
|
|
|
|
|
|
|
|
team = GraphFlow(
|
|
|
|
participants=[agent_a, agent_b, agent_c],
|
|
|
|
graph=graph,
|
|
|
|
runtime=runtime,
|
|
|
|
termination_condition=MaxMessageTermination(20),
|
|
|
|
)
|
|
|
|
|
|
|
|
# Run
|
|
|
|
result = await team.run(task="Start")
|
|
|
|
|
|
|
|
# Assert message order
|
|
|
|
expected_sources = [
|
|
|
|
"user",
|
|
|
|
"A",
|
|
|
|
"B", # 1st loop
|
|
|
|
"A",
|
|
|
|
"B", # 2nd loop
|
|
|
|
"A",
|
|
|
|
"B",
|
|
|
|
"C",
|
|
|
|
_DIGRAPH_STOP_AGENT_NAME,
|
2025-06-16 14:00:14 +08:00
|
|
|
]
|
|
|
|
|
|
|
|
actual_sources = [m.source for m in result.messages]
|
|
|
|
|
|
|
|
assert actual_sources == expected_sources
|
|
|
|
assert result.stop_reason is not None
|
|
|
|
assert result.messages[-2].source == "C"
|
|
|
|
assert any(m.content == "exit" for m in result.messages[:-1]) # type: ignore[attr-defined,union-attr]
|
|
|
|
assert result.messages[-1].source == _DIGRAPH_STOP_AGENT_NAME
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
2025-06-25 12:20:04 +08:00
|
|
|
async def test_digraph_group_chat_loop_with_self_cycle(runtime: AgentRuntime | None) -> None:
|
2025-06-16 14:00:14 +08:00
|
|
|
# Agents A and C: Echo Agents
|
|
|
|
agent_a = _EchoAgent("A", description="Echo agent A")
|
|
|
|
agent_c = _EchoAgent("C", description="Echo agent C")
|
|
|
|
|
|
|
|
# Replay model client for agent B
|
|
|
|
model_client = ReplayChatCompletionClient(
|
|
|
|
chat_completions=[
|
|
|
|
"loop", # First time B will ask to loop
|
|
|
|
"loop", # Second time B will ask to loop
|
|
|
|
"exit", # Third time B will say exit
|
|
|
|
]
|
|
|
|
)
|
|
|
|
# Agent B: Assistant Agent using Replay Client
|
|
|
|
agent_b = AssistantAgent("B", description="Decision agent B", model_client=model_client)
|
|
|
|
|
|
|
|
# DiGraph: A → B(self loop) → C (conditional back to A or terminate)
|
|
|
|
graph = DiGraph(
|
|
|
|
nodes={
|
|
|
|
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B")]),
|
|
|
|
"B": DiGraphNode(
|
2025-06-25 12:20:04 +08:00
|
|
|
name="B",
|
|
|
|
edges=[
|
|
|
|
DiGraphEdge(target="C", condition="exit"),
|
|
|
|
DiGraphEdge(target="B", condition="loop", activation_group="B_loop"),
|
|
|
|
],
|
2025-06-16 14:00:14 +08:00
|
|
|
),
|
|
|
|
"C": DiGraphNode(name="C", edges=[]),
|
|
|
|
},
|
|
|
|
default_start_node="A",
|
|
|
|
)
|
|
|
|
|
|
|
|
team = GraphFlow(
|
|
|
|
participants=[agent_a, agent_b, agent_c],
|
|
|
|
graph=graph,
|
|
|
|
runtime=runtime,
|
|
|
|
termination_condition=MaxMessageTermination(20),
|
|
|
|
)
|
|
|
|
|
|
|
|
# Run
|
|
|
|
result = await team.run(task="Start")
|
|
|
|
|
|
|
|
# Assert message order
|
|
|
|
expected_sources = [
|
|
|
|
"user",
|
|
|
|
"A",
|
|
|
|
"B", # 1st loop
|
|
|
|
"B", # 2nd loop
|
|
|
|
"B",
|
|
|
|
"C",
|
|
|
|
_DIGRAPH_STOP_AGENT_NAME,
|
2025-04-29 12:06:27 +10:00
|
|
|
]
|
|
|
|
|
|
|
|
actual_sources = [m.source for m in result.messages]
|
|
|
|
|
|
|
|
assert actual_sources == expected_sources
|
|
|
|
assert result.stop_reason is not None
|
|
|
|
assert result.messages[-2].source == "C"
|
|
|
|
assert any(m.content == "exit" for m in result.messages[:-1]) # type: ignore[attr-defined,union-attr]
|
|
|
|
assert result.messages[-1].source == _DIGRAPH_STOP_AGENT_NAME
|
|
|
|
|
|
|
|
|
2025-06-25 12:20:04 +08:00
|
|
|
@pytest.mark.asyncio
|
|
|
|
async def test_digraph_group_chat_loop_with_two_cycles(runtime: AgentRuntime | None) -> None:
|
|
|
|
# Agents A and C: Echo Agents
|
|
|
|
agent_a = _EchoAgent("A", description="Echo agent A")
|
|
|
|
agent_b = _EchoAgent("B", description="Echo agent B")
|
|
|
|
agent_c = _EchoAgent("C", description="Echo agent C")
|
|
|
|
agent_e = _EchoAgent("E", description="Echo agent E")
|
|
|
|
|
|
|
|
# Replay model client for agent B
|
|
|
|
model_client = ReplayChatCompletionClient(
|
|
|
|
chat_completions=[
|
|
|
|
"to_x", # First time O will branch to B
|
|
|
|
"to_o", # X will go back to O
|
|
|
|
"to_y", # Second time O will branch to C
|
|
|
|
"to_o", # Y will go back to O
|
|
|
|
"exit", # Third time O will say exit
|
|
|
|
]
|
|
|
|
)
|
|
|
|
# Agent o, b, c: Assistant Agent using Replay Client
|
|
|
|
agent_o = AssistantAgent("O", description="Decision agent o", model_client=model_client)
|
|
|
|
agent_x = AssistantAgent("X", description="Decision agent x", model_client=model_client)
|
|
|
|
agent_y = AssistantAgent("Y", description="Decision agent y", model_client=model_client)
|
|
|
|
|
|
|
|
# DiGraph:
|
|
|
|
#
|
|
|
|
# A
|
|
|
|
# / \
|
|
|
|
# B C
|
|
|
|
# \ |
|
|
|
|
# X = O = Y (bidirectional)
|
|
|
|
# |
|
|
|
|
# E(exit)
|
|
|
|
graph = DiGraph(
|
|
|
|
nodes={
|
|
|
|
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B"), DiGraphEdge(target="C")]),
|
|
|
|
"B": DiGraphNode(
|
|
|
|
name="B", edges=[DiGraphEdge(target="O")]
|
|
|
|
), # default activation group name is same as target node name "O"
|
|
|
|
"C": DiGraphNode(
|
|
|
|
name="C", edges=[DiGraphEdge(target="O")]
|
|
|
|
), # default activation group name is same as target node name "O"
|
|
|
|
"O": DiGraphNode(
|
|
|
|
name="O",
|
|
|
|
edges=[
|
|
|
|
DiGraphEdge(target="X", condition="to_x"),
|
|
|
|
DiGraphEdge(target="Y", condition="to_y"),
|
|
|
|
DiGraphEdge(target="E", condition="exit"),
|
|
|
|
],
|
|
|
|
),
|
|
|
|
"X": DiGraphNode(name="X", edges=[DiGraphEdge(target="O", condition="to_o", activation_group="x_o_loop")]),
|
|
|
|
"Y": DiGraphNode(name="Y", edges=[DiGraphEdge(target="O", condition="to_o", activation_group="y_o_loop")]),
|
|
|
|
"E": DiGraphNode(name="E", edges=[]),
|
|
|
|
},
|
|
|
|
default_start_node="A",
|
|
|
|
)
|
|
|
|
|
|
|
|
team = GraphFlow(
|
|
|
|
participants=[agent_a, agent_o, agent_b, agent_c, agent_x, agent_y, agent_e],
|
|
|
|
graph=graph,
|
|
|
|
runtime=runtime,
|
|
|
|
termination_condition=MaxMessageTermination(20),
|
|
|
|
)
|
|
|
|
|
|
|
|
# Run
|
|
|
|
result = await team.run(task="Start")
|
|
|
|
|
|
|
|
# Assert message order
|
|
|
|
expected_sources = [
|
|
|
|
"user",
|
|
|
|
"A",
|
|
|
|
"B",
|
|
|
|
"C",
|
|
|
|
"O",
|
|
|
|
"X", # O -> X
|
|
|
|
"O", # X -> O
|
|
|
|
"Y", # O -> Y
|
|
|
|
"O", # Y -> O
|
|
|
|
"E", # O -> E
|
|
|
|
_DIGRAPH_STOP_AGENT_NAME,
|
|
|
|
]
|
|
|
|
|
|
|
|
actual_sources = [m.source for m in result.messages]
|
|
|
|
|
|
|
|
assert actual_sources == expected_sources
|
|
|
|
assert result.stop_reason is not None
|
|
|
|
assert result.messages[-2].source == "E"
|
|
|
|
assert any(m.content == "exit" for m in result.messages[:-1]) # type: ignore[attr-defined,union-attr]
|
|
|
|
assert result.messages[-1].source == _DIGRAPH_STOP_AGENT_NAME
|
|
|
|
|
|
|
|
|
2025-04-29 12:06:27 +10:00
|
|
|
@pytest.mark.asyncio
|
|
|
|
async def test_digraph_group_chat_parallel_join_any_1(runtime: AgentRuntime | None) -> None:
|
|
|
|
agent_a = _EchoAgent("A", description="Echo agent A")
|
|
|
|
agent_b = _EchoAgent("B", description="Echo agent B")
|
|
|
|
agent_c = _EchoAgent("C", description="Echo agent C")
|
|
|
|
agent_d = _EchoAgent("D", description="Echo agent D")
|
|
|
|
|
|
|
|
graph = DiGraph(
|
|
|
|
nodes={
|
|
|
|
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B"), DiGraphEdge(target="C")]),
|
2025-06-25 12:20:04 +08:00
|
|
|
"B": DiGraphNode(name="B", edges=[DiGraphEdge(target="D", activation_group="any")]),
|
|
|
|
"C": DiGraphNode(name="C", edges=[DiGraphEdge(target="D", activation_group="any")]),
|
|
|
|
"D": DiGraphNode(name="D", edges=[]),
|
2025-04-29 12:06:27 +10:00
|
|
|
}
|
|
|
|
)
|
|
|
|
|
|
|
|
team = GraphFlow(
|
|
|
|
participants=[agent_a, agent_b, agent_c, agent_d],
|
|
|
|
graph=graph,
|
|
|
|
runtime=runtime,
|
|
|
|
termination_condition=MaxMessageTermination(10),
|
|
|
|
)
|
|
|
|
|
|
|
|
result = await team.run(task="Run parallel join")
|
|
|
|
sequence = [msg.source for msg in result.messages if isinstance(msg, TextMessage)]
|
|
|
|
assert sequence[0] == "user"
|
|
|
|
# B and C should both run
|
|
|
|
assert "B" in sequence
|
|
|
|
assert "C" in sequence
|
|
|
|
# D should trigger twice → once after B and once after C (order depends on runtime)
|
|
|
|
d_indices = [i for i, s in enumerate(sequence) if s == "D"]
|
|
|
|
assert len(d_indices) == 1
|
|
|
|
# Each D trigger must be after corresponding B or C
|
|
|
|
b_index = sequence.index("B")
|
|
|
|
c_index = sequence.index("C")
|
|
|
|
assert any(d > b_index for d in d_indices)
|
|
|
|
assert any(d > c_index for d in d_indices)
|
|
|
|
assert result.stop_reason is not None
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
async def test_digraph_group_chat_chained_parallel_join_any(runtime: AgentRuntime | None) -> None:
|
|
|
|
agent_a = _EchoAgent("A", description="Echo agent A")
|
|
|
|
agent_b = _EchoAgent("B", description="Echo agent B")
|
|
|
|
agent_c = _EchoAgent("C", description="Echo agent C")
|
|
|
|
agent_d = _EchoAgent("D", description="Echo agent D")
|
|
|
|
agent_e = _EchoAgent("E", description="Echo agent E")
|
|
|
|
|
|
|
|
graph = DiGraph(
|
|
|
|
nodes={
|
|
|
|
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B"), DiGraphEdge(target="C")]),
|
|
|
|
"B": DiGraphNode(name="B", edges=[DiGraphEdge(target="D")]),
|
|
|
|
"C": DiGraphNode(name="C", edges=[DiGraphEdge(target="D")]),
|
|
|
|
"D": DiGraphNode(name="D", edges=[DiGraphEdge(target="E")], activation="any"),
|
|
|
|
"E": DiGraphNode(name="E", edges=[], activation="any"),
|
|
|
|
}
|
|
|
|
)
|
|
|
|
|
|
|
|
team = GraphFlow(
|
|
|
|
participants=[agent_a, agent_b, agent_c, agent_d, agent_e],
|
|
|
|
graph=graph,
|
|
|
|
runtime=runtime,
|
|
|
|
termination_condition=MaxMessageTermination(20),
|
|
|
|
)
|
|
|
|
|
|
|
|
result = await team.run(task="Run chained parallel join-any")
|
|
|
|
|
|
|
|
sequence = [msg.source for msg in result.messages if isinstance(msg, TextMessage)]
|
|
|
|
|
|
|
|
# D should trigger twice
|
|
|
|
d_indices = [i for i, s in enumerate(sequence) if s == "D"]
|
|
|
|
assert len(d_indices) == 1
|
|
|
|
# Each D trigger must be after corresponding B or C
|
|
|
|
b_index = sequence.index("B")
|
|
|
|
c_index = sequence.index("C")
|
|
|
|
assert any(d > b_index for d in d_indices)
|
|
|
|
assert any(d > c_index for d in d_indices)
|
|
|
|
|
|
|
|
# E should also trigger twice → once after each D
|
|
|
|
e_indices = [i for i, s in enumerate(sequence) if s == "E"]
|
|
|
|
assert len(e_indices) == 1
|
|
|
|
assert e_indices[0] > d_indices[0]
|
|
|
|
assert result.stop_reason is not None
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
async def test_digraph_group_chat_multiple_conditional(runtime: AgentRuntime | None) -> None:
|
|
|
|
agent_a = _EchoAgent("A", description="Echo agent A")
|
|
|
|
agent_b = _EchoAgent("B", description="Echo agent B")
|
|
|
|
agent_c = _EchoAgent("C", description="Echo agent C")
|
|
|
|
agent_d = _EchoAgent("D", description="Echo agent D")
|
|
|
|
|
Add callable condition for GraphFlow edges (#6623)
This PR adds callable as an option to specify conditional edges in
GraphFlow.
```python
import asyncio
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.conditions import MaxMessageTermination
from autogen_agentchat.teams import DiGraphBuilder, GraphFlow
from autogen_ext.models.openai import OpenAIChatCompletionClient
async def main():
# Initialize agents with OpenAI model clients.
model_client = OpenAIChatCompletionClient(model="gpt-4.1-nano")
agent_a = AssistantAgent(
"A",
model_client=model_client,
system_message="Detect if the input is in Chinese. If it is, say 'yes', else say 'no', and nothing else.",
)
agent_b = AssistantAgent("B", model_client=model_client, system_message="Translate input to English.")
agent_c = AssistantAgent("C", model_client=model_client, system_message="Translate input to Chinese.")
# Create a directed graph with conditional branching flow A -> B ("yes"), A -> C (otherwise).
builder = DiGraphBuilder()
builder.add_node(agent_a).add_node(agent_b).add_node(agent_c)
# Create conditions as callables that check the message content.
builder.add_edge(agent_a, agent_b, condition=lambda msg: "yes" in msg.to_model_text())
builder.add_edge(agent_a, agent_c, condition=lambda msg: "yes" not in msg.to_model_text())
graph = builder.build()
# Create a GraphFlow team with the directed graph.
team = GraphFlow(
participants=[agent_a, agent_b, agent_c],
graph=graph,
termination_condition=MaxMessageTermination(5),
)
# Run the team and print the events.
async for event in team.run_stream(task="AutoGen is a framework for building AI agents."):
print(event)
asyncio.run(main())
```
---------
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: ekzhu <320302+ekzhu@users.noreply.github.com>
2025-06-04 15:43:26 -07:00
|
|
|
# Use string conditions
|
2025-04-29 12:06:27 +10:00
|
|
|
graph = DiGraph(
|
|
|
|
nodes={
|
|
|
|
"A": DiGraphNode(
|
|
|
|
name="A",
|
|
|
|
edges=[
|
|
|
|
DiGraphEdge(target="B", condition="apple"),
|
|
|
|
DiGraphEdge(target="C", condition="banana"),
|
|
|
|
DiGraphEdge(target="D", condition="cherry"),
|
|
|
|
],
|
|
|
|
),
|
|
|
|
"B": DiGraphNode(name="B", edges=[]),
|
|
|
|
"C": DiGraphNode(name="C", edges=[]),
|
|
|
|
"D": DiGraphNode(name="D", edges=[]),
|
|
|
|
}
|
|
|
|
)
|
|
|
|
|
|
|
|
team = GraphFlow(
|
|
|
|
participants=[agent_a, agent_b, agent_c, agent_d],
|
|
|
|
graph=graph,
|
|
|
|
runtime=runtime,
|
|
|
|
termination_condition=MaxMessageTermination(5),
|
|
|
|
)
|
|
|
|
|
|
|
|
# Test banana branch
|
|
|
|
result = await team.run(task="banana")
|
|
|
|
assert result.messages[2].source == "C"
|
|
|
|
|
2025-06-04 22:05:16 -07:00
|
|
|
# Use lambda conditions
|
|
|
|
graph_with_lambda = DiGraph(
|
|
|
|
nodes={
|
|
|
|
"A": DiGraphNode(
|
|
|
|
name="A",
|
|
|
|
edges=[
|
|
|
|
DiGraphEdge(target="B", condition=lambda msg: "apple" in msg.to_model_text()),
|
|
|
|
DiGraphEdge(target="C", condition=lambda msg: "banana" in msg.to_model_text()),
|
|
|
|
DiGraphEdge(target="D", condition=lambda msg: "cherry" in msg.to_model_text()),
|
|
|
|
],
|
|
|
|
),
|
|
|
|
"B": DiGraphNode(name="B", edges=[]),
|
|
|
|
"C": DiGraphNode(name="C", edges=[]),
|
|
|
|
"D": DiGraphNode(name="D", edges=[]),
|
|
|
|
}
|
|
|
|
)
|
|
|
|
team_with_lambda = GraphFlow(
|
|
|
|
participants=[agent_a, agent_b, agent_c, agent_d],
|
|
|
|
graph=graph_with_lambda,
|
|
|
|
runtime=runtime,
|
|
|
|
termination_condition=MaxMessageTermination(5),
|
|
|
|
)
|
|
|
|
result_with_lambda = await team_with_lambda.run(task="cherry")
|
|
|
|
assert result_with_lambda.messages[2].source == "D"
|
|
|
|
|
2025-04-29 12:06:27 +10:00
|
|
|
|
|
|
|
class _TestMessageFilterAgentConfig(BaseModel):
|
|
|
|
name: str
|
|
|
|
description: str = "Echo test agent"
|
|
|
|
|
|
|
|
|
|
|
|
class _TestMessageFilterAgent(BaseChatAgent, Component[_TestMessageFilterAgentConfig]):
|
|
|
|
component_config_schema = _TestMessageFilterAgentConfig
|
|
|
|
component_provider_override = "test_group_chat_graph._TestMessageFilterAgent"
|
|
|
|
|
|
|
|
def __init__(self, name: str, description: str = "Echo test agent") -> None:
|
|
|
|
super().__init__(name=name, description=description)
|
|
|
|
self.received_messages: list[BaseChatMessage] = []
|
|
|
|
|
|
|
|
@property
|
|
|
|
def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:
|
|
|
|
return (TextMessage,)
|
|
|
|
|
|
|
|
async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response:
|
|
|
|
self.received_messages.extend(messages)
|
|
|
|
return Response(chat_message=TextMessage(content="ACK", source=self.name))
|
|
|
|
|
|
|
|
async def on_reset(self, cancellation_token: CancellationToken) -> None:
|
|
|
|
self.received_messages.clear()
|
|
|
|
|
|
|
|
def _to_config(self) -> _TestMessageFilterAgentConfig:
|
|
|
|
return _TestMessageFilterAgentConfig(name=self.name, description=self.description)
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def _from_config(cls, config: _TestMessageFilterAgentConfig) -> "_TestMessageFilterAgent":
|
|
|
|
return cls(name=config.name, description=config.description)
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
async def test_message_filter_agent_empty_filter_blocks_all() -> None:
|
|
|
|
inner_agent = _TestMessageFilterAgent("inner")
|
|
|
|
wrapper = MessageFilterAgent(
|
|
|
|
name="wrapper",
|
|
|
|
wrapped_agent=inner_agent,
|
|
|
|
filter=MessageFilterConfig(per_source=[]),
|
|
|
|
)
|
|
|
|
messages = [
|
|
|
|
TextMessage(source="user", content="Hello"),
|
|
|
|
TextMessage(source="system", content="System msg"),
|
|
|
|
]
|
|
|
|
await wrapper.on_messages(messages, CancellationToken())
|
|
|
|
assert len(inner_agent.received_messages) == 0
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
async def test_message_filter_agent_with_position_none_gets_all() -> None:
|
|
|
|
inner_agent = _TestMessageFilterAgent("inner")
|
|
|
|
wrapper = MessageFilterAgent(
|
|
|
|
name="wrapper",
|
|
|
|
wrapped_agent=inner_agent,
|
|
|
|
filter=MessageFilterConfig(per_source=[PerSourceFilter(source="user", position=None, count=None)]),
|
|
|
|
)
|
|
|
|
messages = [
|
|
|
|
TextMessage(source="user", content="A"),
|
|
|
|
TextMessage(source="user", content="B"),
|
|
|
|
TextMessage(source="system", content="Ignore this"),
|
|
|
|
]
|
|
|
|
await wrapper.on_messages(messages, CancellationToken())
|
|
|
|
assert len(inner_agent.received_messages) == 2
|
|
|
|
assert {m.content for m in inner_agent.received_messages} == {"A", "B"} # type: ignore[attr-defined]
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
async def test_digraph_group_chat() -> None:
|
|
|
|
inner_agent = _TestMessageFilterAgent("agent")
|
|
|
|
wrapper = MessageFilterAgent(
|
|
|
|
name="agent",
|
|
|
|
wrapped_agent=inner_agent,
|
|
|
|
filter=MessageFilterConfig(
|
|
|
|
per_source=[
|
|
|
|
PerSourceFilter(source="user", position="last", count=2),
|
|
|
|
PerSourceFilter(source="system", position="first", count=1),
|
|
|
|
]
|
|
|
|
),
|
|
|
|
)
|
|
|
|
config = wrapper.dump_component()
|
|
|
|
loaded = MessageFilterAgent.load_component(config)
|
|
|
|
assert loaded.name == "agent"
|
|
|
|
assert loaded._filter == wrapper._filter # pyright: ignore[reportPrivateUsage]
|
|
|
|
assert loaded._wrapped_agent.name == wrapper._wrapped_agent.name # pyright: ignore[reportPrivateUsage]
|
|
|
|
|
|
|
|
# Run on_messages and validate filtering still works
|
|
|
|
messages = [
|
|
|
|
TextMessage(source="user", content="u1"),
|
|
|
|
TextMessage(source="user", content="u2"),
|
|
|
|
TextMessage(source="user", content="u3"),
|
|
|
|
TextMessage(source="system", content="s1"),
|
|
|
|
TextMessage(source="system", content="s2"),
|
|
|
|
]
|
|
|
|
await loaded.on_messages(messages, CancellationToken())
|
|
|
|
received = loaded._wrapped_agent.received_messages # type: ignore[attr-defined]
|
|
|
|
assert {m.content for m in received} == {"u2", "u3", "s1"} # pyright: ignore[reportUnknownVariableType,reportUnknownMemberType]
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
async def test_message_filter_agent_in_digraph_group_chat(runtime: AgentRuntime | None) -> None:
|
|
|
|
inner_agent = _TestMessageFilterAgent("filtered")
|
|
|
|
filtered = MessageFilterAgent(
|
|
|
|
name="filtered",
|
|
|
|
wrapped_agent=inner_agent,
|
|
|
|
filter=MessageFilterConfig(
|
|
|
|
per_source=[
|
|
|
|
PerSourceFilter(source="user", position="last", count=1),
|
|
|
|
]
|
|
|
|
),
|
|
|
|
)
|
|
|
|
|
|
|
|
graph = DiGraph(
|
|
|
|
nodes={
|
|
|
|
"filtered": DiGraphNode(name="filtered", edges=[]),
|
|
|
|
}
|
|
|
|
)
|
|
|
|
|
|
|
|
team = GraphFlow(
|
|
|
|
participants=[filtered],
|
|
|
|
graph=graph,
|
|
|
|
runtime=runtime,
|
|
|
|
termination_condition=MaxMessageTermination(3),
|
|
|
|
)
|
|
|
|
|
|
|
|
result = await team.run(task="only last user message matters")
|
|
|
|
assert result.stop_reason is not None
|
|
|
|
assert any(msg.source == "filtered" for msg in result.messages)
|
|
|
|
assert any(msg.content == "ACK" for msg in result.messages if msg.source == "filtered") # type: ignore[attr-defined,union-attr]
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
async def test_message_filter_agent_loop_graph_visibility(runtime: AgentRuntime | None) -> None:
|
|
|
|
agent_a_inner = _TestMessageFilterAgent("A")
|
|
|
|
agent_a = MessageFilterAgent(
|
|
|
|
name="A",
|
|
|
|
wrapped_agent=agent_a_inner,
|
|
|
|
filter=MessageFilterConfig(
|
|
|
|
per_source=[
|
|
|
|
PerSourceFilter(source="user", position="first", count=1),
|
|
|
|
PerSourceFilter(source="B", position="last", count=1),
|
|
|
|
]
|
|
|
|
),
|
|
|
|
)
|
|
|
|
|
|
|
|
from autogen_agentchat.agents import AssistantAgent
|
|
|
|
from autogen_ext.models.replay import ReplayChatCompletionClient
|
|
|
|
|
|
|
|
model_client = ReplayChatCompletionClient(["loop", "loop", "exit"])
|
|
|
|
agent_b_inner = AssistantAgent("B", model_client=model_client)
|
|
|
|
agent_b = MessageFilterAgent(
|
|
|
|
name="B",
|
|
|
|
wrapped_agent=agent_b_inner,
|
|
|
|
filter=MessageFilterConfig(
|
|
|
|
per_source=[
|
|
|
|
PerSourceFilter(source="user", position="first", count=1),
|
|
|
|
PerSourceFilter(source="A", position="last", count=1),
|
|
|
|
PerSourceFilter(source="B", position="last", count=10),
|
|
|
|
]
|
|
|
|
),
|
|
|
|
)
|
|
|
|
|
|
|
|
agent_c_inner = _TestMessageFilterAgent("C")
|
|
|
|
agent_c = MessageFilterAgent(
|
|
|
|
name="C",
|
|
|
|
wrapped_agent=agent_c_inner,
|
|
|
|
filter=MessageFilterConfig(
|
|
|
|
per_source=[
|
|
|
|
PerSourceFilter(source="user", position="first", count=1),
|
|
|
|
PerSourceFilter(source="B", position="last", count=1),
|
|
|
|
]
|
|
|
|
),
|
|
|
|
)
|
|
|
|
|
|
|
|
graph = DiGraph(
|
|
|
|
nodes={
|
|
|
|
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B")]),
|
|
|
|
"B": DiGraphNode(
|
|
|
|
name="B",
|
|
|
|
edges=[
|
|
|
|
DiGraphEdge(target="C", condition="exit"),
|
|
|
|
DiGraphEdge(target="A", condition="loop"),
|
|
|
|
],
|
|
|
|
),
|
|
|
|
"C": DiGraphNode(name="C", edges=[]),
|
|
|
|
},
|
|
|
|
default_start_node="A",
|
|
|
|
)
|
|
|
|
|
|
|
|
team = GraphFlow(
|
|
|
|
participants=[agent_a, agent_b, agent_c],
|
|
|
|
graph=graph,
|
|
|
|
runtime=runtime,
|
|
|
|
termination_condition=MaxMessageTermination(20),
|
|
|
|
)
|
|
|
|
|
|
|
|
result = await team.run(task="Start")
|
|
|
|
assert result.stop_reason is not None
|
|
|
|
|
|
|
|
# Check A received: 1 user + 2 from B
|
|
|
|
assert [m.source for m in agent_a_inner.received_messages].count("user") == 1
|
|
|
|
assert [m.source for m in agent_a_inner.received_messages].count("B") == 2
|
|
|
|
|
|
|
|
# Check C received: 1 user + 1 from B
|
|
|
|
assert [m.source for m in agent_c_inner.received_messages].count("user") == 1
|
|
|
|
assert [m.source for m in agent_c_inner.received_messages].count("B") == 1
|
|
|
|
|
|
|
|
# Check B received: 1 user + multiple from A + own messages
|
|
|
|
model_msgs = await agent_b_inner.model_context.get_messages()
|
|
|
|
sources = [m.source for m in model_msgs] # type: ignore[union-attr]
|
|
|
|
assert sources.count("user") == 1 # pyright: ignore[reportUnknownMemberType]
|
|
|
|
assert sources.count("A") >= 3 # pyright: ignore[reportUnknownMemberType]
|
|
|
|
assert sources.count("B") >= 2 # pyright: ignore[reportUnknownMemberType]
|
|
|
|
|
|
|
|
|
|
|
|
# Test Graph Builder
|
|
|
|
def test_add_node() -> None:
|
|
|
|
client = ReplayChatCompletionClient(["response"])
|
|
|
|
agent = AssistantAgent("A", model_client=client)
|
|
|
|
builder = DiGraphBuilder()
|
|
|
|
builder.add_node(agent)
|
|
|
|
|
|
|
|
assert "A" in builder.nodes
|
|
|
|
assert "A" in builder.agents
|
|
|
|
assert builder.nodes["A"].activation == "all"
|
|
|
|
|
|
|
|
|
|
|
|
def test_add_edge() -> None:
|
|
|
|
client = ReplayChatCompletionClient(["1", "2"])
|
|
|
|
a = AssistantAgent("A", model_client=client)
|
|
|
|
b = AssistantAgent("B", model_client=client)
|
|
|
|
|
|
|
|
builder = DiGraphBuilder()
|
|
|
|
builder.add_node(a).add_node(b)
|
|
|
|
builder.add_edge(a, b)
|
|
|
|
|
|
|
|
assert builder.nodes["A"].edges[0].target == "B"
|
|
|
|
assert builder.nodes["A"].edges[0].condition is None
|
|
|
|
|
|
|
|
|
|
|
|
def test_add_conditional_edges() -> None:
|
|
|
|
client = ReplayChatCompletionClient(["1", "2"])
|
|
|
|
a = AssistantAgent("A", model_client=client)
|
|
|
|
b = AssistantAgent("B", model_client=client)
|
|
|
|
c = AssistantAgent("C", model_client=client)
|
|
|
|
|
|
|
|
builder = DiGraphBuilder()
|
|
|
|
builder.add_node(a).add_node(b).add_node(c)
|
|
|
|
builder.add_conditional_edges(a, {"yes": b, "no": c})
|
|
|
|
|
|
|
|
edges = builder.nodes["A"].edges
|
|
|
|
assert len(edges) == 2
|
Add callable condition for GraphFlow edges (#6623)
This PR adds callable as an option to specify conditional edges in
GraphFlow.
```python
import asyncio
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.conditions import MaxMessageTermination
from autogen_agentchat.teams import DiGraphBuilder, GraphFlow
from autogen_ext.models.openai import OpenAIChatCompletionClient
async def main():
# Initialize agents with OpenAI model clients.
model_client = OpenAIChatCompletionClient(model="gpt-4.1-nano")
agent_a = AssistantAgent(
"A",
model_client=model_client,
system_message="Detect if the input is in Chinese. If it is, say 'yes', else say 'no', and nothing else.",
)
agent_b = AssistantAgent("B", model_client=model_client, system_message="Translate input to English.")
agent_c = AssistantAgent("C", model_client=model_client, system_message="Translate input to Chinese.")
# Create a directed graph with conditional branching flow A -> B ("yes"), A -> C (otherwise).
builder = DiGraphBuilder()
builder.add_node(agent_a).add_node(agent_b).add_node(agent_c)
# Create conditions as callables that check the message content.
builder.add_edge(agent_a, agent_b, condition=lambda msg: "yes" in msg.to_model_text())
builder.add_edge(agent_a, agent_c, condition=lambda msg: "yes" not in msg.to_model_text())
graph = builder.build()
# Create a GraphFlow team with the directed graph.
team = GraphFlow(
participants=[agent_a, agent_b, agent_c],
graph=graph,
termination_condition=MaxMessageTermination(5),
)
# Run the team and print the events.
async for event in team.run_stream(task="AutoGen is a framework for building AI agents."):
print(event)
asyncio.run(main())
```
---------
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: ekzhu <320302+ekzhu@users.noreply.github.com>
2025-06-04 15:43:26 -07:00
|
|
|
|
|
|
|
# Extract the condition strings to compare them
|
|
|
|
conditions = [e.condition for e in edges]
|
|
|
|
assert "yes" in conditions
|
|
|
|
assert "no" in conditions
|
|
|
|
|
|
|
|
# Match edge targets with conditions
|
|
|
|
yes_edge = next(e for e in edges if e.condition == "yes")
|
|
|
|
no_edge = next(e for e in edges if e.condition == "no")
|
|
|
|
|
|
|
|
assert yes_edge.target == "B"
|
|
|
|
assert no_edge.target == "C"
|
2025-04-29 12:06:27 +10:00
|
|
|
|
|
|
|
|
|
|
|
def test_set_entry_point() -> None:
|
|
|
|
client = ReplayChatCompletionClient(["ok"])
|
|
|
|
a = AssistantAgent("A", model_client=client)
|
|
|
|
builder = DiGraphBuilder().add_node(a).set_entry_point(a)
|
|
|
|
graph = builder.build()
|
|
|
|
|
|
|
|
assert graph.default_start_node == "A"
|
|
|
|
|
|
|
|
|
|
|
|
def test_build_graph_validation() -> None:
|
|
|
|
client = ReplayChatCompletionClient(["1", "2", "3"])
|
|
|
|
a = AssistantAgent("A", model_client=client)
|
|
|
|
b = AssistantAgent("B", model_client=client)
|
|
|
|
c = AssistantAgent("C", model_client=client)
|
|
|
|
|
|
|
|
builder = DiGraphBuilder()
|
|
|
|
builder.add_node(a).add_node(b).add_node(c)
|
|
|
|
builder.add_edge("A", "B").add_edge("B", "C")
|
|
|
|
builder.set_entry_point("A")
|
|
|
|
graph = builder.build()
|
|
|
|
|
|
|
|
assert isinstance(graph, DiGraph)
|
|
|
|
assert set(graph.nodes.keys()) == {"A", "B", "C"}
|
|
|
|
assert graph.get_start_nodes() == {"A"}
|
|
|
|
assert graph.get_leaf_nodes() == {"C"}
|
|
|
|
|
|
|
|
|
|
|
|
def test_build_fan_out() -> None:
|
|
|
|
client = ReplayChatCompletionClient(["hi"] * 3)
|
|
|
|
a = AssistantAgent("A", model_client=client)
|
|
|
|
b = AssistantAgent("B", model_client=client)
|
|
|
|
c = AssistantAgent("C", model_client=client)
|
|
|
|
|
|
|
|
builder = DiGraphBuilder()
|
|
|
|
builder.add_node(a).add_node(b).add_node(c)
|
|
|
|
builder.add_edge(a, b).add_edge(a, c)
|
|
|
|
builder.set_entry_point(a)
|
|
|
|
graph = builder.build()
|
|
|
|
|
|
|
|
assert graph.get_start_nodes() == {"A"}
|
|
|
|
assert graph.get_leaf_nodes() == {"B", "C"}
|
|
|
|
|
|
|
|
|
|
|
|
def test_build_parallel_join() -> None:
|
|
|
|
client = ReplayChatCompletionClient(["go"] * 3)
|
|
|
|
a = AssistantAgent("A", model_client=client)
|
|
|
|
b = AssistantAgent("B", model_client=client)
|
|
|
|
c = AssistantAgent("C", model_client=client)
|
|
|
|
|
|
|
|
builder = DiGraphBuilder()
|
|
|
|
builder.add_node(a).add_node(b).add_node(c, activation="all")
|
|
|
|
builder.add_edge(a, c).add_edge(b, c)
|
|
|
|
builder.set_entry_point(a)
|
|
|
|
builder.add_edge(b, c)
|
|
|
|
builder.nodes["B"] = DiGraphNode(name="B", edges=[DiGraphEdge(target="C")])
|
|
|
|
graph = builder.build()
|
|
|
|
|
|
|
|
assert graph.nodes["C"].activation == "all"
|
|
|
|
assert graph.get_leaf_nodes() == {"C"}
|
|
|
|
|
|
|
|
|
|
|
|
def test_build_conditional_loop() -> None:
|
|
|
|
client = ReplayChatCompletionClient(["loop", "loop", "exit"])
|
|
|
|
a = AssistantAgent("A", model_client=client)
|
|
|
|
b = AssistantAgent("B", model_client=client)
|
|
|
|
c = AssistantAgent("C", model_client=client)
|
|
|
|
|
|
|
|
builder = DiGraphBuilder()
|
|
|
|
builder.add_node(a).add_node(b).add_node(c)
|
|
|
|
builder.add_edge(a, b)
|
|
|
|
builder.add_conditional_edges(b, {"loop": a, "exit": c})
|
|
|
|
builder.set_entry_point(a)
|
|
|
|
graph = builder.build()
|
|
|
|
|
Add callable condition for GraphFlow edges (#6623)
This PR adds callable as an option to specify conditional edges in
GraphFlow.
```python
import asyncio
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.conditions import MaxMessageTermination
from autogen_agentchat.teams import DiGraphBuilder, GraphFlow
from autogen_ext.models.openai import OpenAIChatCompletionClient
async def main():
# Initialize agents with OpenAI model clients.
model_client = OpenAIChatCompletionClient(model="gpt-4.1-nano")
agent_a = AssistantAgent(
"A",
model_client=model_client,
system_message="Detect if the input is in Chinese. If it is, say 'yes', else say 'no', and nothing else.",
)
agent_b = AssistantAgent("B", model_client=model_client, system_message="Translate input to English.")
agent_c = AssistantAgent("C", model_client=model_client, system_message="Translate input to Chinese.")
# Create a directed graph with conditional branching flow A -> B ("yes"), A -> C (otherwise).
builder = DiGraphBuilder()
builder.add_node(agent_a).add_node(agent_b).add_node(agent_c)
# Create conditions as callables that check the message content.
builder.add_edge(agent_a, agent_b, condition=lambda msg: "yes" in msg.to_model_text())
builder.add_edge(agent_a, agent_c, condition=lambda msg: "yes" not in msg.to_model_text())
graph = builder.build()
# Create a GraphFlow team with the directed graph.
team = GraphFlow(
participants=[agent_a, agent_b, agent_c],
graph=graph,
termination_condition=MaxMessageTermination(5),
)
# Run the team and print the events.
async for event in team.run_stream(task="AutoGen is a framework for building AI agents."):
print(event)
asyncio.run(main())
```
---------
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: ekzhu <320302+ekzhu@users.noreply.github.com>
2025-06-04 15:43:26 -07:00
|
|
|
# Check that edges have the right conditions and targets
|
|
|
|
edges = graph.nodes["B"].edges
|
|
|
|
assert len(edges) == 2
|
|
|
|
|
|
|
|
# Find edges by their conditions
|
|
|
|
loop_edge = next(e for e in edges if e.condition == "loop")
|
|
|
|
exit_edge = next(e for e in edges if e.condition == "exit")
|
|
|
|
|
|
|
|
assert loop_edge.target == "A"
|
|
|
|
assert exit_edge.target == "C"
|
2025-04-29 12:06:27 +10:00
|
|
|
assert graph.has_cycles_with_exit()
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
async def test_graph_builder_sequential_execution(runtime: AgentRuntime | None) -> None:
|
|
|
|
a = _EchoAgent("A", description="Echo A")
|
|
|
|
b = _EchoAgent("B", description="Echo B")
|
|
|
|
c = _EchoAgent("C", description="Echo C")
|
|
|
|
|
|
|
|
builder = DiGraphBuilder()
|
|
|
|
builder.add_node(a).add_node(b).add_node(c)
|
|
|
|
builder.add_edge(a, b).add_edge(b, c)
|
|
|
|
|
|
|
|
team = GraphFlow(
|
|
|
|
participants=builder.get_participants(),
|
|
|
|
graph=builder.build(),
|
|
|
|
runtime=runtime,
|
|
|
|
termination_condition=MaxMessageTermination(5),
|
|
|
|
)
|
|
|
|
|
|
|
|
result = await team.run(task="Start")
|
|
|
|
assert [m.source for m in result.messages[1:-1]] == ["A", "B", "C"]
|
|
|
|
assert result.stop_reason is not None
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
async def test_graph_builder_fan_out(runtime: AgentRuntime | None) -> None:
|
|
|
|
a = _EchoAgent("A", description="Echo A")
|
|
|
|
b = _EchoAgent("B", description="Echo B")
|
|
|
|
c = _EchoAgent("C", description="Echo C")
|
|
|
|
|
|
|
|
builder = DiGraphBuilder()
|
|
|
|
builder.add_node(a).add_node(b).add_node(c)
|
|
|
|
builder.add_edge(a, b).add_edge(a, c)
|
|
|
|
|
|
|
|
team = GraphFlow(
|
|
|
|
participants=builder.get_participants(),
|
|
|
|
graph=builder.build(),
|
|
|
|
runtime=runtime,
|
|
|
|
termination_condition=MaxMessageTermination(5),
|
|
|
|
)
|
|
|
|
|
|
|
|
result = await team.run(task="Start")
|
|
|
|
sources = [m.source for m in result.messages if isinstance(m, TextMessage)]
|
|
|
|
assert set(sources[1:]) == {"A", "B", "C"}
|
|
|
|
assert result.stop_reason is not None
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
async def test_graph_builder_conditional_execution(runtime: AgentRuntime | None) -> None:
|
|
|
|
a = _EchoAgent("A", description="Echo A")
|
|
|
|
b = _EchoAgent("B", description="Echo B")
|
|
|
|
c = _EchoAgent("C", description="Echo C")
|
|
|
|
|
|
|
|
builder = DiGraphBuilder()
|
|
|
|
builder.add_node(a).add_node(b).add_node(c)
|
|
|
|
builder.add_conditional_edges(a, {"yes": b, "no": c})
|
|
|
|
|
|
|
|
team = GraphFlow(
|
|
|
|
participants=builder.get_participants(),
|
|
|
|
graph=builder.build(),
|
|
|
|
runtime=runtime,
|
|
|
|
termination_condition=MaxMessageTermination(5),
|
|
|
|
)
|
|
|
|
|
Add callable condition for GraphFlow edges (#6623)
This PR adds callable as an option to specify conditional edges in
GraphFlow.
```python
import asyncio
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.conditions import MaxMessageTermination
from autogen_agentchat.teams import DiGraphBuilder, GraphFlow
from autogen_ext.models.openai import OpenAIChatCompletionClient
async def main():
# Initialize agents with OpenAI model clients.
model_client = OpenAIChatCompletionClient(model="gpt-4.1-nano")
agent_a = AssistantAgent(
"A",
model_client=model_client,
system_message="Detect if the input is in Chinese. If it is, say 'yes', else say 'no', and nothing else.",
)
agent_b = AssistantAgent("B", model_client=model_client, system_message="Translate input to English.")
agent_c = AssistantAgent("C", model_client=model_client, system_message="Translate input to Chinese.")
# Create a directed graph with conditional branching flow A -> B ("yes"), A -> C (otherwise).
builder = DiGraphBuilder()
builder.add_node(agent_a).add_node(agent_b).add_node(agent_c)
# Create conditions as callables that check the message content.
builder.add_edge(agent_a, agent_b, condition=lambda msg: "yes" in msg.to_model_text())
builder.add_edge(agent_a, agent_c, condition=lambda msg: "yes" not in msg.to_model_text())
graph = builder.build()
# Create a GraphFlow team with the directed graph.
team = GraphFlow(
participants=[agent_a, agent_b, agent_c],
graph=graph,
termination_condition=MaxMessageTermination(5),
)
# Run the team and print the events.
async for event in team.run_stream(task="AutoGen is a framework for building AI agents."):
print(event)
asyncio.run(main())
```
---------
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: ekzhu <320302+ekzhu@users.noreply.github.com>
2025-06-04 15:43:26 -07:00
|
|
|
# Input "no" should trigger the edge to C
|
2025-04-29 12:06:27 +10:00
|
|
|
result = await team.run(task="no")
|
|
|
|
sources = [m.source for m in result.messages]
|
|
|
|
assert "C" in sources
|
|
|
|
assert result.stop_reason is not None
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
Add callable condition for GraphFlow edges (#6623)
This PR adds callable as an option to specify conditional edges in
GraphFlow.
```python
import asyncio
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.conditions import MaxMessageTermination
from autogen_agentchat.teams import DiGraphBuilder, GraphFlow
from autogen_ext.models.openai import OpenAIChatCompletionClient
async def main():
# Initialize agents with OpenAI model clients.
model_client = OpenAIChatCompletionClient(model="gpt-4.1-nano")
agent_a = AssistantAgent(
"A",
model_client=model_client,
system_message="Detect if the input is in Chinese. If it is, say 'yes', else say 'no', and nothing else.",
)
agent_b = AssistantAgent("B", model_client=model_client, system_message="Translate input to English.")
agent_c = AssistantAgent("C", model_client=model_client, system_message="Translate input to Chinese.")
# Create a directed graph with conditional branching flow A -> B ("yes"), A -> C (otherwise).
builder = DiGraphBuilder()
builder.add_node(agent_a).add_node(agent_b).add_node(agent_c)
# Create conditions as callables that check the message content.
builder.add_edge(agent_a, agent_b, condition=lambda msg: "yes" in msg.to_model_text())
builder.add_edge(agent_a, agent_c, condition=lambda msg: "yes" not in msg.to_model_text())
graph = builder.build()
# Create a GraphFlow team with the directed graph.
team = GraphFlow(
participants=[agent_a, agent_b, agent_c],
graph=graph,
termination_condition=MaxMessageTermination(5),
)
# Run the team and print the events.
async for event in team.run_stream(task="AutoGen is a framework for building AI agents."):
print(event)
asyncio.run(main())
```
---------
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: ekzhu <320302+ekzhu@users.noreply.github.com>
2025-06-04 15:43:26 -07:00
|
|
|
async def test_digraph_group_chat_callable_condition(runtime: AgentRuntime | None) -> None:
|
|
|
|
"""Test that string conditions work correctly in edge transitions."""
|
|
|
|
agent_a = _EchoAgent("A", description="Echo agent A")
|
|
|
|
agent_b = _EchoAgent("B", description="Echo agent B")
|
|
|
|
agent_c = _EchoAgent("C", description="Echo agent C")
|
2025-04-29 12:06:27 +10:00
|
|
|
|
Add callable condition for GraphFlow edges (#6623)
This PR adds callable as an option to specify conditional edges in
GraphFlow.
```python
import asyncio
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.conditions import MaxMessageTermination
from autogen_agentchat.teams import DiGraphBuilder, GraphFlow
from autogen_ext.models.openai import OpenAIChatCompletionClient
async def main():
# Initialize agents with OpenAI model clients.
model_client = OpenAIChatCompletionClient(model="gpt-4.1-nano")
agent_a = AssistantAgent(
"A",
model_client=model_client,
system_message="Detect if the input is in Chinese. If it is, say 'yes', else say 'no', and nothing else.",
)
agent_b = AssistantAgent("B", model_client=model_client, system_message="Translate input to English.")
agent_c = AssistantAgent("C", model_client=model_client, system_message="Translate input to Chinese.")
# Create a directed graph with conditional branching flow A -> B ("yes"), A -> C (otherwise).
builder = DiGraphBuilder()
builder.add_node(agent_a).add_node(agent_b).add_node(agent_c)
# Create conditions as callables that check the message content.
builder.add_edge(agent_a, agent_b, condition=lambda msg: "yes" in msg.to_model_text())
builder.add_edge(agent_a, agent_c, condition=lambda msg: "yes" not in msg.to_model_text())
graph = builder.build()
# Create a GraphFlow team with the directed graph.
team = GraphFlow(
participants=[agent_a, agent_b, agent_c],
graph=graph,
termination_condition=MaxMessageTermination(5),
)
# Run the team and print the events.
async for event in team.run_stream(task="AutoGen is a framework for building AI agents."):
print(event)
asyncio.run(main())
```
---------
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: ekzhu <320302+ekzhu@users.noreply.github.com>
2025-06-04 15:43:26 -07:00
|
|
|
graph = DiGraph(
|
|
|
|
nodes={
|
|
|
|
"A": DiGraphNode(
|
|
|
|
name="A",
|
|
|
|
edges=[
|
|
|
|
# Will go to B if "long" is in message
|
|
|
|
DiGraphEdge(target="B", condition="long"),
|
|
|
|
# Will go to C if "short" is in message
|
|
|
|
DiGraphEdge(target="C", condition="short"),
|
|
|
|
],
|
|
|
|
),
|
|
|
|
"B": DiGraphNode(name="B", edges=[]),
|
|
|
|
"C": DiGraphNode(name="C", edges=[]),
|
|
|
|
}
|
|
|
|
)
|
2025-04-29 12:06:27 +10:00
|
|
|
|
|
|
|
team = GraphFlow(
|
Add callable condition for GraphFlow edges (#6623)
This PR adds callable as an option to specify conditional edges in
GraphFlow.
```python
import asyncio
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.conditions import MaxMessageTermination
from autogen_agentchat.teams import DiGraphBuilder, GraphFlow
from autogen_ext.models.openai import OpenAIChatCompletionClient
async def main():
# Initialize agents with OpenAI model clients.
model_client = OpenAIChatCompletionClient(model="gpt-4.1-nano")
agent_a = AssistantAgent(
"A",
model_client=model_client,
system_message="Detect if the input is in Chinese. If it is, say 'yes', else say 'no', and nothing else.",
)
agent_b = AssistantAgent("B", model_client=model_client, system_message="Translate input to English.")
agent_c = AssistantAgent("C", model_client=model_client, system_message="Translate input to Chinese.")
# Create a directed graph with conditional branching flow A -> B ("yes"), A -> C (otherwise).
builder = DiGraphBuilder()
builder.add_node(agent_a).add_node(agent_b).add_node(agent_c)
# Create conditions as callables that check the message content.
builder.add_edge(agent_a, agent_b, condition=lambda msg: "yes" in msg.to_model_text())
builder.add_edge(agent_a, agent_c, condition=lambda msg: "yes" not in msg.to_model_text())
graph = builder.build()
# Create a GraphFlow team with the directed graph.
team = GraphFlow(
participants=[agent_a, agent_b, agent_c],
graph=graph,
termination_condition=MaxMessageTermination(5),
)
# Run the team and print the events.
async for event in team.run_stream(task="AutoGen is a framework for building AI agents."):
print(event)
asyncio.run(main())
```
---------
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: ekzhu <320302+ekzhu@users.noreply.github.com>
2025-06-04 15:43:26 -07:00
|
|
|
participants=[agent_a, agent_b, agent_c],
|
|
|
|
graph=graph,
|
2025-04-29 12:06:27 +10:00
|
|
|
runtime=runtime,
|
Add callable condition for GraphFlow edges (#6623)
This PR adds callable as an option to specify conditional edges in
GraphFlow.
```python
import asyncio
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.conditions import MaxMessageTermination
from autogen_agentchat.teams import DiGraphBuilder, GraphFlow
from autogen_ext.models.openai import OpenAIChatCompletionClient
async def main():
# Initialize agents with OpenAI model clients.
model_client = OpenAIChatCompletionClient(model="gpt-4.1-nano")
agent_a = AssistantAgent(
"A",
model_client=model_client,
system_message="Detect if the input is in Chinese. If it is, say 'yes', else say 'no', and nothing else.",
)
agent_b = AssistantAgent("B", model_client=model_client, system_message="Translate input to English.")
agent_c = AssistantAgent("C", model_client=model_client, system_message="Translate input to Chinese.")
# Create a directed graph with conditional branching flow A -> B ("yes"), A -> C (otherwise).
builder = DiGraphBuilder()
builder.add_node(agent_a).add_node(agent_b).add_node(agent_c)
# Create conditions as callables that check the message content.
builder.add_edge(agent_a, agent_b, condition=lambda msg: "yes" in msg.to_model_text())
builder.add_edge(agent_a, agent_c, condition=lambda msg: "yes" not in msg.to_model_text())
graph = builder.build()
# Create a GraphFlow team with the directed graph.
team = GraphFlow(
participants=[agent_a, agent_b, agent_c],
graph=graph,
termination_condition=MaxMessageTermination(5),
)
# Run the team and print the events.
async for event in team.run_stream(task="AutoGen is a framework for building AI agents."):
print(event)
asyncio.run(main())
```
---------
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: ekzhu <320302+ekzhu@users.noreply.github.com>
2025-06-04 15:43:26 -07:00
|
|
|
termination_condition=MaxMessageTermination(5),
|
2025-04-29 12:06:27 +10:00
|
|
|
)
|
|
|
|
|
Add callable condition for GraphFlow edges (#6623)
This PR adds callable as an option to specify conditional edges in
GraphFlow.
```python
import asyncio
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.conditions import MaxMessageTermination
from autogen_agentchat.teams import DiGraphBuilder, GraphFlow
from autogen_ext.models.openai import OpenAIChatCompletionClient
async def main():
# Initialize agents with OpenAI model clients.
model_client = OpenAIChatCompletionClient(model="gpt-4.1-nano")
agent_a = AssistantAgent(
"A",
model_client=model_client,
system_message="Detect if the input is in Chinese. If it is, say 'yes', else say 'no', and nothing else.",
)
agent_b = AssistantAgent("B", model_client=model_client, system_message="Translate input to English.")
agent_c = AssistantAgent("C", model_client=model_client, system_message="Translate input to Chinese.")
# Create a directed graph with conditional branching flow A -> B ("yes"), A -> C (otherwise).
builder = DiGraphBuilder()
builder.add_node(agent_a).add_node(agent_b).add_node(agent_c)
# Create conditions as callables that check the message content.
builder.add_edge(agent_a, agent_b, condition=lambda msg: "yes" in msg.to_model_text())
builder.add_edge(agent_a, agent_c, condition=lambda msg: "yes" not in msg.to_model_text())
graph = builder.build()
# Create a GraphFlow team with the directed graph.
team = GraphFlow(
participants=[agent_a, agent_b, agent_c],
graph=graph,
termination_condition=MaxMessageTermination(5),
)
# Run the team and print the events.
async for event in team.run_stream(task="AutoGen is a framework for building AI agents."):
print(event)
asyncio.run(main())
```
---------
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: ekzhu <320302+ekzhu@users.noreply.github.com>
2025-06-04 15:43:26 -07:00
|
|
|
# Test with a message containing "long" - should go to B
|
|
|
|
result = await team.run(task="This is a long message")
|
|
|
|
assert result.messages[2].source == "B"
|
|
|
|
|
|
|
|
# Reset for next test
|
|
|
|
await team.reset()
|
|
|
|
|
|
|
|
# Test with a message containing "short" - should go to C
|
|
|
|
result = await team.run(task="This is a short message")
|
|
|
|
assert result.messages[2].source == "C"
|
2025-05-01 03:25:20 +09:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
async def test_graph_flow_serialize_deserialize() -> None:
|
|
|
|
client_a = ReplayChatCompletionClient(list(map(str, range(10))))
|
|
|
|
client_b = ReplayChatCompletionClient(list(map(str, range(10))))
|
|
|
|
a = AssistantAgent("A", model_client=client_a)
|
|
|
|
b = AssistantAgent("B", model_client=client_b)
|
|
|
|
|
|
|
|
builder = DiGraphBuilder()
|
|
|
|
builder.add_node(a).add_node(b)
|
|
|
|
builder.add_edge(a, b)
|
|
|
|
builder.set_entry_point(a)
|
|
|
|
|
|
|
|
team = GraphFlow(
|
|
|
|
participants=builder.get_participants(),
|
|
|
|
graph=builder.build(),
|
|
|
|
runtime=None,
|
|
|
|
)
|
|
|
|
|
|
|
|
serialized = team.dump_component()
|
|
|
|
deserialized_team = GraphFlow.load_component(serialized)
|
|
|
|
serialized_deserialized = deserialized_team.dump_component()
|
|
|
|
|
|
|
|
results = await team.run(task="Start")
|
|
|
|
de_results = await deserialized_team.run(task="Start")
|
|
|
|
|
|
|
|
assert serialized == serialized_deserialized
|
2025-05-23 14:29:24 +09:00
|
|
|
assert compare_task_results(results, de_results)
|
2025-05-01 03:25:20 +09:00
|
|
|
assert results.stop_reason is not None
|
|
|
|
assert results.stop_reason == de_results.stop_reason
|
2025-05-23 14:29:24 +09:00
|
|
|
assert compare_message_lists(results.messages, de_results.messages)
|
2025-05-01 03:25:20 +09:00
|
|
|
assert isinstance(results.messages[0], TextMessage)
|
|
|
|
assert results.messages[0].source == "user"
|
|
|
|
assert results.messages[0].content == "Start"
|
|
|
|
assert isinstance(results.messages[1], TextMessage)
|
|
|
|
assert results.messages[1].source == "A"
|
|
|
|
assert results.messages[1].content == "0"
|
|
|
|
assert isinstance(results.messages[2], TextMessage)
|
Enable concurrent execution of agents in GraphFlow (#6545)
Support concurrent execution in `GraphFlow`:
- Updated `BaseGroupChatManager.select_speaker` to return a union of a
single string or a list of speaker name strings and added logics to
check for currently activated speakers and only proceed to select next
speakers when all activated speakers have finished.
- Updated existing teams (e.g., `SelectorGroupChat`) with the new
signature, while still returning a single speaker in their
implementations.
- Updated `GraphFlow` to support multiple speakers selected.
- Refactored `GraphFlow` for less dictionary gymnastic by using a queue
and update using `update_message_thread`.
Example: a fan out graph:
```python
import asyncio
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.teams import DiGraphBuilder, GraphFlow
from autogen_ext.models.openai import OpenAIChatCompletionClient
async def main():
# Initialize agents with OpenAI model clients.
model_client = OpenAIChatCompletionClient(model="gpt-4.1-nano")
agent_a = AssistantAgent("A", model_client=model_client, system_message="You are a helpful assistant.")
agent_b = AssistantAgent("B", model_client=model_client, system_message="Translate input to Chinese.")
agent_c = AssistantAgent("C", model_client=model_client, system_message="Translate input to Japanese.")
# Create a directed graph with fan-out flow A -> (B, C).
builder = DiGraphBuilder()
builder.add_node(agent_a).add_node(agent_b).add_node(agent_c)
builder.add_edge(agent_a, agent_b).add_edge(agent_a, agent_c)
graph = builder.build()
# Create a GraphFlow team with the directed graph.
team = GraphFlow(
participants=[agent_a, agent_b, agent_c],
graph=graph,
)
# Run the team and print the events.
async for event in team.run_stream(task="Write a short story about a cat."):
print(event)
asyncio.run(main())
```
Resolves:
#6541
#6533
2025-05-19 14:47:55 -07:00
|
|
|
assert results.messages[2].source == "B"
|
|
|
|
assert results.messages[2].content == "0"
|
2025-05-01 03:25:20 +09:00
|
|
|
assert isinstance(results.messages[-1], StopMessage)
|
|
|
|
assert results.messages[-1].source == _DIGRAPH_STOP_AGENT_NAME
|
|
|
|
assert results.messages[-1].content == "Digraph execution is complete"
|
Enable concurrent execution of agents in GraphFlow (#6545)
Support concurrent execution in `GraphFlow`:
- Updated `BaseGroupChatManager.select_speaker` to return a union of a
single string or a list of speaker name strings and added logics to
check for currently activated speakers and only proceed to select next
speakers when all activated speakers have finished.
- Updated existing teams (e.g., `SelectorGroupChat`) with the new
signature, while still returning a single speaker in their
implementations.
- Updated `GraphFlow` to support multiple speakers selected.
- Refactored `GraphFlow` for less dictionary gymnastic by using a queue
and update using `update_message_thread`.
Example: a fan out graph:
```python
import asyncio
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.teams import DiGraphBuilder, GraphFlow
from autogen_ext.models.openai import OpenAIChatCompletionClient
async def main():
# Initialize agents with OpenAI model clients.
model_client = OpenAIChatCompletionClient(model="gpt-4.1-nano")
agent_a = AssistantAgent("A", model_client=model_client, system_message="You are a helpful assistant.")
agent_b = AssistantAgent("B", model_client=model_client, system_message="Translate input to Chinese.")
agent_c = AssistantAgent("C", model_client=model_client, system_message="Translate input to Japanese.")
# Create a directed graph with fan-out flow A -> (B, C).
builder = DiGraphBuilder()
builder.add_node(agent_a).add_node(agent_b).add_node(agent_c)
builder.add_edge(agent_a, agent_b).add_edge(agent_a, agent_c)
graph = builder.build()
# Create a GraphFlow team with the directed graph.
team = GraphFlow(
participants=[agent_a, agent_b, agent_c],
graph=graph,
)
# Run the team and print the events.
async for event in team.run_stream(task="Write a short story about a cat."):
print(event)
asyncio.run(main())
```
Resolves:
#6541
#6533
2025-05-19 14:47:55 -07:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
async def test_graph_flow_stateful_pause_and_resume_with_termination() -> None:
|
|
|
|
client_a = ReplayChatCompletionClient(["A1", "A2"])
|
|
|
|
client_b = ReplayChatCompletionClient(["B1"])
|
|
|
|
|
|
|
|
a = AssistantAgent("A", model_client=client_a)
|
|
|
|
b = AssistantAgent("B", model_client=client_b)
|
|
|
|
|
|
|
|
builder = DiGraphBuilder()
|
|
|
|
builder.add_node(a).add_node(b)
|
|
|
|
builder.add_edge(a, b)
|
|
|
|
builder.set_entry_point(a)
|
|
|
|
|
|
|
|
team = GraphFlow(
|
|
|
|
participants=builder.get_participants(),
|
|
|
|
graph=builder.build(),
|
|
|
|
runtime=None,
|
|
|
|
termination_condition=SourceMatchTermination(sources=["A"]),
|
|
|
|
)
|
|
|
|
|
|
|
|
result = await team.run(task="Start")
|
|
|
|
assert len(result.messages) == 2
|
|
|
|
assert result.messages[0].source == "user"
|
|
|
|
assert result.messages[1].source == "A"
|
|
|
|
assert result.stop_reason is not None and result.stop_reason == "'A' answered"
|
|
|
|
|
|
|
|
# Export state.
|
|
|
|
state = await team.save_state()
|
|
|
|
|
|
|
|
# Load state into a new team.
|
|
|
|
new_team = GraphFlow(
|
|
|
|
participants=builder.get_participants(),
|
|
|
|
graph=builder.build(),
|
|
|
|
runtime=None,
|
|
|
|
)
|
|
|
|
await new_team.load_state(state)
|
|
|
|
|
|
|
|
# Resume.
|
|
|
|
result = await new_team.run()
|
|
|
|
assert len(result.messages) == 2
|
|
|
|
assert result.messages[0].source == "B"
|
|
|
|
assert result.messages[1].source == _DIGRAPH_STOP_AGENT_NAME
|
Add callable condition for GraphFlow edges (#6623)
This PR adds callable as an option to specify conditional edges in
GraphFlow.
```python
import asyncio
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.conditions import MaxMessageTermination
from autogen_agentchat.teams import DiGraphBuilder, GraphFlow
from autogen_ext.models.openai import OpenAIChatCompletionClient
async def main():
# Initialize agents with OpenAI model clients.
model_client = OpenAIChatCompletionClient(model="gpt-4.1-nano")
agent_a = AssistantAgent(
"A",
model_client=model_client,
system_message="Detect if the input is in Chinese. If it is, say 'yes', else say 'no', and nothing else.",
)
agent_b = AssistantAgent("B", model_client=model_client, system_message="Translate input to English.")
agent_c = AssistantAgent("C", model_client=model_client, system_message="Translate input to Chinese.")
# Create a directed graph with conditional branching flow A -> B ("yes"), A -> C (otherwise).
builder = DiGraphBuilder()
builder.add_node(agent_a).add_node(agent_b).add_node(agent_c)
# Create conditions as callables that check the message content.
builder.add_edge(agent_a, agent_b, condition=lambda msg: "yes" in msg.to_model_text())
builder.add_edge(agent_a, agent_c, condition=lambda msg: "yes" not in msg.to_model_text())
graph = builder.build()
# Create a GraphFlow team with the directed graph.
team = GraphFlow(
participants=[agent_a, agent_b, agent_c],
graph=graph,
termination_condition=MaxMessageTermination(5),
)
# Run the team and print the events.
async for event in team.run_stream(task="AutoGen is a framework for building AI agents."):
print(event)
asyncio.run(main())
```
---------
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: ekzhu <320302+ekzhu@users.noreply.github.com>
2025-06-04 15:43:26 -07:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
async def test_builder_with_lambda_condition(runtime: AgentRuntime | None) -> None:
|
|
|
|
"""Test that DiGraphBuilder supports string conditions."""
|
|
|
|
agent_a = _EchoAgent("A", description="Echo agent A")
|
|
|
|
agent_b = _EchoAgent("B", description="Echo agent B")
|
|
|
|
agent_c = _EchoAgent("C", description="Echo agent C")
|
|
|
|
|
|
|
|
builder = DiGraphBuilder()
|
|
|
|
builder.add_node(agent_a).add_node(agent_b).add_node(agent_c)
|
|
|
|
|
|
|
|
# Using callable conditions
|
|
|
|
builder.add_edge(agent_a, agent_b, lambda msg: "even" in msg.to_model_text())
|
|
|
|
builder.add_edge(agent_a, agent_c, lambda msg: "odd" in msg.to_model_text())
|
|
|
|
|
|
|
|
team = GraphFlow(
|
|
|
|
participants=builder.get_participants(),
|
|
|
|
graph=builder.build(),
|
|
|
|
runtime=runtime,
|
|
|
|
termination_condition=MaxMessageTermination(5),
|
|
|
|
)
|
|
|
|
|
|
|
|
# Test with "even" in message - should go to B
|
|
|
|
result = await team.run(task="even length")
|
|
|
|
assert result.messages[2].source == "B"
|
|
|
|
|
|
|
|
# Reset for next test
|
|
|
|
await team.reset()
|
|
|
|
|
|
|
|
# Test with "odd" in message - should go to C
|
|
|
|
result = await team.run(task="odd message")
|
|
|
|
assert result.messages[2].source == "C"
|
Fix GraphFlow to support multiple task execution without explicit reset (#6747)
## Problem
When using GraphFlow with a termination condition, the second task
execution would immediately terminate without running any agents. The
first task would run successfully, but subsequent tasks would skip all
agents and go directly to the stop agent.
This was demonstrated by the following issue:
```python
# First task runs correctly
result1 = await team.run(task="First task") # ✅ Works fine
# Second task fails immediately
result2 = await team.run(task="Second task") # ❌ Only user + stop messages
```
## Root Cause
The `GraphFlowManager` was not resetting its execution state when
termination occurred. After the first task completed:
1. The `_ready` queue was empty (all nodes had been processed)
2. The `_remaining` and `_enqueued_any` tracking structures remained in
"completed" state
3. The `_message_thread` retained history from the previous task
This left the graph in a "completed" state, causing subsequent tasks to
immediately trigger the stop agent instead of executing the workflow.
## Solution
Added an override of the `_apply_termination_condition` method in
`GraphFlowManager` to automatically reset the graph execution state when
termination occurs:
```python
async def _apply_termination_condition(
self, delta: Sequence[BaseAgentEvent | BaseChatMessage], increment_turn_count: bool = False
) -> bool:
# Call the base implementation first
terminated = await super()._apply_termination_condition(delta, increment_turn_count)
# If terminated, reset the graph execution state and message thread for the next task
if terminated:
self._remaining = {target: Counter(groups) for target, groups in self._graph.get_remaining_map().items()}
self._enqueued_any = {n: {g: False for g in self._enqueued_any[n]} for n in self._enqueued_any}
self._ready = deque([n for n in self._graph.get_start_nodes()])
# Clear the message thread to start fresh for the next task
self._message_thread.clear()
return terminated
```
This ensures that when a task completes (termination condition is met),
the graph is automatically reset to its initial state ready for the next
task.
## Testing
Added a comprehensive test case
`test_digraph_group_chat_multiple_task_execution` that validates:
- Multiple tasks can be run sequentially without explicit reset calls
- All agents are executed the expected number of times
- Both tasks produce the correct number of messages
- The fix works with various termination conditions
(MaxMessageTermination, TextMentionTermination)
## Result
GraphFlow now works like SelectorGroupChat where multiple tasks can be
run sequentially without explicit resets between them:
```python
# Both tasks now work correctly
result1 = await team.run(task="First task") # ✅ 5 messages, all agents called
result2 = await team.run(task="Second task") # ✅ 5 messages, all agents called again
```
Fixes #6746.
> [!WARNING]
>
> <details>
> <summary>Firewall rules blocked me from connecting to one or more
addresses</summary>
>
> #### I tried to connect to the following addresses, but was blocked by
firewall rules:
>
> - `esm.ubuntu.com`
> - Triggering command: `/usr/lib/apt/methods/https` (dns block)
>
> If you need me to access, download, or install something from one of
these locations, you can either:
>
> - Configure [Actions setup
steps](https://gh.io/copilot/actions-setup-steps) to set up my
environment, which run before the firewall is enabled
> - Add the appropriate URLs or hosts to my [firewall allow
list](https://gh.io/copilot/firewall-config)
>
> </details>
<!-- START COPILOT CODING AGENT TIPS -->
---
💬 Share your feedback on Copilot coding agent for the chance to win a
$200 gift card! Click
[here](https://survey.alchemer.com/s3/8343779/Copilot-Coding-agent) to
start the survey.
---------
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: ekzhu <320302+ekzhu@users.noreply.github.com>
Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
2025-07-05 23:32:50 -07:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
async def test_digraph_group_chat_multiple_task_execution(runtime: AgentRuntime | None) -> None:
|
|
|
|
"""Test that GraphFlow can run multiple tasks sequentially after resetting execution state."""
|
|
|
|
# Create agents A → B → C
|
|
|
|
agent_a = _EchoAgent("A", description="Echo agent A")
|
|
|
|
agent_b = _EchoAgent("B", description="Echo agent B")
|
|
|
|
agent_c = _EchoAgent("C", description="Echo agent C")
|
|
|
|
|
|
|
|
# Define graph A → B → C
|
|
|
|
graph = DiGraph(
|
|
|
|
nodes={
|
|
|
|
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B")]),
|
|
|
|
"B": DiGraphNode(name="B", edges=[DiGraphEdge(target="C")]),
|
|
|
|
"C": DiGraphNode(name="C", edges=[]),
|
|
|
|
}
|
|
|
|
)
|
|
|
|
|
|
|
|
# Create team using Graph
|
|
|
|
team = GraphFlow(
|
|
|
|
participants=[agent_a, agent_b, agent_c],
|
|
|
|
graph=graph,
|
|
|
|
runtime=runtime,
|
|
|
|
termination_condition=MaxMessageTermination(5),
|
|
|
|
)
|
|
|
|
|
|
|
|
# Run the first task
|
|
|
|
result1: TaskResult = await team.run(task="First task")
|
|
|
|
|
|
|
|
assert len(result1.messages) == 5
|
|
|
|
assert isinstance(result1.messages[0], TextMessage)
|
|
|
|
assert result1.messages[0].source == "user"
|
|
|
|
assert result1.messages[0].content == "First task"
|
|
|
|
assert result1.messages[1].source == "A"
|
|
|
|
assert result1.messages[2].source == "B"
|
|
|
|
assert result1.messages[3].source == "C"
|
|
|
|
assert result1.messages[4].source == _DIGRAPH_STOP_AGENT_NAME
|
|
|
|
assert result1.stop_reason is not None
|
|
|
|
|
|
|
|
# Run the second task - should work without explicit reset
|
|
|
|
result2: TaskResult = await team.run(task="Second task")
|
|
|
|
|
|
|
|
assert len(result2.messages) == 5
|
|
|
|
assert isinstance(result2.messages[0], TextMessage)
|
|
|
|
assert result2.messages[0].source == "user"
|
|
|
|
assert result2.messages[0].content == "Second task"
|
|
|
|
assert result2.messages[1].source == "A"
|
|
|
|
assert result2.messages[2].source == "B"
|
|
|
|
assert result2.messages[3].source == "C"
|
|
|
|
assert result2.messages[4].source == _DIGRAPH_STOP_AGENT_NAME
|
|
|
|
assert result2.stop_reason is not None
|
|
|
|
|
|
|
|
# Verify agents were properly reset and executed again
|
|
|
|
assert agent_a.total_messages == 2 # Once for each task
|
|
|
|
assert agent_b.total_messages == 2 # Once for each task
|
|
|
|
assert agent_c.total_messages == 2 # Once for each task
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
async def test_digraph_group_chat_resume_with_termination_condition(runtime: AgentRuntime | None) -> None:
|
|
|
|
"""Test that GraphFlow can be resumed with the same execution state when a termination condition is reached."""
|
|
|
|
# Create agents A → B → C
|
|
|
|
agent_a = _EchoAgent("A", description="Echo agent A")
|
|
|
|
agent_b = _EchoAgent("B", description="Echo agent B")
|
|
|
|
agent_c = _EchoAgent("C", description="Echo agent C")
|
|
|
|
|
|
|
|
# Define graph A → B → C
|
|
|
|
graph = DiGraph(
|
|
|
|
nodes={
|
|
|
|
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B")]),
|
|
|
|
"B": DiGraphNode(name="B", edges=[DiGraphEdge(target="C")]),
|
|
|
|
"C": DiGraphNode(name="C", edges=[]),
|
|
|
|
}
|
|
|
|
)
|
|
|
|
|
|
|
|
# Create team with MaxMessageTermination that will stop before completion
|
|
|
|
team = GraphFlow(
|
|
|
|
participants=[agent_a, agent_b, agent_c],
|
|
|
|
graph=graph,
|
|
|
|
runtime=runtime,
|
|
|
|
termination_condition=MaxMessageTermination(3), # Stop after user + A + B
|
|
|
|
)
|
|
|
|
|
|
|
|
# Run the graph flow until termination condition is reached
|
|
|
|
result1: TaskResult = await team.run(task="Start execution")
|
|
|
|
|
|
|
|
# Should have stopped at termination condition (user + A + B messages)
|
|
|
|
assert len(result1.messages) == 3
|
|
|
|
assert result1.messages[0].source == "user"
|
|
|
|
assert result1.messages[1].source == "A"
|
|
|
|
assert result1.messages[2].source == "B"
|
|
|
|
assert result1.stop_reason is not None
|
|
|
|
|
|
|
|
# Verify A and B ran, but C did not
|
|
|
|
assert agent_a.total_messages == 1
|
|
|
|
assert agent_b.total_messages == 1
|
|
|
|
assert agent_c.total_messages == 0
|
|
|
|
|
|
|
|
# Resume the graph flow with no task to continue where it left off
|
|
|
|
result2: TaskResult = await team.run()
|
|
|
|
|
|
|
|
# Should continue and execute C, then complete with stop agent
|
|
|
|
assert len(result2.messages) == 2
|
|
|
|
assert result2.messages[0].source == "C"
|
|
|
|
assert result2.messages[1].source == _DIGRAPH_STOP_AGENT_NAME
|
|
|
|
assert result2.stop_reason is not None
|
|
|
|
|
|
|
|
# Verify C now ran and the execution state was preserved
|
|
|
|
assert agent_a.total_messages == 1 # Still only ran once
|
|
|
|
assert agent_b.total_messages == 1 # Still only ran once
|
|
|
|
assert agent_c.total_messages == 1 # Now ran once
|