mirror of
				https://github.com/microsoft/autogen.git
				synced 2025-10-30 17:29:47 +00:00 
			
		
		
		
	 6a4a11042c
			
		
	
	
		6a4a11042c
		
			
		
	
	
	
	
		
			
			* add oai support, improve component config typing, minor updates to docs, update ags tests * faq updates * update faq, add model_capabilities * update faq
		
			
				
	
	
		
			398 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			398 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import pytest
 | |
| from typing import List
 | |
| 
 | |
| from autogen_agentchat.agents import AssistantAgent
 | |
| from autogen_agentchat.teams import RoundRobinGroupChat, SelectorGroupChat, MagenticOneGroupChat
 | |
| from autogen_agentchat.conditions import MaxMessageTermination, StopMessageTermination, TextMentionTermination
 | |
| from autogen_core.tools import FunctionTool
 | |
| 
 | |
| from autogenstudio.datamodel.types import (
 | |
|     AssistantAgentConfig,
 | |
|     OpenAIModelConfig,
 | |
|     RoundRobinTeamConfig,
 | |
|     SelectorTeamConfig,
 | |
|     MagenticOneTeamConfig,
 | |
|     ToolConfig,
 | |
|     MaxMessageTerminationConfig,
 | |
|     StopMessageTerminationConfig,
 | |
|     TextMentionTerminationConfig,
 | |
|     CombinationTerminationConfig,
 | |
|     ModelTypes,
 | |
|     AgentTypes,
 | |
|     TeamTypes,
 | |
|     TerminationTypes,
 | |
|     ToolTypes,
 | |
|     ComponentTypes,
 | |
| )
 | |
| from autogenstudio.database import ComponentFactory
 | |
| 
 | |
| 
 | |
| @pytest.fixture
 | |
| def component_factory():
 | |
|     return ComponentFactory()
 | |
| 
 | |
| 
 | |
| @pytest.fixture
 | |
| def sample_tool_config():
 | |
|     return ToolConfig(
 | |
|         name="calculator",
 | |
|         description="A simple calculator function",
 | |
|         content="""
 | |
| def calculator(a: int, b: int, operation: str = '+') -> int:
 | |
|     '''
 | |
|     A simple calculator that performs basic operations
 | |
|     '''
 | |
|     if operation == '+':
 | |
|         return a + b
 | |
|     elif operation == '-':
 | |
|         return a - b
 | |
|     elif operation == '*':
 | |
|         return a * b
 | |
|     elif operation == '/':
 | |
|         return a / b
 | |
|     else:
 | |
|         raise ValueError("Invalid operation")
 | |
| """,
 | |
|         tool_type=ToolTypes.PYTHON_FUNCTION,
 | |
|         component_type=ComponentTypes.TOOL,
 | |
|         version="1.0.0",
 | |
|     )
 | |
| 
 | |
| 
 | |
| @pytest.fixture
 | |
| def sample_model_config():
 | |
|     return OpenAIModelConfig(
 | |
|         model_type=ModelTypes.OPENAI,
 | |
|         model="gpt-4",
 | |
|         api_key="test-key",
 | |
|         component_type=ComponentTypes.MODEL,
 | |
|         version="1.0.0",
 | |
|     )
 | |
| 
 | |
| 
 | |
| @pytest.fixture
 | |
| def sample_agent_config(sample_model_config: OpenAIModelConfig, sample_tool_config: ToolConfig):
 | |
|     return AssistantAgentConfig(
 | |
|         name="test_agent",
 | |
|         agent_type=AgentTypes.ASSISTANT,
 | |
|         system_message="You are a helpful assistant",
 | |
|         model_client=sample_model_config,
 | |
|         tools=[sample_tool_config],
 | |
|         component_type=ComponentTypes.AGENT,
 | |
|         version="1.0.0",
 | |
|     )
 | |
| 
 | |
| 
 | |
| @pytest.fixture
 | |
| def sample_termination_config():
 | |
|     return MaxMessageTerminationConfig(
 | |
|         termination_type=TerminationTypes.MAX_MESSAGES,
 | |
|         max_messages=10,
 | |
|         component_type=ComponentTypes.TERMINATION,
 | |
|         version="1.0.0",
 | |
|     )
 | |
| 
 | |
| 
 | |
| @pytest.fixture
 | |
| def sample_team_config(
 | |
|     sample_agent_config: AssistantAgentConfig, sample_termination_config: MaxMessageTerminationConfig, sample_model_config: OpenAIModelConfig
 | |
| ):
 | |
