mirror of
				https://github.com/microsoft/autogen.git
				synced 2025-10-20 12:29:37 +00:00 
			
		
		
		
	 fe96f7de24
			
		
	
	
		fe96f7de24
		
			
		
	
	
	
	
		
			
			* fix import issue related to agentchat update #4245 * update uv lock file * fix db auto_upgrade logic issue. * im prove msg rendering issue * Support termination condition combination. Closes #4325 * fix db instantiation bug * update yarn.lock, closes #4260 #4262 * remove deps for now with vulnerabilities found by dependabot #4262 * update db tests * add ability to load sessions from db .. * format updates, add format checks to ags * format check fixes * linting and ruff check fixes * make tests for ags non-parrallel to avoid db race conditions. * format updates * fix concurrency issue * minor ui tweaks, move run start to websocket * lint fixes * update uv.lock * Update python/packages/autogen-studio/autogenstudio/datamodel/types.py Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com> * Update python/packages/autogen-studio/autogenstudio/teammanager.py Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com> * reuse user proxy from agentchat * ui tweaks --------- Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com> Co-authored-by: Hussein Mozannar <hmozannar@microsoft.com>
		
			
				
	
	
		
			407 lines
		
	
	
		
			16 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			407 lines
		
	
	
		
			16 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| import threading
 | |
| from datetime import datetime
 | |
| from pathlib import Path
 | |
| from typing import Optional
 | |
| 
 | |
| from loguru import logger
 | |
| from sqlalchemy import exc, func, inspect, text
 | |
| from sqlmodel import Session, SQLModel, and_, create_engine, select
 | |
| 
 | |
| from ..datamodel import LinkTypes, Response
 | |
| from .schema_manager import SchemaManager
 | |
| 
 | |
| # from .dbutils import init_db_samples
 | |
| 
 | |
| 
 | |
| class DatabaseManager:
 | |
|     _init_lock = threading.Lock()
 | |
| 
 | |
|     def __init__(self, engine_uri: str, base_dir: Optional[Path] = None):
 | |
|         """
 | |
|         Initialize DatabaseManager with database connection settings.
 | |
|         Does not perform any database operations.
 | |
| 
 | |
|         Args:
 | |
|             engine_uri: Database connection URI (e.g. sqlite:///db.sqlite3)
 | |
|             base_dir: Base directory for migration files. If None, uses current directory
 | |
|         """
 | |
|         connection_args = {"check_same_thread": True} if "sqlite" in engine_uri else {}
 | |
| 
 | |
|         self.engine = create_engine(engine_uri, connect_args=connection_args)
 | |
|         self.schema_manager = SchemaManager(
 | |
|             engine=self.engine,
 | |
|             base_dir=base_dir,
 | |
|         )
 | |
| 
 | |
|     def initialize_database(self, auto_upgrade: bool = False, force_init_alembic: bool = True) -> Response:
 | |
|         """
 | |
|         Initialize database and migrations in the correct order.
 | |
| 
 | |
|         Args:
 | |
|             auto_upgrade: If True, automatically generate and apply migrations for schema changes
 | |
|             force_init_alembic: If True, reinitialize alembic configuration even if it exists
 | |
|         """
 | |
|         if not self._init_lock.acquire(blocking=False):
 | |
|             return Response(message="Database initialization already in progress", status=False)
 | |
| 
 | |
|         try:
 | |
|             inspector = inspect(self.engine)
 | |
|             tables_exist = inspector.get_table_names()
 | |
| 
 | |
|             if not tables_exist:
 | |
|                 # Fresh install - create tables and initialize migrations
 | |
|                 logger.info("Creating database tables...")
 | |
|                 SQLModel.metadata.create_all(self.engine)
 | |
| 
 | |
|                 if self.schema_manager.initialize_migrations(force=force_init_alembic):
 | |
|                     return Response(message="Database initialized successfully", status=True)
 | |
