Lint fixes

This commit is contained in:
Jake Poznanski 2025-08-13 20:21:04 +00:00
parent 05330150ad
commit 93411a80a0
8 changed files with 157 additions and 194 deletions

View File

@ -49,7 +49,7 @@ from olmocr.s3_utils import (
)
from olmocr.train.dataloader import FrontMatterParser
from olmocr.version import VERSION
from olmocr.work_queue import WorkQueue, LocalBackend, S3Backend
from olmocr.work_queue import LocalBackend, S3Backend, WorkQueue
# Initialize logger
logger = logging.getLogger(__name__)

View File

@ -394,9 +394,7 @@ class Config:
steps.append(FrontMatterParser(front_matter_class=front_matter_class))
elif step_name == "PDFRenderer":
steps.append(
PDFRenderer(target_longest_image_dim=step_config.get("target_longest_image_dim", 1024))
)
steps.append(PDFRenderer(target_longest_image_dim=step_config.get("target_longest_image_dim", 1024)))
elif step_name == "StaticLengthDocumentAnchoring":
steps.append(StaticLengthDocumentAnchoring(target_anchor_text_len=step_config.get("target_anchor_text_len", 6000)))
@ -417,9 +415,7 @@ class Config:
steps.append(FrontMatterOutputFormat())
elif step_name == "InstructUserMessages":
steps.append(InstructUserMessages(
prompt_first=step_config.get("prompt_first", False)
))
steps.append(InstructUserMessages(prompt_first=step_config.get("prompt_first", False)))
elif step_name == "LatexBracketNormalizer":
steps.append(LatexBracketNormalizer())
@ -462,18 +458,10 @@ class Config:
steps.append(FilterOutRotatedDocuments())
elif step_name == "RotationAugmentation":
steps.append(
RotationAugmentation(
probability=step_config.get("probability", 0.5)
)
)
steps.append(RotationAugmentation(probability=step_config.get("probability", 0.5)))
elif step_name == "AugraphyBasicAugmentations":
steps.append(
AugraphyBasicAugmentations(
probability=step_config.get("probability", 0.5)
)
)
steps.append(AugraphyBasicAugmentations(probability=step_config.get("probability", 0.5)))
else:
raise ValueError(f"Unknown pipeline step: {step_name}")

View File

@ -5,13 +5,11 @@ import re
from abc import ABC, abstractmethod
from concurrent.futures import ProcessPoolExecutor, as_completed
from dataclasses import dataclass, fields
from functools import reduce
from io import BytesIO
from os import PathLike
from pathlib import Path
from typing import (
Any,
Callable,
Dict,
List,
Optional,
@ -144,7 +142,7 @@ class BaseMarkdownPDFDataset(Dataset):
pbar.update(1)
# Sort samples by markdown path for consistent ordering across runs
self.samples.sort(key=lambda x: x['markdown_path'])
self.samples.sort(key=lambda x: x["markdown_path"])
logger.info(f"Found {valid_count} valid markdown-PDF pairs")
@ -551,15 +549,14 @@ class AugraphyBasicAugmentations(PipelineStep):
# Apply a basic augraphy pipeline
from augraphy import (
AugraphyPipeline,
Brightness,
InkBleed,
InkMottling,
InkShifter,
Jpeg,
LowInkPeriodicLines,
LowInkRandomLines,
OneOf,
Jpeg,
InkMottling,
InkShifter,
Brightness,
)
# Apply geometric transformations first, maintaing scale
@ -590,25 +587,20 @@ class AugraphyBasicAugmentations(PipelineStep):
(width, height),
flags=cv2.INTER_LINEAR,
borderMode=cv2.BORDER_CONSTANT,
borderValue=(255, 255, 255) # White background for documents
borderValue=(255, 255, 255), # White background for documents
)
ink_phase = [
OneOf([InkBleed(p=1), LowInkRandomLines(p=1), LowInkPeriodicLines(p=1), InkMottling(p=1), InkShifter(p=1, text_shift_scale_range=(10, 15))], p=0.2),
]
paper_phase = [
OneOf([Brightness(p=0.2), Jpeg(p=1)])
]
paper_phase = [OneOf([Brightness(p=0.2), Jpeg(p=1)])]
post_phase = [
# Empty on purpose or else augmentations are too strong
]
augmentation_pipeline = AugraphyPipeline(
ink_phase=ink_phase, paper_phase=paper_phase, post_phase=post_phase
)
augmentation_pipeline = AugraphyPipeline(ink_phase=ink_phase, paper_phase=paper_phase, post_phase=post_phase)
# Apply augmentations
augmented_image_bgr = augmentation_pipeline(image_bgr)
@ -621,12 +613,11 @@ class AugraphyBasicAugmentations(PipelineStep):
sample["image"] = augmented_image_pil
# Double-check PIL image size matches original
assert augmented_image_pil.size == image.size, (
f"PIL image size changed during augmentation: {image.size} -> {augmented_image_pil.size}"
)
assert augmented_image_pil.size == image.size, f"PIL image size changed during augmentation: {image.size} -> {augmented_image_pil.size}"
return sample
@dataclass(frozen=True, slots=True)
class InstructUserMessages(PipelineStep):
"""Creates instruction-following messages format for training."""

