mirror of
https://github.com/microsoft/autogen.git
synced 2025-06-26 22:30:10 +00:00
Add callable condition for GraphFlow edges (#6623)
This PR adds callable as an option to specify conditional edges in GraphFlow. ```python import asyncio from autogen_agentchat.agents import AssistantAgent from autogen_agentchat.conditions import MaxMessageTermination from autogen_agentchat.teams import DiGraphBuilder, GraphFlow from autogen_ext.models.openai import OpenAIChatCompletionClient async def main(): # Initialize agents with OpenAI model clients. model_client = OpenAIChatCompletionClient(model="gpt-4.1-nano") agent_a = AssistantAgent( "A", model_client=model_client, system_message="Detect if the input is in Chinese. If it is, say 'yes', else say 'no', and nothing else.", ) agent_b = AssistantAgent("B", model_client=model_client, system_message="Translate input to English.") agent_c = AssistantAgent("C", model_client=model_client, system_message="Translate input to Chinese.") # Create a directed graph with conditional branching flow A -> B ("yes"), A -> C (otherwise). builder = DiGraphBuilder() builder.add_node(agent_a).add_node(agent_b).add_node(agent_c) # Create conditions as callables that check the message content. builder.add_edge(agent_a, agent_b, condition=lambda msg: "yes" in msg.to_model_text()) builder.add_edge(agent_a, agent_c, condition=lambda msg: "yes" not in msg.to_model_text()) graph = builder.build() # Create a GraphFlow team with the directed graph. team = GraphFlow( participants=[agent_a, agent_b, agent_c], graph=graph, termination_condition=MaxMessageTermination(5), ) # Run the team and print the events. async for event in team.run_stream(task="AutoGen is a framework for building AI agents."): print(event) asyncio.run(main()) ``` --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: ekzhu <320302+ekzhu@users.noreply.github.com>
This commit is contained in:
parent
9065c6f37b
commit
b31b4e508d
@ -1,9 +1,9 @@
|
||||
import asyncio
|
||||
from collections import Counter, deque
|
||||
from typing import Any, Callable, Deque, Dict, List, Literal, Mapping, Sequence, Set
|
||||
from typing import Any, Callable, Deque, Dict, List, Literal, Mapping, Sequence, Set, Union
|
||||
|
||||
from autogen_core import AgentRuntime, CancellationToken, Component, ComponentModel
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from autogen_agentchat.agents import BaseChatAgent
|
||||
@ -34,16 +34,50 @@ class DiGraphEdge(BaseModel):
|
||||
|
||||
This is an experimental feature, and the API will change in the future releases.
|
||||
|
||||
.. warning::
|
||||
|
||||
If the condition is a callable, it will not be serialized in the model.
|
||||
|
||||
"""
|
||||
|
||||
target: str # Target node name
|
||||
condition: str | None = None # Optional execution condition (trigger-based)
|
||||
condition: Union[str, Callable[[BaseChatMessage], bool], None] = Field(default=None)
|
||||
"""(Experimental) Condition to execute this edge.
|
||||
If None, the edge is unconditional. If a string, the edge is conditional on the presence of that string in the last agent chat message.
|
||||
NOTE: This is an experimental feature WILL change in the future releases to allow for better spcification of branching conditions
|
||||
similar to the `TerminationCondition` class.
|
||||
If None, the edge is unconditional.
|
||||
If a string, the edge is conditional on the presence of that string in the last agent chat message.
|
||||
If a callable, the edge is conditional on the callable returning True when given the last message.
|
||||
"""
|
||||
|
||||
# Using Field to exclude the condition in serialization if it's a callable
|
||||
condition_function: Callable[[BaseChatMessage], bool] | None = Field(default=None, exclude=True)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_condition(self) -> "DiGraphEdge":
|
||||
# Store callable in a separate field and set condition to None for serialization
|
||||
if callable(self.condition):
|
||||
self.condition_function = self.condition
|
||||
# For serialization purposes, we'll set the condition to None
|
||||
# when storing as a pydantic model/dict
|
||||
object.__setattr__(self, "condition", None)
|
||||
return self
|
||||
|
||||
def check_condition(self, message: BaseChatMessage) -> bool:
|
||||
"""Check if the edge condition is satisfied for the given message.
|
||||
|
||||
Args:
|
||||
message: The message to check the condition against.
|
||||
|
||||
Returns:
|
||||
True if condition is satisfied (None condition always returns True),
|
||||
False otherwise.
|
||||
"""
|
||||
if self.condition_function is not None:
|
||||
return self.condition_function(message)
|
||||
elif isinstance(self.condition, str):
|
||||
# If it's a string, check if the string is in the message content
|
||||
return self.condition in message.to_model_text()
|
||||
return True # None condition is always satisfied
|
||||
|
||||
|
||||
class DiGraphNode(BaseModel):
|
||||
"""Represents a node (agent) in a :class:`DiGraph`, with its outgoing edges and activation type.
|
||||
@ -125,7 +159,7 @@ class DiGraph(BaseModel):
|
||||
cycle_edges: List[DiGraphEdge] = []
|
||||
for n in cycle_nodes:
|
||||
cycle_edges.extend(self.nodes[n].edges)
|
||||
if not any(edge.condition for edge in cycle_edges):
|
||||
if not any(edge.condition is not None for edge in cycle_edges):
|
||||
raise ValueError(
|
||||
f"Cycle detected without exit condition: {' -> '.join(cycle_nodes + cycle_nodes[:1])}"
|
||||
)
|
||||
@ -164,7 +198,7 @@ class DiGraph(BaseModel):
|
||||
# Outgoing edge condition validation (per node)
|
||||
for node in self.nodes.values():
|
||||
# Check that if a node has an outgoing conditional edge, then all outgoing edges are conditional
|
||||
has_condition = any(edge.condition for edge in node.edges)
|
||||
has_condition = any(edge.condition is not None for edge in node.edges)
|
||||
has_unconditioned = any(edge.condition is None for edge in node.edges)
|
||||
if has_condition and has_unconditioned:
|
||||
raise ValueError(f"Node '{node.name}' has a mix of conditional and unconditional edges.")
|
||||
@ -239,11 +273,11 @@ class GraphFlowManager(BaseGroupChatManager):
|
||||
return
|
||||
assert isinstance(message, BaseChatMessage)
|
||||
source = message.source
|
||||
content = message.to_model_text()
|
||||
|
||||
# Propagate the update to the children of the node.
|
||||
for edge in self._edges[source]:
|
||||
if edge.condition and edge.condition not in content:
|
||||
# Use the new check_condition method that handles both string and callable conditions
|
||||
if not edge.check_condition(message):
|
||||
continue
|
||||
if self._activation[edge.target] == "all":
|
||||
self._remaining[edge.target] -= 1
|
||||
@ -360,6 +394,11 @@ class GraphFlow(BaseGroupChat, Component[GraphFlowConfig]):
|
||||
See the :class:`DiGraphBuilder` documentation for more details.
|
||||
The :class:`GraphFlow` class is designed to be used with the :class:`DiGraphBuilder` for creating complex workflows.
|
||||
|
||||
.. warning::
|
||||
|
||||
When using callable conditions in edges, they will not be serialized
|
||||
when calling :meth:`dump_component`. This will be addressed in future releases.
|
||||
|
||||
|
||||
Args:
|
||||
participants (List[ChatAgent]): The participants in the group chat.
|
||||
@ -450,7 +489,7 @@ class GraphFlow(BaseGroupChat, Component[GraphFlowConfig]):
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
**Conditional Branching: A → B (if 'yes') or C (if 'no')**
|
||||
**Conditional Branching: A → B (if 'yes') or C (otherwise)**
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@ -473,11 +512,12 @@ class GraphFlow(BaseGroupChat, Component[GraphFlowConfig]):
|
||||
agent_b = AssistantAgent("B", model_client=model_client, system_message="Translate input to English.")
|
||||
agent_c = AssistantAgent("C", model_client=model_client, system_message="Translate input to Chinese.")
|
||||
|
||||
# Create a directed graph with conditional branching flow A -> B ("yes"), A -> C ("no").
|
||||
# Create a directed graph with conditional branching flow A -> B ("yes"), A -> C (otherwise).
|
||||
builder = DiGraphBuilder()
|
||||
builder.add_node(agent_a).add_node(agent_b).add_node(agent_c)
|
||||
builder.add_edge(agent_a, agent_b, condition="yes")
|
||||
builder.add_edge(agent_a, agent_c, condition="no")
|
||||
# Create conditions as callables that check the message content.
|
||||
builder.add_edge(agent_a, agent_b, condition=lambda msg: "yes" in msg.to_model_text())
|
||||
builder.add_edge(agent_a, agent_c, condition=lambda msg: "yes" not in msg.to_model_text())
|
||||
graph = builder.build()
|
||||
|
||||
# Create a GraphFlow team with the directed graph.
|
||||
@ -494,7 +534,7 @@ class GraphFlow(BaseGroupChat, Component[GraphFlowConfig]):
|
||||
|
||||
asyncio.run(main())
|
||||
|
||||
**Loop with exit condition: A → B → C (if 'APPROVE') or A (if 'REJECT')**
|
||||
**Loop with exit condition: A → B → C (if 'APPROVE') or A (otherwise)**
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@ -518,17 +558,21 @@ class GraphFlow(BaseGroupChat, Component[GraphFlowConfig]):
|
||||
"B",
|
||||
model_client=model_client,
|
||||
system_message="Provide feedback on the input, if your feedback has been addressed, "
|
||||
"say 'APPROVE', else say 'REJECT' and provide a reason.",
|
||||
"say 'APPROVE', otherwise provide a reason for rejection.",
|
||||
)
|
||||
agent_c = AssistantAgent(
|
||||
"C", model_client=model_client, system_message="Translate the final product to Korean."
|
||||
)
|
||||
|
||||
# Create a loop graph with conditional exit: A -> B -> C ("APPROVE"), B -> A ("REJECT").
|
||||
# Create a loop graph with conditional exit: A -> B -> C ("APPROVE"), B -> A (otherwise).
|
||||
builder = DiGraphBuilder()
|
||||
builder.add_node(agent_a).add_node(agent_b).add_node(agent_c)
|
||||
builder.add_edge(agent_a, agent_b)
|
||||
builder.add_conditional_edges(agent_b, {"APPROVE": agent_c, "REJECT": agent_a})
|
||||
|
||||
# Create conditional edges using strings
|
||||
builder.add_edge(agent_b, agent_c, condition=lambda msg: "APPROVE" in msg.to_model_text())
|
||||
builder.add_edge(agent_b, agent_a, condition=lambda msg: "APPROVE" not in msg.to_model_text())
|
||||
|
||||
builder.set_entry_point(agent_a)
|
||||
graph = builder.build()
|
||||
|
||||
|
@ -1,6 +1,8 @@
|
||||
from typing import Dict, Literal, Optional, Union
|
||||
import warnings
|
||||
from typing import Callable, Dict, Literal, Optional, Union
|
||||
|
||||
from autogen_agentchat.base import ChatAgent
|
||||
from autogen_agentchat.messages import BaseChatMessage
|
||||
|
||||
from ._digraph_group_chat import DiGraph, DiGraphEdge, DiGraphNode
|
||||
|
||||
@ -22,7 +24,7 @@ class DiGraphBuilder:
|
||||
- Cyclic loops with safe exits
|
||||
|
||||
Each node in the graph represents an agent. Edges define execution paths between agents,
|
||||
and can optionally be conditioned on message content.
|
||||
and can optionally be conditioned on message content using callable functions.
|
||||
|
||||
The builder is compatible with the `Graph` runner and supports both standard and filtered agents.
|
||||
|
||||
@ -49,16 +51,29 @@ class DiGraphBuilder:
|
||||
>>> builder.add_node(agent_a).add_node(agent_b).add_node(agent_c)
|
||||
>>> builder.add_edge(agent_a, agent_b).add_edge(agent_a, agent_c)
|
||||
|
||||
Example — Conditional Branching A → B ("yes"), A → C ("no"):
|
||||
Example — Conditional Branching A → B or A → C:
|
||||
>>> builder = GraphBuilder()
|
||||
>>> builder.add_node(agent_a).add_node(agent_b).add_node(agent_c)
|
||||
>>> builder.add_conditional_edges(agent_a, {"yes": agent_b, "no": agent_c})
|
||||
>>> # Add conditional edges using keyword check
|
||||
>>> builder.add_edge(agent_a, agent_b, condition="keyword1")
|
||||
>>> builder.add_edge(agent_a, agent_c, condition="keyword2")
|
||||
|
||||
Example — Loop: A → B → A ("loop"), B → C ("exit"):
|
||||
|
||||
Example — Using Custom String Conditions:
|
||||
>>> builder = GraphBuilder()
|
||||
>>> builder.add_node(agent_a).add_node(agent_b).add_node(agent_c)
|
||||
>>> # Add condition strings to check in messages
|
||||
>>> builder.add_edge(agent_a, agent_b, condition="big")
|
||||
>>> builder.add_edge(agent_a, agent_c, condition="small")
|
||||
|
||||
Example — Loop: A → B → A or B → C:
|
||||
>>> builder = GraphBuilder()
|
||||
>>> builder.add_node(agent_a).add_node(agent_b).add_node(agent_c)
|
||||
>>> builder.add_edge(agent_a, agent_b)
|
||||
>>> builder.add_conditional_edges(agent_b, {"loop": agent_a, "exit": agent_c})
|
||||
>> # Add a loop back to agent A
|
||||
>>> builder.add_edge(agent_b, agent_a, condition=lambda msg: "loop" in msg.to_model_text())
|
||||
>>> # Add exit condition to break the loop
|
||||
>>> builder.add_edge(agent_b, agent_c, condition=lambda msg: "loop" not in msg.to_model_text())
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
@ -78,9 +93,26 @@ class DiGraphBuilder:
|
||||
return self
|
||||
|
||||
def add_edge(
|
||||
self, source: Union[str, ChatAgent], target: Union[str, ChatAgent], condition: Optional[str] = None
|
||||
self,
|
||||
source: Union[str, ChatAgent],
|
||||
target: Union[str, ChatAgent],
|
||||
condition: Optional[Union[str, Callable[[BaseChatMessage], bool]]] = None,
|
||||
) -> "DiGraphBuilder":
|
||||
"""Add a directed edge from source to target, optionally with a condition."""
|
||||
"""Add a directed edge from source to target, optionally with a condition.
|
||||
|
||||
Args:
|
||||
source: Source node (agent name or agent object)
|
||||
target: Target node (agent name or agent object)
|
||||
condition: Optional condition for edge activation.
|
||||
If string, activates when substring is found in message.
|
||||
If callable, activates when function returns True for the message.
|
||||
|
||||
Returns:
|
||||
Self for method chaining
|
||||
|
||||
Raises:
|
||||
ValueError: If source or target node doesn't exist in the builder
|
||||
"""
|
||||
source_name = self._get_name(source)
|
||||
target_name = self._get_name(target)
|
||||
|
||||
@ -95,9 +127,35 @@ class DiGraphBuilder:
|
||||
def add_conditional_edges(
|
||||
self, source: Union[str, ChatAgent], condition_to_target: Dict[str, Union[str, ChatAgent]]
|
||||
) -> "DiGraphBuilder":
|
||||
"""Add multiple conditional edges from a source node based on condition strings."""
|
||||
for condition, target in condition_to_target.items():
|
||||
self.add_edge(source, target, condition)
|
||||
"""Add multiple conditional edges from a source node based on keyword checks.
|
||||
|
||||
.. warning::
|
||||
|
||||
This method interface will be changed in the future to support callable conditions.
|
||||
Please use `add_edge` if you need to specify custom conditions.
|
||||
|
||||
Args:
|
||||
source: Source node (agent name or agent object)
|
||||
condition_to_target: Mapping from condition strings to target nodes
|
||||
Each key is a keyword that will be checked in the message content
|
||||
Each value is the target node to activate when condition is met
|
||||
|
||||
For each key (keyword), a lambda will be created that checks
|
||||
if the keyword is in the message text.
|
||||
|
||||
Returns:
|
||||
Self for method chaining
|
||||
"""
|
||||
|
||||
warnings.warn(
|
||||
"add_conditional_edges will be changed in the future to support callable conditions. "
|
||||
"For now, please use add_edge if you need to specify custom conditions.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
for condition_keyword, target in condition_to_target.items():
|
||||
self.add_edge(source, target, condition=condition_keyword)
|
||||
return self
|
||||
|
||||
def set_entry_point(self, name: Union[str, ChatAgent]) -> "DiGraphBuilder":
|
||||
|
@ -99,6 +99,7 @@ def test_get_leaf_nodes() -> None:
|
||||
|
||||
def test_serialization() -> None:
|
||||
"""Test serializing and deserializing the graph."""
|
||||
# Use a string condition instead of a lambda
|
||||
graph = DiGraph(
|
||||
nodes={
|
||||
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B", condition="trigger1")]),
|
||||
@ -114,6 +115,11 @@ def test_serialization() -> None:
|
||||
assert deserialized_graph.nodes["A"].edges[0].condition == "trigger1"
|
||||
assert deserialized_graph.nodes["B"].edges[0].target == "C"
|
||||
|
||||
# Test the original condition works
|
||||
test_msg = TextMessage(content="this has trigger1 in it", source="test")
|
||||
# Manually check if the string is in the message text
|
||||
assert "trigger1" in test_msg.to_model_text()
|
||||
|
||||
|
||||
def test_invalid_graph_no_start_node() -> None:
|
||||
"""Test validation failure when there is no start node."""
|
||||
@ -144,6 +150,7 @@ def test_invalid_graph_no_leaf_node() -> None:
|
||||
|
||||
def test_condition_edge_execution() -> None:
|
||||
"""Test conditional edge execution support."""
|
||||
# Use string condition
|
||||
graph = DiGraph(
|
||||
nodes={
|
||||
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B", condition="TRIGGER")]),
|
||||
@ -152,6 +159,15 @@ def test_condition_edge_execution() -> None:
|
||||
}
|
||||
)
|
||||
|
||||
# Check the condition manually
|
||||
test_message = TextMessage(content="This has TRIGGER in it", source="test")
|
||||
non_match_message = TextMessage(content="This doesn't match", source="test")
|
||||
|
||||
# Check if the string condition is in each message text
|
||||
assert "TRIGGER" in test_message.to_model_text()
|
||||
assert "TRIGGER" not in non_match_message.to_model_text()
|
||||
|
||||
# Check the condition itself
|
||||
assert graph.nodes["A"].edges[0].condition == "TRIGGER"
|
||||
assert graph.nodes["B"].edges[0].condition is None
|
||||
|
||||
@ -193,6 +209,7 @@ def test_cycle_detection_no_cycle() -> None:
|
||||
|
||||
def test_cycle_detection_with_exit_condition() -> None:
|
||||
"""Test a graph with cycle and conditional exit passes validation."""
|
||||
# Use a string condition
|
||||
graph = DiGraph(
|
||||
nodes={
|
||||
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B")]),
|
||||
@ -258,6 +275,7 @@ def test_validate_graph_missing_leaf_node() -> None:
|
||||
|
||||
def test_validate_graph_mixed_conditions() -> None:
|
||||
"""Test validation failure when node has mixed conditional and unconditional edges."""
|
||||
# Use string for condition
|
||||
graph = DiGraph(
|
||||
nodes={
|
||||
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B", condition="cond"), DiGraphEdge(target="C")]),
|
||||
@ -552,6 +570,7 @@ async def test_digraph_group_chat_conditional_branch(runtime: AgentRuntime | Non
|
||||
agent_b = _EchoAgent("B", description="Echo agent B")
|
||||
agent_c = _EchoAgent("C", description="Echo agent C")
|
||||
|
||||
# Use string conditions
|
||||
graph = DiGraph(
|
||||
nodes={
|
||||
"A": DiGraphNode(
|
||||
@ -726,6 +745,7 @@ async def test_digraph_group_chat_multiple_conditional(runtime: AgentRuntime | N
|
||||
agent_c = _EchoAgent("C", description="Echo agent C")
|
||||
agent_d = _EchoAgent("D", description="Echo agent D")
|
||||
|
||||
# Use string conditions
|
||||
graph = DiGraph(
|
||||
nodes={
|
||||
"A": DiGraphNode(
|
||||
@ -1005,10 +1025,18 @@ def test_add_conditional_edges() -> None:
|
||||
|
||||
edges = builder.nodes["A"].edges
|
||||
assert len(edges) == 2
|
||||
conditions = {e.condition for e in edges}
|
||||
targets = {e.target for e in edges}
|
||||
assert conditions == {"yes", "no"}
|
||||
assert targets == {"B", "C"}
|
||||
|
||||
# Extract the condition strings to compare them
|
||||
conditions = [e.condition for e in edges]
|
||||
assert "yes" in conditions
|
||||
assert "no" in conditions
|
||||
|
||||
# Match edge targets with conditions
|
||||
yes_edge = next(e for e in edges if e.condition == "yes")
|
||||
no_edge = next(e for e in edges if e.condition == "no")
|
||||
|
||||
assert yes_edge.target == "B"
|
||||
assert no_edge.target == "C"
|
||||
|
||||
|
||||
def test_set_entry_point() -> None:
|
||||
@ -1085,8 +1113,16 @@ def test_build_conditional_loop() -> None:
|
||||
builder.set_entry_point(a)
|
||||
graph = builder.build()
|
||||
|
||||
assert graph.nodes["B"].edges[0].condition == "loop"
|
||||
assert graph.nodes["B"].edges[1].condition == "exit"
|
||||
# Check that edges have the right conditions and targets
|
||||
edges = graph.nodes["B"].edges
|
||||
assert len(edges) == 2
|
||||
|
||||
# Find edges by their conditions
|
||||
loop_edge = next(e for e in edges if e.condition == "loop")
|
||||
exit_edge = next(e for e in edges if e.condition == "exit")
|
||||
|
||||
assert loop_edge.target == "A"
|
||||
assert exit_edge.target == "C"
|
||||
assert graph.has_cycles_with_exit()
|
||||
|
||||
|
||||
@ -1152,6 +1188,7 @@ async def test_graph_builder_conditional_execution(runtime: AgentRuntime | None)
|
||||
termination_condition=MaxMessageTermination(5),
|
||||
)
|
||||
|
||||
# Input "no" should trigger the edge to C
|
||||
result = await team.run(task="no")
|
||||
sources = [m.source for m in result.messages]
|
||||
assert "C" in sources
|
||||
@ -1159,27 +1196,45 @@ async def test_graph_builder_conditional_execution(runtime: AgentRuntime | None)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_graph_builder_with_filter_agent(runtime: AgentRuntime | None) -> None:
|
||||
inner = _EchoAgent("X", description="Echo X")
|
||||
filter_agent = MessageFilterAgent(
|
||||
name="X",
|
||||
wrapped_agent=inner,
|
||||
filter=MessageFilterConfig(per_source=[PerSourceFilter(source="user", position="last", count=1)]),
|
||||
)
|
||||
async def test_digraph_group_chat_callable_condition(runtime: AgentRuntime | None) -> None:
|
||||
"""Test that string conditions work correctly in edge transitions."""
|
||||
agent_a = _EchoAgent("A", description="Echo agent A")
|
||||
agent_b = _EchoAgent("B", description="Echo agent B")
|
||||
agent_c = _EchoAgent("C", description="Echo agent C")
|
||||
|
||||
builder = DiGraphBuilder()
|
||||
builder.add_node(filter_agent)
|
||||
graph = DiGraph(
|
||||
nodes={
|
||||
"A": DiGraphNode(
|
||||
name="A",
|
||||
edges=[
|
||||
# Will go to B if "long" is in message
|
||||
DiGraphEdge(target="B", condition="long"),
|
||||
# Will go to C if "short" is in message
|
||||
DiGraphEdge(target="C", condition="short"),
|
||||
],
|
||||
),
|
||||
"B": DiGraphNode(name="B", edges=[]),
|
||||
"C": DiGraphNode(name="C", edges=[]),
|
||||
}
|
||||
)
|
||||
|
||||
team = GraphFlow(
|
||||
participants=builder.get_participants(),
|
||||
graph=builder.build(),
|
||||
participants=[agent_a, agent_b, agent_c],
|
||||
graph=graph,
|
||||
runtime=runtime,
|
||||
termination_condition=MaxMessageTermination(3),
|
||||
termination_condition=MaxMessageTermination(5),
|
||||
)
|
||||
|
||||
result = await team.run(task="Hello")
|
||||
assert any(m.source == "X" and m.content == "Hello" for m in result.messages) # type: ignore[union-attr]
|
||||
assert result.stop_reason is not None
|
||||
# Test with a message containing "long" - should go to B
|
||||
result = await team.run(task="This is a long message")
|
||||
assert result.messages[2].source == "B"
|
||||
|
||||
# Reset for next test
|
||||
await team.reset()
|
||||
|
||||
# Test with a message containing "short" - should go to C
|
||||
result = await team.run(task="This is a short message")
|
||||
assert result.messages[2].source == "C"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -1268,3 +1323,36 @@ async def test_graph_flow_stateful_pause_and_resume_with_termination() -> None:
|
||||
assert len(result.messages) == 2
|
||||
assert result.messages[0].source == "B"
|
||||
assert result.messages[1].source == _DIGRAPH_STOP_AGENT_NAME
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_builder_with_lambda_condition(runtime: AgentRuntime | None) -> None:
|
||||
"""Test that DiGraphBuilder supports string conditions."""
|
||||
agent_a = _EchoAgent("A", description="Echo agent A")
|
||||
agent_b = _EchoAgent("B", description="Echo agent B")
|
||||
agent_c = _EchoAgent("C", description="Echo agent C")
|
||||
|
||||
builder = DiGraphBuilder()
|
||||
builder.add_node(agent_a).add_node(agent_b).add_node(agent_c)
|
||||
|
||||
# Using callable conditions
|
||||
builder.add_edge(agent_a, agent_b, lambda msg: "even" in msg.to_model_text())
|
||||
builder.add_edge(agent_a, agent_c, lambda msg: "odd" in msg.to_model_text())
|
||||
|
||||
team = GraphFlow(
|
||||
participants=builder.get_participants(),
|
||||
graph=builder.build(),
|
||||
runtime=runtime,
|
||||
termination_condition=MaxMessageTermination(5),
|
||||
)
|
||||
|
||||
# Test with "even" in message - should go to B
|
||||
result = await team.run(task="even length")
|
||||
assert result.messages[2].source == "B"
|
||||
|
||||
# Reset for next test
|
||||
await team.reset()
|
||||
|
||||
# Test with "odd" in message - should go to C
|
||||
result = await team.run(task="odd message")
|
||||
assert result.messages[2].source == "C"
|
||||
|
@ -414,7 +414,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"execution_count": null,
|
||||
"id": "af297db2",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -591,7 +591,7 @@
|
||||
"reviewer = AssistantAgent(\n",
|
||||
" \"reviewer\",\n",
|
||||
" model_client=model_client,\n",
|
||||
" system_message=\"Review ideas and say 'REVISE' and provide feedbacks, or 'APPROVE' for final approval.\",\n",
|
||||
" system_message=\"Review ideas and provide feedbacks, or just 'APPROVE' for final approval.\",\n",
|
||||
")\n",
|
||||
"summarizer_core = AssistantAgent(\n",
|
||||
" \"summary\", model_client=model_client, system_message=\"Summarize the user request and the final feedback.\"\n",
|
||||
@ -613,8 +613,8 @@
|
||||
"builder = DiGraphBuilder()\n",
|
||||
"builder.add_node(generator).add_node(reviewer).add_node(filtered_summarizer)\n",
|
||||
"builder.add_edge(generator, reviewer)\n",
|
||||
"builder.add_edge(reviewer, generator, condition=\"REVISE\")\n",
|
||||
"builder.add_edge(reviewer, filtered_summarizer, condition=\"APPROVE\")\n",
|
||||
"builder.add_edge(reviewer, filtered_summarizer, condition=lambda msg: \"APPROVE\" in msg.to_model_text())\n",
|
||||
"builder.add_edge(reviewer, generator, condition=lambda msg: \"APPROVE\" not in msg.to_model_text())\n",
|
||||
"builder.set_entry_point(generator) # Set entry point to generator. Required if there are no source nodes.\n",
|
||||
"graph = builder.build()\n",
|
||||
"\n",
|
||||
|
Loading…
x
Reference in New Issue
Block a user