|                 return Response(message="Failed to initialize migrations", status=False)
 | |
| 
 | |
|             # Handle existing database
 | |
|             if auto_upgrade:
 | |
|                 logger.info("Checking database schema...")
 | |
|                 if self.schema_manager.ensure_schema_up_to_date():  # <-- Use this instead
 | |
|                     return Response(message="Database schema is up to date", status=True)
 | |
|                 return Response(message="Database upgrade failed", status=False)
 | |
| 
 | |
|             return Response(message="Database is ready", status=True)
 | |
| 
 | |
|         except Exception as e:
 | |
|             error_msg = f"Database initialization failed: {str(e)}"
 | |
|             logger.error(error_msg)
 | |
|             return Response(message=error_msg, status=False)
 | |
|         finally:
 | |
|             self._init_lock.release()
 | |
| 
 | |
|     def reset_db(self, recreate_tables: bool = True):
 | |
|         """
 | |
|         Reset the database by dropping all tables and optionally recreating them.
 | |
| 
 | |
|         Args:
 | |
|             recreate_tables (bool): If True, recreates the tables after dropping them.
 | |
|                                 Set to False if you want to call create_db_and_tables() separately.
 | |
|         """
 | |
|         if not self._init_lock.acquire(blocking=False):
 | |
|             logger.warning("Database reset already in progress")
 | |
|             return Response(message="Database reset already in progress", status=False, data=None)
 | |
| 
 | |
|         try:
 | |
|             # Dispose existing connections
 | |
|             self.engine.dispose()
 | |
|             with Session(self.engine) as session:
 | |
|                 try:
 | |
|                     # Disable foreign key checks for SQLite
 | |
|                     if "sqlite" in str(self.engine.url):
 | |
|                         session.exec(text("PRAGMA foreign_keys=OFF"))
 | |
| 
 | |
|                     # Drop all tables
 | |
|                     SQLModel.metadata.drop_all(self.engine)
 | |
|                     logger.info("All tables dropped successfully")
 | |
| 
 | |
|                     # Re-enable foreign key checks for SQLite
 | |
|                     if "sqlite" in str(self.engine.url):
 | |
|                         session.exec(text("PRAGMA foreign_keys=ON"))
 | |
| 
 | |
|                     session.commit()
 | |
| 
 | |
|                 except Exception as e:
 | |
|                     session.rollback()
 | |
|                     raise e
 | |
|                 finally:
 | |
|                     session.close()
 | |
|                     self._init_lock.release()
 | |
| 
 | |
|             if recreate_tables:
 | |
|                 logger.info("Recreating tables...")
 | |
|                 self.initialize_database(auto_upgrade=False, force_init_alembic=True)
 | |
| 
 | |
|             return Response(
 | |
|                 message="Database reset successfully" if recreate_tables else "Database tables dropped successfully",
 | |
|                 status=True,
 | |
|                 data=None,
 | |
|             )
 | |
| 
 | |
|         except Exception as e:
 | |
|             error_msg = f"Error while resetting database: {str(e)}"
 | |
|             logger.error(error_msg)
 | |
|             return Response(message=error_msg, status=False, data=None)
 | |
|         finally:
 | |
|             if self._init_lock.locked():
 | |
|                 self._init_lock.release()
 | |
|                 logger.info("Database reset lock released")
 | |
| 
 | |
|     def upsert(self, model: SQLModel, return_json: bool = True):
 | |
|         """Create or update an entity
 | |
| 
 | |
|         Args:
 | |
|             model (SQLModel): The model instance to create or update
 | |
|             return_json (bool, optional): If True, returns the model as a dictionary.
 | |
|                 If False, returns the SQLModel instance. Defaults to True.
 | |
| 
 | |
|         Returns:
 | |
|             Response: Contains status, message and data (either dict or SQLModel based on return_json)
 | |
|         """
 | |
|         status = True
 | |
|         model_class = type(model)
 | |