|     return RoundRobinTeamConfig(
 | |
|         name="test_team",
 | |
|         team_type=TeamTypes.ROUND_ROBIN,
 | |
|         participants=[sample_agent_config],
 | |
|         termination_condition=sample_termination_config,
 | |
|         model_client=sample_model_config,
 | |
|         component_type=ComponentTypes.TEAM,
 | |
|         max_turns=10,
 | |
|         version="1.0.0",
 | |
|     )
 | |
| 
 | |
| 
 | |
| @pytest.mark.asyncio
 | |
| async def test_load_tool(component_factory: ComponentFactory, sample_tool_config: ToolConfig):
 | |
|     # Test loading tool from ToolConfig
 | |
|     tool = await component_factory.load_tool(sample_tool_config)
 | |
|     assert isinstance(tool, FunctionTool)
 | |
|     assert tool.name == "calculator"
 | |
|     assert tool.description == "A simple calculator function"
 | |
| 
 | |
|     # Test tool functionality
 | |
|     result = tool._func(5, 3, "+")
 | |
|     assert result == 8
 | |
| 
 | |
| 
 | |
| @pytest.mark.asyncio
 | |
| async def test_load_tool_invalid_config(component_factory: ComponentFactory):
 | |
|     # Test with missing required fields
 | |
|     with pytest.raises(ValueError):
 | |
|         await component_factory.load_tool(
 | |
|             ToolConfig(
 | |
|                 name="test",
 | |
|                 description="",
 | |
|                 content="",
 | |
|                 tool_type=ToolTypes.PYTHON_FUNCTION,
 | |
|                 component_type=ComponentTypes.TOOL,
 | |
|                 version="1.0.0",
 | |
|             )
 | |
|         )
 | |
| 
 | |
|     # Test with invalid Python code
 | |
|     invalid_config = ToolConfig(
 | |
|         name="invalid",
 | |
|         description="Invalid function",
 | |
|         content="def invalid_func(): return invalid syntax",
 | |
|         tool_type=ToolTypes.PYTHON_FUNCTION,
 | |
|         component_type=ComponentTypes.TOOL,
 | |
|         version="1.0.0",
 | |
|     )
 | |
|     with pytest.raises(ValueError):
 | |
|         await component_factory.load_tool(invalid_config)
 | |
| 
 | |
| 
 | |
| @pytest.mark.asyncio
 | |
| async def test_load_model(component_factory: ComponentFactory, sample_model_config: OpenAIModelConfig):
 | |
|     # Test loading model from ModelConfig
 | |
|     model = await component_factory.load_model(sample_model_config)
 | |
|     assert model is not None
 | |
| 
 | |
| 
 | |
| @pytest.mark.asyncio
 | |
| async def test_load_agent(component_factory: ComponentFactory, sample_agent_config: AssistantAgentConfig):
 | |
|     # Test loading agent from AgentConfig
 | |
|     agent = await component_factory.load_agent(sample_agent_config)
 | |
|     assert isinstance(agent, AssistantAgent)
 | |
|     assert agent.name == "test_agent"
 | |
|     assert len(agent._tools) == 1
 | |
| 
 | |
| 
 | |
| @pytest.mark.asyncio
 | |
| async def test_load_termination(component_factory: ComponentFactory):
 | |
| 
 | |
|     max_msg_config = MaxMessageTerminationConfig(
 | |
|         termination_type=TerminationTypes.MAX_MESSAGES,
 | |
|         max_messages=5,
 | |
|         component_type=ComponentTypes.TERMINATION,
 | |
|         version="1.0.0",
 | |
|     )
 | |
|     termination = await component_factory.load_termination(max_msg_config)
 | |
|     assert isinstance(termination, MaxMessageTermination)
 | |
|     assert termination._max_messages == 5
 | |
| 
 | |
|     # Test StopMessageTermination
 | |
|     stop_msg_config = StopMessageTerminationConfig(
 | |
|         termination_type=TerminationTypes.STOP_MESSAGE, component_type=ComponentTypes.TERMINATION, version="1.0.0"
 | |
|     )
 | |
|     termination = await component_factory.load_termination(stop_msg_config)
 | |
|     assert isinstance(termination, StopMessageTermination)
 | |
| 
 | |
|     # Test TextMentionTermination
 | |
|     text_mention_config = TextMentionTerminationConfig(
 | |
|         termination_type=TerminationTypes.TEXT_MENTION,
 | |
|         text="DONE",
 | |
|         component_type=ComponentTypes.TERMINATION,
 | |
|         version="1.0.0",
 | |
|     )
 | |
|     termination = await component_factory.load_termination(text_mention_config)
 | |
|     assert isinstance(termination, TextMentionTermination)
 | |
