mirror of
https://github.com/allenai/olmocr.git
synced 2025-10-28 00:24:16 +00:00
Lint fixes
This commit is contained in:
parent
05330150ad
commit
93411a80a0
@ -49,7 +49,7 @@ from olmocr.s3_utils import (
|
||||
)
|
||||
from olmocr.train.dataloader import FrontMatterParser
|
||||
from olmocr.version import VERSION
|
||||
from olmocr.work_queue import WorkQueue, LocalBackend, S3Backend
|
||||
from olmocr.work_queue import LocalBackend, S3Backend, WorkQueue
|
||||
|
||||
# Initialize logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -204,11 +204,11 @@ class TrainingConfig:
|
||||
adam_epsilon: float = 1e-8
|
||||
weight_decay: float = 0.01
|
||||
max_grad_norm: float = 1.0
|
||||
|
||||
|
||||
# Muon optimizer specific settings
|
||||
muon_momentum: float = 0.95
|
||||
muon_lr_multiplier_head: float = 11.0 # Learning rate multiplier for head parameters
|
||||
muon_lr_multiplier_embed: float = 30.0 # Learning rate multiplier for embedding parameters
|
||||
muon_lr_multiplier_embed: float = 30.0 # Learning rate multiplier for embedding parameters
|
||||
muon_lr_multiplier_scalar: float = 2.0 # Learning rate multiplier for scalar parameters
|
||||
|
||||
# Gradient checkpointing
|
||||
@ -243,7 +243,7 @@ class TrainingConfig:
|
||||
# Data collator settings
|
||||
collator_max_token_len: Optional[int] = None
|
||||
remove_unused_columns: bool = False # Important for custom datasets
|
||||
|
||||
|
||||
# Torch compile settings
|
||||
torch_compile: bool = False
|
||||
torch_compile_backend: str = "inductor" # "inductor", "aot_eager", "cudagraphs", etc.
|
||||
@ -394,9 +394,7 @@ class Config:
|
||||
steps.append(FrontMatterParser(front_matter_class=front_matter_class))
|
||||
|
||||
elif step_name == "PDFRenderer":
|
||||
steps.append(
|
||||
PDFRenderer(target_longest_image_dim=step_config.get("target_longest_image_dim", 1024))
|
||||
)
|
||||
steps.append(PDFRenderer(target_longest_image_dim=step_config.get("target_longest_image_dim", 1024)))
|
||||
|
||||
elif step_name == "StaticLengthDocumentAnchoring":
|
||||
steps.append(StaticLengthDocumentAnchoring(target_anchor_text_len=step_config.get("target_anchor_text_len", 6000)))
|
||||
@ -417,9 +415,7 @@ class Config:
|
||||
steps.append(FrontMatterOutputFormat())
|
||||
|
||||
elif step_name == "InstructUserMessages":
|
||||
steps.append(InstructUserMessages(
|
||||
prompt_first=step_config.get("prompt_first", False)
|
||||
))
|
||||
steps.append(InstructUserMessages(prompt_first=step_config.get("prompt_first", False)))
|
||||
|
||||
elif step_name == "LatexBracketNormalizer":
|
||||
steps.append(LatexBracketNormalizer())
|
||||
@ -457,24 +453,16 @@ class Config:
|
||||
masking_index=step_config.get("masking_index", -100),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
elif step_name == "FilterOutRotatedDocuments":
|
||||
steps.append(FilterOutRotatedDocuments())
|
||||
|
||||
|
||||
elif step_name == "RotationAugmentation":
|
||||
steps.append(
|
||||
RotationAugmentation(
|
||||
probability=step_config.get("probability", 0.5)
|
||||
)
|
||||
)
|
||||
|
||||
steps.append(RotationAugmentation(probability=step_config.get("probability", 0.5)))
|
||||
|
||||
elif step_name == "AugraphyBasicAugmentations":
|
||||
steps.append(
|
||||
AugraphyBasicAugmentations(
|
||||
probability=step_config.get("probability", 0.5)
|
||||
)
|
||||
)
|
||||
|
||||
steps.append(AugraphyBasicAugmentations(probability=step_config.get("probability", 0.5)))
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown pipeline step: {step_name}")
|
||||
|
||||
|
||||
@ -5,13 +5,11 @@ import re
|
||||
from abc import ABC, abstractmethod
|
||||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||
from dataclasses import dataclass, fields
|
||||
from functools import reduce
|
||||
from io import BytesIO
|
||||
from os import PathLike
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
@ -144,8 +142,8 @@ class BaseMarkdownPDFDataset(Dataset):
|
||||
pbar.update(1)
|
||||
|
||||
# Sort samples by markdown path for consistent ordering across runs
|
||||
self.samples.sort(key=lambda x: x['markdown_path'])
|
||||
|
||||
self.samples.sort(key=lambda x: x["markdown_path"])
|
||||
|
||||
logger.info(f"Found {valid_count} valid markdown-PDF pairs")
|
||||
|
||||
if invalid_pdfs:
|
||||
@ -178,7 +176,7 @@ class BaseMarkdownPDFDataset(Dataset):
|
||||
sample = step(sample)
|
||||
if sample is None:
|
||||
return None
|
||||
|
||||
|
||||
return sample
|
||||
|
||||
|
||||
@ -440,26 +438,26 @@ class LatexBracketNormalizer(PipelineStep):
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class RotationAugmentation(PipelineStep):
|
||||
"""Pipeline step that randomly rotates images for augmentation."""
|
||||
|
||||
|
||||
probability: float = 0.5 # Probability of applying rotation
|
||||
|
||||
|
||||
def __call__(self, sample: Sample) -> Optional[Sample]:
|
||||
"""Randomly rotate image and update rotation metadata."""
|
||||
# Only proceed with given probability
|
||||
if np.random.random() > self.probability:
|
||||
return sample
|
||||
|
||||
|
||||
# Check if image exists
|
||||
if "image" not in sample:
|
||||
return sample
|
||||
|
||||
|
||||
# Check if page_data exists (we need to update it)
|
||||
if "page_data" not in sample:
|
||||
return sample
|
||||
|
||||
|
||||
# Randomly choose a rotation (90, 180, or 270 degrees)
|
||||
rotation_degrees = np.random.choice([90, 180, 270])
|
||||
|
||||
|
||||
# Apply rotation to image
|
||||
image = sample["image"]
|
||||
if rotation_degrees == 90:
|
||||
@ -468,13 +466,13 @@ class RotationAugmentation(PipelineStep):
|
||||
transpose = Image.Transpose.ROTATE_180
|
||||
else: # 270
|
||||
transpose = Image.Transpose.ROTATE_270
|
||||
|
||||
|
||||
rotated_image = image.transpose(transpose)
|
||||
sample["image"] = rotated_image
|
||||
|
||||
|
||||
# Update page_data
|
||||
page_data = sample["page_data"]
|
||||
|
||||
|
||||
# Create new PageResponse with updated rotation info
|
||||
# The rotation_correction should be the inverse of what we applied
|
||||
# If we rotated 90 clockwise, we need 270 counter-clockwise to correct it
|
||||
@ -484,9 +482,9 @@ class RotationAugmentation(PipelineStep):
|
||||
correction = 180
|
||||
else: # 270
|
||||
correction = 90
|
||||
|
||||
|
||||
from olmocr.prompts.prompts import PageResponse
|
||||
|
||||
|
||||
new_page_data = PageResponse(
|
||||
primary_language=page_data.primary_language,
|
||||
is_rotation_valid=False, # Mark as invalid since we rotated it
|
||||
@ -495,7 +493,7 @@ class RotationAugmentation(PipelineStep):
|
||||
is_diagram=page_data.is_diagram,
|
||||
natural_text=page_data.natural_text,
|
||||
)
|
||||
|
||||
|
||||
sample["page_data"] = new_page_data
|
||||
return sample
|
||||
|
||||
@ -509,24 +507,24 @@ class FilterOutRotatedDocuments(PipelineStep):
|
||||
# Check if page_data exists
|
||||
if "page_data" not in sample:
|
||||
return sample
|
||||
|
||||
|
||||
page_data = sample["page_data"]
|
||||
|
||||
|
||||
# Check if page_data has the required attributes
|
||||
if not hasattr(page_data, "is_rotation_valid") or not hasattr(page_data, "rotation_correction"):
|
||||
return sample
|
||||
|
||||
|
||||
# Filter out if rotation is invalid or rotation correction is not 0
|
||||
if page_data.is_rotation_valid is False or page_data.rotation_correction != 0:
|
||||
return None
|
||||
|
||||
|
||||
return sample
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class AugraphyBasicAugmentations(PipelineStep):
|
||||
"""Pipeline step that applies a decent selection of augraphy augmentations to the data"""
|
||||
|
||||
|
||||
probability: float = 0.5 # Overall probability of applying any augmentation
|
||||
|
||||
def __call__(self, sample: Sample) -> Optional[Sample]:
|
||||
@ -534,103 +532,96 @@ class AugraphyBasicAugmentations(PipelineStep):
|
||||
# Check that the image data exists
|
||||
if "image" not in sample:
|
||||
return sample
|
||||
|
||||
|
||||
image = sample["image"]
|
||||
|
||||
|
||||
# Skip all augmentations based on overall probability
|
||||
if np.random.random() > self.probability:
|
||||
return sample
|
||||
|
||||
|
||||
# Convert from PIL to BGR for OpenCV/Augraphy
|
||||
image_numpy = np.array(image)
|
||||
if len(image_numpy.shape) < 3:
|
||||
image_bgr = cv2.cvtColor(image_numpy, cv2.COLOR_GRAY2BGR)
|
||||
else:
|
||||
image_bgr = cv2.cvtColor(image_numpy, cv2.COLOR_RGB2BGR)
|
||||
|
||||
|
||||
# Apply a basic augraphy pipeline
|
||||
from augraphy import (
|
||||
AugraphyPipeline,
|
||||
Brightness,
|
||||
InkBleed,
|
||||
InkMottling,
|
||||
InkShifter,
|
||||
Jpeg,
|
||||
LowInkPeriodicLines,
|
||||
LowInkRandomLines,
|
||||
OneOf,
|
||||
|
||||
Jpeg,
|
||||
InkMottling,
|
||||
InkShifter,
|
||||
Brightness,
|
||||
)
|
||||
|
||||
# Apply geometric transformations first, maintaing scale
|
||||
# Apply geometric transformations first, maintaing scale
|
||||
if np.random.random() < 0.50:
|
||||
# Get dimensions
|
||||
height, width = image_bgr.shape[:2]
|
||||
|
||||
|
||||
# Random parameters for geometric transformations
|
||||
angle = max(min(np.random.standard_normal(), 3), -3) # Small rotation range
|
||||
scale = np.random.uniform(0.95, 1.05) # Small scale range
|
||||
tx = np.random.uniform(-0.02, 0.02) * width # Translation as fraction of width
|
||||
ty = np.random.uniform(-0.02, 0.02) * height # Translation as fraction of height
|
||||
|
||||
|
||||
# Calculate center point
|
||||
center = (width / 2, height / 2)
|
||||
|
||||
|
||||
# Create transformation matrix
|
||||
M = cv2.getRotationMatrix2D(center, angle, scale)
|
||||
|
||||
|
||||
# Add translation
|
||||
M[0, 2] += tx
|
||||
M[1, 2] += ty
|
||||
|
||||
|
||||
# Apply transformation
|
||||
image_bgr = cv2.warpAffine(
|
||||
image_bgr,
|
||||
M,
|
||||
image_bgr,
|
||||
M,
|
||||
(width, height),
|
||||
flags=cv2.INTER_LINEAR,
|
||||
borderMode=cv2.BORDER_CONSTANT,
|
||||
borderValue=(255, 255, 255) # White background for documents
|
||||
borderValue=(255, 255, 255), # White background for documents
|
||||
)
|
||||
|
||||
|
||||
ink_phase = [
|
||||
OneOf([InkBleed(p=1), LowInkRandomLines(p=1), LowInkPeriodicLines(p=1), InkMottling(p=1), InkShifter(p=1, text_shift_scale_range=(10, 15))], p=0.2),
|
||||
]
|
||||
|
||||
paper_phase = [
|
||||
OneOf([Brightness(p=0.2), Jpeg(p=1)])
|
||||
]
|
||||
paper_phase = [OneOf([Brightness(p=0.2), Jpeg(p=1)])]
|
||||
|
||||
post_phase = [
|
||||
# Empty on purpose or else augmentations are too strong
|
||||
]
|
||||
|
||||
augmentation_pipeline = AugraphyPipeline(
|
||||
ink_phase=ink_phase, paper_phase=paper_phase, post_phase=post_phase
|
||||
)
|
||||
|
||||
augmentation_pipeline = AugraphyPipeline(ink_phase=ink_phase, paper_phase=paper_phase, post_phase=post_phase)
|
||||
|
||||
# Apply augmentations
|
||||
augmented_image_bgr = augmentation_pipeline(image_bgr)
|
||||
|
||||
|
||||
# Convert back to RGB and then to PIL format
|
||||
augmented_image_rgb = cv2.cvtColor(augmented_image_bgr, cv2.COLOR_BGR2RGB)
|
||||
augmented_image_pil = Image.fromarray(augmented_image_rgb)
|
||||
|
||||
|
||||
# Update the sample with the augmented image
|
||||
sample["image"] = augmented_image_pil
|
||||
|
||||
|
||||
# Double-check PIL image size matches original
|
||||
assert augmented_image_pil.size == image.size, (
|
||||
f"PIL image size changed during augmentation: {image.size} -> {augmented_image_pil.size}"
|
||||
)
|
||||
|
||||
assert augmented_image_pil.size == image.size, f"PIL image size changed during augmentation: {image.size} -> {augmented_image_pil.size}"
|
||||
|
||||
return sample
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class InstructUserMessages(PipelineStep):
|
||||
"""Creates instruction-following messages format for training."""
|
||||
|
||||
|
||||
prompt_first: bool = False
|
||||
|
||||
def __call__(self, sample: Sample) -> Sample:
|
||||
@ -913,12 +904,12 @@ if __name__ == "__main__":
|
||||
print(f"PDF file: {sample['pdf_path'].name}")
|
||||
if "image" in sample and hasattr(sample["image"], "size"):
|
||||
print(f"Image size: {sample['image'].size}")
|
||||
|
||||
|
||||
# Save image if requested
|
||||
if args.save_image:
|
||||
sample["image"].save(args.save_image)
|
||||
print(f"Saved image to: {args.save_image}")
|
||||
|
||||
|
||||
if "page_data" in sample:
|
||||
print(f"\nPage data: {sample['page_data']}")
|
||||
if "messages" in sample:
|
||||
|
||||
@ -14,8 +14,8 @@ def zeropower_via_newtonschulz5(G, steps: int):
|
||||
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
|
||||
performance at all relative to UV^T, where USV^T = G is the SVD.
|
||||
"""
|
||||
assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
|
||||
a, b, c = (3.4445, -4.7750, 2.0315)
|
||||
assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
|
||||
a, b, c = (3.4445, -4.7750, 2.0315)
|
||||
X = G.bfloat16()
|
||||
if G.size(-2) > G.size(-1):
|
||||
X = X.mT
|
||||
@ -25,9 +25,9 @@ def zeropower_via_newtonschulz5(G, steps: int):
|
||||
# Perform the NS iterations
|
||||
for _ in range(steps):
|
||||
A = X @ X.mT
|
||||
B = b * A + c * A @ A # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
|
||||
B = b * A + c * A @ A # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
|
||||
X = a * X + B @ X
|
||||
|
||||
|
||||
if G.size(-2) > G.size(-1):
|
||||
X = X.mT
|
||||
return X
|
||||
@ -36,10 +36,10 @@ def zeropower_via_newtonschulz5(G, steps: int):
|
||||
def muon_update(grad, momentum, beta=0.95, ns_steps=5, nesterov=True):
|
||||
momentum.lerp_(grad, 1 - beta)
|
||||
update = grad.lerp_(momentum, beta) if nesterov else momentum
|
||||
if update.ndim == 4: # for the case of conv filters
|
||||
if update.ndim == 4: # for the case of conv filters
|
||||
update = update.view(len(update), -1)
|
||||
update = zeropower_via_newtonschulz5(update, steps=ns_steps)
|
||||
update *= max(1, grad.size(-2) / grad.size(-1))**0.5
|
||||
update *= max(1, grad.size(-2) / grad.size(-1)) ** 0.5
|
||||
return update
|
||||
|
||||
|
||||
@ -64,6 +64,7 @@ class Muon(torch.optim.Optimizer):
|
||||
weight_decay: The AdamW-style weight decay.
|
||||
momentum: The momentum. A value of 0.95 here is usually fine.
|
||||
"""
|
||||
|
||||
def __init__(self, params, lr=0.02, weight_decay=0, momentum=0.95):
|
||||
defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum)
|
||||
assert isinstance(params, list) and len(params) >= 1 and isinstance(params[0], torch.nn.Parameter)
|
||||
@ -81,7 +82,7 @@ class Muon(torch.optim.Optimizer):
|
||||
for group in self.param_groups:
|
||||
params = group["params"]
|
||||
params_pad = params + [torch.empty_like(params[-1])] * (dist.get_world_size() - len(params) % dist.get_world_size())
|
||||
for base_i in range(len(params))[::dist.get_world_size()]:
|
||||
for base_i in range(len(params))[:: dist.get_world_size()]:
|
||||
if base_i + dist.get_rank() < len(params):
|
||||
p = params[base_i + dist.get_rank()]
|
||||
if p.grad is None:
|
||||
@ -93,7 +94,7 @@ class Muon(torch.optim.Optimizer):
|
||||
update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"])
|
||||
p.mul_(1 - group["lr"] * group["weight_decay"])
|
||||
p.add_(update.reshape(p.shape), alpha=-group["lr"])
|
||||
dist.all_gather(params_pad[base_i:base_i + dist.get_world_size()], params_pad[base_i + dist.get_rank()])
|
||||
dist.all_gather(params_pad[base_i : base_i + dist.get_world_size()], params_pad[base_i + dist.get_rank()])
|
||||
|
||||
return loss
|
||||
|
||||
@ -102,6 +103,7 @@ class SingleDeviceMuon(torch.optim.Optimizer):
|
||||
"""
|
||||
Muon variant for usage in non-distributed settings.
|
||||
"""
|
||||
|
||||
def __init__(self, params, lr=0.02, weight_decay=0, momentum=0.95):
|
||||
defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum)
|
||||
super().__init__(params, defaults)
|
||||
@ -132,8 +134,8 @@ class SingleDeviceMuon(torch.optim.Optimizer):
|
||||
def adam_update(grad, buf1, buf2, step, betas, eps):
|
||||
buf1.lerp_(grad, 1 - betas[0])
|
||||
buf2.lerp_(grad.square(), 1 - betas[1])
|
||||
buf1c = buf1 / (1 - betas[0]**step)
|
||||
buf2c = buf2 / (1 - betas[1]**step)
|
||||
buf1c = buf1 / (1 - betas[0] ** step)
|
||||
buf2c = buf2 / (1 - betas[1] ** step)
|
||||
return buf1c / (buf2c.sqrt() + eps)
|
||||
|
||||
|
||||
@ -164,6 +166,7 @@ class MuonWithAuxAdam(torch.optim.Optimizer):
|
||||
optimizer = MuonWithAuxAdam(param_groups)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, param_groups):
|
||||
for group in param_groups:
|
||||
assert "use_muon" in group
|
||||
@ -195,7 +198,7 @@ class MuonWithAuxAdam(torch.optim.Optimizer):
|
||||
if group["use_muon"]:
|
||||
params = group["params"]
|
||||
params_pad = params + [torch.empty_like(params[-1])] * (dist.get_world_size() - len(params) % dist.get_world_size())
|
||||
for base_i in range(len(params))[::dist.get_world_size()]:
|
||||
for base_i in range(len(params))[:: dist.get_world_size()]:
|
||||
if base_i + dist.get_rank() < len(params):
|
||||
p = params[base_i + dist.get_rank()]
|
||||
if p.grad is None:
|
||||
@ -207,7 +210,7 @@ class MuonWithAuxAdam(torch.optim.Optimizer):
|
||||
update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"])
|
||||
p.mul_(1 - group["lr"] * group["weight_decay"])
|
||||
p.add_(update.reshape(p.shape), alpha=-group["lr"])
|
||||
dist.all_gather(params_pad[base_i:base_i + dist.get_world_size()], params_pad[base_i + dist.get_rank()])
|
||||
dist.all_gather(params_pad[base_i : base_i + dist.get_world_size()], params_pad[base_i + dist.get_rank()])
|
||||
else:
|
||||
for p in group["params"]:
|
||||
if p.grad is None:
|
||||
@ -219,8 +222,7 @@ class MuonWithAuxAdam(torch.optim.Optimizer):
|
||||
state["exp_avg_sq"] = torch.zeros_like(p)
|
||||
state["step"] = 0
|
||||
state["step"] += 1
|
||||
update = adam_update(p.grad, state["exp_avg"], state["exp_avg_sq"],
|
||||
state["step"], group["betas"], group["eps"])
|
||||
update = adam_update(p.grad, state["exp_avg"], state["exp_avg_sq"], state["step"], group["betas"], group["eps"])
|
||||
p.mul_(1 - group["lr"] * group["weight_decay"])
|
||||
p.add_(update, alpha=-group["lr"])
|
||||
|
||||
@ -231,6 +233,7 @@ class SingleDeviceMuonWithAuxAdam(torch.optim.Optimizer):
|
||||
"""
|
||||
Non-distributed variant of MuonWithAuxAdam.
|
||||
"""
|
||||
|
||||
def __init__(self, param_groups):
|
||||
for group in param_groups:
|
||||
assert "use_muon" in group
|
||||
@ -280,9 +283,8 @@ class SingleDeviceMuonWithAuxAdam(torch.optim.Optimizer):
|
||||
state["exp_avg_sq"] = torch.zeros_like(p)
|
||||
state["step"] = 0
|
||||
state["step"] += 1
|
||||
update = adam_update(p.grad, state["exp_avg"], state["exp_avg_sq"],
|
||||
state["step"], group["betas"], group["eps"])
|
||||
update = adam_update(p.grad, state["exp_avg"], state["exp_avg_sq"], state["step"], group["betas"], group["eps"])
|
||||
p.mul_(1 - group["lr"] * group["weight_decay"])
|
||||
p.add_(update, alpha=-group["lr"])
|
||||
|
||||
return loss
|
||||
return loss
|
||||
|
||||
@ -163,7 +163,7 @@ def prepare_olmocr_mix(dataset_path: str, subset: str, split: str, destination:
|
||||
f.write(natural_text)
|
||||
else:
|
||||
f.write("---")
|
||||
|
||||
|
||||
# Look for matching PDF in extracted directory and create symlinks
|
||||
extracted_pdfs_dir = dest_path / "hugging_face" / "pdf_tarballs" / "extracted"
|
||||
|
||||
|
||||
@ -4,26 +4,25 @@ Simple script to test OlmOCR dataset loading with YAML configuration.
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import math
|
||||
import os
|
||||
import shutil
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import ConcatDataset, DataLoader
|
||||
from torch.optim import AdamW
|
||||
from torch.amp import autocast
|
||||
import wandb
|
||||
from torch.amp import autocast
|
||||
from torch.optim import AdamW
|
||||
from torch.utils.data import ConcatDataset, DataLoader
|
||||
from tqdm import tqdm
|
||||
|
||||
from transformers import (
|
||||
AutoProcessor,
|
||||
get_scheduler,
|
||||
Qwen2_5_VLForConditionalGeneration,
|
||||
Qwen2VLForConditionalGeneration,
|
||||
get_scheduler,
|
||||
)
|
||||
|
||||
from typing import Optional, Dict, Any
|
||||
from olmocr.train.config import Config
|
||||
from olmocr.train.dataloader import BaseMarkdownPDFDataset
|
||||
from olmocr.train.muon import SingleDeviceMuonWithAuxAdam
|
||||
@ -37,7 +36,6 @@ logging.basicConfig(
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
||||
class QwenDataCollator:
|
||||
"""Data collator for vision-language models that handles numpy arrays."""
|
||||
|
||||
@ -80,7 +78,7 @@ class QwenDataCollator:
|
||||
# Check if we have any valid samples
|
||||
if not batch["input_ids"]:
|
||||
return None
|
||||
|
||||
|
||||
# Convert lists to tensors with proper padding
|
||||
# Note: For Qwen2-VL, we typically handle variable length sequences
|
||||
# The model's processor should handle the padding internally
|
||||
@ -107,14 +105,14 @@ def save_checkpoint(
|
||||
"""Save model, optimizer, scheduler, and training state."""
|
||||
checkpoint_dir = os.path.join(output_dir, f"checkpoint-{global_step}")
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
|
||||
|
||||
# Save model
|
||||
model.save_pretrained(checkpoint_dir)
|
||||
|
||||
|
||||
# Save optimizer and scheduler
|
||||
torch.save(optimizer.state_dict(), os.path.join(checkpoint_dir, "optimizer.pt"))
|
||||
torch.save(lr_scheduler.state_dict(), os.path.join(checkpoint_dir, "scheduler.pt"))
|
||||
|
||||
|
||||
# Save training state
|
||||
state = {
|
||||
"epoch": epoch,
|
||||
@ -123,15 +121,12 @@ def save_checkpoint(
|
||||
"best_metric": best_metric,
|
||||
}
|
||||
torch.save(state, os.path.join(checkpoint_dir, "training_state.pt"))
|
||||
|
||||
|
||||
logger.info(f"Saved checkpoint to {checkpoint_dir}")
|
||||
|
||||
|
||||
# Enforce save_total_limit by removing oldest checkpoints
|
||||
if save_total_limit is not None and save_total_limit > 0:
|
||||
checkpoints = sorted(
|
||||
[d for d in os.listdir(output_dir) if d.startswith("checkpoint-")],
|
||||
key=lambda x: int(x.split("-")[1])
|
||||
)
|
||||
checkpoints = sorted([d for d in os.listdir(output_dir) if d.startswith("checkpoint-")], key=lambda x: int(x.split("-")[1]))
|
||||
while len(checkpoints) > save_total_limit:
|
||||
oldest = checkpoints.pop(0)
|
||||
shutil.rmtree(os.path.join(output_dir, oldest))
|
||||
@ -149,10 +144,10 @@ def load_checkpoint(
|
||||
"""Load model, optimizer, scheduler, and training state from checkpoint."""
|
||||
model = model_class.from_pretrained(checkpoint_dir, **init_kwargs)
|
||||
model.to(device)
|
||||
|
||||
|
||||
optimizer.load_state_dict(torch.load(os.path.join(checkpoint_dir, "optimizer.pt"), map_location=device))
|
||||
lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint_dir, "scheduler.pt"), map_location=device))
|
||||
|
||||
|
||||
state = torch.load(os.path.join(checkpoint_dir, "training_state.pt"), map_location=device)
|
||||
logger.info(f"Resumed from checkpoint: {checkpoint_dir} at epoch {state['epoch']:.2f}, step {state['global_step']}, samples seen {state['samples_seen']}")
|
||||
return model, state
|
||||
@ -166,11 +161,11 @@ def evaluate_model(
|
||||
"""Evaluate on all eval datasets and return average loss per dataset."""
|
||||
model.eval()
|
||||
eval_metrics = {}
|
||||
|
||||
|
||||
for dataset_name, dataloader in eval_dataloaders.items():
|
||||
total_loss = 0.0
|
||||
num_batches = 0
|
||||
|
||||
|
||||
with torch.no_grad():
|
||||
for batch in dataloader:
|
||||
# Skip if batch is None (all samples were filtered out)
|
||||
@ -181,16 +176,16 @@ def evaluate_model(
|
||||
outputs = model(**batch)
|
||||
total_loss += outputs.loss.item()
|
||||
num_batches += 1
|
||||
|
||||
|
||||
avg_loss = total_loss / num_batches if num_batches > 0 else 0.0
|
||||
eval_metrics[f"eval_{dataset_name}_loss"] = avg_loss
|
||||
logger.info(f"Eval {dataset_name} loss: {avg_loss:.4f}")
|
||||
|
||||
|
||||
# Compute overall eval loss as average across datasets (or customize as needed)
|
||||
if eval_metrics:
|
||||
overall_loss = sum(eval_metrics.values()) / len(eval_metrics)
|
||||
eval_metrics["eval_loss"] = overall_loss
|
||||
|
||||
|
||||
return eval_metrics
|
||||
|
||||
|
||||
@ -215,11 +210,11 @@ def main():
|
||||
if config.project_name:
|
||||
os.environ["WANDB_PROJECT"] = config.project_name
|
||||
logger.info(f"Setting WANDB_PROJECT to: {config.project_name}")
|
||||
|
||||
|
||||
# Initialize wandb if reporting to it
|
||||
if "wandb" in config.training.report_to:
|
||||
wandb.init(project=config.project_name, name=config.run_name, config=config.to_dict())
|
||||
|
||||
|
||||
# Load processor for tokenization
|
||||
logger.info(f"Loading processor: {config.model.name}")
|
||||
processor = AutoProcessor.from_pretrained(
|
||||
@ -284,7 +279,6 @@ def main():
|
||||
if len(dataset) > 0:
|
||||
eval_datasets[dataset_name] = dataset
|
||||
|
||||
|
||||
# Log total evaluation samples across all datasets
|
||||
total_eval_samples = sum(len(dataset) for dataset in eval_datasets.values())
|
||||
logger.info(f"Total evaluation samples across {len(eval_datasets)} datasets: {total_eval_samples}")
|
||||
@ -310,14 +304,15 @@ def main():
|
||||
|
||||
# Set seeds
|
||||
torch.manual_seed(config.training.seed)
|
||||
|
||||
|
||||
# Set up data loader seed worker function
|
||||
def seed_worker(worker_id):
|
||||
worker_seed = torch.initial_seed() % 2**32
|
||||
np.random.seed(worker_seed)
|
||||
import random
|
||||
|
||||
random.seed(worker_seed)
|
||||
|
||||
|
||||
# Create generator for data loader
|
||||
generator = None
|
||||
if config.training.data_seed is not None:
|
||||
@ -327,7 +322,7 @@ def main():
|
||||
# Device setup
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
model.to(device)
|
||||
|
||||
|
||||
# Apply torch compile if enabled
|
||||
if config.training.torch_compile:
|
||||
logger.info(f"Compiling model with torch.compile (backend={config.training.torch_compile_backend}, mode={config.training.torch_compile_mode})")
|
||||
@ -365,29 +360,29 @@ def main():
|
||||
embed_params = [p for n, p in model.named_parameters() if "embed" in n]
|
||||
scalar_params = [p for p in model.parameters() if p.ndim < 2]
|
||||
head_params = [p for n, p in model.named_parameters() if "lm_head" in n]
|
||||
|
||||
|
||||
# Create Adam groups with different learning rates
|
||||
adam_groups = [
|
||||
dict(params=head_params, lr=float(config.training.learning_rate) * config.training.muon_lr_multiplier_head, use_muon=False),
|
||||
dict(params=embed_params, lr=float(config.training.learning_rate) * config.training.muon_lr_multiplier_embed, use_muon=False),
|
||||
dict(params=scalar_params, lr=float(config.training.learning_rate) * config.training.muon_lr_multiplier_scalar, use_muon=False)
|
||||
dict(params=scalar_params, lr=float(config.training.learning_rate) * config.training.muon_lr_multiplier_scalar, use_muon=False),
|
||||
]
|
||||
|
||||
|
||||
# Add Adam hyperparameters to groups
|
||||
for g in adam_groups:
|
||||
g["betas"] = (config.training.adam_beta1, config.training.adam_beta2)
|
||||
g["eps"] = float(config.training.adam_epsilon)
|
||||
g["weight_decay"] = config.training.weight_decay
|
||||
|
||||
|
||||
# Create Muon group
|
||||
muon_group = dict(
|
||||
params=hidden_matrix_params,
|
||||
lr=float(config.training.learning_rate),
|
||||
momentum=config.training.muon_momentum,
|
||||
weight_decay=config.training.weight_decay,
|
||||
use_muon=True
|
||||
use_muon=True,
|
||||
)
|
||||
|
||||
|
||||
# Combine all groups
|
||||
param_groups = [*adam_groups, muon_group]
|
||||
optimizer = SingleDeviceMuonWithAuxAdam(param_groups)
|
||||
@ -416,7 +411,7 @@ def main():
|
||||
global_step = 0
|
||||
samples_seen = 0
|
||||
best_metric = float("inf") if not config.training.greater_is_better else -float("inf")
|
||||
|
||||
|
||||
if found_resumable_checkpoint:
|
||||
model, state = load_checkpoint(model_class, model_init_kwargs, optimizer, lr_scheduler, found_resumable_checkpoint, device)
|
||||
global_step = state["global_step"]
|
||||
@ -457,7 +452,7 @@ def main():
|
||||
current_epoch = samples_seen / len(train_dataset)
|
||||
logger.info(f"Starting training from epoch {current_epoch:.2f} (step {global_step}, samples {samples_seen}) to {config.training.num_train_epochs} epochs")
|
||||
logger.info(f"Total training steps: {max_train_steps}, Total samples to process: {max_train_samples}")
|
||||
|
||||
|
||||
if samples_seen >= max_train_samples:
|
||||
logger.info("Training already completed based on samples seen!")
|
||||
logger.info("Skipping to final model save.")
|
||||
@ -465,7 +460,7 @@ def main():
|
||||
model.train()
|
||||
accumulated_loss = 0.0
|
||||
num_losses_accumulated = 0
|
||||
|
||||
|
||||
# Create epoch iterator and skip samples if resuming
|
||||
epoch_iterator = iter(train_dataloader)
|
||||
if samples_seen > 0:
|
||||
@ -479,10 +474,10 @@ def main():
|
||||
# We've reached the end of the epoch while skipping, create new iterator
|
||||
epoch_iterator = iter(train_dataloader)
|
||||
break
|
||||
|
||||
|
||||
# Create progress bar
|
||||
pbar = tqdm(total=max_train_samples - samples_seen, desc=f"Training from step {global_step}", unit="samples")
|
||||
|
||||
|
||||
while samples_seen < max_train_samples and global_step < max_train_steps:
|
||||
try:
|
||||
batch = next(epoch_iterator)
|
||||
@ -492,48 +487,43 @@ def main():
|
||||
logger.info(f"Completed epoch {current_epoch:.2f}")
|
||||
epoch_iterator = iter(train_dataloader)
|
||||
batch = next(epoch_iterator)
|
||||
|
||||
|
||||
# Skip if batch is None (all samples were filtered out)
|
||||
if batch is None:
|
||||
continue
|
||||
|
||||
|
||||
batch = {k: v.to(device) for k, v in batch.items()}
|
||||
|
||||
|
||||
with autocast(device_type="cuda", enabled=True, dtype=torch.bfloat16):
|
||||
outputs = model(**batch)
|
||||
loss = outputs.loss / config.training.gradient_accumulation_steps
|
||||
loss.backward()
|
||||
|
||||
|
||||
accumulated_loss += outputs.loss.item() # Use undivided loss for logging
|
||||
num_losses_accumulated += 1
|
||||
samples_seen += config.training.per_device_train_batch_size
|
||||
|
||||
|
||||
# Update progress bar
|
||||
pbar.update(config.training.per_device_train_batch_size)
|
||||
|
||||
|
||||
# Check if we should do a gradient update
|
||||
if samples_seen % samples_per_step == 0 or samples_seen >= max_train_samples:
|
||||
# Clip gradients
|
||||
torch.nn.utils.clip_grad_norm_(model.parameters(), config.training.max_grad_norm)
|
||||
|
||||
|
||||
# Step optimizer and scheduler
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
|
||||
global_step += 1
|
||||
current_epoch = samples_seen / len(train_dataset)
|
||||
|
||||
|
||||
# Update progress bar with current stats
|
||||
current_lr = lr_scheduler.get_last_lr()[0]
|
||||
avg_loss = accumulated_loss / num_losses_accumulated if num_losses_accumulated > 0 else 0
|
||||
pbar.set_postfix({
|
||||
'loss': f'{avg_loss:.4f}',
|
||||
'lr': f'{current_lr:.2e}',
|
||||
'epoch': f'{current_epoch:.2f}',
|
||||
'step': global_step
|
||||
})
|
||||
|
||||
pbar.set_postfix({"loss": f"{avg_loss:.4f}", "lr": f"{current_lr:.2e}", "epoch": f"{current_epoch:.2f}", "step": global_step})
|
||||
|
||||
# Logging
|
||||
if config.training.logging_steps > 0 and global_step % config.training.logging_steps == 0:
|
||||
avg_train_loss = accumulated_loss / num_losses_accumulated if num_losses_accumulated > 0 else 0
|
||||
@ -546,52 +536,49 @@ def main():
|
||||
logger.info(f"Step {global_step}: epoch={current_epoch:.3f}, loss={avg_train_loss:.4f}, lr={lr_scheduler.get_last_lr()[0]:.2e}")
|
||||
if "wandb" in config.training.report_to:
|
||||
wandb.log(logs, step=global_step)
|
||||
|
||||
|
||||
accumulated_loss = 0.0
|
||||
num_losses_accumulated = 0
|
||||
|
||||
|
||||
# Evaluation
|
||||
if config.training.eval_steps > 0 and global_step % config.training.eval_steps == 0 and global_step > 0:
|
||||
metrics = evaluate_model(model, eval_dataloaders, device)
|
||||
logger.info(f"Evaluation at step {global_step}: {metrics}")
|
||||
if "wandb" in config.training.report_to:
|
||||
wandb.log(metrics, step=global_step)
|
||||
|
||||
|
||||
# Update best metric
|
||||
current_metric = metrics.get(config.training.metric_for_best_model, None)
|
||||
if current_metric is not None:
|
||||
if (config.training.greater_is_better and current_metric > best_metric) or \
|
||||
(not config.training.greater_is_better and current_metric < best_metric):
|
||||
if (config.training.greater_is_better and current_metric > best_metric) or (
|
||||
not config.training.greater_is_better and current_metric < best_metric
|
||||
):
|
||||
best_metric = current_metric
|
||||
|
||||
|
||||
# Return to training mode
|
||||
model.train()
|
||||
|
||||
|
||||
# Saving
|
||||
if config.training.save_steps > 0 and global_step % config.training.save_steps == 0:
|
||||
save_checkpoint(
|
||||
model, optimizer, lr_scheduler, current_epoch, global_step, samples_seen, best_metric,
|
||||
full_output_dir, config.training.save_total_limit
|
||||
model, optimizer, lr_scheduler, current_epoch, global_step, samples_seen, best_metric, full_output_dir, config.training.save_total_limit
|
||||
)
|
||||
|
||||
|
||||
# Check if we've reached our training limit
|
||||
if samples_seen >= max_train_samples or global_step >= max_train_steps:
|
||||
break
|
||||
|
||||
|
||||
# Close progress bar
|
||||
pbar.close()
|
||||
|
||||
# Save the final checkpoint with step number
|
||||
logger.info(f"Saving final checkpoint at step {global_step}...")
|
||||
save_checkpoint(
|
||||
model, optimizer, lr_scheduler, current_epoch, global_step, samples_seen, best_metric,
|
||||
full_output_dir, config.training.save_total_limit
|
||||
)
|
||||
|
||||
save_checkpoint(model, optimizer, lr_scheduler, current_epoch, global_step, samples_seen, best_metric, full_output_dir, config.training.save_total_limit)
|
||||
|
||||
# Log final training state
|
||||
final_epoch = samples_seen / len(train_dataset)
|
||||
logger.info(f"Training completed at epoch {final_epoch:.3f}, step {global_step}, samples {samples_seen}")
|
||||
|
||||
|
||||
# Final evaluation
|
||||
final_metrics = evaluate_model(model, eval_dataloaders, device)
|
||||
logger.info(f"Final evaluation metrics: {final_metrics}")
|
||||
@ -601,4 +588,4 @@ def main():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main()
|
||||
|
||||
@ -171,7 +171,6 @@ class WorkQueue:
|
||||
logger.info(f"Initialized queue with {self.size:,} work items")
|
||||
return self.size
|
||||
|
||||
|
||||
async def get_work(self, worker_lock_timeout_secs: int = 1800) -> Optional[WorkItem]:
|
||||
"""
|
||||
Get the next available work item that isn't completed or locked.
|
||||
@ -179,7 +178,6 @@ class WorkQueue:
|
||||
REFRESH_COMPLETED_HASH_CACHE_MAX_ATTEMPTS = 3
|
||||
refresh_completed_hash_attempt = 0
|
||||
|
||||
|
||||
while True:
|
||||
try:
|
||||
work_item = self._queue.get_nowait()
|
||||
@ -221,7 +219,7 @@ class WorkQueue:
|
||||
"""
|
||||
# Create done flag in done_flags_dir
|
||||
await self.backend.create_done_flag(work_item.hash)
|
||||
|
||||
|
||||
# Remove the worker lock
|
||||
await self.backend.delete_worker_lock(work_item.hash)
|
||||
self._queue.task_done()
|
||||
@ -281,11 +279,7 @@ class LocalBackend(Backend):
|
||||
def _list_completed() -> Set[str]:
|
||||
if not os.path.isdir(self._done_flags_dir):
|
||||
return set()
|
||||
return {
|
||||
f[len("done_") : -len(".flag")]
|
||||
for f in os.listdir(self._done_flags_dir)
|
||||
if f.startswith("done_") and f.endswith(".flag")
|
||||
}
|
||||
return {f[len("done_") : -len(".flag")] for f in os.listdir(self._done_flags_dir) if f.startswith("done_") and f.endswith(".flag")}
|
||||
|
||||
return await asyncio.to_thread(_list_completed)
|
||||
|
||||
@ -299,6 +293,7 @@ class LocalBackend(Backend):
|
||||
|
||||
async def _get_object_mtime(self, path: str) -> Optional[datetime.datetime]:
|
||||
"""Internal method to get object mtime."""
|
||||
|
||||
def _get_mtime() -> Optional[datetime.datetime]:
|
||||
if not os.path.exists(path):
|
||||
return None
|
||||
@ -310,17 +305,17 @@ class LocalBackend(Backend):
|
||||
"""Check if a worker lock is taken and not stale."""
|
||||
lock_path = self._get_worker_lock_path(work_hash)
|
||||
lock_mtime = await self._get_object_mtime(lock_path)
|
||||
|
||||
|
||||
if not lock_mtime:
|
||||
return False
|
||||
|
||||
|
||||
now = datetime.datetime.now(datetime.timezone.utc)
|
||||
return (now - lock_mtime).total_seconds() <= worker_lock_timeout_secs
|
||||
|
||||
async def create_worker_lock(self, work_hash: str) -> None:
|
||||
"""Create a worker lock for a work hash."""
|
||||
lock_path = self._get_worker_lock_path(work_hash)
|
||||
|
||||
|
||||
def _create() -> None:
|
||||
with open(lock_path, "wb"):
|
||||
pass
|
||||
@ -330,7 +325,7 @@ class LocalBackend(Backend):
|
||||
async def delete_worker_lock(self, work_hash: str) -> None:
|
||||
"""Delete the worker lock for a work hash if it exists."""
|
||||
lock_path = self._get_worker_lock_path(work_hash)
|
||||
|
||||
|
||||
def _delete() -> None:
|
||||
if os.path.exists(lock_path):
|
||||
os.remove(lock_path)
|
||||
@ -345,7 +340,7 @@ class LocalBackend(Backend):
|
||||
async def create_done_flag(self, work_hash: str) -> None:
|
||||
"""Create a done flag for a work hash."""
|
||||
done_flag_path = self._get_done_flag_path(work_hash)
|
||||
|
||||
|
||||
def _create() -> None:
|
||||
with open(done_flag_path, "wb"):
|
||||
pass
|
||||
@ -406,10 +401,10 @@ class S3Backend(Backend):
|
||||
"""Check if a worker lock is taken and not stale."""
|
||||
lock_path = self._get_worker_lock_path(work_hash)
|
||||
lock_mtime = await self._get_object_mtime(lock_path)
|
||||
|
||||
|
||||
if not lock_mtime:
|
||||
return False
|
||||
|
||||
|
||||
now = datetime.datetime.now(datetime.timezone.utc)
|
||||
return (now - lock_mtime).total_seconds() <= worker_lock_timeout_secs
|
||||
|
||||
@ -434,4 +429,4 @@ class S3Backend(Backend):
|
||||
"""Create a done flag for a work hash."""
|
||||
done_flag_path = self._get_done_flag_path(work_hash)
|
||||
bucket, key = parse_s3_path(done_flag_path)
|
||||
await asyncio.to_thread(self.s3_client.put_object, Bucket=bucket, Key=key, Body=b"")
|
||||
await asyncio.to_thread(self.s3_client.put_object, Bucket=bucket, Key=key, Body=b"")
|
||||
|
||||
@ -6,7 +6,7 @@ from unittest.mock import Mock, patch
|
||||
from botocore.exceptions import ClientError
|
||||
|
||||
# Import the classes we're testing
|
||||
from olmocr.work_queue import WorkQueue, S3Backend, WorkItem
|
||||
from olmocr.work_queue import S3Backend, WorkItem, WorkQueue
|
||||
|
||||
|
||||
class TestS3WorkQueue(unittest.TestCase):
|
||||
@ -214,7 +214,7 @@ class TestS3WorkQueue(unittest.TestCase):
|
||||
self.assertEqual(len(put_calls), 1)
|
||||
done_flag_key = put_calls[0][1]["Key"]
|
||||
self.assertTrue(done_flag_key.endswith(f"done_{work_item.hash}.flag"))
|
||||
|
||||
|
||||
# Verify lock file was deleted
|
||||
self.s3_client.delete_object.assert_called_once()
|
||||
key = self.s3_client.delete_object.call_args[1]["Key"]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user