mirror of
https://github.com/microsoft/autogen.git
synced 2026-01-08 05:01:59 +00:00
Remove string option from GraphFlow edges conditions
Co-authored-by: ekzhu <320302+ekzhu@users.noreply.github.com>
This commit is contained in:
parent
cbb0a5dff6
commit
e2fdb7e5d6
@ -37,10 +37,9 @@ class DiGraphEdge(BaseModel):
|
||||
"""
|
||||
|
||||
target: str # Target node name
|
||||
condition: Union[str, Callable[[BaseChatMessage], bool], None] = None
|
||||
condition: Callable[[BaseChatMessage], bool] | None = None
|
||||
"""(Experimental) Condition to execute this edge.
|
||||
If None, the edge is unconditional.
|
||||
If a string, the edge is conditional on the presence of that string in the last agent chat message.
|
||||
If a callable, the edge is conditional on the callable returning True when given the last message.
|
||||
"""
|
||||
|
||||
@ -49,7 +48,7 @@ class DiGraphEdge(BaseModel):
|
||||
|
||||
@model_validator(mode='after')
|
||||
def _validate_condition(self) -> 'DiGraphEdge':
|
||||
# Store callable in a separate field and set condition to a string marker
|
||||
# Store callable in a separate field and set condition to None for serialization
|
||||
if callable(self.condition):
|
||||
self._condition_function = self.condition
|
||||
# For serialization purposes, we'll set the condition to None
|
||||
@ -69,9 +68,6 @@ class DiGraphEdge(BaseModel):
|
||||
"""
|
||||
if self._condition_function is not None:
|
||||
return self._condition_function(message)
|
||||
elif isinstance(self.condition, str):
|
||||
# If it's a string, check if the string is in the message content
|
||||
return self.condition in message.to_model_text()
|
||||
return True # None condition is always satisfied
|
||||
|
||||
|
||||
@ -155,7 +151,7 @@ class DiGraph(BaseModel):
|
||||
cycle_edges: List[DiGraphEdge] = []
|
||||
for n in cycle_nodes:
|
||||
cycle_edges.extend(self.nodes[n].edges)
|
||||
if not any(edge.condition for edge in cycle_edges):
|
||||
if not any(edge._condition_function is not None for edge in cycle_edges):
|
||||
raise ValueError(
|
||||
f"Cycle detected without exit condition: {' -> '.join(cycle_nodes + cycle_nodes[:1])}"
|
||||
)
|
||||
@ -194,8 +190,8 @@ class DiGraph(BaseModel):
|
||||
# Outgoing edge condition validation (per node)
|
||||
for node in self.nodes.values():
|
||||
# Check that if a node has an outgoing conditional edge, then all outgoing edges are conditional
|
||||
has_condition = any(edge.condition for edge in node.edges)
|
||||
has_unconditioned = any(edge.condition is None for edge in node.edges)
|
||||
has_condition = any(edge._condition_function is not None for edge in node.edges)
|
||||
has_unconditioned = any(edge._condition_function is None and edge.condition is None for edge in node.edges)
|
||||
if has_condition and has_unconditioned:
|
||||
raise ValueError(f"Node '{node.name}' has a mix of conditional and unconditional edges.")
|
||||
|
||||
@ -507,8 +503,11 @@ class GraphFlow(BaseGroupChat, Component[GraphFlowConfig]):
|
||||
# Create a directed graph with conditional branching flow A -> B ("yes"), A -> C ("no").
|
||||
builder = DiGraphBuilder()
|
||||
builder.add_node(agent_a).add_node(agent_b).add_node(agent_c)
|
||||
builder.add_edge(agent_a, agent_b, condition="yes")
|
||||
builder.add_edge(agent_a, agent_c, condition="no")
|
||||
# Create lambda functions to check for specific words in messages
|
||||
yes_condition = lambda msg: "yes" in msg.to_model_text().lower()
|
||||
no_condition = lambda msg: "no" in msg.to_model_text().lower()
|
||||
builder.add_edge(agent_a, agent_b, condition=yes_condition)
|
||||
builder.add_edge(agent_a, agent_c, condition=no_condition)
|
||||
graph = builder.build()
|
||||
|
||||
# Create a GraphFlow team with the directed graph.
|
||||
@ -559,7 +558,13 @@ class GraphFlow(BaseGroupChat, Component[GraphFlowConfig]):
|
||||
builder = DiGraphBuilder()
|
||||
builder.add_node(agent_a).add_node(agent_b).add_node(agent_c)
|
||||
builder.add_edge(agent_a, agent_b)
|
||||
builder.add_conditional_edges(agent_b, {"APPROVE": agent_c, "REJECT": agent_a})
|
||||
|
||||
# Create conditional edges using keyword-based lambdas
|
||||
approve_condition = lambda msg: "APPROVE" in msg.to_model_text()
|
||||
reject_condition = lambda msg: "REJECT" in msg.to_model_text()
|
||||
builder.add_edge(agent_b, agent_c, condition=approve_condition)
|
||||
builder.add_edge(agent_b, agent_a, condition=reject_condition)
|
||||
|
||||
builder.set_entry_point(agent_a)
|
||||
graph = builder.build()
|
||||
|
||||
|
||||
@ -23,7 +23,7 @@ class DiGraphBuilder:
|
||||
- Cyclic loops with safe exits
|
||||
|
||||
Each node in the graph represents an agent. Edges define execution paths between agents,
|
||||
and can optionally be conditioned on message content or custom callable conditions.
|
||||
and can optionally be conditioned on message content using callable functions.
|
||||
|
||||
The builder is compatible with the `Graph` runner and supports both standard and filtered agents.
|
||||
|
||||
@ -50,9 +50,10 @@ class DiGraphBuilder:
|
||||
>>> builder.add_node(agent_a).add_node(agent_b).add_node(agent_c)
|
||||
>>> builder.add_edge(agent_a, agent_b).add_edge(agent_a, agent_c)
|
||||
|
||||
Example — Conditional Branching A → B ("yes"), A → C ("no"):
|
||||
Example — Conditional Branching A → B or A → C:
|
||||
>>> builder = GraphBuilder()
|
||||
>>> builder.add_node(agent_a).add_node(agent_b).add_node(agent_c)
|
||||
>>> # Add conditional edges using keyword check lambdas
|
||||
>>> builder.add_conditional_edges(agent_a, {"yes": agent_b, "no": agent_c})
|
||||
|
||||
Example — Using Custom Callable Conditions:
|
||||
@ -65,7 +66,7 @@ class DiGraphBuilder:
|
||||
>>> builder.add_edge(agent_a, agent_c,
|
||||
... lambda msg: "error" in msg.to_model_text().lower())
|
||||
|
||||
Example — Loop: A → B → A ("loop"), B → C ("exit"):
|
||||
Example — Loop: A → B → A or B → C:
|
||||
>>> builder = GraphBuilder()
|
||||
>>> builder.add_node(agent_a).add_node(agent_b).add_node(agent_c)
|
||||
>>> builder.add_edge(agent_a, agent_b)
|
||||
@ -89,7 +90,7 @@ class DiGraphBuilder:
|
||||
return self
|
||||
|
||||
def add_edge(
|
||||
self, source: Union[str, ChatAgent], target: Union[str, ChatAgent], condition: Optional[Union[str, Callable[[BaseChatMessage], bool]]] = None
|
||||
self, source: Union[str, ChatAgent], target: Union[str, ChatAgent], condition: Optional[Callable[[BaseChatMessage], bool]] = None
|
||||
) -> "DiGraphBuilder":
|
||||
"""Add a directed edge from source to target, optionally with a condition.
|
||||
|
||||
@ -97,7 +98,6 @@ class DiGraphBuilder:
|
||||
source: Source node (agent name or agent object)
|
||||
target: Target node (agent name or agent object)
|
||||
condition: Optional condition for edge activation.
|
||||
If string, activates when substring is found in message.
|
||||
If callable, activates when function returns True for the message.
|
||||
|
||||
Returns:
|
||||
@ -120,19 +120,24 @@ class DiGraphBuilder:
|
||||
def add_conditional_edges(
|
||||
self, source: Union[str, ChatAgent], condition_to_target: Dict[str, Union[str, ChatAgent]]
|
||||
) -> "DiGraphBuilder":
|
||||
"""Add multiple conditional edges from a source node based on condition strings.
|
||||
"""Add multiple conditional edges from a source node based on keyword checks.
|
||||
|
||||
Args:
|
||||
source: Source node (agent name or agent object)
|
||||
condition_to_target: Mapping from condition strings to target nodes
|
||||
Each key is a condition string that must be present in the message
|
||||
Each key is a keyword that will be checked in the message content
|
||||
Each value is the target node to activate when condition is met
|
||||
|
||||
For each key (keyword), a lambda will be created that checks
|
||||
if the keyword is in the message text.
|
||||
|
||||
Returns:
|
||||
Self for method chaining
|
||||
"""
|
||||
for condition, target in condition_to_target.items():
|
||||
self.add_edge(source, target, condition)
|
||||
for condition_keyword, target in condition_to_target.items():
|
||||
# Create a lambda that checks if keyword is in message
|
||||
condition_func = lambda msg, kw=condition_keyword: kw in msg.to_model_text()
|
||||
self.add_edge(source, target, condition_func)
|
||||
return self
|
||||
|
||||
def set_entry_point(self, name: Union[str, ChatAgent]) -> "DiGraphBuilder":
|
||||
|
||||
@ -1175,8 +1175,8 @@ async def test_digraph_group_chat_callable_condition(runtime: AgentRuntime | Non
|
||||
edges=[
|
||||
# Will go to B if message has >5 chars
|
||||
DiGraphEdge(target="B", condition=check_message_length),
|
||||
# Will go to C if message has <=5 chars (handled by adding edge without condition)
|
||||
DiGraphEdge(target="C"),
|
||||
# Will go to C if message has <=5 chars
|
||||
DiGraphEdge(target="C", condition=lambda msg: len(msg.to_model_text()) <= 5),
|
||||
]
|
||||
),
|
||||
"B": DiGraphNode(name="B", edges=[]),
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user