From 00f51fb2c7a111ca6fb7aa79d8b5cefd237be246 Mon Sep 17 00:00:00 2001 From: Jake Poznanski Date: Tue, 2 Sep 2025 21:03:00 +0000 Subject: [PATCH] Fixing bug with multi epoch training --- olmocr/train/train.py | 92 ++++++++++++++++++++++++++++++++++--------- 1 file changed, 74 insertions(+), 18 deletions(-) diff --git a/olmocr/train/train.py b/olmocr/train/train.py index 33b74a4..ba93d5f 100644 --- a/olmocr/train/train.py +++ b/olmocr/train/train.py @@ -189,6 +189,46 @@ def evaluate_model( return eval_metrics +def create_train_dataloader( + train_dataset, + config, + data_collator, + seed_worker, + epoch_num: int = 0, +) -> DataLoader: + """Create a training dataloader with epoch-specific shuffling. + + Args: + train_dataset: The training dataset + config: Training configuration + data_collator: Data collator for batching + seed_worker: Worker initialization function + epoch_num: Current epoch number for seed generation + + Returns: + DataLoader with epoch-specific shuffling + """ + # Create generator with epoch-specific seed for different shuffling each epoch + epoch_generator = torch.Generator() + if config.training.data_seed is not None: + # Use epoch number to ensure different shuffling each epoch while maintaining reproducibility + epoch_generator.manual_seed(config.training.data_seed + epoch_num) + else: + # Use a random seed if no data_seed specified + epoch_generator.manual_seed(int(torch.randint(0, 2**32 - 1, (1,)).item())) + + return DataLoader( + train_dataset, + batch_size=config.training.per_device_train_batch_size, + shuffle=True, + collate_fn=data_collator, + num_workers=config.training.dataloader_num_workers, + drop_last=config.training.dataloader_drop_last, + worker_init_fn=seed_worker, + generator=epoch_generator, + ) + + def main(): parser = argparse.ArgumentParser(description="Train OlmOCR model") parser.add_argument("--config", type=str, default="olmocr/train/configs/example_config.yaml", help="Path to YAML configuration file") @@ -313,12 +353,6 @@ def main(): random.seed(worker_seed) - # Create generator for data loader - generator = None - if config.training.data_seed is not None: - generator = torch.Generator() - generator.manual_seed(config.training.data_seed) - # Device setup device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) @@ -418,16 +452,14 @@ def main(): best_metric = state["best_metric"] samples_seen = state["samples_seen"] - # Create dataloaders - train_dataloader = DataLoader( + # Create dataloaders - use epoch 0 initially (will be recreated with proper epoch if resuming) + current_epoch_num = int(samples_seen / len(train_dataset)) if samples_seen > 0 else 0 + train_dataloader = create_train_dataloader( train_dataset, - batch_size=config.training.per_device_train_batch_size, - shuffle=True, - collate_fn=data_collator, - num_workers=config.training.dataloader_num_workers, - drop_last=config.training.dataloader_drop_last, - worker_init_fn=seed_worker, - generator=generator, + config, + data_collator, + seed_worker, + epoch_num=current_epoch_num, ) eval_dataloaders = { @@ -467,14 +499,26 @@ def main(): samples_to_skip = samples_seen % len(train_dataset) batches_to_skip = samples_to_skip // config.training.per_device_train_batch_size logger.info(f"Resuming training: skipping {batches_to_skip} batches ({samples_to_skip} samples) to reach position {samples_seen}") + + # Skip batches to resume from the correct position within the epoch for _ in range(batches_to_skip): try: next(epoch_iterator) except StopIteration: - # We've reached the end of the epoch while skipping, create new iterator + # We've reached the end of the epoch while skipping + # This shouldn't normally happen, but handle it gracefully + logger.warning(f"Reached end of epoch while skipping batches. Creating new epoch.") + current_epoch_num += 1 + train_dataloader = create_train_dataloader( + train_dataset, + config, + data_collator, + seed_worker, + epoch_num=current_epoch_num, + ) epoch_iterator = iter(train_dataloader) break - + # Create progress bar pbar = tqdm(total=max_train_samples - samples_seen, desc=f"Training from step {global_step}", unit="samples") @@ -482,9 +526,21 @@ def main(): try: batch = next(epoch_iterator) except StopIteration: - # End of epoch, create new iterator + # End of epoch, create new dataloader with fresh shuffle current_epoch = samples_seen / len(train_dataset) logger.info(f"Completed epoch {current_epoch:.2f}") + + # Increment epoch number for new shuffle seed + current_epoch_num += 1 + + # Recreate dataloader with new generator for fresh shuffle + train_dataloader = create_train_dataloader( + train_dataset, + config, + data_collator, + seed_worker, + epoch_num=current_epoch_num, + ) epoch_iterator = iter(train_dataloader) batch = next(epoch_iterator)