|     assert termination._text == "DONE"
 | |
| 
 | |
|     # Test AND combination
 | |
|     and_combo_config = CombinationTerminationConfig(
 | |
|         termination_type=TerminationTypes.COMBINATION,
 | |
|         operator="and",
 | |
|         conditions=[
 | |
|             MaxMessageTerminationConfig(
 | |
|                 termination_type=TerminationTypes.MAX_MESSAGES,
 | |
|                 max_messages=5,
 | |
|                 component_type=ComponentTypes.TERMINATION,
 | |
|                 version="1.0.0",
 | |
|             ),
 | |
|             TextMentionTerminationConfig(
 | |
|                 termination_type=TerminationTypes.TEXT_MENTION,
 | |
|                 text="DONE",
 | |
|                 component_type=ComponentTypes.TERMINATION,
 | |
|                 version="1.0.0",
 | |
|             ),
 | |
|         ],
 | |
|         component_type=ComponentTypes.TERMINATION,
 | |
|         version="1.0.0",
 | |
|     )
 | |
|     termination = await component_factory.load_termination(and_combo_config)
 | |
|     assert termination is not None
 | |
| 
 | |
|     # Test OR combination
 | |
|     or_combo_config = CombinationTerminationConfig(
 | |
|         termination_type=TerminationTypes.COMBINATION,
 | |
|         operator="or",
 | |
|         conditions=[
 | |
|             MaxMessageTerminationConfig(
 | |
|                 termination_type=TerminationTypes.MAX_MESSAGES,
 | |
|                 max_messages=5,
 | |
|                 component_type=ComponentTypes.TERMINATION,
 | |
|                 version="1.0.0",
 | |
|             ),
 | |
|             TextMentionTerminationConfig(
 | |
|                 termination_type=TerminationTypes.TEXT_MENTION,
 | |
|                 text="DONE",
 | |
|                 component_type=ComponentTypes.TERMINATION,
 | |
|                 version="1.0.0",
 | |
|             ),
 | |
|         ],
 | |
|         component_type=ComponentTypes.TERMINATION,
 | |
|         version="1.0.0",
 | |
|     )
 | |
|     termination = await component_factory.load_termination(or_combo_config)
 | |
|     assert termination is not None
 | |
| 
 | |
|     # Test invalid combinations
 | |
|     with pytest.raises(ValueError):
 | |
|         await component_factory.load_termination(
 | |
|             CombinationTerminationConfig(
 | |
|                 termination_type=TerminationTypes.COMBINATION,
 | |
|                 conditions=[],  # Empty conditions
 | |
|                 component_type=ComponentTypes.TERMINATION,
 | |
|                 version="1.0.0",
 | |
|             )
 | |
|         )
 | |
| 
 | |
|     with pytest.raises(ValueError):
 | |
|         await component_factory.load_termination(
 | |
|             CombinationTerminationConfig(
 | |
|                 termination_type=TerminationTypes.COMBINATION,
 | |
|                 operator="invalid",  # type: ignore
 | |
|                 conditions=[
 | |
|                     MaxMessageTerminationConfig(
 | |
|                         termination_type=TerminationTypes.MAX_MESSAGES,
 | |
|                         max_messages=5,
 | |
|                         component_type=ComponentTypes.TERMINATION,
 | |
|                         version="1.0.0",
 | |
|                     )
 | |
|                 ],
 | |
|                 component_type=ComponentTypes.TERMINATION,
 | |
|                 version="1.0.0",
 | |
|             )
 | |
|         )
 | |
| 
 | |
|     # Test missing operator
 | |
|     with pytest.raises(ValueError):
 | |
|         await component_factory.load_termination(
 | |
|             CombinationTerminationConfig(
 | |
|                 termination_type=TerminationTypes.COMBINATION,
 | |
|                 conditions=[
 | |
|                     MaxMessageTerminationConfig(
 | |
|                         termination_type=TerminationTypes.MAX_MESSAGES,
 | |
|                         max_messages=5,
 | |
|                         component_type=ComponentTypes.TERMINATION,
 | |
|                         version="1.0.0",
 | |
|                     ),
 | |
|                     TextMentionTerminationConfig(
 | |
|                         termination_type=TerminationTypes.TEXT_MENTION,
 | |
|                         text="DONE",
 | |
|                         component_type=ComponentTypes.TERMINATION,
 | |
|                         version="1.0.0",
 | |
|                     ),
 | |
|                 ],
 | |
|                 component_type=ComponentTypes.TERMINATION,
 | |
|                 version="1.0.0",
 | |
|             )
 | |
|         )
 | |
| 
 | |
