Lint fixes

This commit is contained in:
Jake Poznanski 2025-08-13 20:21:04 +00:00
parent 05330150ad
commit 93411a80a0
8 changed files with 157 additions and 194 deletions

View File

@ -49,7 +49,7 @@ from olmocr.s3_utils import (
)
from olmocr.train.dataloader import FrontMatterParser
from olmocr.version import VERSION
from olmocr.work_queue import WorkQueue, LocalBackend, S3Backend
from olmocr.work_queue import LocalBackend, S3Backend, WorkQueue
# Initialize logger
logger = logging.getLogger(__name__)

View File

@ -204,11 +204,11 @@ class TrainingConfig:
adam_epsilon: float = 1e-8
weight_decay: float = 0.01
max_grad_norm: float = 1.0
# Muon optimizer specific settings
muon_momentum: float = 0.95
muon_lr_multiplier_head: float = 11.0 # Learning rate multiplier for head parameters
muon_lr_multiplier_embed: float = 30.0 # Learning rate multiplier for embedding parameters
muon_lr_multiplier_embed: float = 30.0 # Learning rate multiplier for embedding parameters
muon_lr_multiplier_scalar: float = 2.0 # Learning rate multiplier for scalar parameters
# Gradient checkpointing
@ -243,7 +243,7 @@ class TrainingConfig:
# Data collator settings
collator_max_token_len: Optional[int] = None
remove_unused_columns: bool = False # Important for custom datasets
# Torch compile settings
torch_compile: bool = False
torch_compile_backend: str = "inductor" # "inductor", "aot_eager", "cudagraphs", etc.
@ -394,9 +394,7 @@ class Config:
steps.append(FrontMatterParser(front_matter_class=front_matter_class))
elif step_name == "PDFRenderer":
steps.append(
PDFRenderer(target_longest_image_dim=step_config.get("target_longest_image_dim", 1024))
)
steps.append(PDFRenderer(target_longest_image_dim=step_config.get("target_longest_image_dim", 1024)))
elif step_name == "StaticLengthDocumentAnchoring":
steps.append(StaticLengthDocumentAnchoring(target_anchor_text_len=step_config.get("target_anchor_text_len", 6000)))
@ -417,9 +415,7 @@ class Config:
steps.append(FrontMatterOutputFormat())
elif step_name == "InstructUserMessages":
steps.append(InstructUserMessages(
prompt_first=step_config.get("prompt_first", False)
))
steps.append(InstructUserMessages(prompt_first=step_config.get("prompt_first", False)))
elif step_name == "LatexBracketNormalizer":
steps.append(LatexBracketNormalizer())
@ -457,24 +453,16 @@ class Config:
masking_index=step_config.get("masking_index", -100),
)
)
elif step_name == "FilterOutRotatedDocuments":
steps.append(FilterOutRotatedDocuments())
elif step_name == "RotationAugmentation":
steps.append(
RotationAugmentation(
probability=step_config.get("probability", 0.5)
)
)
steps.append(RotationAugmentation(probability=step_config.get("probability", 0.5)))
elif step_name == "AugraphyBasicAugmentations":
steps.append(
AugraphyBasicAugmentations(
probability=step_config.get("probability", 0.5)
)
)
steps.append(AugraphyBasicAugmentations(probability=step_config.get("probability", 0.5)))
else:
raise ValueError(f"Unknown pipeline step: {step_name}")

View File

