2024-11-08 16:41:34 -08:00
|
|
|
import asyncio
|
|
|
|
from typing import Any, AsyncGenerator, List
|
|
|
|
|
|
|
|
import pytest
|
|
|
|
from autogen_agentchat.agents import AssistantAgent, SocietyOfMindAgent
|
2024-12-03 15:24:25 -08:00
|
|
|
from autogen_agentchat.conditions import MaxMessageTermination
|
2024-11-08 16:41:34 -08:00
|
|
|
from autogen_agentchat.teams import RoundRobinGroupChat
|
2024-12-10 13:18:09 +10:00
|
|
|
from autogen_ext.models.openai import OpenAIChatCompletionClient
|
2024-11-08 16:41:34 -08:00
|
|
|
from openai.resources.chat.completions import AsyncCompletions
|
|
|
|
from openai.types.chat.chat_completion import ChatCompletion, Choice
|
|
|
|
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
|
|
|
from openai.types.chat.chat_completion_message import ChatCompletionMessage
|
|
|
|
from openai.types.completion_usage import CompletionUsage
|
|
|
|
|
|
|
|
|
|
|
|
class _MockChatCompletion:
|
|
|
|
def __init__(self, chat_completions: List[ChatCompletion]) -> None:
|
|
|
|
self._saved_chat_completions = chat_completions
|
|
|
|
self._curr_index = 0
|
|
|
|
|
|
|
|
async def mock_create(
|
|
|
|
self, *args: Any, **kwargs: Any
|
|
|
|
) -> ChatCompletion | AsyncGenerator[ChatCompletionChunk, None]:
|
|
|
|
await asyncio.sleep(0.1)
|
|
|
|
completion = self._saved_chat_completions[self._curr_index]
|
|
|
|
self._curr_index += 1
|
|
|
|
return completion
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
|
|
async def test_society_of_mind_agent(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
|
|
model = "gpt-4o-2024-05-13"
|
|
|
|
chat_completions = [
|
|
|
|
ChatCompletion(
|
|
|
|
id="id2",
|
|
|
|
choices=[
|
|
|
|
Choice(finish_reason="stop", index=0, message=ChatCompletionMessage(content="1", role="assistant"))
|
|
|
|
],
|
|
|
|
created=0,
|
|
|
|
model=model,
|
|
|
|
object="chat.completion",
|
|
|
|
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
|
|
|
|
),
|
|
|
|
ChatCompletion(
|
|
|
|
id="id2",
|
|
|
|
choices=[
|
|
|
|
Choice(finish_reason="stop", index=0, message=ChatCompletionMessage(content="2", role="assistant"))
|
|
|
|
],
|
|
|
|
created=0,
|
|
|
|
model=model,
|
|
|
|
object="chat.completion",
|
|
|
|
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
|
|
|
|
),
|
|
|
|
ChatCompletion(
|
|
|
|
id="id2",
|
|
|
|
choices=[
|
|
|
|
Choice(finish_reason="stop", index=0, message=ChatCompletionMessage(content="3", role="assistant"))
|
|
|
|
],
|
|
|
|
created=0,
|
|
|
|
model=model,
|
|
|
|
object="chat.completion",
|
|
|
|
usage=CompletionUsage(prompt_tokens=10, completion_tokens=5, total_tokens=0),
|
|
|
|
),
|
|
|
|
]
|
|
|
|
mock = _MockChatCompletion(chat_completions)
|
|
|
|
monkeypatch.setattr(AsyncCompletions, "create", mock.mock_create)
|
|
|
|
model_client = OpenAIChatCompletionClient(model="gpt-4o", api_key="")
|
|
|
|
|
|
|
|
agent1 = AssistantAgent("assistant1", model_client=model_client, system_message="You are a helpful assistant.")
|
|
|
|
agent2 = AssistantAgent("assistant2", model_client=model_client, system_message="You are a helpful assistant.")
|
|
|
|
inner_termination = MaxMessageTermination(3)
|
|
|
|
inner_team = RoundRobinGroupChat([agent1, agent2], termination_condition=inner_termination)
|
|
|
|
society_of_mind_agent = SocietyOfMindAgent("society_of_mind", team=inner_team, model_client=model_client)
|
|
|
|
response = await society_of_mind_agent.run(task="Count to 10.")
|
2024-12-15 11:18:17 +05:30
|
|
|
assert len(response.messages) == 4
|
2024-11-08 16:41:34 -08:00
|
|
|
assert response.messages[0].source == "user"
|
2024-12-15 11:18:17 +05:30
|
|
|
assert response.messages[1].source == "assistant1"
|
|
|
|
assert response.messages[2].source == "assistant2"
|
|
|
|
assert response.messages[3].source == "society_of_mind"
|
|
|
|
|
|
|
|
# Test save and load state.
|
|
|
|
state = await society_of_mind_agent.save_state()
|
|
|
|
assert state is not None
|
|
|
|
agent1 = AssistantAgent("assistant1", model_client=model_client, system_message="You are a helpful assistant.")
|
|
|
|
agent2 = AssistantAgent("assistant2", model_client=model_client, system_message="You are a helpful assistant.")
|
|
|
|
inner_termination = MaxMessageTermination(3)
|
|
|
|
inner_team = RoundRobinGroupChat([agent1, agent2], termination_condition=inner_termination)
|
|
|
|
society_of_mind_agent2 = SocietyOfMindAgent("society_of_mind", team=inner_team, model_client=model_client)
|
|
|
|
await society_of_mind_agent2.load_state(state)
|
|
|
|
state2 = await society_of_mind_agent2.save_state()
|
|
|
|
assert state == state2
|
2025-01-23 23:08:22 -08:00
|
|
|
|
|
|
|
# Test serialization.
|
|
|
|
|
|
|
|
soc_agent_config = society_of_mind_agent.dump_component()
|
|
|
|
assert soc_agent_config.provider == "autogen_agentchat.agents.SocietyOfMindAgent"
|
|
|
|
|
|
|
|
# Test deserialization.
|
|
|
|
loaded_soc_agent = SocietyOfMindAgent.load_component(soc_agent_config)
|
|
|
|
assert isinstance(loaded_soc_agent, SocietyOfMindAgent)
|
|
|
|
assert loaded_soc_agent.name == "society_of_mind"
|