View File

@ -39,7 +39,7 @@ def muon_update(grad, momentum, beta=0.95, ns_steps=5, nesterov=True):
if update.ndim == 4: # for the case of conv filters
update = update.view(len(update), -1)
update = zeropower_via_newtonschulz5(update, steps=ns_steps)
update *= max(1, grad.size(-2) / grad.size(-1))**0.5
update *= max(1, grad.size(-2) / grad.size(-1)) ** 0.5
return update
@ -64,6 +64,7 @@ class Muon(torch.optim.Optimizer):
weight_decay: The AdamW-style weight decay.
momentum: The momentum. A value of 0.95 here is usually fine.
"""
def __init__(self, params, lr=0.02, weight_decay=0, momentum=0.95):
defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum)
assert isinstance(params, list) and len(params) >= 1 and isinstance(params[0], torch.nn.Parameter)
@ -81,7 +82,7 @@ class Muon(torch.optim.Optimizer):
for group in self.param_groups:
params = group["params"]
params_pad = params + [torch.empty_like(params[-1])] * (dist.get_world_size() - len(params) % dist.get_world_size())
for base_i in range(len(params))[::dist.get_world_size()]:
for base_i in range(len(params))[:: dist.get_world_size()]:
if base_i + dist.get_rank() < len(params):
p = params[base_i + dist.get_rank()]
if p.grad is None:
@ -93,7 +94,7 @@ class Muon(torch.optim.Optimizer):
update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"])
p.mul_(1 - group["lr"] * group["weight_decay"])
p.add_(update.reshape(p.shape), alpha=-group["lr"])
dist.all_gather(params_pad[base_i:base_i + dist.get_world_size()], params_pad[base_i + dist.get_rank()])
dist.all_gather(params_pad[base_i : base_i + dist.get_world_size()], params_pad[base_i + dist.get_rank()])
return loss
@ -102,6 +103,7 @@ class SingleDeviceMuon(torch.optim.Optimizer):
"""
Muon variant for usage in non-distributed settings.
"""
def __init__(self, params, lr=0.02, weight_decay=0, momentum=0.95):
defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum)
super().__init__(params, defaults)
@ -132,8 +134,8 @@ class SingleDeviceMuon(torch.optim.Optimizer):
def adam_update(grad, buf1, buf2, step, betas, eps):
buf1.lerp_(grad, 1 - betas[0])
buf2.lerp_(grad.square(), 1 - betas[1])
buf1c = buf1 / (1 - betas[0]**step)
buf2c = buf2 / (1 - betas[1]**step)
buf1c = buf1 / (1 - betas[0] ** step)
buf2c = buf2 / (1 - betas[1] ** step)
return buf1c / (buf2c.sqrt() + eps)
@ -164,6 +166,7 @@ class MuonWithAuxAdam(torch.optim.Optimizer):
optimizer = MuonWithAuxAdam(param_groups)
```
"""
def __init__(self, param_groups):
for group in param_groups:
assert "use_muon" in group
@ -195,7 +198,7 @@ class MuonWithAuxAdam(torch.optim.Optimizer):
if group["use_muon"]:
params = group["params"]
params_pad = params + [torch.empty_like(params[-1])] * (dist.get_world_size() - len(params) % dist.get_world_size())
for base_i in range(len(params))[::dist.get_world_size()]:
for base_i in range(len(params))[:: dist.get_world_size()]:
if base_i + dist.get_rank() < len(params):
p = params[base_i + dist.get_rank()]
if p.grad is None:
@ -207,7 +210,7 @@ class MuonWithAuxAdam(torch.optim.Optimizer):
update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"])
p.mul_(1 - group["lr"] * group["weight_decay"])
p.add_(update.reshape(p.shape), alpha=-group["lr"])
dist.all_gather(params_pad[base_i:base_i + dist.get_world_size()], params_pad[base_i + dist.get_rank()])
dist.all_gather(params_pad[base_i : base_i + dist.get_world_size()], params_pad[base_i + dist.get_rank()])
else:
for p in group["params"]:
if p.grad is None:
@ -219,8 +222,7 @@ class MuonWithAuxAdam(torch.optim.Optimizer):
state["exp_avg_sq"] = torch.zeros_like(p)
state["step"] = 0
state["step"] += 1
update = adam_update(p.grad, state["exp_avg"], state["exp_avg_sq"],
state["step"], group["betas"], group["eps"])
update = adam_update(p.grad, state["exp_avg"], state["exp_avg_sq"], state["step"], group["betas"], group["eps"])
p.mul_(1 - group["lr"] * group["weight_decay"])
p.add_(update, alpha=-group["lr"])
@ -231,6 +233,7 @@ class SingleDeviceMuonWithAuxAdam(torch.optim.Optimizer):
"""
Non-distributed variant of MuonWithAuxAdam.
"""
def __init__(self, param_groups):
for group in param_groups:
assert "use_muon" in group
@ -280,8 +283,7 @@ class SingleDeviceMuonWithAuxAdam(torch.optim.Optimizer):
state["exp_avg_sq"] = torch.zeros_like(p)
state["step"] = 0
state["step"] += 1
update = adam_update(p.grad, state["exp_avg"], state["exp_avg_sq"],
state["step"], group["betas"], group["eps"])
update = adam_update(p.grad, state["exp_avg"], state["exp_avg_sq"], state["step"], group["betas"], group["eps"])
p.mul_(1 - group["lr"] * group["weight_decay"])
p.add_(update, alpha=-group["lr"])

