mirror of
				https://github.com/microsoft/autogen.git
				synced 2025-10-30 17:29:47 +00:00 
			
		
		
		
	 6a4a11042c
			
		
	
	
		6a4a11042c
		
			
		
	
	
	
	
		
			
			* add oai support, improve component config typing, minor updates to docs, update ags tests * faq updates * update faq, add model_capabilities * update faq
		
			
				
	
	
		
			246 lines
		
	
	
		
			8.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			246 lines
		
	
	
		
			8.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import os
 | |
| import asyncio
 | |
| import pytest
 | |
| from sqlmodel import Session, text, select
 | |
| from typing import Generator
 | |
| 
 | |
| from autogenstudio.database import DatabaseManager
 | |
| from autogenstudio.datamodel.types import (
 | |
|     ToolConfig,
 | |
|     OpenAIModelConfig,
 | |
|     RoundRobinTeamConfig,
 | |
|     StopMessageTerminationConfig,
 | |
|     AssistantAgentConfig,
 | |
|     ModelTypes, AgentTypes, TeamTypes, ComponentTypes,
 | |
|     TerminationTypes,  ToolTypes
 | |
| )
 | |
| from autogenstudio.datamodel.db import Model, Tool, Agent, Team, LinkTypes
 | |
| 
 | |
| 
 | |
| @pytest.fixture
 | |
| def test_db() -> Generator[DatabaseManager, None, None]:
 | |
|     """Fixture for test database"""
 | |
|     db_path = "test.db"
 | |
|     db = DatabaseManager(f"sqlite:///{db_path}")
 | |
|     db.reset_db()
 | |
|     # Initialize database instead of create_db_and_tables
 | |
|     db.initialize_database(auto_upgrade=False)
 | |
|     yield db
 | |
|     # Clean up
 | |
|     asyncio.run(db.close())
 | |
|     db.reset_db()
 | |
|     try:
 | |
|         if os.path.exists(db_path):
 | |
|             os.remove(db_path)
 | |
|     except Exception as e:
 | |
|         print(f"Warning: Failed to remove test database file: {e}")
 | |
| 
 | |
| 
 | |
| @pytest.fixture
 | |
| def test_user() -> str:
 | |
|     return "test_user@example.com"
 | |
| 
 | |
| 
 | |
| @pytest.fixture
 | |
| def sample_model(test_user: str) -> Model:
 | |
|     """Create a sample model with proper config"""
 | |
|     return Model(
 | |
|         user_id=test_user,
 | |
|         config=OpenAIModelConfig(
 | |
|             model="gpt-4",
 | |
|             model_type=ModelTypes.OPENAI,
 | |
|             component_type=ComponentTypes.MODEL,
 | |
|             version="1.0.0"
 | |
|         ).model_dump()
 | |
|     )
 | |
| 
 | |
| 
 | |
| @pytest.fixture
 | |
| def sample_tool(test_user: str) -> Tool:
 | |
|     """Create a sample tool with proper config"""
 | |
|     return Tool(
 | |
|         user_id=test_user,
 | |
|         config=ToolConfig(
 | |
|             name="test_tool",
 | |
|             description="A test tool",
 | |
|             content="async def test_func(x: str) -> str:\n    return f'Test {x}'",
 | |
|             tool_type=ToolTypes.PYTHON_FUNCTION,
 | |
|             component_type=ComponentTypes.TOOL,
 | |
|             version="1.0.0"
 | |
|         ).model_dump()
 | |
|     )
 | |
| 
 | |
| 
 | |
| @pytest.fixture
 | |
| def sample_agent(test_user: str, sample_model: Model, sample_tool: Tool) -> Agent:
 | |
|     """Create a sample agent with proper config and relationships"""
 | |
|     return Agent(
 | |
|         user_id=test_user,
 | |
|         config=AssistantAgentConfig(
 | |
|             name="test_agent",
 | |
|             agent_type=AgentTypes.ASSISTANT,
 | |
|             model_client=OpenAIModelConfig.model_validate(sample_model.config),
 | |
|             tools=[ToolConfig.model_validate(sample_tool.config)],
 | |
|             component_type=ComponentTypes.AGENT,
 | |
|             version="1.0.0"
 | |
|         ).model_dump()
 | |
|     )
 | |
| 
 | |
| 
 | |
| @pytest.fixture
 | |
| def sample_team(test_user: str, sample_agent: Agent) -> Team:
 | |
|     """Create a sample team with proper config"""
 | |
|     return Team(
 | |
|         user_id=test_user,
 | |
|         config=RoundRobinTeamConfig(
 | |
|             name="test_team",
 | |
|             participants=[AssistantAgentConfig.model_validate(
 | |
|                 sample_agent.config)],
 | |
|             termination_condition=StopMessageTerminationConfig(
 | |
|                 termination_type=TerminationTypes.STOP_MESSAGE,
 | |
|                 component_type=ComponentTypes.TERMINATION,
 | |
|                 version="1.0.0"
 | |
|             ).model_dump(),
 | |
|             team_type=TeamTypes.ROUND_ROBIN,
 | |
|             component_type=ComponentTypes.TEAM,
 | |
|             version="1.0.0"
 | |
|         ).model_dump()
 | |
|     )
 | |
| 
 | |
| 
 | |
| class TestDatabaseOperations:
 | |
|     def test_basic_setup(self, test_db: DatabaseManager):
 | |
|         """Test basic database setup and connection"""
 | |
|         with Session(test_db.engine) as session:
 | |
|             result = session.exec(text("SELECT 1")).first()
 | |
|             assert result[0] == 1
 | |
|             result = session.exec(select(1)).first()
 | |
|             assert result == 1
 | |
| 
 | |
