diff --git a/olmocr/train/train.py b/olmocr/train/train.py index c1156df..102d579 100644 --- a/olmocr/train/train.py +++ b/olmocr/train/train.py @@ -5,20 +5,22 @@ Simple script to test OlmOCR dataset loading with YAML configuration. import argparse import logging import os +import math +import shutil +import time import numpy as np import torch -from torch.utils.data import ConcatDataset +from torch.utils.data import ConcatDataset, DataLoader from transformers import ( AutoProcessor, - EarlyStoppingCallback, + get_scheduler, + AdamW, Qwen2_5_VLForConditionalGeneration, Qwen2VLForConditionalGeneration, - Trainer, - TrainingArguments, ) -from typing import Optional +from typing import Optional, Dict, Any from olmocr.train.config import Config from olmocr.train.dataloader import BaseMarkdownPDFDataset @@ -82,6 +84,100 @@ class QwenDataCollator: } +def save_checkpoint( + model: torch.nn.Module, + optimizer: torch.optim.Optimizer, + lr_scheduler: Any, + epoch: int, + global_step: int, + best_metric: float, + output_dir: str, + save_total_limit: Optional[int] = None, +): + """Save model, optimizer, scheduler, and training state.""" + checkpoint_dir = os.path.join(output_dir, f"checkpoint-{global_step}") + os.makedirs(checkpoint_dir, exist_ok=True) + + # Save model + model.save_pretrained(checkpoint_dir) + + # Save optimizer and scheduler + torch.save(optimizer.state_dict(), os.path.join(checkpoint_dir, "optimizer.pt")) + torch.save(lr_scheduler.state_dict(), os.path.join(checkpoint_dir, "scheduler.pt")) + + # Save training state + state = { + "epoch": epoch, + "global_step": global_step, + "best_metric": best_metric, + } + torch.save(state, os.path.join(checkpoint_dir, "training_state.pt")) + + logger.info(f"Saved checkpoint to {checkpoint_dir}") + + # Enforce save_total_limit by removing oldest checkpoints + if save_total_limit is not None and save_total_limit > 0: + checkpoints = sorted( + [d for d in os.listdir(output_dir) if d.startswith("checkpoint-")], + key=lambda x: int(x.split("-")[1]) + ) + while len(checkpoints) > save_total_limit: + oldest = checkpoints.pop(0) + shutil.rmtree(os.path.join(output_dir, oldest)) + logger.info(f"Deleted old checkpoint: {oldest}") + + +def load_checkpoint( + model: torch.nn.Module, + optimizer: torch.optim.Optimizer, + lr_scheduler: Any, + checkpoint_dir: str, +) -> Dict[str, Any]: + """Load model, optimizer, scheduler, and training state from checkpoint.""" + model.load_pretrained(checkpoint_dir) + + optimizer.load_state_dict(torch.load(os.path.join(checkpoint_dir, "optimizer.pt"))) + lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint_dir, "scheduler.pt"))) + + state = torch.load(os.path.join(checkpoint_dir, "training_state.pt")) + logger.info(f"Resumed from checkpoint: {checkpoint_dir} at epoch {state['epoch']}, step {state['global_step']}") + return state + + +def evaluate_model( + model: torch.nn.Module, + eval_dataloaders: Dict[str, DataLoader], + device: torch.device, + amp_scaler: Any, # For bf16 +) -> Dict[str, float]: + """Evaluate on all eval datasets and return average loss per dataset.""" + model.eval() + eval_metrics = {} + + for dataset_name, dataloader in eval_dataloaders.items(): + total_loss = 0.0 + num_batches = 0 + + with torch.no_grad(): + for batch in dataloader: + batch = {k: v.to(device) for k, v in batch.items()} + with amp_scaler.autocast(enabled=True): # bf16 + outputs = model(**batch) + total_loss += outputs.loss.item() + num_batches += 1 + + avg_loss = total_loss / num_batches if num_batches > 0 else 0.0 + eval_metrics[f"eval_{dataset_name}_loss"] = avg_loss + logger.info(f"Eval {dataset_name} loss: {avg_loss:.4f}") + + # Compute overall eval loss as average across datasets (or customize as needed) + if eval_metrics: + overall_loss = sum(eval_metrics.values()) / len(eval_metrics) + eval_metrics["eval_loss"] = overall_loss + + return eval_metrics + + def main(): parser = argparse.ArgumentParser(description="Train OlmOCR model") parser.add_argument("--config", type=str, default="olmocr/train/configs/example_config.yaml", help="Path to YAML configuration file") @@ -104,6 +200,11 @@ def main(): os.environ["WANDB_PROJECT"] = config.project_name logger.info(f"Setting WANDB_PROJECT to: {config.project_name}") + # Initialize wandb if reporting to it + if "wandb" in config.training.report_to: + import wandb + wandb.init(project=config.project_name, name=config.run_name, config=config.to_dict()) + # Load processor for tokenization logger.info(f"Loading processor: {config.model.name}") processor = AutoProcessor.from_pretrained( @@ -177,9 +278,10 @@ def main(): # Construct full output directory by appending run_name to base output_dir full_output_dir = os.path.join(config.training.output_dir, config.run_name) logger.info(f"Setting output directory to: {full_output_dir}") + os.makedirs(full_output_dir, exist_ok=True) # Check for existing checkpoints if any - found_resumable_checkpoint = False + found_resumable_checkpoint = None if os.path.exists(full_output_dir): # Look for checkpoint directories checkpoint_dirs = [d for d in os.listdir(full_output_dir) if d.startswith("checkpoint-") and os.path.isdir(os.path.join(full_output_dir, d))] @@ -192,80 +294,197 @@ def main(): else: logger.info("No existing checkpoints found in output directory") - # Set up training arguments - training_args = TrainingArguments( - output_dir=full_output_dir, - num_train_epochs=config.training.num_train_epochs, - per_device_train_batch_size=config.training.per_device_train_batch_size, - per_device_eval_batch_size=config.training.per_device_eval_batch_size, - gradient_accumulation_steps=config.training.gradient_accumulation_steps, - learning_rate=float(config.training.learning_rate), - lr_scheduler_type=config.training.lr_scheduler_type, - warmup_ratio=config.training.warmup_ratio, - lr_scheduler_kwargs=config.training.lr_scheduler_kwargs, - optim=config.training.optim, - adam_beta1=config.training.adam_beta1, - adam_beta2=config.training.adam_beta2, - adam_epsilon=config.training.adam_epsilon, - weight_decay=config.training.weight_decay, - max_grad_norm=config.training.max_grad_norm, - bf16=True, # We're sticking with this known good reduced precision option - eval_strategy=config.training.evaluation_strategy, - eval_steps=config.training.eval_steps, - save_strategy=config.training.save_strategy, - save_steps=config.training.save_steps, - save_total_limit=config.training.save_total_limit, - load_best_model_at_end=config.training.load_best_model_at_end, - metric_for_best_model=config.training.metric_for_best_model, - greater_is_better=config.training.greater_is_better, - logging_dir=config.training.logging_dir, - logging_strategy=config.training.logging_strategy, - logging_steps=config.training.logging_steps, - logging_first_step=config.training.logging_first_step, - report_to=config.training.report_to, - seed=config.training.seed, - data_seed=config.training.data_seed, - push_to_hub=False, - label_names=["labels"], - dataloader_drop_last=config.training.dataloader_drop_last, - dataloader_num_workers=config.training.dataloader_num_workers, - remove_unused_columns=config.training.remove_unused_columns, - eval_on_start=True, - run_name=config.run_name, - ) + # Set seeds + torch.manual_seed(config.training.seed) + if config.training.data_seed is not None: + torch.utils.data.dataset.random.seed(config.training.data_seed) - # Set up callbacks - callbacks = [] - if config.training.use_early_stopping: - callbacks.append( - EarlyStoppingCallback( - early_stopping_patience=config.training.early_stopping_patience, early_stopping_threshold=config.training.early_stopping_threshold - ) + # Device setup + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model.to(device) + + # Set up optimizer + no_decay = ["bias", "LayerNorm.weight"] + optimizer_grouped_parameters = [ + { + "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], + "weight_decay": config.training.weight_decay, + }, + { + "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], + "weight_decay": 0.0, + }, + ] + if config.training.optim == "adamw_torch": + optimizer = AdamW( + optimizer_grouped_parameters, + lr=float(config.training.learning_rate), + betas=(config.training.adam_beta1, config.training.adam_beta2), + eps=config.training.adam_epsilon, ) + else: + raise NotImplementedError(f"Optimizer {config.training.optim} not supported in custom loop") - # Initialize trainer - logger.info("Initializing trainer...") - trainer = Trainer( - model=model, - args=training_args, - train_dataset=train_dataset, - eval_dataset=eval_datasets, - data_collator=QwenDataCollator(max_token_len=config.training.collator_max_token_len), - callbacks=callbacks, + # Total training steps calculation + num_update_steps_per_epoch = math.ceil(len(train_dataset) / (config.training.per_device_train_batch_size * config.training.gradient_accumulation_steps)) + max_train_steps = int(config.training.num_train_epochs * num_update_steps_per_epoch) + + # Set up scheduler + lr_scheduler = get_scheduler( + name=config.training.lr_scheduler_type, + optimizer=optimizer, + num_warmup_steps=int(max_train_steps * config.training.warmup_ratio), + num_training_steps=max_train_steps, + scheduler_specific_kwargs=config.training.lr_scheduler_kwargs, ) - # Start training + # Set up mixed precision (bf16) + from torch.cuda.amp import GradScaler, autocast + amp_scaler = GradScaler(enabled=True) # For bf16, but note: bf16 doesn't need scaling like fp16 + + # Data collator + data_collator = QwenDataCollator(max_token_len=config.training.collator_max_token_len) + + # Create dataloaders + train_dataloader = DataLoader( + train_dataset, + batch_size=config.training.per_device_train_batch_size, + shuffle=True, + collate_fn=data_collator, + num_workers=config.training.dataloader_num_workers, + drop_last=config.training.dataloader_drop_last, + ) + + eval_dataloaders = { + name: DataLoader( + dataset, + batch_size=config.training.per_device_eval_batch_size, + shuffle=False, + collate_fn=data_collator, + num_workers=config.training.dataloader_num_workers, + drop_last=False, + ) + for name, dataset in eval_datasets.items() + } + + # Resume from checkpoint if available + start_epoch = 0 + global_step = 0 + best_metric = float("inf") if config.training.greater_is_better else -float("inf") + best_metric_key = config.training.metric_for_best_model # e.g., "eval_loss" + if found_resumable_checkpoint: + state = load_checkpoint(model, optimizer, lr_scheduler, found_resumable_checkpoint) + start_epoch = state["epoch"] + 1 # Start from next epoch + global_step = state["global_step"] + best_metric = state["best_metric"] + + # Early stopping setup + patience_counter = 0 + early_stopping_patience = config.training.early_stopping_patience if config.training.use_early_stopping else float("inf") + early_stopping_threshold = config.training.early_stopping_threshold + + # Evaluate on start if configured + if config.training.eval_on_start: + metrics = evaluate_model(model, eval_dataloaders, device, autocast) + logger.info(f"Initial evaluation: {metrics}") + if "wandb" in config.training.report_to: + wandb.log(metrics, step=global_step) + + # Main training loop logger.info("Starting training...") - train_result = trainer.train(resume_from_checkpoint=found_resumable_checkpoint) + model.train() + for epoch in range(start_epoch, int(config.training.num_train_epochs)): + epoch_start_time = time.time() + train_loss = 0.0 + num_batches = 0 + + for batch_idx, batch in enumerate(train_dataloader): + batch = {k: v.to(device) for k, v in batch.items()} + + with autocast(enabled=True): # bf16 + outputs = model(**batch) + loss = outputs.loss / config.training.gradient_accumulation_steps + amp_scaler.scale(loss).backward() + + train_loss += loss.item() * config.training.gradient_accumulation_steps + num_batches += 1 + + if (batch_idx + 1) % config.training.gradient_accumulation_steps == 0: + # Clip gradients + amp_scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(model.parameters(), config.training.max_grad_norm) + + # Step optimizer and scheduler + amp_scaler.step(optimizer) + amp_scaler.update() + lr_scheduler.step() + optimizer.zero_grad() + + global_step += 1 + + # Logging + if config.training.logging_steps > 0 and global_step % config.training.logging_steps == 0: + avg_train_loss = train_loss / num_batches + logs = { + "train_loss": avg_train_loss, + "learning_rate": lr_scheduler.get_last_lr()[0], + "epoch": epoch + (batch_idx / len(train_dataloader)), + } + logger.info(f"Step {global_step}: {logs}") + if "wandb" in config.training.report_to: + wandb.log(logs, step=global_step) + + train_loss = 0.0 + num_batches = 0 + + # Evaluation + if config.training.eval_steps > 0 and global_step % config.training.eval_steps == 0: + metrics = evaluate_model(model, eval_dataloaders, device, autocast) + logger.info(f"Evaluation at step {global_step}: {metrics}") + if "wandb" in config.training.report_to: + wandb.log(metrics, step=global_step) + + # Early stopping check + current_metric = metrics.get(best_metric_key, None) + if current_metric is not None: + if (config.training.greater_is_better and current_metric > best_metric + early_stopping_threshold) or \ + (not config.training.greater_is_better and current_metric < best_metric - early_stopping_threshold): + best_metric = current_metric + patience_counter = 0 + if config.training.load_best_model_at_end: + # Save best model (optional: implement loading best at end) + pass + else: + patience_counter += 1 + if patience_counter >= early_stopping_patience: + logger.info(f"Early stopping at step {global_step}") + break + + # Saving + if config.training.save_steps > 0 and global_step % config.training.save_steps == 0: + save_checkpoint( + model, optimizer, lr_scheduler, epoch, global_step, best_metric, + full_output_dir, config.training.save_total_limit + ) + + # End of epoch logging + epoch_time = time.time() - epoch_start_time + logger.info(f"Epoch {epoch} completed in {epoch_time:.2f}s") + + if patience_counter >= early_stopping_patience: + break # Save the final model logger.info("Saving final model...") - trainer.save_model() - trainer.save_state() - - # Log metrics - logger.info(f"Training completed! Metrics: {train_result.metrics}") + model.save_pretrained(full_output_dir) + + # Final evaluation + final_metrics = evaluate_model(model, eval_dataloaders, device, autocast) + logger.info(f"Training completed! Final metrics: {final_metrics}") + if "wandb" in config.training.report_to: + wandb.log(final_metrics, step=global_step) + wandb.finish() if __name__ == "__main__": - main() + main() \ No newline at end of file