|         existing_model = None
 | |
| 
 | |
|         with Session(self.engine) as session:
 | |
|             try:
 | |
|                 existing_model = session.exec(select(model_class).where(model_class.id == model.id)).first()
 | |
|                 if existing_model:
 | |
|                     model.updated_at = datetime.now()
 | |
|                     for key, value in model.model_dump().items():
 | |
|                         setattr(existing_model, key, value)
 | |
|                     model = existing_model  # Use the updated existing model
 | |
|                     session.add(model)
 | |
|                 else:
 | |
|                     session.add(model)
 | |
|                 session.commit()
 | |
|                 session.refresh(model)
 | |
|             except Exception as e:
 | |
|                 session.rollback()
 | |
|                 logger.error("Error while updating/creating " + str(model_class.__name__) + ": " + str(e))
 | |
|                 status = False
 | |
| 
 | |
|         return Response(
 | |
|             message=(
 | |
|                 f"{model_class.__name__} Updated Successfully"
 | |
|                 if existing_model
 | |
|                 else f"{model_class.__name__} Created Successfully"
 | |
|             ),
 | |
|             status=status,
 | |
|             data=model.model_dump() if return_json else model,
 | |
|         )
 | |
| 
 | |
|     def _model_to_dict(self, model_obj):
 | |
|         return {col.name: getattr(model_obj, col.name) for col in model_obj.__table__.columns}
 | |
| 
 | |
|     def get(
 | |
|         self,
 | |
|         model_class: SQLModel,
 | |
|         filters: dict = None,
 | |
|         return_json: bool = False,
 | |
|         order: str = "desc",
 | |
|     ):
 | |
|         """List entities"""
 | |
|         with Session(self.engine) as session:
 | |
|             result = []
 | |
|             status = True
 | |
|             status_message = ""
 | |
| 
 | |
|             try:
 | |
|                 statement = select(model_class)
 | |
|                 if filters:
 | |
|                     conditions = [getattr(model_class, col) == value for col, value in filters.items()]
 | |
|                     statement = statement.where(and_(*conditions))
 | |
| 
 | |
|                 if hasattr(model_class, "created_at") and order:
 | |
|                     order_by_clause = getattr(model_class.created_at, order)()  # Dynamically apply asc/desc
 | |
|                     statement = statement.order_by(order_by_clause)
 | |
| 
 | |
|                 items = session.exec(statement).all()
 | |
|                 result = [self._model_to_dict(item) if return_json else item for item in items]
 | |
|                 status_message = f"{model_class.__name__} Retrieved Successfully"
 | |
|             except Exception as e:
 | |
|                 session.rollback()
 | |
|                 status = False
 | |
|                 status_message = f"Error while fetching {model_class.__name__}"
 | |
|                 logger.error("Error while getting items: " + str(model_class.__name__) + " " + str(e))
 | |
| 
 | |
|             return Response(message=status_message, status=status, data=result)
 | |
| 
 | |
|     def delete(self, model_class: SQLModel, filters: dict = None):
 | |
|         """Delete an entity"""
 | |
|         status_message = ""
 | |
|         status = True
 | |
| 
 | |
|         with Session(self.engine) as session:
 | |
|             try:
 | |
|                 statement = select(model_class)
 | |
|                 if filters:
 | |
|                     conditions = [getattr(model_class, col) == value for col, value in filters.items()]
 | |
|                     statement = statement.where(and_(*conditions))
 | |
| 
 | |
|                 rows = session.exec(statement).all()
 | |
| 
 | |
|                 if rows:
 | |
|                     for row in rows:
 | |
|                         session.delete(row)
 | |
|                     session.commit()
 | |
|                     status_message = f"{model_class.__name__} Deleted Successfully"
 | |
|                 else:
 | |
|                     status_message = "Row not found"
 | |
|                     logger.info(f"Row with filters {filters} not found")
 | |
| 
 | |
|             except exc.IntegrityError as e:
 | |
