mirror of
				https://github.com/microsoft/autogen.git
				synced 2025-10-31 09:50:11 +00:00 
			
		
		
		
	 979d8ab4f1
			
		
	
	
		979d8ab4f1
		
			
		
	
	
	
	
		
			
			* initial pass on making group chats declarative * update group chat tests * update impl to include participant serialization for all teams * v1 making soc declarative * update memory test * update chatagent and team base classes * update serialization doc notebook * fomating updates
		
			
				
	
	
		
			1280 lines
		
	
	
		
			52 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			1280 lines
		
	
	
		
			52 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import asyncio
 | |
| import json
 | |
| import logging
 | |
| import tempfile
 | |
| from typing import Any, AsyncGenerator, List, Sequence
 | |
| 
 | |
| import pytest
 | |
| from autogen_agentchat import EVENT_LOGGER_NAME
 | |
| from autogen_agentchat.agents import (
 | |
|     AssistantAgent,
 | |
|     BaseChatAgent,
 | |
|     CodeExecutorAgent,
 | |
| )
 | |
| from autogen_agentchat.base import Handoff, Response, TaskResult
 | |
| from autogen_agentchat.conditions import HandoffTermination, MaxMessageTermination, TextMentionTermination
 | |
| from autogen_agentchat.messages import (
 | |
|     AgentEvent,
 | |
|     ChatMessage,
 | |
|     HandoffMessage,
 | |
|     MultiModalMessage,
 | |
|     StopMessage,
 | |
|     TextMessage,
 | |
|     ToolCallExecutionEvent,
 | |
|     ToolCallRequestEvent,
 | |
|     ToolCallSummaryMessage,
 | |
| )
 | |
| from autogen_agentchat.teams import MagenticOneGroupChat, RoundRobinGroupChat, SelectorGroupChat, Swarm
 | |
| from autogen_agentchat.teams._group_chat._round_robin_group_chat import RoundRobinGroupChatManager
 | |
| from autogen_agentchat.teams._group_chat._selector_group_chat import SelectorGroupChatManager
 | |
| from autogen_agentchat.teams._group_chat._swarm_group_chat import SwarmGroupChatManager
 | |
| from autogen_agentchat.ui import Console
 | |
| from autogen_core import AgentId, CancellationToken, FunctionCall
 | |
| from autogen_core.models import (
 | |
|     AssistantMessage,
 | |
|     FunctionExecutionResult,
 | |
|     FunctionExecutionResultMessage,
 | |
|     LLMMessage,
 | |
|     UserMessage,
 | |
| )
 | |
| from autogen_core.tools import FunctionTool
 | |
| from autogen_ext.code_executors.local import LocalCommandLineCodeExecutor
 | |
| from autogen_ext.models.openai import OpenAIChatCompletionClient
 | |
| from autogen_ext.models.replay import ReplayChatCompletionClient
 | |
| from openai.resources.chat.completions import AsyncCompletions
 | |
| from openai.types.chat.chat_completion import ChatCompletion, Choice
 | |
| from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
 | |
| from openai.types.chat.chat_completion_message import ChatCompletionMessage
 | |
| from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall, Function
 | |
| from openai.types.completion_usage import CompletionUsage
 | |
| from utils import FileLogHandler
 | |
| 
 | |
| logger = logging.getLogger(EVENT_LOGGER_NAME)
 | |
| logger.setLevel(logging.DEBUG)
 | |
| logger.addHandler(FileLogHandler("test_group_chat.log"))
 | |
| 
 | |
| 
 | |
| class _MockChatCompletion:
 | |
|     def __init__(self, chat_completions: List[ChatCompletion]) -> None:
 | |
|         self._saved_chat_completions = chat_completions
 | |
|         self._curr_index = 0
 | |
| 
 | |
|     async def mock_create(
 | |
|         self, *args: Any, **kwargs: Any
 | |
|     ) -> ChatCompletion | AsyncGenerator[ChatCompletionChunk, None]:
 | |
|         await asyncio.sleep(0.1)
 | |
|         completion = self._saved_chat_completions[self._curr_index]
 | |
|         self._curr_index += 1
 | |
|         return completion
 | |
| 
 | |
|     def reset(self) -> None:
 | |
|         self._curr_index = 0
 | |
| 
 | |
| 
 | |
| class _EchoAgent(BaseChatAgent):
 | |
|     def __init__(self, name: str, description: str) -> None:
 | |
|         super().__init__(name, description)
 | |
|         self._last_message: str | None = None
 | |
|         self._total_messages = 0
 | |
| 
 | |
|     @property
 | |
|     def produced_message_types(self) -> Sequence[type[ChatMessage]]:
 | |
|         return (TextMessage,)
 | |
| 
 | |
|     @property
 | |
|     def total_messages(self) -> int:
 | |
|         return self._total_messages
 | |
| 
 | |
|     async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
 | |
|         if len(messages) > 0:
 | |
|             assert isinstance(messages[0], TextMessage)
 | |
|             self._last_message = messages[0].content
 | |
|             self._total_messages += 1
 | |
|             return Response(chat_message=TextMessage(content=messages[0].content, source=self.name))
 | |
|         else:
 | |
|             assert self._last_message is not None
 | |
|             self._total_messages += 1
 | |
|             return Response(chat_message=TextMessage(content=self._last_message, source=self.name))
 | |
| 
 | |
|     async def on_reset(self, cancellation_token: CancellationToken) -> None:
 | |
|         self._last_message = None
 | |
| 
 | |
| 
 | |
| class _StopAgent(_EchoAgent):
 | |
|     def __init__(self, name: str, description: str, *, stop_at: int = 1) -> None:
 | |
|         super().__init__(name, description)
 | |
|         self._count = 0
 | |
|         self._stop_at = stop_at
 | |
| 
 | |
|     @property
 | |
|     def produced_message_types(self) -> Sequence[type[ChatMessage]]:
 | |
|         return (TextMessage, StopMessage)
 | |
| 
 | |
|     async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
 | |
|         self._count += 1
 | |
|         if self._count < self._stop_at:
 | |
|             return await super().on_messages(messages, cancellation_token)
 | |
|         return Response(chat_message=StopMessage(content="TERMINATE", source=self.name))
 | |
| 
 | |
| 
 | |
| def _pass_function(input: str) -> str:
 | |
|     return "pass"
 | |
| 
 | |
| 
 | |
| @pytest.mark.asyncio
 | |
| async def test_round_robin_group_chat(monkeypatch: pytest.MonkeyPatch) -> None:
 | |
|     model = "gpt-4o-2024-05-13"
 | |
|     chat_completions = [
 | |
|         ChatCompletion(
 | |
|             id="id1",
 | |
|             choices=[
 | |
|                 Choice(
 | |
|                     finish_reason="stop",
 | |
|                     index=0,
 | |
|                     message=ChatCompletionMessage(
 | |
|                         content="""Here is the program\n ```python\nprint("Hello, world!")\n```""",
 | |
|                         role="assistant",
 | |
|                     ),
 | |
|                 )
 | |
|             ],
 | |
|             created=0,
 | |
|             model=model,
 | |
|             object="chat.completion",
 | |
|             usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
 | |
|         ),
 | |
|         ChatCompletion(
 | |
|             id="id2",
 | |
|             choices=[
 | |
|                 Choice(
 | |
|                     finish_reason="stop",
 | |
|                     index=0,
 | |
|                     message=ChatCompletionMessage(content="TERMINATE", role="assistant"),
 | |
|                 )
 | |
|             ],
 | |
|             created=0,
 | |
|             model=model,
 | |
|             object="chat.completion",
 | |
|             usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
 | |
|         ),
 | |
|     ]
 | |
|     mock = _MockChatCompletion(chat_completions)
 | |
|     monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
 | |
|     with tempfile.TemporaryDirectory() as temp_dir:
 | |
