Fix GraphFlowManager termination to prevent _StopAgent from polluting conversation context (#6752)

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: ekzhu <320302+ekzhu@users.noreply.github.com>
Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
This commit is contained in:
Copilot 2025-07-06 00:57:03 -07:00 committed by GitHub
parent c23b9454a8
commit e10767421f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 63 additions and 80 deletions

View File

@ -2,20 +2,16 @@ import asyncio
from collections import Counter, deque
from typing import Any, Callable, Deque, Dict, List, Literal, Mapping, Sequence, Set, Union
from autogen_core import AgentRuntime, CancellationToken, Component, ComponentModel
from autogen_core import AgentRuntime, Component, ComponentModel
from pydantic import BaseModel, Field, model_validator
from typing_extensions import Self
from autogen_agentchat.agents import BaseChatAgent
from autogen_agentchat.base import ChatAgent, OrTerminationCondition, Response, TerminationCondition
from autogen_agentchat.conditions import StopMessageTermination
from autogen_agentchat.base import ChatAgent, TerminationCondition
from autogen_agentchat.messages import (
BaseAgentEvent,
BaseChatMessage,
ChatMessage,
MessageFactory,
StopMessage,
TextMessage,
)
from autogen_agentchat.state import BaseGroupChatManagerState
from autogen_agentchat.teams import BaseGroupChat
@ -23,8 +19,7 @@ from autogen_agentchat.teams import BaseGroupChat
from ..._group_chat._base_group_chat_manager import BaseGroupChatManager
from ..._group_chat._events import GroupChatTermination
_DIGRAPH_STOP_AGENT_NAME = "DiGraphStopAgent"
_DIGRAPH_STOP_AGENT_MESSAGE = "Digraph execution is complete"
_DIGRAPH_STOP_MESSAGE = "Digraph execution is complete"
class DiGraphEdge(BaseModel):
@ -469,17 +464,44 @@ class GraphFlowManager(BaseGroupChatManager):
# Reset the bookkeeping for the specific activation groups that were triggered
self._reset_triggered_activation_groups(speaker)
# If there are no speakers, trigger the stop agent.
if not speakers:
speakers = [_DIGRAPH_STOP_AGENT_NAME]
# Reset the execution state when the stop agent is selected, as this means the graph has naturally completed
self._reset_execution_state()
return speakers
async def validate_group_state(self, messages: List[BaseChatMessage] | None) -> None:
pass
async def _apply_termination_condition(
self, delta: Sequence[BaseAgentEvent | BaseChatMessage], increment_turn_count: bool = False
) -> bool:
"""Apply termination condition including graph-specific completion logic.
First checks if graph execution is complete, then checks standard termination conditions.
Args:
delta: The message delta to check termination conditions against
increment_turn_count: Whether to increment the turn count
Returns:
True if the conversation should be terminated, False otherwise
"""
# Check if the graph execution is complete (no ready speakers) - prioritize this check
if not self._ready:
stop_message = StopMessage(
content=_DIGRAPH_STOP_MESSAGE,
source=self._name,
)
# Reset the execution state when the graph has naturally completed
self._reset_execution_state()
# Reset the termination conditions and turn count.
if self._termination_condition is not None:
await self._termination_condition.reset()
self._current_turn = 0
# Signal termination to the caller of the team.
await self._signal_termination(stop_message)
return True
# Apply the standard termination conditions from the base class
return await super()._apply_termination_condition(delta, increment_turn_count)
def _reset_execution_state(self) -> None:
"""Reset the graph execution state to the initial state."""
self._remaining = {target: Counter(groups) for target, groups in self._graph.get_remaining_map().items()}
@ -514,21 +536,6 @@ class GraphFlowManager(BaseGroupChatManager):
self._reset_execution_state()
class _StopAgent(BaseChatAgent):
def __init__(self) -> None:
super().__init__(_DIGRAPH_STOP_AGENT_NAME, "Agent that terminates the GraphFlow.")
@property
def produced_message_types(self) -> Sequence[type[ChatMessage]]:
return (TextMessage, StopMessage)
async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response:
return Response(chat_message=StopMessage(content=_DIGRAPH_STOP_AGENT_MESSAGE, source=self.name))
async def on_reset(self, cancellation_token: CancellationToken) -> None:
pass
class GraphFlowConfig(BaseModel):
"""The declarative configuration for GraphFlow."""
@ -779,15 +786,8 @@ class GraphFlow(BaseGroupChat, Component[GraphFlowConfig]):
self._input_participants = participants
self._input_termination_condition = termination_condition
stop_agent = _StopAgent()
stop_agent_termination = StopMessageTermination()
termination_condition = (
stop_agent_termination
if not termination_condition
else OrTerminationCondition(stop_agent_termination, termination_condition)
)
participants = [stop_agent] + participants
# No longer add _StopAgent or StopMessageTermination
# Termination is now handled directly in GraphFlowManager._apply_termination_condition
super().__init__(
participants,
group_chat_manager_name="GraphManager",

View File

@ -24,7 +24,6 @@ from autogen_agentchat.teams._group_chat._events import ( # type: ignore[attr-d
GroupChatTermination,
)
from autogen_agentchat.teams._group_chat._graph._digraph_group_chat import (
_DIGRAPH_STOP_AGENT_NAME, # pyright: ignore[reportPrivateUsage]
DiGraph,
DiGraphEdge,
DiGraphNode,
@ -461,14 +460,13 @@ async def test_digraph_group_chat_sequential_execution(runtime: AgentRuntime | N
# Run the chat
result: TaskResult = await team.run(task="Hello from User")
assert len(result.messages) == 5
assert len(result.messages) == 4
assert isinstance(result.messages[0], TextMessage)
assert result.messages[0].source == "user"
assert result.messages[1].source == "A"
assert result.messages[2].source == "B"
assert result.messages[3].source == "C"
assert result.messages[4].source == _DIGRAPH_STOP_AGENT_NAME
assert all(isinstance(m, TextMessage) for m in result.messages[:-1])
assert all(isinstance(m, TextMessage) for m in result.messages)
assert result.stop_reason is not None
@ -494,11 +492,10 @@ async def test_digraph_group_chat_parallel_fanout(runtime: AgentRuntime | None)
)
result: TaskResult = await team.run(task="Start")
assert len(result.messages) == 5
assert len(result.messages) == 4
assert result.messages[0].source == "user"
assert result.messages[1].source == "A"
assert set(m.source for m in result.messages[2:-1]) == {"B", "C"}
assert result.messages[-1].source == _DIGRAPH_STOP_AGENT_NAME
assert set(m.source for m in result.messages[2:]) == {"B", "C"}
assert result.stop_reason is not None
@ -524,11 +521,10 @@ async def test_digraph_group_chat_parallel_join_all(runtime: AgentRuntime | None
)
result: TaskResult = await team.run(task="Go")
assert len(result.messages) == 5
assert len(result.messages) == 4
assert result.messages[0].source == "user"
assert set([result.messages[1].source, result.messages[2].source]) == {"A", "B"}
assert result.messages[3].source == "C"
assert result.messages[-1].source == _DIGRAPH_STOP_AGENT_NAME
assert result.stop_reason is not None
@ -555,12 +551,12 @@ async def test_digraph_group_chat_parallel_join_any(runtime: AgentRuntime | None
result: TaskResult = await team.run(task="Start")
assert len(result.messages) == 5
assert len(result.messages) == 4
assert result.messages[0].source == "user"
sources = [m.source for m in result.messages[1:]]
# C must be last
assert sources[-2] == "C"
assert sources[-1] == "C"
# A and B must both execute
assert {"A", "B"}.issubset(set(sources))
@ -570,7 +566,6 @@ async def test_digraph_group_chat_parallel_join_any(runtime: AgentRuntime | None
index_b = sources.index("B")
index_c = sources.index("C")
assert index_c > min(index_a, index_b)
assert result.messages[-1].source == _DIGRAPH_STOP_AGENT_NAME
assert result.stop_reason is not None
@ -594,10 +589,9 @@ async def test_digraph_group_chat_multiple_start_nodes(runtime: AgentRuntime | N
)
result: TaskResult = await team.run(task="Start")
assert len(result.messages) == 4
assert len(result.messages) == 3
assert result.messages[0].source == "user"
assert set(m.source for m in result.messages[1:-1]) == {"A", "B"}
assert result.messages[-1].source == _DIGRAPH_STOP_AGENT_NAME
assert set(m.source for m in result.messages[1:]) == {"A", "B"}
assert result.stop_reason is not None
@ -625,11 +619,10 @@ async def test_digraph_group_chat_disconnected_graph(runtime: AgentRuntime | Non
)
result: TaskResult = await team.run(task="Go")
assert len(result.messages) == 6
assert len(result.messages) == 5
assert result.messages[0].source == "user"
assert {"A", "C"} == set([result.messages[1].source, result.messages[2].source])
assert {"B", "D"} == set([result.messages[3].source, result.messages[4].source])
assert result.messages[-1].source == _DIGRAPH_STOP_AGENT_NAME
assert result.stop_reason is not None
@ -733,16 +726,14 @@ async def test_digraph_group_chat_loop_with_exit_condition(runtime: AgentRuntime
"A",
"B",
"C",
_DIGRAPH_STOP_AGENT_NAME,
]
actual_sources = [m.source for m in result.messages]
assert actual_sources == expected_sources
assert result.stop_reason is not None
assert result.messages[-2].source == "C"
assert any(m.content == "exit" for m in result.messages[:-1]) # type: ignore[attr-defined,union-attr]
assert result.messages[-1].source == _DIGRAPH_STOP_AGENT_NAME
assert result.messages[-1].source == "C"
assert any(m.content == "exit" for m in result.messages) # type: ignore[attr-defined,union-attr]
@pytest.mark.asyncio
@ -796,16 +787,14 @@ async def test_digraph_group_chat_loop_with_self_cycle(runtime: AgentRuntime | N
"B", # 2nd loop
"B",
"C",
_DIGRAPH_STOP_AGENT_NAME,
]
actual_sources = [m.source for m in result.messages]
assert actual_sources == expected_sources
assert result.stop_reason is not None
assert result.messages[-2].source == "C"
assert any(m.content == "exit" for m in result.messages[:-1]) # type: ignore[attr-defined,union-attr]
assert result.messages[-1].source == _DIGRAPH_STOP_AGENT_NAME
assert result.messages[-1].source == "C"
assert any(m.content == "exit" for m in result.messages) # type: ignore[attr-defined,union-attr]
@pytest.mark.asyncio
@ -886,16 +875,14 @@ async def test_digraph_group_chat_loop_with_two_cycles(runtime: AgentRuntime | N
"Y", # O -> Y
"O", # Y -> O
"E", # O -> E
_DIGRAPH_STOP_AGENT_NAME,
]
actual_sources = [m.source for m in result.messages]
assert actual_sources == expected_sources
assert result.stop_reason is not None
assert result.messages[-2].source == "E"
assert any(m.content == "exit" for m in result.messages[:-1]) # type: ignore[attr-defined,union-attr]
assert result.messages[-1].source == _DIGRAPH_STOP_AGENT_NAME
assert result.messages[-1].source == "E"
assert any(m.content == "exit" for m in result.messages) # type: ignore[attr-defined,union-attr]
@pytest.mark.asyncio
@ -1414,7 +1401,7 @@ async def test_graph_builder_sequential_execution(runtime: AgentRuntime | None)
)
result = await team.run(task="Start")
assert [m.source for m in result.messages[1:-1]] == ["A", "B", "C"]
assert [m.source for m in result.messages[1:]] == ["A", "B", "C"]
assert result.stop_reason is not None
@ -1546,9 +1533,9 @@ async def test_graph_flow_serialize_deserialize() -> None:
assert isinstance(results.messages[2], TextMessage)
assert results.messages[2].source == "B"
assert results.messages[2].content == "0"
assert isinstance(results.messages[-1], StopMessage)
assert results.messages[-1].source == _DIGRAPH_STOP_AGENT_NAME
assert results.messages[-1].content == "Digraph execution is complete"
# No stop agent message should appear in the conversation
assert all(not isinstance(m, StopMessage) for m in results.messages)
assert results.stop_reason is not None
@pytest.mark.asyncio
@ -1590,9 +1577,8 @@ async def test_graph_flow_stateful_pause_and_resume_with_termination() -> None:
# Resume.
result = await new_team.run()
assert len(result.messages) == 2
assert len(result.messages) == 1
assert result.messages[0].source == "B"
assert result.messages[1].source == _DIGRAPH_STOP_AGENT_NAME
@pytest.mark.asyncio
@ -1656,27 +1642,25 @@ async def test_digraph_group_chat_multiple_task_execution(runtime: AgentRuntime
# Run the first task
result1: TaskResult = await team.run(task="First task")
assert len(result1.messages) == 5
assert len(result1.messages) == 4
assert isinstance(result1.messages[0], TextMessage)
assert result1.messages[0].source == "user"
assert result1.messages[0].content == "First task"
assert result1.messages[1].source == "A"
assert result1.messages[2].source == "B"
assert result1.messages[3].source == "C"
assert result1.messages[4].source == _DIGRAPH_STOP_AGENT_NAME
assert result1.stop_reason is not None
# Run the second task - should work without explicit reset
result2: TaskResult = await team.run(task="Second task")
assert len(result2.messages) == 5
assert len(result2.messages) == 4
assert isinstance(result2.messages[0], TextMessage)
assert result2.messages[0].source == "user"
assert result2.messages[0].content == "Second task"
assert result2.messages[1].source == "A"
assert result2.messages[2].source == "B"
assert result2.messages[3].source == "C"
assert result2.messages[4].source == _DIGRAPH_STOP_AGENT_NAME
assert result2.stop_reason is not None
# Verify agents were properly reset and executed again
@ -1728,10 +1712,9 @@ async def test_digraph_group_chat_resume_with_termination_condition(runtime: Age
# Resume the graph flow with no task to continue where it left off
result2: TaskResult = await team.run()
# Should continue and execute C, then complete with stop agent
assert len(result2.messages) == 2
# Should continue and execute C, then complete without stop agent message
assert len(result2.messages) == 1
assert result2.messages[0].source == "C"
assert result2.messages[1].source == _DIGRAPH_STOP_AGENT_NAME
assert result2.stop_reason is not None
# Verify C now ran and the execution state was preserved