View File

@ -4,26 +4,25 @@ Simple script to test OlmOCR dataset loading with YAML configuration.
import argparse
import logging
import os
import math
import os
import shutil
from typing import Any, Dict, Optional
import numpy as np
import torch
from torch.utils.data import ConcatDataset, DataLoader
from torch.optim import AdamW
from torch.amp import autocast
import wandb
from torch.amp import autocast
from torch.optim import AdamW
from torch.utils.data import ConcatDataset, DataLoader
from tqdm import tqdm
from transformers import (
AutoProcessor,
get_scheduler,
Qwen2_5_VLForConditionalGeneration,
Qwen2VLForConditionalGeneration,
get_scheduler,
)
from typing import Optional, Dict, Any
from olmocr.train.config import Config
from olmocr.train.dataloader import BaseMarkdownPDFDataset
from olmocr.train.muon import SingleDeviceMuonWithAuxAdam
@ -37,7 +36,6 @@ logging.basicConfig(
logger = logging.getLogger(__name__)
class QwenDataCollator:
"""Data collator for vision-language models that handles numpy arrays."""
@ -128,10 +126,7 @@ def save_checkpoint(
# Enforce save_total_limit by removing oldest checkpoints
if save_total_limit is not None and save_total_limit > 0:
checkpoints = sorted(
[d for d in os.listdir(output_dir) if d.startswith("checkpoint-")],
key=lambda x: int(x.split("-")[1])
)
checkpoints = sorted([d for d in os.listdir(output_dir) if d.startswith("checkpoint-")], key=lambda x: int(x.split("-")[1]))
while len(checkpoints) > save_total_limit:
oldest = checkpoints.pop(0)
shutil.rmtree(os.path.join(output_dir, oldest))
@ -284,7 +279,6 @@ def main():
if len(dataset) > 0:
eval_datasets[dataset_name] = dataset
# Log total evaluation samples across all datasets
total_eval_samples = sum(len(dataset) for dataset in eval_datasets.values())
logger.info(f"Total evaluation samples across {len(eval_datasets)} datasets: {total_eval_samples}")
@ -316,6 +310,7 @@ def main():
worker_seed = torch.initial_seed() % 2**32
np.random.seed(worker_seed)
import random
random.seed(worker_seed)
# Create generator for data loader
@ -370,7 +365,7 @@ def main():
adam_groups = [
dict(params=head_params, lr=float(config.training.learning_rate) * config.training.muon_lr_multiplier_head, use_muon=False),
dict(params=embed_params, lr=float(config.training.learning_rate) * config.training.muon_lr_multiplier_embed, use_muon=False),
dict(params=scalar_params, lr=float(config.training.learning_rate) * config.training.muon_lr_multiplier_scalar, use_muon=False)
dict(params=scalar_params, lr=float(config.training.learning_rate) * config.training.muon_lr_multiplier_scalar, use_muon=False),
]
# Add Adam hyperparameters to groups
@ -385,7 +380,7 @@ def main():
lr=float(config.training.learning_rate),
momentum=config.training.muon_momentum,
weight_decay=config.training.weight_decay,
use_muon=True
use_muon=True,
)
# Combine all groups
@ -527,12 +522,7 @@ def main():
# Update progress bar with current stats
current_lr = lr_scheduler.get_last_lr()[0]
avg_loss = accumulated_loss / num_losses_accumulated if num_losses_accumulated > 0 else 0
pbar.set_postfix({
'loss': f'{avg_loss:.4f}',
'lr': f'{current_lr:.2e}',
'epoch': f'{current_epoch:.2f}',
'step': global_step
})
pbar.set_postfix({"loss": f"{avg_loss:.4f}", "lr": f"{current_lr:.2e}", "epoch": f"{current_epoch:.2f}", "step": global_step})
# Logging
if config.training.logging_steps > 0 and global_step % config.training.logging_steps == 0:
@ -560,8 +550,9 @@ def main():
# Update best metric
current_metric = metrics.get(config.training.metric_for_best_model, None)
if current_metric is not None:
if (config.training.greater_is_better and current_metric > best_metric) or \
(not config.training.greater_is_better and current_metric < best_metric):
if (config.training.greater_is_better and current_metric > best_metric) or (
not config.training.greater_is_better and current_metric < best_metric
):
best_metric = current_metric
# Return to training mode
@ -570,8 +561,7 @@ def main():
# Saving
if config.training.save_steps > 0 and global_step % config.training.save_steps == 0:
save_checkpoint(
model, optimizer, lr_scheduler, current_epoch, global_step, samples_seen, best_metric,
full_output_dir, config.training.save_total_limit
model, optimizer, lr_scheduler, current_epoch, global_step, samples_seen, best_metric, full_output_dir, config.training.save_total_limit
)
# Check if we've reached our training limit
@ -583,10 +573,7 @@ def main():
# Save the final checkpoint with step number
logger.info(f"Saving final checkpoint at step {global_step}...")
save_checkpoint(
model, optimizer, lr_scheduler, current_epoch, global_step, samples_seen, best_metric,
full_output_dir, config.training.save_total_limit
)
save_checkpoint(model, optimizer, lr_scheduler, current_epoch, global_step, samples_seen, best_metric, full_output_dir, config.training.save_total_limit)
# Log final training state
final_epoch = samples_seen / len(train_dataset)

View File

@ -171,7 +171,6 @@ class WorkQueue:
logger.info(f"Initialized queue with {self.size:,} work items")
return self.size
async def get_work(self, worker_lock_timeout_secs: int = 1800) -> Optional[WorkItem]:
"""
Get the next available work item that isn't completed or locked.
@ -179,7 +178,6 @@ class WorkQueue:
REFRESH_COMPLETED_HASH_CACHE_MAX_ATTEMPTS = 3
refresh_completed_hash_attempt = 0
while True:
try:
work_item = self._queue.get_nowait()
@ -281,11 +279,7 @@ class LocalBackend(Backend):
def _list_completed() -> Set[str]:
if not os.path.isdir(self._done_flags_dir):
return set()
return {
f[len("done_") : -len(".flag")]
for f in os.listdir(self._done_flags_dir)
if f.startswith("done_") and f.endswith(".flag")
}
return {f[len("done_") : -len(".flag")] for f in os.listdir(self._done_flags_dir) if f.startswith("done_") and f.endswith(".flag")}
return await asyncio.to_thread(_list_completed)
@ -299,6 +293,7 @@ class LocalBackend(Backend):
async def _get_object_mtime(self, path: str) -> Optional[datetime.datetime]:
"""Internal method to get object mtime."""
def _get_mtime() -> Optional[datetime.datetime]:
if not os.path.exists(path):
return None

View File

@ -6,7 +6,7 @@ from unittest.mock import Mock, patch
from botocore.exceptions import ClientError
# Import the classes we're testing
from olmocr.work_queue import WorkQueue, S3Backend, WorkItem
from olmocr.work_queue import S3Backend, WorkItem, WorkQueue
class TestS3WorkQueue(unittest.TestCase):