Fix GraphFlowManager termination to prevent _StopAgent from polluting conversation context (#6752)

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: ekzhu <320302+ekzhu@users.noreply.github.com>
Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
This commit is contained in:
Copilot 2025-07-06 00:57:03 -07:00 committed by GitHub
parent c23b9454a8
commit e10767421f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 63 additions and 80 deletions

View File

@ -2,20 +2,16 @@ import asyncio
from collections import Counter, deque from collections import Counter, deque
from typing import Any, Callable, Deque, Dict, List, Literal, Mapping, Sequence, Set, Union from typing import Any, Callable, Deque, Dict, List, Literal, Mapping, Sequence, Set, Union
from autogen_core import AgentRuntime, CancellationToken, Component, ComponentModel from autogen_core import AgentRuntime, Component, ComponentModel
from pydantic import BaseModel, Field, model_validator from pydantic import BaseModel, Field, model_validator
from typing_extensions import Self from typing_extensions import Self
from autogen_agentchat.agents import BaseChatAgent from autogen_agentchat.base import ChatAgent, TerminationCondition
from autogen_agentchat.base import ChatAgent, OrTerminationCondition, Response, TerminationCondition
from autogen_agentchat.conditions import StopMessageTermination
from autogen_agentchat.messages import ( from autogen_agentchat.messages import (
BaseAgentEvent, BaseAgentEvent,
BaseChatMessage, BaseChatMessage,
ChatMessage,
MessageFactory, MessageFactory,
StopMessage, StopMessage,
TextMessage,
) )
from autogen_agentchat.state import BaseGroupChatManagerState from autogen_agentchat.state import BaseGroupChatManagerState
from autogen_agentchat.teams import BaseGroupChat from autogen_agentchat.teams import BaseGroupChat
@ -23,8 +19,7 @@ from autogen_agentchat.teams import BaseGroupChat
from ..._group_chat._base_group_chat_manager import BaseGroupChatManager from ..._group_chat._base_group_chat_manager import BaseGroupChatManager
from ..._group_chat._events import GroupChatTermination from ..._group_chat._events import GroupChatTermination
_DIGRAPH_STOP_AGENT_NAME = "DiGraphStopAgent" _DIGRAPH_STOP_MESSAGE = "Digraph execution is complete"
_DIGRAPH_STOP_AGENT_MESSAGE = "Digraph execution is complete"
class DiGraphEdge(BaseModel): class DiGraphEdge(BaseModel):
@ -469,17 +464,44 @@ class GraphFlowManager(BaseGroupChatManager):
# Reset the bookkeeping for the specific activation groups that were triggered # Reset the bookkeeping for the specific activation groups that were triggered
self._reset_triggered_activation_groups(speaker) self._reset_triggered_activation_groups(speaker)
# If there are no speakers, trigger the stop agent.
if not speakers:
speakers = [_DIGRAPH_STOP_AGENT_NAME]
# Reset the execution state when the stop agent is selected, as this means the graph has naturally completed
self._reset_execution_state()
return speakers return speakers
async def validate_group_state(self, messages: List[BaseChatMessage] | None) -> None: async def validate_group_state(self, messages: List[BaseChatMessage] | None) -> None:
pass pass
async def _apply_termination_condition(
self, delta: Sequence[BaseAgentEvent | BaseChatMessage], increment_turn_count: bool = False
) -> bool:
"""Apply termination condition including graph-specific completion logic.
First checks if graph execution is complete, then checks standard termination conditions.
Args:
delta: The message delta to check termination conditions against
increment_turn_count: Whether to increment the turn count
Returns:
True if the conversation should be terminated, False otherwise
"""
# Check if the graph execution is complete (no ready speakers) - prioritize this check
if not self._ready:
stop_message = StopMessage(
content=_DIGRAPH_STOP_MESSAGE,
source=self._name,
)
# Reset the execution state when the graph has naturally completed
self._reset_execution_state()
# Reset the termination conditions and turn count.
if self._termination_condition is not None:
await self._termination_condition.reset()
self._current_turn = 0
# Signal termination to the caller of the team.
await self._signal_termination(stop_message)
return True
# Apply the standard termination conditions from the base class
return await super()._apply_termination_condition(delta, increment_turn_count)
def _reset_execution_state(self) -> None: def _reset_execution_state(self) -> None:
"""Reset the graph execution state to the initial state.""" """Reset the graph execution state to the initial state."""
self._remaining = {target: Counter(groups) for target, groups in self._graph.get_remaining_map().items()} self._remaining = {target: Counter(groups) for target, groups in self._graph.get_remaining_map().items()}
@ -514,21 +536,6 @@ class GraphFlowManager(BaseGroupChatManager):
self._reset_execution_state() self._reset_execution_state()
class _StopAgent(BaseChatAgent):
def __init__(self) -> None:
super().__init__(_DIGRAPH_STOP_AGENT_NAME, "Agent that terminates the GraphFlow.")
@property
def produced_message_types(self) -> Sequence[type[ChatMessage]]:
return (TextMessage, StopMessage)
async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response:
return Response(chat_message=StopMessage(content=_DIGRAPH_STOP_AGENT_MESSAGE, source=self.name))
async def on_reset(self, cancellation_token: CancellationToken) -> None:
pass
class GraphFlowConfig(BaseModel): class GraphFlowConfig(BaseModel):
"""The declarative configuration for GraphFlow.""" """The declarative configuration for GraphFlow."""
@ -779,15 +786,8 @@ class GraphFlow(BaseGroupChat, Component[GraphFlowConfig]):
self._input_participants = participants self._input_participants = participants
self._input_termination_condition = termination_condition self._input_termination_condition = termination_condition
stop_agent = _StopAgent() # No longer add _StopAgent or StopMessageTermination
stop_agent_termination = StopMessageTermination() # Termination is now handled directly in GraphFlowManager._apply_termination_condition
termination_condition = (
stop_agent_termination
if not termination_condition
else OrTerminationCondition(stop_agent_termination, termination_condition)
)
participants = [stop_agent] + participants
super().__init__( super().__init__(
participants, participants,
group_chat_manager_name="GraphManager", group_chat_manager_name="GraphManager",

View File

@ -24,7 +24,6 @@ from autogen_agentchat.teams._group_chat._events import ( # type: ignore[attr-d
GroupChatTermination, GroupChatTermination,
) )
from autogen_agentchat.teams._group_chat._graph._digraph_group_chat import ( from autogen_agentchat.teams._group_chat._graph._digraph_group_chat import (
_DIGRAPH_STOP_AGENT_NAME, # pyright: ignore[reportPrivateUsage]
DiGraph, DiGraph,
DiGraphEdge, DiGraphEdge,
DiGraphNode, DiGraphNode,
@ -461,14 +460,13 @@ async def test_digraph_group_chat_sequential_execution(runtime: AgentRuntime | N
# Run the chat # Run the chat
result: TaskResult = await team.run(task="Hello from User") result: TaskResult = await team.run(task="Hello from User")
assert len(result.messages) == 5 assert len(result.messages) == 4
assert isinstance(result.messages[0], TextMessage) assert isinstance(result.messages[0], TextMessage)
assert result.messages[0].source == "user" assert result.messages[0].source == "user"
assert result.messages[1].source == "A" assert result.messages[1].source == "A"
assert result.messages[2].source == "B" assert result.messages[2].source == "B"
assert result.messages[3].source == "C" assert result.messages[3].source == "C"
assert result.messages[4].source == _DIGRAPH_STOP_AGENT_NAME assert all(isinstance(m, TextMessage) for m in result.messages)
assert all(isinstance(m, TextMessage) for m in result.messages[:-1])
assert result.stop_reason is not None assert result.stop_reason is not None
@ -494,11 +492,10 @@ async def test_digraph_group_chat_parallel_fanout(runtime: AgentRuntime | None)
) )
result: TaskResult = await team.run(task="Start") result: TaskResult = await team.run(task="Start")
assert len(result.messages) == 5 assert len(result.messages) == 4
assert result.messages[0].source == "user" assert result.messages[0].source == "user"
assert result.messages[1].source == "A" assert result.messages[1].source == "A"
assert set(m.source for m in result.messages[2:-1]) == {"B", "C"} assert set(m.source for m in result.messages[2:]) == {"B", "C"}
assert result.messages[-1].source == _DIGRAPH_STOP_AGENT_NAME
assert result.stop_reason is not None assert result.stop_reason is not None
@ -524,11 +521,10 @@ async def test_digraph_group_chat_parallel_join_all(runtime: AgentRuntime | None
) )
result: TaskResult = await team.run(task="Go") result: TaskResult = await team.run(task="Go")
assert len(result.messages) == 5 assert len(result.messages) == 4
assert result.messages[0].source == "user" assert result.messages[0].source == "user"
assert set([result.messages[1].source, result.messages[2].source]) == {"A", "B"} assert set([result.messages[1].source, result.messages[2].source]) == {"A", "B"}
assert result.messages[3].source == "C" assert result.messages[3].source == "C"
assert result.messages[-1].source == _DIGRAPH_STOP_AGENT_NAME
assert result.stop_reason is not None assert result.stop_reason is not None
@ -555,12 +551,12 @@ async def test_digraph_group_chat_parallel_join_any(runtime: AgentRuntime | None
result: TaskResult = await team.run(task="Start") result: TaskResult = await team.run(task="Start")
assert len(result.messages) == 5 assert len(result.messages) == 4
assert result.messages[0].source == "user" assert result.messages[0].source == "user"
sources = [m.source for m in result.messages[1:]] sources = [m.source for m in result.messages[1:]]
# C must be last # C must be last
assert sources[-2] == "C" assert sources[-1] == "C"
# A and B must both execute # A and B must both execute
assert {"A", "B"}.issubset(set(sources)) assert {"A", "B"}.issubset(set(sources))
@ -570,7 +566,6 @@ async def test_digraph_group_chat_parallel_join_any(runtime: AgentRuntime | None
index_b = sources.index("B") index_b = sources.index("B")
index_c = sources.index("C") index_c = sources.index("C")
assert index_c > min(index_a, index_b) assert index_c > min(index_a, index_b)
assert result.messages[-1].source == _DIGRAPH_STOP_AGENT_NAME
assert result.stop_reason is not None assert result.stop_reason is not None
@ -594,10 +589,9 @@ async def test_digraph_group_chat_multiple_start_nodes(runtime: AgentRuntime | N
) )
result: TaskResult = await team.run(task="Start") result: TaskResult = await team.run(task="Start")
assert len(result.messages) == 4 assert len(result.messages) == 3
assert result.messages[0].source == "user" assert result.messages[0].source == "user"
assert set(m.source for m in result.messages[1:-1]) == {"A", "B"} assert set(m.source for m in result.messages[1:]) == {"A", "B"}
assert result.messages[-1].source == _DIGRAPH_STOP_AGENT_NAME
assert result.stop_reason is not None assert result.stop_reason is not None
@ -625,11 +619,10 @@ async def test_digraph_group_chat_disconnected_graph(runtime: AgentRuntime | Non
) )
result: TaskResult = await team.run(task="Go") result: TaskResult = await team.run(task="Go")
assert len(result.messages) == 6 assert len(result.messages) == 5
assert result.messages[0].source == "user" assert result.messages[0].source == "user"
assert {"A", "C"} == set([result.messages[1].source, result.messages[2].source]) assert {"A", "C"} == set([result.messages[1].source, result.messages[2].source])
assert {"B", "D"} == set([result.messages[3].source, result.messages[4].source]) assert {"B", "D"} == set([result.messages[3].source, result.messages[4].source])
assert result.messages[-1].source == _DIGRAPH_STOP_AGENT_NAME
assert result.stop_reason is not None assert result.stop_reason is not None
@ -733,16 +726,14 @@ async def test_digraph_group_chat_loop_with_exit_condition(runtime: AgentRuntime
"A", "A",
"B", "B",
"C", "C",
_DIGRAPH_STOP_AGENT_NAME,
] ]
actual_sources = [m.source for m in result.messages] actual_sources = [m.source for m in result.messages]
assert actual_sources == expected_sources assert actual_sources == expected_sources
assert result.stop_reason is not None assert result.stop_reason is not None
assert result.messages[-2].source == "C" assert result.messages[-1].source == "C"
assert any(m.content == "exit" for m in result.messages[:-1]) # type: ignore[attr-defined,union-attr] assert any(m.content == "exit" for m in result.messages) # type: ignore[attr-defined,union-attr]
assert result.messages[-1].source == _DIGRAPH_STOP_AGENT_NAME
@pytest.mark.asyncio @pytest.mark.asyncio
@ -796,16 +787,14 @@ async def test_digraph_group_chat_loop_with_self_cycle(runtime: AgentRuntime | N
"B", # 2nd loop "B", # 2nd loop
"B", "B",
"C", "C",
_DIGRAPH_STOP_AGENT_NAME,
] ]
actual_sources = [m.source for m in result.messages] actual_sources = [m.source for m in result.messages]
assert actual_sources == expected_sources assert actual_sources == expected_sources
assert result.stop_reason is not None assert result.stop_reason is not None
assert result.messages[-2].source == "C" assert result.messages[-1].source == "C"
assert any(m.content == "exit" for m in result.messages[:-1]) # type: ignore[attr-defined,union-attr] assert any(m.content == "exit" for m in result.messages) # type: ignore[attr-defined,union-attr]
assert result.messages[-1].source == _DIGRAPH_STOP_AGENT_NAME
@pytest.mark.asyncio @pytest.mark.asyncio
@ -886,16 +875,14 @@ async def test_digraph_group_chat_loop_with_two_cycles(runtime: AgentRuntime | N
"Y", # O -> Y "Y", # O -> Y
"O", # Y -> O "O", # Y -> O
"E", # O -> E "E", # O -> E
_DIGRAPH_STOP_AGENT_NAME,
] ]
actual_sources = [m.source for m in result.messages] actual_sources = [m.source for m in result.messages]
assert actual_sources == expected_sources assert actual_sources == expected_sources
assert result.stop_reason is not None assert result.stop_reason is not None
assert result.messages[-2].source == "E" assert result.messages[-1].source == "E"
assert any(m.content == "exit" for m in result.messages[:-1]) # type: ignore[attr-defined,union-attr] assert any(m.content == "exit" for m in result.messages) # type: ignore[attr-defined,union-attr]
assert result.messages[-1].source == _DIGRAPH_STOP_AGENT_NAME
@pytest.mark.asyncio @pytest.mark.asyncio
@ -1414,7 +1401,7 @@ async def test_graph_builder_sequential_execution(runtime: AgentRuntime | None)
) )
result = await team.run(task="Start") result = await team.run(task="Start")
assert [m.source for m in result.messages[1:-1]] == ["A", "B", "C"] assert [m.source for m in result.messages[1:]] == ["A", "B", "C"]
assert result.stop_reason is not None assert result.stop_reason is not None
@ -1546,9 +1533,9 @@ async def test_graph_flow_serialize_deserialize() -> None:
assert isinstance(results.messages[2], TextMessage) assert isinstance(results.messages[2], TextMessage)
assert results.messages[2].source == "B" assert results.messages[2].source == "B"
assert results.messages[2].content == "0" assert results.messages[2].content == "0"
assert isinstance(results.messages[-1], StopMessage) # No stop agent message should appear in the conversation
assert results.messages[-1].source == _DIGRAPH_STOP_AGENT_NAME assert all(not isinstance(m, StopMessage) for m in results.messages)
assert results.messages[-1].content == "Digraph execution is complete" assert results.stop_reason is not None
@pytest.mark.asyncio @pytest.mark.asyncio
@ -1590,9 +1577,8 @@ async def test_graph_flow_stateful_pause_and_resume_with_termination() -> None:
# Resume. # Resume.
result = await new_team.run() result = await new_team.run()
assert len(result.messages) == 2 assert len(result.messages) == 1
assert result.messages[0].source == "B" assert result.messages[0].source == "B"
assert result.messages[1].source == _DIGRAPH_STOP_AGENT_NAME
@pytest.mark.asyncio @pytest.mark.asyncio
@ -1656,27 +1642,25 @@ async def test_digraph_group_chat_multiple_task_execution(runtime: AgentRuntime
# Run the first task # Run the first task
result1: TaskResult = await team.run(task="First task") result1: TaskResult = await team.run(task="First task")
assert len(result1.messages) == 5 assert len(result1.messages) == 4
assert isinstance(result1.messages[0], TextMessage) assert isinstance(result1.messages[0], TextMessage)
assert result1.messages[0].source == "user" assert result1.messages[0].source == "user"
assert result1.messages[0].content == "First task" assert result1.messages[0].content == "First task"
assert result1.messages[1].source == "A" assert result1.messages[1].source == "A"
assert result1.messages[2].source == "B" assert result1.messages[2].source == "B"
assert result1.messages[3].source == "C" assert result1.messages[3].source == "C"
assert result1.messages[4].source == _DIGRAPH_STOP_AGENT_NAME
assert result1.stop_reason is not None assert result1.stop_reason is not None
# Run the second task - should work without explicit reset # Run the second task - should work without explicit reset
result2: TaskResult = await team.run(task="Second task") result2: TaskResult = await team.run(task="Second task")
assert len(result2.messages) == 5 assert len(result2.messages) == 4
assert isinstance(result2.messages[0], TextMessage) assert isinstance(result2.messages[0], TextMessage)
assert result2.messages[0].source == "user" assert result2.messages[0].source == "user"
assert result2.messages[0].content == "Second task" assert result2.messages[0].content == "Second task"
assert result2.messages[1].source == "A" assert result2.messages[1].source == "A"
assert result2.messages[2].source == "B" assert result2.messages[2].source == "B"
assert result2.messages[3].source == "C" assert result2.messages[3].source == "C"
assert result2.messages[4].source == _DIGRAPH_STOP_AGENT_NAME
assert result2.stop_reason is not None assert result2.stop_reason is not None
# Verify agents were properly reset and executed again # Verify agents were properly reset and executed again
@ -1728,10 +1712,9 @@ async def test_digraph_group_chat_resume_with_termination_condition(runtime: Age
# Resume the graph flow with no task to continue where it left off # Resume the graph flow with no task to continue where it left off
result2: TaskResult = await team.run() result2: TaskResult = await team.run()
# Should continue and execute C, then complete with stop agent # Should continue and execute C, then complete without stop agent message
assert len(result2.messages) == 2 assert len(result2.messages) == 1
assert result2.messages[0].source == "C" assert result2.messages[0].source == "C"
assert result2.messages[1].source == _DIGRAPH_STOP_AGENT_NAME
assert result2.stop_reason is not None assert result2.stop_reason is not None
# Verify C now ran and the execution state was preserved # Verify C now ran and the execution state was preserved