| 
									
										
										
										
											2024-11-09 14:32:24 -08:00
										 |  |  | import threading | 
					
						
							|  |  |  | from datetime import datetime | 
					
						
							| 
									
										
										
										
											2024-11-26 15:39:36 -08:00
										 |  |  | from pathlib import Path | 
					
						
							| 
									
										
										
										
											2024-11-09 14:32:24 -08:00
										 |  |  | from typing import Optional | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | from loguru import logger | 
					
						
							| 
									
										
										
										
											2024-11-26 15:39:36 -08:00
										 |  |  | from sqlalchemy import exc, func, inspect, text | 
					
						
							| 
									
										
										
										
											2024-11-09 14:32:24 -08:00
										 |  |  | from sqlmodel import Session, SQLModel, and_, create_engine, select | 
					
						
							| 
									
										
										
										
											2024-11-26 15:39:36 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  | from ..datamodel import LinkTypes, Response | 
					
						
							| 
									
										
										
										
											2024-11-09 14:32:24 -08:00
										 |  |  | from .schema_manager import SchemaManager | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | # from .dbutils import init_db_samples | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | class DatabaseManager: | 
					
						
							|  |  |  |     _init_lock = threading.Lock() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-11-26 15:39:36 -08:00
										 |  |  |     def __init__(self, engine_uri: str, base_dir: Optional[Path] = None): | 
					
						
							| 
									
										
										
										
											2024-11-15 14:51:57 -08:00
										 |  |  |         """
 | 
					
						
							| 
									
										
										
										
											2024-11-26 15:39:36 -08:00
										 |  |  |         Initialize DatabaseManager with database connection settings. | 
					
						
							|  |  |  |         Does not perform any database operations. | 
					
						
							| 
									
										
										
										
											2024-11-15 14:51:57 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         Args: | 
					
						
							| 
									
										
										
										
											2024-11-26 15:39:36 -08:00
										 |  |  |             engine_uri: Database connection URI (e.g. sqlite:///db.sqlite3) | 
					
						
							|  |  |  |             base_dir: Base directory for migration files. If None, uses current directory | 
					
						
							| 
									
										
										
										
											2024-11-15 14:51:57 -08:00
										 |  |  |         """
 | 
					
						
							| 
									
										
										
										
											2024-11-26 15:39:36 -08:00
										 |  |  |         connection_args = {"check_same_thread": True} if "sqlite" in engine_uri else {} | 
					
						
							| 
									
										
										
										
											2024-11-15 14:51:57 -08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-11-09 14:32:24 -08:00
										 |  |  |         self.engine = create_engine(engine_uri, connect_args=connection_args) | 
					
						
							|  |  |  |         self.schema_manager = SchemaManager( | 
					
						
							|  |  |  |             engine=self.engine, | 
					
						
							| 
									
										
										
										
											2024-11-15 14:51:57 -08:00
										 |  |  |             base_dir=base_dir, | 
					
						
							| 
									
										
										
										
											2024-11-09 14:32:24 -08:00
										 |  |  |         ) | 
					
						
							| 
									
										
										
										
											2024-11-26 15:39:36 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def initialize_database(self, auto_upgrade: bool = False, force_init_alembic: bool = True) -> Response: | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         Initialize database and migrations in the correct order. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         Args: | 
					
						
							|  |  |  |             auto_upgrade: If True, automatically generate and apply migrations for schema changes | 
					
						
							|  |  |  |             force_init_alembic: If True, reinitialize alembic configuration even if it exists | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         if not self._init_lock.acquire(blocking=False): | 
					
						
							|  |  |  |             return Response(message="Database initialization already in progress", status=False) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         try: | 
					
						
							|  |  |  |             inspector = inspect(self.engine) | 
					
						
							|  |  |  |             tables_exist = inspector.get_table_names() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             if not tables_exist: | 
					
						
							|  |  |  |                 # Fresh install - create tables and initialize migrations | 
					
						
							|  |  |  |                 logger.info("Creating database tables...") | 
					
						
							|  |  |  |                 SQLModel.metadata.create_all(self.engine) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 if self.schema_manager.initialize_migrations(force=force_init_alembic): | 
					
						
							|  |  |  |                     return Response(message="Database initialized successfully", status=True) | 
					
						
							|  |  |  |                 return Response(message="Failed to initialize migrations", status=False) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             # Handle existing database | 
					
						
							|  |  |  |             if auto_upgrade: | 
					
						
							|  |  |  |                 logger.info("Checking database schema...") | 
					
						
							|  |  |  |                 if self.schema_manager.ensure_schema_up_to_date():  # <-- Use this instead | 
					
						
							|  |  |  |                     return Response(message="Database schema is up to date", status=True) | 
					
						
							|  |  |  |                 return Response(message="Database upgrade failed", status=False) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             return Response(message="Database is ready", status=True) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         except Exception as e: | 
					
						
							|  |  |  |             error_msg = f"Database initialization failed: {str(e)}" | 
					
						
							|  |  |  |             logger.error(error_msg) | 
					
						
							|  |  |  |             return Response(message=error_msg, status=False) | 
					
						
							|  |  |  |         finally: | 
					
						
							|  |  |  |             self._init_lock.release() | 
					
						
							| 
									
										
										
										
											2024-11-09 14:32:24 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  |     def reset_db(self, recreate_tables: bool = True): | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         Reset the database by dropping all tables and optionally recreating them. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         Args: | 
					
						
							|  |  |  |             recreate_tables (bool): If True, recreates the tables after dropping them. | 
					
						
							|  |  |  |                                 Set to False if you want to call create_db_and_tables() separately. | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         if not self._init_lock.acquire(blocking=False): | 
					
						
							|  |  |  |             logger.warning("Database reset already in progress") | 
					
						
							| 
									
										
										
										
											2024-11-26 15:39:36 -08:00
										 |  |  |             return Response(message="Database reset already in progress", status=False, data=None) | 
					
						
							| 
									
										
										
										
											2024-11-09 14:32:24 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  |         try: | 
					
						
							|  |  |  |             # Dispose existing connections | 
					
						
							|  |  |  |             self.engine.dispose() | 
					
						
							|  |  |  |             with Session(self.engine) as session: | 
					
						
							|  |  |  |                 try: | 
					
						
							|  |  |  |                     # Disable foreign key checks for SQLite | 
					
						
							| 
									
										
										
										
											2024-11-26 15:39:36 -08:00
										 |  |  |                     if "sqlite" in str(self.engine.url): | 
					
						
							|  |  |  |                         session.exec(text("PRAGMA foreign_keys=OFF")) | 
					
						
							| 
									
										
										
										
											2024-11-09 14:32:24 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  |                     # Drop all tables | 
					
						
							|  |  |  |                     SQLModel.metadata.drop_all(self.engine) | 
					
						
							|  |  |  |                     logger.info("All tables dropped successfully") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                     # Re-enable foreign key checks for SQLite | 
					
						
							| 
									
										
										
										
											2024-11-26 15:39:36 -08:00
										 |  |  |                     if "sqlite" in str(self.engine.url): | 
					
						
							|  |  |  |                         session.exec(text("PRAGMA foreign_keys=ON")) | 
					
						
							| 
									
										
										
										
											2024-11-09 14:32:24 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  |                     session.commit() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 except Exception as e: | 
					
						
							|  |  |  |                     session.rollback() | 
					
						
							|  |  |  |                     raise e | 
					
						
							|  |  |  |                 finally: | 
					
						
							|  |  |  |                     session.close() | 
					
						
							|  |  |  |                     self._init_lock.release() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             if recreate_tables: | 
					
						
							|  |  |  |                 logger.info("Recreating tables...") | 
					
						
							| 
									
										
										
										
											2024-11-26 15:39:36 -08:00
										 |  |  |                 self.initialize_database(auto_upgrade=False, force_init_alembic=True) | 
					
						
							| 
									
										
										
										
											2024-11-09 14:32:24 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  |             return Response( | 
					
						
							|  |  |  |                 message="Database reset successfully" if recreate_tables else "Database tables dropped successfully", | 
					
						
							|  |  |  |                 status=True, | 
					
						
							| 
									
										
										
										
											2024-11-26 15:39:36 -08:00
										 |  |  |                 data=None, | 
					
						
							| 
									
										
										
										
											2024-11-09 14:32:24 -08:00
										 |  |  |             ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         except Exception as e: | 
					
						
							|  |  |  |             error_msg = f"Error while resetting database: {str(e)}" | 
					
						
							|  |  |  |             logger.error(error_msg) | 
					
						
							| 
									
										
										
										
											2024-11-26 15:39:36 -08:00
										 |  |  |             return Response(message=error_msg, status=False, data=None) | 
					
						
							| 
									
										
										
										
											2024-11-09 14:32:24 -08:00
										 |  |  |         finally: | 
					
						
							|  |  |  |             if self._init_lock.locked(): | 
					
						
							|  |  |  |                 self._init_lock.release() | 
					
						
							|  |  |  |                 logger.info("Database reset lock released") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def upsert(self, model: SQLModel, return_json: bool = True): | 
					
						
							|  |  |  |         """Create or update an entity
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         Args: | 
					
						
							|  |  |  |             model (SQLModel): The model instance to create or update | 
					
						
							| 
									
										
										
										
											2024-11-26 15:39:36 -08:00
										 |  |  |             return_json (bool, optional): If True, returns the model as a dictionary. | 
					
						
							| 
									
										
										
										
											2024-11-09 14:32:24 -08:00
										 |  |  |                 If False, returns the SQLModel instance. Defaults to True. | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         Returns: | 
					
						
							|  |  |  |             Response: Contains status, message and data (either dict or SQLModel based on return_json) | 
					
						
							|  |  |  |         """
 | 
					
						
							|  |  |  |         status = True | 
					
						
							|  |  |  |         model_class = type(model) | 
					
						
							|  |  |  |         existing_model = None | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         with Session(self.engine) as session: | 
					
						
							|  |  |  |             try: | 
					
						
							| 
									
										
										
										
											2024-11-26 15:39:36 -08:00
										 |  |  |                 existing_model = session.exec(select(model_class).where(model_class.id == model.id)).first() | 
					
						
							| 
									
										
										
										
											2024-11-09 14:32:24 -08:00
										 |  |  |                 if existing_model: | 
					
						
							|  |  |  |                     model.updated_at = datetime.now() | 
					
						
							|  |  |  |                     for key, value in model.model_dump().items(): | 
					
						
							|  |  |  |                         setattr(existing_model, key, value) | 
					
						
							|  |  |  |                     model = existing_model  # Use the updated existing model | 
					
						
							|  |  |  |                     session.add(model) | 
					
						
							|  |  |  |                 else: | 
					
						
							|  |  |  |                     session.add(model) | 
					
						
							|  |  |  |                 session.commit() | 
					
						
							|  |  |  |                 session.refresh(model) | 
					
						
							|  |  |  |             except Exception as e: | 
					
						
							|  |  |  |                 session.rollback() | 
					
						
							| 
									
										
										
										
											2024-11-26 15:39:36 -08:00
										 |  |  |                 logger.error("Error while updating/creating " + str(model_class.__name__) + ": " + str(e)) | 
					
						
							| 
									
										
										
										
											2024-11-09 14:32:24 -08:00
										 |  |  |                 status = False | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return Response( | 
					
						
							|  |  |  |             message=( | 
					
						
							|  |  |  |                 f"{model_class.__name__} Updated Successfully" | 
					
						
							|  |  |  |                 if existing_model | 
					
						
							|  |  |  |                 else f"{model_class.__name__} Created Successfully" | 
					
						
							|  |  |  |             ), | 
					
						
							|  |  |  |             status=status, | 
					
						
							|  |  |  |             data=model.model_dump() if return_json else model, | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def _model_to_dict(self, model_obj): | 
					
						
							|  |  |  |         return {col.name: getattr(model_obj, col.name) for col in model_obj.__table__.columns} | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def get( | 
					
						
							|  |  |  |         self, | 
					
						
							|  |  |  |         model_class: SQLModel, | 
					
						
							|  |  |  |         filters: dict = None, | 
					
						
							|  |  |  |         return_json: bool = False, | 
					
						
							|  |  |  |         order: str = "desc", | 
					
						
							|  |  |  |     ): | 
					
						
							|  |  |  |         """List entities""" | 
					
						
							|  |  |  |         with Session(self.engine) as session: | 
					
						
							|  |  |  |             result = [] | 
					
						
							|  |  |  |             status = True | 
					
						
							|  |  |  |             status_message = "" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             try: | 
					
						
							|  |  |  |                 statement = select(model_class) | 
					
						
							|  |  |  |                 if filters: | 
					
						
							| 
									
										
										
										
											2024-11-26 15:39:36 -08:00
										 |  |  |                     conditions = [getattr(model_class, col) == value for col, value in filters.items()] | 
					
						
							| 
									
										
										
										
											2024-11-09 14:32:24 -08:00
										 |  |  |                     statement = statement.where(and_(*conditions)) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 if hasattr(model_class, "created_at") and order: | 
					
						
							| 
									
										
										
										
											2024-11-26 15:39:36 -08:00
										 |  |  |                     order_by_clause = getattr(model_class.created_at, order)()  # Dynamically apply asc/desc | 
					
						
							| 
									
										
										
										
											2024-11-09 14:32:24 -08:00
										 |  |  |                     statement = statement.order_by(order_by_clause) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 items = session.exec(statement).all() | 
					
						
							| 
									
										
										
										
											2024-11-26 15:39:36 -08:00
										 |  |  |                 result = [self._model_to_dict(item) if return_json else item for item in items] | 
					
						
							| 
									
										
										
										
											2024-11-09 14:32:24 -08:00
										 |  |  |                 status_message = f"{model_class.__name__} Retrieved Successfully" | 
					
						
							|  |  |  |             except Exception as e: | 
					
						
							|  |  |  |                 session.rollback() | 
					
						
							|  |  |  |                 status = False | 
					
						
							|  |  |  |                 status_message = f"Error while fetching {model_class.__name__}" | 
					
						
							| 
									
										
										
										
											2024-11-26 15:39:36 -08:00
										 |  |  |                 logger.error("Error while getting items: " + str(model_class.__name__) + " " + str(e)) | 
					
						
							| 
									
										
										
										
											2024-11-09 14:32:24 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  |             return Response(message=status_message, status=status, data=result) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def delete(self, model_class: SQLModel, filters: dict = None): | 
					
						
							|  |  |  |         """Delete an entity""" | 
					
						
							|  |  |  |         status_message = "" | 
					
						
							|  |  |  |         status = True | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         with Session(self.engine) as session: | 
					
						
							|  |  |  |             try: | 
					
						
							|  |  |  |                 statement = select(model_class) | 
					
						
							|  |  |  |                 if filters: | 
					
						
							| 
									
										
										
										
											2024-11-26 15:39:36 -08:00
										 |  |  |                     conditions = [getattr(model_class, col) == value for col, value in filters.items()] | 
					
						
							| 
									
										
										
										
											2024-11-09 14:32:24 -08:00
										 |  |  |                     statement = statement.where(and_(*conditions)) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 rows = session.exec(statement).all() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 if rows: | 
					
						
							|  |  |  |                     for row in rows: | 
					
						
							|  |  |  |                         session.delete(row) | 
					
						
							|  |  |  |                     session.commit() | 
					
						
							|  |  |  |                     status_message = f"{model_class.__name__} Deleted Successfully" | 
					
						
							|  |  |  |                 else: | 
					
						
							|  |  |  |                     status_message = "Row not found" | 
					
						
							|  |  |  |                     logger.info(f"Row with filters {filters} not found") | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             except exc.IntegrityError as e: | 
					
						
							|  |  |  |                 session.rollback() | 
					
						
							|  |  |  |                 status = False | 
					
						
							|  |  |  |                 status_message = f"Integrity error: The {model_class.__name__} is linked to another entity and cannot be deleted. {e}" | 
					
						
							|  |  |  |                 # Log the specific integrity error | 
					
						
							|  |  |  |                 logger.error(status_message) | 
					
						
							|  |  |  |             except Exception as e: | 
					
						
							|  |  |  |                 session.rollback() | 
					
						
							|  |  |  |                 status = False | 
					
						
							|  |  |  |                 status_message = f"Error while deleting: {e}" | 
					
						
							|  |  |  |                 logger.error(status_message) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         return Response(message=status_message, status=status, data=None) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def link( | 
					
						
							|  |  |  |         self, | 
					
						
							|  |  |  |         link_type: LinkTypes, | 
					
						
							|  |  |  |         primary_id: int, | 
					
						
							|  |  |  |         secondary_id: int, | 
					
						
							|  |  |  |         sequence: Optional[int] = None, | 
					
						
							|  |  |  |     ): | 
					
						
							|  |  |  |         """Link two entities with automatic sequence handling.""" | 
					
						
							|  |  |  |         with Session(self.engine) as session: | 
					
						
							|  |  |  |             try: | 
					
						
							|  |  |  |                 # Get classes from LinkTypes | 
					
						
							|  |  |  |                 primary_class = link_type.primary_class | 
					
						
							|  |  |  |                 secondary_class = link_type.secondary_class | 
					
						
							|  |  |  |                 link_table = link_type.link_table | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 # Get entities | 
					
						
							|  |  |  |                 primary_entity = session.get(primary_class, primary_id) | 
					
						
							|  |  |  |                 secondary_entity = session.get(secondary_class, secondary_id) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 if not primary_entity or not secondary_entity: | 
					
						
							|  |  |  |                     return Response(message="One or both entities do not exist", status=False) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 # Get field names | 
					
						
							|  |  |  |                 primary_id_field = f"{primary_class.__name__.lower()}_id" | 
					
						
							|  |  |  |                 secondary_id_field = f"{secondary_class.__name__.lower()}_id" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 # Check for existing link | 
					
						
							|  |  |  |                 existing_link = session.exec( | 
					
						
							|  |  |  |                     select(link_table).where( | 
					
						
							|  |  |  |                         and_( | 
					
						
							|  |  |  |                             getattr(link_table, primary_id_field) == primary_id, | 
					
						
							| 
									
										
										
										
											2024-11-26 15:39:36 -08:00
										 |  |  |                             getattr(link_table, secondary_id_field) == secondary_id, | 
					
						
							| 
									
										
										
										
											2024-11-09 14:32:24 -08:00
										 |  |  |                         ) | 
					
						
							|  |  |  |                     ) | 
					
						
							|  |  |  |                 ).first() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 if existing_link: | 
					
						
							|  |  |  |                     return Response(message="Link already exists", status=False) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 # Get the next sequence number if not provided | 
					
						
							|  |  |  |                 if sequence is None: | 
					
						
							|  |  |  |                     max_seq_result = session.exec( | 
					
						
							| 
									
										
										
										
											2024-11-26 15:39:36 -08:00
										 |  |  |                         select(func.max(link_table.sequence)).where(getattr(link_table, primary_id_field) == primary_id) | 
					
						
							| 
									
										
										
										
											2024-11-09 14:32:24 -08:00
										 |  |  |                     ).first() | 
					
						
							|  |  |  |                     sequence = 0 if max_seq_result is None else max_seq_result + 1 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 # Create new link | 
					
						
							| 
									
										
										
										
											2024-11-26 15:39:36 -08:00
										 |  |  |                 new_link = link_table( | 
					
						
							|  |  |  |                     **{primary_id_field: primary_id, secondary_id_field: secondary_id, "sequence": sequence} | 
					
						
							|  |  |  |                 ) | 
					
						
							| 
									
										
										
										
											2024-11-09 14:32:24 -08:00
										 |  |  |                 session.add(new_link) | 
					
						
							|  |  |  |                 session.commit() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-11-26 15:39:36 -08:00
										 |  |  |                 return Response(message=f"Entities linked successfully with sequence {sequence}", status=True) | 
					
						
							| 
									
										
										
										
											2024-11-09 14:32:24 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  |             except Exception as e: | 
					
						
							|  |  |  |                 session.rollback() | 
					
						
							|  |  |  |                 return Response(message=f"Error linking entities: {str(e)}", status=False) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-11-26 15:39:36 -08:00
										 |  |  |     def unlink(self, link_type: LinkTypes, primary_id: int, secondary_id: int, sequence: Optional[int] = None): | 
					
						
							| 
									
										
										
										
											2024-11-09 14:32:24 -08:00
										 |  |  |         """Unlink two entities and reorder sequences if needed.""" | 
					
						
							|  |  |  |         with Session(self.engine) as session: | 
					
						
							|  |  |  |             try: | 
					
						
							|  |  |  |                 # Get classes from LinkTypes | 
					
						
							|  |  |  |                 primary_class = link_type.primary_class | 
					
						
							|  |  |  |                 secondary_class = link_type.secondary_class | 
					
						
							|  |  |  |                 link_table = link_type.link_table | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 # Get field names | 
					
						
							|  |  |  |                 primary_id_field = f"{primary_class.__name__.lower()}_id" | 
					
						
							|  |  |  |                 secondary_id_field = f"{secondary_class.__name__.lower()}_id" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 # Find existing link | 
					
						
							|  |  |  |                 statement = select(link_table).where( | 
					
						
							|  |  |  |                     and_( | 
					
						
							|  |  |  |                         getattr(link_table, primary_id_field) == primary_id, | 
					
						
							| 
									
										
										
										
											2024-11-26 15:39:36 -08:00
										 |  |  |                         getattr(link_table, secondary_id_field) == secondary_id, | 
					
						
							| 
									
										
										
										
											2024-11-09 14:32:24 -08:00
										 |  |  |                     ) | 
					
						
							|  |  |  |                 ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 if sequence is not None: | 
					
						
							| 
									
										
										
										
											2024-11-26 15:39:36 -08:00
										 |  |  |                     statement = statement.where(link_table.sequence == sequence) | 
					
						
							| 
									
										
										
										
											2024-11-09 14:32:24 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  |                 existing_link = session.exec(statement).first() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 if not existing_link: | 
					
						
							|  |  |  |                     return Response(message="Link does not exist", status=False) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 deleted_sequence = existing_link.sequence | 
					
						
							|  |  |  |                 session.delete(existing_link) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 # Reorder sequences for remaining links | 
					
						
							|  |  |  |                 remaining_links = session.exec( | 
					
						
							|  |  |  |                     select(link_table) | 
					
						
							|  |  |  |                     .where(getattr(link_table, primary_id_field) == primary_id) | 
					
						
							|  |  |  |                     .where(link_table.sequence > deleted_sequence) | 
					
						
							|  |  |  |                     .order_by(link_table.sequence) | 
					
						
							|  |  |  |                 ).all() | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 # Decrease sequence numbers to fill the gap | 
					
						
							|  |  |  |                 for link in remaining_links: | 
					
						
							|  |  |  |                     link.sequence -= 1 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 session.commit() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-11-26 15:39:36 -08:00
										 |  |  |                 return Response(message="Entities unlinked successfully and sequences reordered", status=True) | 
					
						
							| 
									
										
										
										
											2024-11-09 14:32:24 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  |             except Exception as e: | 
					
						
							|  |  |  |                 session.rollback() | 
					
						
							|  |  |  |                 return Response(message=f"Error unlinking entities: {str(e)}", status=False) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     def get_linked_entities( | 
					
						
							|  |  |  |         self, | 
					
						
							|  |  |  |         link_type: LinkTypes, | 
					
						
							|  |  |  |         primary_id: int, | 
					
						
							|  |  |  |         return_json: bool = False, | 
					
						
							|  |  |  |     ): | 
					
						
							|  |  |  |         """Get linked entities based on link type and primary ID, ordered by sequence.""" | 
					
						
							|  |  |  |         with Session(self.engine) as session: | 
					
						
							|  |  |  |             try: | 
					
						
							|  |  |  |                 # Get classes from LinkTypes | 
					
						
							|  |  |  |                 primary_class = link_type.primary_class | 
					
						
							|  |  |  |                 secondary_class = link_type.secondary_class | 
					
						
							|  |  |  |                 link_table = link_type.link_table | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 # Get field names | 
					
						
							|  |  |  |                 primary_id_field = f"{primary_class.__name__.lower()}_id" | 
					
						
							|  |  |  |                 secondary_id_field = f"{secondary_class.__name__.lower()}_id" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 # Query both link and entity, ordered by sequence | 
					
						
							|  |  |  |                 items = session.exec( | 
					
						
							|  |  |  |                     select(secondary_class) | 
					
						
							|  |  |  |                     .join(link_table, getattr(link_table, secondary_id_field) == secondary_class.id) | 
					
						
							|  |  |  |                     .where(getattr(link_table, primary_id_field) == primary_id) | 
					
						
							|  |  |  |                     .order_by(link_table.sequence) | 
					
						
							|  |  |  |                 ).all() | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-11-26 15:39:36 -08:00
										 |  |  |                 result = [item.model_dump() if return_json else item for item in items] | 
					
						
							| 
									
										
										
										
											2024-11-09 14:32:24 -08:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-11-26 15:39:36 -08:00
										 |  |  |                 return Response(message="Linked entities retrieved successfully", status=True, data=result) | 
					
						
							| 
									
										
										
										
											2024-11-09 14:32:24 -08:00
										 |  |  | 
 | 
					
						
							|  |  |  |             except Exception as e: | 
					
						
							|  |  |  |                 logger.error(f"Error getting linked entities: {str(e)}") | 
					
						
							| 
									
										
										
										
											2024-11-26 15:39:36 -08:00
										 |  |  |                 return Response(message=f"Error getting linked entities: {str(e)}", status=False, data=[]) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-11-09 14:32:24 -08:00
										 |  |  |     # Add new close method | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |     async def close(self): | 
					
						
							|  |  |  |         """Close database connections and cleanup resources""" | 
					
						
							|  |  |  |         logger.info("Closing database connections...") | 
					
						
							|  |  |  |         try: | 
					
						
							|  |  |  |             # Dispose of the SQLAlchemy engine | 
					
						
							|  |  |  |             self.engine.dispose() | 
					
						
							|  |  |  |             logger.info("Database connections closed successfully") | 
					
						
							|  |  |  |         except Exception as e: | 
					
						
							|  |  |  |             logger.error(f"Error closing database connections: {str(e)}") | 
					
						
							|  |  |  |             raise |