From 93411a80a0e1c16a0b9a7475f3b76d510e8c6483 Mon Sep 17 00:00:00 2001 From: Jake Poznanski Date: Wed, 13 Aug 2025 20:21:04 +0000 Subject: [PATCH] Lint fixes --- olmocr/pipeline.py | 2 +- olmocr/train/config.py | 34 +++----- olmocr/train/dataloader.py | 107 +++++++++++------------ olmocr/train/muon.py | 36 ++++---- olmocr/train/prepare_olmocrmix.py | 2 +- olmocr/train/train.py | 139 ++++++++++++++---------------- olmocr/work_queue.py | 27 +++--- tests/test_s3_work_queue.py | 4 +- 8 files changed, 157 insertions(+), 194 deletions(-) diff --git a/olmocr/pipeline.py b/olmocr/pipeline.py index ee33ef4..f88ff4a 100644 --- a/olmocr/pipeline.py +++ b/olmocr/pipeline.py @@ -49,7 +49,7 @@ from olmocr.s3_utils import ( ) from olmocr.train.dataloader import FrontMatterParser from olmocr.version import VERSION -from olmocr.work_queue import WorkQueue, LocalBackend, S3Backend +from olmocr.work_queue import LocalBackend, S3Backend, WorkQueue # Initialize logger logger = logging.getLogger(__name__) diff --git a/olmocr/train/config.py b/olmocr/train/config.py index 729f2dc..63786e4 100644 --- a/olmocr/train/config.py +++ b/olmocr/train/config.py @@ -204,11 +204,11 @@ class TrainingConfig: adam_epsilon: float = 1e-8 weight_decay: float = 0.01 max_grad_norm: float = 1.0 - + # Muon optimizer specific settings muon_momentum: float = 0.95 muon_lr_multiplier_head: float = 11.0 # Learning rate multiplier for head parameters - muon_lr_multiplier_embed: float = 30.0 # Learning rate multiplier for embedding parameters + muon_lr_multiplier_embed: float = 30.0 # Learning rate multiplier for embedding parameters muon_lr_multiplier_scalar: float = 2.0 # Learning rate multiplier for scalar parameters # Gradient checkpointing @@ -243,7 +243,7 @@ class TrainingConfig: # Data collator settings collator_max_token_len: Optional[int] = None remove_unused_columns: bool = False # Important for custom datasets - + # Torch compile settings torch_compile: bool = False torch_compile_backend: str = "inductor" # "inductor", "aot_eager", "cudagraphs", etc. @@ -394,9 +394,7 @@ class Config: steps.append(FrontMatterParser(front_matter_class=front_matter_class)) elif step_name == "PDFRenderer": - steps.append( - PDFRenderer(target_longest_image_dim=step_config.get("target_longest_image_dim", 1024)) - ) + steps.append(PDFRenderer(target_longest_image_dim=step_config.get("target_longest_image_dim", 1024))) elif step_name == "StaticLengthDocumentAnchoring": steps.append(StaticLengthDocumentAnchoring(target_anchor_text_len=step_config.get("target_anchor_text_len", 6000))) @@ -417,9 +415,7 @@ class Config: steps.append(FrontMatterOutputFormat()) elif step_name == "InstructUserMessages": - steps.append(InstructUserMessages( - prompt_first=step_config.get("prompt_first", False) - )) + steps.append(InstructUserMessages(prompt_first=step_config.get("prompt_first", False))) elif step_name == "LatexBracketNormalizer": steps.append(LatexBracketNormalizer()) @@ -457,24 +453,16 @@ class Config: masking_index=step_config.get("masking_index", -100), ) ) - + elif step_name == "FilterOutRotatedDocuments": steps.append(FilterOutRotatedDocuments()) - + elif step_name == "RotationAugmentation": - steps.append( - RotationAugmentation( - probability=step_config.get("probability", 0.5) - ) - ) - + steps.append(RotationAugmentation(probability=step_config.get("probability", 0.5))) + elif step_name == "AugraphyBasicAugmentations": - steps.append( - AugraphyBasicAugmentations( - probability=step_config.get("probability", 0.5) - ) - ) - + steps.append(AugraphyBasicAugmentations(probability=step_config.get("probability", 0.5))) + else: raise ValueError(f"Unknown pipeline step: {step_name}") diff --git a/olmocr/train/dataloader.py b/olmocr/train/dataloader.py index 7271344..c7958f3 100644 --- a/olmocr/train/dataloader.py +++ b/olmocr/train/dataloader.py @@ -5,13 +5,11 @@ import re from abc import ABC, abstractmethod from concurrent.futures import ProcessPoolExecutor, as_completed from dataclasses import dataclass, fields -from functools import reduce from io import BytesIO from os import PathLike from pathlib import Path from typing import ( Any, - Callable, Dict, List, Optional, @@ -144,8 +142,8 @@ class BaseMarkdownPDFDataset(Dataset): pbar.update(1) # Sort samples by markdown path for consistent ordering across runs - self.samples.sort(key=lambda x: x['markdown_path']) - + self.samples.sort(key=lambda x: x["markdown_path"]) + logger.info(f"Found {valid_count} valid markdown-PDF pairs") if invalid_pdfs: @@ -178,7 +176,7 @@ class BaseMarkdownPDFDataset(Dataset): sample = step(sample) if sample is None: return None - + return sample @@ -440,26 +438,26 @@ class LatexBracketNormalizer(PipelineStep): @dataclass(frozen=True, slots=True) class RotationAugmentation(PipelineStep): """Pipeline step that randomly rotates images for augmentation.""" - + probability: float = 0.5 # Probability of applying rotation - + def __call__(self, sample: Sample) -> Optional[Sample]: """Randomly rotate image and update rotation metadata.""" # Only proceed with given probability if np.random.random() > self.probability: return sample - + # Check if image exists if "image" not in sample: return sample - + # Check if page_data exists (we need to update it) if "page_data" not in sample: return sample - + # Randomly choose a rotation (90, 180, or 270 degrees) rotation_degrees = np.random.choice([90, 180, 270]) - + # Apply rotation to image image = sample["image"] if rotation_degrees == 90: @@ -468,13 +466,13 @@ class RotationAugmentation(PipelineStep): transpose = Image.Transpose.ROTATE_180 else: # 270 transpose = Image.Transpose.ROTATE_270 - + rotated_image = image.transpose(transpose) sample["image"] = rotated_image - + # Update page_data page_data = sample["page_data"] - + # Create new PageResponse with updated rotation info # The rotation_correction should be the inverse of what we applied # If we rotated 90 clockwise, we need 270 counter-clockwise to correct it @@ -484,9 +482,9 @@ class RotationAugmentation(PipelineStep): correction = 180 else: # 270 correction = 90 - + from olmocr.prompts.prompts import PageResponse - + new_page_data = PageResponse( primary_language=page_data.primary_language, is_rotation_valid=False, # Mark as invalid since we rotated it @@ -495,7 +493,7 @@ class RotationAugmentation(PipelineStep): is_diagram=page_data.is_diagram, natural_text=page_data.natural_text, ) - + sample["page_data"] = new_page_data return sample @@ -509,24 +507,24 @@ class FilterOutRotatedDocuments(PipelineStep): # Check if page_data exists if "page_data" not in sample: return sample - + page_data = sample["page_data"] - + # Check if page_data has the required attributes if not hasattr(page_data, "is_rotation_valid") or not hasattr(page_data, "rotation_correction"): return sample - + # Filter out if rotation is invalid or rotation correction is not 0 if page_data.is_rotation_valid is False or page_data.rotation_correction != 0: return None - + return sample @dataclass(frozen=True, slots=True) class AugraphyBasicAugmentations(PipelineStep): """Pipeline step that applies a decent selection of augraphy augmentations to the data""" - + probability: float = 0.5 # Overall probability of applying any augmentation def __call__(self, sample: Sample) -> Optional[Sample]: @@ -534,103 +532,96 @@ class AugraphyBasicAugmentations(PipelineStep): # Check that the image data exists if "image" not in sample: return sample - + image = sample["image"] - + # Skip all augmentations based on overall probability if np.random.random() > self.probability: return sample - + # Convert from PIL to BGR for OpenCV/Augraphy image_numpy = np.array(image) if len(image_numpy.shape) < 3: image_bgr = cv2.cvtColor(image_numpy, cv2.COLOR_GRAY2BGR) else: image_bgr = cv2.cvtColor(image_numpy, cv2.COLOR_RGB2BGR) - + # Apply a basic augraphy pipeline from augraphy import ( AugraphyPipeline, + Brightness, InkBleed, + InkMottling, + InkShifter, + Jpeg, LowInkPeriodicLines, LowInkRandomLines, OneOf, - - Jpeg, - InkMottling, - InkShifter, - Brightness, ) - # Apply geometric transformations first, maintaing scale + # Apply geometric transformations first, maintaing scale if np.random.random() < 0.50: # Get dimensions height, width = image_bgr.shape[:2] - + # Random parameters for geometric transformations angle = max(min(np.random.standard_normal(), 3), -3) # Small rotation range scale = np.random.uniform(0.95, 1.05) # Small scale range tx = np.random.uniform(-0.02, 0.02) * width # Translation as fraction of width ty = np.random.uniform(-0.02, 0.02) * height # Translation as fraction of height - + # Calculate center point center = (width / 2, height / 2) - + # Create transformation matrix M = cv2.getRotationMatrix2D(center, angle, scale) - + # Add translation M[0, 2] += tx M[1, 2] += ty - + # Apply transformation image_bgr = cv2.warpAffine( - image_bgr, - M, + image_bgr, + M, (width, height), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT, - borderValue=(255, 255, 255) # White background for documents + borderValue=(255, 255, 255), # White background for documents ) - ink_phase = [ OneOf([InkBleed(p=1), LowInkRandomLines(p=1), LowInkPeriodicLines(p=1), InkMottling(p=1), InkShifter(p=1, text_shift_scale_range=(10, 15))], p=0.2), ] - paper_phase = [ - OneOf([Brightness(p=0.2), Jpeg(p=1)]) - ] + paper_phase = [OneOf([Brightness(p=0.2), Jpeg(p=1)])] post_phase = [ # Empty on purpose or else augmentations are too strong ] - augmentation_pipeline = AugraphyPipeline( - ink_phase=ink_phase, paper_phase=paper_phase, post_phase=post_phase - ) - + augmentation_pipeline = AugraphyPipeline(ink_phase=ink_phase, paper_phase=paper_phase, post_phase=post_phase) + # Apply augmentations augmented_image_bgr = augmentation_pipeline(image_bgr) - + # Convert back to RGB and then to PIL format augmented_image_rgb = cv2.cvtColor(augmented_image_bgr, cv2.COLOR_BGR2RGB) augmented_image_pil = Image.fromarray(augmented_image_rgb) - + # Update the sample with the augmented image sample["image"] = augmented_image_pil - + # Double-check PIL image size matches original - assert augmented_image_pil.size == image.size, ( - f"PIL image size changed during augmentation: {image.size} -> {augmented_image_pil.size}" - ) - + assert augmented_image_pil.size == image.size, f"PIL image size changed during augmentation: {image.size} -> {augmented_image_pil.size}" + return sample + @dataclass(frozen=True, slots=True) class InstructUserMessages(PipelineStep): """Creates instruction-following messages format for training.""" - + prompt_first: bool = False def __call__(self, sample: Sample) -> Sample: @@ -913,12 +904,12 @@ if __name__ == "__main__": print(f"PDF file: {sample['pdf_path'].name}") if "image" in sample and hasattr(sample["image"], "size"): print(f"Image size: {sample['image'].size}") - + # Save image if requested if args.save_image: sample["image"].save(args.save_image) print(f"Saved image to: {args.save_image}") - + if "page_data" in sample: print(f"\nPage data: {sample['page_data']}") if "messages" in sample: diff --git a/olmocr/train/muon.py b/olmocr/train/muon.py index 771e9c7..924ff6d 100644 --- a/olmocr/train/muon.py +++ b/olmocr/train/muon.py @@ -14,8 +14,8 @@ def zeropower_via_newtonschulz5(G, steps: int): where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model performance at all relative to UV^T, where USV^T = G is the SVD. """ - assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng - a, b, c = (3.4445, -4.7750, 2.0315) + assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng + a, b, c = (3.4445, -4.7750, 2.0315) X = G.bfloat16() if G.size(-2) > G.size(-1): X = X.mT @@ -25,9 +25,9 @@ def zeropower_via_newtonschulz5(G, steps: int): # Perform the NS iterations for _ in range(steps): A = X @ X.mT - B = b * A + c * A @ A # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng + B = b * A + c * A @ A # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng X = a * X + B @ X - + if G.size(-2) > G.size(-1): X = X.mT return X @@ -36,10 +36,10 @@ def zeropower_via_newtonschulz5(G, steps: int): def muon_update(grad, momentum, beta=0.95, ns_steps=5, nesterov=True): momentum.lerp_(grad, 1 - beta) update = grad.lerp_(momentum, beta) if nesterov else momentum - if update.ndim == 4: # for the case of conv filters + if update.ndim == 4: # for the case of conv filters update = update.view(len(update), -1) update = zeropower_via_newtonschulz5(update, steps=ns_steps) - update *= max(1, grad.size(-2) / grad.size(-1))**0.5 + update *= max(1, grad.size(-2) / grad.size(-1)) ** 0.5 return update @@ -64,6 +64,7 @@ class Muon(torch.optim.Optimizer): weight_decay: The AdamW-style weight decay. momentum: The momentum. A value of 0.95 here is usually fine. """ + def __init__(self, params, lr=0.02, weight_decay=0, momentum=0.95): defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) assert isinstance(params, list) and len(params) >= 1 and isinstance(params[0], torch.nn.Parameter) @@ -81,7 +82,7 @@ class Muon(torch.optim.Optimizer): for group in self.param_groups: params = group["params"] params_pad = params + [torch.empty_like(params[-1])] * (dist.get_world_size() - len(params) % dist.get_world_size()) - for base_i in range(len(params))[::dist.get_world_size()]: + for base_i in range(len(params))[:: dist.get_world_size()]: if base_i + dist.get_rank() < len(params): p = params[base_i + dist.get_rank()] if p.grad is None: @@ -93,7 +94,7 @@ class Muon(torch.optim.Optimizer): update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"]) p.mul_(1 - group["lr"] * group["weight_decay"]) p.add_(update.reshape(p.shape), alpha=-group["lr"]) - dist.all_gather(params_pad[base_i:base_i + dist.get_world_size()], params_pad[base_i + dist.get_rank()]) + dist.all_gather(params_pad[base_i : base_i + dist.get_world_size()], params_pad[base_i + dist.get_rank()]) return loss @@ -102,6 +103,7 @@ class SingleDeviceMuon(torch.optim.Optimizer): """ Muon variant for usage in non-distributed settings. """ + def __init__(self, params, lr=0.02, weight_decay=0, momentum=0.95): defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) super().__init__(params, defaults) @@ -132,8 +134,8 @@ class SingleDeviceMuon(torch.optim.Optimizer): def adam_update(grad, buf1, buf2, step, betas, eps): buf1.lerp_(grad, 1 - betas[0]) buf2.lerp_(grad.square(), 1 - betas[1]) - buf1c = buf1 / (1 - betas[0]**step) - buf2c = buf2 / (1 - betas[1]**step) + buf1c = buf1 / (1 - betas[0] ** step) + buf2c = buf2 / (1 - betas[1] ** step) return buf1c / (buf2c.sqrt() + eps) @@ -164,6 +166,7 @@ class MuonWithAuxAdam(torch.optim.Optimizer): optimizer = MuonWithAuxAdam(param_groups) ``` """ + def __init__(self, param_groups): for group in param_groups: assert "use_muon" in group @@ -195,7 +198,7 @@ class MuonWithAuxAdam(torch.optim.Optimizer): if group["use_muon"]: params = group["params"] params_pad = params + [torch.empty_like(params[-1])] * (dist.get_world_size() - len(params) % dist.get_world_size()) - for base_i in range(len(params))[::dist.get_world_size()]: + for base_i in range(len(params))[:: dist.get_world_size()]: if base_i + dist.get_rank() < len(params): p = params[base_i + dist.get_rank()] if p.grad is None: @@ -207,7 +210,7 @@ class MuonWithAuxAdam(torch.optim.Optimizer): update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"]) p.mul_(1 - group["lr"] * group["weight_decay"]) p.add_(update.reshape(p.shape), alpha=-group["lr"]) - dist.all_gather(params_pad[base_i:base_i + dist.get_world_size()], params_pad[base_i + dist.get_rank()]) + dist.all_gather(params_pad[base_i : base_i + dist.get_world_size()], params_pad[base_i + dist.get_rank()]) else: for p in group["params"]: if p.grad is None: @@ -219,8 +222,7 @@ class MuonWithAuxAdam(torch.optim.Optimizer): state["exp_avg_sq"] = torch.zeros_like(p) state["step"] = 0 state["step"] += 1 - update = adam_update(p.grad, state["exp_avg"], state["exp_avg_sq"], - state["step"], group["betas"], group["eps"]) + update = adam_update(p.grad, state["exp_avg"], state["exp_avg_sq"], state["step"], group["betas"], group["eps"]) p.mul_(1 - group["lr"] * group["weight_decay"]) p.add_(update, alpha=-group["lr"]) @@ -231,6 +233,7 @@ class SingleDeviceMuonWithAuxAdam(torch.optim.Optimizer): """ Non-distributed variant of MuonWithAuxAdam. """ + def __init__(self, param_groups): for group in param_groups: assert "use_muon" in group @@ -280,9 +283,8 @@ class SingleDeviceMuonWithAuxAdam(torch.optim.Optimizer): state["exp_avg_sq"] = torch.zeros_like(p) state["step"] = 0 state["step"] += 1 - update = adam_update(p.grad, state["exp_avg"], state["exp_avg_sq"], - state["step"], group["betas"], group["eps"]) + update = adam_update(p.grad, state["exp_avg"], state["exp_avg_sq"], state["step"], group["betas"], group["eps"]) p.mul_(1 - group["lr"] * group["weight_decay"]) p.add_(update, alpha=-group["lr"]) - return loss \ No newline at end of file + return loss diff --git a/olmocr/train/prepare_olmocrmix.py b/olmocr/train/prepare_olmocrmix.py index 1092c51..e785f40 100644 --- a/olmocr/train/prepare_olmocrmix.py +++ b/olmocr/train/prepare_olmocrmix.py @@ -163,7 +163,7 @@ def prepare_olmocr_mix(dataset_path: str, subset: str, split: str, destination: f.write(natural_text) else: f.write("---") - + # Look for matching PDF in extracted directory and create symlinks extracted_pdfs_dir = dest_path / "hugging_face" / "pdf_tarballs" / "extracted" diff --git a/olmocr/train/train.py b/olmocr/train/train.py index ef869ea..33b74a4 100644 --- a/olmocr/train/train.py +++ b/olmocr/train/train.py @@ -4,26 +4,25 @@ Simple script to test OlmOCR dataset loading with YAML configuration. import argparse import logging -import os import math +import os import shutil +from typing import Any, Dict, Optional import numpy as np import torch -from torch.utils.data import ConcatDataset, DataLoader -from torch.optim import AdamW -from torch.amp import autocast import wandb +from torch.amp import autocast +from torch.optim import AdamW +from torch.utils.data import ConcatDataset, DataLoader from tqdm import tqdm - from transformers import ( AutoProcessor, - get_scheduler, Qwen2_5_VLForConditionalGeneration, Qwen2VLForConditionalGeneration, + get_scheduler, ) -from typing import Optional, Dict, Any from olmocr.train.config import Config from olmocr.train.dataloader import BaseMarkdownPDFDataset from olmocr.train.muon import SingleDeviceMuonWithAuxAdam @@ -37,7 +36,6 @@ logging.basicConfig( logger = logging.getLogger(__name__) - class QwenDataCollator: """Data collator for vision-language models that handles numpy arrays.""" @@ -80,7 +78,7 @@ class QwenDataCollator: # Check if we have any valid samples if not batch["input_ids"]: return None - + # Convert lists to tensors with proper padding # Note: For Qwen2-VL, we typically handle variable length sequences # The model's processor should handle the padding internally @@ -107,14 +105,14 @@ def save_checkpoint( """Save model, optimizer, scheduler, and training state.""" checkpoint_dir = os.path.join(output_dir, f"checkpoint-{global_step}") os.makedirs(checkpoint_dir, exist_ok=True) - + # Save model model.save_pretrained(checkpoint_dir) - + # Save optimizer and scheduler torch.save(optimizer.state_dict(), os.path.join(checkpoint_dir, "optimizer.pt")) torch.save(lr_scheduler.state_dict(), os.path.join(checkpoint_dir, "scheduler.pt")) - + # Save training state state = { "epoch": epoch, @@ -123,15 +121,12 @@ def save_checkpoint( "best_metric": best_metric, } torch.save(state, os.path.join(checkpoint_dir, "training_state.pt")) - + logger.info(f"Saved checkpoint to {checkpoint_dir}") - + # Enforce save_total_limit by removing oldest checkpoints if save_total_limit is not None and save_total_limit > 0: - checkpoints = sorted( - [d for d in os.listdir(output_dir) if d.startswith("checkpoint-")], - key=lambda x: int(x.split("-")[1]) - ) + checkpoints = sorted([d for d in os.listdir(output_dir) if d.startswith("checkpoint-")], key=lambda x: int(x.split("-")[1])) while len(checkpoints) > save_total_limit: oldest = checkpoints.pop(0) shutil.rmtree(os.path.join(output_dir, oldest)) @@ -149,10 +144,10 @@ def load_checkpoint( """Load model, optimizer, scheduler, and training state from checkpoint.""" model = model_class.from_pretrained(checkpoint_dir, **init_kwargs) model.to(device) - + optimizer.load_state_dict(torch.load(os.path.join(checkpoint_dir, "optimizer.pt"), map_location=device)) lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint_dir, "scheduler.pt"), map_location=device)) - + state = torch.load(os.path.join(checkpoint_dir, "training_state.pt"), map_location=device) logger.info(f"Resumed from checkpoint: {checkpoint_dir} at epoch {state['epoch']:.2f}, step {state['global_step']}, samples seen {state['samples_seen']}") return model, state @@ -166,11 +161,11 @@ def evaluate_model( """Evaluate on all eval datasets and return average loss per dataset.""" model.eval() eval_metrics = {} - + for dataset_name, dataloader in eval_dataloaders.items(): total_loss = 0.0 num_batches = 0 - + with torch.no_grad(): for batch in dataloader: # Skip if batch is None (all samples were filtered out) @@ -181,16 +176,16 @@ def evaluate_model( outputs = model(**batch) total_loss += outputs.loss.item() num_batches += 1 - + avg_loss = total_loss / num_batches if num_batches > 0 else 0.0 eval_metrics[f"eval_{dataset_name}_loss"] = avg_loss logger.info(f"Eval {dataset_name} loss: {avg_loss:.4f}") - + # Compute overall eval loss as average across datasets (or customize as needed) if eval_metrics: overall_loss = sum(eval_metrics.values()) / len(eval_metrics) eval_metrics["eval_loss"] = overall_loss - + return eval_metrics @@ -215,11 +210,11 @@ def main(): if config.project_name: os.environ["WANDB_PROJECT"] = config.project_name logger.info(f"Setting WANDB_PROJECT to: {config.project_name}") - + # Initialize wandb if reporting to it if "wandb" in config.training.report_to: wandb.init(project=config.project_name, name=config.run_name, config=config.to_dict()) - + # Load processor for tokenization logger.info(f"Loading processor: {config.model.name}") processor = AutoProcessor.from_pretrained( @@ -284,7 +279,6 @@ def main(): if len(dataset) > 0: eval_datasets[dataset_name] = dataset - # Log total evaluation samples across all datasets total_eval_samples = sum(len(dataset) for dataset in eval_datasets.values()) logger.info(f"Total evaluation samples across {len(eval_datasets)} datasets: {total_eval_samples}") @@ -310,14 +304,15 @@ def main(): # Set seeds torch.manual_seed(config.training.seed) - + # Set up data loader seed worker function def seed_worker(worker_id): worker_seed = torch.initial_seed() % 2**32 np.random.seed(worker_seed) import random + random.seed(worker_seed) - + # Create generator for data loader generator = None if config.training.data_seed is not None: @@ -327,7 +322,7 @@ def main(): # Device setup device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) - + # Apply torch compile if enabled if config.training.torch_compile: logger.info(f"Compiling model with torch.compile (backend={config.training.torch_compile_backend}, mode={config.training.torch_compile_mode})") @@ -365,29 +360,29 @@ def main(): embed_params = [p for n, p in model.named_parameters() if "embed" in n] scalar_params = [p for p in model.parameters() if p.ndim < 2] head_params = [p for n, p in model.named_parameters() if "lm_head" in n] - + # Create Adam groups with different learning rates adam_groups = [ dict(params=head_params, lr=float(config.training.learning_rate) * config.training.muon_lr_multiplier_head, use_muon=False), dict(params=embed_params, lr=float(config.training.learning_rate) * config.training.muon_lr_multiplier_embed, use_muon=False), - dict(params=scalar_params, lr=float(config.training.learning_rate) * config.training.muon_lr_multiplier_scalar, use_muon=False) + dict(params=scalar_params, lr=float(config.training.learning_rate) * config.training.muon_lr_multiplier_scalar, use_muon=False), ] - + # Add Adam hyperparameters to groups for g in adam_groups: g["betas"] = (config.training.adam_beta1, config.training.adam_beta2) g["eps"] = float(config.training.adam_epsilon) g["weight_decay"] = config.training.weight_decay - + # Create Muon group muon_group = dict( params=hidden_matrix_params, lr=float(config.training.learning_rate), momentum=config.training.muon_momentum, weight_decay=config.training.weight_decay, - use_muon=True + use_muon=True, ) - + # Combine all groups param_groups = [*adam_groups, muon_group] optimizer = SingleDeviceMuonWithAuxAdam(param_groups) @@ -416,7 +411,7 @@ def main(): global_step = 0 samples_seen = 0 best_metric = float("inf") if not config.training.greater_is_better else -float("inf") - + if found_resumable_checkpoint: model, state = load_checkpoint(model_class, model_init_kwargs, optimizer, lr_scheduler, found_resumable_checkpoint, device) global_step = state["global_step"] @@ -457,7 +452,7 @@ def main(): current_epoch = samples_seen / len(train_dataset) logger.info(f"Starting training from epoch {current_epoch:.2f} (step {global_step}, samples {samples_seen}) to {config.training.num_train_epochs} epochs") logger.info(f"Total training steps: {max_train_steps}, Total samples to process: {max_train_samples}") - + if samples_seen >= max_train_samples: logger.info("Training already completed based on samples seen!") logger.info("Skipping to final model save.") @@ -465,7 +460,7 @@ def main(): model.train() accumulated_loss = 0.0 num_losses_accumulated = 0 - + # Create epoch iterator and skip samples if resuming epoch_iterator = iter(train_dataloader) if samples_seen > 0: @@ -479,10 +474,10 @@ def main(): # We've reached the end of the epoch while skipping, create new iterator epoch_iterator = iter(train_dataloader) break - + # Create progress bar pbar = tqdm(total=max_train_samples - samples_seen, desc=f"Training from step {global_step}", unit="samples") - + while samples_seen < max_train_samples and global_step < max_train_steps: try: batch = next(epoch_iterator) @@ -492,48 +487,43 @@ def main(): logger.info(f"Completed epoch {current_epoch:.2f}") epoch_iterator = iter(train_dataloader) batch = next(epoch_iterator) - + # Skip if batch is None (all samples were filtered out) if batch is None: continue - + batch = {k: v.to(device) for k, v in batch.items()} - + with autocast(device_type="cuda", enabled=True, dtype=torch.bfloat16): outputs = model(**batch) loss = outputs.loss / config.training.gradient_accumulation_steps loss.backward() - + accumulated_loss += outputs.loss.item() # Use undivided loss for logging num_losses_accumulated += 1 samples_seen += config.training.per_device_train_batch_size - + # Update progress bar pbar.update(config.training.per_device_train_batch_size) - + # Check if we should do a gradient update if samples_seen % samples_per_step == 0 or samples_seen >= max_train_samples: # Clip gradients torch.nn.utils.clip_grad_norm_(model.parameters(), config.training.max_grad_norm) - + # Step optimizer and scheduler optimizer.step() lr_scheduler.step() optimizer.zero_grad() - + global_step += 1 current_epoch = samples_seen / len(train_dataset) - + # Update progress bar with current stats current_lr = lr_scheduler.get_last_lr()[0] avg_loss = accumulated_loss / num_losses_accumulated if num_losses_accumulated > 0 else 0 - pbar.set_postfix({ - 'loss': f'{avg_loss:.4f}', - 'lr': f'{current_lr:.2e}', - 'epoch': f'{current_epoch:.2f}', - 'step': global_step - }) - + pbar.set_postfix({"loss": f"{avg_loss:.4f}", "lr": f"{current_lr:.2e}", "epoch": f"{current_epoch:.2f}", "step": global_step}) + # Logging if config.training.logging_steps > 0 and global_step % config.training.logging_steps == 0: avg_train_loss = accumulated_loss / num_losses_accumulated if num_losses_accumulated > 0 else 0 @@ -546,52 +536,49 @@ def main(): logger.info(f"Step {global_step}: epoch={current_epoch:.3f}, loss={avg_train_loss:.4f}, lr={lr_scheduler.get_last_lr()[0]:.2e}") if "wandb" in config.training.report_to: wandb.log(logs, step=global_step) - + accumulated_loss = 0.0 num_losses_accumulated = 0 - + # Evaluation if config.training.eval_steps > 0 and global_step % config.training.eval_steps == 0 and global_step > 0: metrics = evaluate_model(model, eval_dataloaders, device) logger.info(f"Evaluation at step {global_step}: {metrics}") if "wandb" in config.training.report_to: wandb.log(metrics, step=global_step) - + # Update best metric current_metric = metrics.get(config.training.metric_for_best_model, None) if current_metric is not None: - if (config.training.greater_is_better and current_metric > best_metric) or \ - (not config.training.greater_is_better and current_metric < best_metric): + if (config.training.greater_is_better and current_metric > best_metric) or ( + not config.training.greater_is_better and current_metric < best_metric + ): best_metric = current_metric - + # Return to training mode model.train() - + # Saving if config.training.save_steps > 0 and global_step % config.training.save_steps == 0: save_checkpoint( - model, optimizer, lr_scheduler, current_epoch, global_step, samples_seen, best_metric, - full_output_dir, config.training.save_total_limit + model, optimizer, lr_scheduler, current_epoch, global_step, samples_seen, best_metric, full_output_dir, config.training.save_total_limit ) - + # Check if we've reached our training limit if samples_seen >= max_train_samples or global_step >= max_train_steps: break - + # Close progress bar pbar.close() # Save the final checkpoint with step number logger.info(f"Saving final checkpoint at step {global_step}...") - save_checkpoint( - model, optimizer, lr_scheduler, current_epoch, global_step, samples_seen, best_metric, - full_output_dir, config.training.save_total_limit - ) - + save_checkpoint(model, optimizer, lr_scheduler, current_epoch, global_step, samples_seen, best_metric, full_output_dir, config.training.save_total_limit) + # Log final training state final_epoch = samples_seen / len(train_dataset) logger.info(f"Training completed at epoch {final_epoch:.3f}, step {global_step}, samples {samples_seen}") - + # Final evaluation final_metrics = evaluate_model(model, eval_dataloaders, device) logger.info(f"Final evaluation metrics: {final_metrics}") @@ -601,4 +588,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/olmocr/work_queue.py b/olmocr/work_queue.py index 03eb54e..e8e9e5f 100644 --- a/olmocr/work_queue.py +++ b/olmocr/work_queue.py @@ -171,7 +171,6 @@ class WorkQueue: logger.info(f"Initialized queue with {self.size:,} work items") return self.size - async def get_work(self, worker_lock_timeout_secs: int = 1800) -> Optional[WorkItem]: """ Get the next available work item that isn't completed or locked. @@ -179,7 +178,6 @@ class WorkQueue: REFRESH_COMPLETED_HASH_CACHE_MAX_ATTEMPTS = 3 refresh_completed_hash_attempt = 0 - while True: try: work_item = self._queue.get_nowait() @@ -221,7 +219,7 @@ class WorkQueue: """ # Create done flag in done_flags_dir await self.backend.create_done_flag(work_item.hash) - + # Remove the worker lock await self.backend.delete_worker_lock(work_item.hash) self._queue.task_done() @@ -281,11 +279,7 @@ class LocalBackend(Backend): def _list_completed() -> Set[str]: if not os.path.isdir(self._done_flags_dir): return set() - return { - f[len("done_") : -len(".flag")] - for f in os.listdir(self._done_flags_dir) - if f.startswith("done_") and f.endswith(".flag") - } + return {f[len("done_") : -len(".flag")] for f in os.listdir(self._done_flags_dir) if f.startswith("done_") and f.endswith(".flag")} return await asyncio.to_thread(_list_completed) @@ -299,6 +293,7 @@ class LocalBackend(Backend): async def _get_object_mtime(self, path: str) -> Optional[datetime.datetime]: """Internal method to get object mtime.""" + def _get_mtime() -> Optional[datetime.datetime]: if not os.path.exists(path): return None @@ -310,17 +305,17 @@ class LocalBackend(Backend): """Check if a worker lock is taken and not stale.""" lock_path = self._get_worker_lock_path(work_hash) lock_mtime = await self._get_object_mtime(lock_path) - + if not lock_mtime: return False - + now = datetime.datetime.now(datetime.timezone.utc) return (now - lock_mtime).total_seconds() <= worker_lock_timeout_secs async def create_worker_lock(self, work_hash: str) -> None: """Create a worker lock for a work hash.""" lock_path = self._get_worker_lock_path(work_hash) - + def _create() -> None: with open(lock_path, "wb"): pass @@ -330,7 +325,7 @@ class LocalBackend(Backend): async def delete_worker_lock(self, work_hash: str) -> None: """Delete the worker lock for a work hash if it exists.""" lock_path = self._get_worker_lock_path(work_hash) - + def _delete() -> None: if os.path.exists(lock_path): os.remove(lock_path) @@ -345,7 +340,7 @@ class LocalBackend(Backend): async def create_done_flag(self, work_hash: str) -> None: """Create a done flag for a work hash.""" done_flag_path = self._get_done_flag_path(work_hash) - + def _create() -> None: with open(done_flag_path, "wb"): pass @@ -406,10 +401,10 @@ class S3Backend(Backend): """Check if a worker lock is taken and not stale.""" lock_path = self._get_worker_lock_path(work_hash) lock_mtime = await self._get_object_mtime(lock_path) - + if not lock_mtime: return False - + now = datetime.datetime.now(datetime.timezone.utc) return (now - lock_mtime).total_seconds() <= worker_lock_timeout_secs @@ -434,4 +429,4 @@ class S3Backend(Backend): """Create a done flag for a work hash.""" done_flag_path = self._get_done_flag_path(work_hash) bucket, key = parse_s3_path(done_flag_path) - await asyncio.to_thread(self.s3_client.put_object, Bucket=bucket, Key=key, Body=b"") \ No newline at end of file + await asyncio.to_thread(self.s3_client.put_object, Bucket=bucket, Key=key, Body=b"") diff --git a/tests/test_s3_work_queue.py b/tests/test_s3_work_queue.py index 15667d0..2f9e74b 100644 --- a/tests/test_s3_work_queue.py +++ b/tests/test_s3_work_queue.py @@ -6,7 +6,7 @@ from unittest.mock import Mock, patch from botocore.exceptions import ClientError # Import the classes we're testing -from olmocr.work_queue import WorkQueue, S3Backend, WorkItem +from olmocr.work_queue import S3Backend, WorkItem, WorkQueue class TestS3WorkQueue(unittest.TestCase): @@ -214,7 +214,7 @@ class TestS3WorkQueue(unittest.TestCase): self.assertEqual(len(put_calls), 1) done_flag_key = put_calls[0][1]["Key"] self.assertTrue(done_flag_key.endswith(f"done_{work_item.hash}.flag")) - + # Verify lock file was deleted self.s3_client.delete_object.assert_called_once() key = self.s3_client.delete_object.call_args[1]["Key"]