This commit is contained in:
Jake Poznanski 2025-07-22 20:48:46 +00:00
parent 94c78b4f37
commit d0efc70de6

View File

@ -13,7 +13,7 @@ import numpy as np
import torch import torch
from torch.utils.data import ConcatDataset, DataLoader from torch.utils.data import ConcatDataset, DataLoader
from torch.optim import AdamW from torch.optim import AdamW
from torch.cuda.amp import autocast from torch.amp import autocast
import wandb import wandb
from transformers import ( from transformers import (
@ -166,7 +166,7 @@ def evaluate_model(
with torch.no_grad(): with torch.no_grad():
for batch in dataloader: for batch in dataloader:
batch = {k: v.to(device) for k, v in batch.items()} batch = {k: v.to(device) for k, v in batch.items()}
with autocast(enabled=True, dtype=torch.bfloat16): with autocast(device_type="cuda", enabled=True, dtype=torch.bfloat16):
outputs = model(**batch) outputs = model(**batch)
total_loss += outputs.loss.item() total_loss += outputs.loss.item()
num_batches += 1 num_batches += 1
@ -405,7 +405,7 @@ def main():
for batch_idx, batch in enumerate(train_dataloader): for batch_idx, batch in enumerate(train_dataloader):
batch = {k: v.to(device) for k, v in batch.items()} batch = {k: v.to(device) for k, v in batch.items()}
with autocast(enabled=True, dtype=torch.bfloat16): with autocast(device_type="cuda", enabled=True, dtype=torch.bfloat16):
outputs = model(**batch) outputs = model(**batch)
loss = outputs.loss / config.training.gradient_accumulation_steps loss = outputs.loss / config.training.gradient_accumulation_steps
loss.backward() loss.backward()
@ -439,8 +439,9 @@ def main():
train_loss = 0.0 train_loss = 0.0
num_batches = 0 num_batches = 0
# Evaluation # Evaluation (only after gradient accumulation is complete)
if config.training.eval_steps > 0 and global_step % config.training.eval_steps == 0: if (batch_idx + 1) % config.training.gradient_accumulation_steps == 0 and \
config.training.eval_steps > 0 and global_step % config.training.eval_steps == 0 and global_step > 0:
metrics = evaluate_model(model, eval_dataloaders, device) metrics = evaluate_model(model, eval_dataloaders, device)
logger.info(f"Evaluation at step {global_step}: {metrics}") logger.info(f"Evaluation at step {global_step}: {metrics}")
if "wandb" in config.training.report_to: if "wandb" in config.training.report_to: