diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_graph/_digraph_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_graph/_digraph_group_chat.py index 544afb222..c269e1963 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_graph/_digraph_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_graph/_digraph_group_chat.py @@ -2,20 +2,16 @@ import asyncio from collections import Counter, deque from typing import Any, Callable, Deque, Dict, List, Literal, Mapping, Sequence, Set, Union -from autogen_core import AgentRuntime, CancellationToken, Component, ComponentModel +from autogen_core import AgentRuntime, Component, ComponentModel from pydantic import BaseModel, Field, model_validator from typing_extensions import Self -from autogen_agentchat.agents import BaseChatAgent -from autogen_agentchat.base import ChatAgent, OrTerminationCondition, Response, TerminationCondition -from autogen_agentchat.conditions import StopMessageTermination +from autogen_agentchat.base import ChatAgent, TerminationCondition from autogen_agentchat.messages import ( BaseAgentEvent, BaseChatMessage, - ChatMessage, MessageFactory, StopMessage, - TextMessage, ) from autogen_agentchat.state import BaseGroupChatManagerState from autogen_agentchat.teams import BaseGroupChat @@ -23,8 +19,7 @@ from autogen_agentchat.teams import BaseGroupChat from ..._group_chat._base_group_chat_manager import BaseGroupChatManager from ..._group_chat._events import GroupChatTermination -_DIGRAPH_STOP_AGENT_NAME = "DiGraphStopAgent" -_DIGRAPH_STOP_AGENT_MESSAGE = "Digraph execution is complete" +_DIGRAPH_STOP_MESSAGE = "Digraph execution is complete" class DiGraphEdge(BaseModel): @@ -469,17 +464,44 @@ class GraphFlowManager(BaseGroupChatManager): # Reset the bookkeeping for the specific activation groups that were triggered self._reset_triggered_activation_groups(speaker) - # If there are no speakers, trigger the stop agent. - if not speakers: - speakers = [_DIGRAPH_STOP_AGENT_NAME] - # Reset the execution state when the stop agent is selected, as this means the graph has naturally completed - self._reset_execution_state() - return speakers async def validate_group_state(self, messages: List[BaseChatMessage] | None) -> None: pass + async def _apply_termination_condition( + self, delta: Sequence[BaseAgentEvent | BaseChatMessage], increment_turn_count: bool = False + ) -> bool: + """Apply termination condition including graph-specific completion logic. + + First checks if graph execution is complete, then checks standard termination conditions. + + Args: + delta: The message delta to check termination conditions against + increment_turn_count: Whether to increment the turn count + + Returns: + True if the conversation should be terminated, False otherwise + """ + # Check if the graph execution is complete (no ready speakers) - prioritize this check + if not self._ready: + stop_message = StopMessage( + content=_DIGRAPH_STOP_MESSAGE, + source=self._name, + ) + # Reset the execution state when the graph has naturally completed + self._reset_execution_state() + # Reset the termination conditions and turn count. + if self._termination_condition is not None: + await self._termination_condition.reset() + self._current_turn = 0 + # Signal termination to the caller of the team. + await self._signal_termination(stop_message) + return True + + # Apply the standard termination conditions from the base class + return await super()._apply_termination_condition(delta, increment_turn_count) + def _reset_execution_state(self) -> None: """Reset the graph execution state to the initial state.""" self._remaining = {target: Counter(groups) for target, groups in self._graph.get_remaining_map().items()} @@ -514,21 +536,6 @@ class GraphFlowManager(BaseGroupChatManager): self._reset_execution_state() -class _StopAgent(BaseChatAgent): - def __init__(self) -> None: - super().__init__(_DIGRAPH_STOP_AGENT_NAME, "Agent that terminates the GraphFlow.") - - @property - def produced_message_types(self) -> Sequence[type[ChatMessage]]: - return (TextMessage, StopMessage) - - async def on_messages(self, messages: Sequence[BaseChatMessage], cancellation_token: CancellationToken) -> Response: - return Response(chat_message=StopMessage(content=_DIGRAPH_STOP_AGENT_MESSAGE, source=self.name)) - - async def on_reset(self, cancellation_token: CancellationToken) -> None: - pass - - class GraphFlowConfig(BaseModel): """The declarative configuration for GraphFlow.""" @@ -779,15 +786,8 @@ class GraphFlow(BaseGroupChat, Component[GraphFlowConfig]): self._input_participants = participants self._input_termination_condition = termination_condition - stop_agent = _StopAgent() - stop_agent_termination = StopMessageTermination() - termination_condition = ( - stop_agent_termination - if not termination_condition - else OrTerminationCondition(stop_agent_termination, termination_condition) - ) - - participants = [stop_agent] + participants + # No longer add _StopAgent or StopMessageTermination + # Termination is now handled directly in GraphFlowManager._apply_termination_condition super().__init__( participants, group_chat_manager_name="GraphManager", diff --git a/python/packages/autogen-agentchat/tests/test_group_chat_graph.py b/python/packages/autogen-agentchat/tests/test_group_chat_graph.py index adcc42958..98cc8ca66 100644 --- a/python/packages/autogen-agentchat/tests/test_group_chat_graph.py +++ b/python/packages/autogen-agentchat/tests/test_group_chat_graph.py @@ -24,7 +24,6 @@ from autogen_agentchat.teams._group_chat._events import ( # type: ignore[attr-d GroupChatTermination, ) from autogen_agentchat.teams._group_chat._graph._digraph_group_chat import ( - _DIGRAPH_STOP_AGENT_NAME, # pyright: ignore[reportPrivateUsage] DiGraph, DiGraphEdge, DiGraphNode, @@ -461,14 +460,13 @@ async def test_digraph_group_chat_sequential_execution(runtime: AgentRuntime | N # Run the chat result: TaskResult = await team.run(task="Hello from User") - assert len(result.messages) == 5 + assert len(result.messages) == 4 assert isinstance(result.messages[0], TextMessage) assert result.messages[0].source == "user" assert result.messages[1].source == "A" assert result.messages[2].source == "B" assert result.messages[3].source == "C" - assert result.messages[4].source == _DIGRAPH_STOP_AGENT_NAME - assert all(isinstance(m, TextMessage) for m in result.messages[:-1]) + assert all(isinstance(m, TextMessage) for m in result.messages) assert result.stop_reason is not None @@ -494,11 +492,10 @@ async def test_digraph_group_chat_parallel_fanout(runtime: AgentRuntime | None) ) result: TaskResult = await team.run(task="Start") - assert len(result.messages) == 5 + assert len(result.messages) == 4 assert result.messages[0].source == "user" assert result.messages[1].source == "A" - assert set(m.source for m in result.messages[2:-1]) == {"B", "C"} - assert result.messages[-1].source == _DIGRAPH_STOP_AGENT_NAME + assert set(m.source for m in result.messages[2:]) == {"B", "C"} assert result.stop_reason is not None @@ -524,11 +521,10 @@ async def test_digraph_group_chat_parallel_join_all(runtime: AgentRuntime | None ) result: TaskResult = await team.run(task="Go") - assert len(result.messages) == 5 + assert len(result.messages) == 4 assert result.messages[0].source == "user" assert set([result.messages[1].source, result.messages[2].source]) == {"A", "B"} assert result.messages[3].source == "C" - assert result.messages[-1].source == _DIGRAPH_STOP_AGENT_NAME assert result.stop_reason is not None @@ -555,12 +551,12 @@ async def test_digraph_group_chat_parallel_join_any(runtime: AgentRuntime | None result: TaskResult = await team.run(task="Start") - assert len(result.messages) == 5 + assert len(result.messages) == 4 assert result.messages[0].source == "user" sources = [m.source for m in result.messages[1:]] # C must be last - assert sources[-2] == "C" + assert sources[-1] == "C" # A and B must both execute assert {"A", "B"}.issubset(set(sources)) @@ -570,7 +566,6 @@ async def test_digraph_group_chat_parallel_join_any(runtime: AgentRuntime | None index_b = sources.index("B") index_c = sources.index("C") assert index_c > min(index_a, index_b) - assert result.messages[-1].source == _DIGRAPH_STOP_AGENT_NAME assert result.stop_reason is not None @@ -594,10 +589,9 @@ async def test_digraph_group_chat_multiple_start_nodes(runtime: AgentRuntime | N ) result: TaskResult = await team.run(task="Start") - assert len(result.messages) == 4 + assert len(result.messages) == 3 assert result.messages[0].source == "user" - assert set(m.source for m in result.messages[1:-1]) == {"A", "B"} - assert result.messages[-1].source == _DIGRAPH_STOP_AGENT_NAME + assert set(m.source for m in result.messages[1:]) == {"A", "B"} assert result.stop_reason is not None @@ -625,11 +619,10 @@ async def test_digraph_group_chat_disconnected_graph(runtime: AgentRuntime | Non ) result: TaskResult = await team.run(task="Go") - assert len(result.messages) == 6 + assert len(result.messages) == 5 assert result.messages[0].source == "user" assert {"A", "C"} == set([result.messages[1].source, result.messages[2].source]) assert {"B", "D"} == set([result.messages[3].source, result.messages[4].source]) - assert result.messages[-1].source == _DIGRAPH_STOP_AGENT_NAME assert result.stop_reason is not None @@ -733,16 +726,14 @@ async def test_digraph_group_chat_loop_with_exit_condition(runtime: AgentRuntime "A", "B", "C", - _DIGRAPH_STOP_AGENT_NAME, ] actual_sources = [m.source for m in result.messages] assert actual_sources == expected_sources assert result.stop_reason is not None - assert result.messages[-2].source == "C" - assert any(m.content == "exit" for m in result.messages[:-1]) # type: ignore[attr-defined,union-attr] - assert result.messages[-1].source == _DIGRAPH_STOP_AGENT_NAME + assert result.messages[-1].source == "C" + assert any(m.content == "exit" for m in result.messages) # type: ignore[attr-defined,union-attr] @pytest.mark.asyncio @@ -796,16 +787,14 @@ async def test_digraph_group_chat_loop_with_self_cycle(runtime: AgentRuntime | N "B", # 2nd loop "B", "C", - _DIGRAPH_STOP_AGENT_NAME, ] actual_sources = [m.source for m in result.messages] assert actual_sources == expected_sources assert result.stop_reason is not None - assert result.messages[-2].source == "C" - assert any(m.content == "exit" for m in result.messages[:-1]) # type: ignore[attr-defined,union-attr] - assert result.messages[-1].source == _DIGRAPH_STOP_AGENT_NAME + assert result.messages[-1].source == "C" + assert any(m.content == "exit" for m in result.messages) # type: ignore[attr-defined,union-attr] @pytest.mark.asyncio @@ -886,16 +875,14 @@ async def test_digraph_group_chat_loop_with_two_cycles(runtime: AgentRuntime | N "Y", # O -> Y "O", # Y -> O "E", # O -> E - _DIGRAPH_STOP_AGENT_NAME, ] actual_sources = [m.source for m in result.messages] assert actual_sources == expected_sources assert result.stop_reason is not None - assert result.messages[-2].source == "E" - assert any(m.content == "exit" for m in result.messages[:-1]) # type: ignore[attr-defined,union-attr] - assert result.messages[-1].source == _DIGRAPH_STOP_AGENT_NAME + assert result.messages[-1].source == "E" + assert any(m.content == "exit" for m in result.messages) # type: ignore[attr-defined,union-attr] @pytest.mark.asyncio @@ -1414,7 +1401,7 @@ async def test_graph_builder_sequential_execution(runtime: AgentRuntime | None) ) result = await team.run(task="Start") - assert [m.source for m in result.messages[1:-1]] == ["A", "B", "C"] + assert [m.source for m in result.messages[1:]] == ["A", "B", "C"] assert result.stop_reason is not None @@ -1546,9 +1533,9 @@ async def test_graph_flow_serialize_deserialize() -> None: assert isinstance(results.messages[2], TextMessage) assert results.messages[2].source == "B" assert results.messages[2].content == "0" - assert isinstance(results.messages[-1], StopMessage) - assert results.messages[-1].source == _DIGRAPH_STOP_AGENT_NAME - assert results.messages[-1].content == "Digraph execution is complete" + # No stop agent message should appear in the conversation + assert all(not isinstance(m, StopMessage) for m in results.messages) + assert results.stop_reason is not None @pytest.mark.asyncio @@ -1590,9 +1577,8 @@ async def test_graph_flow_stateful_pause_and_resume_with_termination() -> None: # Resume. result = await new_team.run() - assert len(result.messages) == 2 + assert len(result.messages) == 1 assert result.messages[0].source == "B" - assert result.messages[1].source == _DIGRAPH_STOP_AGENT_NAME @pytest.mark.asyncio @@ -1656,27 +1642,25 @@ async def test_digraph_group_chat_multiple_task_execution(runtime: AgentRuntime # Run the first task result1: TaskResult = await team.run(task="First task") - assert len(result1.messages) == 5 + assert len(result1.messages) == 4 assert isinstance(result1.messages[0], TextMessage) assert result1.messages[0].source == "user" assert result1.messages[0].content == "First task" assert result1.messages[1].source == "A" assert result1.messages[2].source == "B" assert result1.messages[3].source == "C" - assert result1.messages[4].source == _DIGRAPH_STOP_AGENT_NAME assert result1.stop_reason is not None # Run the second task - should work without explicit reset result2: TaskResult = await team.run(task="Second task") - assert len(result2.messages) == 5 + assert len(result2.messages) == 4 assert isinstance(result2.messages[0], TextMessage) assert result2.messages[0].source == "user" assert result2.messages[0].content == "Second task" assert result2.messages[1].source == "A" assert result2.messages[2].source == "B" assert result2.messages[3].source == "C" - assert result2.messages[4].source == _DIGRAPH_STOP_AGENT_NAME assert result2.stop_reason is not None # Verify agents were properly reset and executed again @@ -1728,10 +1712,9 @@ async def test_digraph_group_chat_resume_with_termination_condition(runtime: Age # Resume the graph flow with no task to continue where it left off result2: TaskResult = await team.run() - # Should continue and execute C, then complete with stop agent - assert len(result2.messages) == 2 + # Should continue and execute C, then complete without stop agent message + assert len(result2.messages) == 1 assert result2.messages[0].source == "C" - assert result2.messages[1].source == _DIGRAPH_STOP_AGENT_NAME assert result2.stop_reason is not None # Verify C now ran and the execution state was preserved