From d0efc70de6b21168e1cbe71c8d7d8bfda40ca02e Mon Sep 17 00:00:00 2001 From: Jake Poznanski Date: Tue, 22 Jul 2025 20:48:46 +0000 Subject: [PATCH] Fixes --- olmocr/train/train.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/olmocr/train/train.py b/olmocr/train/train.py index 847b95f..8618bed 100644 --- a/olmocr/train/train.py +++ b/olmocr/train/train.py @@ -13,7 +13,7 @@ import numpy as np import torch from torch.utils.data import ConcatDataset, DataLoader from torch.optim import AdamW -from torch.cuda.amp import autocast +from torch.amp import autocast import wandb from transformers import ( @@ -166,7 +166,7 @@ def evaluate_model( with torch.no_grad(): for batch in dataloader: batch = {k: v.to(device) for k, v in batch.items()} - with autocast(enabled=True, dtype=torch.bfloat16): + with autocast(device_type="cuda", enabled=True, dtype=torch.bfloat16): outputs = model(**batch) total_loss += outputs.loss.item() num_batches += 1 @@ -405,7 +405,7 @@ def main(): for batch_idx, batch in enumerate(train_dataloader): batch = {k: v.to(device) for k, v in batch.items()} - with autocast(enabled=True, dtype=torch.bfloat16): + with autocast(device_type="cuda", enabled=True, dtype=torch.bfloat16): outputs = model(**batch) loss = outputs.loss / config.training.gradient_accumulation_steps loss.backward() @@ -439,8 +439,9 @@ def main(): train_loss = 0.0 num_batches = 0 - # Evaluation - if config.training.eval_steps > 0 and global_step % config.training.eval_steps == 0: + # Evaluation (only after gradient accumulation is complete) + if (batch_idx + 1) % config.training.gradient_accumulation_steps == 0 and \ + config.training.eval_steps > 0 and global_step % config.training.eval_steps == 0 and global_step > 0: metrics = evaluate_model(model, eval_dataloaders, device) logger.info(f"Evaluation at step {global_step}: {metrics}") if "wandb" in config.training.report_to: