Fixing bug with multi epoch training

This commit is contained in:
Jake Poznanski 2025-09-02 21:03:00 +00:00
parent 72fcfafde7
commit 00f51fb2c7

View File

@ -189,6 +189,46 @@ def evaluate_model(
return eval_metrics
def create_train_dataloader(
train_dataset,
config,
data_collator,
seed_worker,
epoch_num: int = 0,
) -> DataLoader:
"""Create a training dataloader with epoch-specific shuffling.
Args:
train_dataset: The training dataset
config: Training configuration
data_collator: Data collator for batching
seed_worker: Worker initialization function
epoch_num: Current epoch number for seed generation
Returns:
DataLoader with epoch-specific shuffling
"""
# Create generator with epoch-specific seed for different shuffling each epoch
epoch_generator = torch.Generator()
if config.training.data_seed is not None:
# Use epoch number to ensure different shuffling each epoch while maintaining reproducibility
epoch_generator.manual_seed(config.training.data_seed + epoch_num)
else:
# Use a random seed if no data_seed specified
epoch_generator.manual_seed(int(torch.randint(0, 2**32 - 1, (1,)).item()))
return DataLoader(
train_dataset,
batch_size=config.training.per_device_train_batch_size,
shuffle=True,
collate_fn=data_collator,
num_workers=config.training.dataloader_num_workers,
drop_last=config.training.dataloader_drop_last,
worker_init_fn=seed_worker,
generator=epoch_generator,
)
def main():
parser = argparse.ArgumentParser(description="Train OlmOCR model")
parser.add_argument("--config", type=str, default="olmocr/train/configs/example_config.yaml", help="Path to YAML configuration file")
@ -313,12 +353,6 @@ def main():
random.seed(worker_seed)
# Create generator for data loader
generator = None
if config.training.data_seed is not None:
generator = torch.Generator()
generator.manual_seed(config.training.data_seed)
# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
@ -418,16 +452,14 @@ def main():
best_metric = state["best_metric"]
samples_seen = state["samples_seen"]
# Create dataloaders
train_dataloader = DataLoader(
# Create dataloaders - use epoch 0 initially (will be recreated with proper epoch if resuming)
current_epoch_num = int(samples_seen / len(train_dataset)) if samples_seen > 0 else 0
train_dataloader = create_train_dataloader(
train_dataset,
batch_size=config.training.per_device_train_batch_size,
shuffle=True,
collate_fn=data_collator,
num_workers=config.training.dataloader_num_workers,
drop_last=config.training.dataloader_drop_last,
worker_init_fn=seed_worker,
generator=generator,
config,
data_collator,
seed_worker,
epoch_num=current_epoch_num,
)
eval_dataloaders = {
@ -467,14 +499,26 @@ def main():
samples_to_skip = samples_seen % len(train_dataset)
batches_to_skip = samples_to_skip // config.training.per_device_train_batch_size
logger.info(f"Resuming training: skipping {batches_to_skip} batches ({samples_to_skip} samples) to reach position {samples_seen}")
# Skip batches to resume from the correct position within the epoch
for _ in range(batches_to_skip):
try:
next(epoch_iterator)
except StopIteration:
# We've reached the end of the epoch while skipping, create new iterator
# We've reached the end of the epoch while skipping
# This shouldn't normally happen, but handle it gracefully
logger.warning(f"Reached end of epoch while skipping batches. Creating new epoch.")
current_epoch_num += 1
train_dataloader = create_train_dataloader(
train_dataset,
config,
data_collator,
seed_worker,
epoch_num=current_epoch_num,
)
epoch_iterator = iter(train_dataloader)
break
# Create progress bar
pbar = tqdm(total=max_train_samples - samples_seen, desc=f"Training from step {global_step}", unit="samples")
@ -482,9 +526,21 @@ def main():
try:
batch = next(epoch_iterator)
except StopIteration:
# End of epoch, create new iterator
# End of epoch, create new dataloader with fresh shuffle
current_epoch = samples_seen / len(train_dataset)
logger.info(f"Completed epoch {current_epoch:.2f}")
# Increment epoch number for new shuffle seed
current_epoch_num += 1
# Recreate dataloader with new generator for fresh shuffle
train_dataloader = create_train_dataloader(
train_dataset,
config,
data_collator,
seed_worker,
epoch_num=current_epoch_num,
)
epoch_iterator = iter(train_dataloader)
batch = next(epoch_iterator)