|         code_executor_agent = CodeExecutorAgent(
 | |
|             "code_executor", code_executor=LocalCommandLineCodeExecutor(work_dir=temp_dir)
 | |
|         )
 | |
|         coding_assistant_agent = AssistantAgent(
 | |
|             "coding_assistant", model_client=OpenAIChatCompletionClient(model=model, api_key="")
 | |
|         )
 | |
|         termination = TextMentionTermination("TERMINATE")
 | |
|         team = RoundRobinGroupChat(
 | |
|             participants=[coding_assistant_agent, code_executor_agent], termination_condition=termination
 | |
|         )
 | |
|         result = await team.run(
 | |
|             task="Write a program that prints 'Hello, world!'",
 | |
|         )
 | |
|         expected_messages = [
 | |
|             "Write a program that prints 'Hello, world!'",
 | |
|             'Here is the program\n ```python\nprint("Hello, world!")\n```',
 | |
|             "Hello, world!",
 | |
|             "TERMINATE",
 | |
|         ]
 | |
|         # Normalize the messages to remove \r\n and any leading/trailing whitespace.
 | |
|         normalized_messages = [
 | |
|             msg.content.replace("\r\n", "\n").rstrip("\n") if isinstance(msg.content, str) else msg.content
 | |
|             for msg in result.messages
 | |
|         ]
 | |
| 
 | |
|         # Assert that all expected messages are in the collected messages
 | |
|         assert normalized_messages == expected_messages
 | |
| 
 | |
|         assert result.stop_reason is not None and result.stop_reason == "Text 'TERMINATE' mentioned"
 | |
| 
 | |
|         # Test streaming.
 | |
|         mock.reset()
 | |
|         index = 0
 | |
|         await team.reset()
 | |
|         async for message in team.run_stream(
 | |
|             task="Write a program that prints 'Hello, world!'",
 | |
|         ):
 | |
|             if isinstance(message, TaskResult):
 | |
|                 assert message == result
 | |
|             else:
 | |
|                 assert message == result.messages[index]
 | |
|             index += 1
 | |
| 
 | |
|         # Test message input.
 | |
|         # Text message.
 | |
|         mock.reset()
 | |
|         index = 0
 | |
|         await team.reset()
 | |
|         result_2 = await team.run(
 | |
|             task=TextMessage(content="Write a program that prints 'Hello, world!'", source="user")
 | |
|         )
 | |
|         assert result == result_2
 | |
| 
 | |
|         # Test multi-modal message.
 | |
|         mock.reset()
 | |
|         index = 0
 | |
|         await team.reset()
 | |
|         result_2 = await team.run(
 | |
|             task=MultiModalMessage(content=["Write a program that prints 'Hello, world!'"], source="user")
 | |
|         )
 | |
|         assert result.messages[0].content == result_2.messages[0].content[0]
 | |
|         assert result.messages[1:] == result_2.messages[1:]
 | |
| 
 | |
| 
 | |
| @pytest.mark.asyncio
 | |
| async def test_round_robin_group_chat_state() -> None:
 | |
|     model_client = ReplayChatCompletionClient(
 | |
|         ["No facts", "No plan", "print('Hello, world!')", "TERMINATE"],
 | |
|     )
 | |
|     agent1 = AssistantAgent("agent1", model_client=model_client)
 | |
|     agent2 = AssistantAgent("agent2", model_client=model_client)
 | |
|     termination = TextMentionTermination("TERMINATE")
 | |
|     team1 = RoundRobinGroupChat(participants=[agent1, agent2], termination_condition=termination)
 | |
|     await team1.run(task="Write a program that prints 'Hello, world!'")
 | |
|     state = await team1.save_state()
 | |
| 
 | |
|     agent3 = AssistantAgent("agent1", model_client=model_client)
 | |
|     agent4 = AssistantAgent("agent2", model_client=model_client)
 | |
|     team2 = RoundRobinGroupChat(participants=[agent3, agent4], termination_condition=termination)
 | |
|     await team2.load_state(state)
 | |
|     state2 = await team2.save_state()
 | |
|     assert state == state2
 | |
| 
 | |
|     agent1_model_ctx_messages = await agent1._model_context.get_messages()  # pyright: ignore
 | |
|     agent2_model_ctx_messages = await agent2._model_context.get_messages()  # pyright: ignore
 | |
|     agent3_model_ctx_messages = await agent3._model_context.get_messages()  # pyright: ignore
 | |
|     agent4_model_ctx_messages = await agent4._model_context.get_messages()  # pyright: ignore
 | |
|     assert agent3_model_ctx_messages == agent1_model_ctx_messages
 | |
|     assert agent4_model_ctx_messages == agent2_model_ctx_messages
 | |
|     manager_1 = await team1._runtime.try_get_underlying_agent_instance(  # pyright: ignore
 | |
|         AgentId("group_chat_manager", team1._team_id),  # pyright: ignore
 | |
|         RoundRobinGroupChatManager,  # pyright: ignore
 | |
|     )  # pyright: ignore
 | |
|     manager_2 = await team2._runtime.try_get_underlying_agent_instance(  # pyright: ignore
 | |
|         AgentId("group_chat_manager", team2._team_id),  # pyright: ignore
 | |
|         RoundRobinGroupChatManager,  # pyright: ignore
 | |
|     )  # pyright: ignore
 | |
|     assert manager_1._current_turn == manager_2._current_turn  # pyright: ignore
 | |
|     assert manager_1._message_thread == manager_2._message_thread  # pyright: ignore
 | |
| 
 | |
| 
 | |
| @pytest.mark.asyncio
 | |
| async def test_round_robin_group_chat_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
 | |
|     model = "gpt-4o-2024-05-13"
 | |
|     chat_completions = [
 | |
|         ChatCompletion(
 | |
|             id="id1",
 | |
|             choices=[
 | |
|                 Choice(
 | |
|                     finish_reason="tool_calls",
 | |
|                     index=0,
 | |
|                     message=ChatCompletionMessage(
 | |
|                         content=None,
 | |
|                         tool_calls=[
 | |
|                             ChatCompletionMessageToolCall(
 | |
|                                 id="1",
 | |
|                                 type="function",
 | |
|                                 function=Function(
 | |
|                                     name="pass",
 | |
|                                     arguments=json.dumps({"input": "pass"}),
 | |
|                                 ),
 | |
|                             )
 | |
|                         ],
 | |
|                         role="assistant",
 | |
|                     ),
 | |
|                 )
 | |
|             ],
 | |
|             created=0,
 | |
|             model=model,
 | |
|             object="chat.completion",
 | |
|             usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
 | |
|         ),
 | |
|         ChatCompletion(
 | |
|             id="id2",
 | |
|             choices=[
 | |
|                 Choice(finish_reason="stop", index=0, message=ChatCompletionMessage(content="Hello", role="assistant"))
 | |
|             ],
 | |
|             created=0,
 | |
|             model=model,
 | |
|             object="chat.completion",
 | |
|             usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
 | |
|         ),
 | |
|         ChatCompletion(
 | |
|             id="id2",
 | |
|             choices=[
 | |
|                 Choice(
 | |
|                     finish_reason="stop", index=0, message=ChatCompletionMessage(content="TERMINATE", role="assistant")
 | |
|                 )
 | |
|             ],
 | |
|             created=0,
 | |
|             model=model,
 | |
|             object="chat.completion",
 | |
|             usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
 | |
|         ),
 | |
|     ]
 | |
|     # Test with repeat tool calls once
 | |
|     mock = _MockChatCompletion(chat_completions)
 | |
|     monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
 | |
|     tool = FunctionTool(_pass_function, name="pass", description="pass function")
 | |
|     tool_use_agent = AssistantAgent(
 | |
|         "tool_use_agent",
 | |
|         model_client=OpenAIChatCompletionClient(model=model, api_key=""),
 | |
|         tools=[tool],
 | |
|     )
 | |
