mirror of
https://github.com/allenai/olmocr.git
synced 2025-10-11 16:22:29 +00:00
Fixing bug with multi epoch training
This commit is contained in:
parent
72fcfafde7
commit
00f51fb2c7
@ -189,6 +189,46 @@ def evaluate_model(
|
|||||||
return eval_metrics
|
return eval_metrics
|
||||||
|
|
||||||
|
|
||||||
|
def create_train_dataloader(
|
||||||
|
train_dataset,
|
||||||
|
config,
|
||||||
|
data_collator,
|
||||||
|
seed_worker,
|
||||||
|
epoch_num: int = 0,
|
||||||
|
) -> DataLoader:
|
||||||
|
"""Create a training dataloader with epoch-specific shuffling.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
train_dataset: The training dataset
|
||||||
|
config: Training configuration
|
||||||
|
data_collator: Data collator for batching
|
||||||
|
seed_worker: Worker initialization function
|
||||||
|
epoch_num: Current epoch number for seed generation
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
DataLoader with epoch-specific shuffling
|
||||||
|
"""
|
||||||
|
# Create generator with epoch-specific seed for different shuffling each epoch
|
||||||
|
epoch_generator = torch.Generator()
|
||||||
|
if config.training.data_seed is not None:
|
||||||
|
# Use epoch number to ensure different shuffling each epoch while maintaining reproducibility
|
||||||
|
epoch_generator.manual_seed(config.training.data_seed + epoch_num)
|
||||||
|
else:
|
||||||
|
# Use a random seed if no data_seed specified
|
||||||
|
epoch_generator.manual_seed(int(torch.randint(0, 2**32 - 1, (1,)).item()))
|
||||||
|
|
||||||
|
return DataLoader(
|
||||||
|
train_dataset,
|
||||||
|
batch_size=config.training.per_device_train_batch_size,
|
||||||
|
shuffle=True,
|
||||||
|
collate_fn=data_collator,
|
||||||
|
num_workers=config.training.dataloader_num_workers,
|
||||||
|
drop_last=config.training.dataloader_drop_last,
|
||||||
|
worker_init_fn=seed_worker,
|
||||||
|
generator=epoch_generator,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(description="Train OlmOCR model")
|
parser = argparse.ArgumentParser(description="Train OlmOCR model")
|
||||||
parser.add_argument("--config", type=str, default="olmocr/train/configs/example_config.yaml", help="Path to YAML configuration file")
|
parser.add_argument("--config", type=str, default="olmocr/train/configs/example_config.yaml", help="Path to YAML configuration file")
|
||||||
@ -313,12 +353,6 @@ def main():
|
|||||||
|
|
||||||
random.seed(worker_seed)
|
random.seed(worker_seed)
|
||||||
|
|
||||||
# Create generator for data loader
|
|
||||||
generator = None
|
|
||||||
if config.training.data_seed is not None:
|
|
||||||
generator = torch.Generator()
|
|
||||||
generator.manual_seed(config.training.data_seed)
|
|
||||||
|
|
||||||
# Device setup
|
# Device setup
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
model.to(device)
|
model.to(device)
|
||||||
@ -418,16 +452,14 @@ def main():
|
|||||||
best_metric = state["best_metric"]
|
best_metric = state["best_metric"]
|
||||||
samples_seen = state["samples_seen"]
|
samples_seen = state["samples_seen"]
|
||||||
|
|
||||||
# Create dataloaders
|
# Create dataloaders - use epoch 0 initially (will be recreated with proper epoch if resuming)
|
||||||
train_dataloader = DataLoader(
|
current_epoch_num = int(samples_seen / len(train_dataset)) if samples_seen > 0 else 0
|
||||||
|
train_dataloader = create_train_dataloader(
|
||||||
train_dataset,
|
train_dataset,
|
||||||
batch_size=config.training.per_device_train_batch_size,
|
config,
|
||||||
shuffle=True,
|
data_collator,
|
||||||
collate_fn=data_collator,
|
seed_worker,
|
||||||
num_workers=config.training.dataloader_num_workers,
|
epoch_num=current_epoch_num,
|
||||||
drop_last=config.training.dataloader_drop_last,
|
|
||||||
worker_init_fn=seed_worker,
|
|
||||||
generator=generator,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
eval_dataloaders = {
|
eval_dataloaders = {
|
||||||
@ -467,14 +499,26 @@ def main():
|
|||||||
samples_to_skip = samples_seen % len(train_dataset)
|
samples_to_skip = samples_seen % len(train_dataset)
|
||||||
batches_to_skip = samples_to_skip // config.training.per_device_train_batch_size
|
batches_to_skip = samples_to_skip // config.training.per_device_train_batch_size
|
||||||
logger.info(f"Resuming training: skipping {batches_to_skip} batches ({samples_to_skip} samples) to reach position {samples_seen}")
|
logger.info(f"Resuming training: skipping {batches_to_skip} batches ({samples_to_skip} samples) to reach position {samples_seen}")
|
||||||
|
|
||||||
|
# Skip batches to resume from the correct position within the epoch
|
||||||
for _ in range(batches_to_skip):
|
for _ in range(batches_to_skip):
|
||||||
try:
|
try:
|
||||||
next(epoch_iterator)
|
next(epoch_iterator)
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
# We've reached the end of the epoch while skipping, create new iterator
|
# We've reached the end of the epoch while skipping
|
||||||
|
# This shouldn't normally happen, but handle it gracefully
|
||||||
|
logger.warning(f"Reached end of epoch while skipping batches. Creating new epoch.")
|
||||||
|
current_epoch_num += 1
|
||||||
|
train_dataloader = create_train_dataloader(
|
||||||
|
train_dataset,
|
||||||
|
config,
|
||||||
|
data_collator,
|
||||||
|
seed_worker,
|
||||||
|
epoch_num=current_epoch_num,
|
||||||
|
)
|
||||||
epoch_iterator = iter(train_dataloader)
|
epoch_iterator = iter(train_dataloader)
|
||||||
break
|
break
|
||||||
|
|
||||||
# Create progress bar
|
# Create progress bar
|
||||||
pbar = tqdm(total=max_train_samples - samples_seen, desc=f"Training from step {global_step}", unit="samples")
|
pbar = tqdm(total=max_train_samples - samples_seen, desc=f"Training from step {global_step}", unit="samples")
|
||||||
|
|
||||||
@ -482,9 +526,21 @@ def main():
|
|||||||
try:
|
try:
|
||||||
batch = next(epoch_iterator)
|
batch = next(epoch_iterator)
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
# End of epoch, create new iterator
|
# End of epoch, create new dataloader with fresh shuffle
|
||||||
current_epoch = samples_seen / len(train_dataset)
|
current_epoch = samples_seen / len(train_dataset)
|
||||||
logger.info(f"Completed epoch {current_epoch:.2f}")
|
logger.info(f"Completed epoch {current_epoch:.2f}")
|
||||||
|
|
||||||
|
# Increment epoch number for new shuffle seed
|
||||||
|
current_epoch_num += 1
|
||||||
|
|
||||||
|
# Recreate dataloader with new generator for fresh shuffle
|
||||||
|
train_dataloader = create_train_dataloader(
|
||||||
|
train_dataset,
|
||||||
|
config,
|
||||||
|
data_collator,
|
||||||
|
seed_worker,
|
||||||
|
epoch_num=current_epoch_num,
|
||||||
|
)
|
||||||
epoch_iterator = iter(train_dataloader)
|
epoch_iterator = iter(train_dataloader)
|
||||||
batch = next(epoch_iterator)
|
batch = next(epoch_iterator)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user