mirror of
https://github.com/allenai/olmocr.git
synced 2025-10-14 17:52:53 +00:00
Fixes
This commit is contained in:
parent
94c78b4f37
commit
d0efc70de6
@ -13,7 +13,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
from torch.utils.data import ConcatDataset, DataLoader
|
from torch.utils.data import ConcatDataset, DataLoader
|
||||||
from torch.optim import AdamW
|
from torch.optim import AdamW
|
||||||
from torch.cuda.amp import autocast
|
from torch.amp import autocast
|
||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
@ -166,7 +166,7 @@ def evaluate_model(
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for batch in dataloader:
|
for batch in dataloader:
|
||||||
batch = {k: v.to(device) for k, v in batch.items()}
|
batch = {k: v.to(device) for k, v in batch.items()}
|
||||||
with autocast(enabled=True, dtype=torch.bfloat16):
|
with autocast(device_type="cuda", enabled=True, dtype=torch.bfloat16):
|
||||||
outputs = model(**batch)
|
outputs = model(**batch)
|
||||||
total_loss += outputs.loss.item()
|
total_loss += outputs.loss.item()
|
||||||
num_batches += 1
|
num_batches += 1
|
||||||
@ -405,7 +405,7 @@ def main():
|
|||||||
for batch_idx, batch in enumerate(train_dataloader):
|
for batch_idx, batch in enumerate(train_dataloader):
|
||||||
batch = {k: v.to(device) for k, v in batch.items()}
|
batch = {k: v.to(device) for k, v in batch.items()}
|
||||||
|
|
||||||
with autocast(enabled=True, dtype=torch.bfloat16):
|
with autocast(device_type="cuda", enabled=True, dtype=torch.bfloat16):
|
||||||
outputs = model(**batch)
|
outputs = model(**batch)
|
||||||
loss = outputs.loss / config.training.gradient_accumulation_steps
|
loss = outputs.loss / config.training.gradient_accumulation_steps
|
||||||
loss.backward()
|
loss.backward()
|
||||||
@ -439,8 +439,9 @@ def main():
|
|||||||
train_loss = 0.0
|
train_loss = 0.0
|
||||||
num_batches = 0
|
num_batches = 0
|
||||||
|
|
||||||
# Evaluation
|
# Evaluation (only after gradient accumulation is complete)
|
||||||
if config.training.eval_steps > 0 and global_step % config.training.eval_steps == 0:
|
if (batch_idx + 1) % config.training.gradient_accumulation_steps == 0 and \
|
||||||
|
config.training.eval_steps > 0 and global_step % config.training.eval_steps == 0 and global_step > 0:
|
||||||
metrics = evaluate_model(model, eval_dataloaders, device)
|
metrics = evaluate_model(model, eval_dataloaders, device)
|
||||||
logger.info(f"Evaluation at step {global_step}: {metrics}")
|
logger.info(f"Evaluation at step {global_step}: {metrics}")
|
||||||
if "wandb" in config.training.report_to:
|
if "wandb" in config.training.report_to:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user