mirror of
https://github.com/microsoft/autogen.git
synced 2025-07-13 20:11:00 +00:00

<!-- Thank you for your contribution! Please review https://microsoft.github.io/autogen/docs/Contribute before opening a pull request. --> <!-- Please add a reviewer to the assignee section when you create a PR. If you don't have the access to it, we will shortly find a reviewer and assign them to your PR. --> ## Why are these changes needed? <!-- Please give a short summary of the change and the problem this solves. --> The PR introduces two changes. The first change is adding a name attribute to `FunctionExecutionResult`. The motivation is that semantic kernel requires it for their function result interface and it seemed like a easy modification as `FunctionExecutionResult` is always created in the context of a `FunctionCall` which will contain the name. I'm unsure if there was a motivation to keep it out but this change makes it easier to trace which tool the result refers to and also increases api compatibility with SK. The second change is an update to how messages are mapped from autogen to semantic kernel, which includes an update/fix in the processing of function results. ## Related issue number <!-- For example: "Closes #1234" --> Related to #5675 but wont fix the underlying issue of anthropic requiring tools during AssistantAgent reflection. ## Checks - [ ] I've included any doc changes needed for <https://microsoft.github.io/autogen/>. See <https://github.com/microsoft/autogen/blob/main/CONTRIBUTING.md> to build and test documentation locally. - [ ] I've added tests (if relevant) corresponding to the changes introduced in this PR. - [ ] I've made sure all auto checks have passed. --------- Co-authored-by: Leonardo Pinheiro <lpinheiro@microsoft.com>
1380 lines
56 KiB
Python
1380 lines
56 KiB
Python
import asyncio
|
|
import json
|
|
import logging
|
|
import tempfile
|
|
from typing import Any, AsyncGenerator, List, Sequence
|
|
|
|
import pytest
|
|
from autogen_agentchat import EVENT_LOGGER_NAME
|
|
from autogen_agentchat.agents import (
|
|
AssistantAgent,
|
|
BaseChatAgent,
|
|
CodeExecutorAgent,
|
|
)
|
|
from autogen_agentchat.base import Handoff, Response, TaskResult
|
|
from autogen_agentchat.conditions import HandoffTermination, MaxMessageTermination, TextMentionTermination
|
|
from autogen_agentchat.messages import (
|
|
AgentEvent,
|
|
ChatMessage,
|
|
HandoffMessage,
|
|
MultiModalMessage,
|
|
StopMessage,
|
|
TextMessage,
|
|
ToolCallExecutionEvent,
|
|
ToolCallRequestEvent,
|
|
ToolCallSummaryMessage,
|
|
)
|
|
from autogen_agentchat.teams import MagenticOneGroupChat, RoundRobinGroupChat, SelectorGroupChat, Swarm
|
|
from autogen_agentchat.teams._group_chat._round_robin_group_chat import RoundRobinGroupChatManager
|
|
from autogen_agentchat.teams._group_chat._selector_group_chat import SelectorGroupChatManager
|
|
from autogen_agentchat.teams._group_chat._swarm_group_chat import SwarmGroupChatManager
|
|
from autogen_agentchat.ui import Console
|
|
from autogen_core import AgentId, CancellationToken, FunctionCall
|
|
from autogen_core.models import (
|
|
AssistantMessage,
|
|
FunctionExecutionResult,
|
|
FunctionExecutionResultMessage,
|
|
LLMMessage,
|
|
UserMessage,
|
|
)
|
|
from autogen_core.tools import FunctionTool
|
|
from autogen_ext.code_executors.local import LocalCommandLineCodeExecutor
|
|
from autogen_ext.models.openai import OpenAIChatCompletionClient
|
|
from autogen_ext.models.replay import ReplayChatCompletionClient
|
|
from openai.resources.chat.completions import AsyncCompletions
|
|
from openai.types.chat.chat_completion import ChatCompletion, Choice
|
|
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
|
from openai.types.chat.chat_completion_message import ChatCompletionMessage
|
|
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall, Function
|
|
from openai.types.completion_usage import CompletionUsage
|
|
from utils import FileLogHandler
|
|
|
|
logger = logging.getLogger(EVENT_LOGGER_NAME)
|
|
logger.setLevel(logging.DEBUG)
|
|
logger.addHandler(FileLogHandler("test_group_chat.log"))
|
|
|
|
|
|
class _MockChatCompletion:
|
|
def __init__(self, chat_completions: List[ChatCompletion]) -> None:
|
|
self._saved_chat_completions = chat_completions
|
|
self._curr_index = 0
|
|
|
|
async def mock_create(
|
|
self, *args: Any, **kwargs: Any
|
|
) -> ChatCompletion | AsyncGenerator[ChatCompletionChunk, None]:
|
|
await asyncio.sleep(0.1)
|
|
completion = self._saved_chat_completions[self._curr_index]
|
|
self._curr_index += 1
|
|
return completion
|
|
|
|
def reset(self) -> None:
|
|
self._curr_index = 0
|
|
|
|
|
|
class _EchoAgent(BaseChatAgent):
|
|
def __init__(self, name: str, description: str) -> None:
|
|
super().__init__(name, description)
|
|
self._last_message: str | None = None
|
|
self._total_messages = 0
|
|
|
|
@property
|
|
def produced_message_types(self) -> Sequence[type[ChatMessage]]:
|
|
return (TextMessage,)
|
|
|
|
@property
|
|
def total_messages(self) -> int:
|
|
return self._total_messages
|
|
|
|
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
|
|
if len(messages) > 0:
|
|
assert isinstance(messages[0], TextMessage)
|
|
self._last_message = messages[0].content
|
|
self._total_messages += 1
|
|
return Response(chat_message=TextMessage(content=messages[0].content, source=self.name))
|
|
else:
|
|
assert self._last_message is not None
|
|
self._total_messages += 1
|
|
return Response(chat_message=TextMessage(content=self._last_message, source=self.name))
|
|
|
|
async def on_reset(self, cancellation_token: CancellationToken) -> None:
|
|
self._last_message = None
|
|
|
|
|
|
class _FlakyAgent(BaseChatAgent):
|
|
def __init__(self, name: str, description: str) -> None:
|
|
super().__init__(name, description)
|
|
self._last_message: str | None = None
|
|
self._total_messages = 0
|
|
|
|
@property
|
|
def produced_message_types(self) -> Sequence[type[ChatMessage]]:
|
|
return (TextMessage,)
|
|
|
|
@property
|
|
def total_messages(self) -> int:
|
|
return self._total_messages
|
|
|
|
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
|
|
raise ValueError("I am a flaky agent...")
|
|
|
|
async def on_reset(self, cancellation_token: CancellationToken) -> None:
|
|
self._last_message = None
|
|
|
|
|
|
class _StopAgent(_EchoAgent):
|
|
def __init__(self, name: str, description: str, *, stop_at: int = 1) -> None:
|
|
super().__init__(name, description)
|
|
self._count = 0
|
|
self._stop_at = stop_at
|
|
|
|
@property
|
|
def produced_message_types(self) -> Sequence[type[ChatMessage]]:
|
|
return (TextMessage, StopMessage)
|
|
|
|
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
|
|
self._count += 1
|
|
if self._count < self._stop_at:
|
|
return await super().on_messages(messages, cancellation_token)
|
|
return Response(chat_message=StopMessage(content="TERMINATE", source=self.name))
|
|
|
|
|
|
def _pass_function(input: str) -> str:
|
|
return "pass"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_round_robin_group_chat(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
model = "gpt-4o-2024-05-13"
|
|
chat_completions = [
|
|
ChatCompletion(
|
|
id="id1",
|
|
choices=[
|
|
Choice(
|
|
finish_reason="stop",
|
|
index=0,
|
|
message=ChatCompletionMessage(
|
|
content="""Here is the program\n ```python\nprint("Hello, world!")\n```""",
|
|
role="assistant",
|
|
),
|
|
)
|
|
],
|
|
created=0,
|
|
model=model,
|
|
object="chat.completion",
|
|
usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
|
|
),
|
|
ChatCompletion(
|
|
id="id2",
|
|
choices=[
|
|
Choice(
|
|
finish_reason="stop",
|
|
index=0,
|
|
message=ChatCompletionMessage(content="TERMINATE", role="assistant"),
|
|
)
|
|
],
|
|
created=0,
|
|
model=model,
|
|
object="chat.completion",
|
|
usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
|
|
),
|
|
]
|
|
mock = _MockChatCompletion(chat_completions)
|
|
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
|
|
with tempfile.TemporaryDirectory() as temp_dir:
|
|
code_executor_agent = CodeExecutorAgent(
|
|
"code_executor", code_executor=LocalCommandLineCodeExecutor(work_dir=temp_dir)
|
|
)
|
|
coding_assistant_agent = AssistantAgent(
|
|
"coding_assistant", model_client=OpenAIChatCompletionClient(model=model, api_key="")
|
|
)
|
|
termination = TextMentionTermination("TERMINATE")
|
|
team = RoundRobinGroupChat(
|
|
participants=[coding_assistant_agent, code_executor_agent], termination_condition=termination
|
|
)
|
|
result = await team.run(
|
|
task="Write a program that prints 'Hello, world!'",
|
|
)
|
|
expected_messages = [
|
|
"Write a program that prints 'Hello, world!'",
|
|
'Here is the program\n ```python\nprint("Hello, world!")\n```',
|
|
"Hello, world!",
|
|
"TERMINATE",
|
|
]
|
|
# Normalize the messages to remove \r\n and any leading/trailing whitespace.
|
|
normalized_messages = [
|
|
msg.content.replace("\r\n", "\n").rstrip("\n") if isinstance(msg.content, str) else msg.content
|
|
for msg in result.messages
|
|
]
|
|
|
|
# Assert that all expected messages are in the collected messages
|
|
assert normalized_messages == expected_messages
|
|
|
|
assert result.stop_reason is not None and result.stop_reason == "Text 'TERMINATE' mentioned"
|
|
|
|
# Test streaming.
|
|
mock.reset()
|
|
index = 0
|
|
await team.reset()
|
|
async for message in team.run_stream(
|
|
task="Write a program that prints 'Hello, world!'",
|
|
):
|
|
if isinstance(message, TaskResult):
|
|
assert message == result
|
|
else:
|
|
assert message == result.messages[index]
|
|
index += 1
|
|
|
|
# Test message input.
|
|
# Text message.
|
|
mock.reset()
|
|
index = 0
|
|
await team.reset()
|
|
result_2 = await team.run(
|
|
task=TextMessage(content="Write a program that prints 'Hello, world!'", source="user")
|
|
)
|
|
assert result == result_2
|
|
|
|
# Test multi-modal message.
|
|
mock.reset()
|
|
index = 0
|
|
await team.reset()
|
|
result_2 = await team.run(
|
|
task=MultiModalMessage(content=["Write a program that prints 'Hello, world!'"], source="user")
|
|
)
|
|
assert result.messages[0].content == result_2.messages[0].content[0]
|
|
assert result.messages[1:] == result_2.messages[1:]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_round_robin_group_chat_state() -> None:
|
|
model_client = ReplayChatCompletionClient(
|
|
["No facts", "No plan", "print('Hello, world!')", "TERMINATE"],
|
|
)
|
|
agent1 = AssistantAgent("agent1", model_client=model_client)
|
|
agent2 = AssistantAgent("agent2", model_client=model_client)
|
|
termination = TextMentionTermination("TERMINATE")
|
|
team1 = RoundRobinGroupChat(participants=[agent1, agent2], termination_condition=termination)
|
|
await team1.run(task="Write a program that prints 'Hello, world!'")
|
|
state = await team1.save_state()
|
|
|
|
agent3 = AssistantAgent("agent1", model_client=model_client)
|
|
agent4 = AssistantAgent("agent2", model_client=model_client)
|
|
team2 = RoundRobinGroupChat(participants=[agent3, agent4], termination_condition=termination)
|
|
await team2.load_state(state)
|
|
state2 = await team2.save_state()
|
|
assert state == state2
|
|
|
|
agent1_model_ctx_messages = await agent1._model_context.get_messages() # pyright: ignore
|
|
agent2_model_ctx_messages = await agent2._model_context.get_messages() # pyright: ignore
|
|
agent3_model_ctx_messages = await agent3._model_context.get_messages() # pyright: ignore
|
|
agent4_model_ctx_messages = await agent4._model_context.get_messages() # pyright: ignore
|
|
assert agent3_model_ctx_messages == agent1_model_ctx_messages
|
|
assert agent4_model_ctx_messages == agent2_model_ctx_messages
|
|
manager_1 = await team1._runtime.try_get_underlying_agent_instance( # pyright: ignore
|
|
AgentId("group_chat_manager", team1._team_id), # pyright: ignore
|
|
RoundRobinGroupChatManager, # pyright: ignore
|
|
) # pyright: ignore
|
|
manager_2 = await team2._runtime.try_get_underlying_agent_instance( # pyright: ignore
|
|
AgentId("group_chat_manager", team2._team_id), # pyright: ignore
|
|
RoundRobinGroupChatManager, # pyright: ignore
|
|
) # pyright: ignore
|
|
assert manager_1._current_turn == manager_2._current_turn # pyright: ignore
|
|
assert manager_1._message_thread == manager_2._message_thread # pyright: ignore
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_round_robin_group_chat_with_tools(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
model = "gpt-4o-2024-05-13"
|
|
chat_completions = [
|
|
ChatCompletion(
|
|
id="id1",
|
|
choices=[
|
|
Choice(
|
|
finish_reason="tool_calls",
|
|
index=0,
|
|
message=ChatCompletionMessage(
|
|
content=None,
|
|
tool_calls=[
|
|
ChatCompletionMessageToolCall(
|
|
id="1",
|
|
type="function",
|
|
function=Function(
|
|
name="pass",
|
|
arguments=json.dumps({"input": "pass"}),
|
|
),
|
|
)
|
|
],
|
|
role="assistant",
|
|
),
|
|
)
|
|
],
|
|
created=0,
|
|
model=model,
|
|
object="chat.completion",
|
|
usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
|
|
),
|
|
ChatCompletion(
|
|
id="id2",
|
|
choices=[
|
|
Choice(finish_reason="stop", index=0, message=ChatCompletionMessage(content="Hello", role="assistant"))
|
|
],
|
|
created=0,
|
|
model=model,
|
|
object="chat.completion",
|
|
usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
|
|
),
|
|
ChatCompletion(
|
|
id="id2",
|
|
choices=[
|
|
Choice(
|
|
finish_reason="stop", index=0, message=ChatCompletionMessage(content="TERMINATE", role="assistant")
|
|
)
|
|
],
|
|
created=0,
|
|
model=model,
|
|
object="chat.completion",
|
|
usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
|
|
),
|
|
]
|
|
# Test with repeat tool calls once
|
|
mock = _MockChatCompletion(chat_completions)
|
|
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
|
|
tool = FunctionTool(_pass_function, name="pass", description="pass function")
|
|
tool_use_agent = AssistantAgent(
|
|
"tool_use_agent",
|
|
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
|
|
tools=[tool],
|
|
)
|
|
echo_agent = _EchoAgent("echo_agent", description="echo agent")
|
|
termination = TextMentionTermination("TERMINATE")
|
|
team = RoundRobinGroupChat(participants=[tool_use_agent, echo_agent], termination_condition=termination)
|
|
result = await team.run(
|
|
task="Write a program that prints 'Hello, world!'",
|
|
)
|
|
assert len(result.messages) == 8
|
|
assert isinstance(result.messages[0], TextMessage) # task
|
|
assert isinstance(result.messages[1], ToolCallRequestEvent) # tool call
|
|
assert isinstance(result.messages[2], ToolCallExecutionEvent) # tool call result
|
|
assert isinstance(result.messages[3], ToolCallSummaryMessage) # tool use agent response
|
|
assert result.messages[3].content == "pass" # ensure the tool call was executed
|
|
assert isinstance(result.messages[4], TextMessage) # echo agent response
|
|
assert isinstance(result.messages[5], TextMessage) # tool use agent response
|
|
assert isinstance(result.messages[6], TextMessage) # echo agent response
|
|
assert isinstance(result.messages[7], TextMessage) # tool use agent response, that has TERMINATE
|
|
assert result.messages[7].content == "TERMINATE"
|
|
|
|
assert result.stop_reason is not None and result.stop_reason == "Text 'TERMINATE' mentioned"
|
|
|
|
# Test streaming.
|
|
await tool_use_agent._model_context.clear() # pyright: ignore
|
|
mock.reset()
|
|
index = 0
|
|
await team.reset()
|
|
async for message in team.run_stream(
|
|
task="Write a program that prints 'Hello, world!'",
|
|
):
|
|
if isinstance(message, TaskResult):
|
|
assert message == result
|
|
else:
|
|
assert message == result.messages[index]
|
|
index += 1
|
|
|
|
# Test Console.
|
|
await tool_use_agent._model_context.clear() # pyright: ignore
|
|
mock.reset()
|
|
index = 0
|
|
await team.reset()
|
|
result2 = await Console(team.run_stream(task="Write a program that prints 'Hello, world!'"))
|
|
assert result2 == result
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_round_robin_group_chat_with_resume_and_reset() -> None:
|
|
agent_1 = _EchoAgent("agent_1", description="echo agent 1")
|
|
agent_2 = _EchoAgent("agent_2", description="echo agent 2")
|
|
agent_3 = _EchoAgent("agent_3", description="echo agent 3")
|
|
agent_4 = _EchoAgent("agent_4", description="echo agent 4")
|
|
termination = MaxMessageTermination(3)
|
|
team = RoundRobinGroupChat(participants=[agent_1, agent_2, agent_3, agent_4], termination_condition=termination)
|
|
result = await team.run(
|
|
task="Write a program that prints 'Hello, world!'",
|
|
)
|
|
assert len(result.messages) == 3
|
|
assert result.messages[1].source == "agent_1"
|
|
assert result.messages[2].source == "agent_2"
|
|
assert result.stop_reason is not None
|
|
|
|
# Resume.
|
|
result = await team.run()
|
|
assert len(result.messages) == 3
|
|
assert result.messages[0].source == "agent_3"
|
|
assert result.messages[1].source == "agent_4"
|
|
assert result.messages[2].source == "agent_1"
|
|
assert result.stop_reason is not None
|
|
|
|
# Reset.
|
|
await team.reset()
|
|
result = await team.run(task="Write a program that prints 'Hello, world!'")
|
|
assert len(result.messages) == 3
|
|
assert result.messages[1].source == "agent_1"
|
|
assert result.messages[2].source == "agent_2"
|
|
assert result.stop_reason is not None
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_round_robin_group_chat_with_exception_raised() -> None:
|
|
agent_1 = _EchoAgent("agent_1", description="echo agent 1")
|
|
agent_2 = _FlakyAgent("agent_2", description="echo agent 2")
|
|
agent_3 = _EchoAgent("agent_3", description="echo agent 3")
|
|
termination = MaxMessageTermination(3)
|
|
team = RoundRobinGroupChat(
|
|
participants=[agent_1, agent_2, agent_3],
|
|
termination_condition=termination,
|
|
)
|
|
|
|
with pytest.raises(ValueError, match="I am a flaky agent..."):
|
|
await team.run(
|
|
task="Write a program that prints 'Hello, world!'",
|
|
)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_round_robin_group_chat_max_turn() -> None:
|
|
agent_1 = _EchoAgent("agent_1", description="echo agent 1")
|
|
agent_2 = _EchoAgent("agent_2", description="echo agent 2")
|
|
agent_3 = _EchoAgent("agent_3", description="echo agent 3")
|
|
agent_4 = _EchoAgent("agent_4", description="echo agent 4")
|
|
team = RoundRobinGroupChat(participants=[agent_1, agent_2, agent_3, agent_4], max_turns=3)
|
|
result = await team.run(
|
|
task="Write a program that prints 'Hello, world!'",
|
|
)
|
|
assert len(result.messages) == 4
|
|
assert result.messages[1].source == "agent_1"
|
|
assert result.messages[2].source == "agent_2"
|
|
assert result.messages[3].source == "agent_3"
|
|
assert result.stop_reason is not None
|
|
|
|
# Resume.
|
|
result = await team.run()
|
|
assert len(result.messages) == 3
|
|
assert result.messages[0].source == "agent_4"
|
|
assert result.messages[1].source == "agent_1"
|
|
assert result.messages[2].source == "agent_2"
|
|
assert result.stop_reason is not None
|
|
|
|
# Reset.
|
|
await team.reset()
|
|
result = await team.run(task="Write a program that prints 'Hello, world!'")
|
|
assert len(result.messages) == 4
|
|
assert result.messages[1].source == "agent_1"
|
|
assert result.messages[2].source == "agent_2"
|
|
assert result.messages[3].source == "agent_3"
|
|
assert result.stop_reason is not None
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_round_robin_group_chat_cancellation() -> None:
|
|
agent_1 = _EchoAgent("agent_1", description="echo agent 1")
|
|
agent_2 = _EchoAgent("agent_2", description="echo agent 2")
|
|
agent_3 = _EchoAgent("agent_3", description="echo agent 3")
|
|
agent_4 = _EchoAgent("agent_4", description="echo agent 4")
|
|
# Set max_turns to a large number to avoid stopping due to max_turns before cancellation.
|
|
team = RoundRobinGroupChat(participants=[agent_1, agent_2, agent_3, agent_4], max_turns=1000)
|
|
cancellation_token = CancellationToken()
|
|
run_task = asyncio.create_task(
|
|
team.run(
|
|
task="Write a program that prints 'Hello, world!'",
|
|
cancellation_token=cancellation_token,
|
|
)
|
|
)
|
|
await asyncio.sleep(0.1)
|
|
# Cancel the task.
|
|
cancellation_token.cancel()
|
|
with pytest.raises(asyncio.CancelledError):
|
|
await run_task
|
|
|
|
# Total messages produced so far.
|
|
total_messages = agent_1.total_messages + agent_2.total_messages + agent_3.total_messages + agent_4.total_messages
|
|
|
|
# Still can run again and finish the task.
|
|
result = await team.run()
|
|
assert len(result.messages) + total_messages == 1000
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_selector_group_chat(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
model = "gpt-4o-2024-05-13"
|
|
chat_completions = [
|
|
ChatCompletion(
|
|
id="id2",
|
|
choices=[
|
|
Choice(finish_reason="stop", index=0, message=ChatCompletionMessage(content="agent3", role="assistant"))
|
|
],
|
|
created=0,
|
|
model=model,
|
|
object="chat.completion",
|
|
usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
|
|
),
|
|
ChatCompletion(
|
|
id="id2",
|
|
choices=[
|
|
Choice(finish_reason="stop", index=0, message=ChatCompletionMessage(content="agent2", role="assistant"))
|
|
],
|
|
created=0,
|
|
model=model,
|
|
object="chat.completion",
|
|
usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
|
|
),
|
|
ChatCompletion(
|
|
id="id2",
|
|
choices=[
|
|
Choice(finish_reason="stop", index=0, message=ChatCompletionMessage(content="agent1", role="assistant"))
|
|
],
|
|
created=0,
|
|
model=model,
|
|
object="chat.completion",
|
|
usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
|
|
),
|
|
ChatCompletion(
|
|
id="id2",
|
|
choices=[
|
|
Choice(finish_reason="stop", index=0, message=ChatCompletionMessage(content="agent2", role="assistant"))
|
|
],
|
|
created=0,
|
|
model=model,
|
|
object="chat.completion",
|
|
usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
|
|
),
|
|
ChatCompletion(
|
|
id="id2",
|
|
choices=[
|
|
Choice(finish_reason="stop", index=0, message=ChatCompletionMessage(content="agent1", role="assistant"))
|
|
],
|
|
created=0,
|
|
model=model,
|
|
object="chat.completion",
|
|
usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
|
|
),
|
|
]
|
|
mock = _MockChatCompletion(chat_completions)
|
|
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
|
|
|
|
agent1 = _StopAgent("agent1", description="echo agent 1", stop_at=2)
|
|
agent2 = _EchoAgent("agent2", description="echo agent 2")
|
|
agent3 = _EchoAgent("agent3", description="echo agent 3")
|
|
termination = TextMentionTermination("TERMINATE")
|
|
team = SelectorGroupChat(
|
|
participants=[agent1, agent2, agent3],
|
|
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
|
|
termination_condition=termination,
|
|
)
|
|
result = await team.run(
|
|
task="Write a program that prints 'Hello, world!'",
|
|
)
|
|
assert len(result.messages) == 6
|
|
assert result.messages[0].content == "Write a program that prints 'Hello, world!'"
|
|
assert result.messages[1].source == "agent3"
|
|
assert result.messages[2].source == "agent2"
|
|
assert result.messages[3].source == "agent1"
|
|
assert result.messages[4].source == "agent2"
|
|
assert result.messages[5].source == "agent1"
|
|
assert result.stop_reason is not None and result.stop_reason == "Text 'TERMINATE' mentioned"
|
|
|
|
# Test streaming.
|
|
mock.reset()
|
|
agent1._count = 0 # pyright: ignore
|
|
index = 0
|
|
await team.reset()
|
|
async for message in team.run_stream(
|
|
task="Write a program that prints 'Hello, world!'",
|
|
):
|
|
if isinstance(message, TaskResult):
|
|
assert message == result
|
|
else:
|
|
assert message == result.messages[index]
|
|
index += 1
|
|
|
|
# Test Console.
|
|
mock.reset()
|
|
agent1._count = 0 # pyright: ignore
|
|
index = 0
|
|
await team.reset()
|
|
result2 = await Console(team.run_stream(task="Write a program that prints 'Hello, world!'"))
|
|
assert result2 == result
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_selector_group_chat_state() -> None:
|
|
model_client = ReplayChatCompletionClient(
|
|
["agent1", "No facts", "agent2", "No plan", "agent1", "print('Hello, world!')", "agent2", "TERMINATE"],
|
|
)
|
|
agent1 = AssistantAgent("agent1", model_client=model_client)
|
|
agent2 = AssistantAgent("agent2", model_client=model_client)
|
|
termination = TextMentionTermination("TERMINATE")
|
|
team1 = SelectorGroupChat(
|
|
participants=[agent1, agent2], termination_condition=termination, model_client=model_client
|
|
)
|
|
await team1.run(task="Write a program that prints 'Hello, world!'")
|
|
state = await team1.save_state()
|
|
|
|
agent3 = AssistantAgent("agent1", model_client=model_client)
|
|
agent4 = AssistantAgent("agent2", model_client=model_client)
|
|
team2 = SelectorGroupChat(
|
|
participants=[agent3, agent4], termination_condition=termination, model_client=model_client
|
|
)
|
|
await team2.load_state(state)
|
|
state2 = await team2.save_state()
|
|
assert state == state2
|
|
|
|
agent1_model_ctx_messages = await agent1._model_context.get_messages() # pyright: ignore
|
|
agent2_model_ctx_messages = await agent2._model_context.get_messages() # pyright: ignore
|
|
agent3_model_ctx_messages = await agent3._model_context.get_messages() # pyright: ignore
|
|
agent4_model_ctx_messages = await agent4._model_context.get_messages() # pyright: ignore
|
|
assert agent3_model_ctx_messages == agent1_model_ctx_messages
|
|
assert agent4_model_ctx_messages == agent2_model_ctx_messages
|
|
manager_1 = await team1._runtime.try_get_underlying_agent_instance( # pyright: ignore
|
|
AgentId("group_chat_manager", team1._team_id), # pyright: ignore
|
|
SelectorGroupChatManager, # pyright: ignore
|
|
) # pyright: ignore
|
|
manager_2 = await team2._runtime.try_get_underlying_agent_instance( # pyright: ignore
|
|
AgentId("group_chat_manager", team2._team_id), # pyright: ignore
|
|
SelectorGroupChatManager, # pyright: ignore
|
|
) # pyright: ignore
|
|
assert manager_1._message_thread == manager_2._message_thread # pyright: ignore
|
|
assert manager_1._previous_speaker == manager_2._previous_speaker # pyright: ignore
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_selector_group_chat_two_speakers(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
model = "gpt-4o-2024-05-13"
|
|
chat_completions = [
|
|
ChatCompletion(
|
|
id="id2",
|
|
choices=[
|
|
Choice(finish_reason="stop", index=0, message=ChatCompletionMessage(content="agent2", role="assistant"))
|
|
],
|
|
created=0,
|
|
model=model,
|
|
object="chat.completion",
|
|
usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
|
|
),
|
|
]
|
|
mock = _MockChatCompletion(chat_completions)
|
|
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
|
|
|
|
agent1 = _StopAgent("agent1", description="echo agent 1", stop_at=2)
|
|
agent2 = _EchoAgent("agent2", description="echo agent 2")
|
|
termination = TextMentionTermination("TERMINATE")
|
|
team = SelectorGroupChat(
|
|
participants=[agent1, agent2],
|
|
termination_condition=termination,
|
|
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
|
|
)
|
|
result = await team.run(
|
|
task="Write a program that prints 'Hello, world!'",
|
|
)
|
|
assert len(result.messages) == 5
|
|
assert result.messages[0].content == "Write a program that prints 'Hello, world!'"
|
|
assert result.messages[1].source == "agent2"
|
|
assert result.messages[2].source == "agent1"
|
|
assert result.messages[3].source == "agent2"
|
|
assert result.messages[4].source == "agent1"
|
|
# only one chat completion was called
|
|
assert mock._curr_index == 1 # pyright: ignore
|
|
assert result.stop_reason is not None and result.stop_reason == "Text 'TERMINATE' mentioned"
|
|
|
|
# Test streaming.
|
|
mock.reset()
|
|
agent1._count = 0 # pyright: ignore
|
|
index = 0
|
|
await team.reset()
|
|
async for message in team.run_stream(task="Write a program that prints 'Hello, world!'"):
|
|
if isinstance(message, TaskResult):
|
|
assert message == result
|
|
else:
|
|
assert message == result.messages[index]
|
|
index += 1
|
|
|
|
# Test Console.
|
|
mock.reset()
|
|
agent1._count = 0 # pyright: ignore
|
|
index = 0
|
|
await team.reset()
|
|
result2 = await Console(team.run_stream(task="Write a program that prints 'Hello, world!'"))
|
|
assert result2 == result
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_selector_group_chat_two_speakers_allow_repeated(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
model = "gpt-4o-2024-05-13"
|
|
chat_completions = [
|
|
ChatCompletion(
|
|
id="id2",
|
|
choices=[
|
|
Choice(finish_reason="stop", index=0, message=ChatCompletionMessage(content="agent2", role="assistant"))
|
|
],
|
|
created=0,
|
|
model=model,
|
|
object="chat.completion",
|
|
usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
|
|
),
|
|
ChatCompletion(
|
|
id="id2",
|
|
choices=[
|
|
Choice(finish_reason="stop", index=0, message=ChatCompletionMessage(content="agent2", role="assistant"))
|
|
],
|
|
created=0,
|
|
model=model,
|
|
object="chat.completion",
|
|
usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
|
|
),
|
|
ChatCompletion(
|
|
id="id2",
|
|
choices=[
|
|
Choice(finish_reason="stop", index=0, message=ChatCompletionMessage(content="agent1", role="assistant"))
|
|
],
|
|
created=0,
|
|
model=model,
|
|
object="chat.completion",
|
|
usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
|
|
),
|
|
]
|
|
mock = _MockChatCompletion(chat_completions)
|
|
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
|
|
|
|
agent1 = _StopAgent("agent1", description="echo agent 1", stop_at=1)
|
|
agent2 = _EchoAgent("agent2", description="echo agent 2")
|
|
termination = TextMentionTermination("TERMINATE")
|
|
team = SelectorGroupChat(
|
|
participants=[agent1, agent2],
|
|
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
|
|
termination_condition=termination,
|
|
allow_repeated_speaker=True,
|
|
)
|
|
result = await team.run(task="Write a program that prints 'Hello, world!'")
|
|
assert len(result.messages) == 4
|
|
assert result.messages[0].content == "Write a program that prints 'Hello, world!'"
|
|
assert result.messages[1].source == "agent2"
|
|
assert result.messages[2].source == "agent2"
|
|
assert result.messages[3].source == "agent1"
|
|
assert result.stop_reason is not None and result.stop_reason == "Text 'TERMINATE' mentioned"
|
|
|
|
# Test streaming.
|
|
mock.reset()
|
|
index = 0
|
|
await team.reset()
|
|
async for message in team.run_stream(task="Write a program that prints 'Hello, world!'"):
|
|
if isinstance(message, TaskResult):
|
|
assert message == result
|
|
else:
|
|
assert message == result.messages[index]
|
|
index += 1
|
|
|
|
# Test Console.
|
|
mock.reset()
|
|
index = 0
|
|
await team.reset()
|
|
result2 = await Console(team.run_stream(task="Write a program that prints 'Hello, world!'"))
|
|
assert result2 == result
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_selector_group_chat_succcess_after_2_attempts() -> None:
|
|
model_client = ReplayChatCompletionClient(
|
|
["agent2, agent3", "agent2"],
|
|
)
|
|
agent1 = _StopAgent("agent1", description="echo agent 1", stop_at=1)
|
|
agent2 = _EchoAgent("agent2", description="echo agent 2")
|
|
agent3 = _EchoAgent("agent3", description="echo agent 3")
|
|
team = SelectorGroupChat(
|
|
participants=[agent1, agent2, agent3],
|
|
model_client=model_client,
|
|
max_turns=1,
|
|
)
|
|
result = await team.run(task="Write a program that prints 'Hello, world!'")
|
|
assert len(result.messages) == 2
|
|
assert result.messages[0].content == "Write a program that prints 'Hello, world!'"
|
|
assert result.messages[1].source == "agent2"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_selector_group_chat_fall_back_to_first_after_3_attempts() -> None:
|
|
model_client = ReplayChatCompletionClient(
|
|
[
|
|
"agent2, agent3", # Multiple speakers
|
|
"agent5", # Non-existent speaker
|
|
"agent3, agent1", # Multiple speakers
|
|
]
|
|
)
|
|
agent1 = _StopAgent("agent1", description="echo agent 1", stop_at=1)
|
|
agent2 = _EchoAgent("agent2", description="echo agent 2")
|
|
agent3 = _EchoAgent("agent3", description="echo agent 3")
|
|
team = SelectorGroupChat(
|
|
participants=[agent1, agent2, agent3],
|
|
model_client=model_client,
|
|
max_turns=1,
|
|
)
|
|
result = await team.run(task="Write a program that prints 'Hello, world!'")
|
|
assert len(result.messages) == 2
|
|
assert result.messages[0].content == "Write a program that prints 'Hello, world!'"
|
|
assert result.messages[1].source == "agent1"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_selector_group_chat_fall_back_to_previous_after_3_attempts() -> None:
|
|
model_client = ReplayChatCompletionClient(
|
|
["agent2", "agent2", "agent2", "agent2"],
|
|
)
|
|
agent1 = _StopAgent("agent1", description="echo agent 1", stop_at=1)
|
|
agent2 = _EchoAgent("agent2", description="echo agent 2")
|
|
agent3 = _EchoAgent("agent3", description="echo agent 3")
|
|
team = SelectorGroupChat(
|
|
participants=[agent1, agent2, agent3],
|
|
model_client=model_client,
|
|
max_turns=2,
|
|
)
|
|
result = await team.run(task="Write a program that prints 'Hello, world!'")
|
|
assert len(result.messages) == 3
|
|
assert result.messages[0].content == "Write a program that prints 'Hello, world!'"
|
|
assert result.messages[1].source == "agent2"
|
|
assert result.messages[2].source == "agent2"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_selector_group_chat_custom_selector(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
model = "gpt-4o-2024-05-13"
|
|
chat_completions = [
|
|
ChatCompletion(
|
|
id="id2",
|
|
choices=[
|
|
Choice(finish_reason="stop", index=0, message=ChatCompletionMessage(content="agent3", role="assistant"))
|
|
],
|
|
created=0,
|
|
model=model,
|
|
object="chat.completion",
|
|
usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
|
|
),
|
|
]
|
|
mock = _MockChatCompletion(chat_completions)
|
|
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
|
|
agent1 = _EchoAgent("agent1", description="echo agent 1")
|
|
agent2 = _EchoAgent("agent2", description="echo agent 2")
|
|
agent3 = _EchoAgent("agent3", description="echo agent 3")
|
|
agent4 = _EchoAgent("agent4", description="echo agent 4")
|
|
|
|
def _select_agent(messages: Sequence[AgentEvent | ChatMessage]) -> str | None:
|
|
if len(messages) == 0:
|
|
return "agent1"
|
|
elif messages[-1].source == "agent1":
|
|
return "agent2"
|
|
elif messages[-1].source == "agent2":
|
|
return None
|
|
elif messages[-1].source == "agent3":
|
|
return "agent4"
|
|
else:
|
|
return "agent1"
|
|
|
|
termination = MaxMessageTermination(6)
|
|
team = SelectorGroupChat(
|
|
participants=[agent1, agent2, agent3, agent4],
|
|
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
|
|
selector_func=_select_agent,
|
|
termination_condition=termination,
|
|
)
|
|
result = await team.run(task="task")
|
|
assert len(result.messages) == 6
|
|
assert result.messages[1].source == "agent1"
|
|
assert result.messages[2].source == "agent2"
|
|
assert result.messages[3].source == "agent3"
|
|
assert result.messages[4].source == "agent4"
|
|
assert result.messages[5].source == "agent1"
|
|
assert (
|
|
result.stop_reason is not None
|
|
and result.stop_reason == "Maximum number of messages 6 reached, current message count: 6"
|
|
)
|
|
|
|
|
|
class _HandOffAgent(BaseChatAgent):
|
|
def __init__(self, name: str, description: str, next_agent: str) -> None:
|
|
super().__init__(name, description)
|
|
self._next_agent = next_agent
|
|
|
|
@property
|
|
def produced_message_types(self) -> Sequence[type[ChatMessage]]:
|
|
return (HandoffMessage,)
|
|
|
|
async def on_messages(self, messages: Sequence[ChatMessage], cancellation_token: CancellationToken) -> Response:
|
|
return Response(
|
|
chat_message=HandoffMessage(
|
|
content=f"Transferred to {self._next_agent}.", target=self._next_agent, source=self.name
|
|
)
|
|
)
|
|
|
|
async def on_reset(self, cancellation_token: CancellationToken) -> None:
|
|
pass
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_swarm_handoff() -> None:
|
|
first_agent = _HandOffAgent("first_agent", description="first agent", next_agent="second_agent")
|
|
second_agent = _HandOffAgent("second_agent", description="second agent", next_agent="third_agent")
|
|
third_agent = _HandOffAgent("third_agent", description="third agent", next_agent="first_agent")
|
|
|
|
termination = MaxMessageTermination(6)
|
|
team = Swarm([second_agent, first_agent, third_agent], termination_condition=termination)
|
|
result = await team.run(task="task")
|
|
assert len(result.messages) == 6
|
|
assert result.messages[0].content == "task"
|
|
assert result.messages[1].content == "Transferred to third_agent."
|
|
assert result.messages[2].content == "Transferred to first_agent."
|
|
assert result.messages[3].content == "Transferred to second_agent."
|
|
assert result.messages[4].content == "Transferred to third_agent."
|
|
assert result.messages[5].content == "Transferred to first_agent."
|
|
assert (
|
|
result.stop_reason is not None
|
|
and result.stop_reason == "Maximum number of messages 6 reached, current message count: 6"
|
|
)
|
|
|
|
# Test streaming.
|
|
index = 0
|
|
await team.reset()
|
|
stream = team.run_stream(task="task")
|
|
async for message in stream:
|
|
if isinstance(message, TaskResult):
|
|
assert message == result
|
|
else:
|
|
assert message == result.messages[index]
|
|
index += 1
|
|
|
|
# Test save and load.
|
|
state = await team.save_state()
|
|
first_agent2 = _HandOffAgent("first_agent", description="first agent", next_agent="second_agent")
|
|
second_agent2 = _HandOffAgent("second_agent", description="second agent", next_agent="third_agent")
|
|
third_agent2 = _HandOffAgent("third_agent", description="third agent", next_agent="first_agent")
|
|
team2 = Swarm([second_agent2, first_agent2, third_agent2], termination_condition=termination)
|
|
await team2.load_state(state)
|
|
state2 = await team2.save_state()
|
|
assert state == state2
|
|
manager_1 = await team._runtime.try_get_underlying_agent_instance( # pyright: ignore
|
|
AgentId("group_chat_manager", team._team_id), # pyright: ignore
|
|
SwarmGroupChatManager, # pyright: ignore
|
|
) # pyright: ignore
|
|
manager_2 = await team2._runtime.try_get_underlying_agent_instance( # pyright: ignore
|
|
AgentId("group_chat_manager", team2._team_id), # pyright: ignore
|
|
SwarmGroupChatManager, # pyright: ignore
|
|
) # pyright: ignore
|
|
assert manager_1._message_thread == manager_2._message_thread # pyright: ignore
|
|
assert manager_1._current_speaker == manager_2._current_speaker # pyright: ignore
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_swarm_handoff_using_tool_calls(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
model = "gpt-4o-2024-05-13"
|
|
chat_completions = [
|
|
ChatCompletion(
|
|
id="id1",
|
|
choices=[
|
|
Choice(
|
|
finish_reason="tool_calls",
|
|
index=0,
|
|
message=ChatCompletionMessage(
|
|
content=None,
|
|
tool_calls=[
|
|
ChatCompletionMessageToolCall(
|
|
id="1",
|
|
type="function",
|
|
function=Function(
|
|
name="handoff_to_agent2",
|
|
arguments=json.dumps({}),
|
|
),
|
|
)
|
|
],
|
|
role="assistant",
|
|
),
|
|
)
|
|
],
|
|
created=0,
|
|
model=model,
|
|
object="chat.completion",
|
|
usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
|
|
),
|
|
ChatCompletion(
|
|
id="id2",
|
|
choices=[
|
|
Choice(finish_reason="stop", index=0, message=ChatCompletionMessage(content="Hello", role="assistant"))
|
|
],
|
|
created=0,
|
|
model=model,
|
|
object="chat.completion",
|
|
usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
|
|
),
|
|
ChatCompletion(
|
|
id="id2",
|
|
choices=[
|
|
Choice(
|
|
finish_reason="stop", index=0, message=ChatCompletionMessage(content="TERMINATE", role="assistant")
|
|
)
|
|
],
|
|
created=0,
|
|
model=model,
|
|
object="chat.completion",
|
|
usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
|
|
),
|
|
]
|
|
mock = _MockChatCompletion(chat_completions)
|
|
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
|
|
|
|
agent1 = AssistantAgent(
|
|
"agent1",
|
|
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
|
|
handoffs=[Handoff(target="agent2", name="handoff_to_agent2", message="handoff to agent2")],
|
|
)
|
|
agent2 = _HandOffAgent("agent2", description="agent 2", next_agent="agent1")
|
|
termination = TextMentionTermination("TERMINATE")
|
|
team = Swarm([agent1, agent2], termination_condition=termination)
|
|
result = await team.run(task="task")
|
|
assert len(result.messages) == 7
|
|
assert result.messages[0].content == "task"
|
|
assert isinstance(result.messages[1], ToolCallRequestEvent)
|
|
assert isinstance(result.messages[2], ToolCallExecutionEvent)
|
|
assert result.messages[3].content == "handoff to agent2"
|
|
assert result.messages[4].content == "Transferred to agent1."
|
|
assert result.messages[5].content == "Hello"
|
|
assert result.messages[6].content == "TERMINATE"
|
|
assert result.stop_reason is not None and result.stop_reason == "Text 'TERMINATE' mentioned"
|
|
|
|
# Test streaming.
|
|
await agent1._model_context.clear() # pyright: ignore
|
|
mock.reset()
|
|
index = 0
|
|
await team.reset()
|
|
stream = team.run_stream(task="task")
|
|
async for message in stream:
|
|
if isinstance(message, TaskResult):
|
|
assert message == result
|
|
else:
|
|
assert message == result.messages[index]
|
|
index += 1
|
|
|
|
# Test Console
|
|
await agent1._model_context.clear() # pyright: ignore
|
|
mock.reset()
|
|
index = 0
|
|
await team.reset()
|
|
result2 = await Console(team.run_stream(task="task"))
|
|
assert result2 == result
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_swarm_pause_and_resume() -> None:
|
|
first_agent = _HandOffAgent("first_agent", description="first agent", next_agent="second_agent")
|
|
second_agent = _HandOffAgent("second_agent", description="second agent", next_agent="third_agent")
|
|
third_agent = _HandOffAgent("third_agent", description="third agent", next_agent="first_agent")
|
|
|
|
team = Swarm([second_agent, first_agent, third_agent], max_turns=1)
|
|
result = await team.run(task="task")
|
|
assert len(result.messages) == 2
|
|
assert result.messages[0].content == "task"
|
|
assert result.messages[1].content == "Transferred to third_agent."
|
|
|
|
# Resume with a new task.
|
|
result = await team.run(task="new task")
|
|
assert len(result.messages) == 2
|
|
assert result.messages[0].content == "new task"
|
|
assert result.messages[1].content == "Transferred to first_agent."
|
|
|
|
# Resume with the same task.
|
|
result = await team.run()
|
|
assert len(result.messages) == 1
|
|
assert result.messages[0].content == "Transferred to second_agent."
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_swarm_with_parallel_tool_calls(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
model = "gpt-4o-2024-05-13"
|
|
chat_completions = [
|
|
ChatCompletion(
|
|
id="id1",
|
|
choices=[
|
|
Choice(
|
|
finish_reason="tool_calls",
|
|
index=0,
|
|
message=ChatCompletionMessage(
|
|
content=None,
|
|
tool_calls=[
|
|
ChatCompletionMessageToolCall(
|
|
id="1",
|
|
type="function",
|
|
function=Function(
|
|
name="tool1",
|
|
arguments=json.dumps({}),
|
|
),
|
|
),
|
|
ChatCompletionMessageToolCall(
|
|
id="2",
|
|
type="function",
|
|
function=Function(
|
|
name="tool2",
|
|
arguments=json.dumps({}),
|
|
),
|
|
),
|
|
ChatCompletionMessageToolCall(
|
|
id="3",
|
|
type="function",
|
|
function=Function(
|
|
name="handoff_to_agent2",
|
|
arguments=json.dumps({}),
|
|
),
|
|
),
|
|
],
|
|
role="assistant",
|
|
),
|
|
)
|
|
],
|
|
created=0,
|
|
model=model,
|
|
object="chat.completion",
|
|
usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
|
|
),
|
|
ChatCompletion(
|
|
id="id2",
|
|
choices=[
|
|
Choice(finish_reason="stop", index=0, message=ChatCompletionMessage(content="Hello", role="assistant"))
|
|
],
|
|
created=0,
|
|
model=model,
|
|
object="chat.completion",
|
|
usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
|
|
),
|
|
ChatCompletion(
|
|
id="id2",
|
|
choices=[
|
|
Choice(
|
|
finish_reason="stop", index=0, message=ChatCompletionMessage(content="TERMINATE", role="assistant")
|
|
)
|
|
],
|
|
created=0,
|
|
model=model,
|
|
object="chat.completion",
|
|
usage=CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0),
|
|
),
|
|
]
|
|
mock = _MockChatCompletion(chat_completions)
|
|
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
|
|
|
|
expected_handoff_context: List[LLMMessage] = [
|
|
AssistantMessage(
|
|
source="agent1",
|
|
content=[
|
|
FunctionCall(id="1", name="tool1", arguments="{}"),
|
|
FunctionCall(id="2", name="tool2", arguments="{}"),
|
|
],
|
|
),
|
|
FunctionExecutionResultMessage(
|
|
content=[
|
|
FunctionExecutionResult(content="tool1", call_id="1", is_error=False, name="tool1"),
|
|
FunctionExecutionResult(content="tool2", call_id="2", is_error=False, name="tool2"),
|
|
]
|
|
),
|
|
]
|
|
|
|
def tool1() -> str:
|
|
return "tool1"
|
|
|
|
def tool2() -> str:
|
|
return "tool2"
|
|
|
|
agent1 = AssistantAgent(
|
|
"agent1",
|
|
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
|
|
handoffs=[Handoff(target="agent2", name="handoff_to_agent2", message="handoff to agent2")],
|
|
tools=[tool1, tool2],
|
|
)
|
|
agent2 = AssistantAgent(
|
|
"agent2",
|
|
model_client=OpenAIChatCompletionClient(model=model, api_key=""),
|
|
)
|
|
termination = TextMentionTermination("TERMINATE")
|
|
team = Swarm([agent1, agent2], termination_condition=termination)
|
|
result = await team.run(task="task")
|
|
assert len(result.messages) == 6
|
|
assert result.messages[0] == TextMessage(content="task", source="user")
|
|
assert isinstance(result.messages[1], ToolCallRequestEvent)
|
|
assert isinstance(result.messages[2], ToolCallExecutionEvent)
|
|
assert result.messages[3] == HandoffMessage(
|
|
content="handoff to agent2",
|
|
target="agent2",
|
|
source="agent1",
|
|
context=expected_handoff_context,
|
|
)
|
|
assert result.messages[4].content == "Hello"
|
|
assert result.messages[4].source == "agent2"
|
|
assert result.messages[5].content == "TERMINATE"
|
|
assert result.messages[5].source == "agent2"
|
|
|
|
# Verify the tool calls are in agent2's context.
|
|
agent2_model_ctx_messages = await agent2._model_context.get_messages() # pyright: ignore
|
|
assert agent2_model_ctx_messages[0] == UserMessage(content="task", source="user")
|
|
assert agent2_model_ctx_messages[1] == expected_handoff_context[0]
|
|
assert agent2_model_ctx_messages[2] == expected_handoff_context[1]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_swarm_with_handoff_termination() -> None:
|
|
first_agent = _HandOffAgent("first_agent", description="first agent", next_agent="second_agent")
|
|
second_agent = _HandOffAgent("second_agent", description="second agent", next_agent="third_agent")
|
|
third_agent = _HandOffAgent("third_agent", description="third agent", next_agent="first_agent")
|
|
|
|
# Handoff to an existing agent.
|
|
termination = HandoffTermination(target="third_agent")
|
|
team = Swarm([second_agent, first_agent, third_agent], termination_condition=termination)
|
|
# Start
|
|
result = await team.run(task="task")
|
|
assert len(result.messages) == 2
|
|
assert result.messages[0].content == "task"
|
|
assert result.messages[1].content == "Transferred to third_agent."
|
|
# Resume existing.
|
|
result = await team.run()
|
|
assert len(result.messages) == 3
|
|
assert result.messages[0].content == "Transferred to first_agent."
|
|
assert result.messages[1].content == "Transferred to second_agent."
|
|
assert result.messages[2].content == "Transferred to third_agent."
|
|
# Resume new task.
|
|
result = await team.run(task="new task")
|
|
assert len(result.messages) == 4
|
|
assert result.messages[0].content == "new task"
|
|
assert result.messages[1].content == "Transferred to first_agent."
|
|
assert result.messages[2].content == "Transferred to second_agent."
|
|
assert result.messages[3].content == "Transferred to third_agent."
|
|
|
|
# Handoff to a non-existing agent.
|
|
third_agent = _HandOffAgent("third_agent", description="third agent", next_agent="non_existing_agent")
|
|
termination = HandoffTermination(target="non_existing_agent")
|
|
team = Swarm([second_agent, first_agent, third_agent], termination_condition=termination)
|
|
# Start
|
|
result = await team.run(task="task")
|
|
assert len(result.messages) == 3
|
|
assert result.messages[0].content == "task"
|
|
assert result.messages[1].content == "Transferred to third_agent."
|
|
assert result.messages[2].content == "Transferred to non_existing_agent."
|
|
# Attempt to resume.
|
|
with pytest.raises(ValueError):
|
|
await team.run()
|
|
# Attempt to resume with a new task.
|
|
with pytest.raises(ValueError):
|
|
await team.run(task="new task")
|
|
# Resume with a HandoffMessage
|
|
result = await team.run(task=HandoffMessage(content="Handoff to first_agent.", target="first_agent", source="user"))
|
|
assert len(result.messages) == 4
|
|
assert result.messages[0].content == "Handoff to first_agent."
|
|
assert result.messages[1].content == "Transferred to second_agent."
|
|
assert result.messages[2].content == "Transferred to third_agent."
|
|
assert result.messages[3].content == "Transferred to non_existing_agent."
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_round_robin_group_chat_with_message_list() -> None:
|
|
# Create a simple team with echo agents
|
|
agent1 = _EchoAgent("Agent1", "First agent")
|
|
agent2 = _EchoAgent("Agent2", "Second agent")
|
|
termination = MaxMessageTermination(4) # Stop after 4 messages
|
|
team = RoundRobinGroupChat([agent1, agent2], termination_condition=termination)
|
|
|
|
# Create a list of messages
|
|
messages: List[ChatMessage] = [
|
|
TextMessage(content="Message 1", source="user"),
|
|
TextMessage(content="Message 2", source="user"),
|
|
TextMessage(content="Message 3", source="user"),
|
|
]
|
|
|
|
# Run the team with the message list
|
|
result = await team.run(task=messages)
|
|
|
|
# Verify the messages were processed in order
|
|
assert len(result.messages) == 4 # Initial messages + echo until termination
|
|
assert result.messages[0].content == "Message 1" # First message
|
|
assert result.messages[1].content == "Message 2" # Second message
|
|
assert result.messages[2].content == "Message 3" # Third message
|
|
assert result.messages[3].content == "Message 1" # Echo from first agent
|
|
assert result.stop_reason == "Maximum number of messages 4 reached, current message count: 4"
|
|
|
|
# Test with streaming
|
|
await team.reset()
|
|
index = 0
|
|
async for message in team.run_stream(task=messages):
|
|
if isinstance(message, TaskResult):
|
|
assert message == result
|
|
else:
|
|
assert message == result.messages[index]
|
|
index += 1
|
|
|
|
# Test with invalid message list
|
|
with pytest.raises(ValueError, match="All messages in task list must be valid ChatMessage types"):
|
|
await team.run(task=["not a message"]) # type: ignore[list-item, arg-type] # intentionally testing invalid input
|
|
|
|
# Test with empty message list
|
|
with pytest.raises(ValueError, match="Task list cannot be empty"):
|
|
await team.run(task=[])
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_declarative_groupchats_with_config() -> None:
|
|
# Create basic agents and components for testing
|
|
agent1 = AssistantAgent(
|
|
"agent_1",
|
|
model_client=OpenAIChatCompletionClient(model="gpt-4o-2024-05-13", api_key=""),
|
|
handoffs=["agent_2"],
|
|
)
|
|
agent2 = AssistantAgent("agent_2", model_client=OpenAIChatCompletionClient(model="gpt-4o-2024-05-13", api_key=""))
|
|
termination = MaxMessageTermination(4)
|
|
model_client = OpenAIChatCompletionClient(model="gpt-4o-2024-05-13", api_key="")
|
|
|
|
# Test round robin - verify config is preserved
|
|
round_robin = RoundRobinGroupChat(participants=[agent1, agent2], termination_condition=termination, max_turns=5)
|
|
config = round_robin.dump_component()
|
|
loaded = RoundRobinGroupChat.load_component(config)
|
|
assert loaded.dump_component() == config
|
|
|
|
# Test selector group chat - verify config is preserved
|
|
selector_prompt = "Custom selector prompt with {roles}, {participants}, {history}"
|
|
selector = SelectorGroupChat(
|
|
participants=[agent1, agent2],
|
|
model_client=model_client,
|
|
termination_condition=termination,
|
|
max_turns=10,
|
|
selector_prompt=selector_prompt,
|
|
allow_repeated_speaker=True,
|
|
)
|
|
selector_config = selector.dump_component()
|
|
selector_loaded = SelectorGroupChat.load_component(selector_config)
|
|
assert selector_loaded.dump_component() == selector_config
|
|
|
|
# Test swarm with handoff termination
|
|
handoff_termination = HandoffTermination(target="Agent2")
|
|
swarm = Swarm(participants=[agent1, agent2], termination_condition=handoff_termination, max_turns=5)
|
|
swarm_config = swarm.dump_component()
|
|
swarm_loaded = Swarm.load_component(swarm_config)
|
|
assert swarm_loaded.dump_component() == swarm_config
|
|
|
|
# Test MagenticOne with custom parameters
|
|
magentic = MagenticOneGroupChat(
|
|
participants=[agent1],
|
|
model_client=model_client,
|
|
max_turns=15,
|
|
max_stalls=5,
|
|
final_answer_prompt="Custom prompt",
|
|
)
|
|
magentic_config = magentic.dump_component()
|
|
magentic_loaded = MagenticOneGroupChat.load_component(magentic_config)
|
|
assert magentic_loaded.dump_component() == magentic_config
|
|
|
|
# Verify component types are correctly set for each
|
|
for team in [loaded, selector, swarm, magentic]:
|
|
assert team.component_type == "team"
|
|
|
|
# Verify provider strings are correctly set
|
|
assert round_robin.dump_component().provider == "autogen_agentchat.teams.RoundRobinGroupChat"
|
|
assert selector.dump_component().provider == "autogen_agentchat.teams.SelectorGroupChat"
|
|
assert swarm.dump_component().provider == "autogen_agentchat.teams.Swarm"
|
|
assert magentic.dump_component().provider == "autogen_agentchat.teams.MagenticOneGroupChat"
|