diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_graph/_digraph_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_graph/_digraph_group_chat.py index 58069ae2e..ffd2a2764 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_graph/_digraph_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_graph/_digraph_group_chat.py @@ -37,10 +37,9 @@ class DiGraphEdge(BaseModel): """ target: str # Target node name - condition: Union[str, Callable[[BaseChatMessage], bool], None] = None + condition: Callable[[BaseChatMessage], bool] | None = None """(Experimental) Condition to execute this edge. If None, the edge is unconditional. - If a string, the edge is conditional on the presence of that string in the last agent chat message. If a callable, the edge is conditional on the callable returning True when given the last message. """ @@ -49,7 +48,7 @@ class DiGraphEdge(BaseModel): @model_validator(mode='after') def _validate_condition(self) -> 'DiGraphEdge': - # Store callable in a separate field and set condition to a string marker + # Store callable in a separate field and set condition to None for serialization if callable(self.condition): self._condition_function = self.condition # For serialization purposes, we'll set the condition to None @@ -69,9 +68,6 @@ class DiGraphEdge(BaseModel): """ if self._condition_function is not None: return self._condition_function(message) - elif isinstance(self.condition, str): - # If it's a string, check if the string is in the message content - return self.condition in message.to_model_text() return True # None condition is always satisfied @@ -155,7 +151,7 @@ class DiGraph(BaseModel): cycle_edges: List[DiGraphEdge] = [] for n in cycle_nodes: cycle_edges.extend(self.nodes[n].edges) - if not any(edge.condition for edge in cycle_edges): + if not any(edge._condition_function is not None for edge in cycle_edges): raise ValueError( f"Cycle detected without exit condition: {' -> '.join(cycle_nodes + cycle_nodes[:1])}" ) @@ -194,8 +190,8 @@ class DiGraph(BaseModel): # Outgoing edge condition validation (per node) for node in self.nodes.values(): # Check that if a node has an outgoing conditional edge, then all outgoing edges are conditional - has_condition = any(edge.condition for edge in node.edges) - has_unconditioned = any(edge.condition is None for edge in node.edges) + has_condition = any(edge._condition_function is not None for edge in node.edges) + has_unconditioned = any(edge._condition_function is None and edge.condition is None for edge in node.edges) if has_condition and has_unconditioned: raise ValueError(f"Node '{node.name}' has a mix of conditional and unconditional edges.") @@ -507,8 +503,11 @@ class GraphFlow(BaseGroupChat, Component[GraphFlowConfig]): # Create a directed graph with conditional branching flow A -> B ("yes"), A -> C ("no"). builder = DiGraphBuilder() builder.add_node(agent_a).add_node(agent_b).add_node(agent_c) - builder.add_edge(agent_a, agent_b, condition="yes") - builder.add_edge(agent_a, agent_c, condition="no") + # Create lambda functions to check for specific words in messages + yes_condition = lambda msg: "yes" in msg.to_model_text().lower() + no_condition = lambda msg: "no" in msg.to_model_text().lower() + builder.add_edge(agent_a, agent_b, condition=yes_condition) + builder.add_edge(agent_a, agent_c, condition=no_condition) graph = builder.build() # Create a GraphFlow team with the directed graph. @@ -559,7 +558,13 @@ class GraphFlow(BaseGroupChat, Component[GraphFlowConfig]): builder = DiGraphBuilder() builder.add_node(agent_a).add_node(agent_b).add_node(agent_c) builder.add_edge(agent_a, agent_b) - builder.add_conditional_edges(agent_b, {"APPROVE": agent_c, "REJECT": agent_a}) + + # Create conditional edges using keyword-based lambdas + approve_condition = lambda msg: "APPROVE" in msg.to_model_text() + reject_condition = lambda msg: "REJECT" in msg.to_model_text() + builder.add_edge(agent_b, agent_c, condition=approve_condition) + builder.add_edge(agent_b, agent_a, condition=reject_condition) + builder.set_entry_point(agent_a) graph = builder.build() diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_graph/_graph_builder.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_graph/_graph_builder.py index b01c81354..72d6c33c3 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_graph/_graph_builder.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_graph/_graph_builder.py @@ -23,7 +23,7 @@ class DiGraphBuilder: - Cyclic loops with safe exits Each node in the graph represents an agent. Edges define execution paths between agents, - and can optionally be conditioned on message content or custom callable conditions. + and can optionally be conditioned on message content using callable functions. The builder is compatible with the `Graph` runner and supports both standard and filtered agents. @@ -50,9 +50,10 @@ class DiGraphBuilder: >>> builder.add_node(agent_a).add_node(agent_b).add_node(agent_c) >>> builder.add_edge(agent_a, agent_b).add_edge(agent_a, agent_c) - Example — Conditional Branching A → B ("yes"), A → C ("no"): + Example — Conditional Branching A → B or A → C: >>> builder = GraphBuilder() >>> builder.add_node(agent_a).add_node(agent_b).add_node(agent_c) + >>> # Add conditional edges using keyword check lambdas >>> builder.add_conditional_edges(agent_a, {"yes": agent_b, "no": agent_c}) Example — Using Custom Callable Conditions: @@ -65,7 +66,7 @@ class DiGraphBuilder: >>> builder.add_edge(agent_a, agent_c, ... lambda msg: "error" in msg.to_model_text().lower()) - Example — Loop: A → B → A ("loop"), B → C ("exit"): + Example — Loop: A → B → A or B → C: >>> builder = GraphBuilder() >>> builder.add_node(agent_a).add_node(agent_b).add_node(agent_c) >>> builder.add_edge(agent_a, agent_b) @@ -89,7 +90,7 @@ class DiGraphBuilder: return self def add_edge( - self, source: Union[str, ChatAgent], target: Union[str, ChatAgent], condition: Optional[Union[str, Callable[[BaseChatMessage], bool]]] = None + self, source: Union[str, ChatAgent], target: Union[str, ChatAgent], condition: Optional[Callable[[BaseChatMessage], bool]] = None ) -> "DiGraphBuilder": """Add a directed edge from source to target, optionally with a condition. @@ -97,7 +98,6 @@ class DiGraphBuilder: source: Source node (agent name or agent object) target: Target node (agent name or agent object) condition: Optional condition for edge activation. - If string, activates when substring is found in message. If callable, activates when function returns True for the message. Returns: @@ -120,19 +120,24 @@ class DiGraphBuilder: def add_conditional_edges( self, source: Union[str, ChatAgent], condition_to_target: Dict[str, Union[str, ChatAgent]] ) -> "DiGraphBuilder": - """Add multiple conditional edges from a source node based on condition strings. + """Add multiple conditional edges from a source node based on keyword checks. Args: source: Source node (agent name or agent object) condition_to_target: Mapping from condition strings to target nodes - Each key is a condition string that must be present in the message + Each key is a keyword that will be checked in the message content Each value is the target node to activate when condition is met + + For each key (keyword), a lambda will be created that checks + if the keyword is in the message text. Returns: Self for method chaining """ - for condition, target in condition_to_target.items(): - self.add_edge(source, target, condition) + for condition_keyword, target in condition_to_target.items(): + # Create a lambda that checks if keyword is in message + condition_func = lambda msg, kw=condition_keyword: kw in msg.to_model_text() + self.add_edge(source, target, condition_func) return self def set_entry_point(self, name: Union[str, ChatAgent]) -> "DiGraphBuilder": diff --git a/python/packages/autogen-agentchat/tests/test_group_chat_graph.py b/python/packages/autogen-agentchat/tests/test_group_chat_graph.py index 3f0dd17cc..82c25aede 100644 --- a/python/packages/autogen-agentchat/tests/test_group_chat_graph.py +++ b/python/packages/autogen-agentchat/tests/test_group_chat_graph.py @@ -1175,8 +1175,8 @@ async def test_digraph_group_chat_callable_condition(runtime: AgentRuntime | Non edges=[ # Will go to B if message has >5 chars DiGraphEdge(target="B", condition=check_message_length), - # Will go to C if message has <=5 chars (handled by adding edge without condition) - DiGraphEdge(target="C"), + # Will go to C if message has <=5 chars + DiGraphEdge(target="C", condition=lambda msg: len(msg.to_model_text()) <= 5), ] ), "B": DiGraphNode(name="B", edges=[]),