@ -5,13 +5,11 @@ import re
from abc import ABC, abstractmethod
from concurrent.futures import ProcessPoolExecutor, as_completed
from dataclasses import dataclass, fields
from functools import reduce
from io import BytesIO
from os import PathLike
from pathlib import Path
from typing import (
Any,
Callable,
Dict,
List,
Optional,
@ -144,8 +142,8 @@ class BaseMarkdownPDFDataset(Dataset):
pbar.update(1)
# Sort samples by markdown path for consistent ordering across runs
self.samples.sort(key=lambda x: x['markdown_path'])
self.samples.sort(key=lambda x: x["markdown_path"])
logger.info(f"Found {valid_count} valid markdown-PDF pairs")
if invalid_pdfs:
@ -178,7 +176,7 @@ class BaseMarkdownPDFDataset(Dataset):
sample = step(sample)
if sample is None:
return None
return sample
@ -440,26 +438,26 @@ class LatexBracketNormalizer(PipelineStep):
@dataclass(frozen=True, slots=True)
class RotationAugmentation(PipelineStep):
"""Pipeline step that randomly rotates images for augmentation."""
probability: float = 0.5 # Probability of applying rotation
def __call__(self, sample: Sample) -> Optional[Sample]:
"""Randomly rotate image and update rotation metadata."""
# Only proceed with given probability
if np.random.random() > self.probability:
return sample
# Check if image exists
if "image" not in sample:
return sample
# Check if page_data exists (we need to update it)
if "page_data" not in sample:
return sample
# Randomly choose a rotation (90, 180, or 270 degrees)
rotation_degrees = np.random.choice([90, 180, 270])
# Apply rotation to image
image = sample["image"]
if rotation_degrees == 90:
@ -468,13 +466,13 @@ class RotationAugmentation(PipelineStep):
transpose = Image.Transpose.ROTATE_180
else: # 270
transpose = Image.Transpose.ROTATE_270
rotated_image = image.transpose(transpose)
sample["image"] = rotated_image
# Update page_data
page_data = sample["page_data"]
# Create new PageResponse with updated rotation info
# The rotation_correction should be the inverse of what we applied
# If we rotated 90 clockwise, we need 270 counter-clockwise to correct it
@ -484,9 +482,9 @@ class RotationAugmentation(PipelineStep):
correction = 180
else: # 270
correction = 90
from olmocr.prompts.prompts import PageResponse
new_page_data = PageResponse(
primary_language=page_data.primary_language,
is_rotation_valid=False, # Mark as invalid since we rotated it
@ -495,7 +493,7 @@ class RotationAugmentation(PipelineStep):
is_diagram=page_data.is_diagram,
natural_text=page_data.natural_text,
)
sample["page_data"] = new_page_data
return sample
@ -509,24 +507,24 @@ class FilterOutRotatedDocuments(PipelineStep):
# Check if page_data exists
if "page_data" not in sample:
return sample
page_data = sample["page_data"]
# Check if page_data has the required attributes
if not hasattr(page_data, "is_rotation_valid") or not hasattr(page_data, "rotation_correction"):
return sample
# Filter out if rotation is invalid or rotation correction is not 0
if page_data.is_rotation_valid is False or page_data.rotation_correction != 0:
return None
return sample
@dataclass(frozen=True, slots=True)
class AugraphyBasicAugmentations(PipelineStep):
"""Pipeline step that applies a decent selection of augraphy augmentations to the data"""
probability: float = 0.5 # Overall probability of applying any augmentation
def __call__(self, sample: Sample) -> Optional[Sample]:
@ -534,103 +532,96 @@ class AugraphyBasicAugmentations(PipelineStep):
# Check that the image data exists
if "image" not in sample:
return sample
image = sample["image"]
# Skip all augmentations based on overall probability
if np.random.random() > self.probability:
return sample
# Convert from PIL to BGR for OpenCV/Augraphy
image_numpy = np.array(image)
if len(image_numpy.shape) < 3:
image_bgr = cv2.cvtColor(image_numpy, cv2.COLOR_GRAY2BGR)
else:
image_bgr = cv2.cvtColor(image_numpy, cv2.COLOR_RGB2BGR)
# Apply a basic augraphy pipeline
from augraphy import (
AugraphyPipeline,
Brightness,
InkBleed,
InkMottling,
InkShifter,
Jpeg,
LowInkPeriodicLines,
LowInkRandomLines,
OneOf,
Jpeg,
InkMottling,
InkShifter,
Brightness,
)
# Apply geometric transformations first, maintaing scale
# Apply geometric transformations first, maintaing scale
if np.random.random() < 0.50:
# Get dimensions
height, width = image_bgr.shape[:2]
# Random parameters for geometric transformations
angle = max(min(np.random.standard_normal(), 3), -3) # Small rotation range
scale = np.random.uniform(0.95, 1.05) # Small scale range
tx = np.random.uniform(-0.02, 0.02) * width # Translation as fraction of width
ty = np.random.uniform(-0.02, 0.02) * height # Translation as fraction of height
# Calculate center point
center = (width / 2, height / 2)
# Create transformation matrix
M = cv2.getRotationMatrix2D(center, angle, scale)
# Add translation
M[0, 2] += tx
M[1, 2] += ty
# Apply transformation
image_bgr = cv2.warpAffine(
image_bgr,
M,
image_bgr,
M,
(width, height),
flags=cv2.INTER_LINEAR,
borderMode=cv2.BORDER_CONSTANT,
borderValue=(255, 255, 255) # White background for documents
borderValue=(255, 255, 255), # White background for documents
)
ink_phase = [
OneOf([InkBleed(p=1), LowInkRandomLines(p=1), LowInkPeriodicLines(p=1), InkMottling(p=1), InkShifter(p=1, text_shift_scale_range=(10, 15))], p=0.2),
]
paper_phase = [
OneOf([Brightness(p=0.2), Jpeg(p=1)])
]
paper_phase = [OneOf([Brightness(p=0.2), Jpeg(p=1)])]
post_phase = [
# Empty on purpose or else augmentations are too strong
]
augmentation_pipeline = AugraphyPipeline(
ink_phase=ink_phase, paper_phase=paper_phase, post_phase=post_phase
)
augmentation_pipeline = AugraphyPipeline(ink_phase=ink_phase, paper_phase=paper_phase, post_phase=post_phase)
# Apply augmentations
augmented_image_bgr = augmentation_pipeline(image_bgr)
# Convert back to RGB and then to PIL format
augmented_image_rgb = cv2.cvtColor(augmented_image_bgr, cv2.COLOR_BGR2RGB)
augmented_image_pil = Image.fromarray(augmented_image_rgb)
# Update the sample with the augmented image
sample["image"] = augmented_image_pil
# Double-check PIL image size matches original
assert augmented_image_pil.size == image.size, (
f"PIL image size changed during augmentation: {image.size} -> {augmented_image_pil.size}"
)
assert augmented_image_pil.size == image.size, f"PIL image size changed during augmentation: {image.size} -> {augmented_image_pil.size}"
return sample
@dataclass(frozen=True, slots=True)
class InstructUserMessages(PipelineStep):
"""Creates instruction-following messages format for training."""
prompt_first: bool = False
def __call__(self, sample: Sample) -> Sample:
@ -913,12 +904,12 @@ if __name__ == "__main__":
print(f"PDF file: {sample['pdf_path'].name}")
if "image" in sample and hasattr(sample["image"], "size"):
print(f"Image size: {sample['image'].size}")
# Save image if requested
if args.save_image:
sample["image"].save(args.save_image)
print(f"Saved image to: {args.save_image}")
if "page_data" in sample:
print(f"\nPage data: {sample['page_data']}")
if "messages" in sample:

View File

@ -14,8 +14,8 @@ def zeropower_via_newtonschulz5(G, steps: int):
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
performance at all relative to UV^T, where USV^T = G is the SVD.
"""
assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
a, b, c = (3.4445, -4.7750, 2.0315)
assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
a, b, c = (3.4445, -4.7750, 2.0315)
X = G.bfloat16()
if G.size(-2) > G.size(-1):
X = X.mT
@ -25,9 +25,9 @@ def zeropower_via_newtonschulz5(G, steps: int):
# Perform the NS iterations
for _ in range(steps):
A = X @ X.mT
B = b * A + c * A @ A # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
B = b * A + c * A @ A # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
X = a * X + B @ X
if G.size(-2) > G.size(-1):
X = X.mT
return X
@ -36,10 +36,10 @@ def zeropower_via_newtonschulz5(G, steps: int):
def muon_update(grad, momentum, beta=0.95, ns_steps=5, nesterov=True):
momentum.lerp_(grad, 1 - beta)
update = grad.lerp_(momentum, beta) if nesterov else momentum
if update.ndim == 4: # for the case of conv filters
if update.ndim == 4: # for the case of conv filters
update = update.view(len(update), -1)
update = zeropower_via_newtonschulz5(update, steps=ns_steps)
update *= max(1, grad.size(-2) / grad.size(-1))**0.5
update *= max(1, grad.size(-2) / grad.size(-1)) ** 0.5
return update
@ -64,6 +64,7 @@ class Muon(torch.optim.Optimizer):
weight_decay: The AdamW-style weight decay.
momentum: The momentum. A value of 0.95 here is usually fine.
"""
def __init__(self, params, lr=0.02, weight_decay=0, momentum=0.95):
defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum)
assert isinstance(params, list) and len(params) >= 1 and isinstance(params[0], torch.nn.Parameter)
@ -81,7 +82,7 @@ class Muon(torch.optim.Optimizer):
for group in self.param_groups:
params = group["params"]
params_pad = params + [torch.empty_like(params[-1])] * (dist.get_world_size() - len(params) % dist.get_world_size())
for base_i in range(len(params))[::dist.get_world_size()]:
for base_i in range(len(params))[:: dist.get_world_size()]:
if base_i + dist.get_rank() < len(params):
p = params[base_i + dist.get_rank()]
if p.grad is None:
@ -93,7 +94,7 @@ class Muon(torch.optim.Optimizer):
update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"])
p.mul_(1 - group["lr"] * group["weight_decay"])
p.add_(update.reshape(p.shape), alpha=-group["lr"])
dist.all_gather(params_pad[base_i:base_i + dist.get_world_size()], params_pad[base_i + dist.get_rank()])
dist.all_gather(params_pad[base_i : base_i + dist.get_world_size()], params_pad[base_i + dist.get_rank()])
return loss
@ -102,6 +103,7 @@ class SingleDeviceMuon(torch.optim.Optimizer):
"""
Muon variant for usage in non-distributed settings.
"""
def __init__(self, params, lr=0.02, weight_decay=0, momentum=0.95):
defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum)
super().__init__(params, defaults)
@ -132,8 +134,8 @@ class SingleDeviceMuon(torch.optim.Optimizer):
def adam_update(grad, buf1, buf2, step, betas, eps):
buf1.lerp_(grad, 1 - betas[0])
buf2.lerp_(grad.square(), 1 - betas[1])
buf1c = buf1 / (1 - betas[0]**step)
buf2c = buf2 / (1 - betas[1]**step)
buf1c = buf1 / (1 - betas[0] ** step)
buf2c = buf2 / (1 - betas[1] ** step)
return buf1c / (buf2c.sqrt() + eps)
@ -164,6 +166,7 @@ class MuonWithAuxAdam(torch.optim.Optimizer):
optimizer = MuonWithAuxAdam(param_groups)
```
"""
def __init__(self, param_groups):
for group in param_groups:
assert "use_muon" in group
@ -195,7 +198,7 @@ class MuonWithAuxAdam(torch.optim.Optimizer):
if group["use_muon"]:
params = group["params"]
params_pad = params + [torch.empty_like(params[-1])] * (dist.get_world_size() - len(params) % dist.get_world_size())
for base_i in range(len(params))[::dist.get_world_size()]:
for base_i in range(len(params))[:: dist.get_world_size()]:
if base_i + dist.get_rank() < len(params):
p = params[base_i + dist.get_rank()]
if p.grad is None:
@ -207,7 +210,7 @@ class MuonWithAuxAdam(torch.optim.Optimizer):
update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"])
p.mul_(1 - group["lr"] * group["weight_decay"])
p.add_(update.reshape(p.shape), alpha=-group["lr"])
dist.all_gather(params_pad[base_i:base_i + dist.get_world_size()], params_pad[base_i + dist.get_rank()])
dist.all_gather(params_pad[base_i : base_i + dist.get_world_size()], params_pad[base_i + dist.get_rank()])
else:
for p in group["params"]:
if p.grad is None:
@ -219,8 +222,7 @@ class MuonWithAuxAdam(torch.optim.Optimizer):
state["exp_avg_sq"] = torch.zeros_like(p)
state["step"] = 0
state["step"] += 1
update = adam_update(p.grad, state["exp_avg"], state["exp_avg_sq"],
state["step"], group["betas"], group["eps"])
update = adam_update(p.grad, state["exp_avg"], state["exp_avg_sq"], state["step"], group["betas"], group["eps"])
p.mul_(1 - group["lr"] * group["weight_decay"])
p.add_(update, alpha=-group["lr"])
@ -231,6 +233,7 @@ class SingleDeviceMuonWithAuxAdam(torch.optim.Optimizer):
"""
Non-distributed variant of MuonWithAuxAdam.
"""
def __init__(self, param_groups):
for group in param_groups:
assert "use_muon" in group
@ -280,9 +283,8 @@ class SingleDeviceMuonWithAuxAdam(torch.optim.Optimizer):
state["exp_avg_sq"] = torch.zeros_like(p)
state["step"] = 0
state["step"] += 1
update = adam_update(p.grad, state["exp_avg"], state["exp_avg_sq"],
state["step"], group["betas"], group["eps"])
update = adam_update(p.grad, state["exp_avg"], state["exp_avg_sq"], state["step"], group["betas"], group["eps"])
p.mul_(1 - group["lr"] * group["weight_decay"])
p.add_(update, alpha=-group["lr"])
return loss
return loss

View File

@ -163,7 +163,7 @@ def prepare_olmocr_mix(dataset_path: str, subset: str, split: str, destination:
f.write(natural_text)
else:
f.write("---")
# Look for matching PDF in extracted directory and create symlinks
extracted_pdfs_dir = dest_path / "hugging_face" / "pdf_tarballs" / "extracted"

View File

@ -4,26 +4,25 @@ Simple script to test OlmOCR dataset loading with YAML configuration.
import argparse
import logging
import os
import math
import os
import shutil
from typing import Any, Dict, Optional
import numpy as np
import torch
from torch.utils.data import ConcatDataset, DataLoader
from torch.optim import AdamW
from torch.amp import autocast
import wandb
from torch.amp import autocast
from torch.optim import AdamW
from torch.utils.data import ConcatDataset, DataLoader
from tqdm import tqdm
from transformers import (
AutoProcessor,
get_scheduler,
Qwen2_5_VLForConditionalGeneration,
Qwen2VLForConditionalGeneration,
get_scheduler,
)
from typing import Optional, Dict, Any
from olmocr.train.config import Config
from olmocr.train.dataloader import BaseMarkdownPDFDataset
from olmocr.train.muon import SingleDeviceMuonWithAuxAdam
@ -37,7 +36,6 @@ logging.basicConfig(
logger = logging.getLogger(__name__)
class QwenDataCollator:
"""Data collator for vision-language models that handles numpy arrays."""
@ -80,7 +78,7 @@ class QwenDataCollator:
# Check if we have any valid samples
if not batch["input_ids"]:
return None
# Convert lists to tensors with proper padding
# Note: For Qwen2-VL, we typically handle variable length sequences
# The model's processor should handle the padding internally
@ -107,14 +105,14 @@ def save_checkpoint(
"""Save model, optimizer, scheduler, and training state."""
checkpoint_dir = os.path.join(output_dir, f"checkpoint-{global_step}")
os.makedirs(checkpoint_dir, exist_ok=True)
# Save model
model.save_pretrained(checkpoint_dir)
# Save optimizer and scheduler
torch.save(optimizer.state_dict(), os.path.join(checkpoint_dir, "optimizer.pt"))
torch.save(lr_scheduler.state_dict(), os.path.join(checkpoint_dir, "scheduler.pt"))
# Save training state
state = {
"epoch": epoch,
@ -123,15 +121,12 @@ def save_checkpoint(
"best_metric": best_metric,
}
torch.save(state, os.path.join(checkpoint_dir, "training_state.pt"))
logger.info(f"Saved checkpoint to {checkpoint_dir}")
# Enforce save_total_limit by removing oldest checkpoints
if save_total_limit is not None and save_total_limit > 0:
checkpoints = sorted(
[d for d in os.listdir(output_dir) if d.startswith("checkpoint-")],
key=lambda x: int(x.split("-")[1])
)
checkpoints = sorted([d for d in os.listdir(output_dir) if d.startswith("checkpoint-")], key=lambda x: int(x.split("-")[1]))
while len(checkpoints) > save_total_limit:
oldest = checkpoints.pop(0)
shutil.rmtree(os.path.join(output_dir, oldest))
@ -149,10 +144,10 @@ def load_checkpoint(
"""Load model, optimizer, scheduler, and training state from checkpoint."""
model = model_class.from_pretrained(checkpoint_dir, **init_kwargs)
model.to(device)
optimizer.load_state_dict(torch.load(os.path.join(checkpoint_dir, "optimizer.pt"), map_location=device))
lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint_dir, "scheduler.pt"), map_location=device))
state = torch.load(os.path.join(checkpoint_dir, "training_state.pt"), map_location=device)
logger.info(f"Resumed from checkpoint: {checkpoint_dir} at epoch {state['epoch']:.2f}, step {state['global_step']}, samples seen {state['samples_seen']}")
return model, state
@ -166,11 +161,11 @@ def evaluate_model(
"""Evaluate on all eval datasets and return average loss per dataset."""
model.eval()
eval_metrics = {}
for dataset_name, dataloader in eval_dataloaders.items():
total_loss = 0.0
num_batches = 0
with torch.no_grad():
for batch in dataloader:
# Skip if batch is None (all samples were filtered out)
@ -181,16 +176,16 @@ def evaluate_model(
outputs = model(**batch)
total_loss += outputs.loss.item()
num_batches += 1
avg_loss = total_loss / num_batches if num_batches > 0 else 0.0
eval_metrics[f"eval_{dataset_name}_loss"] = avg_loss
logger.info(f"Eval {dataset_name} loss: {avg_loss:.4f}")
# Compute overall eval loss as average across datasets (or customize as needed)
if eval_metrics:
overall_loss = sum(eval_metrics.values()) / len(eval_metrics)
eval_metrics["eval_loss"] = overall_loss
return eval_metrics
@ -215,11 +210,11 @@ def main():
if config.project_name:
os.environ["WANDB_PROJECT"] = config.project_name
logger.info(f"Setting WANDB_PROJECT to: {config.project_name}")
# Initialize wandb if reporting to it
if "wandb" in config.training.report_to:
wandb.init(project=config.project_name, name=config.run_name, config=config.to_dict())
# Load processor for tokenization
logger.info(f"Loading processor: {config.model.name}")
processor = AutoProcessor.from_pretrained(
@ -284,7 +279,6 @@ def main():
if len(dataset) > 0:
eval_datasets[dataset_name] = dataset
# Log total evaluation samples across all datasets
total_eval_samples = sum(len(dataset) for dataset in eval_datasets.values())
logger.info(f"Total evaluation samples across {len(eval_datasets)} datasets: {total_eval_samples}")
@ -310,14 +304,15 @@ def main():
# Set seeds
torch.manual_seed(config.training.seed)
# Set up data loader seed worker function
def seed_worker(worker_id):
worker_seed = torch.initial_seed() % 2**32
np.random.seed(worker_seed)
import random
random.seed(worker_seed)
# Create generator for data loader
generator = None
if config.training.data_seed is not None:
@ -327,7 +322,7 @@ def main():
# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# Apply torch compile if enabled
if config.training.torch_compile:
logger.info(f"Compiling model with torch.compile (backend={config.training.torch_compile_backend}, mode={config.training.torch_compile_mode})")
@ -365,29 +360,29 @@ def main():
embed_params = [p for n, p in model.named_parameters() if "embed" in n]
scalar_params = [p for p in model.parameters() if p.ndim < 2]
head_params = [p for n, p in model.named_parameters() if "lm_head" in n]
# Create Adam groups with different learning rates
adam_groups = [
dict(params=head_params, lr=float(config.training.learning_rate) * config.training.muon_lr_multiplier_head, use_muon=False),
dict(params=embed_params, lr=float(config.training.learning_rate) * config.training.muon_lr_multiplier_embed, use_muon=False),
dict(params=scalar_params, lr=float(config.training.learning_rate) * config.training.muon_lr_multiplier_scalar, use_muon=False)
dict(params=scalar_params, lr=float(config.training.learning_rate) * config.training.muon_lr_multiplier_scalar, use_muon=False),
]
# Add Adam hyperparameters to groups
for g in adam_groups:
g["betas"] = (config.training.adam_beta1, config.training.adam_beta2)
g["eps"] = float(config.training.adam_epsilon)
g["weight_decay"] = config.training.weight_decay
# Create Muon group
muon_group = dict(
params=hidden_matrix_params,
lr=float(config.training.learning_rate),
momentum=config.training.muon_momentum,
weight_decay=config.training.weight_decay,
use_muon=True
use_muon=True,
)
# Combine all groups
param_groups = [*adam_groups, muon_group]
optimizer = SingleDeviceMuonWithAuxAdam(param_groups)
@ -416,7 +411,7 @@ def main():
global_step = 0
samples_seen = 0
best_metric = float("inf") if not config.training.greater_is_better else -float("inf")
if found_resumable_checkpoint:
model, state = load_checkpoint(model_class, model_init_kwargs, optimizer, lr_scheduler, found_resumable_checkpoint, device)
global_step = state["global_step"]
@ -457,7 +452,7 @@ def main():
current_epoch = samples_seen / len(train_dataset)
logger.info(f"Starting training from epoch {current_epoch:.2f} (step {global_step}, samples {samples_seen}) to {config.training.num_train_epochs} epochs")
logger.info(f"Total training steps: {max_train_steps}, Total samples to process: {max_train_samples}")
if samples_seen >= max_train_samples:
logger.info("Training already completed based on samples seen!")
logger.info("Skipping to final model save.")
@ -465,7 +460,7 @@ def main():
model.train()
accumulated_loss = 0.0
num_losses_accumulated = 0
# Create epoch iterator and skip samples if resuming
epoch_iterator = iter(train_dataloader)
if samples_seen > 0:
@ -479,10 +474,10 @@ def main():
# We've reached the end of the epoch while skipping, create new iterator
epoch_iterator = iter(train_dataloader)
break
# Create progress bar
pbar = tqdm(total=max_train_samples - samples_seen, desc=f"Training from step {global_step}", unit="samples")
while samples_seen < max_train_samples and global_step < max_train_steps:
try:
batch = next(epoch_iterator)
@ -492,48 +487,43 @@ def main():
logger.info(f"Completed epoch {current_epoch:.2f}")
epoch_iterator = iter(train_dataloader)
batch = next(epoch_iterator)
# Skip if batch is None (all samples were filtered out)
if batch is None:
continue
batch = {k: v.to(device) for k, v in batch.items()}
with autocast(device_type="cuda", enabled=True, dtype=torch.bfloat16):
outputs = model(**batch)
loss = outputs.loss / config.training.gradient_accumulation_steps
loss.backward()
accumulated_loss += outputs.loss.item() # Use undivided loss for logging
num_losses_accumulated += 1
samples_seen += config.training.per_device_train_batch_size
# Update progress bar
pbar.update(config.training.per_device_train_batch_size)
# Check if we should do a gradient update
if samples_seen % samples_per_step == 0 or samples_seen >= max_train_samples:
# Clip gradients
torch.nn.utils.clip_grad_norm_(model.parameters(), config.training.max_grad_norm)
# Step optimizer and scheduler
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
global_step += 1
current_epoch = samples_seen / len(train_dataset)
# Update progress bar with current stats
current_lr = lr_scheduler.get_last_lr()[0]
avg_loss = accumulated_loss / num_losses_accumulated if num_losses_accumulated > 0 else 0
pbar.set_postfix({
'loss': f'{avg_loss:.4f}',
'lr': f'{current_lr:.2e}',
'epoch': f'{current_epoch:.2f}',
'step': global_step
})
pbar.set_postfix({"loss": f"{avg_loss:.4f}", "lr": f"{current_lr:.2e}", "epoch": f"{current_epoch:.2f}", "step": global_step})
# Logging
if config.training.logging_steps > 0 and global_step % config.training.logging_steps == 0:
avg_train_loss = accumulated_loss / num_losses_accumulated if num_losses_accumulated > 0 else 0
@ -546,52 +536,49 @@ def main():
logger.info(f"Step {global_step}: epoch={current_epoch:.3f}, loss={avg_train_loss:.4f}, lr={lr_scheduler.get_last_lr()[0]:.2e}")
if "wandb" in config.training.report_to:
wandb.log(logs, step=global_step)
accumulated_loss = 0.0
num_losses_accumulated = 0
# Evaluation
if config.training.eval_steps > 0 and global_step % config.training.eval_steps == 0 and global_step > 0:
metrics = evaluate_model(model, eval_dataloaders, device)
logger.info(f"Evaluation at step {global_step}: {metrics}")
if "wandb" in config.training.report_to:
wandb.log(metrics, step=global_step)
# Update best metric
current_metric = metrics.get(config.training.metric_for_best_model, None)
if current_metric is not None:
if (config.training.greater_is_better and current_metric > best_metric) or \
(not config.training.greater_is_better and current_metric < best_metric):
if (config.training.greater_is_better and current_metric > best_metric) or (
not config.training.greater_is_better and current_metric < best_metric
):
best_metric = current_metric
# Return to training mode
model.train()
# Saving
if config.training.save_steps > 0 and global_step % config.training.save_steps == 0:
save_checkpoint(
model, optimizer, lr_scheduler, current_epoch, global_step, samples_seen, best_metric,
full_output_dir, config.training.save_total_limit
model, optimizer, lr_scheduler, current_epoch, global_step, samples_seen, best_metric, full_output_dir, config.training.save_total_limit
)
# Check if we've reached our training limit
if samples_seen >= max_train_samples or global_step >= max_train_steps:
break
# Close progress bar
pbar.close()
# Save the final checkpoint with step number
logger.info(f"Saving final checkpoint at step {global_step}...")
save_checkpoint(
model, optimizer, lr_scheduler, current_epoch, global_step, samples_seen, best_metric,
full_output_dir, config.training.save_total_limit
)
save_checkpoint(model, optimizer, lr_scheduler, current_epoch, global_step, samples_seen, best_metric, full_output_dir, config.training.save_total_limit)
# Log final training state
final_epoch = samples_seen / len(train_dataset)
logger.info(f"Training completed at epoch {final_epoch:.3f}, step {global_step}, samples {samples_seen}")
# Final evaluation
final_metrics = evaluate_model(model, eval_dataloaders, device)
logger.info(f"Final evaluation metrics: {final_metrics}")
@ -601,4 +588,4 @@ def main():
if __name__ == "__main__":
main()
main()

View File

@ -171,7 +171,6 @@ class WorkQueue:
logger.info(f"Initialized queue with {self.size:,} work items")
return self.size
async def get_work(self, worker_lock_timeout_secs: int = 1800) -> Optional[WorkItem]:
"""
Get the next available work item that isn't completed or locked.
@ -179,7 +178,6 @@ class WorkQueue:
REFRESH_COMPLETED_HASH_CACHE_MAX_ATTEMPTS = 3
refresh_completed_hash_attempt = 0
while True:
try:
work_item = self._queue.get_nowait()
@ -221,7 +219,7 @@ class WorkQueue:
"""
# Create done flag in done_flags_dir
await self.backend.create_done_flag(work_item.hash)
# Remove the worker lock
await self.backend.delete_worker_lock(work_item.hash)
self._queue.task_done()
@ -281,11 +279,7 @@ class LocalBackend(Backend):
def _list_completed() -> Set[str]:
if not os.path.isdir(self._done_flags_dir):
return set()
return {
f[len("done_") : -len(".flag")]
for f in os.listdir(self._done_flags_dir)
if f.startswith("done_") and f.endswith(".flag")
}
return {f[len("done_") : -len(".flag")] for f in os.listdir(self._done_flags_dir) if f.startswith("done_") and f.endswith(".flag")}
return await asyncio.to_thread(_list_completed)
@ -299,6 +293,7 @@ class LocalBackend(Backend):
async def _get_object_mtime(self, path: str) -> Optional[datetime.datetime]:
"""Internal method to get object mtime."""
def _get_mtime() -> Optional[datetime.datetime]:
if not os.path.exists(path):
return None
@ -310,17 +305,17 @@ class LocalBackend(Backend):
"""Check if a worker lock is taken and not stale."""
lock_path = self._get_worker_lock_path(work_hash)
lock_mtime = await self._get_object_mtime(lock_path)
if not lock_mtime:
return False
now = datetime.datetime.now(datetime.timezone.utc)
return (now - lock_mtime).total_seconds() <= worker_lock_timeout_secs
async def create_worker_lock(self, work_hash: str) -> None:
"""Create a worker lock for a work hash."""
lock_path = self._get_worker_lock_path(work_hash)
def _create() -> None:
with open(lock_path, "wb"):
pass
@ -330,7 +325,7 @@ class LocalBackend(Backend):
async def delete_worker_lock(self, work_hash: str) -> None:
"""Delete the worker lock for a work hash if it exists."""
lock_path = self._get_worker_lock_path(work_hash)
def _delete() -> None:
if os.path.exists(lock_path):
os.remove(lock_path)
@ -345,7 +340,7 @@ class LocalBackend(Backend):
async def create_done_flag(self, work_hash: str) -> None:
"""Create a done flag for a work hash."""
done_flag_path = self._get_done_flag_path(work_hash)
def _create() -> None:
with open(done_flag_path, "wb"):
pass
@ -406,10 +401,10 @@ class S3Backend(Backend):
"""Check if a worker lock is taken and not stale."""
lock_path = self._get_worker_lock_path(work_hash)
lock_mtime = await self._get_object_mtime(lock_path)
if not lock_mtime:
return False
now = datetime.datetime.now(datetime.timezone.utc)
return (now - lock_mtime).total_seconds() <= worker_lock_timeout_secs
@ -434,4 +429,4 @@ class S3Backend(Backend):
"""Create a done flag for a work hash."""
done_flag_path = self._get_done_flag_path(work_hash)
bucket, key = parse_s3_path(done_flag_path)
await asyncio.to_thread(self.s3_client.put_object, Bucket=bucket, Key=key, Body=b"")
await asyncio.to_thread(self.s3_client.put_object, Bucket=bucket, Key=key, Body=b"")

View File

@ -6,7 +6,7 @@ from unittest.mock import Mock, patch
from botocore.exceptions import ClientError
# Import the classes we're testing
from olmocr.work_queue import WorkQueue, S3Backend, WorkItem
from olmocr.work_queue import S3Backend, WorkItem, WorkQueue
class TestS3WorkQueue(unittest.TestCase):
@ -214,7 +214,7 @@ class TestS3WorkQueue(unittest.TestCase):
self.assertEqual(len(put_calls), 1)
done_flag_key = put_calls[0][1]["Key"]
self.assertTrue(done_flag_key.endswith(f"done_{work_item.hash}.flag"))
# Verify lock file was deleted
self.s3_client.delete_object.assert_called_once()
key = self.s3_client.delete_object.call_args[1]["Key"]