mirror of
https://github.com/allenai/olmocr.git
synced 2025-10-18 03:32:16 +00:00
Grok was asked to drop the hf trainer and implement it custom
This commit is contained in:
parent
da6bc458cd
commit
28a207b912
@ -5,20 +5,22 @@ Simple script to test OlmOCR dataset loading with YAML configuration.
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import math
|
||||
import shutil
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import ConcatDataset
|
||||
from torch.utils.data import ConcatDataset, DataLoader
|
||||
from transformers import (
|
||||
AutoProcessor,
|
||||
EarlyStoppingCallback,
|
||||
get_scheduler,
|
||||
AdamW,
|
||||
Qwen2_5_VLForConditionalGeneration,
|
||||
Qwen2VLForConditionalGeneration,
|
||||
Trainer,
|
||||
TrainingArguments,
|
||||
)
|
||||
|
||||
from typing import Optional
|
||||
from typing import Optional, Dict, Any
|
||||
from olmocr.train.config import Config
|
||||
from olmocr.train.dataloader import BaseMarkdownPDFDataset
|
||||
|
||||
@ -82,6 +84,100 @@ class QwenDataCollator:
|
||||
}
|
||||
|
||||
|
||||
def save_checkpoint(
|
||||
model: torch.nn.Module,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
lr_scheduler: Any,
|
||||
epoch: int,
|
||||
global_step: int,
|
||||
best_metric: float,
|
||||
output_dir: str,
|
||||
save_total_limit: Optional[int] = None,
|
||||
):
|
||||
"""Save model, optimizer, scheduler, and training state."""
|
||||
checkpoint_dir = os.path.join(output_dir, f"checkpoint-{global_step}")
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
|
||||
# Save model
|
||||
model.save_pretrained(checkpoint_dir)
|
||||
|
||||
# Save optimizer and scheduler
|
||||
torch.save(optimizer.state_dict(), os.path.join(checkpoint_dir, "optimizer.pt"))
|
||||
torch.save(lr_scheduler.state_dict(), os.path.join(checkpoint_dir, "scheduler.pt"))
|
||||
|
||||
# Save training state
|
||||
state = {
|
||||
"epoch": epoch,
|
||||
"global_step": global_step,
|
||||
"best_metric": best_metric,
|
||||
}
|
||||
torch.save(state, os.path.join(checkpoint_dir, "training_state.pt"))
|
||||
|
||||
logger.info(f"Saved checkpoint to {checkpoint_dir}")
|
||||
|
||||
# Enforce save_total_limit by removing oldest checkpoints
|
||||
if save_total_limit is not None and save_total_limit > 0:
|
||||
checkpoints = sorted(
|
||||
[d for d in os.listdir(output_dir) if d.startswith("checkpoint-")],
|
||||
key=lambda x: int(x.split("-")[1])
|
||||
)
|
||||
while len(checkpoints) > save_total_limit:
|
||||
oldest = checkpoints.pop(0)
|
||||
shutil.rmtree(os.path.join(output_dir, oldest))
|
||||
logger.info(f"Deleted old checkpoint: {oldest}")
|
||||
|
||||
|
||||
def load_checkpoint(
|
||||
model: torch.nn.Module,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
lr_scheduler: Any,
|
||||
checkpoint_dir: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""Load model, optimizer, scheduler, and training state from checkpoint."""
|
||||
model.load_pretrained(checkpoint_dir)
|
||||
|
||||
optimizer.load_state_dict(torch.load(os.path.join(checkpoint_dir, "optimizer.pt")))
|
||||
lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint_dir, "scheduler.pt")))
|
||||
|
||||
state = torch.load(os.path.join(checkpoint_dir, "training_state.pt"))
|
||||
logger.info(f"Resumed from checkpoint: {checkpoint_dir} at epoch {state['epoch']}, step {state['global_step']}")
|
||||
return state
|
||||
|
||||
|
||||
def evaluate_model(
|
||||
model: torch.nn.Module,
|
||||
eval_dataloaders: Dict[str, DataLoader],
|
||||
device: torch.device,
|
||||
amp_scaler: Any, # For bf16
|
||||
) -> Dict[str, float]:
|
||||
"""Evaluate on all eval datasets and return average loss per dataset."""
|
||||
model.eval()
|
||||
eval_metrics = {}
|
||||
|
||||
for dataset_name, dataloader in eval_dataloaders.items():
|
||||
total_loss = 0.0
|
||||
num_batches = 0
|
||||
|
||||
with torch.no_grad():
|
||||
for batch in dataloader:
|
||||
batch = {k: v.to(device) for k, v in batch.items()}
|
||||
with amp_scaler.autocast(enabled=True): # bf16
|
||||
outputs = model(**batch)
|
||||
total_loss += outputs.loss.item()
|
||||
num_batches += 1
|
||||
|
||||
avg_loss = total_loss / num_batches if num_batches > 0 else 0.0
|
||||
eval_metrics[f"eval_{dataset_name}_loss"] = avg_loss
|
||||
logger.info(f"Eval {dataset_name} loss: {avg_loss:.4f}")
|
||||
|
||||
# Compute overall eval loss as average across datasets (or customize as needed)
|
||||
if eval_metrics:
|
||||
overall_loss = sum(eval_metrics.values()) / len(eval_metrics)
|
||||
eval_metrics["eval_loss"] = overall_loss
|
||||
|
||||
return eval_metrics
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Train OlmOCR model")
|
||||
parser.add_argument("--config", type=str, default="olmocr/train/configs/example_config.yaml", help="Path to YAML configuration file")
|
||||
@ -104,6 +200,11 @@ def main():
|
||||
os.environ["WANDB_PROJECT"] = config.project_name
|
||||
logger.info(f"Setting WANDB_PROJECT to: {config.project_name}")
|
||||
|
||||
# Initialize wandb if reporting to it
|
||||
if "wandb" in config.training.report_to:
|
||||
import wandb
|
||||
wandb.init(project=config.project_name, name=config.run_name, config=config.to_dict())
|
||||
|
||||
# Load processor for tokenization
|
||||
logger.info(f"Loading processor: {config.model.name}")
|
||||
processor = AutoProcessor.from_pretrained(
|
||||
@ -177,9 +278,10 @@ def main():
|
||||
# Construct full output directory by appending run_name to base output_dir
|
||||
full_output_dir = os.path.join(config.training.output_dir, config.run_name)
|
||||
logger.info(f"Setting output directory to: {full_output_dir}")
|
||||
os.makedirs(full_output_dir, exist_ok=True)
|
||||
|
||||
# Check for existing checkpoints if any
|
||||
found_resumable_checkpoint = False
|
||||
found_resumable_checkpoint = None
|
||||
if os.path.exists(full_output_dir):
|
||||
# Look for checkpoint directories
|
||||
checkpoint_dirs = [d for d in os.listdir(full_output_dir) if d.startswith("checkpoint-") and os.path.isdir(os.path.join(full_output_dir, d))]
|
||||
@ -192,80 +294,197 @@ def main():
|
||||
else:
|
||||
logger.info("No existing checkpoints found in output directory")
|
||||
|
||||
# Set up training arguments
|
||||
training_args = TrainingArguments(
|
||||
output_dir=full_output_dir,
|
||||
num_train_epochs=config.training.num_train_epochs,
|
||||
per_device_train_batch_size=config.training.per_device_train_batch_size,
|
||||
per_device_eval_batch_size=config.training.per_device_eval_batch_size,
|
||||
gradient_accumulation_steps=config.training.gradient_accumulation_steps,
|
||||
learning_rate=float(config.training.learning_rate),
|
||||
lr_scheduler_type=config.training.lr_scheduler_type,
|
||||
warmup_ratio=config.training.warmup_ratio,
|
||||
lr_scheduler_kwargs=config.training.lr_scheduler_kwargs,
|
||||
optim=config.training.optim,
|
||||
adam_beta1=config.training.adam_beta1,
|
||||
adam_beta2=config.training.adam_beta2,
|
||||
adam_epsilon=config.training.adam_epsilon,
|
||||
weight_decay=config.training.weight_decay,
|
||||
max_grad_norm=config.training.max_grad_norm,
|
||||
bf16=True, # We're sticking with this known good reduced precision option
|
||||
eval_strategy=config.training.evaluation_strategy,
|
||||
eval_steps=config.training.eval_steps,
|
||||
save_strategy=config.training.save_strategy,
|
||||
save_steps=config.training.save_steps,
|
||||
save_total_limit=config.training.save_total_limit,
|
||||
load_best_model_at_end=config.training.load_best_model_at_end,
|
||||
metric_for_best_model=config.training.metric_for_best_model,
|
||||
greater_is_better=config.training.greater_is_better,
|
||||
logging_dir=config.training.logging_dir,
|
||||
logging_strategy=config.training.logging_strategy,
|
||||
logging_steps=config.training.logging_steps,
|
||||
logging_first_step=config.training.logging_first_step,
|
||||
report_to=config.training.report_to,
|
||||
seed=config.training.seed,
|
||||
data_seed=config.training.data_seed,
|
||||
push_to_hub=False,
|
||||
label_names=["labels"],
|
||||
dataloader_drop_last=config.training.dataloader_drop_last,
|
||||
dataloader_num_workers=config.training.dataloader_num_workers,
|
||||
remove_unused_columns=config.training.remove_unused_columns,
|
||||
eval_on_start=True,
|
||||
run_name=config.run_name,
|
||||
)
|
||||
# Set seeds
|
||||
torch.manual_seed(config.training.seed)
|
||||
if config.training.data_seed is not None:
|
||||
torch.utils.data.dataset.random.seed(config.training.data_seed)
|
||||
|
||||
# Set up callbacks
|
||||
callbacks = []
|
||||
if config.training.use_early_stopping:
|
||||
callbacks.append(
|
||||
EarlyStoppingCallback(
|
||||
early_stopping_patience=config.training.early_stopping_patience, early_stopping_threshold=config.training.early_stopping_threshold
|
||||
)
|
||||
# Device setup
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
model.to(device)
|
||||
|
||||
# Set up optimizer
|
||||
no_decay = ["bias", "LayerNorm.weight"]
|
||||
optimizer_grouped_parameters = [
|
||||
{
|
||||
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
|
||||
"weight_decay": config.training.weight_decay,
|
||||
},
|
||||
{
|
||||
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
|
||||
"weight_decay": 0.0,
|
||||
},
|
||||
]
|
||||
if config.training.optim == "adamw_torch":
|
||||
optimizer = AdamW(
|
||||
optimizer_grouped_parameters,
|
||||
lr=float(config.training.learning_rate),
|
||||
betas=(config.training.adam_beta1, config.training.adam_beta2),
|
||||
eps=config.training.adam_epsilon,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Optimizer {config.training.optim} not supported in custom loop")
|
||||
|
||||
# Initialize trainer
|
||||
logger.info("Initializing trainer...")
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_datasets,
|
||||
data_collator=QwenDataCollator(max_token_len=config.training.collator_max_token_len),
|
||||
callbacks=callbacks,
|
||||
# Total training steps calculation
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataset) / (config.training.per_device_train_batch_size * config.training.gradient_accumulation_steps))
|
||||
max_train_steps = int(config.training.num_train_epochs * num_update_steps_per_epoch)
|
||||
|
||||
# Set up scheduler
|
||||
lr_scheduler = get_scheduler(
|
||||
name=config.training.lr_scheduler_type,
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=int(max_train_steps * config.training.warmup_ratio),
|
||||
num_training_steps=max_train_steps,
|
||||
scheduler_specific_kwargs=config.training.lr_scheduler_kwargs,
|
||||
)
|
||||
|
||||
# Start training
|
||||
# Set up mixed precision (bf16)
|
||||
from torch.cuda.amp import GradScaler, autocast
|
||||
amp_scaler = GradScaler(enabled=True) # For bf16, but note: bf16 doesn't need scaling like fp16
|
||||
|
||||
# Data collator
|
||||
data_collator = QwenDataCollator(max_token_len=config.training.collator_max_token_len)
|
||||
|
||||
# Create dataloaders
|
||||
train_dataloader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=config.training.per_device_train_batch_size,
|
||||
shuffle=True,
|
||||
collate_fn=data_collator,
|
||||
num_workers=config.training.dataloader_num_workers,
|
||||
drop_last=config.training.dataloader_drop_last,
|
||||
)
|
||||
|
||||
eval_dataloaders = {
|
||||
name: DataLoader(
|
||||
dataset,
|
||||
batch_size=config.training.per_device_eval_batch_size,
|
||||
shuffle=False,
|
||||
collate_fn=data_collator,
|
||||
num_workers=config.training.dataloader_num_workers,
|
||||
drop_last=False,
|
||||
)
|
||||
for name, dataset in eval_datasets.items()
|
||||
}
|
||||
|
||||
# Resume from checkpoint if available
|
||||
start_epoch = 0
|
||||
global_step = 0
|
||||
best_metric = float("inf") if config.training.greater_is_better else -float("inf")
|
||||
best_metric_key = config.training.metric_for_best_model # e.g., "eval_loss"
|
||||
if found_resumable_checkpoint:
|
||||
state = load_checkpoint(model, optimizer, lr_scheduler, found_resumable_checkpoint)
|
||||
start_epoch = state["epoch"] + 1 # Start from next epoch
|
||||
global_step = state["global_step"]
|
||||
best_metric = state["best_metric"]
|
||||
|
||||
# Early stopping setup
|
||||
patience_counter = 0
|
||||
early_stopping_patience = config.training.early_stopping_patience if config.training.use_early_stopping else float("inf")
|
||||
early_stopping_threshold = config.training.early_stopping_threshold
|
||||
|
||||
# Evaluate on start if configured
|
||||
if config.training.eval_on_start:
|
||||
metrics = evaluate_model(model, eval_dataloaders, device, autocast)
|
||||
logger.info(f"Initial evaluation: {metrics}")
|
||||
if "wandb" in config.training.report_to:
|
||||
wandb.log(metrics, step=global_step)
|
||||
|
||||
# Main training loop
|
||||
logger.info("Starting training...")
|
||||
train_result = trainer.train(resume_from_checkpoint=found_resumable_checkpoint)
|
||||
model.train()
|
||||
for epoch in range(start_epoch, int(config.training.num_train_epochs)):
|
||||
epoch_start_time = time.time()
|
||||
train_loss = 0.0
|
||||
num_batches = 0
|
||||
|
||||
for batch_idx, batch in enumerate(train_dataloader):
|
||||
batch = {k: v.to(device) for k, v in batch.items()}
|
||||
|
||||
with autocast(enabled=True): # bf16
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss / config.training.gradient_accumulation_steps
|
||||
amp_scaler.scale(loss).backward()
|
||||
|
||||
train_loss += loss.item() * config.training.gradient_accumulation_steps
|
||||
num_batches += 1
|
||||
|
||||
if (batch_idx + 1) % config.training.gradient_accumulation_steps == 0:
|
||||
# Clip gradients
|
||||
amp_scaler.unscale_(optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), config.training.max_grad_norm)
|
||||
|
||||
# Step optimizer and scheduler
|
||||
amp_scaler.step(optimizer)
|
||||
amp_scaler.update()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
global_step += 1
|
||||
|
||||
# Logging
|
||||
if config.training.logging_steps > 0 and global_step % config.training.logging_steps == 0:
|
||||
avg_train_loss = train_loss / num_batches
|
||||
logs = {
|
||||
"train_loss": avg_train_loss,
|
||||
"learning_rate": lr_scheduler.get_last_lr()[0],
|
||||
"epoch": epoch + (batch_idx / len(train_dataloader)),
|
||||
}
|
||||
logger.info(f"Step {global_step}: {logs}")
|
||||
if "wandb" in config.training.report_to:
|
||||
wandb.log(logs, step=global_step)
|
||||
|
||||
train_loss = 0.0
|
||||
num_batches = 0
|
||||
|
||||
# Evaluation
|
||||
if config.training.eval_steps > 0 and global_step % config.training.eval_steps == 0:
|
||||
metrics = evaluate_model(model, eval_dataloaders, device, autocast)
|
||||
logger.info(f"Evaluation at step {global_step}: {metrics}")
|
||||
if "wandb" in config.training.report_to:
|
||||
wandb.log(metrics, step=global_step)
|
||||
|
||||
# Early stopping check
|
||||
current_metric = metrics.get(best_metric_key, None)
|
||||
if current_metric is not None:
|
||||
if (config.training.greater_is_better and current_metric > best_metric + early_stopping_threshold) or \
|
||||
(not config.training.greater_is_better and current_metric < best_metric - early_stopping_threshold):
|
||||
best_metric = current_metric
|
||||
patience_counter = 0
|
||||
if config.training.load_best_model_at_end:
|
||||
# Save best model (optional: implement loading best at end)
|
||||
pass
|
||||
else:
|
||||
patience_counter += 1
|
||||
if patience_counter >= early_stopping_patience:
|
||||
logger.info(f"Early stopping at step {global_step}")
|
||||
break
|
||||
|
||||
# Saving
|
||||
if config.training.save_steps > 0 and global_step % config.training.save_steps == 0:
|
||||
save_checkpoint(
|
||||
model, optimizer, lr_scheduler, epoch, global_step, best_metric,
|
||||
full_output_dir, config.training.save_total_limit
|
||||
)
|
||||
|
||||
# End of epoch logging
|
||||
epoch_time = time.time() - epoch_start_time
|
||||
logger.info(f"Epoch {epoch} completed in {epoch_time:.2f}s")
|
||||
|
||||
if patience_counter >= early_stopping_patience:
|
||||
break
|
||||
|
||||
# Save the final model
|
||||
logger.info("Saving final model...")
|
||||
trainer.save_model()
|
||||
trainer.save_state()
|
||||
|
||||
# Log metrics
|
||||
logger.info(f"Training completed! Metrics: {train_result.metrics}")
|
||||
model.save_pretrained(full_output_dir)
|
||||
|
||||
# Final evaluation
|
||||
final_metrics = evaluate_model(model, eval_dataloaders, device, autocast)
|
||||
logger.info(f"Training completed! Final metrics: {final_metrics}")
|
||||
if "wandb" in config.training.report_to:
|
||||
wandb.log(final_metrics, step=global_step)
|
||||
wandb.finish()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main()
|
Loading…
x
Reference in New Issue
Block a user