autogen/python/packages/autogen-agentchat/tests/test_society_of_mind_agent.py

81 lines
3.4 KiB
Python
Raw Normal View History

import asyncio
from typing import Any, AsyncGenerator, List
import pytest
from autogen_agentchat.agents import AssistantAgent, SocietyOfMindAgent
from autogen_agentchat.task import MaxMessageTermination
from autogen_agentchat.teams import RoundRobinGroupChat
from autogen_ext.models import OpenAIChatCompletionClient
from openai.resources.chat.completions import AsyncCompletions
from openai.types.chat.chat_completion import ChatCompletion, Choice
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from openai.types.chat.chat_completion_message import ChatCompletionMessage
from openai.types.completion_usage import CompletionUsage
class _MockChatCompletion:
def __init__(self, chat_completions: List[ChatCompletion]) -> None:
self._saved_chat_completions = chat_completions
self._curr_index = 0
async def mock_create(
self, *args: Any, **kwargs: Any
) -> ChatCompletion | AsyncGenerator[ChatCompletionChunk, None]:
await asyncio.sleep(0.1)
completion = self._saved_chat_completions[self._curr_index]
self._curr_index += 1
return completion
@pytest.mark.asyncio
async def test_society_of_mind_agent(monkeypatch: pytest.MonkeyPatch) -> None:
model = "gpt-4o-2024-05-13"
chat_completions = [
ChatCompletion(
id="id2",
choices=[
Choice(finish_reason="stop", index=0, message=ChatCompletionMessage(content="1", role="assistant"))
],
created=0,
model=model,
object="chat.completion",
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
),
ChatCompletion(
id="id2",
choices=[
Choice(finish_reason="stop", index=0, message=ChatCompletionMessage(content="2", role="assistant"))
],
created=0,
model=model,
object="chat.completion",
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
),
ChatCompletion(
id="id2",
choices=[
Choice(finish_reason="stop", index=0, message=ChatCompletionMessage(content="3", role="assistant"))
],
created=0,
model=model,
object="chat.completion",
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
),
]
mock = _MockChatCompletion(chat_completions)
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
model_client = OpenAIChatCompletionClient(model="gpt-4o", api_key="")
agent1 = AssistantAgent("assistant1", model_client=model_client, system_message="You are a helpful assistant.")
agent2 = AssistantAgent("assistant2", model_client=model_client, system_message="You are a helpful assistant.")
inner_termination = MaxMessageTermination(3)
inner_team = RoundRobinGroupChat([agent1, agent2], termination_condition=inner_termination)
society_of_mind_agent = SocietyOfMindAgent("society_of_mind", team=inner_team, model_client=model_client)
response = await society_of_mind_agent.run(task="Count to 10.")
assert len(response.messages) == 5
assert response.messages[0].source == "user"
assert response.messages[1].source == "user"
assert response.messages[2].source == "assistant1"
assert response.messages[3].source == "assistant2"
assert response.messages[4].source == "society_of_mind"