autogen/python/packages/autogen-studio/tests/test_component_factory.py
Victor Dibia 6a4a11042c
Add AOAI Support in AGS (#4718)
* add oai support, improve component config typing, minor updates to docs, update ags tests

* faq updates

* update faq, add model_capabilities

* update faq
2024-12-16 13:17:42 -08:00

398 lines
14 KiB
Python

import pytest
from typing import List
from autogen_agentchat.agents import AssistantAgent
from autogen_agentchat.teams import RoundRobinGroupChat, SelectorGroupChat, MagenticOneGroupChat
from autogen_agentchat.conditions import MaxMessageTermination, StopMessageTermination, TextMentionTermination
from autogen_core.tools import FunctionTool
from autogenstudio.datamodel.types import (
AssistantAgentConfig,
OpenAIModelConfig,
RoundRobinTeamConfig,
SelectorTeamConfig,
MagenticOneTeamConfig,
ToolConfig,
MaxMessageTerminationConfig,
StopMessageTerminationConfig,
TextMentionTerminationConfig,
CombinationTerminationConfig,
ModelTypes,
AgentTypes,
TeamTypes,
TerminationTypes,
ToolTypes,
ComponentTypes,
)
from autogenstudio.database import ComponentFactory
@pytest.fixture
def component_factory():
return ComponentFactory()
@pytest.fixture
def sample_tool_config():
return ToolConfig(
name="calculator",
description="A simple calculator function",
content="""
def calculator(a: int, b: int, operation: str = '+') -> int:
'''
A simple calculator that performs basic operations
'''
if operation == '+':
return a + b
elif operation == '-':
return a - b
elif operation == '*':
return a * b
elif operation == '/':
return a / b
else:
raise ValueError("Invalid operation")
""",
tool_type=ToolTypes.PYTHON_FUNCTION,
component_type=ComponentTypes.TOOL,
version="1.0.0",
)
@pytest.fixture
def sample_model_config():
return OpenAIModelConfig(
model_type=ModelTypes.OPENAI,
model="gpt-4",
api_key="test-key",
component_type=ComponentTypes.MODEL,
version="1.0.0",
)
@pytest.fixture
def sample_agent_config(sample_model_config: OpenAIModelConfig, sample_tool_config: ToolConfig):
return AssistantAgentConfig(
name="test_agent",
agent_type=AgentTypes.ASSISTANT,
system_message="You are a helpful assistant",
model_client=sample_model_config,
tools=[sample_tool_config],
component_type=ComponentTypes.AGENT,
version="1.0.0",
)
@pytest.fixture
def sample_termination_config():
return MaxMessageTerminationConfig(
termination_type=TerminationTypes.MAX_MESSAGES,
max_messages=10,
component_type=ComponentTypes.TERMINATION,
version="1.0.0",
)
@pytest.fixture
def sample_team_config(
sample_agent_config: AssistantAgentConfig, sample_termination_config: MaxMessageTerminationConfig, sample_model_config: OpenAIModelConfig
):
return RoundRobinTeamConfig(
name="test_team",
team_type=TeamTypes.ROUND_ROBIN,
participants=[sample_agent_config],
termination_condition=sample_termination_config,
model_client=sample_model_config,
component_type=ComponentTypes.TEAM,
max_turns=10,
version="1.0.0",
)
@pytest.mark.asyncio
async def test_load_tool(component_factory: ComponentFactory, sample_tool_config: ToolConfig):
# Test loading tool from ToolConfig
tool = await component_factory.load_tool(sample_tool_config)
assert isinstance(tool, FunctionTool)
assert tool.name == "calculator"
assert tool.description == "A simple calculator function"
# Test tool functionality
result = tool._func(5, 3, "+")
assert result == 8
@pytest.mark.asyncio
async def test_load_tool_invalid_config(component_factory: ComponentFactory):
# Test with missing required fields
with pytest.raises(ValueError):
await component_factory.load_tool(
ToolConfig(
name="test",
description="",
content="",
tool_type=ToolTypes.PYTHON_FUNCTION,
component_type=ComponentTypes.TOOL,
version="1.0.0",
)
)
# Test with invalid Python code
invalid_config = ToolConfig(
name="invalid",
description="Invalid function",
content="def invalid_func(): return invalid syntax",
tool_type=ToolTypes.PYTHON_FUNCTION,
component_type=ComponentTypes.TOOL,
version="1.0.0",
)
with pytest.raises(ValueError):
await component_factory.load_tool(invalid_config)
@pytest.mark.asyncio
async def test_load_model(component_factory: ComponentFactory, sample_model_config: OpenAIModelConfig):
# Test loading model from ModelConfig
model = await component_factory.load_model(sample_model_config)
assert model is not None
@pytest.mark.asyncio
async def test_load_agent(component_factory: ComponentFactory, sample_agent_config: AssistantAgentConfig):
# Test loading agent from AgentConfig
agent = await component_factory.load_agent(sample_agent_config)
assert isinstance(agent, AssistantAgent)
assert agent.name == "test_agent"
assert len(agent._tools) == 1
@pytest.mark.asyncio
async def test_load_termination(component_factory: ComponentFactory):
max_msg_config = MaxMessageTerminationConfig(
termination_type=TerminationTypes.MAX_MESSAGES,
max_messages=5,
component_type=ComponentTypes.TERMINATION,
version="1.0.0",
)
termination = await component_factory.load_termination(max_msg_config)
assert isinstance(termination, MaxMessageTermination)
assert termination._max_messages == 5
# Test StopMessageTermination
stop_msg_config = StopMessageTerminationConfig(
termination_type=TerminationTypes.STOP_MESSAGE, component_type=ComponentTypes.TERMINATION, version="1.0.0"
)
termination = await component_factory.load_termination(stop_msg_config)
assert isinstance(termination, StopMessageTermination)
# Test TextMentionTermination
text_mention_config = TextMentionTerminationConfig(
termination_type=TerminationTypes.TEXT_MENTION,
text="DONE",
component_type=ComponentTypes.TERMINATION,
version="1.0.0",
)
termination = await component_factory.load_termination(text_mention_config)
assert isinstance(termination, TextMentionTermination)
assert termination._text == "DONE"
# Test AND combination
and_combo_config = CombinationTerminationConfig(
termination_type=TerminationTypes.COMBINATION,
operator="and",
conditions=[
MaxMessageTerminationConfig(
termination_type=TerminationTypes.MAX_MESSAGES,
max_messages=5,
component_type=ComponentTypes.TERMINATION,
version="1.0.0",
),
TextMentionTerminationConfig(
termination_type=TerminationTypes.TEXT_MENTION,
text="DONE",
component_type=ComponentTypes.TERMINATION,
version="1.0.0",
),
],
component_type=ComponentTypes.TERMINATION,
version="1.0.0",
)
termination = await component_factory.load_termination(and_combo_config)
assert termination is not None
# Test OR combination
or_combo_config = CombinationTerminationConfig(
termination_type=TerminationTypes.COMBINATION,
operator="or",
conditions=[
MaxMessageTerminationConfig(
termination_type=TerminationTypes.MAX_MESSAGES,
max_messages=5,
component_type=ComponentTypes.TERMINATION,
version="1.0.0",
),
TextMentionTerminationConfig(
termination_type=TerminationTypes.TEXT_MENTION,
text="DONE",
component_type=ComponentTypes.TERMINATION,
version="1.0.0",
),
],
component_type=ComponentTypes.TERMINATION,
version="1.0.0",
)
termination = await component_factory.load_termination(or_combo_config)
assert termination is not None
# Test invalid combinations
with pytest.raises(ValueError):
await component_factory.load_termination(
CombinationTerminationConfig(
termination_type=TerminationTypes.COMBINATION,
conditions=[], # Empty conditions
component_type=ComponentTypes.TERMINATION,
version="1.0.0",
)
)
with pytest.raises(ValueError):
await component_factory.load_termination(
CombinationTerminationConfig(
termination_type=TerminationTypes.COMBINATION,
operator="invalid", # type: ignore
conditions=[
MaxMessageTerminationConfig(
termination_type=TerminationTypes.MAX_MESSAGES,
max_messages=5,
component_type=ComponentTypes.TERMINATION,
version="1.0.0",
)
],
component_type=ComponentTypes.TERMINATION,
version="1.0.0",
)
)
# Test missing operator
with pytest.raises(ValueError):
await component_factory.load_termination(
CombinationTerminationConfig(
termination_type=TerminationTypes.COMBINATION,
conditions=[
MaxMessageTerminationConfig(
termination_type=TerminationTypes.MAX_MESSAGES,
max_messages=5,
component_type=ComponentTypes.TERMINATION,
version="1.0.0",
),
TextMentionTerminationConfig(
termination_type=TerminationTypes.TEXT_MENTION,
text="DONE",
component_type=ComponentTypes.TERMINATION,
version="1.0.0",
),
],
component_type=ComponentTypes.TERMINATION,
version="1.0.0",
)
)
@pytest.mark.asyncio
async def test_load_team(
component_factory: ComponentFactory, sample_team_config: RoundRobinTeamConfig, sample_model_config: OpenAIModelConfig
):
# Test loading RoundRobinGroupChat team
team = await component_factory.load_team(sample_team_config)
assert isinstance(team, RoundRobinGroupChat)
assert len(team._participants) == 1
# Test loading SelectorGroupChat team with multiple participants
selector_team_config = SelectorTeamConfig(
name="selector_team",
team_type=TeamTypes.SELECTOR,
participants=[ # Add two participants
sample_team_config.participants[0], # First agent
AssistantAgentConfig( # Second agent
name="test_agent_2",
agent_type=AgentTypes.ASSISTANT,
system_message="You are another helpful assistant",
model_client=sample_model_config,
tools=sample_team_config.participants[0].tools,
component_type=ComponentTypes.AGENT,
version="1.0.0",
),
],
termination_condition=sample_team_config.termination_condition,
model_client=sample_model_config,
component_type=ComponentTypes.TEAM,
version="1.0.0",
)
team = await component_factory.load_team(selector_team_config)
assert isinstance(team, SelectorGroupChat)
assert len(team._participants) == 2
# Test loading MagenticOneGroupChat team
magentic_one_config = MagenticOneTeamConfig(
name="magentic_one_team",
team_type=TeamTypes.MAGENTIC_ONE,
participants=[ # Add two participants
sample_team_config.participants[0], # First agent
AssistantAgentConfig( # Second agent
name="test_agent_2",
agent_type=AgentTypes.ASSISTANT,
system_message="You are another helpful assistant",
model_client=sample_model_config,
tools=sample_team_config.participants[0].tools,
component_type=ComponentTypes.AGENT,
max_turns=sample_team_config.max_turns,
version="1.0.0",
),
],
termination_condition=sample_team_config.termination_condition,
model_client=sample_model_config,
component_type=ComponentTypes.TEAM,
version="1.0.0",
)
team = await component_factory.load_team(magentic_one_config)
assert isinstance(team, MagenticOneGroupChat)
assert len(team._participants) == 2
@pytest.mark.asyncio
async def test_invalid_configs(component_factory: ComponentFactory):
# Test invalid agent type
with pytest.raises(ValueError):
await component_factory.load_agent(
AssistantAgentConfig(
name="test",
agent_type="InvalidAgent", # type: ignore
system_message="test",
component_type=ComponentTypes.AGENT,
version="1.0.0",
)
)
# Test invalid team type
with pytest.raises(ValueError):
await component_factory.load_team(
RoundRobinTeamConfig(
name="test",
team_type="InvalidTeam", # type: ignore
participants=[],
component_type=ComponentTypes.TEAM,
version="1.0.0",
)
)
# Test invalid termination type
with pytest.raises(ValueError):
await component_factory.load_termination(
MaxMessageTerminationConfig(
termination_type="InvalidTermination", # type: ignore
component_type=ComponentTypes.TERMINATION,
version="1.0.0",
)
)