|                 session.rollback()
 | |
|                 status = False
 | |
|                 status_message = f"Integrity error: The {model_class.__name__} is linked to another entity and cannot be deleted. {e}"
 | |
|                 # Log the specific integrity error
 | |
|                 logger.error(status_message)
 | |
|             except Exception as e:
 | |
|                 session.rollback()
 | |
|                 status = False
 | |
|                 status_message = f"Error while deleting: {e}"
 | |
|                 logger.error(status_message)
 | |
| 
 | |
|         return Response(message=status_message, status=status, data=None)
 | |
| 
 | |
|     def link(
 | |
|         self,
 | |
|         link_type: LinkTypes,
 | |
|         primary_id: int,
 | |
|         secondary_id: int,
 | |
|         sequence: Optional[int] = None,
 | |
|     ):
 | |
|         """Link two entities with automatic sequence handling."""
 | |
|         with Session(self.engine) as session:
 | |
|             try:
 | |
|                 # Get classes from LinkTypes
 | |
|                 primary_class = link_type.primary_class
 | |
|                 secondary_class = link_type.secondary_class
 | |
|                 link_table = link_type.link_table
 | |
| 
 | |
|                 # Get entities
 | |
|                 primary_entity = session.get(primary_class, primary_id)
 | |
|                 secondary_entity = session.get(secondary_class, secondary_id)
 | |
| 
 | |
|                 if not primary_entity or not secondary_entity:
 | |
|                     return Response(message="One or both entities do not exist", status=False)
 | |
| 
 | |
|                 # Get field names
 | |
|                 primary_id_field = f"{primary_class.__name__.lower()}_id"
 | |
|                 secondary_id_field = f"{secondary_class.__name__.lower()}_id"
 | |
| 
 | |
|                 # Check for existing link
 | |
|                 existing_link = session.exec(
 | |
|                     select(link_table).where(
 | |
|                         and_(
 | |
|                             getattr(link_table, primary_id_field) == primary_id,
 | |
|                             getattr(link_table, secondary_id_field) == secondary_id,
 | |
|                         )
 | |
|                     )
 | |
|                 ).first()
 | |
| 
 | |
|                 if existing_link:
 | |
|                     return Response(message="Link already exists", status=False)
 | |
| 
 | |
|                 # Get the next sequence number if not provided
 | |
|                 if sequence is None:
 | |
|                     max_seq_result = session.exec(
 | |
|                         select(func.max(link_table.sequence)).where(getattr(link_table, primary_id_field) == primary_id)
 | |
|                     ).first()
 | |
|                     sequence = 0 if max_seq_result is None else max_seq_result + 1
 | |
| 
 | |
|                 # Create new link
 | |
|                 new_link = link_table(
 | |
|                     **{primary_id_field: primary_id, secondary_id_field: secondary_id, "sequence": sequence}
 | |
|                 )
 | |
|                 session.add(new_link)
 | |
|                 session.commit()
 | |
| 
 | |
|                 return Response(message=f"Entities linked successfully with sequence {sequence}", status=True)
 | |
| 
 | |
|             except Exception as e:
 | |
|                 session.rollback()
 | |
|                 return Response(message=f"Error linking entities: {str(e)}", status=False)
 | |
| 
 | |
|     def unlink(self, link_type: LinkTypes, primary_id: int, secondary_id: int, sequence: Optional[int] = None):
 | |
|         """Unlink two entities and reorder sequences if needed."""
 | |
|         with Session(self.engine) as session:
 | |
|             try:
 | |
|                 # Get classes from LinkTypes
 | |
|                 primary_class = link_type.primary_class
 | |
|                 secondary_class = link_type.secondary_class
 | |
|                 link_table = link_type.link_table
 | |
| 
 | |
|                 # Get field names
 | |
|                 primary_id_field = f"{primary_class.__name__.lower()}_id"
 | |
