mirror of
https://github.com/allenai/olmocr.git
synced 2025-10-11 16:22:29 +00:00
Fixing bug with multi epoch training
This commit is contained in:
parent
72fcfafde7
commit
00f51fb2c7
@ -189,6 +189,46 @@ def evaluate_model(
|
||||
return eval_metrics
|
||||
|
||||
|
||||
def create_train_dataloader(
|
||||
train_dataset,
|
||||
config,
|
||||
data_collator,
|
||||
seed_worker,
|
||||
epoch_num: int = 0,
|
||||
) -> DataLoader:
|
||||
"""Create a training dataloader with epoch-specific shuffling.
|
||||
|
||||
Args:
|
||||
train_dataset: The training dataset
|
||||
config: Training configuration
|
||||
data_collator: Data collator for batching
|
||||
seed_worker: Worker initialization function
|
||||
epoch_num: Current epoch number for seed generation
|
||||
|
||||
Returns:
|
||||
DataLoader with epoch-specific shuffling
|
||||
"""
|
||||
# Create generator with epoch-specific seed for different shuffling each epoch
|
||||
epoch_generator = torch.Generator()
|
||||
if config.training.data_seed is not None:
|
||||
# Use epoch number to ensure different shuffling each epoch while maintaining reproducibility
|
||||
epoch_generator.manual_seed(config.training.data_seed + epoch_num)
|
||||
else:
|
||||
# Use a random seed if no data_seed specified
|
||||
epoch_generator.manual_seed(int(torch.randint(0, 2**32 - 1, (1,)).item()))
|
||||
|
||||
return DataLoader(
|
||||
train_dataset,
|
||||
batch_size=config.training.per_device_train_batch_size,
|
||||
shuffle=True,
|
||||
collate_fn=data_collator,
|
||||
num_workers=config.training.dataloader_num_workers,
|
||||
drop_last=config.training.dataloader_drop_last,
|
||||
worker_init_fn=seed_worker,
|
||||
generator=epoch_generator,
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Train OlmOCR model")
|
||||
parser.add_argument("--config", type=str, default="olmocr/train/configs/example_config.yaml", help="Path to YAML configuration file")
|
||||
@ -313,12 +353,6 @@ def main():
|
||||
|
||||
random.seed(worker_seed)
|
||||
|
||||
# Create generator for data loader
|
||||
generator = None
|
||||
if config.training.data_seed is not None:
|
||||
generator = torch.Generator()
|
||||
generator.manual_seed(config.training.data_seed)
|
||||
|
||||
# Device setup
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
model.to(device)
|
||||
@ -418,16 +452,14 @@ def main():
|
||||
best_metric = state["best_metric"]
|
||||
samples_seen = state["samples_seen"]
|
||||
|
||||
# Create dataloaders
|
||||
train_dataloader = DataLoader(
|
||||
# Create dataloaders - use epoch 0 initially (will be recreated with proper epoch if resuming)
|
||||
current_epoch_num = int(samples_seen / len(train_dataset)) if samples_seen > 0 else 0
|
||||
train_dataloader = create_train_dataloader(
|
||||
train_dataset,
|
||||
batch_size=config.training.per_device_train_batch_size,
|
||||
shuffle=True,
|
||||
collate_fn=data_collator,
|
||||
num_workers=config.training.dataloader_num_workers,
|
||||
drop_last=config.training.dataloader_drop_last,
|
||||
worker_init_fn=seed_worker,
|
||||
generator=generator,
|
||||
config,
|
||||
data_collator,
|
||||
seed_worker,
|
||||
epoch_num=current_epoch_num,
|
||||
)
|
||||
|
||||
eval_dataloaders = {
|
||||
@ -467,14 +499,26 @@ def main():
|
||||
samples_to_skip = samples_seen % len(train_dataset)
|
||||
batches_to_skip = samples_to_skip // config.training.per_device_train_batch_size
|
||||
logger.info(f"Resuming training: skipping {batches_to_skip} batches ({samples_to_skip} samples) to reach position {samples_seen}")
|
||||
|
||||
# Skip batches to resume from the correct position within the epoch
|
||||
for _ in range(batches_to_skip):
|
||||
try:
|
||||
next(epoch_iterator)
|
||||
except StopIteration:
|
||||
# We've reached the end of the epoch while skipping, create new iterator
|
||||
# We've reached the end of the epoch while skipping
|
||||
# This shouldn't normally happen, but handle it gracefully
|
||||
logger.warning(f"Reached end of epoch while skipping batches. Creating new epoch.")
|
||||
current_epoch_num += 1
|
||||
train_dataloader = create_train_dataloader(
|
||||
train_dataset,
|
||||
config,
|
||||
data_collator,
|
||||
seed_worker,
|
||||
epoch_num=current_epoch_num,
|
||||
)
|
||||
epoch_iterator = iter(train_dataloader)
|
||||
break
|
||||
|
||||
|
||||
# Create progress bar
|
||||
pbar = tqdm(total=max_train_samples - samples_seen, desc=f"Training from step {global_step}", unit="samples")
|
||||
|
||||
@ -482,9 +526,21 @@ def main():
|
||||
try:
|
||||
batch = next(epoch_iterator)
|
||||
except StopIteration:
|
||||
# End of epoch, create new iterator
|
||||
# End of epoch, create new dataloader with fresh shuffle
|
||||
current_epoch = samples_seen / len(train_dataset)
|
||||
logger.info(f"Completed epoch {current_epoch:.2f}")
|
||||
|
||||
# Increment epoch number for new shuffle seed
|
||||
current_epoch_num += 1
|
||||
|
||||
# Recreate dataloader with new generator for fresh shuffle
|
||||
train_dataloader = create_train_dataloader(
|
||||
train_dataset,
|
||||
config,
|
||||
data_collator,
|
||||
seed_worker,
|
||||
epoch_num=current_epoch_num,
|
||||
)
|
||||
epoch_iterator = iter(train_dataloader)
|
||||
batch = next(epoch_iterator)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user