import os import shutil from pathlib import Path from typing import List, Optional, Tuple import sqlmodel from alembic import command from alembic.autogenerate import compare_metadata from alembic.config import Config from alembic.runtime.migration import MigrationContext from alembic.script import ScriptDirectory from alembic.util.exc import CommandError from loguru import logger from sqlalchemy import Engine, text from sqlmodel import SQLModel class SchemaManager: """ Manages database schema validation and migrations using Alembic. Operations are initiated explicitly by DatabaseManager. """ def __init__( self, engine: Engine, base_dir: Optional[Path] = None, ): """ Initialize configuration only - no filesystem or DB operations. Args: engine: SQLAlchemy engine instance base_dir: Base directory for Alembic files. If None, uses current working directory """ # Convert string path to Path object if necessary if isinstance(base_dir, str): base_dir = Path(base_dir) self.engine = engine self.base_dir = base_dir or Path(__file__).parent self.alembic_dir = self.base_dir / "alembic" self.alembic_ini_path = self.base_dir / "alembic.ini" def initialize_migrations(self, force: bool = False) -> bool: try: if force: logger.info("Force reinitialization of migrations...") self._cleanup_existing_alembic() if not self._initialize_alembic(): return False else: try: self._validate_alembic_setup() logger.info("Using existing Alembic configuration") self._update_configuration() except FileNotFoundError: logger.info("Initializing new Alembic configuration") if not self._initialize_alembic(): return False # Only generate initial revision if alembic is properly initialized logger.info("Creating initial migration...") return self.generate_revision("Initial schema") is not None except Exception as e: logger.error(f"Failed to initialize migrations: {e}") return False def _update_configuration(self) -> None: """Updates existing Alembic configuration with current settings.""" logger.info("Updating existing Alembic configuration...") # Update alembic.ini config_content = self._generate_alembic_ini_content() with open(self.alembic_ini_path, "w") as f: f.write(config_content) # Update env.py env_path = self.alembic_dir / "env.py" if env_path.exists(): self._update_env_py(env_path) else: self._create_minimal_env_py(env_path) def _cleanup_existing_alembic(self) -> None: """ Completely remove existing Alembic configuration including versions. For fresh initialization, we don't need to preserve anything. """ logger.info("Cleaning up existing Alembic configuration...") # Remove entire alembic directory if it exists if self.alembic_dir.exists(): import shutil shutil.rmtree(self.alembic_dir) logger.info(f"Removed alembic directory: {self.alembic_dir}") # Remove alembic.ini if it exists if self.alembic_ini_path.exists(): self.alembic_ini_path.unlink() logger.info("Removed alembic.ini") def _ensure_alembic_setup(self, *, force: bool = False) -> None: """ Ensures Alembic is properly set up, initializing if necessary. Args: force: If True, removes existing configuration and reinitializes """ try: self._validate_alembic_setup() if force: logger.info("Force initialization requested. Cleaning up existing configuration...") self._cleanup_existing_alembic() self._initialize_alembic() except FileNotFoundError: logger.info("Alembic configuration not found. Initializing...") if self.alembic_dir.exists(): logger.warning("Found existing alembic directory but missing configuration") self._cleanup_existing_alembic() self._initialize_alembic() logger.info("Alembic initialization complete") def _initialize_alembic(self) -> bool: """Initialize alembic structure and configuration""" try: # Ensure parent directory exists self.alembic_dir.parent.mkdir(exist_ok=True) # Run alembic init to create fresh directory structure logger.info("Initializing alembic directory structure...") # Create initial config file for alembic init config_content = self._generate_alembic_ini_content() with open(self.alembic_ini_path, "w") as f: f.write(config_content) # Use the config we just created config = Config(str(self.alembic_ini_path)) command.init(config, str(self.alembic_dir)) # Update script template after initialization self.update_script_template() # Update env.py with our customizations self._update_env_py(self.alembic_dir / "env.py") logger.info("Alembic initialization complete") return True except Exception as e: # Explicitly convert error to string logger.error(f"Failed to initialize alembic: {str(e)}") return False def _create_minimal_env_py(self, env_path: Path) -> None: """Creates a minimal env.py file for Alembic.""" content = """ from logging.config import fileConfig from sqlalchemy import engine_from_config from sqlalchemy import pool from alembic import context from sqlmodel import SQLModel config = context.config if config.config_file_name is not None: fileConfig(config.config_file_name) target_metadata = SQLModel.metadata def run_migrations_offline() -> None: url = config.get_main_option("sqlalchemy.url") context.configure( url=url, target_metadata=target_metadata, literal_binds=True, dialect_opts={"paramstyle": "named"}, compare_type=True ) with context.begin_transaction(): context.run_migrations() def run_migrations_online() -> None: connectable = engine_from_config( config.get_section(config.config_ini_section), prefix="sqlalchemy.", poolclass=pool.NullPool, ) with connectable.connect() as connection: context.configure( connection=connection, target_metadata=target_metadata, compare_type=True ) with context.begin_transaction(): context.run_migrations() if context.is_offline_mode(): run_migrations_offline() else: run_migrations_online()""" with open(env_path, "w") as f: f.write(content) def _generate_alembic_ini_content(self) -> str: """ Generates content for alembic.ini file. """ return f""" [alembic] script_location = {self.alembic_dir} sqlalchemy.url = {self.engine.url} [loggers] keys = root,sqlalchemy,alembic [handlers] keys = console [formatters] keys = generic [logger_root] level = WARN handlers = console qualname = [logger_sqlalchemy] level = WARN handlers = qualname = sqlalchemy.engine [logger_alembic] level = INFO handlers = qualname = alembic [handler_console] class = StreamHandler args = (sys.stderr,) level = NOTSET formatter = generic [formatter_generic] format = %(levelname)-5.5s [%(name)s] %(message)s datefmt = %H:%M:%S """.strip() def update_script_template(self): """Update the Alembic script template to include SQLModel.""" template_path = self.alembic_dir / "script.py.mako" try: with open(template_path, "r") as f: content = f.read() # Add sqlmodel import to imports section import_section = "from alembic import op\nimport sqlalchemy as sa" new_imports = "from alembic import op\nimport sqlalchemy as sa\nimport sqlmodel" content = content.replace(import_section, new_imports) with open(template_path, "w") as f: f.write(content) logger.info("Updated script template") return True except Exception as e: logger.error(f"Failed to update script template: {e}") return False def _update_env_py(self, env_path: Path) -> None: """ Updates the env.py file to use SQLModel metadata. """ if not env_path.exists(): self._create_minimal_env_py(env_path) return try: with open(env_path, "r") as f: content = f.read() # Add SQLModel import if not present if "from sqlmodel import SQLModel" not in content: content = "from sqlmodel import SQLModel\n" + content # Replace target_metadata content = content.replace("target_metadata = None", "target_metadata = SQLModel.metadata") # Update both configure blocks properly content = content.replace( """context.configure( url=url, target_metadata=target_metadata, literal_binds=True, dialect_opts={"paramstyle": "named"}, )""", """context.configure( url=url, target_metadata=target_metadata, literal_binds=True, dialect_opts={"paramstyle": "named"}, compare_type=True, )""", ) content = content.replace( """ context.configure( connection=connection, target_metadata=target_metadata )""", """ context.configure( connection=connection, target_metadata=target_metadata, compare_type=True, )""", ) with open(env_path, "w") as f: f.write(content) logger.info("Updated env.py with SQLModel metadata") except Exception as e: logger.error(f"Failed to update env.py: {e}") raise # Fixed: use keyword-only argument def _ensure_alembic_setup(self, *, force: bool = False) -> None: """ Ensures Alembic is properly set up, initializing if necessary. Args: force: If True, removes existing configuration and reinitializes """ try: self._validate_alembic_setup() if force: logger.info("Force initialization requested. Cleaning up existing configuration...") self._cleanup_existing_alembic() self._initialize_alembic() except FileNotFoundError: logger.info("Alembic configuration not found. Initializing...") if self.alembic_dir.exists(): logger.warning("Found existing alembic directory but missing configuration") self._cleanup_existing_alembic() self._initialize_alembic() logger.info("Alembic initialization complete") def _validate_alembic_setup(self) -> None: """Validates that Alembic is properly configured.""" required_files = [self.alembic_ini_path, self.alembic_dir / "env.py", self.alembic_dir / "versions"] missing = [f for f in required_files if not f.exists()] if missing: raise FileNotFoundError(f"Alembic configuration incomplete. Missing: {', '.join(str(f) for f in missing)}") def get_alembic_config(self) -> Config: """ Gets Alembic configuration. Returns: Config: Alembic Config object Raises: FileNotFoundError: If alembic.ini cannot be found """ if not self.alembic_ini_path.exists(): raise FileNotFoundError("Could not find alembic.ini") return Config(str(self.alembic_ini_path)) def get_current_revision(self) -> Optional[str]: """ Gets the current database revision. Returns: str: Current revision string or None if no revision """ with self.engine.connect() as conn: context = MigrationContext.configure(conn) return context.get_current_revision() def get_head_revision(self) -> str: """ Gets the latest available revision. Returns: str: Head revision string """ config = self.get_alembic_config() script = ScriptDirectory.from_config(config) return script.get_current_head() def get_schema_differences(self) -> List[tuple]: """ Detects differences between current database and models. Returns: List[tuple]: List of differences found """ with self.engine.connect() as conn: context = MigrationContext.configure(conn) diff = compare_metadata(context, SQLModel.metadata) return list(diff) def check_schema_status(self) -> Tuple[bool, str]: """ Checks if database schema matches current models and migrations. Returns: Tuple[bool, str]: (needs_upgrade, status_message) """ try: current_rev = self.get_current_revision() head_rev = self.get_head_revision() if current_rev != head_rev: return True, f"Database needs upgrade: {current_rev} -> {head_rev}" differences = self.get_schema_differences() if differences: changes_desc = "\n".join(str(diff) for diff in differences) return True, f"Unmigrated changes detected:\n{changes_desc}" return False, "Database schema is up to date" except Exception as e: logger.error(f"Error checking schema status: {str(e)}") return True, f"Error checking schema: {str(e)}" def upgrade_schema(self, revision: str = "head") -> bool: """ Upgrades database schema to specified revision. Args: revision: Target revision (default: "head") Returns: bool: True if upgrade successful """ try: config = self.get_alembic_config() command.upgrade(config, revision) logger.info(f"Schema upgraded successfully to {revision}") return True except Exception as e: logger.error(f"Schema upgrade failed: {str(e)}") return False def check_and_upgrade(self) -> Tuple[bool, str]: """ Checks schema status and upgrades if necessary. Returns: Tuple[bool, str]: (action_taken, status_message) """ needs_upgrade, status = self.check_schema_status() if needs_upgrade: # Remove the auto_upgrade check since we explicitly called this method if self.upgrade_schema(): return True, "Schema was automatically upgraded" else: return False, "Automatic schema upgrade failed" return False, status def generate_revision(self, message: str = "auto") -> Optional[str]: """ Generates new migration revision for current schema changes. Args: message: Revision message Returns: str: Revision ID if successful, None otherwise """ try: config = self.get_alembic_config() command.revision(config, message=message, autogenerate=True) return self.get_head_revision() except Exception as e: logger.error(f"Failed to generate revision: {str(e)}") return None def get_pending_migrations(self) -> List[str]: """ Gets list of pending migrations that need to be applied. Returns: List[str]: List of pending migration revision IDs """ config = self.get_alembic_config() script = ScriptDirectory.from_config(config) current = self.get_current_revision() head = self.get_head_revision() if current == head: return [] pending = [] for rev in script.iterate_revisions(current, head): pending.append(rev.revision) return pending def print_status(self) -> None: """Prints current migration status information to logger.""" current = self.get_current_revision() head = self.get_head_revision() differences = self.get_schema_differences() pending = self.get_pending_migrations() logger.info("=== Database Schema Status ===") logger.info(f"Current revision: {current}") logger.info(f"Head revision: {head}") logger.info(f"Pending migrations: {len(pending)}") for rev in pending: logger.info(f" - {rev}") logger.info(f"Unmigrated changes: {len(differences)}") for diff in differences: logger.info(f" - {diff}") def ensure_schema_up_to_date(self) -> bool: """ Reset migrations and create fresh migration for current schema state. """ try: logger.info("Resetting migrations and updating to current schema...") # 1. Clear the entire alembic directory if self.alembic_dir.exists(): shutil.rmtree(self.alembic_dir) logger.info("Cleared alembic directory") # 2. Clear alembic_version table with self.engine.connect() as connection: connection.execute(text("DROP TABLE IF EXISTS alembic_version")) connection.commit() logger.info("Reset alembic version") # 3. Reinitialize alembic from scratch if not self._initialize_alembic(): logger.error("Failed to reinitialize alembic") return False # 4. Generate fresh migration from current schema revision = self.generate_revision("current_schema") if not revision: logger.error("Failed to generate new migration") return False logger.info(f"Generated fresh migration: {revision}") # 5. Apply the migration if not self.upgrade_schema(): logger.error("Failed to apply migration") return False logger.info("Successfully applied migration") return True except Exception as e: logger.error(f"Failed to ensure schema is up to date: {e}") return False