mirror of
https://github.com/microsoft/autogen.git
synced 2025-07-04 23:50:39 +00:00

Resolves #4075 1. Introduce custom runtime parameter for all AgentChat teams (RoundRobinGroupChat, SelectorGroupChat, etc.). This is done by making sure each team's topics are isolated from other teams, and decoupling state from agent identities. Also, I removed the closure agent from the BaseGroupChat and use the group chat manager agent to relay messages to the output message queue. 2. Added unit tests to test scenarios with custom runtimes by using pytest fixture 3. Refactored existing unit tests to use ReplayChatCompletionClient with a few improvements to the client. 4. Fix a one-liner bug in AssistantAgent that caused deserialized agent to have handoffs. How to use it? ```python import asyncio from autogen_core import SingleThreadedAgentRuntime from autogen_agentchat.agents import AssistantAgent from autogen_agentchat.teams import RoundRobinGroupChat from autogen_agentchat.conditions import TextMentionTermination from autogen_ext.models.replay import ReplayChatCompletionClient async def main() -> None: # Create a runtime runtime = SingleThreadedAgentRuntime() runtime.start() # Create a model client. model_client = ReplayChatCompletionClient( ["1", "2", "3", "4", "5", "6", "7", "8", "9", "10"], ) # Create agents agent1 = AssistantAgent("assistant1", model_client=model_client, system_message="You are a helpful assistant.") agent2 = AssistantAgent("assistant2", model_client=model_client, system_message="You are a helpful assistant.") # Create a termination condition termination_condition = TextMentionTermination("10", sources=["assistant1", "assistant2"]) # Create a team team = RoundRobinGroupChat([agent1, agent2], runtime=runtime, termination_condition=termination_condition) # Run the team stream = team.run_stream(task="Count to 10.") async for message in stream: print(message) # Save the state. state = await team.save_state() # Load the state to an existing team. await team.load_state(state) # Run the team again model_client.reset() stream = team.run_stream(task="Count to 10.") async for message in stream: print(message) # Create a new team, with the same agent names. agent3 = AssistantAgent("assistant1", model_client=model_client, system_message="You are a helpful assistant.") agent4 = AssistantAgent("assistant2", model_client=model_client, system_message="You are a helpful assistant.") new_team = RoundRobinGroupChat([agent3, agent4], runtime=runtime, termination_condition=termination_condition) # Load the state to the new team. await new_team.load_state(state) # Run the new team model_client.reset() new_stream = new_team.run_stream(task="Count to 10.") async for message in new_stream: print(message) # Stop the runtime await runtime.stop() asyncio.run(main()) ``` TODOs as future PRs: 1. Documentation. 2. How to handle errors in custom runtime when the agent has exception? --------- Co-authored-by: Ryan Sweet <rysweet@microsoft.com>
1114 lines
46 KiB
Python
1114 lines
46 KiB
Python
import asyncio
|
|
import json
|
|
import logging
|
|
import tempfile
|
|
from typing import AsyncGenerator, List, Sequence
|
|
|
|
import pytest
|
|
import pytest_asyncio
|
|
from autogen_agentchat import EVENT_LOGGER_NAME
|
|
from autogen_agentchat.agents import (
|
|
AssistantAgent,
|
|
BaseChatAgent,
|
|
CodeExecutorAgent,
|
|
)
|
|
from autogen_agentchat.base import Handoff, Response, TaskResult
|
|
from autogen_agentchat.conditions import HandoffTermination, MaxMessageTermination, TextMentionTermination
|
|
from autogen_agentchat.messages import (
|
|
AgentEvent,
|
|
ChatMessage,
|
|
HandoffMessage,
|
|
MultiModalMessage,
|
|
StopMessage,
|
|
TextMessage,
|
|
ToolCallExecutionEvent,
|
|
ToolCallRequestEvent,
|
|
ToolCallSummaryMessage,
|
|
)
|
|
from autogen_agentchat.teams import MagenticOneGroupChat, RoundRobinGroupChat, SelectorGroupChat, Swarm
|
|
from autogen_agentchat.teams._group_chat._round_robin_group_chat import RoundRobinGroupChatManager
|
|
from autogen_agentchat.teams._group_chat._selector_group_chat import SelectorGroupChatManager
|
|
from autogen_agentchat.teams._group_chat._swarm_group_chat import SwarmGroupChatManager
|
|
from autogen_agentchat.ui import Console
|
|
from autogen_core import AgentId, AgentRuntime, CancellationToken, FunctionCall, SingleThreadedAgentRuntime
|
|
from autogen_core.models import (
|
|
AssistantMessage,
|
|
CreateResult,
|
|
FunctionExecutionResult,
|
|
FunctionExecutionResultMessage,
|
|
LLMMessage,
|
|
RequestUsage,
|
|
UserMessage,
|
|
)
|
|
from autogen_core.tools import FunctionTool
|
|
from autogen_ext.code_executors.local import LocalCommandLineCodeExecutor
|
|
from autogen_ext.models.openai import OpenAIChatCompletionClient
|
|
from autogen_ext.models.replay import ReplayChatCompletionClient
|
|
from utils import FileLogHandler
|
|
|
|
logger = logging.getLogger(EVENT_LOGGER_NAME)
|
|
logger.setLevel(logging.DEBUG)
|
|
logger.addHandler(FileLogHandler("test_group_chat.log"))
|
|
|
|
|
|
class _EchoAgent(BaseChatAgent):
|
|
def __init__(self, name: str, description: str) -> None:
|
|
super().__init__(name, description)
|
|
self._last_message: str | None = None
|
|
self._total_messages = 0
|
|
|
|
@property
|
|
def produced_message_types(self) -> Sequence[type[ChatMessage]]:
|
|
return (TextMessage,)
|
|
|
|
@property
|
|
def total_messages(self) -> int:
|
|
return self._total_messages
|
|
|
|
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
|
|
if len(messages) > 0:
|
|
assert isinstance(messages[0], TextMessage)
|
|
self._last_message = messages[0].content
|
|
self._total_messages += 1
|
|
return Response(chat_message=TextMessage(content=messages[0].content, source=self.name))
|
|
else:
|
|
assert self._last_message is not None
|
|
self._total_messages += 1
|
|
return Response(chat_message=TextMessage(content=self._last_message, source=self.name))
|
|
|
|
async def on_reset(self, cancellation_token: CancellationToken) -> None:
|
|
self._last_message = None
|
|
|
|
|
|
class _FlakyAgent(BaseChatAgent):
|
|
def __init__(self, name: str, description: str) -> None:
|
|
super().__init__(name, description)
|
|
self._last_message: str | None = None
|
|
self._total_messages = 0
|
|
|
|
@property
|
|
def produced_message_types(self) -> Sequence[type[ChatMessage]]:
|
|
return (TextMessage,)
|
|
|
|
@property
|
|
def total_messages(self) -> int:
|
|
return self._total_messages
|
|
|
|
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
|
|
raise ValueError("I am a flaky agent...")
|
|
|
|
async def on_reset(self, cancellation_token: CancellationToken) -> None:
|
|
self._last_message = None
|
|
|
|
|
|
class _StopAgent(_EchoAgent):
|
|
def __init__(self, name: str, description: str, *, stop_at: int = 1) -> None:
|
|
super().__init__(name, description)
|
|
self._count = 0
|
|
self._stop_at = stop_at
|
|
|
|
@property
|
|
def produced_message_types(self) -> Sequence[type[ChatMessage]]:
|
|
return (TextMessage, StopMessage)
|
|
|
|
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
|
|
self._count += 1
|
|
if self._count < self._stop_at:
|
|
return await super().on_messages(messages, cancellation_token)
|
|
return Response(chat_message=StopMessage(content="TERMINATE", source=self.name))
|
|
|
|
|
|
def _pass_function(input: str) -> str:
|
|
return "pass"
|
|
|
|
|
|
@pytest_asyncio.fixture(params=["single_threaded", "embedded"]) # type: ignore
|
|
async def runtime(request: pytest.FixtureRequest) -> AsyncGenerator[AgentRuntime | None, None]:
|
|
if request.param == "single_threaded":
|
|
runtime = SingleThreadedAgentRuntime()
|
|
runtime.start()
|
|
yield runtime
|
|
await runtime.stop()
|
|
elif request.param == "embedded":
|
|
yield None
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_round_robin_group_chat(runtime: AgentRuntime | None) -> None:
|
|
model_client = ReplayChatCompletionClient(
|
|
[
|
|
'Here is the program\n ```python\nprint("Hello, world!")\n```',
|
|
"TERMINATE",
|
|
],
|
|
)
|
|
with tempfile.TemporaryDirectory() as temp_dir:
|
|
code_executor_agent = CodeExecutorAgent(
|
|
"code_executor", code_executor=LocalCommandLineCodeExecutor(work_dir=temp_dir)
|
|
)
|
|
coding_assistant_agent = AssistantAgent(
|
|
"coding_assistant",
|
|
model_client=model_client,
|
|
)
|
|
termination = TextMentionTermination("TERMINATE")
|
|
team = RoundRobinGroupChat(
|
|
participants=[coding_assistant_agent, code_executor_agent],
|
|
termination_condition=termination,
|
|
runtime=runtime,
|
|
)
|
|
result = await team.run(
|
|
task="Write a program that prints 'Hello, world!'",
|
|
)
|
|
expected_messages = [
|
|
"Write a program that prints 'Hello, world!'",
|
|
'Here is the program\n ```python\nprint("Hello, world!")\n```',
|
|
"Hello, world!",
|
|
"TERMINATE",
|
|
]
|
|
# Normalize the messages to remove \r\n and any leading/trailing whitespace.
|
|
normalized_messages = [
|
|
msg.content.replace("\r\n", "\n").rstrip("\n") if isinstance(msg.content, str) else msg.content
|
|
for msg in result.messages
|
|
]
|
|
|
|
# Assert that all expected messages are in the collected messages
|
|
assert normalized_messages == expected_messages
|
|
|
|
assert result.stop_reason is not None and result.stop_reason == "Text 'TERMINATE' mentioned"
|
|
|
|
# Test streaming.
|
|
model_client.reset()
|
|
index = 0
|
|
await team.reset()
|
|
async for message in team.run_stream(
|
|
task="Write a program that prints 'Hello, world!'",
|
|
):
|
|
if isinstance(message, TaskResult):
|
|
assert message == result
|
|
else:
|
|
assert message == result.messages[index]
|
|
index += 1
|
|
|
|
# Test message input.
|
|
# Text message.
|
|
model_client.reset()
|
|
index = 0
|
|
await team.reset()
|
|
result_2 = await team.run(
|
|
task=TextMessage(content="Write a program that prints 'Hello, world!'", source="user")
|
|
)
|
|
assert result == result_2
|
|
|
|
# Test multi-modal message.
|
|
model_client.reset()
|
|
index = 0
|
|
await team.reset()
|
|
result_2 = await team.run(
|
|
task=MultiModalMessage(content=["Write a program that prints 'Hello, world!'"], source="user")
|
|
)
|
|
assert result.messages[0].content == result_2.messages[0].content[0]
|
|
assert result.messages[1:] == result_2.messages[1:]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_round_robin_group_chat_state(runtime: AgentRuntime | None) -> None:
|
|
model_client = ReplayChatCompletionClient(
|
|
["No facts", "No plan", "print('Hello, world!')", "TERMINATE"],
|
|
)
|
|
agent1 = AssistantAgent("agent1", model_client=model_client)
|
|
agent2 = AssistantAgent("agent2", model_client=model_client)
|
|
termination = TextMentionTermination("TERMINATE")
|
|
team1 = RoundRobinGroupChat(participants=[agent1, agent2], termination_condition=termination, runtime=runtime)
|
|
await team1.run(task="Write a program that prints 'Hello, world!'")
|
|
state = await team1.save_state()
|
|
|
|
agent3 = AssistantAgent("agent1", model_client=model_client)
|
|
agent4 = AssistantAgent("agent2", model_client=model_client)
|
|
team2 = RoundRobinGroupChat(participants=[agent3, agent4], termination_condition=termination, runtime=runtime)
|
|
await team2.load_state(state)
|
|
state2 = await team2.save_state()
|
|
assert state == state2
|
|
|
|
agent1_model_ctx_messages = await agent1._model_context.get_messages() # pyright: ignore
|
|
agent2_model_ctx_messages = await agent2._model_context.get_messages() # pyright: ignore
|
|
agent3_model_ctx_messages = await agent3._model_context.get_messages() # pyright: ignore
|
|
agent4_model_ctx_messages = await agent4._model_context.get_messages() # pyright: ignore
|
|
assert agent3_model_ctx_messages == agent1_model_ctx_messages
|
|
assert agent4_model_ctx_messages == agent2_model_ctx_messages
|
|
manager_1 = await team1._runtime.try_get_underlying_agent_instance( # pyright: ignore
|
|
AgentId(f"{team1._group_chat_manager_name}_{team1._team_id}", team1._team_id), # pyright: ignore
|
|
RoundRobinGroupChatManager, # pyright: ignore
|
|
) # pyright: ignore
|
|
manager_2 = await team2._runtime.try_get_underlying_agent_instance( # pyright: ignore
|
|
AgentId(f"{team2._group_chat_manager_name}_{team2._team_id}", team2._team_id), # pyright: ignore
|
|
RoundRobinGroupChatManager, # pyright: ignore
|
|
) # pyright: ignore
|
|
assert manager_1._current_turn == manager_2._current_turn # pyright: ignore
|
|
assert manager_1._message_thread == manager_2._message_thread # pyright: ignore
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_round_robin_group_chat_with_tools(runtime: AgentRuntime | None) -> None:
|
|
model_client = ReplayChatCompletionClient(
|
|
chat_completions=[
|
|
CreateResult(
|
|
finish_reason="function_calls",
|
|
content=[FunctionCall(id="1", name="pass", arguments=json.dumps({"input": "pass"}))],
|
|
usage=RequestUsage(prompt_tokens=0, completion_tokens=0),
|
|
cached=False,
|
|
),
|
|
"Hello",
|
|
"TERMINATE",
|
|
],
|
|
model_info={"family": "gpt-4o", "function_calling": True, "json_output": True, "vision": True},
|
|
)
|
|
tool = FunctionTool(_pass_function, name="pass", description="pass function")
|
|
tool_use_agent = AssistantAgent("tool_use_agent", model_client=model_client, tools=[tool])
|
|
echo_agent = _EchoAgent("echo_agent", description="echo agent")
|
|
termination = TextMentionTermination("TERMINATE")
|
|
team = RoundRobinGroupChat(
|
|
participants=[tool_use_agent, echo_agent], termination_condition=termination, runtime=runtime
|
|
)
|
|
result = await team.run(
|
|
task="Write a program that prints 'Hello, world!'",
|
|
)
|
|
assert len(result.messages) == 8
|
|
assert isinstance(result.messages[0], TextMessage) # task
|
|
assert isinstance(result.messages[1], ToolCallRequestEvent) # tool call
|
|
assert isinstance(result.messages[2], ToolCallExecutionEvent) # tool call result
|
|
assert isinstance(result.messages[3], ToolCallSummaryMessage) # tool use agent response
|
|
assert result.messages[3].content == "pass" # ensure the tool call was executed
|
|
assert isinstance(result.messages[4], TextMessage) # echo agent response
|
|
assert isinstance(result.messages[5], TextMessage) # tool use agent response
|
|
assert isinstance(result.messages[6], TextMessage) # echo agent response
|
|
assert isinstance(result.messages[7], TextMessage) # tool use agent response, that has TERMINATE
|
|
assert result.messages[7].content == "TERMINATE"
|
|
|
|
assert result.stop_reason is not None and result.stop_reason == "Text 'TERMINATE' mentioned"
|
|
|
|
# Test streaming.
|
|
await tool_use_agent._model_context.clear() # pyright: ignore
|
|
model_client.reset()
|
|
index = 0
|
|
await team.reset()
|
|
async for message in team.run_stream(
|
|
task="Write a program that prints 'Hello, world!'",
|
|
):
|
|
if isinstance(message, TaskResult):
|
|
assert message == result
|
|
else:
|
|
assert message == result.messages[index]
|
|
index += 1
|
|
|
|
# Test Console.
|
|
await tool_use_agent._model_context.clear() # pyright: ignore
|
|
model_client.reset()
|
|
index = 0
|
|
await team.reset()
|
|
result2 = await Console(team.run_stream(task="Write a program that prints 'Hello, world!'"))
|
|
assert result2 == result
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_round_robin_group_chat_with_resume_and_reset(runtime: AgentRuntime | None) -> None:
|
|
agent_1 = _EchoAgent("agent_1", description="echo agent 1")
|
|
agent_2 = _EchoAgent("agent_2", description="echo agent 2")
|
|
agent_3 = _EchoAgent("agent_3", description="echo agent 3")
|
|
agent_4 = _EchoAgent("agent_4", description="echo agent 4")
|
|
termination = MaxMessageTermination(3)
|
|
team = RoundRobinGroupChat(
|
|
participants=[agent_1, agent_2, agent_3, agent_4], termination_condition=termination, runtime=runtime
|
|
)
|
|
result = await team.run(
|
|
task="Write a program that prints 'Hello, world!'",
|
|
)
|
|
assert len(result.messages) == 3
|
|
assert result.messages[1].source == "agent_1"
|
|
assert result.messages[2].source == "agent_2"
|
|
assert result.stop_reason is not None
|
|
|
|
# Resume.
|
|
result = await team.run()
|
|
assert len(result.messages) == 3
|
|
assert result.messages[0].source == "agent_3"
|
|
assert result.messages[1].source == "agent_4"
|
|
assert result.messages[2].source == "agent_1"
|
|
assert result.stop_reason is not None
|
|
|
|
# Reset.
|
|
await team.reset()
|
|
result = await team.run(task="Write a program that prints 'Hello, world!'")
|
|
assert len(result.messages) == 3
|
|
assert result.messages[1].source == "agent_1"
|
|
assert result.messages[2].source == "agent_2"
|
|
assert result.stop_reason is not None
|
|
|
|
|
|
# TODO: add runtime fixture for testing with custom runtime once the issue regarding
|
|
# hanging on exception is resolved.
|
|
@pytest.mark.asyncio
|
|
async def test_round_robin_group_chat_with_exception_raised() -> None:
|
|
agent_1 = _EchoAgent("agent_1", description="echo agent 1")
|
|
agent_2 = _FlakyAgent("agent_2", description="echo agent 2")
|
|
agent_3 = _EchoAgent("agent_3", description="echo agent 3")
|
|
termination = MaxMessageTermination(3)
|
|
team = RoundRobinGroupChat(
|
|
participants=[agent_1, agent_2, agent_3],
|
|
termination_condition=termination,
|
|
)
|
|
|
|
with pytest.raises(ValueError, match="I am a flaky agent..."):
|
|
await team.run(
|
|
task="Write a program that prints 'Hello, world!'",
|
|
)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_round_robin_group_chat_max_turn(runtime: AgentRuntime | None) -> None:
|
|
agent_1 = _EchoAgent("agent_1", description="echo agent 1")
|
|
agent_2 = _EchoAgent("agent_2", description="echo agent 2")
|
|
agent_3 = _EchoAgent("agent_3", description="echo agent 3")
|
|
agent_4 = _EchoAgent("agent_4", description="echo agent 4")
|
|
team = RoundRobinGroupChat(participants=[agent_1, agent_2, agent_3, agent_4], max_turns=3, runtime=runtime)
|
|
result = await team.run(
|
|
task="Write a program that prints 'Hello, world!'",
|
|
)
|
|
assert len(result.messages) == 4
|
|
assert result.messages[1].source == "agent_1"
|
|
assert result.messages[2].source == "agent_2"
|
|
assert result.messages[3].source == "agent_3"
|
|
assert result.stop_reason is not None
|
|
|
|
# Resume.
|
|
result = await team.run()
|
|
assert len(result.messages) == 3
|
|
assert result.messages[0].source == "agent_4"
|
|
assert result.messages[1].source == "agent_1"
|
|
assert result.messages[2].source == "agent_2"
|
|
assert result.stop_reason is not None
|
|
|
|
# Reset.
|
|
await team.reset()
|
|
result = await team.run(task="Write a program that prints 'Hello, world!'")
|
|
assert len(result.messages) == 4
|
|
assert result.messages[1].source == "agent_1"
|
|
assert result.messages[2].source == "agent_2"
|
|
assert result.messages[3].source == "agent_3"
|
|
assert result.stop_reason is not None
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_round_robin_group_chat_cancellation(runtime: AgentRuntime | None) -> None:
|
|
agent_1 = _EchoAgent("agent_1", description="echo agent 1")
|
|
agent_2 = _EchoAgent("agent_2", description="echo agent 2")
|
|
agent_3 = _EchoAgent("agent_3", description="echo agent 3")
|
|
agent_4 = _EchoAgent("agent_4", description="echo agent 4")
|
|
# Set max_turns to a large number to avoid stopping due to max_turns before cancellation.
|
|
team = RoundRobinGroupChat(participants=[agent_1, agent_2, agent_3, agent_4], max_turns=1000, runtime=runtime)
|
|
cancellation_token = CancellationToken()
|
|
run_task = asyncio.create_task(
|
|
team.run(
|
|
task="Write a program that prints 'Hello, world!'",
|
|
cancellation_token=cancellation_token,
|
|
)
|
|
)
|
|
await asyncio.sleep(0.1)
|
|
# Cancel the task.
|
|
cancellation_token.cancel()
|
|
with pytest.raises(asyncio.CancelledError):
|
|
await run_task
|
|
|
|
# Still can run again and finish the task.
|
|
result = await team.run()
|
|
assert result.stop_reason is not None and result.stop_reason == "Maximum number of turns 1000 reached."
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_selector_group_chat(runtime: AgentRuntime | None) -> None:
|
|
model_client = ReplayChatCompletionClient(
|
|
chat_completions=[
|
|
"agent3",
|
|
"agent2",
|
|
"agent1",
|
|
"agent2",
|
|
"agent1",
|
|
]
|
|
)
|
|
agent1 = _StopAgent("agent1", description="echo agent 1", stop_at=2)
|
|
agent2 = _EchoAgent("agent2", description="echo agent 2")
|
|
agent3 = _EchoAgent("agent3", description="echo agent 3")
|
|
termination = TextMentionTermination("TERMINATE")
|
|
team = SelectorGroupChat(
|
|
participants=[agent1, agent2, agent3],
|
|
model_client=model_client,
|
|
termination_condition=termination,
|
|
runtime=runtime,
|
|
)
|
|
result = await team.run(
|
|
task="Write a program that prints 'Hello, world!'",
|
|
)
|
|
assert len(result.messages) == 6
|
|
assert result.messages[0].content == "Write a program that prints 'Hello, world!'"
|
|
assert result.messages[1].source == "agent3"
|
|
assert result.messages[2].source == "agent2"
|
|
assert result.messages[3].source == "agent1"
|
|
assert result.messages[4].source == "agent2"
|
|
assert result.messages[5].source == "agent1"
|
|
assert result.stop_reason is not None and result.stop_reason == "Text 'TERMINATE' mentioned"
|
|
|
|
# Test streaming.
|
|
model_client.reset()
|
|
agent1._count = 0 # pyright: ignore
|
|
index = 0
|
|
await team.reset()
|
|
async for message in team.run_stream(
|
|
task="Write a program that prints 'Hello, world!'",
|
|
):
|
|
if isinstance(message, TaskResult):
|
|
assert message == result
|
|
else:
|
|
assert message == result.messages[index]
|
|
index += 1
|
|
|
|
# Test Console.
|
|
model_client.reset()
|
|
agent1._count = 0 # pyright: ignore
|
|
index = 0
|
|
await team.reset()
|
|
result2 = await Console(team.run_stream(task="Write a program that prints 'Hello, world!'"))
|
|
assert result2 == result
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_selector_group_chat_state(runtime: AgentRuntime | None) -> None:
|
|
model_client = ReplayChatCompletionClient(
|
|
["agent1", "No facts", "agent2", "No plan", "agent1", "print('Hello, world!')", "agent2", "TERMINATE"],
|
|
)
|
|
agent1 = AssistantAgent("agent1", model_client=model_client)
|
|
agent2 = AssistantAgent("agent2", model_client=model_client)
|
|
termination = TextMentionTermination("TERMINATE")
|
|
team1 = SelectorGroupChat(
|
|
participants=[agent1, agent2],
|
|
termination_condition=termination,
|
|
model_client=model_client,
|
|
runtime=runtime,
|
|
)
|
|
await team1.run(task="Write a program that prints 'Hello, world!'")
|
|
state = await team1.save_state()
|
|
|
|
agent3 = AssistantAgent("agent1", model_client=model_client)
|
|
agent4 = AssistantAgent("agent2", model_client=model_client)
|
|
team2 = SelectorGroupChat(
|
|
participants=[agent3, agent4], termination_condition=termination, model_client=model_client
|
|
)
|
|
await team2.load_state(state)
|
|
state2 = await team2.save_state()
|
|
assert state == state2
|
|
|
|
agent1_model_ctx_messages = await agent1._model_context.get_messages() # pyright: ignore
|
|
agent2_model_ctx_messages = await agent2._model_context.get_messages() # pyright: ignore
|
|
agent3_model_ctx_messages = await agent3._model_context.get_messages() # pyright: ignore
|
|
agent4_model_ctx_messages = await agent4._model_context.get_messages() # pyright: ignore
|
|
assert agent3_model_ctx_messages == agent1_model_ctx_messages
|
|
assert agent4_model_ctx_messages == agent2_model_ctx_messages
|
|
manager_1 = await team1._runtime.try_get_underlying_agent_instance( # pyright: ignore
|
|
AgentId(f"{team1._group_chat_manager_name}_{team1._team_id}", team1._team_id), # pyright: ignore
|
|
SelectorGroupChatManager, # pyright: ignore
|
|
) # pyright: ignore
|
|
manager_2 = await team2._runtime.try_get_underlying_agent_instance( # pyright: ignore
|
|
AgentId(f"{team2._group_chat_manager_name}_{team2._team_id}", team2._team_id), # pyright: ignore
|
|
SelectorGroupChatManager, # pyright: ignore
|
|
) # pyright: ignore
|
|
assert manager_1._message_thread == manager_2._message_thread # pyright: ignore
|
|
assert manager_1._previous_speaker == manager_2._previous_speaker # pyright: ignore
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_selector_group_chat_two_speakers(runtime: AgentRuntime | None) -> None:
|
|
model_client = ReplayChatCompletionClient(["agent2"])
|
|
|
|
agent1 = _StopAgent("agent1", description="echo agent 1", stop_at=2)
|
|
agent2 = _EchoAgent("agent2", description="echo agent 2")
|
|
termination = TextMentionTermination("TERMINATE")
|
|
team = SelectorGroupChat(
|
|
participants=[agent1, agent2],
|
|
termination_condition=termination,
|
|
model_client=model_client,
|
|
runtime=runtime,
|
|
)
|
|
result = await team.run(
|
|
task="Write a program that prints 'Hello, world!'",
|
|
)
|
|
assert len(result.messages) == 5
|
|
assert result.messages[0].content == "Write a program that prints 'Hello, world!'"
|
|
assert result.messages[1].source == "agent2"
|
|
assert result.messages[2].source == "agent1"
|
|
assert result.messages[3].source == "agent2"
|
|
assert result.messages[4].source == "agent1"
|
|
assert result.stop_reason is not None and result.stop_reason == "Text 'TERMINATE' mentioned"
|
|
|
|
# Test streaming.
|
|
model_client.reset()
|
|
agent1._count = 0 # pyright: ignore
|
|
index = 0
|
|
await team.reset()
|
|
async for message in team.run_stream(task="Write a program that prints 'Hello, world!'"):
|
|
if isinstance(message, TaskResult):
|
|
assert message == result
|
|
else:
|
|
assert message == result.messages[index]
|
|
index += 1
|
|
|
|
# Test Console.
|
|
model_client.reset()
|
|
agent1._count = 0 # pyright: ignore
|
|
index = 0
|
|
await team.reset()
|
|
result2 = await Console(team.run_stream(task="Write a program that prints 'Hello, world!'"))
|
|
assert result2 == result
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_selector_group_chat_two_speakers_allow_repeated(runtime: AgentRuntime | None) -> None:
|
|
model_client = ReplayChatCompletionClient(
|
|
[
|
|
"agent2",
|
|
"agent2",
|
|
"agent1",
|
|
]
|
|
)
|
|
agent1 = _StopAgent("agent1", description="echo agent 1", stop_at=1)
|
|
agent2 = _EchoAgent("agent2", description="echo agent 2")
|
|
termination = TextMentionTermination("TERMINATE")
|
|
team = SelectorGroupChat(
|
|
participants=[agent1, agent2],
|
|
model_client=model_client,
|
|
termination_condition=termination,
|
|
allow_repeated_speaker=True,
|
|
runtime=runtime,
|
|
)
|
|
result = await team.run(task="Write a program that prints 'Hello, world!'")
|
|
assert len(result.messages) == 4
|
|
assert result.messages[0].content == "Write a program that prints 'Hello, world!'"
|
|
assert result.messages[1].source == "agent2"
|
|
assert result.messages[2].source == "agent2"
|
|
assert result.messages[3].source == "agent1"
|
|
assert result.stop_reason is not None and result.stop_reason == "Text 'TERMINATE' mentioned"
|
|
|
|
# Test streaming.
|
|
model_client.reset()
|
|
index = 0
|
|
await team.reset()
|
|
async for message in team.run_stream(task="Write a program that prints 'Hello, world!'"):
|
|
if isinstance(message, TaskResult):
|
|
assert message == result
|
|
else:
|
|
assert message == result.messages[index]
|
|
index += 1
|
|
|
|
# Test Console.
|
|
model_client.reset()
|
|
index = 0
|
|
await team.reset()
|
|
result2 = await Console(team.run_stream(task="Write a program that prints 'Hello, world!'"))
|
|
assert result2 == result
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_selector_group_chat_succcess_after_2_attempts(runtime: AgentRuntime | None) -> None:
|
|
model_client = ReplayChatCompletionClient(
|
|
["agent2, agent3", "agent2"],
|
|
)
|
|
agent1 = _StopAgent("agent1", description="echo agent 1", stop_at=1)
|
|
agent2 = _EchoAgent("agent2", description="echo agent 2")
|
|
agent3 = _EchoAgent("agent3", description="echo agent 3")
|
|
team = SelectorGroupChat(
|
|
participants=[agent1, agent2, agent3],
|
|
model_client=model_client,
|
|
max_turns=1,
|
|
runtime=runtime,
|
|
)
|
|
result = await team.run(task="Write a program that prints 'Hello, world!'")
|
|
assert len(result.messages) == 2
|
|
assert result.messages[0].content == "Write a program that prints 'Hello, world!'"
|
|
assert result.messages[1].source == "agent2"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_selector_group_chat_fall_back_to_first_after_3_attempts(runtime: AgentRuntime | None) -> None:
|
|
model_client = ReplayChatCompletionClient(
|
|
[
|
|
"agent2, agent3", # Multiple speakers
|
|
"agent5", # Non-existent speaker
|
|
"agent3, agent1", # Multiple speakers
|
|
]
|
|
)
|
|
agent1 = _StopAgent("agent1", description="echo agent 1", stop_at=1)
|
|
agent2 = _EchoAgent("agent2", description="echo agent 2")
|
|
agent3 = _EchoAgent("agent3", description="echo agent 3")
|
|
team = SelectorGroupChat(
|
|
participants=[agent1, agent2, agent3],
|
|
model_client=model_client,
|
|
max_turns=1,
|
|
runtime=runtime,
|
|
)
|
|
result = await team.run(task="Write a program that prints 'Hello, world!'")
|
|
assert len(result.messages) == 2
|
|
assert result.messages[0].content == "Write a program that prints 'Hello, world!'"
|
|
assert result.messages[1].source == "agent1"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_selector_group_chat_fall_back_to_previous_after_3_attempts(runtime: AgentRuntime | None) -> None:
|
|
model_client = ReplayChatCompletionClient(
|
|
["agent2", "agent2", "agent2", "agent2"],
|
|
)
|
|
agent1 = _StopAgent("agent1", description="echo agent 1", stop_at=1)
|
|
agent2 = _EchoAgent("agent2", description="echo agent 2")
|
|
agent3 = _EchoAgent("agent3", description="echo agent 3")
|
|
team = SelectorGroupChat(
|
|
participants=[agent1, agent2, agent3],
|
|
model_client=model_client,
|
|
max_turns=2,
|
|
runtime=runtime,
|
|
)
|
|
result = await team.run(task="Write a program that prints 'Hello, world!'")
|
|
assert len(result.messages) == 3
|
|
assert result.messages[0].content == "Write a program that prints 'Hello, world!'"
|
|
assert result.messages[1].source == "agent2"
|
|
assert result.messages[2].source == "agent2"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_selector_group_chat_custom_selector(runtime: AgentRuntime | None) -> None:
|
|
model_client = ReplayChatCompletionClient(["agent3"])
|
|
agent1 = _EchoAgent("agent1", description="echo agent 1")
|
|
agent2 = _EchoAgent("agent2", description="echo agent 2")
|
|
agent3 = _EchoAgent("agent3", description="echo agent 3")
|
|
agent4 = _EchoAgent("agent4", description="echo agent 4")
|
|
|
|
def _select_agent(messages: Sequence[AgentEvent | ChatMessage]) -> str | None:
|
|
if len(messages) == 0:
|
|
return "agent1"
|
|
elif messages[-1].source == "agent1":
|
|
return "agent2"
|
|
elif messages[-1].source == "agent2":
|
|
return None
|
|
elif messages[-1].source == "agent3":
|
|
return "agent4"
|
|
else:
|
|
return "agent1"
|
|
|
|
termination = MaxMessageTermination(6)
|
|
team = SelectorGroupChat(
|
|
participants=[agent1, agent2, agent3, agent4],
|
|
model_client=model_client,
|
|
selector_func=_select_agent,
|
|
termination_condition=termination,
|
|
runtime=runtime,
|
|
)
|
|
result = await team.run(task="task")
|
|
assert len(result.messages) == 6
|
|
assert result.messages[1].source == "agent1"
|
|
assert result.messages[2].source == "agent2"
|
|
assert result.messages[3].source == "agent3"
|
|
assert result.messages[4].source == "agent4"
|
|
assert result.messages[5].source == "agent1"
|
|
assert (
|
|
result.stop_reason is not None
|
|
and result.stop_reason == "Maximum number of messages 6 reached, current message count: 6"
|
|
)
|
|
|
|
|
|
class _HandOffAgent(BaseChatAgent):
|
|
def __init__(self, name: str, description: str, next_agent: str) -> None:
|
|
super().__init__(name, description)
|
|
self._next_agent = next_agent
|
|
|
|
@property
|
|
def produced_message_types(self) -> Sequence[type[ChatMessage]]:
|
|
return (HandoffMessage,)
|
|
|
|
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
|
|
return Response(
|
|
chat_message=HandoffMessage(
|
|
content=f"Transferred to {self._next_agent}.", target=self._next_agent, source=self.name
|
|
)
|
|
)
|
|
|
|
async def on_reset(self, cancellation_token: CancellationToken) -> None:
|
|
pass
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_swarm_handoff(runtime: AgentRuntime | None) -> None:
|
|
first_agent = _HandOffAgent("first_agent", description="first agent", next_agent="second_agent")
|
|
second_agent = _HandOffAgent("second_agent", description="second agent", next_agent="third_agent")
|
|
third_agent = _HandOffAgent("third_agent", description="third agent", next_agent="first_agent")
|
|
|
|
termination = MaxMessageTermination(6)
|
|
team = Swarm([second_agent, first_agent, third_agent], termination_condition=termination, runtime=runtime)
|
|
result = await team.run(task="task")
|
|
assert len(result.messages) == 6
|
|
assert result.messages[0].content == "task"
|
|
assert result.messages[1].content == "Transferred to third_agent."
|
|
assert result.messages[2].content == "Transferred to first_agent."
|
|
assert result.messages[3].content == "Transferred to second_agent."
|
|
assert result.messages[4].content == "Transferred to third_agent."
|
|
assert result.messages[5].content == "Transferred to first_agent."
|
|
assert (
|
|
result.stop_reason is not None
|
|
and result.stop_reason == "Maximum number of messages 6 reached, current message count: 6"
|
|
)
|
|
|
|
# Test streaming.
|
|
index = 0
|
|
await team.reset()
|
|
stream = team.run_stream(task="task")
|
|
async for message in stream:
|
|
if isinstance(message, TaskResult):
|
|
assert message == result
|
|
else:
|
|
assert message == result.messages[index]
|
|
index += 1
|
|
|
|
# Test save and load.
|
|
state = await team.save_state()
|
|
first_agent2 = _HandOffAgent("first_agent", description="first agent", next_agent="second_agent")
|
|
second_agent2 = _HandOffAgent("second_agent", description="second agent", next_agent="third_agent")
|
|
third_agent2 = _HandOffAgent("third_agent", description="third agent", next_agent="first_agent")
|
|
team2 = Swarm([second_agent2, first_agent2, third_agent2], termination_condition=termination, runtime=runtime)
|
|
await team2.load_state(state)
|
|
state2 = await team2.save_state()
|
|
assert state == state2
|
|
manager_1 = await team._runtime.try_get_underlying_agent_instance( # pyright: ignore
|
|
AgentId(f"{team._group_chat_manager_name}_{team._team_id}", team._team_id), # pyright: ignore
|
|
SwarmGroupChatManager, # pyright: ignore
|
|
) # pyright: ignore
|
|
manager_2 = await team2._runtime.try_get_underlying_agent_instance( # pyright: ignore
|
|
AgentId(f"{team2._group_chat_manager_name}_{team2._team_id}", team2._team_id), # pyright: ignore
|
|
SwarmGroupChatManager, # pyright: ignore
|
|
) # pyright: ignore
|
|
assert manager_1._message_thread == manager_2._message_thread # pyright: ignore
|
|
assert manager_1._current_speaker == manager_2._current_speaker # pyright: ignore
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_swarm_handoff_using_tool_calls(runtime: AgentRuntime | None) -> None:
|
|
model_client = ReplayChatCompletionClient(
|
|
chat_completions=[
|
|
CreateResult(
|
|
finish_reason="function_calls",
|
|
content=[FunctionCall(id="1", name="handoff_to_agent2", arguments=json.dumps({}))],
|
|
usage=RequestUsage(prompt_tokens=0, completion_tokens=0),
|
|
cached=False,
|
|
),
|
|
"Hello",
|
|
"TERMINATE",
|
|
],
|
|
model_info={"family": "gpt-4o", "function_calling": True, "json_output": True, "vision": True},
|
|
)
|
|
agent1 = AssistantAgent(
|
|
"agent1",
|
|
model_client=model_client,
|
|
handoffs=[Handoff(target="agent2", name="handoff_to_agent2", message="handoff to agent2")],
|
|
)
|
|
agent2 = _HandOffAgent("agent2", description="agent 2", next_agent="agent1")
|
|
termination = TextMentionTermination("TERMINATE")
|
|
team = Swarm([agent1, agent2], termination_condition=termination, runtime=runtime)
|
|
result = await team.run(task="task")
|
|
assert len(result.messages) == 7
|
|
assert result.messages[0].content == "task"
|
|
assert isinstance(result.messages[1], ToolCallRequestEvent)
|
|
assert isinstance(result.messages[2], ToolCallExecutionEvent)
|
|
assert result.messages[3].content == "handoff to agent2"
|
|
assert result.messages[4].content == "Transferred to agent1."
|
|
assert result.messages[5].content == "Hello"
|
|
assert result.messages[6].content == "TERMINATE"
|
|
assert result.stop_reason is not None and result.stop_reason == "Text 'TERMINATE' mentioned"
|
|
|
|
# Test streaming.
|
|
await agent1._model_context.clear() # pyright: ignore
|
|
model_client.reset()
|
|
index = 0
|
|
await team.reset()
|
|
stream = team.run_stream(task="task")
|
|
async for message in stream:
|
|
if isinstance(message, TaskResult):
|
|
assert message == result
|
|
else:
|
|
assert message == result.messages[index]
|
|
index += 1
|
|
|
|
# Test Console
|
|
await agent1._model_context.clear() # pyright: ignore
|
|
model_client.reset()
|
|
index = 0
|
|
await team.reset()
|
|
result2 = await Console(team.run_stream(task="task"))
|
|
assert result2 == result
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_swarm_pause_and_resume(runtime: AgentRuntime | None) -> None:
|
|
first_agent = _HandOffAgent("first_agent", description="first agent", next_agent="second_agent")
|
|
second_agent = _HandOffAgent("second_agent", description="second agent", next_agent="third_agent")
|
|
third_agent = _HandOffAgent("third_agent", description="third agent", next_agent="first_agent")
|
|
|
|
team = Swarm([second_agent, first_agent, third_agent], max_turns=1, runtime=runtime)
|
|
result = await team.run(task="task")
|
|
assert len(result.messages) == 2
|
|
assert result.messages[0].content == "task"
|
|
assert result.messages[1].content == "Transferred to third_agent."
|
|
|
|
# Resume with a new task.
|
|
result = await team.run(task="new task")
|
|
assert len(result.messages) == 2
|
|
assert result.messages[0].content == "new task"
|
|
assert result.messages[1].content == "Transferred to first_agent."
|
|
|
|
# Resume with the same task.
|
|
result = await team.run()
|
|
assert len(result.messages) == 1
|
|
assert result.messages[0].content == "Transferred to second_agent."
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_swarm_with_parallel_tool_calls(runtime: AgentRuntime | None) -> None:
|
|
model_client = ReplayChatCompletionClient(
|
|
[
|
|
CreateResult(
|
|
finish_reason="function_calls",
|
|
content=[
|
|
FunctionCall(id="1", name="tool1", arguments="{}"),
|
|
FunctionCall(id="2", name="tool2", arguments="{}"),
|
|
FunctionCall(id="3", name="handoff_to_agent2", arguments=json.dumps({})),
|
|
],
|
|
usage=RequestUsage(prompt_tokens=0, completion_tokens=0),
|
|
cached=False,
|
|
),
|
|
"Hello",
|
|
"TERMINATE",
|
|
],
|
|
model_info={"family": "gpt-4o", "function_calling": True, "json_output": True, "vision": True},
|
|
)
|
|
|
|
expected_handoff_context: List[LLMMessage] = [
|
|
AssistantMessage(
|
|
source="agent1",
|
|
content=[
|
|
FunctionCall(id="1", name="tool1", arguments="{}"),
|
|
FunctionCall(id="2", name="tool2", arguments="{}"),
|
|
],
|
|
),
|
|
FunctionExecutionResultMessage(
|
|
content=[
|
|
FunctionExecutionResult(content="tool1", call_id="1", is_error=False, name="tool1"),
|
|
FunctionExecutionResult(content="tool2", call_id="2", is_error=False, name="tool2"),
|
|
]
|
|
),
|
|
]
|
|
|
|
def tool1() -> str:
|
|
return "tool1"
|
|
|
|
def tool2() -> str:
|
|
return "tool2"
|
|
|
|
agent1 = AssistantAgent(
|
|
"agent1",
|
|
model_client=model_client,
|
|
handoffs=[Handoff(target="agent2", name="handoff_to_agent2", message="handoff to agent2")],
|
|
tools=[tool1, tool2],
|
|
)
|
|
agent2 = AssistantAgent(
|
|
"agent2",
|
|
model_client=model_client,
|
|
)
|
|
termination = TextMentionTermination("TERMINATE")
|
|
team = Swarm([agent1, agent2], termination_condition=termination, runtime=runtime)
|
|
result = await team.run(task="task")
|
|
assert len(result.messages) == 6
|
|
assert result.messages[0] == TextMessage(content="task", source="user")
|
|
assert isinstance(result.messages[1], ToolCallRequestEvent)
|
|
assert isinstance(result.messages[2], ToolCallExecutionEvent)
|
|
assert result.messages[3] == HandoffMessage(
|
|
content="handoff to agent2",
|
|
target="agent2",
|
|
source="agent1",
|
|
context=expected_handoff_context,
|
|
)
|
|
assert result.messages[4].content == "Hello"
|
|
assert result.messages[4].source == "agent2"
|
|
assert result.messages[5].content == "TERMINATE"
|
|
assert result.messages[5].source == "agent2"
|
|
|
|
# Verify the tool calls are in agent2's context.
|
|
agent2_model_ctx_messages = await agent2._model_context.get_messages() # pyright: ignore
|
|
assert agent2_model_ctx_messages[0] == UserMessage(content="task", source="user")
|
|
assert agent2_model_ctx_messages[1] == expected_handoff_context[0]
|
|
assert agent2_model_ctx_messages[2] == expected_handoff_context[1]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_swarm_with_handoff_termination(runtime: AgentRuntime | None) -> None:
|
|
first_agent = _HandOffAgent("first_agent", description="first agent", next_agent="second_agent")
|
|
second_agent = _HandOffAgent("second_agent", description="second agent", next_agent="third_agent")
|
|
third_agent = _HandOffAgent("third_agent", description="third agent", next_agent="first_agent")
|
|
|
|
# Handoff to an existing agent.
|
|
termination = HandoffTermination(target="third_agent")
|
|
team = Swarm([second_agent, first_agent, third_agent], termination_condition=termination, runtime=runtime)
|
|
# Start
|
|
result = await team.run(task="task")
|
|
assert len(result.messages) == 2
|
|
assert result.messages[0].content == "task"
|
|
assert result.messages[1].content == "Transferred to third_agent."
|
|
# Resume existing.
|
|
result = await team.run()
|
|
assert len(result.messages) == 3
|
|
assert result.messages[0].content == "Transferred to first_agent."
|
|
assert result.messages[1].content == "Transferred to second_agent."
|
|
assert result.messages[2].content == "Transferred to third_agent."
|
|
# Resume new task.
|
|
result = await team.run(task="new task")
|
|
assert len(result.messages) == 4
|
|
assert result.messages[0].content == "new task"
|
|
assert result.messages[1].content == "Transferred to first_agent."
|
|
assert result.messages[2].content == "Transferred to second_agent."
|
|
assert result.messages[3].content == "Transferred to third_agent."
|
|
|
|
# Handoff to a non-existing agent.
|
|
third_agent = _HandOffAgent("third_agent", description="third agent", next_agent="non_existing_agent")
|
|
termination = HandoffTermination(target="non_existing_agent")
|
|
team = Swarm([second_agent, first_agent, third_agent], termination_condition=termination, runtime=runtime)
|
|
# Start
|
|
result = await team.run(task="task")
|
|
assert len(result.messages) == 3
|
|
assert result.messages[0].content == "task"
|
|
assert result.messages[1].content == "Transferred to third_agent."
|
|
assert result.messages[2].content == "Transferred to non_existing_agent."
|
|
# Attempt to resume.
|
|
with pytest.raises(ValueError):
|
|
await team.run()
|
|
# Attempt to resume with a new task.
|
|
with pytest.raises(ValueError):
|
|
await team.run(task="new task")
|
|
# Resume with a HandoffMessage
|
|
result = await team.run(task=HandoffMessage(content="Handoff to first_agent.", target="first_agent", source="user"))
|
|
assert len(result.messages) == 4
|
|
assert result.messages[0].content == "Handoff to first_agent."
|
|
assert result.messages[1].content == "Transferred to second_agent."
|
|
assert result.messages[2].content == "Transferred to third_agent."
|
|
assert result.messages[3].content == "Transferred to non_existing_agent."
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_round_robin_group_chat_with_message_list(runtime: AgentRuntime | None) -> None:
|
|
# Create a simple team with echo agents
|
|
agent1 = _EchoAgent("Agent1", "First agent")
|
|
agent2 = _EchoAgent("Agent2", "Second agent")
|
|
termination = MaxMessageTermination(4) # Stop after 4 messages
|
|
team = RoundRobinGroupChat([agent1, agent2], termination_condition=termination, runtime=runtime)
|
|
|
|
# Create a list of messages
|
|
messages: List[ChatMessage] = [
|
|
TextMessage(content="Message 1", source="user"),
|
|
TextMessage(content="Message 2", source="user"),
|
|
TextMessage(content="Message 3", source="user"),
|
|
]
|
|
|
|
# Run the team with the message list
|
|
result = await team.run(task=messages)
|
|
|
|
# Verify the messages were processed in order
|
|
assert len(result.messages) == 4 # Initial messages + echo until termination
|
|
assert result.messages[0].content == "Message 1" # First message
|
|
assert result.messages[1].content == "Message 2" # Second message
|
|
assert result.messages[2].content == "Message 3" # Third message
|
|
assert result.messages[3].content == "Message 1" # Echo from first agent
|
|
assert result.stop_reason == "Maximum number of messages 4 reached, current message count: 4"
|
|
|
|
# Test with streaming
|
|
await team.reset()
|
|
index = 0
|
|
async for message in team.run_stream(task=messages):
|
|
if isinstance(message, TaskResult):
|
|
assert message == result
|
|
else:
|
|
assert message == result.messages[index]
|
|
index += 1
|
|
|
|
# Test with invalid message list
|
|
with pytest.raises(ValueError, match="All messages in task list must be valid ChatMessage types"):
|
|
await team.run(task=["not a message"]) # type: ignore[list-item, arg-type] # intentionally testing invalid input
|
|
|
|
# Test with empty message list
|
|
with pytest.raises(ValueError, match="Task list cannot be empty"):
|
|
await team.run(task=[])
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_declarative_groupchats_with_config(runtime: AgentRuntime | None) -> None:
|
|
# Create basic agents and components for testing
|
|
agent1 = AssistantAgent(
|
|
"agent_1",
|
|
model_client=OpenAIChatCompletionClient(model="gpt-4o-2024-05-13", api_key=""),
|
|
handoffs=["agent_2"],
|
|
)
|
|
agent2 = AssistantAgent("agent_2", model_client=OpenAIChatCompletionClient(model="gpt-4o-2024-05-13", api_key=""))
|
|
termination = MaxMessageTermination(4)
|
|
model_client = OpenAIChatCompletionClient(model="gpt-4o-2024-05-13", api_key="")
|
|
|
|
# Test round robin - verify config is preserved
|
|
round_robin = RoundRobinGroupChat(participants=[agent1, agent2], termination_condition=termination, max_turns=5)
|
|
config = round_robin.dump_component()
|
|
loaded = RoundRobinGroupChat.load_component(config)
|
|
assert loaded.dump_component() == config
|
|
|
|
# Test selector group chat - verify config is preserved
|
|
selector_prompt = "Custom selector prompt with {roles}, {participants}, {history}"
|
|
selector = SelectorGroupChat(
|
|
participants=[agent1, agent2],
|
|
model_client=model_client,
|
|
termination_condition=termination,
|
|
max_turns=10,
|
|
selector_prompt=selector_prompt,
|
|
allow_repeated_speaker=True,
|
|
runtime=runtime,
|
|
)
|
|
selector_config = selector.dump_component()
|
|
selector_loaded = SelectorGroupChat.load_component(selector_config)
|
|
assert selector_loaded.dump_component() == selector_config
|
|
|
|
# Test swarm with handoff termination
|
|
handoff_termination = HandoffTermination(target="Agent2")
|
|
swarm = Swarm(
|
|
participants=[agent1, agent2], termination_condition=handoff_termination, max_turns=5, runtime=runtime
|
|
)
|
|
swarm_config = swarm.dump_component()
|
|
swarm_loaded = Swarm.load_component(swarm_config)
|
|
assert swarm_loaded.dump_component() == swarm_config
|
|
|
|
# Test MagenticOne with custom parameters
|
|
magentic = MagenticOneGroupChat(
|
|
participants=[agent1],
|
|
model_client=model_client,
|
|
max_turns=15,
|
|
max_stalls=5,
|
|
final_answer_prompt="Custom prompt",
|
|
runtime=runtime,
|
|
)
|
|
magentic_config = magentic.dump_component()
|
|
magentic_loaded = MagenticOneGroupChat.load_component(magentic_config)
|
|
assert magentic_loaded.dump_component() == magentic_config
|
|
|
|
# Verify component types are correctly set for each
|
|
for team in [loaded, selector, swarm, magentic]:
|
|
assert team.component_type == "team"
|
|
|
|
# Verify provider strings are correctly set
|
|
assert round_robin.dump_component().provider == "autogen_agentchat.teams.RoundRobinGroupChat"
|
|
assert selector.dump_component().provider == "autogen_agentchat.teams.SelectorGroupChat"
|
|
assert swarm.dump_component().provider == "autogen_agentchat.teams.Swarm"
|
|
assert magentic.dump_component().provider == "autogen_agentchat.teams.MagenticOneGroupChat"
|