Grok was asked to drop the hf trainer and implement it custom

This commit is contained in:
Jake Poznanski 2025-07-21 22:33:50 +00:00
parent da6bc458cd
commit 28a207b912

View File

@ -5,20 +5,22 @@ Simple script to test OlmOCR dataset loading with YAML configuration.
import argparse
import logging
import os
import math
import shutil
import time
import numpy as np
import torch
from torch.utils.data import ConcatDataset
from torch.utils.data import ConcatDataset, DataLoader
from transformers import (
AutoProcessor,
EarlyStoppingCallback,
get_scheduler,
AdamW,
Qwen2_5_VLForConditionalGeneration,
Qwen2VLForConditionalGeneration,
Trainer,
TrainingArguments,
)
from typing import Optional
from typing import Optional, Dict, Any
from olmocr.train.config import Config
from olmocr.train.dataloader import BaseMarkdownPDFDataset
@ -82,6 +84,100 @@ class QwenDataCollator:
}
def save_checkpoint(
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
lr_scheduler: Any,
epoch: int,
global_step: int,
best_metric: float,
output_dir: str,
save_total_limit: Optional[int] = None,
):
"""Save model, optimizer, scheduler, and training state."""
checkpoint_dir = os.path.join(output_dir, f"checkpoint-{global_step}")
os.makedirs(checkpoint_dir, exist_ok=True)
# Save model
model.save_pretrained(checkpoint_dir)
# Save optimizer and scheduler
torch.save(optimizer.state_dict(), os.path.join(checkpoint_dir, "optimizer.pt"))
torch.save(lr_scheduler.state_dict(), os.path.join(checkpoint_dir, "scheduler.pt"))
# Save training state
state = {
"epoch": epoch,
"global_step": global_step,
"best_metric": best_metric,
}
torch.save(state, os.path.join(checkpoint_dir, "training_state.pt"))
logger.info(f"Saved checkpoint to {checkpoint_dir}")
# Enforce save_total_limit by removing oldest checkpoints
if save_total_limit is not None and save_total_limit > 0:
checkpoints = sorted(
[d for d in os.listdir(output_dir) if d.startswith("checkpoint-")],
key=lambda x: int(x.split("-")[1])
)
while len(checkpoints) > save_total_limit:
oldest = checkpoints.pop(0)
shutil.rmtree(os.path.join(output_dir, oldest))
logger.info(f"Deleted old checkpoint: {oldest}")
def load_checkpoint(
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
lr_scheduler: Any,
checkpoint_dir: str,
) -> Dict[str, Any]:
"""Load model, optimizer, scheduler, and training state from checkpoint."""
model.load_pretrained(checkpoint_dir)
optimizer.load_state_dict(torch.load(os.path.join(checkpoint_dir, "optimizer.pt")))
lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint_dir, "scheduler.pt")))
state = torch.load(os.path.join(checkpoint_dir, "training_state.pt"))
logger.info(f"Resumed from checkpoint: {checkpoint_dir} at epoch {state['epoch']}, step {state['global_step']}")
return state
def evaluate_model(
model: torch.nn.Module,
eval_dataloaders: Dict[str, DataLoader],
device: torch.device,
amp_scaler: Any, # For bf16
) -> Dict[str, float]:
"""Evaluate on all eval datasets and return average loss per dataset."""
model.eval()
eval_metrics = {}
for dataset_name, dataloader in eval_dataloaders.items():
total_loss = 0.0
num_batches = 0
with torch.no_grad():
for batch in dataloader:
batch = {k: v.to(device) for k, v in batch.items()}
with amp_scaler.autocast(enabled=True): # bf16
outputs = model(**batch)
total_loss += outputs.loss.item()
num_batches += 1
avg_loss = total_loss / num_batches if num_batches > 0 else 0.0
eval_metrics[f"eval_{dataset_name}_loss"] = avg_loss
logger.info(f"Eval {dataset_name} loss: {avg_loss:.4f}")
# Compute overall eval loss as average across datasets (or customize as needed)
if eval_metrics:
overall_loss = sum(eval_metrics.values()) / len(eval_metrics)
eval_metrics["eval_loss"] = overall_loss
return eval_metrics
def main():
parser = argparse.ArgumentParser(description="Train OlmOCR model")
parser.add_argument("--config", type=str, default="olmocr/train/configs/example_config.yaml", help="Path to YAML configuration file")
@ -104,6 +200,11 @@ def main():
os.environ["WANDB_PROJECT"] = config.project_name
logger.info(f"Setting WANDB_PROJECT to: {config.project_name}")
# Initialize wandb if reporting to it
if "wandb" in config.training.report_to:
import wandb
wandb.init(project=config.project_name, name=config.run_name, config=config.to_dict())
# Load processor for tokenization
logger.info(f"Loading processor: {config.model.name}")
processor = AutoProcessor.from_pretrained(
@ -177,9 +278,10 @@ def main():
# Construct full output directory by appending run_name to base output_dir
full_output_dir = os.path.join(config.training.output_dir, config.run_name)
logger.info(f"Setting output directory to: {full_output_dir}")
os.makedirs(full_output_dir, exist_ok=True)
# Check for existing checkpoints if any
found_resumable_checkpoint = False
found_resumable_checkpoint = None
if os.path.exists(full_output_dir):
# Look for checkpoint directories
checkpoint_dirs = [d for d in os.listdir(full_output_dir) if d.startswith("checkpoint-") and os.path.isdir(os.path.join(full_output_dir, d))]
@ -192,80 +294,197 @@ def main():
else:
logger.info("No existing checkpoints found in output directory")
# Set up training arguments
training_args = TrainingArguments(
output_dir=full_output_dir,
num_train_epochs=config.training.num_train_epochs,
per_device_train_batch_size=config.training.per_device_train_batch_size,
per_device_eval_batch_size=config.training.per_device_eval_batch_size,
gradient_accumulation_steps=config.training.gradient_accumulation_steps,
learning_rate=float(config.training.learning_rate),
lr_scheduler_type=config.training.lr_scheduler_type,
warmup_ratio=config.training.warmup_ratio,
lr_scheduler_kwargs=config.training.lr_scheduler_kwargs,
optim=config.training.optim,
adam_beta1=config.training.adam_beta1,
adam_beta2=config.training.adam_beta2,
adam_epsilon=config.training.adam_epsilon,
weight_decay=config.training.weight_decay,
max_grad_norm=config.training.max_grad_norm,
bf16=True, # We're sticking with this known good reduced precision option
eval_strategy=config.training.evaluation_strategy,
eval_steps=config.training.eval_steps,
save_strategy=config.training.save_strategy,
save_steps=config.training.save_steps,
save_total_limit=config.training.save_total_limit,
load_best_model_at_end=config.training.load_best_model_at_end,
metric_for_best_model=config.training.metric_for_best_model,
greater_is_better=config.training.greater_is_better,
logging_dir=config.training.logging_dir,
logging_strategy=config.training.logging_strategy,
logging_steps=config.training.logging_steps,
logging_first_step=config.training.logging_first_step,
report_to=config.training.report_to,
seed=config.training.seed,
data_seed=config.training.data_seed,
push_to_hub=False,
label_names=["labels"],
dataloader_drop_last=config.training.dataloader_drop_last,
dataloader_num_workers=config.training.dataloader_num_workers,
remove_unused_columns=config.training.remove_unused_columns,
eval_on_start=True,
run_name=config.run_name,
)
# Set seeds
torch.manual_seed(config.training.seed)
if config.training.data_seed is not None:
torch.utils.data.dataset.random.seed(config.training.data_seed)
# Set up callbacks
callbacks = []
if config.training.use_early_stopping:
callbacks.append(
EarlyStoppingCallback(
early_stopping_patience=config.training.early_stopping_patience, early_stopping_threshold=config.training.early_stopping_threshold
)
# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Set up optimizer
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
"weight_decay": config.training.weight_decay,
},
{
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
"weight_decay": 0.0,
},
]
if config.training.optim == "adamw_torch":
optimizer = AdamW(
optimizer_grouped_parameters,
lr=float(config.training.learning_rate),
betas=(config.training.adam_beta1, config.training.adam_beta2),
eps=config.training.adam_epsilon,
)
else:
raise NotImplementedError(f"Optimizer {config.training.optim} not supported in custom loop")
# Initialize trainer
logger.info("Initializing trainer...")
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_datasets,
data_collator=QwenDataCollator(max_token_len=config.training.collator_max_token_len),
callbacks=callbacks,
# Total training steps calculation
num_update_steps_per_epoch = math.ceil(len(train_dataset) / (config.training.per_device_train_batch_size * config.training.gradient_accumulation_steps))
max_train_steps = int(config.training.num_train_epochs * num_update_steps_per_epoch)
# Set up scheduler
lr_scheduler = get_scheduler(
name=config.training.lr_scheduler_type,
optimizer=optimizer,
num_warmup_steps=int(max_train_steps * config.training.warmup_ratio),
num_training_steps=max_train_steps,
scheduler_specific_kwargs=config.training.lr_scheduler_kwargs,
)
# Start training
# Set up mixed precision (bf16)
from torch.cuda.amp import GradScaler, autocast
amp_scaler = GradScaler(enabled=True) # For bf16, but note: bf16 doesn't need scaling like fp16
# Data collator
data_collator = QwenDataCollator(max_token_len=config.training.collator_max_token_len)
# Create dataloaders
train_dataloader = DataLoader(
train_dataset,
batch_size=config.training.per_device_train_batch_size,
shuffle=True,
collate_fn=data_collator,
num_workers=config.training.dataloader_num_workers,
drop_last=config.training.dataloader_drop_last,
)
eval_dataloaders = {
name: DataLoader(
dataset,
batch_size=config.training.per_device_eval_batch_size,
shuffle=False,
collate_fn=data_collator,
num_workers=config.training.dataloader_num_workers,
drop_last=False,
)
for name, dataset in eval_datasets.items()
}
# Resume from checkpoint if available
start_epoch = 0
global_step = 0
best_metric = float("inf") if config.training.greater_is_better else -float("inf")
best_metric_key = config.training.metric_for_best_model # e.g., "eval_loss"
if found_resumable_checkpoint:
state = load_checkpoint(model, optimizer, lr_scheduler, found_resumable_checkpoint)
start_epoch = state["epoch"] + 1 # Start from next epoch
global_step = state["global_step"]
best_metric = state["best_metric"]
# Early stopping setup
patience_counter = 0
early_stopping_patience = config.training.early_stopping_patience if config.training.use_early_stopping else float("inf")
early_stopping_threshold = config.training.early_stopping_threshold
# Evaluate on start if configured
if config.training.eval_on_start:
metrics = evaluate_model(model, eval_dataloaders, device, autocast)
logger.info(f"Initial evaluation: {metrics}")
if "wandb" in config.training.report_to:
wandb.log(metrics, step=global_step)
# Main training loop
logger.info("Starting training...")
train_result = trainer.train(resume_from_checkpoint=found_resumable_checkpoint)
model.train()
for epoch in range(start_epoch, int(config.training.num_train_epochs)):
epoch_start_time = time.time()
train_loss = 0.0
num_batches = 0
for batch_idx, batch in enumerate(train_dataloader):
batch = {k: v.to(device) for k, v in batch.items()}
with autocast(enabled=True): # bf16
outputs = model(**batch)
loss = outputs.loss / config.training.gradient_accumulation_steps
amp_scaler.scale(loss).backward()
train_loss += loss.item() * config.training.gradient_accumulation_steps
num_batches += 1
if (batch_idx + 1) % config.training.gradient_accumulation_steps == 0:
# Clip gradients
amp_scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), config.training.max_grad_norm)
# Step optimizer and scheduler
amp_scaler.step(optimizer)
amp_scaler.update()
lr_scheduler.step()
optimizer.zero_grad()
global_step += 1
# Logging
if config.training.logging_steps > 0 and global_step % config.training.logging_steps == 0:
avg_train_loss = train_loss / num_batches
logs = {
"train_loss": avg_train_loss,
"learning_rate": lr_scheduler.get_last_lr()[0],
"epoch": epoch + (batch_idx / len(train_dataloader)),
}
logger.info(f"Step {global_step}: {logs}")
if "wandb" in config.training.report_to:
wandb.log(logs, step=global_step)
train_loss = 0.0
num_batches = 0
# Evaluation
if config.training.eval_steps > 0 and global_step % config.training.eval_steps == 0:
metrics = evaluate_model(model, eval_dataloaders, device, autocast)
logger.info(f"Evaluation at step {global_step}: {metrics}")
if "wandb" in config.training.report_to:
wandb.log(metrics, step=global_step)
# Early stopping check
current_metric = metrics.get(best_metric_key, None)
if current_metric is not None:
if (config.training.greater_is_better and current_metric > best_metric + early_stopping_threshold) or \
(not config.training.greater_is_better and current_metric < best_metric - early_stopping_threshold):
best_metric = current_metric
patience_counter = 0
if config.training.load_best_model_at_end:
# Save best model (optional: implement loading best at end)
pass
else:
patience_counter += 1
if patience_counter >= early_stopping_patience:
logger.info(f"Early stopping at step {global_step}")
break
# Saving
if config.training.save_steps > 0 and global_step % config.training.save_steps == 0:
save_checkpoint(
model, optimizer, lr_scheduler, epoch, global_step, best_metric,
full_output_dir, config.training.save_total_limit
)
# End of epoch logging
epoch_time = time.time() - epoch_start_time
logger.info(f"Epoch {epoch} completed in {epoch_time:.2f}s")
if patience_counter >= early_stopping_patience:
break
# Save the final model
logger.info("Saving final model...")
trainer.save_model()
trainer.save_state()
# Log metrics
logger.info(f"Training completed! Metrics: {train_result.metrics}")
model.save_pretrained(full_output_dir)
# Final evaluation
final_metrics = evaluate_model(model, eval_dataloaders, device, autocast)
logger.info(f"Training completed! Final metrics: {final_metrics}")
if "wandb" in config.training.report_to:
wandb.log(final_metrics, step=global_step)
wandb.finish()
if __name__ == "__main__":
main()
main()