|     echo_agent = _EchoAgent("echo_agent", description="echo agent")
 | |
|     termination = TextMentionTermination("TERMINATE")
 | |
|     team = RoundRobinGroupChat(participants=[tool_use_agent, echo_agent], termination_condition=termination)
 | |
|     result = await team.run(
 | |
|         task="Write a program that prints 'Hello, world!'",
 | |
|     )
 | |
|     assert len(result.messages) == 8
 | |
|     assert isinstance(result.messages[0], TextMessage)  # task
 | |
|     assert isinstance(result.messages[1], ToolCallRequestEvent)  # tool call
 | |
|     assert isinstance(result.messages[2], ToolCallExecutionEvent)  # tool call result
 | |
|     assert isinstance(result.messages[3], ToolCallSummaryMessage)  # tool use agent response
 | |
|     assert result.messages[3].content == "pass"  #  ensure the tool call was executed
 | |
|     assert isinstance(result.messages[4], TextMessage)  # echo agent response
 | |
|     assert isinstance(result.messages[5], TextMessage)  # tool use agent response
 | |
|     assert isinstance(result.messages[6], TextMessage)  # echo agent response
 | |
|     assert isinstance(result.messages[7], TextMessage)  # tool use agent response, that has TERMINATE
 | |
|     assert result.messages[7].content == "TERMINATE"
 | |
| 
 | |
|     assert result.stop_reason is not None and result.stop_reason == "Text 'TERMINATE' mentioned"
 | |
| 
 | |
|     # Test streaming.
 | |
|     await tool_use_agent._model_context.clear()  # pyright: ignore
 | |
|     mock.reset()
 | |
|     index = 0
 | |
|     await team.reset()
 | |
|     async for message in team.run_stream(
 | |
|         task="Write a program that prints 'Hello, world!'",
 | |
|     ):
 | |
|         if isinstance(message, TaskResult):
 | |
|             assert message == result
 | |
|         else:
 | |
|             assert message == result.messages[index]
 | |
|         index += 1
 | |
| 
 | |
|     # Test Console.
 | |
|     await tool_use_agent._model_context.clear()  # pyright: ignore
 | |
|     mock.reset()
 | |
|     index = 0
 | |
|     await team.reset()
 | |
|     result2 = await Console(team.run_stream(task="Write a program that prints 'Hello, world!'"))
 | |
|     assert result2 == result
 | |
| 
 | |
| 
 | |
| @pytest.mark.asyncio
 | |
| async def test_round_robin_group_chat_with_resume_and_reset() -> None:
 | |
|     agent_1 = _EchoAgent("agent_1", description="echo agent 1")
 | |
|     agent_2 = _EchoAgent("agent_2", description="echo agent 2")
 | |
|     agent_3 = _EchoAgent("agent_3", description="echo agent 3")
 | |
|     agent_4 = _EchoAgent("agent_4", description="echo agent 4")
 | |
|     termination = MaxMessageTermination(3)
 | |
|     team = RoundRobinGroupChat(participants=[agent_1, agent_2, agent_3, agent_4], termination_condition=termination)
 | |
|     result = await team.run(
 | |
|         task="Write a program that prints 'Hello, world!'",
 | |
|     )
 | |
|     assert len(result.messages) == 3
 | |
|     assert result.messages[1].source == "agent_1"
 | |
|     assert result.messages[2].source == "agent_2"
 | |
|     assert result.stop_reason is not None
 | |
| 
 | |
|     # Resume.
 | |
|     result = await team.run()
 | |
|     assert len(result.messages) == 3
 | |
|     assert result.messages[0].source == "agent_3"
 | |
|     assert result.messages[1].source == "agent_4"
 | |
|     assert result.messages[2].source == "agent_1"
 | |
|     assert result.stop_reason is not None
 | |
| 
 | |
|     # Reset.
 | |
|     await team.reset()
 | |
|     result = await team.run(task="Write a program that prints 'Hello, world!'")
 | |
|     assert len(result.messages) == 3
 | |
|     assert result.messages[1].source == "agent_1"
 | |
|     assert result.messages[2].source == "agent_2"
 | |
|     assert result.stop_reason is not None
 | |
| 
 | |
| 
 | |
| @pytest.mark.asyncio
 | |
| async def test_round_robin_group_chat_max_turn() -> None:
 | |
|     agent_1 = _EchoAgent("agent_1", description="echo agent 1")
 | |
|     agent_2 = _EchoAgent("agent_2", description="echo agent 2")
 | |
|     agent_3 = _EchoAgent("agent_3", description="echo agent 3")
 | |
|     agent_4 = _EchoAgent("agent_4", description="echo agent 4")
 | |
|     team = RoundRobinGroupChat(participants=[agent_1, agent_2, agent_3, agent_4], max_turns=3)
 | |
|     result = await team.run(
 | |
|         task="Write a program that prints 'Hello, world!'",
 | |
|     )
 | |
|     assert len(result.messages) == 4
 | |
|     assert result.messages[1].source == "agent_1"
 | |
|     assert result.messages[2].source == "agent_2"
 | |
|     assert result.messages[3].source == "agent_3"
 | |
|     assert result.stop_reason is not None
 | |
| 
 | |
|     # Resume.
 | |
|     result = await team.run()
 | |
|     assert len(result.messages) == 3
 | |
|     assert result.messages[0].source == "agent_4"
 | |
|     assert result.messages[1].source == "agent_1"
 | |
|     assert result.messages[2].source == "agent_2"
 | |
|     assert result.stop_reason is not None
 | |
| 
 | |
|     # Reset.
 | |
|     await team.reset()
 | |
|     result = await team.run(task="Write a program that prints 'Hello, world!'")
 | |
|     assert len(result.messages) == 4
 | |
|     assert result.messages[1].source == "agent_1"
 | |
|     assert result.messages[2].source == "agent_2"
 | |
|     assert result.messages[3].source == "agent_3"
 | |
|     assert result.stop_reason is not None
 | |
| 
 | |
| 
 | |
| @pytest.mark.asyncio
 | |
| async def test_round_robin_group_chat_cancellation() -> None:
 | |
|     agent_1 = _EchoAgent("agent_1", description="echo agent 1")
 | |
|     agent_2 = _EchoAgent("agent_2", description="echo agent 2")
 | |
|     agent_3 = _EchoAgent("agent_3", description="echo agent 3")
 | |
|     agent_4 = _EchoAgent("agent_4", description="echo agent 4")
 | |
|     # Set max_turns to a large number to avoid stopping due to max_turns before cancellation.
 | |
|     team = RoundRobinGroupChat(participants=[agent_1, agent_2, agent_3, agent_4], max_turns=1000)
 | |
|     cancellation_token = CancellationToken()
 | |
|     run_task = asyncio.create_task(
 | |
|         team.run(
 | |
|             task="Write a program that prints 'Hello, world!'",
 | |
|             cancellation_token=cancellation_token,
 | |
|         )
 | |
|     )
 | |
|     await asyncio.sleep(0.1)
 | |
|     # Cancel the task.
 | |
|     cancellation_token.cancel()
 | |
|     with pytest.raises(asyncio.CancelledError):
 | |
|         await run_task
 | |
| 
 | |
|     # Total messages produced so far.
 | |
|     total_messages = agent_1.total_messages + agent_2.total_messages + agent_3.total_messages + agent_4.total_messages
 | |
| 
 | |
|     # Still can run again and finish the task.
 | |
|     result = await team.run()
 | |
|     assert len(result.messages) + total_messages == 1000
 | |
| 
 | |
| 
 | |
| @pytest.mark.asyncio
 | |
| async def test_selector_group_chat(monkeypatch: pytest.MonkeyPatch) -> None:
 | |
|     model = "gpt-4o-2024-05-13"
 | |
