mirror of
https://github.com/microsoft/autogen.git
synced 2025-06-26 22:30:10 +00:00

## Why are these changes needed? 1. problem When the GraphFlowManager encounters cycles, it tracks remaining indegree counts for the node's activation. However, this tracking mechanism has a flaw when dealing with cycles. When a node first enters a cycle, the GraphFlowManager evaluates all remaining incoming edges, including those that loop back to the origin node. If the activation prerequisites are not satisfied at that moment, the workflow will eventually finish because the _remaining counter never reaches zero, preventing the select_speaker() method from selecting any agents for execution. 2. solution change activation map to 2 layer for ditinguish remaining inside different cycle and outside the cycle. add a activation group and policy property for edge, compute the remaining map when GraphFlowManager is init and check the remaining map with activation group to avoid checking the loop back edges <!-- Please give a short summary of the change and the problem this solves. --> ## Related issue number #6710 ## Checks - [x] I've included any doc changes needed for <https://microsoft.github.io/autogen/>. See <https://github.com/microsoft/autogen/blob/main/CONTRIBUTING.md> to build and test documentation locally. - [x] I've added tests (if relevant) corresponding to the changes introduced in this PR. - [x] I've made sure all auto checks have passed.
1629 lines
57 KiB
Python
1629 lines
57 KiB
Python
import asyncio
|
|
import re
|
|
from typing import AsyncGenerator, List, Sequence
|
|
from unittest.mock import patch
|
|
|
|
import pytest
|
|
import pytest_asyncio
|
|
from autogen_agentchat.agents import (
|
|
AssistantAgent,
|
|
BaseChatAgent,
|
|
MessageFilterAgent,
|
|
MessageFilterConfig,
|
|
PerSourceFilter,
|
|
)
|
|
from autogen_agentchat.base import Response, TaskResult
|
|
from autogen_agentchat.conditions import MaxMessageTermination, SourceMatchTermination
|
|
from autogen_agentchat.messages import BaseChatMessage, ChatMessage, MessageFactory, StopMessage, TextMessage
|
|
from autogen_agentchat.teams import (
|
|
DiGraphBuilder,
|
|
GraphFlow,
|
|
)
|
|
from autogen_agentchat.teams._group_chat._events import ( # type: ignore[attr-defined]
|
|
BaseAgentEvent,
|
|
GroupChatTermination,
|
|
)
|
|
from autogen_agentchat.teams._group_chat._graph._digraph_group_chat import (
|
|
_DIGRAPH_STOP_AGENT_NAME, # pyright: ignore[reportPrivateUsage]
|
|
DiGraph,
|
|
DiGraphEdge,
|
|
DiGraphNode,
|
|
GraphFlowManager,
|
|
)
|
|
from autogen_core import AgentRuntime, CancellationToken, Component, SingleThreadedAgentRuntime
|
|
from autogen_ext.models.replay import ReplayChatCompletionClient
|
|
from pydantic import BaseModel
|
|
from utils import compare_message_lists, compare_task_results
|
|
|
|
|
|
def test_create_digraph() -> None:
|
|
"""Test creating a simple directed graph."""
|
|
graph = DiGraph(
|
|
nodes={
|
|
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B")]),
|
|
"B": DiGraphNode(name="B", edges=[DiGraphEdge(target="C")]),
|
|
"C": DiGraphNode(name="C", edges=[]),
|
|
}
|
|
)
|
|
|
|
assert "A" in graph.nodes
|
|
assert "B" in graph.nodes
|
|
assert "C" in graph.nodes
|
|
assert len(graph.nodes["A"].edges) == 1
|
|
assert len(graph.nodes["B"].edges) == 1
|
|
assert len(graph.nodes["C"].edges) == 0
|
|
|
|
|
|
def test_get_parents() -> None:
|
|
"""Test computing parent relationships."""
|
|
graph = DiGraph(
|
|
nodes={
|
|
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B")]),
|
|
"B": DiGraphNode(name="B", edges=[DiGraphEdge(target="C")]),
|
|
"C": DiGraphNode(name="C", edges=[]),
|
|
}
|
|
)
|
|
|
|
parents = graph.get_parents()
|
|
assert parents["A"] == []
|
|
assert parents["B"] == ["A"]
|
|
assert parents["C"] == ["B"]
|
|
|
|
|
|
def test_get_start_nodes() -> None:
|
|
"""Test retrieving start nodes (nodes with no incoming edges)."""
|
|
graph = DiGraph(
|
|
nodes={
|
|
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B")]),
|
|
"B": DiGraphNode(name="B", edges=[DiGraphEdge(target="C")]),
|
|
"C": DiGraphNode(name="C", edges=[]),
|
|
}
|
|
)
|
|
|
|
start_nodes = graph.get_start_nodes()
|
|
assert start_nodes == set(["A"])
|
|
|
|
|
|
def test_get_leaf_nodes() -> None:
|
|
"""Test retrieving leaf nodes (nodes with no outgoing edges)."""
|
|
graph = DiGraph(
|
|
nodes={
|
|
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B")]),
|
|
"B": DiGraphNode(name="B", edges=[DiGraphEdge(target="C")]),
|
|
"C": DiGraphNode(name="C", edges=[]),
|
|
}
|
|
)
|
|
|
|
leaf_nodes = graph.get_leaf_nodes()
|
|
assert leaf_nodes == set(["C"])
|
|
|
|
|
|
def test_serialization() -> None:
|
|
"""Test serializing and deserializing the graph."""
|
|
# Use a string condition instead of a lambda
|
|
graph = DiGraph(
|
|
nodes={
|
|
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B", condition="trigger1")]),
|
|
"B": DiGraphNode(name="B", edges=[DiGraphEdge(target="C")]),
|
|
"C": DiGraphNode(name="C", edges=[]),
|
|
}
|
|
)
|
|
|
|
serialized = graph.model_dump_json()
|
|
deserialized_graph = DiGraph.model_validate_json(serialized)
|
|
|
|
assert deserialized_graph.nodes["A"].edges[0].target == "B"
|
|
assert deserialized_graph.nodes["A"].edges[0].condition == "trigger1"
|
|
assert deserialized_graph.nodes["B"].edges[0].target == "C"
|
|
|
|
# Test the original condition works
|
|
test_msg = TextMessage(content="this has trigger1 in it", source="test")
|
|
# Manually check if the string is in the message text
|
|
assert "trigger1" in test_msg.to_model_text()
|
|
|
|
|
|
def test_invalid_graph_no_start_node() -> None:
|
|
"""Test validation failure when there is no start node."""
|
|
graph = DiGraph(
|
|
nodes={
|
|
"B": DiGraphNode(name="B", edges=[DiGraphEdge(target="C")]),
|
|
"C": DiGraphNode(name="C", edges=[DiGraphEdge(target="B")]), # Forms a cycle
|
|
}
|
|
)
|
|
|
|
start_nodes = graph.get_start_nodes()
|
|
assert len(start_nodes) == 0 # Now it correctly fails when no start nodes exist
|
|
|
|
|
|
def test_invalid_graph_no_leaf_node() -> None:
|
|
"""Test validation failure when there is no leaf node."""
|
|
graph = DiGraph(
|
|
nodes={
|
|
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B")]),
|
|
"B": DiGraphNode(name="B", edges=[DiGraphEdge(target="C")]),
|
|
"C": DiGraphNode(name="C", edges=[DiGraphEdge(target="A")]), # Circular reference
|
|
}
|
|
)
|
|
|
|
leaf_nodes = graph.get_leaf_nodes()
|
|
assert len(leaf_nodes) == 0 # No true endpoint because of cycle
|
|
|
|
|
|
def test_condition_edge_execution() -> None:
|
|
"""Test conditional edge execution support."""
|
|
# Use string condition
|
|
graph = DiGraph(
|
|
nodes={
|
|
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B", condition="TRIGGER")]),
|
|
"B": DiGraphNode(name="B", edges=[DiGraphEdge(target="C")]),
|
|
"C": DiGraphNode(name="C", edges=[]),
|
|
}
|
|
)
|
|
|
|
# Check the condition manually
|
|
test_message = TextMessage(content="This has TRIGGER in it", source="test")
|
|
non_match_message = TextMessage(content="This doesn't match", source="test")
|
|
|
|
# Check if the string condition is in each message text
|
|
assert "TRIGGER" in test_message.to_model_text()
|
|
assert "TRIGGER" not in non_match_message.to_model_text()
|
|
|
|
# Check the condition itself
|
|
assert graph.nodes["A"].edges[0].condition == "TRIGGER"
|
|
assert graph.nodes["B"].edges[0].condition is None
|
|
|
|
|
|
def test_graph_with_multiple_paths() -> None:
|
|
"""Test a graph with multiple execution paths."""
|
|
graph = DiGraph(
|
|
nodes={
|
|
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B"), DiGraphEdge(target="C")]),
|
|
"B": DiGraphNode(name="B", edges=[DiGraphEdge(target="D")]),
|
|
"C": DiGraphNode(name="C", edges=[DiGraphEdge(target="D")]),
|
|
"D": DiGraphNode(name="D", edges=[]),
|
|
}
|
|
)
|
|
|
|
parents = graph.get_parents()
|
|
assert parents["B"] == ["A"]
|
|
assert parents["C"] == ["A"]
|
|
assert parents["D"] == ["B", "C"]
|
|
|
|
start_nodes = graph.get_start_nodes()
|
|
assert start_nodes == set(["A"])
|
|
|
|
leaf_nodes = graph.get_leaf_nodes()
|
|
assert leaf_nodes == set(["D"])
|
|
|
|
|
|
def test_cycle_detection_no_cycle() -> None:
|
|
"""Test that a valid acyclic graph returns False for cycle check."""
|
|
graph = DiGraph(
|
|
nodes={
|
|
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B")]),
|
|
"B": DiGraphNode(name="B", edges=[DiGraphEdge(target="C")]),
|
|
"C": DiGraphNode(name="C", edges=[]),
|
|
}
|
|
)
|
|
assert not graph.has_cycles_with_exit()
|
|
|
|
|
|
def test_cycle_detection_with_exit_condition() -> None:
|
|
"""Test a graph with cycle and conditional exit passes validation."""
|
|
# Use a string condition
|
|
graph = DiGraph(
|
|
nodes={
|
|
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B")]),
|
|
"B": DiGraphNode(name="B", edges=[DiGraphEdge(target="C")]),
|
|
"C": DiGraphNode(name="C", edges=[DiGraphEdge(target="A", condition="exit")]), # Cycle with condition
|
|
}
|
|
)
|
|
assert graph.has_cycles_with_exit()
|
|
|
|
# Use a lambda condition
|
|
graph_with_lambda = DiGraph(
|
|
nodes={
|
|
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B")]),
|
|
"B": DiGraphNode(name="B", edges=[DiGraphEdge(target="C")]),
|
|
"C": DiGraphNode(
|
|
name="C", edges=[DiGraphEdge(target="A", condition=lambda msg: "test" in msg.to_model_text())]
|
|
), # Cycle with lambda
|
|
}
|
|
)
|
|
assert graph_with_lambda.has_cycles_with_exit()
|
|
|
|
|
|
def test_cycle_detection_without_exit_condition() -> None:
|
|
"""Test that cycle without exit condition raises an error."""
|
|
graph = DiGraph(
|
|
nodes={
|
|
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B")]),
|
|
"B": DiGraphNode(name="B", edges=[DiGraphEdge(target="C")]),
|
|
"C": DiGraphNode(name="C", edges=[DiGraphEdge(target="A")]), # Cycle without condition
|
|
"D": DiGraphNode(name="D", edges=[DiGraphEdge(target="E")]),
|
|
"E": DiGraphNode(name="E", edges=[]),
|
|
}
|
|
)
|
|
with pytest.raises(ValueError, match="Cycle detected without exit condition: A -> B -> C -> A"):
|
|
graph.has_cycles_with_exit()
|
|
|
|
|
|
def test_different_activation_groups_detection() -> None:
|
|
"""Test different activation groups."""
|
|
graph = DiGraph(
|
|
nodes={
|
|
"A": DiGraphNode(
|
|
name="A",
|
|
edges=[
|
|
DiGraphEdge(target="B"),
|
|
DiGraphEdge(target="C"),
|
|
],
|
|
),
|
|
"B": DiGraphNode(name="B", edges=[DiGraphEdge(target="D", activation_condition="all")]),
|
|
"C": DiGraphNode(name="C", edges=[DiGraphEdge(target="D", activation_condition="any")]),
|
|
"D": DiGraphNode(name="D", edges=[]),
|
|
}
|
|
)
|
|
with pytest.raises(
|
|
ValueError,
|
|
match=re.escape(
|
|
"Conflicting activation conditions for target 'D' group 'D': "
|
|
"'all' (from node 'B') and 'any' (from node 'C')"
|
|
),
|
|
):
|
|
graph.graph_validate()
|
|
|
|
|
|
def test_validate_graph_success() -> None:
|
|
"""Test successful validation of a valid graph."""
|
|
graph = DiGraph(
|
|
nodes={
|
|
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B")]),
|
|
"B": DiGraphNode(name="B", edges=[]),
|
|
}
|
|
)
|
|
# No error should be raised
|
|
graph.graph_validate()
|
|
assert not graph.get_has_cycles()
|
|
|
|
# Use a lambda condition
|
|
graph_with_lambda = DiGraph(
|
|
nodes={
|
|
"A": DiGraphNode(
|
|
name="A", edges=[DiGraphEdge(target="B", condition=lambda msg: "test" in msg.to_model_text())]
|
|
),
|
|
"B": DiGraphNode(name="B", edges=[]),
|
|
}
|
|
)
|
|
# No error should be raised
|
|
graph_with_lambda.graph_validate()
|
|
assert not graph_with_lambda.get_has_cycles()
|
|
|
|
|
|
def test_validate_graph_missing_start_node() -> None:
|
|
"""Test validation failure when no start node exists."""
|
|
graph = DiGraph(
|
|
nodes={
|
|
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B")]),
|
|
"B": DiGraphNode(name="B", edges=[DiGraphEdge(target="A")]), # Cycle
|
|
}
|
|
)
|
|
with pytest.raises(ValueError, match="Graph must have at least one start node"):
|
|
graph.graph_validate()
|
|
|
|
|
|
def test_validate_graph_missing_leaf_node() -> None:
|
|
"""Test validation failure when no leaf node exists."""
|
|
graph = DiGraph(
|
|
nodes={
|
|
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B")]),
|
|
"B": DiGraphNode(name="B", edges=[DiGraphEdge(target="C")]),
|
|
"C": DiGraphNode(name="C", edges=[DiGraphEdge(target="B")]), # Cycle
|
|
}
|
|
)
|
|
with pytest.raises(ValueError, match="Graph must have at least one leaf node"):
|
|
graph.graph_validate()
|
|
|
|
|
|
def test_validate_graph_mixed_conditions() -> None:
|
|
"""Test validation failure when node has mixed conditional and unconditional edges."""
|
|
# Use string for condition
|
|
graph = DiGraph(
|
|
nodes={
|
|
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B", condition="cond"), DiGraphEdge(target="C")]),
|
|
"B": DiGraphNode(name="B", edges=[]),
|
|
"C": DiGraphNode(name="C", edges=[]),
|
|
}
|
|
)
|
|
with pytest.raises(ValueError, match="Node 'A' has a mix of conditional and unconditional edges"):
|
|
graph.graph_validate()
|
|
|
|
# Use lambda for condition
|
|
graph_with_lambda = DiGraph(
|
|
nodes={
|
|
"A": DiGraphNode(
|
|
name="A",
|
|
edges=[
|
|
DiGraphEdge(target="B", condition=lambda msg: "test" in msg.to_model_text()),
|
|
DiGraphEdge(target="C"),
|
|
],
|
|
),
|
|
"B": DiGraphNode(name="B", edges=[]),
|
|
"C": DiGraphNode(name="C", edges=[]),
|
|
}
|
|
)
|
|
with pytest.raises(ValueError, match="Node 'A' has a mix of conditional and unconditional edges"):
|
|
graph_with_lambda.graph_validate()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_invalid_digraph_manager_cycle_without_termination() -> None:
|
|
"""Test GraphManager raises error for cyclic graph without termination condition."""
|
|
# Create a cyclic graph A → B → A
|
|
graph = DiGraph(
|
|
nodes={
|
|
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B")]),
|
|
"B": DiGraphNode(name="B", edges=[DiGraphEdge(target="A")]),
|
|
}
|
|
)
|
|
|
|
output_queue: asyncio.Queue[BaseAgentEvent | BaseChatMessage | GroupChatTermination] = asyncio.Queue()
|
|
|
|
with patch(
|
|
"autogen_agentchat.teams._group_chat._base_group_chat_manager.BaseGroupChatManager.__init__",
|
|
return_value=None,
|
|
):
|
|
manager = GraphFlowManager.__new__(GraphFlowManager)
|
|
|
|
with pytest.raises(ValueError, match="Graph must have at least one start node"):
|
|
manager.__init__( # type: ignore[misc]
|
|
name="test_manager",
|
|
group_topic_type="topic",
|
|
output_topic_type="topic",
|
|
participant_topic_types=["topic1", "topic2"],
|
|
participant_names=["A", "B"],
|
|
participant_descriptions=["Agent A", "Agent B"],
|
|
output_message_queue=output_queue,
|
|
termination_condition=None,
|
|
max_turns=None,
|
|
message_factory=MessageFactory(),
|
|
graph=graph,
|
|
)
|
|
|
|
|
|
class _EchoAgent(BaseChatAgent):
|
|
def __init__(self, name: str, description: str) -> None:
|
|
super().__init__(name, description)
|
|
self._last_message: str | None = None
|
|
self._total_messages = 0
|
|
|
|
@property
|
|
def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:
|
|
return (TextMessage,)
|
|
|
|
@property
|
|
def total_messages(self) -> int:
|
|
return self._total_messages
|
|
|
|
async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response:
|
|
if len(messages) > 0:
|
|
assert isinstance(messages[0], TextMessage)
|
|
self._last_message = messages[0].content
|
|
self._total_messages += 1
|
|
return Response(chat_message=TextMessage(content=messages[0].content, source=self.name))
|
|
else:
|
|
assert self._last_message is not None
|
|
self._total_messages += 1
|
|
return Response(chat_message=TextMessage(content=self._last_message, source=self.name))
|
|
|
|
async def on_reset(self, cancellation_token: CancellationToken) -> None:
|
|
self._last_message = None
|
|
|
|
|
|
@pytest_asyncio.fixture(params=["single_threaded", "embedded"]) # type: ignore
|
|
async def runtime(request: pytest.FixtureRequest) -> AsyncGenerator[AgentRuntime | None, None]:
|
|
if request.param == "single_threaded":
|
|
runtime = SingleThreadedAgentRuntime()
|
|
runtime.start()
|
|
yield runtime
|
|
await runtime.stop()
|
|
elif request.param == "embedded":
|
|
yield None
|
|
|
|
|
|
TaskType = str | List[ChatMessage] | ChatMessage
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_digraph_group_chat_sequential_execution(runtime: AgentRuntime | None) -> None:
|
|
# Create agents A → B → C
|
|
agent_a = _EchoAgent("A", description="Echo agent A")
|
|
agent_b = _EchoAgent("B", description="Echo agent B")
|
|
agent_c = _EchoAgent("C", description="Echo agent C")
|
|
|
|
# Define graph A → B → C
|
|
graph = DiGraph(
|
|
nodes={
|
|
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B")]),
|
|
"B": DiGraphNode(name="B", edges=[DiGraphEdge(target="C")]),
|
|
"C": DiGraphNode(name="C", edges=[]),
|
|
}
|
|
)
|
|
|
|
# Create team using Graph
|
|
team = GraphFlow(
|
|
participants=[agent_a, agent_b, agent_c],
|
|
graph=graph,
|
|
runtime=runtime,
|
|
termination_condition=MaxMessageTermination(5),
|
|
)
|
|
|
|
# Run the chat
|
|
result: TaskResult = await team.run(task="Hello from User")
|
|
|
|
assert len(result.messages) == 5
|
|
assert isinstance(result.messages[0], TextMessage)
|
|
assert result.messages[0].source == "user"
|
|
assert result.messages[1].source == "A"
|
|
assert result.messages[2].source == "B"
|
|
assert result.messages[3].source == "C"
|
|
assert result.messages[4].source == _DIGRAPH_STOP_AGENT_NAME
|
|
assert all(isinstance(m, TextMessage) for m in result.messages[:-1])
|
|
assert result.stop_reason is not None
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_digraph_group_chat_parallel_fanout(runtime: AgentRuntime | None) -> None:
|
|
agent_a = _EchoAgent("A", description="Echo agent A")
|
|
agent_b = _EchoAgent("B", description="Echo agent B")
|
|
agent_c = _EchoAgent("C", description="Echo agent C")
|
|
|
|
graph = DiGraph(
|
|
nodes={
|
|
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B"), DiGraphEdge(target="C")]),
|
|
"B": DiGraphNode(name="B", edges=[]),
|
|
"C": DiGraphNode(name="C", edges=[]),
|
|
}
|
|
)
|
|
|
|
team = GraphFlow(
|
|
participants=[agent_a, agent_b, agent_c],
|
|
graph=graph,
|
|
runtime=runtime,
|
|
termination_condition=MaxMessageTermination(5),
|
|
)
|
|
|
|
result: TaskResult = await team.run(task="Start")
|
|
assert len(result.messages) == 5
|
|
assert result.messages[0].source == "user"
|
|
assert result.messages[1].source == "A"
|
|
assert set(m.source for m in result.messages[2:-1]) == {"B", "C"}
|
|
assert result.messages[-1].source == _DIGRAPH_STOP_AGENT_NAME
|
|
assert result.stop_reason is not None
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_digraph_group_chat_parallel_join_all(runtime: AgentRuntime | None) -> None:
|
|
agent_a = _EchoAgent("A", description="Echo agent A")
|
|
agent_b = _EchoAgent("B", description="Echo agent B")
|
|
agent_c = _EchoAgent("C", description="Echo agent C")
|
|
|
|
graph = DiGraph(
|
|
nodes={
|
|
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="C")]),
|
|
"B": DiGraphNode(name="B", edges=[DiGraphEdge(target="C")]),
|
|
"C": DiGraphNode(name="C", edges=[], activation="all"),
|
|
}
|
|
)
|
|
|
|
team = GraphFlow(
|
|
participants=[agent_a, agent_b, agent_c],
|
|
graph=graph,
|
|
runtime=runtime,
|
|
termination_condition=MaxMessageTermination(5),
|
|
)
|
|
|
|
result: TaskResult = await team.run(task="Go")
|
|
assert len(result.messages) == 5
|
|
assert result.messages[0].source == "user"
|
|
assert set([result.messages[1].source, result.messages[2].source]) == {"A", "B"}
|
|
assert result.messages[3].source == "C"
|
|
assert result.messages[-1].source == _DIGRAPH_STOP_AGENT_NAME
|
|
assert result.stop_reason is not None
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_digraph_group_chat_parallel_join_any(runtime: AgentRuntime | None) -> None:
|
|
agent_a = _EchoAgent("A", description="Echo agent A")
|
|
agent_b = _EchoAgent("B", description="Echo agent B")
|
|
agent_c = _EchoAgent("C", description="Echo agent C")
|
|
|
|
graph = DiGraph(
|
|
nodes={
|
|
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="C")]),
|
|
"B": DiGraphNode(name="B", edges=[DiGraphEdge(target="C")]),
|
|
"C": DiGraphNode(name="C", edges=[], activation="any"),
|
|
}
|
|
)
|
|
|
|
team = GraphFlow(
|
|
participants=[agent_a, agent_b, agent_c],
|
|
graph=graph,
|
|
runtime=runtime,
|
|
termination_condition=MaxMessageTermination(5),
|
|
)
|
|
|
|
result: TaskResult = await team.run(task="Start")
|
|
|
|
assert len(result.messages) == 5
|
|
assert result.messages[0].source == "user"
|
|
sources = [m.source for m in result.messages[1:]]
|
|
|
|
# C must be last
|
|
assert sources[-2] == "C"
|
|
|
|
# A and B must both execute
|
|
assert {"A", "B"}.issubset(set(sources))
|
|
|
|
# One of A or B must execute before C
|
|
index_a = sources.index("A")
|
|
index_b = sources.index("B")
|
|
index_c = sources.index("C")
|
|
assert index_c > min(index_a, index_b)
|
|
assert result.messages[-1].source == _DIGRAPH_STOP_AGENT_NAME
|
|
assert result.stop_reason is not None
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_digraph_group_chat_multiple_start_nodes(runtime: AgentRuntime | None) -> None:
|
|
agent_a = _EchoAgent("A", description="Echo agent A")
|
|
agent_b = _EchoAgent("B", description="Echo agent B")
|
|
|
|
graph = DiGraph(
|
|
nodes={
|
|
"A": DiGraphNode(name="A", edges=[]),
|
|
"B": DiGraphNode(name="B", edges=[]),
|
|
}
|
|
)
|
|
|
|
team = GraphFlow(
|
|
participants=[agent_a, agent_b],
|
|
graph=graph,
|
|
runtime=runtime,
|
|
termination_condition=MaxMessageTermination(5),
|
|
)
|
|
|
|
result: TaskResult = await team.run(task="Start")
|
|
assert len(result.messages) == 4
|
|
assert result.messages[0].source == "user"
|
|
assert set(m.source for m in result.messages[1:-1]) == {"A", "B"}
|
|
assert result.messages[-1].source == _DIGRAPH_STOP_AGENT_NAME
|
|
assert result.stop_reason is not None
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_digraph_group_chat_disconnected_graph(runtime: AgentRuntime | None) -> None:
|
|
agent_a = _EchoAgent("A", description="Echo agent A")
|
|
agent_b = _EchoAgent("B", description="Echo agent B")
|
|
agent_c = _EchoAgent("C", description="Echo agent C")
|
|
agent_d = _EchoAgent("D", description="Echo agent D")
|
|
|
|
graph = DiGraph(
|
|
nodes={
|
|
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B")]),
|
|
"B": DiGraphNode(name="B", edges=[]),
|
|
"C": DiGraphNode(name="C", edges=[DiGraphEdge(target="D")]),
|
|
"D": DiGraphNode(name="D", edges=[]),
|
|
}
|
|
)
|
|
|
|
team = GraphFlow(
|
|
participants=[agent_a, agent_b, agent_c, agent_d],
|
|
graph=graph,
|
|
runtime=runtime,
|
|
termination_condition=MaxMessageTermination(10),
|
|
)
|
|
|
|
result: TaskResult = await team.run(task="Go")
|
|
assert len(result.messages) == 6
|
|
assert result.messages[0].source == "user"
|
|
assert {"A", "C"} == set([result.messages[1].source, result.messages[2].source])
|
|
assert {"B", "D"} == set([result.messages[3].source, result.messages[4].source])
|
|
assert result.messages[-1].source == _DIGRAPH_STOP_AGENT_NAME
|
|
assert result.stop_reason is not None
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_digraph_group_chat_conditional_branch(runtime: AgentRuntime | None) -> None:
|
|
agent_a = _EchoAgent("A", description="Echo agent A")
|
|
agent_b = _EchoAgent("B", description="Echo agent B")
|
|
agent_c = _EchoAgent("C", description="Echo agent C")
|
|
|
|
# Use string conditions
|
|
graph = DiGraph(
|
|
nodes={
|
|
"A": DiGraphNode(
|
|
name="A", edges=[DiGraphEdge(target="B", condition="yes"), DiGraphEdge(target="C", condition="no")]
|
|
),
|
|
"B": DiGraphNode(name="B", edges=[], activation="any"),
|
|
"C": DiGraphNode(name="C", edges=[], activation="any"),
|
|
}
|
|
)
|
|
|
|
team = GraphFlow(
|
|
participants=[agent_a, agent_b, agent_c],
|
|
graph=graph,
|
|
runtime=runtime,
|
|
termination_condition=MaxMessageTermination(5),
|
|
)
|
|
|
|
result = await team.run(task="Trigger yes")
|
|
assert result.messages[2].source == "B"
|
|
|
|
# Use lambda conditions
|
|
graph_with_lambda = DiGraph(
|
|
nodes={
|
|
"A": DiGraphNode(
|
|
name="A",
|
|
edges=[
|
|
DiGraphEdge(target="B", condition=lambda msg: "yes" in msg.to_model_text()),
|
|
DiGraphEdge(target="C", condition=lambda msg: "no" in msg.to_model_text()),
|
|
],
|
|
),
|
|
"B": DiGraphNode(name="B", edges=[], activation="any"),
|
|
"C": DiGraphNode(name="C", edges=[], activation="any"),
|
|
}
|
|
)
|
|
team_with_lambda = GraphFlow(
|
|
participants=[agent_a, agent_b, agent_c],
|
|
graph=graph_with_lambda,
|
|
runtime=runtime,
|
|
termination_condition=MaxMessageTermination(5),
|
|
)
|
|
result_with_lambda = await team_with_lambda.run(task="Trigger no")
|
|
assert result_with_lambda.messages[2].source == "C"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_digraph_group_chat_loop_with_exit_condition(runtime: AgentRuntime | None) -> None:
|
|
# Agents A and C: Echo Agents
|
|
agent_a = _EchoAgent("A", description="Echo agent A")
|
|
agent_c = _EchoAgent("C", description="Echo agent C")
|
|
|
|
# Replay model client for agent B
|
|
model_client = ReplayChatCompletionClient(
|
|
chat_completions=[
|
|
"loop", # First time B will ask to loop
|
|
"loop", # Second time B will ask to loop
|
|
"exit", # Third time B will say exit
|
|
]
|
|
)
|
|
# Agent B: Assistant Agent using Replay Client
|
|
agent_b = AssistantAgent("B", description="Decision agent B", model_client=model_client)
|
|
|
|
# DiGraph: A → B → C (conditional back to A or terminate)
|
|
graph = DiGraph(
|
|
nodes={
|
|
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B")]),
|
|
"B": DiGraphNode(
|
|
name="B", edges=[DiGraphEdge(target="C", condition="exit"), DiGraphEdge(target="A", condition="loop")]
|
|
),
|
|
"C": DiGraphNode(name="C", edges=[]),
|
|
},
|
|
default_start_node="A",
|
|
)
|
|
|
|
team = GraphFlow(
|
|
participants=[agent_a, agent_b, agent_c],
|
|
graph=graph,
|
|
runtime=runtime,
|
|
termination_condition=MaxMessageTermination(20),
|
|
)
|
|
|
|
# Run
|
|
result = await team.run(task="Start")
|
|
|
|
# Assert message order
|
|
expected_sources = [
|
|
"user",
|
|
"A",
|
|
"B", # 1st loop
|
|
"A",
|
|
"B", # 2nd loop
|
|
"A",
|
|
"B",
|
|
"C",
|
|
_DIGRAPH_STOP_AGENT_NAME,
|
|
]
|
|
|
|
actual_sources = [m.source for m in result.messages]
|
|
|
|
assert actual_sources == expected_sources
|
|
assert result.stop_reason is not None
|
|
assert result.messages[-2].source == "C"
|
|
assert any(m.content == "exit" for m in result.messages[:-1]) # type: ignore[attr-defined,union-attr]
|
|
assert result.messages[-1].source == _DIGRAPH_STOP_AGENT_NAME
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_digraph_group_chat_loop_with_self_cycle(runtime: AgentRuntime | None) -> None:
|
|
# Agents A and C: Echo Agents
|
|
agent_a = _EchoAgent("A", description="Echo agent A")
|
|
agent_c = _EchoAgent("C", description="Echo agent C")
|
|
|
|
# Replay model client for agent B
|
|
model_client = ReplayChatCompletionClient(
|
|
chat_completions=[
|
|
"loop", # First time B will ask to loop
|
|
"loop", # Second time B will ask to loop
|
|
"exit", # Third time B will say exit
|
|
]
|
|
)
|
|
# Agent B: Assistant Agent using Replay Client
|
|
agent_b = AssistantAgent("B", description="Decision agent B", model_client=model_client)
|
|
|
|
# DiGraph: A → B(self loop) → C (conditional back to A or terminate)
|
|
graph = DiGraph(
|
|
nodes={
|
|
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B")]),
|
|
"B": DiGraphNode(
|
|
name="B",
|
|
edges=[
|
|
DiGraphEdge(target="C", condition="exit"),
|
|
DiGraphEdge(target="B", condition="loop", activation_group="B_loop"),
|
|
],
|
|
),
|
|
"C": DiGraphNode(name="C", edges=[]),
|
|
},
|
|
default_start_node="A",
|
|
)
|
|
|
|
team = GraphFlow(
|
|
participants=[agent_a, agent_b, agent_c],
|
|
graph=graph,
|
|
runtime=runtime,
|
|
termination_condition=MaxMessageTermination(20),
|
|
)
|
|
|
|
# Run
|
|
result = await team.run(task="Start")
|
|
|
|
# Assert message order
|
|
expected_sources = [
|
|
"user",
|
|
"A",
|
|
"B", # 1st loop
|
|
"B", # 2nd loop
|
|
"B",
|
|
"C",
|
|
_DIGRAPH_STOP_AGENT_NAME,
|
|
]
|
|
|
|
actual_sources = [m.source for m in result.messages]
|
|
|
|
assert actual_sources == expected_sources
|
|
assert result.stop_reason is not None
|
|
assert result.messages[-2].source == "C"
|
|
assert any(m.content == "exit" for m in result.messages[:-1]) # type: ignore[attr-defined,union-attr]
|
|
assert result.messages[-1].source == _DIGRAPH_STOP_AGENT_NAME
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_digraph_group_chat_loop_with_two_cycles(runtime: AgentRuntime | None) -> None:
|
|
# Agents A and C: Echo Agents
|
|
agent_a = _EchoAgent("A", description="Echo agent A")
|
|
agent_b = _EchoAgent("B", description="Echo agent B")
|
|
agent_c = _EchoAgent("C", description="Echo agent C")
|
|
agent_e = _EchoAgent("E", description="Echo agent E")
|
|
|
|
# Replay model client for agent B
|
|
model_client = ReplayChatCompletionClient(
|
|
chat_completions=[
|
|
"to_x", # First time O will branch to B
|
|
"to_o", # X will go back to O
|
|
"to_y", # Second time O will branch to C
|
|
"to_o", # Y will go back to O
|
|
"exit", # Third time O will say exit
|
|
]
|
|
)
|
|
# Agent o, b, c: Assistant Agent using Replay Client
|
|
agent_o = AssistantAgent("O", description="Decision agent o", model_client=model_client)
|
|
agent_x = AssistantAgent("X", description="Decision agent x", model_client=model_client)
|
|
agent_y = AssistantAgent("Y", description="Decision agent y", model_client=model_client)
|
|
|
|
# DiGraph:
|
|
#
|
|
# A
|
|
# / \
|
|
# B C
|
|
# \ |
|
|
# X = O = Y (bidirectional)
|
|
# |
|
|
# E(exit)
|
|
graph = DiGraph(
|
|
nodes={
|
|
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B"), DiGraphEdge(target="C")]),
|
|
"B": DiGraphNode(
|
|
name="B", edges=[DiGraphEdge(target="O")]
|
|
), # default activation group name is same as target node name "O"
|
|
"C": DiGraphNode(
|
|
name="C", edges=[DiGraphEdge(target="O")]
|
|
), # default activation group name is same as target node name "O"
|
|
"O": DiGraphNode(
|
|
name="O",
|
|
edges=[
|
|
DiGraphEdge(target="X", condition="to_x"),
|
|
DiGraphEdge(target="Y", condition="to_y"),
|
|
DiGraphEdge(target="E", condition="exit"),
|
|
],
|
|
),
|
|
"X": DiGraphNode(name="X", edges=[DiGraphEdge(target="O", condition="to_o", activation_group="x_o_loop")]),
|
|
"Y": DiGraphNode(name="Y", edges=[DiGraphEdge(target="O", condition="to_o", activation_group="y_o_loop")]),
|
|
"E": DiGraphNode(name="E", edges=[]),
|
|
},
|
|
default_start_node="A",
|
|
)
|
|
|
|
team = GraphFlow(
|
|
participants=[agent_a, agent_o, agent_b, agent_c, agent_x, agent_y, agent_e],
|
|
graph=graph,
|
|
runtime=runtime,
|
|
termination_condition=MaxMessageTermination(20),
|
|
)
|
|
|
|
# Run
|
|
result = await team.run(task="Start")
|
|
|
|
# Assert message order
|
|
expected_sources = [
|
|
"user",
|
|
"A",
|
|
"B",
|
|
"C",
|
|
"O",
|
|
"X", # O -> X
|
|
"O", # X -> O
|
|
"Y", # O -> Y
|
|
"O", # Y -> O
|
|
"E", # O -> E
|
|
_DIGRAPH_STOP_AGENT_NAME,
|
|
]
|
|
|
|
actual_sources = [m.source for m in result.messages]
|
|
|
|
assert actual_sources == expected_sources
|
|
assert result.stop_reason is not None
|
|
assert result.messages[-2].source == "E"
|
|
assert any(m.content == "exit" for m in result.messages[:-1]) # type: ignore[attr-defined,union-attr]
|
|
assert result.messages[-1].source == _DIGRAPH_STOP_AGENT_NAME
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_digraph_group_chat_parallel_join_any_1(runtime: AgentRuntime | None) -> None:
|
|
agent_a = _EchoAgent("A", description="Echo agent A")
|
|
agent_b = _EchoAgent("B", description="Echo agent B")
|
|
agent_c = _EchoAgent("C", description="Echo agent C")
|
|
agent_d = _EchoAgent("D", description="Echo agent D")
|
|
|
|
graph = DiGraph(
|
|
nodes={
|
|
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B"), DiGraphEdge(target="C")]),
|
|
"B": DiGraphNode(name="B", edges=[DiGraphEdge(target="D", activation_group="any")]),
|
|
"C": DiGraphNode(name="C", edges=[DiGraphEdge(target="D", activation_group="any")]),
|
|
"D": DiGraphNode(name="D", edges=[]),
|
|
}
|
|
)
|
|
|
|
team = GraphFlow(
|
|
participants=[agent_a, agent_b, agent_c, agent_d],
|
|
graph=graph,
|
|
runtime=runtime,
|
|
termination_condition=MaxMessageTermination(10),
|
|
)
|
|
|
|
result = await team.run(task="Run parallel join")
|
|
sequence = [msg.source for msg in result.messages if isinstance(msg, TextMessage)]
|
|
assert sequence[0] == "user"
|
|
# B and C should both run
|
|
assert "B" in sequence
|
|
assert "C" in sequence
|
|
# D should trigger twice → once after B and once after C (order depends on runtime)
|
|
d_indices = [i for i, s in enumerate(sequence) if s == "D"]
|
|
assert len(d_indices) == 1
|
|
# Each D trigger must be after corresponding B or C
|
|
b_index = sequence.index("B")
|
|
c_index = sequence.index("C")
|
|
assert any(d > b_index for d in d_indices)
|
|
assert any(d > c_index for d in d_indices)
|
|
assert result.stop_reason is not None
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_digraph_group_chat_chained_parallel_join_any(runtime: AgentRuntime | None) -> None:
|
|
agent_a = _EchoAgent("A", description="Echo agent A")
|
|
agent_b = _EchoAgent("B", description="Echo agent B")
|
|
agent_c = _EchoAgent("C", description="Echo agent C")
|
|
agent_d = _EchoAgent("D", description="Echo agent D")
|
|
agent_e = _EchoAgent("E", description="Echo agent E")
|
|
|
|
graph = DiGraph(
|
|
nodes={
|
|
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B"), DiGraphEdge(target="C")]),
|
|
"B": DiGraphNode(name="B", edges=[DiGraphEdge(target="D")]),
|
|
"C": DiGraphNode(name="C", edges=[DiGraphEdge(target="D")]),
|
|
"D": DiGraphNode(name="D", edges=[DiGraphEdge(target="E")], activation="any"),
|
|
"E": DiGraphNode(name="E", edges=[], activation="any"),
|
|
}
|
|
)
|
|
|
|
team = GraphFlow(
|
|
participants=[agent_a, agent_b, agent_c, agent_d, agent_e],
|
|
graph=graph,
|
|
runtime=runtime,
|
|
termination_condition=MaxMessageTermination(20),
|
|
)
|
|
|
|
result = await team.run(task="Run chained parallel join-any")
|
|
|
|
sequence = [msg.source for msg in result.messages if isinstance(msg, TextMessage)]
|
|
|
|
# D should trigger twice
|
|
d_indices = [i for i, s in enumerate(sequence) if s == "D"]
|
|
assert len(d_indices) == 1
|
|
# Each D trigger must be after corresponding B or C
|
|
b_index = sequence.index("B")
|
|
c_index = sequence.index("C")
|
|
assert any(d > b_index for d in d_indices)
|
|
assert any(d > c_index for d in d_indices)
|
|
|
|
# E should also trigger twice → once after each D
|
|
e_indices = [i for i, s in enumerate(sequence) if s == "E"]
|
|
assert len(e_indices) == 1
|
|
assert e_indices[0] > d_indices[0]
|
|
assert result.stop_reason is not None
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_digraph_group_chat_multiple_conditional(runtime: AgentRuntime | None) -> None:
|
|
agent_a = _EchoAgent("A", description="Echo agent A")
|
|
agent_b = _EchoAgent("B", description="Echo agent B")
|
|
agent_c = _EchoAgent("C", description="Echo agent C")
|
|
agent_d = _EchoAgent("D", description="Echo agent D")
|
|
|
|
# Use string conditions
|
|
graph = DiGraph(
|
|
nodes={
|
|
"A": DiGraphNode(
|
|
name="A",
|
|
edges=[
|
|
DiGraphEdge(target="B", condition="apple"),
|
|
DiGraphEdge(target="C", condition="banana"),
|
|
DiGraphEdge(target="D", condition="cherry"),
|
|
],
|
|
),
|
|
"B": DiGraphNode(name="B", edges=[]),
|
|
"C": DiGraphNode(name="C", edges=[]),
|
|
"D": DiGraphNode(name="D", edges=[]),
|
|
}
|
|
)
|
|
|
|
team = GraphFlow(
|
|
participants=[agent_a, agent_b, agent_c, agent_d],
|
|
graph=graph,
|
|
runtime=runtime,
|
|
termination_condition=MaxMessageTermination(5),
|
|
)
|
|
|
|
# Test banana branch
|
|
result = await team.run(task="banana")
|
|
assert result.messages[2].source == "C"
|
|
|
|
# Use lambda conditions
|
|
graph_with_lambda = DiGraph(
|
|
nodes={
|
|
"A": DiGraphNode(
|
|
name="A",
|
|
edges=[
|
|
DiGraphEdge(target="B", condition=lambda msg: "apple" in msg.to_model_text()),
|
|
DiGraphEdge(target="C", condition=lambda msg: "banana" in msg.to_model_text()),
|
|
DiGraphEdge(target="D", condition=lambda msg: "cherry" in msg.to_model_text()),
|
|
],
|
|
),
|
|
"B": DiGraphNode(name="B", edges=[]),
|
|
"C": DiGraphNode(name="C", edges=[]),
|
|
"D": DiGraphNode(name="D", edges=[]),
|
|
}
|
|
)
|
|
team_with_lambda = GraphFlow(
|
|
participants=[agent_a, agent_b, agent_c, agent_d],
|
|
graph=graph_with_lambda,
|
|
runtime=runtime,
|
|
termination_condition=MaxMessageTermination(5),
|
|
)
|
|
result_with_lambda = await team_with_lambda.run(task="cherry")
|
|
assert result_with_lambda.messages[2].source == "D"
|
|
|
|
|
|
class _TestMessageFilterAgentConfig(BaseModel):
|
|
name: str
|
|
description: str = "Echo test agent"
|
|
|
|
|
|
class _TestMessageFilterAgent(BaseChatAgent, Component[_TestMessageFilterAgentConfig]):
|
|
component_config_schema = _TestMessageFilterAgentConfig
|
|
component_provider_override = "test_group_chat_graph._TestMessageFilterAgent"
|
|
|
|
def __init__(self, name: str, description: str = "Echo test agent") -> None:
|
|
super().__init__(name=name, description=description)
|
|
self.received_messages: list[BaseChatMessage] = []
|
|
|
|
@property
|
|
def produced_message_types(self) -> Sequence[type[BaseChatMessage]]:
|
|
return (TextMessage,)
|
|
|
|
async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response:
|
|
self.received_messages.extend(messages)
|
|
return Response(chat_message=TextMessage(content="ACK", source=self.name))
|
|
|
|
async def on_reset(self, cancellation_token: CancellationToken) -> None:
|
|
self.received_messages.clear()
|
|
|
|
def _to_config(self) -> _TestMessageFilterAgentConfig:
|
|
return _TestMessageFilterAgentConfig(name=self.name, description=self.description)
|
|
|
|
@classmethod
|
|
def _from_config(cls, config: _TestMessageFilterAgentConfig) -> "_TestMessageFilterAgent":
|
|
return cls(name=config.name, description=config.description)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_message_filter_agent_empty_filter_blocks_all() -> None:
|
|
inner_agent = _TestMessageFilterAgent("inner")
|
|
wrapper = MessageFilterAgent(
|
|
name="wrapper",
|
|
wrapped_agent=inner_agent,
|
|
filter=MessageFilterConfig(per_source=[]),
|
|
)
|
|
messages = [
|
|
TextMessage(source="user", content="Hello"),
|
|
TextMessage(source="system", content="System msg"),
|
|
]
|
|
await wrapper.on_messages(messages, CancellationToken())
|
|
assert len(inner_agent.received_messages) == 0
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_message_filter_agent_with_position_none_gets_all() -> None:
|
|
inner_agent = _TestMessageFilterAgent("inner")
|
|
wrapper = MessageFilterAgent(
|
|
name="wrapper",
|
|
wrapped_agent=inner_agent,
|
|
filter=MessageFilterConfig(per_source=[PerSourceFilter(source="user", position=None, count=None)]),
|
|
)
|
|
messages = [
|
|
TextMessage(source="user", content="A"),
|
|
TextMessage(source="user", content="B"),
|
|
TextMessage(source="system", content="Ignore this"),
|
|
]
|
|
await wrapper.on_messages(messages, CancellationToken())
|
|
assert len(inner_agent.received_messages) == 2
|
|
assert {m.content for m in inner_agent.received_messages} == {"A", "B"} # type: ignore[attr-defined]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_digraph_group_chat() -> None:
|
|
inner_agent = _TestMessageFilterAgent("agent")
|
|
wrapper = MessageFilterAgent(
|
|
name="agent",
|
|
wrapped_agent=inner_agent,
|
|
filter=MessageFilterConfig(
|
|
per_source=[
|
|
PerSourceFilter(source="user", position="last", count=2),
|
|
PerSourceFilter(source="system", position="first", count=1),
|
|
]
|
|
),
|
|
)
|
|
config = wrapper.dump_component()
|
|
loaded = MessageFilterAgent.load_component(config)
|
|
assert loaded.name == "agent"
|
|
assert loaded._filter == wrapper._filter # pyright: ignore[reportPrivateUsage]
|
|
assert loaded._wrapped_agent.name == wrapper._wrapped_agent.name # pyright: ignore[reportPrivateUsage]
|
|
|
|
# Run on_messages and validate filtering still works
|
|
messages = [
|
|
TextMessage(source="user", content="u1"),
|
|
TextMessage(source="user", content="u2"),
|
|
TextMessage(source="user", content="u3"),
|
|
TextMessage(source="system", content="s1"),
|
|
TextMessage(source="system", content="s2"),
|
|
]
|
|
await loaded.on_messages(messages, CancellationToken())
|
|
received = loaded._wrapped_agent.received_messages # type: ignore[attr-defined]
|
|
assert {m.content for m in received} == {"u2", "u3", "s1"} # pyright: ignore[reportUnknownVariableType,reportUnknownMemberType]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_message_filter_agent_in_digraph_group_chat(runtime: AgentRuntime | None) -> None:
|
|
inner_agent = _TestMessageFilterAgent("filtered")
|
|
filtered = MessageFilterAgent(
|
|
name="filtered",
|
|
wrapped_agent=inner_agent,
|
|
filter=MessageFilterConfig(
|
|
per_source=[
|
|
PerSourceFilter(source="user", position="last", count=1),
|
|
]
|
|
),
|
|
)
|
|
|
|
graph = DiGraph(
|
|
nodes={
|
|
"filtered": DiGraphNode(name="filtered", edges=[]),
|
|
}
|
|
)
|
|
|
|
team = GraphFlow(
|
|
participants=[filtered],
|
|
graph=graph,
|
|
runtime=runtime,
|
|
termination_condition=MaxMessageTermination(3),
|
|
)
|
|
|
|
result = await team.run(task="only last user message matters")
|
|
assert result.stop_reason is not None
|
|
assert any(msg.source == "filtered" for msg in result.messages)
|
|
assert any(msg.content == "ACK" for msg in result.messages if msg.source == "filtered") # type: ignore[attr-defined,union-attr]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_message_filter_agent_loop_graph_visibility(runtime: AgentRuntime | None) -> None:
|
|
agent_a_inner = _TestMessageFilterAgent("A")
|
|
agent_a = MessageFilterAgent(
|
|
name="A",
|
|
wrapped_agent=agent_a_inner,
|
|
filter=MessageFilterConfig(
|
|
per_source=[
|
|
PerSourceFilter(source="user", position="first", count=1),
|
|
PerSourceFilter(source="B", position="last", count=1),
|
|
]
|
|
),
|
|
)
|
|
|
|
from autogen_agentchat.agents import AssistantAgent
|
|
from autogen_ext.models.replay import ReplayChatCompletionClient
|
|
|
|
model_client = ReplayChatCompletionClient(["loop", "loop", "exit"])
|
|
agent_b_inner = AssistantAgent("B", model_client=model_client)
|
|
agent_b = MessageFilterAgent(
|
|
name="B",
|
|
wrapped_agent=agent_b_inner,
|
|
filter=MessageFilterConfig(
|
|
per_source=[
|
|
PerSourceFilter(source="user", position="first", count=1),
|
|
PerSourceFilter(source="A", position="last", count=1),
|
|
PerSourceFilter(source="B", position="last", count=10),
|
|
]
|
|
),
|
|
)
|
|
|
|
agent_c_inner = _TestMessageFilterAgent("C")
|
|
agent_c = MessageFilterAgent(
|
|
name="C",
|
|
wrapped_agent=agent_c_inner,
|
|
filter=MessageFilterConfig(
|
|
per_source=[
|
|
PerSourceFilter(source="user", position="first", count=1),
|
|
PerSourceFilter(source="B", position="last", count=1),
|
|
]
|
|
),
|
|
)
|
|
|
|
graph = DiGraph(
|
|
nodes={
|
|
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B")]),
|
|
"B": DiGraphNode(
|
|
name="B",
|
|
edges=[
|
|
DiGraphEdge(target="C", condition="exit"),
|
|
DiGraphEdge(target="A", condition="loop"),
|
|
],
|
|
),
|
|
"C": DiGraphNode(name="C", edges=[]),
|
|
},
|
|
default_start_node="A",
|
|
)
|
|
|
|
team = GraphFlow(
|
|
participants=[agent_a, agent_b, agent_c],
|
|
graph=graph,
|
|
runtime=runtime,
|
|
termination_condition=MaxMessageTermination(20),
|
|
)
|
|
|
|
result = await team.run(task="Start")
|
|
assert result.stop_reason is not None
|
|
|
|
# Check A received: 1 user + 2 from B
|
|
assert [m.source for m in agent_a_inner.received_messages].count("user") == 1
|
|
assert [m.source for m in agent_a_inner.received_messages].count("B") == 2
|
|
|
|
# Check C received: 1 user + 1 from B
|
|
assert [m.source for m in agent_c_inner.received_messages].count("user") == 1
|
|
assert [m.source for m in agent_c_inner.received_messages].count("B") == 1
|
|
|
|
# Check B received: 1 user + multiple from A + own messages
|
|
model_msgs = await agent_b_inner.model_context.get_messages()
|
|
sources = [m.source for m in model_msgs] # type: ignore[union-attr]
|
|
assert sources.count("user") == 1 # pyright: ignore[reportUnknownMemberType]
|
|
assert sources.count("A") >= 3 # pyright: ignore[reportUnknownMemberType]
|
|
assert sources.count("B") >= 2 # pyright: ignore[reportUnknownMemberType]
|
|
|
|
|
|
# Test Graph Builder
|
|
def test_add_node() -> None:
|
|
client = ReplayChatCompletionClient(["response"])
|
|
agent = AssistantAgent("A", model_client=client)
|
|
builder = DiGraphBuilder()
|
|
builder.add_node(agent)
|
|
|
|
assert "A" in builder.nodes
|
|
assert "A" in builder.agents
|
|
assert builder.nodes["A"].activation == "all"
|
|
|
|
|
|
def test_add_edge() -> None:
|
|
client = ReplayChatCompletionClient(["1", "2"])
|
|
a = AssistantAgent("A", model_client=client)
|
|
b = AssistantAgent("B", model_client=client)
|
|
|
|
builder = DiGraphBuilder()
|
|
builder.add_node(a).add_node(b)
|
|
builder.add_edge(a, b)
|
|
|
|
assert builder.nodes["A"].edges[0].target == "B"
|
|
assert builder.nodes["A"].edges[0].condition is None
|
|
|
|
|
|
def test_add_conditional_edges() -> None:
|
|
client = ReplayChatCompletionClient(["1", "2"])
|
|
a = AssistantAgent("A", model_client=client)
|
|
b = AssistantAgent("B", model_client=client)
|
|
c = AssistantAgent("C", model_client=client)
|
|
|
|
builder = DiGraphBuilder()
|
|
builder.add_node(a).add_node(b).add_node(c)
|
|
builder.add_conditional_edges(a, {"yes": b, "no": c})
|
|
|
|
edges = builder.nodes["A"].edges
|
|
assert len(edges) == 2
|
|
|
|
# Extract the condition strings to compare them
|
|
conditions = [e.condition for e in edges]
|
|
assert "yes" in conditions
|
|
assert "no" in conditions
|
|
|
|
# Match edge targets with conditions
|
|
yes_edge = next(e for e in edges if e.condition == "yes")
|
|
no_edge = next(e for e in edges if e.condition == "no")
|
|
|
|
assert yes_edge.target == "B"
|
|
assert no_edge.target == "C"
|
|
|
|
|
|
def test_set_entry_point() -> None:
|
|
client = ReplayChatCompletionClient(["ok"])
|
|
a = AssistantAgent("A", model_client=client)
|
|
builder = DiGraphBuilder().add_node(a).set_entry_point(a)
|
|
graph = builder.build()
|
|
|
|
assert graph.default_start_node == "A"
|
|
|
|
|
|
def test_build_graph_validation() -> None:
|
|
client = ReplayChatCompletionClient(["1", "2", "3"])
|
|
a = AssistantAgent("A", model_client=client)
|
|
b = AssistantAgent("B", model_client=client)
|
|
c = AssistantAgent("C", model_client=client)
|
|
|
|
builder = DiGraphBuilder()
|
|
builder.add_node(a).add_node(b).add_node(c)
|
|
builder.add_edge("A", "B").add_edge("B", "C")
|
|
builder.set_entry_point("A")
|
|
graph = builder.build()
|
|
|
|
assert isinstance(graph, DiGraph)
|
|
assert set(graph.nodes.keys()) == {"A", "B", "C"}
|
|
assert graph.get_start_nodes() == {"A"}
|
|
assert graph.get_leaf_nodes() == {"C"}
|
|
|
|
|
|
def test_build_fan_out() -> None:
|
|
client = ReplayChatCompletionClient(["hi"] * 3)
|
|
a = AssistantAgent("A", model_client=client)
|
|
b = AssistantAgent("B", model_client=client)
|
|
c = AssistantAgent("C", model_client=client)
|
|
|
|
builder = DiGraphBuilder()
|
|
builder.add_node(a).add_node(b).add_node(c)
|
|
builder.add_edge(a, b).add_edge(a, c)
|
|
builder.set_entry_point(a)
|
|
graph = builder.build()
|
|
|
|
assert graph.get_start_nodes() == {"A"}
|
|
assert graph.get_leaf_nodes() == {"B", "C"}
|
|
|
|
|
|
def test_build_parallel_join() -> None:
|
|
client = ReplayChatCompletionClient(["go"] * 3)
|
|
a = AssistantAgent("A", model_client=client)
|
|
b = AssistantAgent("B", model_client=client)
|
|
c = AssistantAgent("C", model_client=client)
|
|
|
|
builder = DiGraphBuilder()
|
|
builder.add_node(a).add_node(b).add_node(c, activation="all")
|
|
builder.add_edge(a, c).add_edge(b, c)
|
|
builder.set_entry_point(a)
|
|
builder.add_edge(b, c)
|
|
builder.nodes["B"] = DiGraphNode(name="B", edges=[DiGraphEdge(target="C")])
|
|
graph = builder.build()
|
|
|
|
assert graph.nodes["C"].activation == "all"
|
|
assert graph.get_leaf_nodes() == {"C"}
|
|
|
|
|
|
def test_build_conditional_loop() -> None:
|
|
client = ReplayChatCompletionClient(["loop", "loop", "exit"])
|
|
a = AssistantAgent("A", model_client=client)
|
|
b = AssistantAgent("B", model_client=client)
|
|
c = AssistantAgent("C", model_client=client)
|
|
|
|
builder = DiGraphBuilder()
|
|
builder.add_node(a).add_node(b).add_node(c)
|
|
builder.add_edge(a, b)
|
|
builder.add_conditional_edges(b, {"loop": a, "exit": c})
|
|
builder.set_entry_point(a)
|
|
graph = builder.build()
|
|
|
|
# Check that edges have the right conditions and targets
|
|
edges = graph.nodes["B"].edges
|
|
assert len(edges) == 2
|
|
|
|
# Find edges by their conditions
|
|
loop_edge = next(e for e in edges if e.condition == "loop")
|
|
exit_edge = next(e for e in edges if e.condition == "exit")
|
|
|
|
assert loop_edge.target == "A"
|
|
assert exit_edge.target == "C"
|
|
assert graph.has_cycles_with_exit()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_graph_builder_sequential_execution(runtime: AgentRuntime | None) -> None:
|
|
a = _EchoAgent("A", description="Echo A")
|
|
b = _EchoAgent("B", description="Echo B")
|
|
c = _EchoAgent("C", description="Echo C")
|
|
|
|
builder = DiGraphBuilder()
|
|
builder.add_node(a).add_node(b).add_node(c)
|
|
builder.add_edge(a, b).add_edge(b, c)
|
|
|
|
team = GraphFlow(
|
|
participants=builder.get_participants(),
|
|
graph=builder.build(),
|
|
runtime=runtime,
|
|
termination_condition=MaxMessageTermination(5),
|
|
)
|
|
|
|
result = await team.run(task="Start")
|
|
assert [m.source for m in result.messages[1:-1]] == ["A", "B", "C"]
|
|
assert result.stop_reason is not None
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_graph_builder_fan_out(runtime: AgentRuntime | None) -> None:
|
|
a = _EchoAgent("A", description="Echo A")
|
|
b = _EchoAgent("B", description="Echo B")
|
|
c = _EchoAgent("C", description="Echo C")
|
|
|
|
builder = DiGraphBuilder()
|
|
builder.add_node(a).add_node(b).add_node(c)
|
|
builder.add_edge(a, b).add_edge(a, c)
|
|
|
|
team = GraphFlow(
|
|
participants=builder.get_participants(),
|
|
graph=builder.build(),
|
|
runtime=runtime,
|
|
termination_condition=MaxMessageTermination(5),
|
|
)
|
|
|
|
result = await team.run(task="Start")
|
|
sources = [m.source for m in result.messages if isinstance(m, TextMessage)]
|
|
assert set(sources[1:]) == {"A", "B", "C"}
|
|
assert result.stop_reason is not None
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_graph_builder_conditional_execution(runtime: AgentRuntime | None) -> None:
|
|
a = _EchoAgent("A", description="Echo A")
|
|
b = _EchoAgent("B", description="Echo B")
|
|
c = _EchoAgent("C", description="Echo C")
|
|
|
|
builder = DiGraphBuilder()
|
|
builder.add_node(a).add_node(b).add_node(c)
|
|
builder.add_conditional_edges(a, {"yes": b, "no": c})
|
|
|
|
team = GraphFlow(
|
|
participants=builder.get_participants(),
|
|
graph=builder.build(),
|
|
runtime=runtime,
|
|
termination_condition=MaxMessageTermination(5),
|
|
)
|
|
|
|
# Input "no" should trigger the edge to C
|
|
result = await team.run(task="no")
|
|
sources = [m.source for m in result.messages]
|
|
assert "C" in sources
|
|
assert result.stop_reason is not None
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_digraph_group_chat_callable_condition(runtime: AgentRuntime | None) -> None:
|
|
"""Test that string conditions work correctly in edge transitions."""
|
|
agent_a = _EchoAgent("A", description="Echo agent A")
|
|
agent_b = _EchoAgent("B", description="Echo agent B")
|
|
agent_c = _EchoAgent("C", description="Echo agent C")
|
|
|
|
graph = DiGraph(
|
|
nodes={
|
|
"A": DiGraphNode(
|
|
name="A",
|
|
edges=[
|
|
# Will go to B if "long" is in message
|
|
DiGraphEdge(target="B", condition="long"),
|
|
# Will go to C if "short" is in message
|
|
DiGraphEdge(target="C", condition="short"),
|
|
],
|
|
),
|
|
"B": DiGraphNode(name="B", edges=[]),
|
|
"C": DiGraphNode(name="C", edges=[]),
|
|
}
|
|
)
|
|
|
|
team = GraphFlow(
|
|
participants=[agent_a, agent_b, agent_c],
|
|
graph=graph,
|
|
runtime=runtime,
|
|
termination_condition=MaxMessageTermination(5),
|
|
)
|
|
|
|
# Test with a message containing "long" - should go to B
|
|
result = await team.run(task="This is a long message")
|
|
assert result.messages[2].source == "B"
|
|
|
|
# Reset for next test
|
|
await team.reset()
|
|
|
|
# Test with a message containing "short" - should go to C
|
|
result = await team.run(task="This is a short message")
|
|
assert result.messages[2].source == "C"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_graph_flow_serialize_deserialize() -> None:
|
|
client_a = ReplayChatCompletionClient(list(map(str, range(10))))
|
|
client_b = ReplayChatCompletionClient(list(map(str, range(10))))
|
|
a = AssistantAgent("A", model_client=client_a)
|
|
b = AssistantAgent("B", model_client=client_b)
|
|
|
|
builder = DiGraphBuilder()
|
|
builder.add_node(a).add_node(b)
|
|
builder.add_edge(a, b)
|
|
builder.set_entry_point(a)
|
|
|
|
team = GraphFlow(
|
|
participants=builder.get_participants(),
|
|
graph=builder.build(),
|
|
runtime=None,
|
|
)
|
|
|
|
serialized = team.dump_component()
|
|
deserialized_team = GraphFlow.load_component(serialized)
|
|
serialized_deserialized = deserialized_team.dump_component()
|
|
|
|
results = await team.run(task="Start")
|
|
de_results = await deserialized_team.run(task="Start")
|
|
|
|
assert serialized == serialized_deserialized
|
|
assert compare_task_results(results, de_results)
|
|
assert results.stop_reason is not None
|
|
assert results.stop_reason == de_results.stop_reason
|
|
assert compare_message_lists(results.messages, de_results.messages)
|
|
assert isinstance(results.messages[0], TextMessage)
|
|
assert results.messages[0].source == "user"
|
|
assert results.messages[0].content == "Start"
|
|
assert isinstance(results.messages[1], TextMessage)
|
|
assert results.messages[1].source == "A"
|
|
assert results.messages[1].content == "0"
|
|
assert isinstance(results.messages[2], TextMessage)
|
|
assert results.messages[2].source == "B"
|
|
assert results.messages[2].content == "0"
|
|
assert isinstance(results.messages[-1], StopMessage)
|
|
assert results.messages[-1].source == _DIGRAPH_STOP_AGENT_NAME
|
|
assert results.messages[-1].content == "Digraph execution is complete"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_graph_flow_stateful_pause_and_resume_with_termination() -> None:
|
|
client_a = ReplayChatCompletionClient(["A1", "A2"])
|
|
client_b = ReplayChatCompletionClient(["B1"])
|
|
|
|
a = AssistantAgent("A", model_client=client_a)
|
|
b = AssistantAgent("B", model_client=client_b)
|
|
|
|
builder = DiGraphBuilder()
|
|
builder.add_node(a).add_node(b)
|
|
builder.add_edge(a, b)
|
|
builder.set_entry_point(a)
|
|
|
|
team = GraphFlow(
|
|
participants=builder.get_participants(),
|
|
graph=builder.build(),
|
|
runtime=None,
|
|
termination_condition=SourceMatchTermination(sources=["A"]),
|
|
)
|
|
|
|
result = await team.run(task="Start")
|
|
assert len(result.messages) == 2
|
|
assert result.messages[0].source == "user"
|
|
assert result.messages[1].source == "A"
|
|
assert result.stop_reason is not None and result.stop_reason == "'A' answered"
|
|
|
|
# Export state.
|
|
state = await team.save_state()
|
|
|
|
# Load state into a new team.
|
|
new_team = GraphFlow(
|
|
participants=builder.get_participants(),
|
|
graph=builder.build(),
|
|
runtime=None,
|
|
)
|
|
await new_team.load_state(state)
|
|
|
|
# Resume.
|
|
result = await new_team.run()
|
|
assert len(result.messages) == 2
|
|
assert result.messages[0].source == "B"
|
|
assert result.messages[1].source == _DIGRAPH_STOP_AGENT_NAME
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_builder_with_lambda_condition(runtime: AgentRuntime | None) -> None:
|
|
"""Test that DiGraphBuilder supports string conditions."""
|
|
agent_a = _EchoAgent("A", description="Echo agent A")
|
|
agent_b = _EchoAgent("B", description="Echo agent B")
|
|
agent_c = _EchoAgent("C", description="Echo agent C")
|
|
|
|
builder = DiGraphBuilder()
|
|
builder.add_node(agent_a).add_node(agent_b).add_node(agent_c)
|
|
|
|
# Using callable conditions
|
|
builder.add_edge(agent_a, agent_b, lambda msg: "even" in msg.to_model_text())
|
|
builder.add_edge(agent_a, agent_c, lambda msg: "odd" in msg.to_model_text())
|
|
|
|
team = GraphFlow(
|
|
participants=builder.get_participants(),
|
|
graph=builder.build(),
|
|
runtime=runtime,
|
|
termination_condition=MaxMessageTermination(5),
|
|
)
|
|
|
|
# Test with "even" in message - should go to B
|
|
result = await team.run(task="even length")
|
|
assert result.messages[2].source == "B"
|
|
|
|
# Reset for next test
|
|
await team.reset()
|
|
|
|
# Test with "odd" in message - should go to C
|
|
result = await team.run(task="odd message")
|
|
assert result.messages[2].source == "C"
|