|     def test_basic_entity_creation(self, test_db: DatabaseManager, sample_model: Model,
 | |
|                                    sample_tool: Tool, sample_agent: Agent, sample_team: Team):
 | |
|         """Test creating all entity types with proper configs"""
 | |
|         with Session(test_db.engine) as session:
 | |
|             # Add all entities
 | |
|             session.add(sample_model)
 | |
|             session.add(sample_tool)
 | |
|             session.add(sample_agent)
 | |
|             session.add(sample_team)
 | |
|             session.commit()
 | |
| 
 | |
|             # Store IDs
 | |
|             model_id = sample_model.id
 | |
|             tool_id = sample_tool.id
 | |
|             agent_id = sample_agent.id
 | |
|             team_id = sample_team.id
 | |
| 
 | |
|         # Verify all entities were created with new session
 | |
|         with Session(test_db.engine) as session:
 | |
|             assert session.get(Model, model_id) is not None
 | |
|             assert session.get(Tool, tool_id) is not None
 | |
|             assert session.get(Agent, agent_id) is not None
 | |
|             assert session.get(Team, team_id) is not None
 | |
| 
 | |
|     def test_multiple_links(self, test_db: DatabaseManager, sample_agent: Agent):
 | |
|         """Test linking multiple models to an agent"""
 | |
|         with Session(test_db.engine) as session:
 | |
|             # Create two models with updated configs
 | |
|             model1 = Model(
 | |
|                 user_id="test_user",
 | |
|                 config=OpenAIModelConfig(
 | |
|                     model="gpt-4",
 | |
|                     model_type=ModelTypes.OPENAI,
 | |
|                     component_type=ComponentTypes.MODEL,
 | |
|                     version="1.0.0"
 | |
|                 ).model_dump()
 | |
|             )
 | |
|             model2 = Model(
 | |
|                 user_id="test_user",
 | |
|                 config=OpenAIModelConfig(
 | |
|                     model="gpt-3.5",
 | |
|                     model_type=ModelTypes.OPENAI,
 | |
|                     component_type=ComponentTypes.MODEL,
 | |
|                     version="1.0.0"
 | |
|                 ).model_dump()
 | |
|             )
 | |
| 
 | |
|             # Add and commit all entities
 | |
|             session.add(model1)
 | |
|             session.add(model2)
 | |
|             session.add(sample_agent)
 | |
|             session.commit()
 | |
| 
 | |
|             model1_id = model1.id
 | |
|             model2_id = model2.id
 | |
|             agent_id = sample_agent.id
 | |
| 
 | |
|         # Create links using IDs
 | |
|         test_db.link(LinkTypes.AGENT_MODEL, agent_id, model1_id)
 | |
|         test_db.link(LinkTypes.AGENT_MODEL, agent_id, model2_id)
 | |
| 
 | |
|         # Verify links
 | |
|         linked_models = test_db.get_linked_entities(
 | |
|             LinkTypes.AGENT_MODEL, agent_id)
 | |
|         assert len(linked_models.data) == 2
 | |
| 
 | |
|         # Verify model names
 | |
|         model_names = [model.config["model"] for model in linked_models.data]
 | |
|         assert "gpt-4" in model_names
 | |
|         assert "gpt-3.5" in model_names
 | |
| 
 | |
|     def test_upsert_operations(self, test_db: DatabaseManager, sample_model: Model):
 | |
|         """Test upsert for both create and update scenarios"""
 | |
|         # Test Create
 | |
|         response = test_db.upsert(sample_model)
 | |
|         assert response.status is True
 | |
|         assert "Created Successfully" in response.message
 | |
| 
 | |
|         # Test Update
 | |
|         sample_model.config["model"] = "gpt-4-turbo"
 | |
|         response = test_db.upsert(sample_model)
 | |
|         assert response.status is True
 | |
|         assert "Updated Successfully" in response.message
 | |
| 
 | |
|         # Verify Update
 | |
|         result = test_db.get(Model, {"id": sample_model.id})
 | |
|         assert result.status is True
 | |
|         assert result.data[0].config["model"] == "gpt-4-turbo"
 | |
| 
 | |
|     def test_delete_operations(self, test_db: DatabaseManager, sample_model: Model):
 | |
|         """Test delete with various filters"""
 | |
|         # First insert the model
 | |
|         test_db.upsert(sample_model)
 | |
| 
 | |
|         # Test deletion by id
 | |
|         response = test_db.delete(Model, {"id": sample_model.id})
 | |
|         assert response.status is True
 | |
|         assert "Deleted Successfully" in response.message
 | |
| 
 | |
|         # Verify deletion
 | |
|         result = test_db.get(Model, {"id": sample_model.id})
 | |
|         assert len(result.data) == 0
 | |
| 
 | |
|         # Test deletion with non-existent id
 | |
|         response = test_db.delete(Model, {"id": 999999})
 | |
|         assert "Row not found" in response.message
 | |
| 
 | |
|     def test_initialize_database_scenarios(self):
 | |
|         """Test different initialize_database parameters"""
 | |
|         db_path = "test_init.db"
 | |
|         db = DatabaseManager(f"sqlite:///{db_path}")
 | |
| 
 | |
|         try:
 | |
|             # Test basic initialization
 | |
|             response = db.initialize_database()
 | |
|             assert response.status is True
 | |
| 
 | |
|             # Test with auto_upgrade
 | |
|             response = db.initialize_database(auto_upgrade=True)
 | |
|             assert response.status is True
 | |
| 
 | |
|         finally:
 | |
|             asyncio.run(db.close())
 | |
|             db.reset_db()
 | |
|             if os.path.exists(db_path):
 | |
|                 os.remove(db_path)
 |