|     chat_completions = [
 | |
|         ChatCompletion(
 | |
|             id="id2",
 | |
|             choices=[
 | |
|                 Choice(finish_reason="stop", index=0, message=ChatCompletionMessage(content="agent3", role="assistant"))
 | |
|             ],
 | |
|             created=0,
 | |
|             model=model,
 | |
|             object="chat.completion",
 | |
|             usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
 | |
|         ),
 | |
|         ChatCompletion(
 | |
|             id="id2",
 | |
|             choices=[
 | |
|                 Choice(finish_reason="stop", index=0, message=ChatCompletionMessage(content="agent2", role="assistant"))
 | |
|             ],
 | |
|             created=0,
 | |
|             model=model,
 | |
|             object="chat.completion",
 | |
|             usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
 | |
|         ),
 | |
|         ChatCompletion(
 | |
|             id="id2",
 | |
|             choices=[
 | |
|                 Choice(finish_reason="stop", index=0, message=ChatCompletionMessage(content="agent1", role="assistant"))
 | |
|             ],
 | |
|             created=0,
 | |
|             model=model,
 | |
|             object="chat.completion",
 | |
|             usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
 | |
|         ),
 | |
|         ChatCompletion(
 | |
|             id="id2",
 | |
|             choices=[
 | |
|                 Choice(finish_reason="stop", index=0, message=ChatCompletionMessage(content="agent2", role="assistant"))
 | |
|             ],
 | |
|             created=0,
 | |
|             model=model,
 | |
|             object="chat.completion",
 | |
|             usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
 | |
|         ),
 | |
|         ChatCompletion(
 | |
|             id="id2",
 | |
|             choices=[
 | |
|                 Choice(finish_reason="stop", index=0, message=ChatCompletionMessage(content="agent1", role="assistant"))
 | |
|             ],
 | |
|             created=0,
 | |
|             model=model,
 | |
|             object="chat.completion",
 | |
|             usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
 | |
|         ),
 | |
|     ]
 | |
|     mock = _MockChatCompletion(chat_completions)
 | |
|     monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
 | |
| 
 | |
|     agent1 = _StopAgent("agent1", description="echo agent 1", stop_at=2)
 | |
|     agent2 = _EchoAgent("agent2", description="echo agent 2")
 | |
|     agent3 = _EchoAgent("agent3", description="echo agent 3")
 | |
|     termination = TextMentionTermination("TERMINATE")
 | |
|     team = SelectorGroupChat(
 | |
|         participants=[agent1, agent2, agent3],
 | |
|         model_client=OpenAIChatCompletionClient(model=model, api_key=""),
 | |
|         termination_condition=termination,
 | |
|     )
 | |
|     result = await team.run(
 | |
|         task="Write a program that prints 'Hello, world!'",
 | |
|     )
 | |
|     assert len(result.messages) == 6
 | |
|     assert result.messages[0].content == "Write a program that prints 'Hello, world!'"
 | |
|     assert result.messages[1].source == "agent3"
 | |
|     assert result.messages[2].source == "agent2"
 | |
|     assert result.messages[3].source == "agent1"
 | |
|     assert result.messages[4].source == "agent2"
 | |
|     assert result.messages[5].source == "agent1"
 | |
|     assert result.stop_reason is not None and result.stop_reason == "Text 'TERMINATE' mentioned"
 | |
| 
 | |
|     # Test streaming.
 | |
|     mock.reset()
 | |
|     agent1._count = 0  # pyright: ignore
 | |
|     index = 0
 | |
|     await team.reset()
 | |
|     async for message in team.run_stream(
 | |
|         task="Write a program that prints 'Hello, world!'",
 | |
|     ):
 | |
|         if isinstance(message, TaskResult):
 | |
|             assert message == result
 | |
|         else:
 | |
|             assert message == result.messages[index]
 | |
|         index += 1
 | |
| 
 | |
|     # Test Console.
 | |
|     mock.reset()
 | |
|     agent1._count = 0  # pyright: ignore
 | |
|     index = 0
 | |
|     await team.reset()
 | |
|     result2 = await Console(team.run_stream(task="Write a program that prints 'Hello, world!'"))
 | |
|     assert result2 == result
 | |
| 
 | |
| 
 | |
| @pytest.mark.asyncio
 | |
| async def test_selector_group_chat_state() -> None:
 | |
|     model_client = ReplayChatCompletionClient(
 | |
|         ["agent1", "No facts", "agent2", "No plan", "agent1", "print('Hello, world!')", "agent2", "TERMINATE"],
 | |
|     )
 | |
|     agent1 = AssistantAgent("agent1", model_client=model_client)
 | |
|     agent2 = AssistantAgent("agent2", model_client=model_client)
 | |
|     termination = TextMentionTermination("TERMINATE")
 | |
|     team1 = SelectorGroupChat(
 | |
|         participants=[agent1, agent2], termination_condition=termination, model_client=model_client
 | |
|     )
 | |
|     await team1.run(task="Write a program that prints 'Hello, world!'")
 | |
|     state = await team1.save_state()
 | |
| 
 | |
|     agent3 = AssistantAgent("agent1", model_client=model_client)
 | |
|     agent4 = AssistantAgent("agent2", model_client=model_client)
 | |
|     team2 = SelectorGroupChat(
 | |
|         participants=[agent3, agent4], termination_condition=termination, model_client=model_client
 | |
|     )
 | |
|     await team2.load_state(state)
 | |
|     state2 = await team2.save_state()
 | |
|     assert state == state2
 | |
| 
 | |
|     agent1_model_ctx_messages = await agent1._model_context.get_messages()  # pyright: ignore
 | |
|     agent2_model_ctx_messages = await agent2._model_context.get_messages()  # pyright: ignore
 | |
|     agent3_model_ctx_messages = await agent3._model_context.get_messages()  # pyright: ignore
 | |
|     agent4_model_ctx_messages = await agent4._model_context.get_messages()  # pyright: ignore
 | |
|     assert agent3_model_ctx_messages == agent1_model_ctx_messages
 | |
|     assert agent4_model_ctx_messages == agent2_model_ctx_messages
 | |
|     manager_1 = await team1._runtime.try_get_underlying_agent_instance(  # pyright: ignore
 | |
|         AgentId("group_chat_manager", team1._team_id),  # pyright: ignore
 | |
|         SelectorGroupChatManager,  # pyright: ignore
 | |
|     )  # pyright: ignore
 | |
|     manager_2 = await team2._runtime.try_get_underlying_agent_instance(  # pyright: ignore
 | |
|         AgentId("group_chat_manager", team2._team_id),  # pyright: ignore
 | |
|         SelectorGroupChatManager,  # pyright: ignore
 | |
|     )  # pyright: ignore
 | |
|     assert manager_1._message_thread == manager_2._message_thread  # pyright: ignore
 | |
|     assert manager_1._previous_speaker == manager_2._previous_speaker  # pyright: ignore
 | |
| 
 | |
| 
 | |
| @pytest.mark.asyncio
 | |
| async def test_selector_group_chat_two_speakers(monkeypatch: pytest.MonkeyPatch) -> None:
 | |
|     model = "gpt-4o-2024-05-13"
 | |
|     chat_completions = [
 | |
|         ChatCompletion(
 | |
|             id="id2",
 | |
|             choices=[
 | |
|                 Choice(finish_reason="stop", index=0, message=ChatCompletionMessage(content="agent2", role="assistant"))
 | |
|             ],
 | |
|             created=0,
 | |
|             model=model,
 | |
|             object="chat.completion",
 | |
|             usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
 | |
|         ),
 | |
|     ]
 | |
|     mock = _MockChatCompletion(chat_completions)
 | |
|     monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
 | |
| 
 | |
|     agent1 = _StopAgent("agent1", description="echo agent 1", stop_at=2)
 | |
|     agent2 = _EchoAgent("agent2", description="echo agent 2")
 | |
|     termination = TextMentionTermination("TERMINATE")
 | |
