fix: fix self-loop in workflow (#6677)

This commit is contained in:
Zen 2025-06-16 14:00:14 +08:00 committed by GitHub
parent 8c1236dd9e
commit cd15c0853c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 61 additions and 1 deletions

View File

@ -112,7 +112,8 @@ class DiGraph(BaseModel):
parents: Dict[str, List[str]] = {node: [] for node in self.nodes} parents: Dict[str, List[str]] = {node: [] for node in self.nodes}
for node in self.nodes.values(): for node in self.nodes.values():
for edge in node.edges: for edge in node.edges:
parents[edge.target].append(node.name) if edge.target != node.name:
parents[edge.target].append(node.name)
return parents return parents
def get_start_nodes(self) -> Set[str]: def get_start_nodes(self) -> Set[str]:

View File

@ -718,6 +718,65 @@ async def test_digraph_group_chat_loop_with_exit_condition(runtime: AgentRuntime
assert result.messages[-1].source == _DIGRAPH_STOP_AGENT_NAME assert result.messages[-1].source == _DIGRAPH_STOP_AGENT_NAME
@pytest.mark.asyncio
async def test_digraph_group_chat_loop_with_exit_condition_2(runtime: AgentRuntime | None) -> None:
# Agents A and C: Echo Agents
agent_a = _EchoAgent("A", description="Echo agent A")
agent_c = _EchoAgent("C", description="Echo agent C")
# Replay model client for agent B
model_client = ReplayChatCompletionClient(
chat_completions=[
"loop", # First time B will ask to loop
"loop", # Second time B will ask to loop
"exit", # Third time B will say exit
]
)
# Agent B: Assistant Agent using Replay Client
agent_b = AssistantAgent("B", description="Decision agent B", model_client=model_client)
# DiGraph: A → B(self loop) → C (conditional back to A or terminate)
graph = DiGraph(
nodes={
"A": DiGraphNode(name="A", edges=[DiGraphEdge(target="B")]),
"B": DiGraphNode(
name="B", edges=[DiGraphEdge(target="C", condition="exit"), DiGraphEdge(target="B", condition="loop")]
),
"C": DiGraphNode(name="C", edges=[]),
},
default_start_node="A",
)
team = GraphFlow(
participants=[agent_a, agent_b, agent_c],
graph=graph,
runtime=runtime,
termination_condition=MaxMessageTermination(20),
)
# Run
result = await team.run(task="Start")
# Assert message order
expected_sources = [
"user",
"A",
"B", # 1st loop
"B", # 2nd loop
"B",
"C",
_DIGRAPH_STOP_AGENT_NAME,
]
actual_sources = [m.source for m in result.messages]
assert actual_sources == expected_sources
assert result.stop_reason is not None
assert result.messages[-2].source == "C"
assert any(m.content == "exit" for m in result.messages[:-1]) # type: ignore[attr-defined,union-attr]
assert result.messages[-1].source == _DIGRAPH_STOP_AGENT_NAME
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_digraph_group_chat_parallel_join_any_1(runtime: AgentRuntime | None) -> None: async def test_digraph_group_chat_parallel_join_any_1(runtime: AgentRuntime | None) -> None:
agent_a = _EchoAgent("A", description="Echo agent A") agent_a = _EchoAgent("A", description="Echo agent A")