| 
 | |
| @pytest.mark.asyncio
 | |
| async def test_load_team(
 | |
|     component_factory: ComponentFactory, sample_team_config: RoundRobinTeamConfig, sample_model_config: OpenAIModelConfig
 | |
| ):
 | |
|     # Test loading RoundRobinGroupChat team
 | |
|     team = await component_factory.load_team(sample_team_config)
 | |
|     assert isinstance(team, RoundRobinGroupChat)
 | |
|     assert len(team._participants) == 1
 | |
| 
 | |
|     # Test loading SelectorGroupChat team with multiple participants
 | |
|     selector_team_config = SelectorTeamConfig(
 | |
|         name="selector_team",
 | |
|         team_type=TeamTypes.SELECTOR,
 | |
|         participants=[  # Add two participants
 | |
|             sample_team_config.participants[0],  # First agent
 | |
|             AssistantAgentConfig(  # Second agent
 | |
|                 name="test_agent_2",
 | |
|                 agent_type=AgentTypes.ASSISTANT,
 | |
|                 system_message="You are another helpful assistant",
 | |
|                 model_client=sample_model_config,
 | |
|                 tools=sample_team_config.participants[0].tools,
 | |
|                 component_type=ComponentTypes.AGENT,
 | |
|                 version="1.0.0",
 | |
|             ),
 | |
|         ],
 | |
|         termination_condition=sample_team_config.termination_condition,
 | |
|         model_client=sample_model_config,
 | |
|         component_type=ComponentTypes.TEAM,
 | |
|         version="1.0.0",
 | |
|     )
 | |
|     team = await component_factory.load_team(selector_team_config)
 | |
|     assert isinstance(team, SelectorGroupChat)
 | |
|     assert len(team._participants) == 2
 | |
| 
 | |
|     # Test loading MagenticOneGroupChat team
 | |
|     magentic_one_config = MagenticOneTeamConfig(
 | |
|         name="magentic_one_team",
 | |
|         team_type=TeamTypes.MAGENTIC_ONE,
 | |
|         participants=[  # Add two participants
 | |
|             sample_team_config.participants[0],  # First agent
 | |
|             AssistantAgentConfig(  # Second agent
 | |
|                 name="test_agent_2",
 | |
|                 agent_type=AgentTypes.ASSISTANT,
 | |
|                 system_message="You are another helpful assistant",
 | |
|                 model_client=sample_model_config,
 | |
|                 tools=sample_team_config.participants[0].tools,
 | |
|                 component_type=ComponentTypes.AGENT,
 | |
|                 max_turns=sample_team_config.max_turns,
 | |
|                 version="1.0.0",
 | |
|             ),
 | |
|         ],
 | |
|         termination_condition=sample_team_config.termination_condition,
 | |
|         model_client=sample_model_config,
 | |
|         component_type=ComponentTypes.TEAM,
 | |
|         version="1.0.0",
 | |
|     )
 | |
|     team = await component_factory.load_team(magentic_one_config)
 | |
|     assert isinstance(team, MagenticOneGroupChat)
 | |
|     assert len(team._participants) == 2
 | |
| 
 | |
| 
 | |
| @pytest.mark.asyncio
 | |
| async def test_invalid_configs(component_factory: ComponentFactory):
 | |
|     # Test invalid agent type
 | |
|     with pytest.raises(ValueError):
 | |
|         await component_factory.load_agent(
 | |
|             AssistantAgentConfig(
 | |
|                 name="test",
 | |
|                 agent_type="InvalidAgent",  # type: ignore
 | |
|                 system_message="test",
 | |
|                 component_type=ComponentTypes.AGENT,
 | |
|                 version="1.0.0",
 | |
|             )
 | |
|         )
 | |
| 
 | |
|     # Test invalid team type
 | |
|     with pytest.raises(ValueError):
 | |
|         await component_factory.load_team(
 | |
|             RoundRobinTeamConfig(
 | |
|                 name="test",
 | |
|                 team_type="InvalidTeam",  # type: ignore
 | |
|                 participants=[],
 | |
|                 component_type=ComponentTypes.TEAM,
 | |
|                 version="1.0.0",
 | |
|             )
 | |
|         )
 | |
| 
 | |
|     # Test invalid termination type
 | |
|     with pytest.raises(ValueError):
 | |
|         await component_factory.load_termination(
 | |
|             MaxMessageTerminationConfig(
 | |
|                 termination_type="InvalidTermination",  # type: ignore
 | |
|                 component_type=ComponentTypes.TERMINATION,
 | |
|                 version="1.0.0",
 | |
|             )
 | |
|         )
 |