|     team = SelectorGroupChat(
 | |
|         participants=[agent1, agent2],
 | |
|         termination_condition=termination,
 | |
|         model_client=OpenAIChatCompletionClient(model=model, api_key=""),
 | |
|     )
 | |
|     result = await team.run(
 | |
|         task="Write a program that prints 'Hello, world!'",
 | |
|     )
 | |
|     assert len(result.messages) == 5
 | |
|     assert result.messages[0].content == "Write a program that prints 'Hello, world!'"
 | |
|     assert result.messages[1].source == "agent2"
 | |
|     assert result.messages[2].source == "agent1"
 | |
|     assert result.messages[3].source == "agent2"
 | |
|     assert result.messages[4].source == "agent1"
 | |
|     # only one chat completion was called
 | |
|     assert mock._curr_index == 1  # pyright: ignore
 | |
|     assert result.stop_reason is not None and result.stop_reason == "Text 'TERMINATE' mentioned"
 | |
| 
 | |
|     # Test streaming.
 | |
|     mock.reset()
 | |
|     agent1._count = 0  # pyright: ignore
 | |
|     index = 0
 | |
|     await team.reset()
 | |
|     async for message in team.run_stream(task="Write a program that prints 'Hello, world!'"):
 | |
|         if isinstance(message, TaskResult):
 | |
|             assert message == result
 | |
|         else:
 | |
|             assert message == result.messages[index]
 | |
|         index += 1
 | |
| 
 | |
|     # Test Console.
 | |
|     mock.reset()
 | |
|     agent1._count = 0  # pyright: ignore
 | |
|     index = 0
 | |
|     await team.reset()
 | |
|     result2 = await Console(team.run_stream(task="Write a program that prints 'Hello, world!'"))
 | |
|     assert result2 == result
 | |
| 
 | |
| 
 | |
| @pytest.mark.asyncio
 | |
| async def test_selector_group_chat_two_speakers_allow_repeated(monkeypatch: pytest.MonkeyPatch) -> None:
 | |
|     model = "gpt-4o-2024-05-13"
 | |
|     chat_completions = [
 | |
|         ChatCompletion(
 | |
|             id="id2",
 | |
|             choices=[
 | |
|                 Choice(finish_reason="stop", index=0, message=ChatCompletionMessage(content="agent2", role="assistant"))
 | |
|             ],
 | |
|             created=0,
 | |
|             model=model,
 | |
|             object="chat.completion",
 | |
|             usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
 | |
|         ),
 | |
|         ChatCompletion(
 | |
|             id="id2",
 | |
|             choices=[
 | |
|                 Choice(finish_reason="stop", index=0, message=ChatCompletionMessage(content="agent2", role="assistant"))
 | |
|             ],
 | |
|             created=0,
 | |
|             model=model,
 | |
|             object="chat.completion",
 | |
|             usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
 | |
|         ),
 | |
|         ChatCompletion(
 | |
|             id="id2",
 | |
|             choices=[
 | |
|                 Choice(finish_reason="stop", index=0, message=ChatCompletionMessage(content="agent1", role="assistant"))
 | |
|             ],
 | |
|             created=0,
 | |
|             model=model,
 | |
|             object="chat.completion",
 | |
|             usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
 | |
|         ),
 | |
|     ]
 | |
|     mock = _MockChatCompletion(chat_completions)
 | |
|     monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
 | |
| 
 | |
|     agent1 = _StopAgent("agent1", description="echo agent 1", stop_at=1)
 | |
|     agent2 = _EchoAgent("agent2", description="echo agent 2")
 | |
|     termination = TextMentionTermination("TERMINATE")
 | |
|     team = SelectorGroupChat(
 | |
|         participants=[agent1, agent2],
 | |
|         model_client=OpenAIChatCompletionClient(model=model, api_key=""),
 | |
|         termination_condition=termination,
 | |
|         allow_repeated_speaker=True,
 | |
|     )
 | |
|     result = await team.run(task="Write a program that prints 'Hello, world!'")
 | |
|     assert len(result.messages) == 4
 | |
|     assert result.messages[0].content == "Write a program that prints 'Hello, world!'"
 | |
|     assert result.messages[1].source == "agent2"
 | |
|     assert result.messages[2].source == "agent2"
 | |
|     assert result.messages[3].source == "agent1"
 | |
|     assert result.stop_reason is not None and result.stop_reason == "Text 'TERMINATE' mentioned"
 | |
| 
 | |
|     # Test streaming.
 | |
|     mock.reset()
 | |
|     index = 0
 | |
|     await team.reset()
 | |
|     async for message in team.run_stream(task="Write a program that prints 'Hello, world!'"):
 | |
|         if isinstance(message, TaskResult):
 | |
|             assert message == result
 | |
|         else:
 | |
|             assert message == result.messages[index]
 | |
|         index += 1
 | |
| 
 | |
|     # Test Console.
 | |
|     mock.reset()
 | |
|     index = 0
 | |
|     await team.reset()
 | |
|     result2 = await Console(team.run_stream(task="Write a program that prints 'Hello, world!'"))
 | |
|     assert result2 == result
 | |
| 
 | |
| 
 | |
| @pytest.mark.asyncio
 | |
| async def test_selector_group_chat_custom_selector(monkeypatch: pytest.MonkeyPatch) -> None:
 | |
|     model = "gpt-4o-2024-05-13"
 | |
|     chat_completions = [
 | |
|         ChatCompletion(
 | |
|             id="id2",
 | |
|             choices=[
 | |
|                 Choice(finish_reason="stop", index=0, message=ChatCompletionMessage(content="agent3", role="assistant"))
 | |
|             ],
 | |
|             created=0,
 | |
|             model=model,
 | |
|             object="chat.completion",
 | |
|             usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
 | |
|         ),
 | |
|     ]
 | |
|     mock = _MockChatCompletion(chat_completions)
 | |
|     monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
 | |
|     agent1 = _EchoAgent("agent1", description="echo agent 1")
 | |
|     agent2 = _EchoAgent("agent2", description="echo agent 2")
 | |
|     agent3 = _EchoAgent("agent3", description="echo agent 3")
 | |
|     agent4 = _EchoAgent("agent4", description="echo agent 4")
 | |
| 
 | |
|     def _select_agent(messages: Sequence[AgentEvent | ChatMessage]) -> str | None:
 | |
|         if len(messages) == 0:
 | |
|             return "agent1"
 | |
|         elif messages[-1].source == "agent1":
 | |
|             return "agent2"
 | |
|         elif messages[-1].source == "agent2":
 | |
|             return None
 | |
|         elif messages[-1].source == "agent3":
 | |
|             return "agent4"
 | |
|         else:
 | |
|             return "agent1"
 | |
| 
 | |
|     termination = MaxMessageTermination(6)
 | |
|     team = SelectorGroupChat(
 | |
|         participants=[agent1, agent2, agent3, agent4],
 | |
|         model_client=OpenAIChatCompletionClient(model=model, api_key=""),
 | |
|         selector_func=_select_agent,
 | |
|         termination_condition=termination,
 | |
|     )
 | |
|     result = await team.run(task="task")
 | |
|     assert len(result.messages) == 6
 | |
|     assert result.messages[1].source == "agent1"
 | |
|     assert result.messages[2].source == "agent2"
 | |
|     assert result.messages[3].source == "agent3"
 | |
|     assert result.messages[4].source == "agent4"
 | |
|     assert result.messages[5].source == "agent1"
 | |
|     assert (
 | |
|         result.stop_reason is not None
 | |
|         and result.stop_reason == "Maximum number of messages 6 reached, current message count: 6"
 | |
|     )
 | |
| 
 | |
| 
 | |
| class _HandOffAgent(BaseChatAgent):
 | |
|     def __init__(self, name: str, description: str, next_agent: str) -> None:
 | |
|         super().__init__(name, description)
 | |
|         self._next_agent = next_agent
 | |
| 
 | |
|     @property
 | |
|     def produced_message_types(self) -> Sequence[type[ChatMessage]]:
 | |
|         return (HandoffMessage,)
 | |
| 
 | |
|     async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
 | |
|         return Response(
 | |
|             chat_message=HandoffMessage(
 | |
|                 content=f"Transferred to {self._next_agent}.", target=self._next_agent, source=self.name
 | |
|             )
 | |
|         )
 | |
| 
 | |
|     async def on_reset(self, cancellation_token: CancellationToken) -> None:
 | |
|         pass
 | |
| 
 | |
| 
 | |
| @pytest.mark.asyncio
 | |
| async def test_swarm_handoff() -> None:
 | |
|     first_agent = _HandOffAgent("first_agent", description="first agent", next_agent="second_agent")
 | |
|     second_agent = _HandOffAgent("second_agent", description="second agent", next_agent="third_agent")
 | |
|     third_agent = _HandOffAgent("third_agent", description="third agent", next_agent="first_agent")
 | |
| 
 | |
|     termination = MaxMessageTermination(6)
 | |
|     team = Swarm([second_agent, first_agent, third_agent], termination_condition=termination)
 | |
|     result = await team.run(task="task")
 | |
|     assert len(result.messages) == 6
 | |
|     assert result.messages[0].content == "task"
 | |
|     assert result.messages[1].content == "Transferred to third_agent."
 | |
|     assert result.messages[2].content == "Transferred to first_agent."
 | |
|     assert result.messages[3].content == "Transferred to second_agent."
 | |
|     assert result.messages[4].content == "Transferred to third_agent."
 | |
|     assert result.messages[5].content == "Transferred to first_agent."
 | |
|     assert (
 | |
|         result.stop_reason is not None
 | |
|         and result.stop_reason == "Maximum number of messages 6 reached, current message count: 6"
 | |
|     )
 | |
| 
 | |
|     # Test streaming.
 | |
|     index = 0
 | |
|     await team.reset()
 | |
|     stream = team.run_stream(task="task")
 | |
|     async for message in stream:
 | |
|         if isinstance(message, TaskResult):
 | |
|             assert message == result
 | |
|         else:
 | |
|             assert message == result.messages[index]
 | |
|         index += 1
 | |
| 
 | |
|     # Test save and load.
 | |
|     state = await team.save_state()
 | |
|     first_agent2 = _HandOffAgent("first_agent", description="first agent", next_agent="second_agent")
 | |
|     second_agent2 = _HandOffAgent("second_agent", description="second agent", next_agent="third_agent")
 | |
|     third_agent2 = _HandOffAgent("third_agent", description="third agent", next_agent="first_agent")
 | |
|     team2 = Swarm([second_agent2, first_agent2, third_agent2], termination_condition=termination)
 | |
|     await team2.load_state(state)
 | |
|     state2 = await team2.save_state()
 | |
|     assert state == state2
 | |
|     manager_1 = await team._runtime.try_get_underlying_agent_instance(  # pyright: ignore
 | |
|         AgentId("group_chat_manager", team._team_id),  # pyright: ignore
 | |
|         SwarmGroupChatManager,  # pyright: ignore
 | |
|     )  # pyright: ignore
 | |
|     manager_2 = await team2._runtime.try_get_underlying_agent_instance(  # pyright: ignore
 | |
|         AgentId("group_chat_manager", team2._team_id),  # pyright: ignore
 | |
|         SwarmGroupChatManager,  # pyright: ignore
 | |
|     )  # pyright: ignore
 | |
|     assert manager_1._message_thread == manager_2._message_thread  # pyright: ignore
 | |
|     assert manager_1._current_speaker == manager_2._current_speaker  # pyright: ignore
 | |
| 
 | |
| 
 | |
| @pytest.mark.asyncio
 | |
| async def test_swarm_handoff_using_tool_calls(monkeypatch: pytest.MonkeyPatch) -> None:
 | |
|     model = "gpt-4o-2024-05-13"
 | |
|     chat_completions = [
 | |
|         ChatCompletion(
 | |
|             id="id1",
 | |
|             choices=[
 | |
|                 Choice(
 | |
|                     finish_reason="tool_calls",
 | |
|                     index=0,
 | |
|                     message=ChatCompletionMessage(
 | |
|                         content=None,
 | |
|                         tool_calls=[
 | |
|                             ChatCompletionMessageToolCall(
 | |
|                                 id="1",
 | |
|                                 type="function",
 | |
|                                 function=Function(
 | |
|                                     name="handoff_to_agent2",
 | |
|                                     arguments=json.dumps({}),
 | |
|                                 ),
 | |
|                             )
 | |
|                         ],
 | |
|                         role="assistant",
 | |
|                     ),
 | |
|                 )
 | |
|             ],
 | |
|             created=0,
 | |
|             model=model,
 | |
|             object="chat.completion",
 | |
|             usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
 | |
|         ),
 | |
|         ChatCompletion(
 | |
|             id="id2",
 | |
|             choices=[
 | |
|                 Choice(finish_reason="stop", index=0, message=ChatCompletionMessage(content="Hello", role="assistant"))
 | |
|             ],
 | |
|             created=0,
 | |
|             model=model,
 | |
|             object="chat.completion",
 | |
|             usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
 | |
|         ),
 | |
|         ChatCompletion(
 | |
|             id="id2",
 | |
|             choices=[
 | |
|                 Choice(
 | |
|                     finish_reason="stop", index=0, message=ChatCompletionMessage(content="TERMINATE", role="assistant")
 | |
|                 )
 | |
|             ],
 | |
|             created=0,
 | |
|             model=model,
 | |
|             object="chat.completion",
 | |
|             usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
 | |
|         ),
 | |
|     ]
 | |
|     mock = _MockChatCompletion(chat_completions)
 | |
|     monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
 | |
| 
 | |
|     agent1 = AssistantAgent(
 | |
|         "agent1",
 | |
|         model_client=OpenAIChatCompletionClient(model=model, api_key=""),
 | |
|         handoffs=[Handoff(target="agent2", name="handoff_to_agent2", message="handoff to agent2")],
 | |
|     )
 | |
|     agent2 = _HandOffAgent("agent2", description="agent 2", next_agent="agent1")
 | |
|     termination = TextMentionTermination("TERMINATE")
 | |
|     team = Swarm([agent1, agent2], termination_condition=termination)
 | |
|     result = await team.run(task="task")
 | |
|     assert len(result.messages) == 7
 | |
|     assert result.messages[0].content == "task"
 | |
|     assert isinstance(result.messages[1], ToolCallRequestEvent)
 | |
|     assert isinstance(result.messages[2], ToolCallExecutionEvent)
 | |
|     assert result.messages[3].content == "handoff to agent2"
 | |
|     assert result.messages[4].content == "Transferred to agent1."
 | |
|     assert result.messages[5].content == "Hello"
 | |
|     assert result.messages[6].content == "TERMINATE"
 | |
|     assert result.stop_reason is not None and result.stop_reason == "Text 'TERMINATE' mentioned"
 | |
| 
 | |
|     # Test streaming.
 | |
|     await agent1._model_context.clear()  # pyright: ignore
 | |
|     mock.reset()
 | |
|     index = 0
 | |
|     await team.reset()
 | |
|     stream = team.run_stream(task="task")
 | |
|     async for message in stream:
 | |
|         if isinstance(message, TaskResult):
 | |
|             assert message == result
 | |
|         else:
 | |
|             assert message == result.messages[index]
 | |
|         index += 1
 | |
| 
 | |
|     # Test Console
 | |
|     await agent1._model_context.clear()  # pyright: ignore
 | |
|     mock.reset()
 | |
|     index = 0
 | |
|     await team.reset()
 | |
|     result2 = await Console(team.run_stream(task="task"))
 | |
|     assert result2 == result
 | |
| 
 | |
| 
 | |
