This commit is contained in:
Jake Poznanski 2025-07-22 20:48:46 +00:00
parent 94c78b4f37
commit d0efc70de6

View File

@ -13,7 +13,7 @@ import numpy as np
import torch
from torch.utils.data import ConcatDataset, DataLoader
from torch.optim import AdamW
from torch.cuda.amp import autocast
from torch.amp import autocast
import wandb
from transformers import (
@ -166,7 +166,7 @@ def evaluate_model(
with torch.no_grad():
for batch in dataloader:
batch = {k: v.to(device) for k, v in batch.items()}
with autocast(enabled=True, dtype=torch.bfloat16):
with autocast(device_type="cuda", enabled=True, dtype=torch.bfloat16):
outputs = model(**batch)
total_loss += outputs.loss.item()
num_batches += 1
@ -405,7 +405,7 @@ def main():
for batch_idx, batch in enumerate(train_dataloader):
batch = {k: v.to(device) for k, v in batch.items()}
with autocast(enabled=True, dtype=torch.bfloat16):
with autocast(device_type="cuda", enabled=True, dtype=torch.bfloat16):
outputs = model(**batch)
loss = outputs.loss / config.training.gradient_accumulation_steps
loss.backward()
@ -439,8 +439,9 @@ def main():
train_loss = 0.0
num_batches = 0
# Evaluation
if config.training.eval_steps > 0 and global_step % config.training.eval_steps == 0:
# Evaluation (only after gradient accumulation is complete)
if (batch_idx + 1) % config.training.gradient_accumulation_steps == 0 and \
config.training.eval_steps > 0 and global_step % config.training.eval_steps == 0 and global_step > 0:
metrics = evaluate_model(model, eval_dataloaders, device)
logger.info(f"Evaluation at step {global_step}: {metrics}")
if "wandb" in config.training.report_to: