olmocr/olmocr/train/train.py
2025-07-25 17:29:12 +00:00

582 lines
24 KiB
Python

"""
Simple script to test OlmOCR dataset loading with YAML configuration.
"""
import argparse
import logging
import os
import math
import shutil
import time
import numpy as np
import torch
from torch.utils.data import ConcatDataset, DataLoader
from torch.optim import AdamW
from torch.amp import autocast
import wandb
from tqdm import tqdm
from transformers import (
AutoProcessor,
get_scheduler,
Qwen2_5_VLForConditionalGeneration,
Qwen2VLForConditionalGeneration,
)
from typing import Optional, Dict, Any
from olmocr.train.config import Config
from olmocr.train.dataloader import BaseMarkdownPDFDataset
from olmocr.train.muon import SingleDeviceMuonWithAuxAdam
# Configure logging
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger = logging.getLogger(__name__)
class QwenDataCollator:
"""Data collator for vision-language models that handles numpy arrays."""
def __init__(self, max_token_len: Optional[int] = None):
self.max_token_len = max_token_len
def __call__(self, examples):
# Filter out None values and extract the fields we need
batch = {"input_ids": [], "attention_mask": [], "labels": [], "pixel_values": [], "image_grid_thw": []}
for example in examples:
if example is not None:
# Convert numpy arrays to tensors
input_ids = torch.from_numpy(example["input_ids"]) if isinstance(example["input_ids"], np.ndarray) else example["input_ids"]
attention_mask = torch.from_numpy(example["attention_mask"]) if isinstance(example["attention_mask"], np.ndarray) else example["attention_mask"]
labels = torch.from_numpy(example["labels"]) if isinstance(example["labels"], np.ndarray) else example["labels"]
# Trim to max_token_len if specified
if self.max_token_len is not None:
input_ids = input_ids[:self.max_token_len]
attention_mask = attention_mask[:self.max_token_len]
labels = labels[:self.max_token_len]
batch["input_ids"].append(input_ids)
batch["attention_mask"].append(attention_mask)
batch["labels"].append(labels)
# Handle pixel_values which might be numpy array or already a tensor
pixel_values = example["pixel_values"]
if isinstance(pixel_values, np.ndarray):
pixel_values = torch.from_numpy(pixel_values)
batch["pixel_values"].append(pixel_values)
# Handle image_grid_thw
image_grid_thw = example["image_grid_thw"]
if isinstance(image_grid_thw, np.ndarray):
image_grid_thw = torch.from_numpy(image_grid_thw)
batch["image_grid_thw"].append(image_grid_thw)
# Convert lists to tensors with proper padding
# Note: For Qwen2-VL, we typically handle variable length sequences
# The model's processor should handle the padding internally
return {
"input_ids": torch.stack(batch["input_ids"]),
"attention_mask": torch.stack(batch["attention_mask"]),
"labels": torch.stack(batch["labels"]),
"pixel_values": torch.stack(batch["pixel_values"]), # Stack into tensor
"image_grid_thw": torch.stack(batch["image_grid_thw"]),
}
def save_checkpoint(
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
lr_scheduler: Any,
epoch: float,
global_step: int,
samples_seen: int,
best_metric: float,
output_dir: str,
save_total_limit: Optional[int] = None,
):
"""Save model, optimizer, scheduler, and training state."""
checkpoint_dir = os.path.join(output_dir, f"checkpoint-{global_step}")
os.makedirs(checkpoint_dir, exist_ok=True)
# Save model
model.save_pretrained(checkpoint_dir)
# Save optimizer and scheduler
torch.save(optimizer.state_dict(), os.path.join(checkpoint_dir, "optimizer.pt"))
torch.save(lr_scheduler.state_dict(), os.path.join(checkpoint_dir, "scheduler.pt"))
# Save training state
state = {
"epoch": epoch,
"global_step": global_step,
"samples_seen": samples_seen,
"best_metric": best_metric,
}
torch.save(state, os.path.join(checkpoint_dir, "training_state.pt"))
logger.info(f"Saved checkpoint to {checkpoint_dir}")
# Enforce save_total_limit by removing oldest checkpoints
if save_total_limit is not None and save_total_limit > 0:
checkpoints = sorted(
[d for d in os.listdir(output_dir) if d.startswith("checkpoint-")],
key=lambda x: int(x.split("-")[1])
)
while len(checkpoints) > save_total_limit:
oldest = checkpoints.pop(0)
shutil.rmtree(os.path.join(output_dir, oldest))
logger.info(f"Deleted old checkpoint: {oldest}")
def load_checkpoint(
model_class: type,
init_kwargs: Dict[str, Any],
optimizer: torch.optim.Optimizer,
lr_scheduler: Any,
checkpoint_dir: str,
device: torch.device,
) -> tuple[torch.nn.Module, Dict[str, Any]]:
"""Load model, optimizer, scheduler, and training state from checkpoint."""
model = model_class.from_pretrained(checkpoint_dir, **init_kwargs)
model.to(device)
optimizer.load_state_dict(torch.load(os.path.join(checkpoint_dir, "optimizer.pt"), map_location=device))
lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint_dir, "scheduler.pt"), map_location=device))
state = torch.load(os.path.join(checkpoint_dir, "training_state.pt"), map_location=device)
logger.info(f"Resumed from checkpoint: {checkpoint_dir} at epoch {state['epoch']:.2f}, step {state['global_step']}, samples seen {state['samples_seen']}")
return model, state
def evaluate_model(
model: torch.nn.Module,
eval_dataloaders: Dict[str, DataLoader],
device: torch.device,
) -> Dict[str, float]:
"""Evaluate on all eval datasets and return average loss per dataset."""
model.eval()
eval_metrics = {}
for dataset_name, dataloader in eval_dataloaders.items():
total_loss = 0.0
num_batches = 0
with torch.no_grad():
for batch in dataloader:
batch = {k: v.to(device) for k, v in batch.items()}
with autocast(device_type="cuda", enabled=True, dtype=torch.bfloat16):
outputs = model(**batch)
total_loss += outputs.loss.item()
num_batches += 1
avg_loss = total_loss / num_batches if num_batches > 0 else 0.0
eval_metrics[f"eval_{dataset_name}_loss"] = avg_loss
logger.info(f"Eval {dataset_name} loss: {avg_loss:.4f}")
# Compute overall eval loss as average across datasets (or customize as needed)
if eval_metrics:
overall_loss = sum(eval_metrics.values()) / len(eval_metrics)
eval_metrics["eval_loss"] = overall_loss
return eval_metrics
def main():
parser = argparse.ArgumentParser(description="Train OlmOCR model")
parser.add_argument("--config", type=str, default="olmocr/train/configs/example_config.yaml", help="Path to YAML configuration file")
args = parser.parse_args()
# Load configuration
logger.info(f"Loading configuration from: {args.config}")
config = Config.from_yaml(args.config)
# Validate configuration
try:
config.validate()
except ValueError as e:
logger.error(f"Configuration validation failed: {e}")
return
# Set wandb project from config
if config.project_name:
os.environ["WANDB_PROJECT"] = config.project_name
logger.info(f"Setting WANDB_PROJECT to: {config.project_name}")
# Initialize wandb if reporting to it
if "wandb" in config.training.report_to:
wandb.init(project=config.project_name, name=config.run_name, config=config.to_dict())
# Load processor for tokenization
logger.info(f"Loading processor: {config.model.name}")
processor = AutoProcessor.from_pretrained(
config.model.name,
)
# Model init kwargs to reuse for loading checkpoints
model_init_kwargs = {
"torch_dtype": getattr(torch, config.model.torch_dtype) if config.model.torch_dtype != "auto" else "auto",
"device_map": config.model.device_map,
"trust_remote_code": config.model.trust_remote_code,
"attn_implementation": config.model.attn_implementation if config.model.use_flash_attention else None,
}
# Load model
logger.info(f"Loading model: {config.model.name}")
if "Qwen2.5-VL" in config.model.name:
model_class = Qwen2_5_VLForConditionalGeneration
model = model_class.from_pretrained(config.model.name, **model_init_kwargs)
elif "Qwen2-VL" in config.model.name:
model_class = Qwen2VLForConditionalGeneration
model = model_class.from_pretrained(config.model.name, **model_init_kwargs)
else:
raise NotImplementedError()
# Enable gradient checkpointing if configured
if config.training.gradient_checkpointing:
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=config.training.gradient_checkpointing_kwargs)
# Create training datasets
logger.info("Creating training datasets...")
train_datasets = []
for i, dataset_cfg in enumerate(config.dataset.train):
root_dir = dataset_cfg["root_dir"]
pipeline_steps = config.get_pipeline_steps(dataset_cfg["pipeline"], processor)
logger.info(f"Creating training dataset {i+1} from: {root_dir}")
dataset = BaseMarkdownPDFDataset(root_dir, pipeline_steps)
logger.info(f"Found {len(dataset)} samples")
if len(dataset) > 0:
train_datasets.append(dataset)
# Combine all training datasets
train_dataset = ConcatDataset(train_datasets) if len(train_datasets) > 1 else train_datasets[0]
logger.info(f"Total training samples: {len(train_dataset)}")
# Create evaluation datasets
logger.info("Creating evaluation datasets...")
eval_datasets = {}
for i, dataset_cfg in enumerate(config.dataset.eval):
root_dir = dataset_cfg["root_dir"]
pipeline_steps = config.get_pipeline_steps(dataset_cfg["pipeline"], processor)
# Use dataset name if provided, otherwise use root_dir as name
dataset_name = dataset_cfg.get("name", f"eval_dataset_{i+1}")
logger.info(f"Creating evaluation dataset '{dataset_name}' from: {root_dir}")
dataset = BaseMarkdownPDFDataset(root_dir, pipeline_steps)
logger.info(f"Found {len(dataset)} samples")
if len(dataset) > 0:
eval_datasets[dataset_name] = dataset
# Log total evaluation samples across all datasets
total_eval_samples = sum(len(dataset) for dataset in eval_datasets.values())
logger.info(f"Total evaluation samples across {len(eval_datasets)} datasets: {total_eval_samples}")
# Construct full output directory by appending run_name to base output_dir
full_output_dir = os.path.join(config.training.output_dir, config.run_name)
logger.info(f"Setting output directory to: {full_output_dir}")
os.makedirs(full_output_dir, exist_ok=True)
# Check for existing checkpoints if any
found_resumable_checkpoint = None
if os.path.exists(full_output_dir):
# Look for checkpoint directories
checkpoint_dirs = [d for d in os.listdir(full_output_dir) if d.startswith("checkpoint-") and os.path.isdir(os.path.join(full_output_dir, d))]
if checkpoint_dirs:
# Sort by checkpoint number and get the latest
checkpoint_dirs.sort(key=lambda x: int(x.split("-")[1]))
latest_checkpoint = os.path.join(full_output_dir, checkpoint_dirs[-1])
logger.info(f"Found existing checkpoint: {latest_checkpoint}")
found_resumable_checkpoint = latest_checkpoint
else:
logger.info("No existing checkpoints found in output directory")
# Set seeds
torch.manual_seed(config.training.seed)
# Set up data loader seed worker function
def seed_worker(worker_id):
worker_seed = torch.initial_seed() % 2**32
np.random.seed(worker_seed)
import random
random.seed(worker_seed)
# Create generator for data loader
generator = None
if config.training.data_seed is not None:
generator = torch.Generator()
generator.manual_seed(config.training.data_seed)
# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Set up optimizer
if config.training.optim == "adamw_torch":
no_decay = ["bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
"weight_decay": config.training.weight_decay,
},
{
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
"weight_decay": 0.0,
},
]
optimizer = AdamW(
optimizer_grouped_parameters,
lr=float(config.training.learning_rate),
betas=(config.training.adam_beta1, config.training.adam_beta2),
eps=config.training.adam_epsilon,
)
elif config.training.optim == "muon":
# Separate parameters for Muon (hidden matrices) and Adam (embeddings, scalars, head)
hidden_matrix_params = [p for n, p in model.named_parameters() if p.ndim >= 2 and "embed" not in n and "lm_head" not in n]
embed_params = [p for n, p in model.named_parameters() if "embed" in n]
scalar_params = [p for p in model.parameters() if p.ndim < 2]
head_params = [p for n, p in model.named_parameters() if "lm_head" in n]
# Create Adam groups with different learning rates
adam_groups = [
dict(params=head_params, lr=float(config.training.learning_rate) * config.training.muon_lr_multiplier_head, use_muon=False),
dict(params=embed_params, lr=float(config.training.learning_rate) * config.training.muon_lr_multiplier_embed, use_muon=False),
dict(params=scalar_params, lr=float(config.training.learning_rate) * config.training.muon_lr_multiplier_scalar, use_muon=False)
]
# Add Adam hyperparameters to groups
for g in adam_groups:
g["betas"] = (config.training.adam_beta1, config.training.adam_beta2)
g["eps"] = config.training.adam_epsilon
g["weight_decay"] = config.training.weight_decay
# Create Muon group
muon_group = dict(
params=hidden_matrix_params,
lr=float(config.training.learning_rate),
momentum=config.training.muon_momentum,
weight_decay=config.training.weight_decay,
use_muon=True
)
# Combine all groups
param_groups = [*adam_groups, muon_group]
optimizer = SingleDeviceMuonWithAuxAdam(param_groups)
else:
raise NotImplementedError(f"Optimizer {config.training.optim} not supported in custom loop")
# Total training steps calculation
samples_per_step = config.training.per_device_train_batch_size * config.training.gradient_accumulation_steps
num_update_steps_per_epoch = math.ceil(len(train_dataset) / samples_per_step)
max_train_steps = int(math.ceil(config.training.num_train_epochs * num_update_steps_per_epoch))
max_train_samples = int(math.ceil(config.training.num_train_epochs * len(train_dataset)))
# Set up scheduler
lr_scheduler = get_scheduler(
name=config.training.lr_scheduler_type,
optimizer=optimizer,
num_warmup_steps=int(max_train_steps * config.training.warmup_ratio),
num_training_steps=max_train_steps,
scheduler_specific_kwargs=config.training.lr_scheduler_kwargs,
)
# Data collator
data_collator = QwenDataCollator(max_token_len=config.training.collator_max_token_len)
# Resume from checkpoint if available
global_step = 0
samples_seen = 0
best_metric = float("inf") if not config.training.greater_is_better else -float("inf")
if found_resumable_checkpoint:
model, state = load_checkpoint(model_class, model_init_kwargs, optimizer, lr_scheduler, found_resumable_checkpoint, device)
global_step = state["global_step"]
best_metric = state["best_metric"]
samples_seen = state["samples_seen"]
# Create dataloaders
train_dataloader = DataLoader(
train_dataset,
batch_size=config.training.per_device_train_batch_size,
shuffle=True,
collate_fn=data_collator,
num_workers=config.training.dataloader_num_workers,
drop_last=config.training.dataloader_drop_last,
worker_init_fn=seed_worker,
generator=generator,
)
eval_dataloaders = {
name: DataLoader(
dataset,
batch_size=config.training.per_device_eval_batch_size,
shuffle=False,
collate_fn=data_collator,
num_workers=config.training.dataloader_num_workers,
drop_last=False,
)
for name, dataset in eval_datasets.items()
}
# Always evaluate on start
metrics = evaluate_model(model, eval_dataloaders, device)
logger.info(f"Initial evaluation: {metrics}")
if "wandb" in config.training.report_to:
wandb.log(metrics, step=global_step)
# Main training loop
current_epoch = samples_seen / len(train_dataset)
logger.info(f"Starting training from epoch {current_epoch:.2f} (step {global_step}, samples {samples_seen}) to {config.training.num_train_epochs} epochs")
logger.info(f"Total training steps: {max_train_steps}, Total samples to process: {max_train_samples}")
if samples_seen >= max_train_samples:
logger.info("Training already completed based on samples seen!")
logger.info("Skipping to final model save.")
else:
model.train()
accumulated_loss = 0.0
num_losses_accumulated = 0
# Create epoch iterator and skip samples if resuming
epoch_iterator = iter(train_dataloader)
if samples_seen > 0:
samples_to_skip = samples_seen % len(train_dataset)
batches_to_skip = samples_to_skip // config.training.per_device_train_batch_size
logger.info(f"Resuming training: skipping {batches_to_skip} batches ({samples_to_skip} samples) to reach position {samples_seen}")
for _ in range(batches_to_skip):
try:
next(epoch_iterator)
except StopIteration:
# We've reached the end of the epoch while skipping, create new iterator
epoch_iterator = iter(train_dataloader)
break
# Create progress bar
pbar = tqdm(total=max_train_samples - samples_seen, desc=f"Training from step {global_step}", unit="samples")
while samples_seen < max_train_samples and global_step < max_train_steps:
try:
batch = next(epoch_iterator)
except StopIteration:
# End of epoch, create new iterator
current_epoch = samples_seen / len(train_dataset)
logger.info(f"Completed epoch {current_epoch:.2f}")
epoch_iterator = iter(train_dataloader)
batch = next(epoch_iterator)
batch = {k: v.to(device) for k, v in batch.items()}
with autocast(device_type="cuda", enabled=True, dtype=torch.bfloat16):
outputs = model(**batch)
loss = outputs.loss / config.training.gradient_accumulation_steps
loss.backward()
accumulated_loss += loss.item()
num_losses_accumulated += 1
samples_seen += config.training.per_device_train_batch_size
# Update progress bar
pbar.update(config.training.per_device_train_batch_size)
# Check if we should do a gradient update
if samples_seen % samples_per_step == 0 or samples_seen >= max_train_samples:
# Clip gradients
torch.nn.utils.clip_grad_norm_(model.parameters(), config.training.max_grad_norm)
# Step optimizer and scheduler
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
global_step += 1
current_epoch = samples_seen / len(train_dataset)
# Update progress bar with current stats
current_lr = lr_scheduler.get_last_lr()[0]
avg_loss = accumulated_loss / num_losses_accumulated if num_losses_accumulated > 0 else 0
pbar.set_postfix({
'loss': f'{avg_loss:.4f}',
'lr': f'{current_lr:.2e}',
'epoch': f'{current_epoch:.2f}',
'step': global_step
})
# Logging
if config.training.logging_steps > 0 and global_step % config.training.logging_steps == 0:
avg_train_loss = accumulated_loss / num_losses_accumulated if num_losses_accumulated > 0 else 0
logs = {
"train_loss": avg_train_loss,
"learning_rate": lr_scheduler.get_last_lr()[0],
"epoch": current_epoch,
"samples_seen": samples_seen,
}
logger.info(f"Step {global_step}: epoch={current_epoch:.3f}, loss={avg_train_loss:.4f}, lr={lr_scheduler.get_last_lr()[0]:.2e}")
if "wandb" in config.training.report_to:
wandb.log(logs, step=global_step)
accumulated_loss = 0.0
num_losses_accumulated = 0
# Evaluation
if config.training.eval_steps > 0 and global_step % config.training.eval_steps == 0 and global_step > 0:
metrics = evaluate_model(model, eval_dataloaders, device)
logger.info(f"Evaluation at step {global_step}: {metrics}")
if "wandb" in config.training.report_to:
wandb.log(metrics, step=global_step)
# Update best metric
current_metric = metrics.get(config.training.metric_for_best_model, None)
if current_metric is not None:
if (config.training.greater_is_better and current_metric > best_metric) or \
(not config.training.greater_is_better and current_metric < best_metric):
best_metric = current_metric
# Return to training mode
model.train()
# Saving
if config.training.save_steps > 0 and global_step % config.training.save_steps == 0:
save_checkpoint(
model, optimizer, lr_scheduler, current_epoch, global_step, samples_seen, best_metric,
full_output_dir, config.training.save_total_limit
)
# Check if we've reached our training limit
if samples_seen >= max_train_samples or global_step >= max_train_steps:
break
# Close progress bar
pbar.close()
# Save the final checkpoint with step number
logger.info(f"Saving final checkpoint at step {global_step}...")
save_checkpoint(
model, optimizer, lr_scheduler, current_epoch, global_step, samples_seen, best_metric,
full_output_dir, config.training.save_total_limit
)
# Log final training state
final_epoch = samples_seen / len(train_dataset)
logger.info(f"Training completed at epoch {final_epoch:.3f}, step {global_step}, samples {samples_seen}")
# Final evaluation
final_metrics = evaluate_model(model, eval_dataloaders, device)
logger.info(f"Final evaluation metrics: {final_metrics}")
if "wandb" in config.training.report_to:
wandb.log(final_metrics, step=global_step)
wandb.finish()
if __name__ == "__main__":
main()