| @pytest.mark.asyncio
 | |
| async def test_swarm_pause_and_resume() -> None:
 | |
|     first_agent = _HandOffAgent("first_agent", description="first agent", next_agent="second_agent")
 | |
|     second_agent = _HandOffAgent("second_agent", description="second agent", next_agent="third_agent")
 | |
|     third_agent = _HandOffAgent("third_agent", description="third agent", next_agent="first_agent")
 | |
| 
 | |
|     team = Swarm([second_agent, first_agent, third_agent], max_turns=1)
 | |
|     result = await team.run(task="task")
 | |
|     assert len(result.messages) == 2
 | |
|     assert result.messages[0].content == "task"
 | |
|     assert result.messages[1].content == "Transferred to third_agent."
 | |
| 
 | |
|     # Resume with a new task.
 | |
|     result = await team.run(task="new task")
 | |
|     assert len(result.messages) == 2
 | |
|     assert result.messages[0].content == "new task"
 | |
|     assert result.messages[1].content == "Transferred to first_agent."
 | |
| 
 | |
|     # Resume with the same task.
 | |
|     result = await team.run()
 | |
|     assert len(result.messages) == 1
 | |
|     assert result.messages[0].content == "Transferred to second_agent."
 | |
| 
 | |
| 
 | |
| @pytest.mark.asyncio
 | |
| async def test_swarm_with_parallel_tool_calls(monkeypatch: pytest.MonkeyPatch) -> None:
 | |
|     model = "gpt-4o-2024-05-13"
 | |
|     chat_completions = [
 | |
|         ChatCompletion(
 | |
|             id="id1",
 | |
|             choices=[
 | |
|                 Choice(
 | |
|                     finish_reason="tool_calls",
 | |
|                     index=0,
 | |
|                     message=ChatCompletionMessage(
 | |
|                         content=None,
 | |
|                         tool_calls=[
 | |
|                             ChatCompletionMessageToolCall(
 | |
|                                 id="1",
 | |
|                                 type="function",
 | |
|                                 function=Function(
 | |
|                                     name="tool1",
 | |
|                                     arguments=json.dumps({}),
 | |
|                                 ),
 | |
|                             ),
 | |
|                             ChatCompletionMessageToolCall(
 | |
|                                 id="2",
 | |
|                                 type="function",
 | |
|                                 function=Function(
 | |
|                                     name="tool2",
 | |
|                                     arguments=json.dumps({}),
 | |
|                                 ),
 | |
|                             ),
 | |
|                             ChatCompletionMessageToolCall(
 | |
|                                 id="3",
 | |
|                                 type="function",
 | |
|                                 function=Function(
 | |
|                                     name="handoff_to_agent2",
 | |
|                                     arguments=json.dumps({}),
 | |
|                                 ),
 | |
|                             ),
 | |
|                         ],
 | |
|                         role="assistant",
 | |
|                     ),
 | |
|                 )
 | |
|             ],
 | |
|             created=0,
 | |
|             model=model,
 | |
|             object="chat.completion",
 | |
|             usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
 | |
|         ),
 | |
|         ChatCompletion(
 | |
|             id="id2",
 | |
|             choices=[
 | |
|                 Choice(finish_reason="stop", index=0, message=ChatCompletionMessage(content="Hello", role="assistant"))
 | |
|             ],
 | |
|             created=0,
 | |
|             model=model,
 | |
|             object="chat.completion",
 | |
|             usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
 | |
|         ),
 | |
|         ChatCompletion(
 | |
|             id="id2",
 | |
|             choices=[
 | |
|                 Choice(
 | |
|                     finish_reason="stop", index=0, message=ChatCompletionMessage(content="TERMINATE", role="assistant")
 | |
|                 )
 | |
|             ],
 | |
|             created=0,
 | |
|             model=model,
 | |
|             object="chat.completion",
 | |
|             usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
 | |
|         ),
 | |
|     ]
 | |
|     mock = _MockChatCompletion(chat_completions)
 | |
|     monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
 | |
| 
 | |
|     expected_handoff_context: List[LLMMessage] = [
 | |
|         AssistantMessage(
 | |
|             source="agent1",
 | |
|             content=[
 | |
|                 FunctionCall(id="1", name="tool1", arguments="{}"),
 | |
|                 FunctionCall(id="2", name="tool2", arguments="{}"),
 | |
|             ],
 | |
|         ),
 | |
|         FunctionExecutionResultMessage(
 | |
|             content=[
 | |
|                 FunctionExecutionResult(content="tool1", call_id="1"),
 | |
|                 FunctionExecutionResult(content="tool2", call_id="2"),
 | |
|             ]
 | |
|         ),
 | |
|     ]
 | |
| 
 | |
|     def tool1() -> str:
 | |
|         return "tool1"
 | |
| 
 | |
|     def tool2() -> str:
 | |
|         return "tool2"
 | |
| 
 | |
|     agent1 = AssistantAgent(
 | |
|         "agent1",
 | |
|         model_client=OpenAIChatCompletionClient(model=model, api_key=""),
 | |
|         handoffs=[Handoff(target="agent2", name="handoff_to_agent2", message="handoff to agent2")],
 | |
|         tools=[tool1, tool2],
 | |
|     )
 | |
|     agent2 = AssistantAgent(
 | |
|         "agent2",
 | |
|         model_client=OpenAIChatCompletionClient(model=model, api_key=""),
 | |
|     )
 | |
|     termination = TextMentionTermination("TERMINATE")
 | |
|     team = Swarm([agent1, agent2], termination_condition=termination)
 | |
|     result = await team.run(task="task")
 | |
|     assert len(result.messages) == 6
 | |
|     assert result.messages[0] == TextMessage(content="task", source="user")
 | |
|     assert isinstance(result.messages[1], ToolCallRequestEvent)
 | |
|     assert isinstance(result.messages[2], ToolCallExecutionEvent)
 | |
|     assert result.messages[3] == HandoffMessage(
 | |
|         content="handoff to agent2",
 | |
|         target="agent2",
 | |
|         source="agent1",
 | |
|         context=expected_handoff_context,
 | |
|     )
 | |
|     assert result.messages[4].content == "Hello"
 | |
|     assert result.messages[4].source == "agent2"
 | |
|     assert result.messages[5].content == "TERMINATE"
 | |
|     assert result.messages[5].source == "agent2"
 | |
| 
 | |
|     # Verify the tool calls are in agent2's context.
 | |
|     agent2_model_ctx_messages = await agent2._model_context.get_messages()  # pyright: ignore
 | |
|     assert agent2_model_ctx_messages[0] == UserMessage(content="task", source="user")
 | |
|     assert agent2_model_ctx_messages[1] == expected_handoff_context[0]
 | |
|     assert agent2_model_ctx_messages[2] == expected_handoff_context[1]
 | |
| 
 | |
| 
 | |
| @pytest.mark.asyncio
 | |
| async def test_swarm_with_handoff_termination() -> None:
 | |
|     first_agent = _HandOffAgent("first_agent", description="first agent", next_agent="second_agent")
 | |
|     second_agent = _HandOffAgent("second_agent", description="second agent", next_agent="third_agent")
 | |
|     third_agent = _HandOffAgent("third_agent", description="third agent", next_agent="first_agent")
 | |
| 
 | |
|     # Handoff to an existing agent.
 | |
|     termination = HandoffTermination(target="third_agent")
 | |
|     team = Swarm([second_agent, first_agent, third_agent], termination_condition=termination)
 | |
|     # Start
 | |
|     result = await team.run(task="task")
 | |
|     assert len(result.messages) == 2
 | |
|     assert result.messages[0].content == "task"
 | |
|     assert result.messages[1].content == "Transferred to third_agent."
 | |
|     # Resume existing.
 | |
|     result = await team.run()
 | |
|     assert len(result.messages) == 3
 | |
|     assert result.messages[0].content == "Transferred to first_agent."
 | |
