Grok was asked to drop the hf trainer and implement it custom

This commit is contained in:
Jake Poznanski 2025-07-21 22:33:50 +00:00
parent da6bc458cd
commit 28a207b912

View File

@ -5,20 +5,22 @@ Simple script to test OlmOCR dataset loading with YAML configuration.
import argparse import argparse
import logging import logging
import os import os
import math
import shutil
import time
import numpy as np import numpy as np
import torch import torch
from torch.utils.data import ConcatDataset from torch.utils.data import ConcatDataset, DataLoader
from transformers import ( from transformers import (
AutoProcessor, AutoProcessor,
EarlyStoppingCallback, get_scheduler,
AdamW,
Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration,
Qwen2VLForConditionalGeneration, Qwen2VLForConditionalGeneration,
Trainer,
TrainingArguments,
) )
from typing import Optional from typing import Optional, Dict, Any
from olmocr.train.config import Config from olmocr.train.config import Config
from olmocr.train.dataloader import BaseMarkdownPDFDataset from olmocr.train.dataloader import BaseMarkdownPDFDataset
@ -82,6 +84,100 @@ class QwenDataCollator:
} }
def save_checkpoint(
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
lr_scheduler: Any,
epoch: int,
global_step: int,
best_metric: float,
output_dir: str,
save_total_limit: Optional[int] = None,
):
"""Save model, optimizer, scheduler, and training state."""
checkpoint_dir = os.path.join(output_dir, f"checkpoint-{global_step}")
os.makedirs(checkpoint_dir, exist_ok=True)
# Save model
model.save_pretrained(checkpoint_dir)
# Save optimizer and scheduler
torch.save(optimizer.state_dict(), os.path.join(checkpoint_dir, "optimizer.pt"))
torch.save(lr_scheduler.state_dict(), os.path.join(checkpoint_dir, "scheduler.pt"))
# Save training state
state = {
"epoch": epoch,
"global_step": global_step,
"best_metric": best_metric,
}
torch.save(state, os.path.join(checkpoint_dir, "training_state.pt"))
logger.info(f"Saved checkpoint to {checkpoint_dir}")
# Enforce save_total_limit by removing oldest checkpoints
if save_total_limit is not None and save_total_limit > 0:
checkpoints = sorted(
[d for d in os.listdir(output_dir) if d.startswith("checkpoint-")],
key=lambda x: int(x.split("-")[1])
)
while len(checkpoints) > save_total_limit:
oldest = checkpoints.pop(0)
shutil.rmtree(os.path.join(output_dir, oldest))
logger.info(f"Deleted old checkpoint: {oldest}")
def load_checkpoint(
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
lr_scheduler: Any,
checkpoint_dir: str,
) -> Dict[str, Any]:
"""Load model, optimizer, scheduler, and training state from checkpoint."""
model.load_pretrained(checkpoint_dir)
optimizer.load_state_dict(torch.load(os.path.join(checkpoint_dir, "optimizer.pt")))
lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint_dir, "scheduler.pt")))
state = torch.load(os.path.join(checkpoint_dir, "training_state.pt"))
logger.info(f"Resumed from checkpoint: {checkpoint_dir} at epoch {state['epoch']}, step {state['global_step']}")
return state
def evaluate_model(
model: torch.nn.Module,
eval_dataloaders: Dict[str, DataLoader],
device: torch.device,
amp_scaler: Any, # For bf16
) -> Dict[str, float]:
"""Evaluate on all eval datasets and return average loss per dataset."""
model.eval()
eval_metrics = {}
for dataset_name, dataloader in eval_dataloaders.items():
total_loss = 0.0
num_batches = 0
with torch.no_grad():
for batch in dataloader:
batch = {k: v.to(device) for k, v in batch.items()}
with amp_scaler.autocast(enabled=True): # bf16
outputs = model(**batch)
total_loss += outputs.loss.item()
num_batches += 1
avg_loss = total_loss / num_batches if num_batches > 0 else 0.0
eval_metrics[f"eval_{dataset_name}_loss"] = avg_loss
logger.info(f"Eval {dataset_name} loss: {avg_loss:.4f}")
# Compute overall eval loss as average across datasets (or customize as needed)
if eval_metrics:
overall_loss = sum(eval_metrics.values()) / len(eval_metrics)
eval_metrics["eval_loss"] = overall_loss
return eval_metrics
def main(): def main():
parser = argparse.ArgumentParser(description="Train OlmOCR model") parser = argparse.ArgumentParser(description="Train OlmOCR model")
parser.add_argument("--config", type=str, default="olmocr/train/configs/example_config.yaml", help="Path to YAML configuration file") parser.add_argument("--config", type=str, default="olmocr/train/configs/example_config.yaml", help="Path to YAML configuration file")
@ -104,6 +200,11 @@ def main():
os.environ["WANDB_PROJECT"] = config.project_name os.environ["WANDB_PROJECT"] = config.project_name
logger.info(f"Setting WANDB_PROJECT to: {config.project_name}") logger.info(f"Setting WANDB_PROJECT to: {config.project_name}")
# Initialize wandb if reporting to it
if "wandb" in config.training.report_to:
import wandb
wandb.init(project=config.project_name, name=config.run_name, config=config.to_dict())
# Load processor for tokenization # Load processor for tokenization
logger.info(f"Loading processor: {config.model.name}") logger.info(f"Loading processor: {config.model.name}")
processor = AutoProcessor.from_pretrained( processor = AutoProcessor.from_pretrained(
@ -177,9 +278,10 @@ def main():
# Construct full output directory by appending run_name to base output_dir # Construct full output directory by appending run_name to base output_dir
full_output_dir = os.path.join(config.training.output_dir, config.run_name) full_output_dir = os.path.join(config.training.output_dir, config.run_name)
logger.info(f"Setting output directory to: {full_output_dir}") logger.info(f"Setting output directory to: {full_output_dir}")
os.makedirs(full_output_dir, exist_ok=True)
# Check for existing checkpoints if any # Check for existing checkpoints if any
found_resumable_checkpoint = False found_resumable_checkpoint = None
if os.path.exists(full_output_dir): if os.path.exists(full_output_dir):
# Look for checkpoint directories # Look for checkpoint directories
checkpoint_dirs = [d for d in os.listdir(full_output_dir) if d.startswith("checkpoint-") and os.path.isdir(os.path.join(full_output_dir, d))] checkpoint_dirs = [d for d in os.listdir(full_output_dir) if d.startswith("checkpoint-") and os.path.isdir(os.path.join(full_output_dir, d))]
@ -192,80 +294,197 @@ def main():
else: else:
logger.info("No existing checkpoints found in output directory") logger.info("No existing checkpoints found in output directory")
# Set up training arguments # Set seeds
training_args = TrainingArguments( torch.manual_seed(config.training.seed)
output_dir=full_output_dir, if config.training.data_seed is not None:
num_train_epochs=config.training.num_train_epochs, torch.utils.data.dataset.random.seed(config.training.data_seed)
per_device_train_batch_size=config.training.per_device_train_batch_size,
per_device_eval_batch_size=config.training.per_device_eval_batch_size,
gradient_accumulation_steps=config.training.gradient_accumulation_steps,
learning_rate=float(config.training.learning_rate),
lr_scheduler_type=config.training.lr_scheduler_type,
warmup_ratio=config.training.warmup_ratio,
lr_scheduler_kwargs=config.training.lr_scheduler_kwargs,
optim=config.training.optim,
adam_beta1=config.training.adam_beta1,
adam_beta2=config.training.adam_beta2,
adam_epsilon=config.training.adam_epsilon,
weight_decay=config.training.weight_decay,
max_grad_norm=config.training.max_grad_norm,
bf16=True, # We're sticking with this known good reduced precision option
eval_strategy=config.training.evaluation_strategy,
eval_steps=config.training.eval_steps,
save_strategy=config.training.save_strategy,
save_steps=config.training.save_steps,
save_total_limit=config.training.save_total_limit,
load_best_model_at_end=config.training.load_best_model_at_end,
metric_for_best_model=config.training.metric_for_best_model,
greater_is_better=config.training.greater_is_better,
logging_dir=config.training.logging_dir,
logging_strategy=config.training.logging_strategy,
logging_steps=config.training.logging_steps,
logging_first_step=config.training.logging_first_step,
report_to=config.training.report_to,
seed=config.training.seed,
data_seed=config.training.data_seed,
push_to_hub=False,
label_names=["labels"],
dataloader_drop_last=config.training.dataloader_drop_last,
dataloader_num_workers=config.training.dataloader_num_workers,
remove_unused_columns=config.training.remove_unused_columns,
eval_on_start=True,
run_name=config.run_name,
)
# Set up callbacks # Device setup
callbacks = [] device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if config.training.use_early_stopping: model.to(device)
callbacks.append(
EarlyStoppingCallback( # Set up optimizer
early_stopping_patience=config.training.early_stopping_patience, early_stopping_threshold=config.training.early_stopping_threshold no_decay = ["bias", "LayerNorm.weight"]
) optimizer_grouped_parameters = [
{
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
"weight_decay": config.training.weight_decay,
},
{
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
"weight_decay": 0.0,
},
]
if config.training.optim == "adamw_torch":
optimizer = AdamW(
optimizer_grouped_parameters,
lr=float(config.training.learning_rate),
betas=(config.training.adam_beta1, config.training.adam_beta2),
eps=config.training.adam_epsilon,
) )
else:
raise NotImplementedError(f"Optimizer {config.training.optim} not supported in custom loop")
# Initialize trainer # Total training steps calculation
logger.info("Initializing trainer...") num_update_steps_per_epoch = math.ceil(len(train_dataset) / (config.training.per_device_train_batch_size * config.training.gradient_accumulation_steps))
trainer = Trainer( max_train_steps = int(config.training.num_train_epochs * num_update_steps_per_epoch)
model=model,
args=training_args, # Set up scheduler
train_dataset=train_dataset, lr_scheduler = get_scheduler(
eval_dataset=eval_datasets, name=config.training.lr_scheduler_type,
data_collator=QwenDataCollator(max_token_len=config.training.collator_max_token_len), optimizer=optimizer,
callbacks=callbacks, num_warmup_steps=int(max_train_steps * config.training.warmup_ratio),
num_training_steps=max_train_steps,
scheduler_specific_kwargs=config.training.lr_scheduler_kwargs,
) )
# Start training # Set up mixed precision (bf16)
from torch.cuda.amp import GradScaler, autocast
amp_scaler = GradScaler(enabled=True) # For bf16, but note: bf16 doesn't need scaling like fp16
# Data collator
data_collator = QwenDataCollator(max_token_len=config.training.collator_max_token_len)
# Create dataloaders
train_dataloader = DataLoader(
train_dataset,
batch_size=config.training.per_device_train_batch_size,
shuffle=True,
collate_fn=data_collator,
num_workers=config.training.dataloader_num_workers,
drop_last=config.training.dataloader_drop_last,
)
eval_dataloaders = {
name: DataLoader(
dataset,
batch_size=config.training.per_device_eval_batch_size,
shuffle=False,
collate_fn=data_collator,
num_workers=config.training.dataloader_num_workers,
drop_last=False,
)
for name, dataset in eval_datasets.items()
}
# Resume from checkpoint if available
start_epoch = 0
global_step = 0
best_metric = float("inf") if config.training.greater_is_better else -float("inf")
best_metric_key = config.training.metric_for_best_model # e.g., "eval_loss"
if found_resumable_checkpoint:
state = load_checkpoint(model, optimizer, lr_scheduler, found_resumable_checkpoint)
start_epoch = state["epoch"] + 1 # Start from next epoch
global_step = state["global_step"]
best_metric = state["best_metric"]
# Early stopping setup
patience_counter = 0
early_stopping_patience = config.training.early_stopping_patience if config.training.use_early_stopping else float("inf")
early_stopping_threshold = config.training.early_stopping_threshold
# Evaluate on start if configured
if config.training.eval_on_start:
metrics = evaluate_model(model, eval_dataloaders, device, autocast)
logger.info(f"Initial evaluation: {metrics}")
if "wandb" in config.training.report_to:
wandb.log(metrics, step=global_step)
# Main training loop
logger.info("Starting training...") logger.info("Starting training...")
train_result = trainer.train(resume_from_checkpoint=found_resumable_checkpoint) model.train()
for epoch in range(start_epoch, int(config.training.num_train_epochs)):
epoch_start_time = time.time()
train_loss = 0.0
num_batches = 0
for batch_idx, batch in enumerate(train_dataloader):
batch = {k: v.to(device) for k, v in batch.items()}
with autocast(enabled=True): # bf16
outputs = model(**batch)
loss = outputs.loss / config.training.gradient_accumulation_steps
amp_scaler.scale(loss).backward()
train_loss += loss.item() * config.training.gradient_accumulation_steps
num_batches += 1
if (batch_idx + 1) % config.training.gradient_accumulation_steps == 0:
# Clip gradients
amp_scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), config.training.max_grad_norm)
# Step optimizer and scheduler
amp_scaler.step(optimizer)
amp_scaler.update()
lr_scheduler.step()
optimizer.zero_grad()
global_step += 1
# Logging
if config.training.logging_steps > 0 and global_step % config.training.logging_steps == 0:
avg_train_loss = train_loss / num_batches
logs = {
"train_loss": avg_train_loss,
"learning_rate": lr_scheduler.get_last_lr()[0],
"epoch": epoch + (batch_idx / len(train_dataloader)),
}
logger.info(f"Step {global_step}: {logs}")
if "wandb" in config.training.report_to:
wandb.log(logs, step=global_step)
train_loss = 0.0
num_batches = 0
# Evaluation
if config.training.eval_steps > 0 and global_step % config.training.eval_steps == 0:
metrics = evaluate_model(model, eval_dataloaders, device, autocast)
logger.info(f"Evaluation at step {global_step}: {metrics}")
if "wandb" in config.training.report_to:
wandb.log(metrics, step=global_step)
# Early stopping check
current_metric = metrics.get(best_metric_key, None)
if current_metric is not None:
if (config.training.greater_is_better and current_metric > best_metric + early_stopping_threshold) or \
(not config.training.greater_is_better and current_metric < best_metric - early_stopping_threshold):
best_metric = current_metric
patience_counter = 0
if config.training.load_best_model_at_end:
# Save best model (optional: implement loading best at end)
pass
else:
patience_counter += 1
if patience_counter >= early_stopping_patience:
logger.info(f"Early stopping at step {global_step}")
break
# Saving
if config.training.save_steps > 0 and global_step % config.training.save_steps == 0:
save_checkpoint(
model, optimizer, lr_scheduler, epoch, global_step, best_metric,
full_output_dir, config.training.save_total_limit
)
# End of epoch logging
epoch_time = time.time() - epoch_start_time
logger.info(f"Epoch {epoch} completed in {epoch_time:.2f}s")
if patience_counter >= early_stopping_patience:
break
# Save the final model # Save the final model
logger.info("Saving final model...") logger.info("Saving final model...")
trainer.save_model() model.save_pretrained(full_output_dir)
trainer.save_state()
# Final evaluation
# Log metrics final_metrics = evaluate_model(model, eval_dataloaders, device, autocast)
logger.info(f"Training completed! Metrics: {train_result.metrics}") logger.info(f"Training completed! Final metrics: {final_metrics}")
if "wandb" in config.training.report_to:
wandb.log(final_metrics, step=global_step)
wandb.finish()
if __name__ == "__main__": if __name__ == "__main__":
main() main()