Remove string option from GraphFlow edges conditions

Co-authored-by: ekzhu <320302+ekzhu@users.noreply.github.com>
This commit is contained in:
copilot-swe-agent[bot] 2025-05-20 22:01:45 +00:00
parent cbb0a5dff6
commit e2fdb7e5d6
3 changed files with 33 additions and 23 deletions

View File

@ -37,10 +37,9 @@ class DiGraphEdge(BaseModel):
"""
target: str # Target node name
condition: Union[str, Callable[[BaseChatMessage], bool], None] = None
condition: Callable[[BaseChatMessage], bool] | None = None
"""(Experimental) Condition to execute this edge.
If None, the edge is unconditional.
If a string, the edge is conditional on the presence of that string in the last agent chat message.
If a callable, the edge is conditional on the callable returning True when given the last message.
"""
@ -49,7 +48,7 @@ class DiGraphEdge(BaseModel):
@model_validator(mode='after')
def _validate_condition(self) -> 'DiGraphEdge':
# Store callable in a separate field and set condition to a string marker
# Store callable in a separate field and set condition to None for serialization
if callable(self.condition):
self._condition_function = self.condition
# For serialization purposes, we'll set the condition to None
@ -69,9 +68,6 @@ class DiGraphEdge(BaseModel):
"""
if self._condition_function is not None:
return self._condition_function(message)
elif isinstance(self.condition, str):
# If it's a string, check if the string is in the message content
return self.condition in message.to_model_text()
return True # None condition is always satisfied
@ -155,7 +151,7 @@ class DiGraph(BaseModel):
cycle_edges: List[DiGraphEdge] = []
for n in cycle_nodes:
cycle_edges.extend(self.nodes[n].edges)
if not any(edge.condition for edge in cycle_edges):
if not any(edge._condition_function is not None for edge in cycle_edges):
raise ValueError(
f"Cycle detected without exit condition: {' -> '.join(cycle_nodes + cycle_nodes[:1])}"
)
@ -194,8 +190,8 @@ class DiGraph(BaseModel):
# Outgoing edge condition validation (per node)
for node in self.nodes.values():
# Check that if a node has an outgoing conditional edge, then all outgoing edges are conditional
has_condition = any(edge.condition for edge in node.edges)
has_unconditioned = any(edge.condition is None for edge in node.edges)
has_condition = any(edge._condition_function is not None for edge in node.edges)
has_unconditioned = any(edge._condition_function is None and edge.condition is None for edge in node.edges)
if has_condition and has_unconditioned:
raise ValueError(f"Node '{node.name}' has a mix of conditional and unconditional edges.")
@ -507,8 +503,11 @@ class GraphFlow(BaseGroupChat, Component[GraphFlowConfig]):
# Create a directed graph with conditional branching flow A -> B ("yes"), A -> C ("no").
builder = DiGraphBuilder()
builder.add_node(agent_a).add_node(agent_b).add_node(agent_c)
builder.add_edge(agent_a, agent_b, condition="yes")
builder.add_edge(agent_a, agent_c, condition="no")
# Create lambda functions to check for specific words in messages
yes_condition = lambda msg: "yes" in msg.to_model_text().lower()
no_condition = lambda msg: "no" in msg.to_model_text().lower()
builder.add_edge(agent_a, agent_b, condition=yes_condition)
builder.add_edge(agent_a, agent_c, condition=no_condition)
graph = builder.build()
# Create a GraphFlow team with the directed graph.
@ -559,7 +558,13 @@ class GraphFlow(BaseGroupChat, Component[GraphFlowConfig]):
builder = DiGraphBuilder()
builder.add_node(agent_a).add_node(agent_b).add_node(agent_c)
builder.add_edge(agent_a, agent_b)
builder.add_conditional_edges(agent_b, {"APPROVE": agent_c, "REJECT": agent_a})
# Create conditional edges using keyword-based lambdas
approve_condition = lambda msg: "APPROVE" in msg.to_model_text()
reject_condition = lambda msg: "REJECT" in msg.to_model_text()
builder.add_edge(agent_b, agent_c, condition=approve_condition)
builder.add_edge(agent_b, agent_a, condition=reject_condition)
builder.set_entry_point(agent_a)
graph = builder.build()

