mirror of
https://github.com/microsoft/autogen.git
synced 2025-08-15 04:01:26 +00:00
parent
1b32eb660d
commit
c99aa7416d
@ -198,8 +198,10 @@ class DiGraph(BaseModel):
|
|||||||
# Outgoing edge condition validation (per node)
|
# Outgoing edge condition validation (per node)
|
||||||
for node in self.nodes.values():
|
for node in self.nodes.values():
|
||||||
# Check that if a node has an outgoing conditional edge, then all outgoing edges are conditional
|
# Check that if a node has an outgoing conditional edge, then all outgoing edges are conditional
|
||||||
has_condition = any(edge.condition is not None for edge in node.edges)
|
has_condition = any(
|
||||||
has_unconditioned = any(edge.condition is None for edge in node.edges)
|
edge.condition is not None or edge.condition_function is not None for edge in node.edges
|
||||||
|
)
|
||||||
|
has_unconditioned = any(edge.condition is None and edge.condition_function is None for edge in node.edges)
|
||||||
if has_condition and has_unconditioned:
|
if has_condition and has_unconditioned:
|
||||||
raise ValueError(f"Node '{node.name}' has a mix of conditional and unconditional edges.")
|
raise ValueError(f"Node '{node.name}' has a mix of conditional and unconditional edges.")
|
||||||
|
|
||||||
|
@ -259,6 +259,19 @@ def test_validate_graph_success() -> None:
|
|||||||
graph.graph_validate()
|
graph.graph_validate()
|
||||||
assert not graph.get_has_cycles()
|
assert not graph.get_has_cycles()
|
||||||
|
|
||||||
|
# Use a lambda condition
|
||||||
|
graph_with_lambda = DiGraph(
|
||||||
|
nodes={
|
||||||
|
"A": DiGraphNode(
|
||||||
|
name="A", edges=[DiGraphEdge(target="B", condition=lambda msg: "test" in msg.to_model_text())]
|
||||||
|
),
|
||||||
|
"B": DiGraphNode(name="B", edges=[]),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
# No error should be raised
|
||||||
|
graph_with_lambda.graph_validate()
|
||||||
|
assert not graph_with_lambda.get_has_cycles()
|
||||||
|
|
||||||
|
|
||||||
def test_validate_graph_missing_start_node() -> None:
|
def test_validate_graph_missing_start_node() -> None:
|
||||||
"""Test validation failure when no start node exists."""
|
"""Test validation failure when no start node exists."""
|
||||||
@ -298,6 +311,23 @@ def test_validate_graph_mixed_conditions() -> None:
|
|||||||
with pytest.raises(ValueError, match="Node 'A' has a mix of conditional and unconditional edges"):
|
with pytest.raises(ValueError, match="Node 'A' has a mix of conditional and unconditional edges"):
|
||||||
graph.graph_validate()
|
graph.graph_validate()
|
||||||
|
|
||||||
|
# Use lambda for condition
|
||||||
|
graph_with_lambda = DiGraph(
|
||||||
|
nodes={
|
||||||
|
"A": DiGraphNode(
|
||||||
|
name="A",
|
||||||
|
edges=[
|
||||||
|
DiGraphEdge(target="B", condition=lambda msg: "test" in msg.to_model_text()),
|
||||||
|
DiGraphEdge(target="C"),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
"B": DiGraphNode(name="B", edges=[]),
|
||||||
|
"C": DiGraphNode(name="C", edges=[]),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
with pytest.raises(ValueError, match="Node 'A' has a mix of conditional and unconditional edges"):
|
||||||
|
graph_with_lambda.graph_validate()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_invalid_digraph_manager_cycle_without_termination() -> None:
|
async def test_invalid_digraph_manager_cycle_without_termination() -> None:
|
||||||
@ -603,6 +633,29 @@ async def test_digraph_group_chat_conditional_branch(runtime: AgentRuntime | Non
|
|||||||
result = await team.run(task="Trigger yes")
|
result = await team.run(task="Trigger yes")
|
||||||
assert result.messages[2].source == "B"
|
assert result.messages[2].source == "B"
|
||||||
|
|
||||||
|
# Use lambda conditions
|
||||||
|
graph_with_lambda = DiGraph(
|
||||||
|
nodes={
|
||||||
|
"A": DiGraphNode(
|
||||||
|
name="A",
|
||||||
|
edges=[
|
||||||
|
DiGraphEdge(target="B", condition=lambda msg: "yes" in msg.to_model_text()),
|
||||||
|
DiGraphEdge(target="C", condition=lambda msg: "no" in msg.to_model_text()),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
"B": DiGraphNode(name="B", edges=[], activation="any"),
|
||||||
|
"C": DiGraphNode(name="C", edges=[], activation="any"),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
team_with_lambda = GraphFlow(
|
||||||
|
participants=[agent_a, agent_b, agent_c],
|
||||||
|
graph=graph_with_lambda,
|
||||||
|
runtime=runtime,
|
||||||
|
termination_condition=MaxMessageTermination(5),
|
||||||
|
)
|
||||||
|
result_with_lambda = await team_with_lambda.run(task="Trigger no")
|
||||||
|
assert result_with_lambda.messages[2].source == "C"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_digraph_group_chat_loop_with_exit_condition(runtime: AgentRuntime | None) -> None:
|
async def test_digraph_group_chat_loop_with_exit_condition(runtime: AgentRuntime | None) -> None:
|
||||||
@ -785,6 +838,31 @@ async def test_digraph_group_chat_multiple_conditional(runtime: AgentRuntime | N
|
|||||||
result = await team.run(task="banana")
|
result = await team.run(task="banana")
|
||||||
assert result.messages[2].source == "C"
|
assert result.messages[2].source == "C"
|
||||||
|
|
||||||
|
# Use lambda conditions
|
||||||
|
graph_with_lambda = DiGraph(
|
||||||
|
nodes={
|
||||||
|
"A": DiGraphNode(
|
||||||
|
name="A",
|
||||||
|
edges=[
|
||||||
|
DiGraphEdge(target="B", condition=lambda msg: "apple" in msg.to_model_text()),
|
||||||
|
DiGraphEdge(target="C", condition=lambda msg: "banana" in msg.to_model_text()),
|
||||||
|
DiGraphEdge(target="D", condition=lambda msg: "cherry" in msg.to_model_text()),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
"B": DiGraphNode(name="B", edges=[]),
|
||||||
|
"C": DiGraphNode(name="C", edges=[]),
|
||||||
|
"D": DiGraphNode(name="D", edges=[]),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
team_with_lambda = GraphFlow(
|
||||||
|
participants=[agent_a, agent_b, agent_c, agent_d],
|
||||||
|
graph=graph_with_lambda,
|
||||||
|
runtime=runtime,
|
||||||
|
termination_condition=MaxMessageTermination(5),
|
||||||
|
)
|
||||||
|
result_with_lambda = await team_with_lambda.run(task="cherry")
|
||||||
|
assert result_with_lambda.messages[2].source == "D"
|
||||||
|
|
||||||
|
|
||||||
class _TestMessageFilterAgentConfig(BaseModel):
|
class _TestMessageFilterAgentConfig(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
|
Loading…
x
Reference in New Issue
Block a user