mirror of
https://github.com/allenai/olmocr.git
synced 2025-10-29 17:05:18 +00:00
Lint fixes
This commit is contained in:
parent
05330150ad
commit
93411a80a0
@ -49,7 +49,7 @@ from olmocr.s3_utils import (
|
||||
)
|
||||
from olmocr.train.dataloader import FrontMatterParser
|
||||
from olmocr.version import VERSION
|
||||
from olmocr.work_queue import WorkQueue, LocalBackend, S3Backend
|
||||
from olmocr.work_queue import LocalBackend, S3Backend, WorkQueue
|
||||
|
||||
# Initialize logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -394,9 +394,7 @@ class Config:
|
||||
steps.append(FrontMatterParser(front_matter_class=front_matter_class))
|
||||
|
||||
elif step_name == "PDFRenderer":
|
||||
steps.append(
|
||||
PDFRenderer(target_longest_image_dim=step_config.get("target_longest_image_dim", 1024))
|
||||
)
|
||||
steps.append(PDFRenderer(target_longest_image_dim=step_config.get("target_longest_image_dim", 1024)))
|
||||
|
||||
elif step_name == "StaticLengthDocumentAnchoring":
|
||||
steps.append(StaticLengthDocumentAnchoring(target_anchor_text_len=step_config.get("target_anchor_text_len", 6000)))
|
||||
@ -417,9 +415,7 @@ class Config:
|
||||
steps.append(FrontMatterOutputFormat())
|
||||
|
||||
elif step_name == "InstructUserMessages":
|
||||
steps.append(InstructUserMessages(
|
||||
prompt_first=step_config.get("prompt_first", False)
|
||||
))
|
||||
steps.append(InstructUserMessages(prompt_first=step_config.get("prompt_first", False)))
|
||||
|
||||
elif step_name == "LatexBracketNormalizer":
|
||||
steps.append(LatexBracketNormalizer())
|
||||
@ -462,18 +458,10 @@ class Config:
|
||||
steps.append(FilterOutRotatedDocuments())
|
||||
|
||||
elif step_name == "RotationAugmentation":
|
||||
steps.append(
|
||||
RotationAugmentation(
|
||||
probability=step_config.get("probability", 0.5)
|
||||
)
|
||||
)
|
||||
steps.append(RotationAugmentation(probability=step_config.get("probability", 0.5)))
|
||||
|
||||
elif step_name == "AugraphyBasicAugmentations":
|
||||
steps.append(
|
||||
AugraphyBasicAugmentations(
|
||||
probability=step_config.get("probability", 0.5)
|
||||
)
|
||||
)
|
||||
steps.append(AugraphyBasicAugmentations(probability=step_config.get("probability", 0.5)))
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown pipeline step: {step_name}")
|
||||
|
||||
@ -5,13 +5,11 @@ import re
|
||||
from abc import ABC, abstractmethod
|
||||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||
from dataclasses import dataclass, fields
|
||||
from functools import reduce
|
||||
from io import BytesIO
|
||||
from os import PathLike
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
@ -144,7 +142,7 @@ class BaseMarkdownPDFDataset(Dataset):
|
||||
pbar.update(1)
|
||||
|
||||
# Sort samples by markdown path for consistent ordering across runs
|
||||
self.samples.sort(key=lambda x: x['markdown_path'])
|
||||
self.samples.sort(key=lambda x: x["markdown_path"])
|
||||
|
||||
logger.info(f"Found {valid_count} valid markdown-PDF pairs")
|
||||
|
||||
@ -551,18 +549,17 @@ class AugraphyBasicAugmentations(PipelineStep):
|
||||
# Apply a basic augraphy pipeline
|
||||
from augraphy import (
|
||||
AugraphyPipeline,
|
||||
Brightness,
|
||||
InkBleed,
|
||||
InkMottling,
|
||||
InkShifter,
|
||||
Jpeg,
|
||||
LowInkPeriodicLines,
|
||||
LowInkRandomLines,
|
||||
OneOf,
|
||||
|
||||
Jpeg,
|
||||
InkMottling,
|
||||
InkShifter,
|
||||
Brightness,
|
||||
)
|
||||
|
||||
# Apply geometric transformations first, maintaing scale
|
||||
# Apply geometric transformations first, maintaing scale
|
||||
if np.random.random() < 0.50:
|
||||
# Get dimensions
|
||||
height, width = image_bgr.shape[:2]
|
||||
@ -590,25 +587,20 @@ class AugraphyBasicAugmentations(PipelineStep):
|
||||
(width, height),
|
||||
flags=cv2.INTER_LINEAR,
|
||||
borderMode=cv2.BORDER_CONSTANT,
|
||||
borderValue=(255, 255, 255) # White background for documents
|
||||
borderValue=(255, 255, 255), # White background for documents
|
||||
)
|
||||
|
||||
|
||||
ink_phase = [
|
||||
OneOf([InkBleed(p=1), LowInkRandomLines(p=1), LowInkPeriodicLines(p=1), InkMottling(p=1), InkShifter(p=1, text_shift_scale_range=(10, 15))], p=0.2),
|
||||
]
|
||||
|
||||
paper_phase = [
|
||||
OneOf([Brightness(p=0.2), Jpeg(p=1)])
|
||||
]
|
||||
paper_phase = [OneOf([Brightness(p=0.2), Jpeg(p=1)])]
|
||||
|
||||
post_phase = [
|
||||
# Empty on purpose or else augmentations are too strong
|
||||
]
|
||||
|
||||
augmentation_pipeline = AugraphyPipeline(
|
||||
ink_phase=ink_phase, paper_phase=paper_phase, post_phase=post_phase
|
||||
)
|
||||
augmentation_pipeline = AugraphyPipeline(ink_phase=ink_phase, paper_phase=paper_phase, post_phase=post_phase)
|
||||
|
||||
# Apply augmentations
|
||||
augmented_image_bgr = augmentation_pipeline(image_bgr)
|
||||
@ -621,12 +613,11 @@ class AugraphyBasicAugmentations(PipelineStep):
|
||||
sample["image"] = augmented_image_pil
|
||||
|
||||
# Double-check PIL image size matches original
|
||||
assert augmented_image_pil.size == image.size, (
|
||||
f"PIL image size changed during augmentation: {image.size} -> {augmented_image_pil.size}"
|
||||
)
|
||||
assert augmented_image_pil.size == image.size, f"PIL image size changed during augmentation: {image.size} -> {augmented_image_pil.size}"
|
||||
|
||||
return sample
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class InstructUserMessages(PipelineStep):
|
||||
"""Creates instruction-following messages format for training."""
|
||||
|
||||
@ -14,8 +14,8 @@ def zeropower_via_newtonschulz5(G, steps: int):
|
||||
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
|
||||
performance at all relative to UV^T, where USV^T = G is the SVD.
|
||||
"""
|
||||
assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
|
||||
a, b, c = (3.4445, -4.7750, 2.0315)
|
||||
assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
|
||||
a, b, c = (3.4445, -4.7750, 2.0315)
|
||||
X = G.bfloat16()
|
||||
if G.size(-2) > G.size(-1):
|
||||
X = X.mT
|
||||
@ -25,7 +25,7 @@ def zeropower_via_newtonschulz5(G, steps: int):
|
||||
# Perform the NS iterations
|
||||
for _ in range(steps):
|
||||
A = X @ X.mT
|
||||
B = b * A + c * A @ A # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
|
||||
B = b * A + c * A @ A # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
|
||||
X = a * X + B @ X
|
||||
|
||||
if G.size(-2) > G.size(-1):
|
||||
@ -36,10 +36,10 @@ def zeropower_via_newtonschulz5(G, steps: int):
|
||||
def muon_update(grad, momentum, beta=0.95, ns_steps=5, nesterov=True):
|
||||
momentum.lerp_(grad, 1 - beta)
|
||||
update = grad.lerp_(momentum, beta) if nesterov else momentum
|
||||
if update.ndim == 4: # for the case of conv filters
|
||||
if update.ndim == 4: # for the case of conv filters
|
||||
update = update.view(len(update), -1)
|
||||
update = zeropower_via_newtonschulz5(update, steps=ns_steps)
|
||||
update *= max(1, grad.size(-2) / grad.size(-1))**0.5
|
||||
update *= max(1, grad.size(-2) / grad.size(-1)) ** 0.5
|
||||
return update
|
||||
|
||||
|
||||
@ -64,6 +64,7 @@ class Muon(torch.optim.Optimizer):
|
||||
weight_decay: The AdamW-style weight decay.
|
||||
momentum: The momentum. A value of 0.95 here is usually fine.
|
||||
"""
|
||||
|
||||
def __init__(self, params, lr=0.02, weight_decay=0, momentum=0.95):
|
||||
defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum)
|
||||
assert isinstance(params, list) and len(params) >= 1 and isinstance(params[0], torch.nn.Parameter)
|
||||
@ -81,7 +82,7 @@ class Muon(torch.optim.Optimizer):
|
||||
for group in self.param_groups:
|
||||
params = group["params"]
|
||||
params_pad = params + [torch.empty_like(params[-1])] * (dist.get_world_size() - len(params) % dist.get_world_size())
|
||||
for base_i in range(len(params))[::dist.get_world_size()]:
|
||||
for base_i in range(len(params))[:: dist.get_world_size()]:
|
||||
if base_i + dist.get_rank() < len(params):
|
||||
p = params[base_i + dist.get_rank()]
|
||||
if p.grad is None:
|
||||
@ -93,7 +94,7 @@ class Muon(torch.optim.Optimizer):
|
||||
update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"])
|
||||
p.mul_(1 - group["lr"] * group["weight_decay"])
|
||||
p.add_(update.reshape(p.shape), alpha=-group["lr"])
|
||||
dist.all_gather(params_pad[base_i:base_i + dist.get_world_size()], params_pad[base_i + dist.get_rank()])
|
||||
dist.all_gather(params_pad[base_i : base_i + dist.get_world_size()], params_pad[base_i + dist.get_rank()])
|
||||
|
||||
return loss
|
||||
|
||||
@ -102,6 +103,7 @@ class SingleDeviceMuon(torch.optim.Optimizer):
|
||||
"""
|
||||
Muon variant for usage in non-distributed settings.
|
||||
"""
|
||||
|
||||
def __init__(self, params, lr=0.02, weight_decay=0, momentum=0.95):
|
||||
defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum)
|
||||
super().__init__(params, defaults)
|
||||
@ -132,8 +134,8 @@ class SingleDeviceMuon(torch.optim.Optimizer):
|
||||
def adam_update(grad, buf1, buf2, step, betas, eps):
|
||||
buf1.lerp_(grad, 1 - betas[0])
|
||||
buf2.lerp_(grad.square(), 1 - betas[1])
|
||||
buf1c = buf1 / (1 - betas[0]**step)
|
||||
buf2c = buf2 / (1 - betas[1]**step)
|
||||
buf1c = buf1 / (1 - betas[0] ** step)
|
||||
buf2c = buf2 / (1 - betas[1] ** step)
|
||||
return buf1c / (buf2c.sqrt() + eps)
|
||||
|
||||
|
||||
@ -164,6 +166,7 @@ class MuonWithAuxAdam(torch.optim.Optimizer):
|
||||
optimizer = MuonWithAuxAdam(param_groups)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, param_groups):
|
||||
for group in param_groups:
|
||||
assert "use_muon" in group
|
||||
@ -195,7 +198,7 @@ class MuonWithAuxAdam(torch.optim.Optimizer):
|
||||
if group["use_muon"]:
|
||||
params = group["params"]
|
||||
params_pad = params + [torch.empty_like(params[-1])] * (dist.get_world_size() - len(params) % dist.get_world_size())
|
||||
for base_i in range(len(params))[::dist.get_world_size()]:
|
||||
for base_i in range(len(params))[:: dist.get_world_size()]:
|
||||
if base_i + dist.get_rank() < len(params):
|
||||
p = params[base_i + dist.get_rank()]
|
||||
if p.grad is None:
|
||||
@ -207,7 +210,7 @@ class MuonWithAuxAdam(torch.optim.Optimizer):
|
||||
update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"])
|
||||
p.mul_(1 - group["lr"] * group["weight_decay"])
|
||||
p.add_(update.reshape(p.shape), alpha=-group["lr"])
|
||||
dist.all_gather(params_pad[base_i:base_i + dist.get_world_size()], params_pad[base_i + dist.get_rank()])
|
||||
dist.all_gather(params_pad[base_i : base_i + dist.get_world_size()], params_pad[base_i + dist.get_rank()])
|
||||
else:
|
||||
for p in group["params"]:
|
||||
if p.grad is None:
|
||||
@ -219,8 +222,7 @@ class MuonWithAuxAdam(torch.optim.Optimizer):
|
||||
state["exp_avg_sq"] = torch.zeros_like(p)
|
||||
state["step"] = 0
|
||||
state["step"] += 1
|
||||
update = adam_update(p.grad, state["exp_avg"], state["exp_avg_sq"],
|
||||
state["step"], group["betas"], group["eps"])
|
||||
update = adam_update(p.grad, state["exp_avg"], state["exp_avg_sq"], state["step"], group["betas"], group["eps"])
|
||||
p.mul_(1 - group["lr"] * group["weight_decay"])
|
||||
p.add_(update, alpha=-group["lr"])
|
||||
|
||||
@ -231,6 +233,7 @@ class SingleDeviceMuonWithAuxAdam(torch.optim.Optimizer):
|
||||
"""
|
||||
Non-distributed variant of MuonWithAuxAdam.
|
||||
"""
|
||||
|
||||
def __init__(self, param_groups):
|
||||
for group in param_groups:
|
||||
assert "use_muon" in group
|
||||
@ -280,8 +283,7 @@ class SingleDeviceMuonWithAuxAdam(torch.optim.Optimizer):
|
||||
state["exp_avg_sq"] = torch.zeros_like(p)
|
||||
state["step"] = 0
|
||||
state["step"] += 1
|
||||
update = adam_update(p.grad, state["exp_avg"], state["exp_avg_sq"],
|
||||
state["step"], group["betas"], group["eps"])
|
||||
update = adam_update(p.grad, state["exp_avg"], state["exp_avg_sq"], state["step"], group["betas"], group["eps"])
|
||||
p.mul_(1 - group["lr"] * group["weight_decay"])
|
||||
p.add_(update, alpha=-group["lr"])
|
||||
|
||||
|
||||
@ -4,26 +4,25 @@ Simple script to test OlmOCR dataset loading with YAML configuration.
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
import math
|
||||
import os
|
||||
import shutil
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import ConcatDataset, DataLoader
|
||||
from torch.optim import AdamW
|
||||
from torch.amp import autocast
|
||||
import wandb
|
||||
from torch.amp import autocast
|
||||
from torch.optim import AdamW
|
||||
from torch.utils.data import ConcatDataset, DataLoader
|
||||
from tqdm import tqdm
|
||||
|
||||
from transformers import (
|
||||
AutoProcessor,
|
||||
get_scheduler,
|
||||
Qwen2_5_VLForConditionalGeneration,
|
||||
Qwen2VLForConditionalGeneration,
|
||||
get_scheduler,
|
||||
)
|
||||
|
||||
from typing import Optional, Dict, Any
|
||||
from olmocr.train.config import Config
|
||||
from olmocr.train.dataloader import BaseMarkdownPDFDataset
|
||||
from olmocr.train.muon import SingleDeviceMuonWithAuxAdam
|
||||
@ -37,7 +36,6 @@ logging.basicConfig(
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
||||
class QwenDataCollator:
|
||||
"""Data collator for vision-language models that handles numpy arrays."""
|
||||
|
||||
@ -128,10 +126,7 @@ def save_checkpoint(
|
||||
|
||||
# Enforce save_total_limit by removing oldest checkpoints
|
||||
if save_total_limit is not None and save_total_limit > 0:
|
||||
checkpoints = sorted(
|
||||
[d for d in os.listdir(output_dir) if d.startswith("checkpoint-")],
|
||||
key=lambda x: int(x.split("-")[1])
|
||||
)
|
||||
checkpoints = sorted([d for d in os.listdir(output_dir) if d.startswith("checkpoint-")], key=lambda x: int(x.split("-")[1]))
|
||||
while len(checkpoints) > save_total_limit:
|
||||
oldest = checkpoints.pop(0)
|
||||
shutil.rmtree(os.path.join(output_dir, oldest))
|
||||
@ -284,7 +279,6 @@ def main():
|
||||
if len(dataset) > 0:
|
||||
eval_datasets[dataset_name] = dataset
|
||||
|
||||
|
||||
# Log total evaluation samples across all datasets
|
||||
total_eval_samples = sum(len(dataset) for dataset in eval_datasets.values())
|
||||
logger.info(f"Total evaluation samples across {len(eval_datasets)} datasets: {total_eval_samples}")
|
||||
@ -316,6 +310,7 @@ def main():
|
||||
worker_seed = torch.initial_seed() % 2**32
|
||||
np.random.seed(worker_seed)
|
||||
import random
|
||||
|
||||
random.seed(worker_seed)
|
||||
|
||||
# Create generator for data loader
|
||||
@ -370,7 +365,7 @@ def main():
|
||||
adam_groups = [
|
||||
dict(params=head_params, lr=float(config.training.learning_rate) * config.training.muon_lr_multiplier_head, use_muon=False),
|
||||
dict(params=embed_params, lr=float(config.training.learning_rate) * config.training.muon_lr_multiplier_embed, use_muon=False),
|
||||
dict(params=scalar_params, lr=float(config.training.learning_rate) * config.training.muon_lr_multiplier_scalar, use_muon=False)
|
||||
dict(params=scalar_params, lr=float(config.training.learning_rate) * config.training.muon_lr_multiplier_scalar, use_muon=False),
|
||||
]
|
||||
|
||||
# Add Adam hyperparameters to groups
|
||||
@ -385,7 +380,7 @@ def main():
|
||||
lr=float(config.training.learning_rate),
|
||||
momentum=config.training.muon_momentum,
|
||||
weight_decay=config.training.weight_decay,
|
||||
use_muon=True
|
||||
use_muon=True,
|
||||
)
|
||||
|
||||
# Combine all groups
|
||||
@ -527,12 +522,7 @@ def main():
|
||||
# Update progress bar with current stats
|
||||
current_lr = lr_scheduler.get_last_lr()[0]
|
||||
avg_loss = accumulated_loss / num_losses_accumulated if num_losses_accumulated > 0 else 0
|
||||
pbar.set_postfix({
|
||||
'loss': f'{avg_loss:.4f}',
|
||||
'lr': f'{current_lr:.2e}',
|
||||
'epoch': f'{current_epoch:.2f}',
|
||||
'step': global_step
|
||||
})
|
||||
pbar.set_postfix({"loss": f"{avg_loss:.4f}", "lr": f"{current_lr:.2e}", "epoch": f"{current_epoch:.2f}", "step": global_step})
|
||||
|
||||
# Logging
|
||||
if config.training.logging_steps > 0 and global_step % config.training.logging_steps == 0:
|
||||
@ -560,8 +550,9 @@ def main():
|
||||
# Update best metric
|
||||
current_metric = metrics.get(config.training.metric_for_best_model, None)
|
||||
if current_metric is not None:
|
||||
if (config.training.greater_is_better and current_metric > best_metric) or \
|
||||
(not config.training.greater_is_better and current_metric < best_metric):
|
||||
if (config.training.greater_is_better and current_metric > best_metric) or (
|
||||
not config.training.greater_is_better and current_metric < best_metric
|
||||
):
|
||||
best_metric = current_metric
|
||||
|
||||
# Return to training mode
|
||||
@ -570,8 +561,7 @@ def main():
|
||||
# Saving
|
||||
if config.training.save_steps > 0 and global_step % config.training.save_steps == 0:
|
||||
save_checkpoint(
|
||||
model, optimizer, lr_scheduler, current_epoch, global_step, samples_seen, best_metric,
|
||||
full_output_dir, config.training.save_total_limit
|
||||
model, optimizer, lr_scheduler, current_epoch, global_step, samples_seen, best_metric, full_output_dir, config.training.save_total_limit
|
||||
)
|
||||
|
||||
# Check if we've reached our training limit
|
||||
@ -583,10 +573,7 @@ def main():
|
||||
|
||||
# Save the final checkpoint with step number
|
||||
logger.info(f"Saving final checkpoint at step {global_step}...")
|
||||
save_checkpoint(
|
||||
model, optimizer, lr_scheduler, current_epoch, global_step, samples_seen, best_metric,
|
||||
full_output_dir, config.training.save_total_limit
|
||||
)
|
||||
save_checkpoint(model, optimizer, lr_scheduler, current_epoch, global_step, samples_seen, best_metric, full_output_dir, config.training.save_total_limit)
|
||||
|
||||
# Log final training state
|
||||
final_epoch = samples_seen / len(train_dataset)
|
||||
|
||||
@ -171,7 +171,6 @@ class WorkQueue:
|
||||
logger.info(f"Initialized queue with {self.size:,} work items")
|
||||
return self.size
|
||||
|
||||
|
||||
async def get_work(self, worker_lock_timeout_secs: int = 1800) -> Optional[WorkItem]:
|
||||
"""
|
||||
Get the next available work item that isn't completed or locked.
|
||||
@ -179,7 +178,6 @@ class WorkQueue:
|
||||
REFRESH_COMPLETED_HASH_CACHE_MAX_ATTEMPTS = 3
|
||||
refresh_completed_hash_attempt = 0
|
||||
|
||||
|
||||
while True:
|
||||
try:
|
||||
work_item = self._queue.get_nowait()
|
||||
@ -281,11 +279,7 @@ class LocalBackend(Backend):
|
||||
def _list_completed() -> Set[str]:
|
||||
if not os.path.isdir(self._done_flags_dir):
|
||||
return set()
|
||||
return {
|
||||
f[len("done_") : -len(".flag")]
|
||||
for f in os.listdir(self._done_flags_dir)
|
||||
if f.startswith("done_") and f.endswith(".flag")
|
||||
}
|
||||
return {f[len("done_") : -len(".flag")] for f in os.listdir(self._done_flags_dir) if f.startswith("done_") and f.endswith(".flag")}
|
||||
|
||||
return await asyncio.to_thread(_list_completed)
|
||||
|
||||
@ -299,6 +293,7 @@ class LocalBackend(Backend):
|
||||
|
||||
async def _get_object_mtime(self, path: str) -> Optional[datetime.datetime]:
|
||||
"""Internal method to get object mtime."""
|
||||
|
||||
def _get_mtime() -> Optional[datetime.datetime]:
|
||||
if not os.path.exists(path):
|
||||
return None
|
||||
|
||||
@ -6,7 +6,7 @@ from unittest.mock import Mock, patch
|
||||
from botocore.exceptions import ClientError
|
||||
|
||||
# Import the classes we're testing
|
||||
from olmocr.work_queue import WorkQueue, S3Backend, WorkItem
|
||||
from olmocr.work_queue import S3Backend, WorkItem, WorkQueue
|
||||
|
||||
|
||||
class TestS3WorkQueue(unittest.TestCase):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user