|                 secondary_id_field = f"{secondary_class.__name__.lower()}_id"
 | |
| 
 | |
|                 # Find existing link
 | |
|                 statement = select(link_table).where(
 | |
|                     and_(
 | |
|                         getattr(link_table, primary_id_field) == primary_id,
 | |
|                         getattr(link_table, secondary_id_field) == secondary_id,
 | |
|                     )
 | |
|                 )
 | |
| 
 | |
|                 if sequence is not None:
 | |
|                     statement = statement.where(link_table.sequence == sequence)
 | |
| 
 | |
|                 existing_link = session.exec(statement).first()
 | |
| 
 | |
|                 if not existing_link:
 | |
|                     return Response(message="Link does not exist", status=False)
 | |
| 
 | |
|                 deleted_sequence = existing_link.sequence
 | |
|                 session.delete(existing_link)
 | |
| 
 | |
|                 # Reorder sequences for remaining links
 | |
|                 remaining_links = session.exec(
 | |
|                     select(link_table)
 | |
|                     .where(getattr(link_table, primary_id_field) == primary_id)
 | |
|                     .where(link_table.sequence > deleted_sequence)
 | |
|                     .order_by(link_table.sequence)
 | |
|                 ).all()
 | |
| 
 | |
|                 # Decrease sequence numbers to fill the gap
 | |
|                 for link in remaining_links:
 | |
|                     link.sequence -= 1
 | |
| 
 | |
|                 session.commit()
 | |
| 
 | |
|                 return Response(message="Entities unlinked successfully and sequences reordered", status=True)
 | |
| 
 | |
|             except Exception as e:
 | |
|                 session.rollback()
 | |
|                 return Response(message=f"Error unlinking entities: {str(e)}", status=False)
 | |
| 
 | |
|     def get_linked_entities(
 | |
|         self,
 | |
|         link_type: LinkTypes,
 | |
|         primary_id: int,
 | |
|         return_json: bool = False,
 | |
|     ):
 | |
|         """Get linked entities based on link type and primary ID, ordered by sequence."""
 | |
|         with Session(self.engine) as session:
 | |
|             try:
 | |
|                 # Get classes from LinkTypes
 | |
|                 primary_class = link_type.primary_class
 | |
|                 secondary_class = link_type.secondary_class
 | |
|                 link_table = link_type.link_table
 | |
| 
 | |
|                 # Get field names
 | |
|                 primary_id_field = f"{primary_class.__name__.lower()}_id"
 | |
|                 secondary_id_field = f"{secondary_class.__name__.lower()}_id"
 | |
| 
 | |
|                 # Query both link and entity, ordered by sequence
 | |
|                 items = session.exec(
 | |
|                     select(secondary_class)
 | |
|                     .join(link_table, getattr(link_table, secondary_id_field) == secondary_class.id)
 | |
|                     .where(getattr(link_table, primary_id_field) == primary_id)
 | |
|                     .order_by(link_table.sequence)
 | |
|                 ).all()
 | |
| 
 | |
|                 result = [item.model_dump() if return_json else item for item in items]
 | |
| 
 | |
|                 return Response(message="Linked entities retrieved successfully", status=True, data=result)
 | |
| 
 | |
|             except Exception as e:
 | |
|                 logger.error(f"Error getting linked entities: {str(e)}")
 | |
|                 return Response(message=f"Error getting linked entities: {str(e)}", status=False, data=[])
 | |
| 
 | |
|     # Add new close method
 | |
| 
 | |
|     async def close(self):
 | |
|         """Close database connections and cleanup resources"""
 | |
|         logger.info("Closing database connections...")
 | |
|         try:
 | |
|             # Dispose of the SQLAlchemy engine
 | |
|             self.engine.dispose()
 | |
|             logger.info("Database connections closed successfully")
 | |
|         except Exception as e:
 | |
|             logger.error(f"Error closing database connections: {str(e)}")
 | |
|             raise
 |