|     assert result.messages[1].content == "Transferred to second_agent."
 | |
|     assert result.messages[2].content == "Transferred to third_agent."
 | |
|     # Resume new task.
 | |
|     result = await team.run(task="new task")
 | |
|     assert len(result.messages) == 4
 | |
|     assert result.messages[0].content == "new task"
 | |
|     assert result.messages[1].content == "Transferred to first_agent."
 | |
|     assert result.messages[2].content == "Transferred to second_agent."
 | |
|     assert result.messages[3].content == "Transferred to third_agent."
 | |
| 
 | |
|     # Handoff to a non-existing agent.
 | |
|     third_agent = _HandOffAgent("third_agent", description="third agent", next_agent="non_existing_agent")
 | |
|     termination = HandoffTermination(target="non_existing_agent")
 | |
|     team = Swarm([second_agent, first_agent, third_agent], termination_condition=termination)
 | |
|     # Start
 | |
|     result = await team.run(task="task")
 | |
|     assert len(result.messages) == 3
 | |
|     assert result.messages[0].content == "task"
 | |
|     assert result.messages[1].content == "Transferred to third_agent."
 | |
|     assert result.messages[2].content == "Transferred to non_existing_agent."
 | |
|     # Attempt to resume.
 | |
|     with pytest.raises(ValueError):
 | |
|         await team.run()
 | |
|     # Attempt to resume with a new task.
 | |
|     with pytest.raises(ValueError):
 | |
|         await team.run(task="new task")
 | |
|     # Resume with a HandoffMessage
 | |
|     result = await team.run(task=HandoffMessage(content="Handoff to first_agent.", target="first_agent", source="user"))
 | |
|     assert len(result.messages) == 4
 | |
|     assert result.messages[0].content == "Handoff to first_agent."
 | |
|     assert result.messages[1].content == "Transferred to second_agent."
 | |
|     assert result.messages[2].content == "Transferred to third_agent."
 | |
|     assert result.messages[3].content == "Transferred to non_existing_agent."
 | |
| 
 | |
| 
 | |
| @pytest.mark.asyncio
 | |
| async def test_round_robin_group_chat_with_message_list() -> None:
 | |
|     # Create a simple team with echo agents
 | |
|     agent1 = _EchoAgent("Agent1", "First agent")
 | |
|     agent2 = _EchoAgent("Agent2", "Second agent")
 | |
|     termination = MaxMessageTermination(4)  # Stop after 4 messages
 | |
|     team = RoundRobinGroupChat([agent1, agent2], termination_condition=termination)
 | |
| 
 | |
|     # Create a list of messages
 | |
|     messages: List[ChatMessage] = [
 | |
|         TextMessage(content="Message 1", source="user"),
 | |
|         TextMessage(content="Message 2", source="user"),
 | |
|         TextMessage(content="Message 3", source="user"),
 | |
|     ]
 | |
| 
 | |
|     # Run the team with the message list
 | |
|     result = await team.run(task=messages)
 | |
| 
 | |
|     # Verify the messages were processed in order
 | |
|     assert len(result.messages) == 4  # Initial messages + echo until termination
 | |
|     assert result.messages[0].content == "Message 1"  # First message
 | |
|     assert result.messages[1].content == "Message 2"  # Second message
 | |
|     assert result.messages[2].content == "Message 3"  # Third message
 | |
|     assert result.messages[3].content == "Message 1"  # Echo from first agent
 | |
|     assert result.stop_reason == "Maximum number of messages 4 reached, current message count: 4"
 | |
| 
 | |
|     # Test with streaming
 | |
|     await team.reset()
 | |
|     index = 0
 | |
|     async for message in team.run_stream(task=messages):
 | |
|         if isinstance(message, TaskResult):
 | |
|             assert message == result
 | |
|         else:
 | |
|             assert message == result.messages[index]
 | |
|             index += 1
 | |
| 
 | |
|     # Test with invalid message list
 | |
|     with pytest.raises(ValueError, match="All messages in task list must be valid ChatMessage types"):
 | |
|         await team.run(task=["not a message"])  # type: ignore[list-item, arg-type]  # intentionally testing invalid input
 | |
| 
 | |
|     # Test with empty message list
 | |
|     with pytest.raises(ValueError, match="Task list cannot be empty"):
 | |
|         await team.run(task=[])
 | |
| 
 | |
| 
 | |
| @pytest.mark.asyncio
 | |
| async def test_declarative_groupchats_with_config() -> None:
 | |
|     # Create basic agents and components for testing
 | |
|     agent1 = AssistantAgent(
 | |
|         "agent_1",
 | |
|         model_client=OpenAIChatCompletionClient(model="gpt-4o-2024-05-13", api_key=""),
 | |
|         handoffs=["agent_2"],
 | |
|     )
 | |
|     agent2 = AssistantAgent("agent_2", model_client=OpenAIChatCompletionClient(model="gpt-4o-2024-05-13", api_key=""))
 | |
|     termination = MaxMessageTermination(4)
 | |
|     model_client = OpenAIChatCompletionClient(model="gpt-4o-2024-05-13", api_key="")
 | |
| 
 | |
|     # Test round robin - verify config is preserved
 | |
|     round_robin = RoundRobinGroupChat(participants=[agent1, agent2], termination_condition=termination, max_turns=5)
 | |
|     config = round_robin.dump_component()
 | |
|     loaded = RoundRobinGroupChat.load_component(config)
 | |
|     assert loaded.dump_component() == config
 | |
| 
 | |
|     # Test selector group chat - verify config is preserved
 | |
|     selector_prompt = "Custom selector prompt with {roles}, {participants}, {history}"
 | |
|     selector = SelectorGroupChat(
 | |
|         participants=[agent1, agent2],
 | |
|         model_client=model_client,
 | |
|         termination_condition=termination,
 | |
|         max_turns=10,
 | |
|         selector_prompt=selector_prompt,
 | |
|         allow_repeated_speaker=True,
 | |
|     )
 | |
|     selector_config = selector.dump_component()
 | |
|     selector_loaded = SelectorGroupChat.load_component(selector_config)
 | |
|     assert selector_loaded.dump_component() == selector_config
 | |
| 
 | |
|     # Test swarm with handoff termination
 | |
|     handoff_termination = HandoffTermination(target="Agent2")
 | |
|     swarm = Swarm(participants=[agent1, agent2], termination_condition=handoff_termination, max_turns=5)
 | |
|     swarm_config = swarm.dump_component()
 | |
|     swarm_loaded = Swarm.load_component(swarm_config)
 | |
|     assert swarm_loaded.dump_component() == swarm_config
 | |
| 
 | |
|     # Test MagenticOne with custom parameters
 | |
|     magentic = MagenticOneGroupChat(
 | |
|         participants=[agent1],
 | |
|         model_client=model_client,
 | |
|         max_turns=15,
 | |
|         max_stalls=5,
 | |
|         final_answer_prompt="Custom prompt",
 | |
|     )
 | |
|     magentic_config = magentic.dump_component()
 | |
|     magentic_loaded = MagenticOneGroupChat.load_component(magentic_config)
 | |
|     assert magentic_loaded.dump_component() == magentic_config
 | |
| 
 | |
|     # Verify component types are correctly set for each
 | |
|     for team in [loaded, selector, swarm, magentic]:
 | |
|         assert team.component_type == "team"
 | |
| 
 | |
|     # Verify provider strings are correctly set
 | |
|     assert round_robin.dump_component().provider == "autogen_agentchat.teams.RoundRobinGroupChat"
 | |
|     assert selector.dump_component().provider == "autogen_agentchat.teams.SelectorGroupChat"
 | |
|     assert swarm.dump_component().provider == "autogen_agentchat.teams.Swarm"
 | |
|     assert magentic.dump_component().provider == "autogen_agentchat.teams.MagenticOneGroupChat"
 |