View File

@ -23,7 +23,7 @@ class DiGraphBuilder:
- Cyclic loops with safe exits
Each node in the graph represents an agent. Edges define execution paths between agents,
and can optionally be conditioned on message content or custom callable conditions.
and can optionally be conditioned on message content using callable functions.
The builder is compatible with the `Graph` runner and supports both standard and filtered agents.
@ -50,9 +50,10 @@ class DiGraphBuilder:
>>> builder.add_node(agent_a).add_node(agent_b).add_node(agent_c)
>>> builder.add_edge(agent_a, agent_b).add_edge(agent_a, agent_c)
Example Conditional Branching A B ("yes"), A C ("no"):
Example Conditional Branching A B or A C:
>>> builder = GraphBuilder()
>>> builder.add_node(agent_a).add_node(agent_b).add_node(agent_c)
>>> # Add conditional edges using keyword check lambdas
>>> builder.add_conditional_edges(agent_a, {"yes": agent_b, "no": agent_c})
Example Using Custom Callable Conditions:
@ -65,7 +66,7 @@ class DiGraphBuilder:
>>> builder.add_edge(agent_a, agent_c,
... lambda msg: "error" in msg.to_model_text().lower())
Example Loop: A B A ("loop"), B C ("exit"):
Example Loop: A B A or B C:
>>> builder = GraphBuilder()
>>> builder.add_node(agent_a).add_node(agent_b).add_node(agent_c)
>>> builder.add_edge(agent_a, agent_b)
@ -89,7 +90,7 @@ class DiGraphBuilder:
return self
def add_edge(
self, source: Union[str, ChatAgent], target: Union[str, ChatAgent], condition: Optional[Union[str, Callable[[BaseChatMessage], bool]]] = None
self, source: Union[str, ChatAgent], target: Union[str, ChatAgent], condition: Optional[Callable[[BaseChatMessage], bool]] = None
) -> "DiGraphBuilder":
"""Add a directed edge from source to target, optionally with a condition.
@ -97,7 +98,6 @@ class DiGraphBuilder:
source: Source node (agent name or agent object)
target: Target node (agent name or agent object)
condition: Optional condition for edge activation.
If string, activates when substring is found in message.
If callable, activates when function returns True for the message.
Returns:
@ -120,19 +120,24 @@ class DiGraphBuilder:
def add_conditional_edges(
self, source: Union[str, ChatAgent], condition_to_target: Dict[str, Union[str, ChatAgent]]
) -> "DiGraphBuilder":
"""Add multiple conditional edges from a source node based on condition strings.
"""Add multiple conditional edges from a source node based on keyword checks.
Args:
source: Source node (agent name or agent object)
condition_to_target: Mapping from condition strings to target nodes
Each key is a condition string that must be present in the message
Each key is a keyword that will be checked in the message content
Each value is the target node to activate when condition is met
For each key (keyword), a lambda will be created that checks
if the keyword is in the message text.
Returns:
Self for method chaining
"""
for condition, target in condition_to_target.items():
self.add_edge(source, target, condition)
for condition_keyword, target in condition_to_target.items():
# Create a lambda that checks if keyword is in message
condition_func = lambda msg, kw=condition_keyword: kw in msg.to_model_text()
self.add_edge(source, target, condition_func)
return self
def set_entry_point(self, name: Union[str, ChatAgent]) -> "DiGraphBuilder":

View File

@ -1175,8 +1175,8 @@ async def test_digraph_group_chat_callable_condition(runtime: AgentRuntime | Non
edges=[
# Will go to B if message has >5 chars
DiGraphEdge(target="B", condition=check_message_length),
# Will go to C if message has <=5 chars (handled by adding edge without condition)
DiGraphEdge(target="C"),
# Will go to C if message has <=5 chars
DiGraphEdge(target="C", condition=lambda msg: len(msg.to_model_text()) <= 5),
]
),
"B": DiGraphNode(name="B", edges=[]),