mirror of
https://github.com/microsoft/autogen.git
synced 2025-12-27 15:09:41 +00:00
Fix tests to use callables instead of strings in GraphFlow edge conditions
Co-authored-by: ekzhu <320302+ekzhu@users.noreply.github.com>
This commit is contained in:
parent
637417dcc2
commit
e6cb014614
@ -98,9 +98,12 @@ def test_get_leaf_nodes() -> None:
|
||||
|
||||
def test_serialization() -> None:
|
||||
"""Test serializing and deserializing the graph."""
|
||||
# Create a lambda function condition instead of a string
|
||||
trigger_condition = lambda msg: "trigger1" in msg.to_model_text()
|
||||
|
||||
graph = DiGraph(
|
||||
nodes={
|
||||
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B", condition="trigger1")]),
|
||||
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B", condition=trigger_condition)]),
|
||||
"B": DiGraphNode(name="B", edges=[DiGraphEdge(target="C")]),
|
||||
"C": DiGraphNode(name="C", edges=[]),
|
||||
}
|
||||
@ -110,8 +113,13 @@ def test_serialization() -> None:
|
||||
deserialized_graph = DiGraph.model_validate_json(serialized)
|
||||
|
||||
assert deserialized_graph.nodes["A"].edges[0].target == "B"
|
||||
assert deserialized_graph.nodes["A"].edges[0].condition == "trigger1"
|
||||
# Condition should be None in serialized form since callables can't be serialized directly
|
||||
assert deserialized_graph.nodes["A"].edges[0].condition is None
|
||||
assert deserialized_graph.nodes["B"].edges[0].target == "C"
|
||||
|
||||
# Test the original condition works
|
||||
test_msg = TextMessage(content="this has trigger1 in it", source="test")
|
||||
assert graph.nodes["A"].edges[0].check_condition(test_msg)
|
||||
|
||||
|
||||
def test_invalid_graph_no_start_node() -> None:
|
||||
@ -143,15 +151,23 @@ def test_invalid_graph_no_leaf_node() -> None:
|
||||
|
||||
def test_condition_edge_execution() -> None:
|
||||
"""Test conditional edge execution support."""
|
||||
# Use a lambda function instead of a string condition
|
||||
trigger_condition = lambda msg: "TRIGGER" in msg.to_model_text()
|
||||
|
||||
graph = DiGraph(
|
||||
nodes={
|
||||
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B", condition="TRIGGER")]),
|
||||
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B", condition=trigger_condition)]),
|
||||
"B": DiGraphNode(name="B", edges=[DiGraphEdge(target="C")]),
|
||||
"C": DiGraphNode(name="C", edges=[]),
|
||||
}
|
||||
)
|
||||
|
||||
assert graph.nodes["A"].edges[0].condition == "TRIGGER"
|
||||
# Check the condition actually works as expected
|
||||
test_message = TextMessage(content="This has TRIGGER in it", source="test")
|
||||
non_match_message = TextMessage(content="This doesn't match", source="test")
|
||||
|
||||
assert graph.nodes["A"].edges[0].check_condition(test_message)
|
||||
assert not graph.nodes["A"].edges[0].check_condition(non_match_message)
|
||||
assert graph.nodes["B"].edges[0].condition is None
|
||||
|
||||
|
||||
@ -192,11 +208,14 @@ def test_cycle_detection_no_cycle() -> None:
|
||||
|
||||
def test_cycle_detection_with_exit_condition() -> None:
|
||||
"""Test a graph with cycle and conditional exit passes validation."""
|
||||
# Use a lambda condition instead of a string
|
||||
exit_condition = lambda msg: "exit" in msg.to_model_text()
|
||||
|
||||
graph = DiGraph(
|
||||
nodes={
|
||||
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B")]),
|
||||
"B": DiGraphNode(name="B", edges=[DiGraphEdge(target="C")]),
|
||||
"C": DiGraphNode(name="C", edges=[DiGraphEdge(target="A", condition="exit")]), # Cycle with condition
|
||||
"C": DiGraphNode(name="C", edges=[DiGraphEdge(target="A", condition=exit_condition)]), # Cycle with condition
|
||||
}
|
||||
)
|
||||
assert graph.has_cycles_with_exit()
|
||||
@ -257,9 +276,12 @@ def test_validate_graph_missing_leaf_node() -> None:
|
||||
|
||||
def test_validate_graph_mixed_conditions() -> None:
|
||||
"""Test validation failure when node has mixed conditional and unconditional edges."""
|
||||
# Use lambda instead of string for condition
|
||||
cond_function = lambda msg: "cond" in msg.to_model_text()
|
||||
|
||||
graph = DiGraph(
|
||||
nodes={
|
||||
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B", condition="cond"), DiGraphEdge(target="C")]),
|
||||
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B", condition=cond_function), DiGraphEdge(target="C")]),
|
||||
"B": DiGraphNode(name="B", edges=[]),
|
||||
"C": DiGraphNode(name="C", edges=[]),
|
||||
}
|
||||
@ -551,10 +573,14 @@ async def test_digraph_group_chat_conditional_branch(runtime: AgentRuntime | Non
|
||||
agent_b = _EchoAgent("B", description="Echo agent B")
|
||||
agent_c = _EchoAgent("C", description="Echo agent C")
|
||||
|
||||
# Use lambda functions instead of strings
|
||||
yes_condition = lambda msg: "yes" in msg.to_model_text().lower()
|
||||
no_condition = lambda msg: "no" in msg.to_model_text().lower()
|
||||
|
||||
graph = DiGraph(
|
||||
nodes={
|
||||
"A": DiGraphNode(
|
||||
name="A", edges=[DiGraphEdge(target="B", condition="yes"), DiGraphEdge(target="C", condition="no")]
|
||||
name="A", edges=[DiGraphEdge(target="B", condition=yes_condition), DiGraphEdge(target="C", condition=no_condition)]
|
||||
),
|
||||
"B": DiGraphNode(name="B", edges=[], activation="any"),
|
||||
"C": DiGraphNode(name="C", edges=[], activation="any"),
|
||||
@ -589,12 +615,16 @@ async def test_digraph_group_chat_loop_with_exit_condition(runtime: AgentRuntime
|
||||
# Agent B: Assistant Agent using Replay Client
|
||||
agent_b = AssistantAgent("B", description="Decision agent B", model_client=model_client)
|
||||
|
||||
# Create lambda conditions instead of strings
|
||||
loop_condition = lambda msg: "loop" in msg.to_model_text().lower()
|
||||
exit_condition = lambda msg: "exit" in msg.to_model_text().lower()
|
||||
|
||||
# DiGraph: A → B → C (conditional back to A or terminate)
|
||||
graph = DiGraph(
|
||||
nodes={
|
||||
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B")]),
|
||||
"B": DiGraphNode(
|
||||
name="B", edges=[DiGraphEdge(target="C", condition="exit"), DiGraphEdge(target="A", condition="loop")]
|
||||
name="B", edges=[DiGraphEdge(target="C", condition=exit_condition), DiGraphEdge(target="A", condition=loop_condition)]
|
||||
),
|
||||
"C": DiGraphNode(name="C", edges=[]),
|
||||
},
|
||||
@ -725,14 +755,19 @@ async def test_digraph_group_chat_multiple_conditional(runtime: AgentRuntime | N
|
||||
agent_c = _EchoAgent("C", description="Echo agent C")
|
||||
agent_d = _EchoAgent("D", description="Echo agent D")
|
||||
|
||||
# Use lambda functions for conditions
|
||||
apple_condition = lambda msg: "apple" in msg.to_model_text().lower()
|
||||
banana_condition = lambda msg: "banana" in msg.to_model_text().lower()
|
||||
cherry_condition = lambda msg: "cherry" in msg.to_model_text().lower()
|
||||
|
||||
graph = DiGraph(
|
||||
nodes={
|
||||
"A": DiGraphNode(
|
||||
name="A",
|
||||
edges=[
|
||||
DiGraphEdge(target="B", condition="apple"),
|
||||
DiGraphEdge(target="C", condition="banana"),
|
||||
DiGraphEdge(target="D", condition="cherry"),
|
||||
DiGraphEdge(target="B", condition=apple_condition),
|
||||
DiGraphEdge(target="C", condition=banana_condition),
|
||||
DiGraphEdge(target="D", condition=cherry_condition),
|
||||
],
|
||||
),
|
||||
"B": DiGraphNode(name="B", edges=[]),
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user