Lint fixes

This commit is contained in:
Jake Poznanski 2025-08-13 20:21:04 +00:00
parent 05330150ad
commit 93411a80a0
8 changed files with 157 additions and 194 deletions

View File

@ -49,7 +49,7 @@ from olmocr.s3_utils import (
) )
from olmocr.train.dataloader import FrontMatterParser from olmocr.train.dataloader import FrontMatterParser
from olmocr.version import VERSION from olmocr.version import VERSION
from olmocr.work_queue import WorkQueue, LocalBackend, S3Backend from olmocr.work_queue import LocalBackend, S3Backend, WorkQueue
# Initialize logger # Initialize logger
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -204,11 +204,11 @@ class TrainingConfig:
adam_epsilon: float = 1e-8 adam_epsilon: float = 1e-8
weight_decay: float = 0.01 weight_decay: float = 0.01
max_grad_norm: float = 1.0 max_grad_norm: float = 1.0
# Muon optimizer specific settings # Muon optimizer specific settings
muon_momentum: float = 0.95 muon_momentum: float = 0.95
muon_lr_multiplier_head: float = 11.0 # Learning rate multiplier for head parameters muon_lr_multiplier_head: float = 11.0 # Learning rate multiplier for head parameters
muon_lr_multiplier_embed: float = 30.0 # Learning rate multiplier for embedding parameters muon_lr_multiplier_embed: float = 30.0 # Learning rate multiplier for embedding parameters
muon_lr_multiplier_scalar: float = 2.0 # Learning rate multiplier for scalar parameters muon_lr_multiplier_scalar: float = 2.0 # Learning rate multiplier for scalar parameters
# Gradient checkpointing # Gradient checkpointing
@ -243,7 +243,7 @@ class TrainingConfig:
# Data collator settings # Data collator settings
collator_max_token_len: Optional[int] = None collator_max_token_len: Optional[int] = None
remove_unused_columns: bool = False # Important for custom datasets remove_unused_columns: bool = False # Important for custom datasets
# Torch compile settings # Torch compile settings
torch_compile: bool = False torch_compile: bool = False
torch_compile_backend: str = "inductor" # "inductor", "aot_eager", "cudagraphs", etc. torch_compile_backend: str = "inductor" # "inductor", "aot_eager", "cudagraphs", etc.
@ -394,9 +394,7 @@ class Config:
steps.append(FrontMatterParser(front_matter_class=front_matter_class)) steps.append(FrontMatterParser(front_matter_class=front_matter_class))
elif step_name == "PDFRenderer": elif step_name == "PDFRenderer":
steps.append( steps.append(PDFRenderer(target_longest_image_dim=step_config.get("target_longest_image_dim", 1024)))
PDFRenderer(target_longest_image_dim=step_config.get("target_longest_image_dim", 1024))
)
elif step_name == "StaticLengthDocumentAnchoring": elif step_name == "StaticLengthDocumentAnchoring":
steps.append(StaticLengthDocumentAnchoring(target_anchor_text_len=step_config.get("target_anchor_text_len", 6000))) steps.append(StaticLengthDocumentAnchoring(target_anchor_text_len=step_config.get("target_anchor_text_len", 6000)))
@ -417,9 +415,7 @@ class Config:
steps.append(FrontMatterOutputFormat()) steps.append(FrontMatterOutputFormat())
elif step_name == "InstructUserMessages": elif step_name == "InstructUserMessages":
steps.append(InstructUserMessages( steps.append(InstructUserMessages(prompt_first=step_config.get("prompt_first", False)))
prompt_first=step_config.get("prompt_first", False)
))
elif step_name == "LatexBracketNormalizer": elif step_name == "LatexBracketNormalizer":
steps.append(LatexBracketNormalizer()) steps.append(LatexBracketNormalizer())
@ -457,24 +453,16 @@ class Config:
masking_index=step_config.get("masking_index", -100), masking_index=step_config.get("masking_index", -100),
) )
) )
elif step_name == "FilterOutRotatedDocuments": elif step_name == "FilterOutRotatedDocuments":
steps.append(FilterOutRotatedDocuments()) steps.append(FilterOutRotatedDocuments())
elif step_name == "RotationAugmentation": elif step_name == "RotationAugmentation":
steps.append( steps.append(RotationAugmentation(probability=step_config.get("probability", 0.5)))
RotationAugmentation(
probability=step_config.get("probability", 0.5)
)
)
elif step_name == "AugraphyBasicAugmentations": elif step_name == "AugraphyBasicAugmentations":
steps.append( steps.append(AugraphyBasicAugmentations(probability=step_config.get("probability", 0.5)))
AugraphyBasicAugmentations(
probability=step_config.get("probability", 0.5)
)
)
else: else:
raise ValueError(f"Unknown pipeline step: {step_name}") raise ValueError(f"Unknown pipeline step: {step_name}")

View File

@ -5,13 +5,11 @@ import re
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from concurrent.futures import ProcessPoolExecutor, as_completed from concurrent.futures import ProcessPoolExecutor, as_completed
from dataclasses import dataclass, fields from dataclasses import dataclass, fields
from functools import reduce
from io import BytesIO from io import BytesIO
from os import PathLike from os import PathLike
from pathlib import Path from pathlib import Path
from typing import ( from typing import (
Any, Any,
Callable,
Dict, Dict,
List, List,
Optional, Optional,
@ -144,8 +142,8 @@ class BaseMarkdownPDFDataset(Dataset):
pbar.update(1) pbar.update(1)
# Sort samples by markdown path for consistent ordering across runs # Sort samples by markdown path for consistent ordering across runs
self.samples.sort(key=lambda x: x['markdown_path']) self.samples.sort(key=lambda x: x["markdown_path"])
logger.info(f"Found {valid_count} valid markdown-PDF pairs") logger.info(f"Found {valid_count} valid markdown-PDF pairs")
if invalid_pdfs: if invalid_pdfs:
@ -178,7 +176,7 @@ class BaseMarkdownPDFDataset(Dataset):
sample = step(sample) sample = step(sample)
if sample is None: if sample is None:
return None return None
return sample return sample
@ -440,26 +438,26 @@ class LatexBracketNormalizer(PipelineStep):
@dataclass(frozen=True, slots=True) @dataclass(frozen=True, slots=True)
class RotationAugmentation(PipelineStep): class RotationAugmentation(PipelineStep):
"""Pipeline step that randomly rotates images for augmentation.""" """Pipeline step that randomly rotates images for augmentation."""
probability: float = 0.5 # Probability of applying rotation probability: float = 0.5 # Probability of applying rotation
def __call__(self, sample: Sample) -> Optional[Sample]: def __call__(self, sample: Sample) -> Optional[Sample]:
"""Randomly rotate image and update rotation metadata.""" """Randomly rotate image and update rotation metadata."""
# Only proceed with given probability # Only proceed with given probability
if np.random.random() > self.probability: if np.random.random() > self.probability:
return sample return sample
# Check if image exists # Check if image exists
if "image" not in sample: if "image" not in sample:
return sample return sample
# Check if page_data exists (we need to update it) # Check if page_data exists (we need to update it)
if "page_data" not in sample: if "page_data" not in sample:
return sample return sample
# Randomly choose a rotation (90, 180, or 270 degrees) # Randomly choose a rotation (90, 180, or 270 degrees)
rotation_degrees = np.random.choice([90, 180, 270]) rotation_degrees = np.random.choice([90, 180, 270])
# Apply rotation to image # Apply rotation to image
image = sample["image"] image = sample["image"]
if rotation_degrees == 90: if rotation_degrees == 90:
@ -468,13 +466,13 @@ class RotationAugmentation(PipelineStep):
transpose = Image.Transpose.ROTATE_180 transpose = Image.Transpose.ROTATE_180
else: # 270 else: # 270
transpose = Image.Transpose.ROTATE_270 transpose = Image.Transpose.ROTATE_270
rotated_image = image.transpose(transpose) rotated_image = image.transpose(transpose)
sample["image"] = rotated_image sample["image"] = rotated_image
# Update page_data # Update page_data
page_data = sample["page_data"] page_data = sample["page_data"]
# Create new PageResponse with updated rotation info # Create new PageResponse with updated rotation info
# The rotation_correction should be the inverse of what we applied # The rotation_correction should be the inverse of what we applied
# If we rotated 90 clockwise, we need 270 counter-clockwise to correct it # If we rotated 90 clockwise, we need 270 counter-clockwise to correct it
@ -484,9 +482,9 @@ class RotationAugmentation(PipelineStep):
correction = 180 correction = 180
else: # 270 else: # 270
correction = 90 correction = 90
from olmocr.prompts.prompts import PageResponse from olmocr.prompts.prompts import PageResponse
new_page_data = PageResponse( new_page_data = PageResponse(
primary_language=page_data.primary_language, primary_language=page_data.primary_language,
is_rotation_valid=False, # Mark as invalid since we rotated it is_rotation_valid=False, # Mark as invalid since we rotated it
@ -495,7 +493,7 @@ class RotationAugmentation(PipelineStep):
is_diagram=page_data.is_diagram, is_diagram=page_data.is_diagram,
natural_text=page_data.natural_text, natural_text=page_data.natural_text,
) )
sample["page_data"] = new_page_data sample["page_data"] = new_page_data
return sample return sample
@ -509,24 +507,24 @@ class FilterOutRotatedDocuments(PipelineStep):
# Check if page_data exists # Check if page_data exists
if "page_data" not in sample: if "page_data" not in sample:
return sample return sample
page_data = sample["page_data"] page_data = sample["page_data"]
# Check if page_data has the required attributes # Check if page_data has the required attributes
if not hasattr(page_data, "is_rotation_valid") or not hasattr(page_data, "rotation_correction"): if not hasattr(page_data, "is_rotation_valid") or not hasattr(page_data, "rotation_correction"):
return sample return sample
# Filter out if rotation is invalid or rotation correction is not 0 # Filter out if rotation is invalid or rotation correction is not 0
if page_data.is_rotation_valid is False or page_data.rotation_correction != 0: if page_data.is_rotation_valid is False or page_data.rotation_correction != 0:
return None return None
return sample return sample
@dataclass(frozen=True, slots=True) @dataclass(frozen=True, slots=True)
class AugraphyBasicAugmentations(PipelineStep): class AugraphyBasicAugmentations(PipelineStep):
"""Pipeline step that applies a decent selection of augraphy augmentations to the data""" """Pipeline step that applies a decent selection of augraphy augmentations to the data"""
probability: float = 0.5 # Overall probability of applying any augmentation probability: float = 0.5 # Overall probability of applying any augmentation
def __call__(self, sample: Sample) -> Optional[Sample]: def __call__(self, sample: Sample) -> Optional[Sample]:
@ -534,103 +532,96 @@ class AugraphyBasicAugmentations(PipelineStep):
# Check that the image data exists # Check that the image data exists
if "image" not in sample: if "image" not in sample:
return sample return sample
image = sample["image"] image = sample["image"]
# Skip all augmentations based on overall probability # Skip all augmentations based on overall probability
if np.random.random() > self.probability: if np.random.random() > self.probability:
return sample return sample
# Convert from PIL to BGR for OpenCV/Augraphy # Convert from PIL to BGR for OpenCV/Augraphy
image_numpy = np.array(image) image_numpy = np.array(image)
if len(image_numpy.shape) < 3: if len(image_numpy.shape) < 3:
image_bgr = cv2.cvtColor(image_numpy, cv2.COLOR_GRAY2BGR) image_bgr = cv2.cvtColor(image_numpy, cv2.COLOR_GRAY2BGR)
else: else:
image_bgr = cv2.cvtColor(image_numpy, cv2.COLOR_RGB2BGR) image_bgr = cv2.cvtColor(image_numpy, cv2.COLOR_RGB2BGR)
# Apply a basic augraphy pipeline # Apply a basic augraphy pipeline
from augraphy import ( from augraphy import (
AugraphyPipeline, AugraphyPipeline,
Brightness,
InkBleed, InkBleed,
InkMottling,
InkShifter,
Jpeg,
LowInkPeriodicLines, LowInkPeriodicLines,
LowInkRandomLines, LowInkRandomLines,
OneOf, OneOf,
Jpeg,
InkMottling,
InkShifter,
Brightness,
) )
# Apply geometric transformations first, maintaing scale # Apply geometric transformations first, maintaing scale
if np.random.random() < 0.50: if np.random.random() < 0.50:
# Get dimensions # Get dimensions
height, width = image_bgr.shape[:2] height, width = image_bgr.shape[:2]
# Random parameters for geometric transformations # Random parameters for geometric transformations
angle = max(min(np.random.standard_normal(), 3), -3) # Small rotation range angle = max(min(np.random.standard_normal(), 3), -3) # Small rotation range
scale = np.random.uniform(0.95, 1.05) # Small scale range scale = np.random.uniform(0.95, 1.05) # Small scale range
tx = np.random.uniform(-0.02, 0.02) * width # Translation as fraction of width tx = np.random.uniform(-0.02, 0.02) * width # Translation as fraction of width
ty = np.random.uniform(-0.02, 0.02) * height # Translation as fraction of height ty = np.random.uniform(-0.02, 0.02) * height # Translation as fraction of height
# Calculate center point # Calculate center point
center = (width / 2, height / 2) center = (width / 2, height / 2)
# Create transformation matrix # Create transformation matrix
M = cv2.getRotationMatrix2D(center, angle, scale) M = cv2.getRotationMatrix2D(center, angle, scale)
# Add translation # Add translation
M[0, 2] += tx M[0, 2] += tx
M[1, 2] += ty M[1, 2] += ty
# Apply transformation # Apply transformation
image_bgr = cv2.warpAffine( image_bgr = cv2.warpAffine(
image_bgr, image_bgr,
M, M,
(width, height), (width, height),
flags=cv2.INTER_LINEAR, flags=cv2.INTER_LINEAR,
borderMode=cv2.BORDER_CONSTANT, borderMode=cv2.BORDER_CONSTANT,
borderValue=(255, 255, 255) # White background for documents borderValue=(255, 255, 255), # White background for documents
) )
ink_phase = [ ink_phase = [
OneOf([InkBleed(p=1), LowInkRandomLines(p=1), LowInkPeriodicLines(p=1), InkMottling(p=1), InkShifter(p=1, text_shift_scale_range=(10, 15))], p=0.2), OneOf([InkBleed(p=1), LowInkRandomLines(p=1), LowInkPeriodicLines(p=1), InkMottling(p=1), InkShifter(p=1, text_shift_scale_range=(10, 15))], p=0.2),
] ]
paper_phase = [ paper_phase = [OneOf([Brightness(p=0.2), Jpeg(p=1)])]
OneOf([Brightness(p=0.2), Jpeg(p=1)])
]
post_phase = [ post_phase = [
# Empty on purpose or else augmentations are too strong # Empty on purpose or else augmentations are too strong
] ]
augmentation_pipeline = AugraphyPipeline( augmentation_pipeline = AugraphyPipeline(ink_phase=ink_phase, paper_phase=paper_phase, post_phase=post_phase)
ink_phase=ink_phase, paper_phase=paper_phase, post_phase=post_phase
)
# Apply augmentations # Apply augmentations
augmented_image_bgr = augmentation_pipeline(image_bgr) augmented_image_bgr = augmentation_pipeline(image_bgr)
# Convert back to RGB and then to PIL format # Convert back to RGB and then to PIL format
augmented_image_rgb = cv2.cvtColor(augmented_image_bgr, cv2.COLOR_BGR2RGB) augmented_image_rgb = cv2.cvtColor(augmented_image_bgr, cv2.COLOR_BGR2RGB)
augmented_image_pil = Image.fromarray(augmented_image_rgb) augmented_image_pil = Image.fromarray(augmented_image_rgb)
# Update the sample with the augmented image # Update the sample with the augmented image
sample["image"] = augmented_image_pil sample["image"] = augmented_image_pil
# Double-check PIL image size matches original # Double-check PIL image size matches original
assert augmented_image_pil.size == image.size, ( assert augmented_image_pil.size == image.size, f"PIL image size changed during augmentation: {image.size} -> {augmented_image_pil.size}"
f"PIL image size changed during augmentation: {image.size} -> {augmented_image_pil.size}"
)
return sample return sample
@dataclass(frozen=True, slots=True) @dataclass(frozen=True, slots=True)
class InstructUserMessages(PipelineStep): class InstructUserMessages(PipelineStep):
"""Creates instruction-following messages format for training.""" """Creates instruction-following messages format for training."""
prompt_first: bool = False prompt_first: bool = False
def __call__(self, sample: Sample) -> Sample: def __call__(self, sample: Sample) -> Sample:
@ -913,12 +904,12 @@ if __name__ == "__main__":
print(f"PDF file: {sample['pdf_path'].name}") print(f"PDF file: {sample['pdf_path'].name}")
if "image" in sample and hasattr(sample["image"], "size"): if "image" in sample and hasattr(sample["image"], "size"):
print(f"Image size: {sample['image'].size}") print(f"Image size: {sample['image'].size}")
# Save image if requested # Save image if requested
if args.save_image: if args.save_image:
sample["image"].save(args.save_image) sample["image"].save(args.save_image)
print(f"Saved image to: {args.save_image}") print(f"Saved image to: {args.save_image}")
if "page_data" in sample: if "page_data" in sample:
print(f"\nPage data: {sample['page_data']}") print(f"\nPage data: {sample['page_data']}")
if "messages" in sample: if "messages" in sample:

View File

@ -14,8 +14,8 @@ def zeropower_via_newtonschulz5(G, steps: int):
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
performance at all relative to UV^T, where USV^T = G is the SVD. performance at all relative to UV^T, where USV^T = G is the SVD.
""" """
assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
a, b, c = (3.4445, -4.7750, 2.0315) a, b, c = (3.4445, -4.7750, 2.0315)
X = G.bfloat16() X = G.bfloat16()
if G.size(-2) > G.size(-1): if G.size(-2) > G.size(-1):
X = X.mT X = X.mT
@ -25,9 +25,9 @@ def zeropower_via_newtonschulz5(G, steps: int):
# Perform the NS iterations # Perform the NS iterations
for _ in range(steps): for _ in range(steps):
A = X @ X.mT A = X @ X.mT
B = b * A + c * A @ A # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng B = b * A + c * A @ A # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
X = a * X + B @ X X = a * X + B @ X
if G.size(-2) > G.size(-1): if G.size(-2) > G.size(-1):
X = X.mT X = X.mT
return X return X
@ -36,10 +36,10 @@ def zeropower_via_newtonschulz5(G, steps: int):
def muon_update(grad, momentum, beta=0.95, ns_steps=5, nesterov=True): def muon_update(grad, momentum, beta=0.95, ns_steps=5, nesterov=True):
momentum.lerp_(grad, 1 - beta) momentum.lerp_(grad, 1 - beta)
update = grad.lerp_(momentum, beta) if nesterov else momentum update = grad.lerp_(momentum, beta) if nesterov else momentum
if update.ndim == 4: # for the case of conv filters if update.ndim == 4: # for the case of conv filters
update = update.view(len(update), -1) update = update.view(len(update), -1)
update = zeropower_via_newtonschulz5(update, steps=ns_steps) update = zeropower_via_newtonschulz5(update, steps=ns_steps)
update *= max(1, grad.size(-2) / grad.size(-1))**0.5 update *= max(1, grad.size(-2) / grad.size(-1)) ** 0.5
return update return update
@ -64,6 +64,7 @@ class Muon(torch.optim.Optimizer):
weight_decay: The AdamW-style weight decay. weight_decay: The AdamW-style weight decay.
momentum: The momentum. A value of 0.95 here is usually fine. momentum: The momentum. A value of 0.95 here is usually fine.
""" """
def __init__(self, params, lr=0.02, weight_decay=0, momentum=0.95): def __init__(self, params, lr=0.02, weight_decay=0, momentum=0.95):
defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum)
assert isinstance(params, list) and len(params) >= 1 and isinstance(params[0], torch.nn.Parameter) assert isinstance(params, list) and len(params) >= 1 and isinstance(params[0], torch.nn.Parameter)
@ -81,7 +82,7 @@ class Muon(torch.optim.Optimizer):
for group in self.param_groups: for group in self.param_groups:
params = group["params"] params = group["params"]
params_pad = params + [torch.empty_like(params[-1])] * (dist.get_world_size() - len(params) % dist.get_world_size()) params_pad = params + [torch.empty_like(params[-1])] * (dist.get_world_size() - len(params) % dist.get_world_size())
for base_i in range(len(params))[::dist.get_world_size()]: for base_i in range(len(params))[:: dist.get_world_size()]:
if base_i + dist.get_rank() < len(params): if base_i + dist.get_rank() < len(params):
p = params[base_i + dist.get_rank()] p = params[base_i + dist.get_rank()]
if p.grad is None: if p.grad is None:
@ -93,7 +94,7 @@ class Muon(torch.optim.Optimizer):
update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"]) update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"])
p.mul_(1 - group["lr"] * group["weight_decay"]) p.mul_(1 - group["lr"] * group["weight_decay"])
p.add_(update.reshape(p.shape), alpha=-group["lr"]) p.add_(update.reshape(p.shape), alpha=-group["lr"])
dist.all_gather(params_pad[base_i:base_i + dist.get_world_size()], params_pad[base_i + dist.get_rank()]) dist.all_gather(params_pad[base_i : base_i + dist.get_world_size()], params_pad[base_i + dist.get_rank()])
return loss return loss
@ -102,6 +103,7 @@ class SingleDeviceMuon(torch.optim.Optimizer):
""" """
Muon variant for usage in non-distributed settings. Muon variant for usage in non-distributed settings.
""" """
def __init__(self, params, lr=0.02, weight_decay=0, momentum=0.95): def __init__(self, params, lr=0.02, weight_decay=0, momentum=0.95):
defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum)
super().__init__(params, defaults) super().__init__(params, defaults)
@ -132,8 +134,8 @@ class SingleDeviceMuon(torch.optim.Optimizer):
def adam_update(grad, buf1, buf2, step, betas, eps): def adam_update(grad, buf1, buf2, step, betas, eps):
buf1.lerp_(grad, 1 - betas[0]) buf1.lerp_(grad, 1 - betas[0])
buf2.lerp_(grad.square(), 1 - betas[1]) buf2.lerp_(grad.square(), 1 - betas[1])
buf1c = buf1 / (1 - betas[0]**step) buf1c = buf1 / (1 - betas[0] ** step)
buf2c = buf2 / (1 - betas[1]**step) buf2c = buf2 / (1 - betas[1] ** step)
return buf1c / (buf2c.sqrt() + eps) return buf1c / (buf2c.sqrt() + eps)
@ -164,6 +166,7 @@ class MuonWithAuxAdam(torch.optim.Optimizer):
optimizer = MuonWithAuxAdam(param_groups) optimizer = MuonWithAuxAdam(param_groups)
``` ```
""" """
def __init__(self, param_groups): def __init__(self, param_groups):
for group in param_groups: for group in param_groups:
assert "use_muon" in group assert "use_muon" in group
@ -195,7 +198,7 @@ class MuonWithAuxAdam(torch.optim.Optimizer):
if group["use_muon"]: if group["use_muon"]:
params = group["params"] params = group["params"]
params_pad = params + [torch.empty_like(params[-1])] * (dist.get_world_size() - len(params) % dist.get_world_size()) params_pad = params + [torch.empty_like(params[-1])] * (dist.get_world_size() - len(params) % dist.get_world_size())
for base_i in range(len(params))[::dist.get_world_size()]: for base_i in range(len(params))[:: dist.get_world_size()]:
if base_i + dist.get_rank() < len(params): if base_i + dist.get_rank() < len(params):
p = params[base_i + dist.get_rank()] p = params[base_i + dist.get_rank()]
if p.grad is None: if p.grad is None:
@ -207,7 +210,7 @@ class MuonWithAuxAdam(torch.optim.Optimizer):
update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"]) update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"])
p.mul_(1 - group["lr"] * group["weight_decay"]) p.mul_(1 - group["lr"] * group["weight_decay"])
p.add_(update.reshape(p.shape), alpha=-group["lr"]) p.add_(update.reshape(p.shape), alpha=-group["lr"])
dist.all_gather(params_pad[base_i:base_i + dist.get_world_size()], params_pad[base_i + dist.get_rank()]) dist.all_gather(params_pad[base_i : base_i + dist.get_world_size()], params_pad[base_i + dist.get_rank()])
else: else:
for p in group["params"]: for p in group["params"]:
if p.grad is None: if p.grad is None:
@ -219,8 +222,7 @@ class MuonWithAuxAdam(torch.optim.Optimizer):
state["exp_avg_sq"] = torch.zeros_like(p) state["exp_avg_sq"] = torch.zeros_like(p)
state["step"] = 0 state["step"] = 0
state["step"] += 1 state["step"] += 1
update = adam_update(p.grad, state["exp_avg"], state["exp_avg_sq"], update = adam_update(p.grad, state["exp_avg"], state["exp_avg_sq"], state["step"], group["betas"], group["eps"])
state["step"], group["betas"], group["eps"])
p.mul_(1 - group["lr"] * group["weight_decay"]) p.mul_(1 - group["lr"] * group["weight_decay"])
p.add_(update, alpha=-group["lr"]) p.add_(update, alpha=-group["lr"])
@ -231,6 +233,7 @@ class SingleDeviceMuonWithAuxAdam(torch.optim.Optimizer):
""" """
Non-distributed variant of MuonWithAuxAdam. Non-distributed variant of MuonWithAuxAdam.
""" """
def __init__(self, param_groups): def __init__(self, param_groups):
for group in param_groups: for group in param_groups:
assert "use_muon" in group assert "use_muon" in group
@ -280,9 +283,8 @@ class SingleDeviceMuonWithAuxAdam(torch.optim.Optimizer):
state["exp_avg_sq"] = torch.zeros_like(p) state["exp_avg_sq"] = torch.zeros_like(p)
state["step"] = 0 state["step"] = 0
state["step"] += 1 state["step"] += 1
update = adam_update(p.grad, state["exp_avg"], state["exp_avg_sq"], update = adam_update(p.grad, state["exp_avg"], state["exp_avg_sq"], state["step"], group["betas"], group["eps"])
state["step"], group["betas"], group["eps"])
p.mul_(1 - group["lr"] * group["weight_decay"]) p.mul_(1 - group["lr"] * group["weight_decay"])
p.add_(update, alpha=-group["lr"]) p.add_(update, alpha=-group["lr"])
return loss return loss

View File

@ -163,7 +163,7 @@ def prepare_olmocr_mix(dataset_path: str, subset: str, split: str, destination:
f.write(natural_text) f.write(natural_text)
else: else:
f.write("---") f.write("---")
# Look for matching PDF in extracted directory and create symlinks # Look for matching PDF in extracted directory and create symlinks
extracted_pdfs_dir = dest_path / "hugging_face" / "pdf_tarballs" / "extracted" extracted_pdfs_dir = dest_path / "hugging_face" / "pdf_tarballs" / "extracted"

View File

@ -4,26 +4,25 @@ Simple script to test OlmOCR dataset loading with YAML configuration.
import argparse import argparse
import logging import logging
import os
import math import math
import os
import shutil import shutil
from typing import Any, Dict, Optional
import numpy as np import numpy as np
import torch import torch
from torch.utils.data import ConcatDataset, DataLoader
from torch.optim import AdamW
from torch.amp import autocast
import wandb import wandb
from torch.amp import autocast
from torch.optim import AdamW
from torch.utils.data import ConcatDataset, DataLoader
from tqdm import tqdm from tqdm import tqdm
from transformers import ( from transformers import (
AutoProcessor, AutoProcessor,
get_scheduler,
Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration,
Qwen2VLForConditionalGeneration, Qwen2VLForConditionalGeneration,
get_scheduler,
) )
from typing import Optional, Dict, Any
from olmocr.train.config import Config from olmocr.train.config import Config
from olmocr.train.dataloader import BaseMarkdownPDFDataset from olmocr.train.dataloader import BaseMarkdownPDFDataset
from olmocr.train.muon import SingleDeviceMuonWithAuxAdam from olmocr.train.muon import SingleDeviceMuonWithAuxAdam
@ -37,7 +36,6 @@ logging.basicConfig(
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class QwenDataCollator: class QwenDataCollator:
"""Data collator for vision-language models that handles numpy arrays.""" """Data collator for vision-language models that handles numpy arrays."""
@ -80,7 +78,7 @@ class QwenDataCollator:
# Check if we have any valid samples # Check if we have any valid samples
if not batch["input_ids"]: if not batch["input_ids"]:
return None return None
# Convert lists to tensors with proper padding # Convert lists to tensors with proper padding
# Note: For Qwen2-VL, we typically handle variable length sequences # Note: For Qwen2-VL, we typically handle variable length sequences
# The model's processor should handle the padding internally # The model's processor should handle the padding internally
@ -107,14 +105,14 @@ def save_checkpoint(
"""Save model, optimizer, scheduler, and training state.""" """Save model, optimizer, scheduler, and training state."""
checkpoint_dir = os.path.join(output_dir, f"checkpoint-{global_step}") checkpoint_dir = os.path.join(output_dir, f"checkpoint-{global_step}")
os.makedirs(checkpoint_dir, exist_ok=True) os.makedirs(checkpoint_dir, exist_ok=True)
# Save model # Save model
model.save_pretrained(checkpoint_dir) model.save_pretrained(checkpoint_dir)
# Save optimizer and scheduler # Save optimizer and scheduler
torch.save(optimizer.state_dict(), os.path.join(checkpoint_dir, "optimizer.pt")) torch.save(optimizer.state_dict(), os.path.join(checkpoint_dir, "optimizer.pt"))
torch.save(lr_scheduler.state_dict(), os.path.join(checkpoint_dir, "scheduler.pt")) torch.save(lr_scheduler.state_dict(), os.path.join(checkpoint_dir, "scheduler.pt"))
# Save training state # Save training state
state = { state = {
"epoch": epoch, "epoch": epoch,
@ -123,15 +121,12 @@ def save_checkpoint(
"best_metric": best_metric, "best_metric": best_metric,
} }
torch.save(state, os.path.join(checkpoint_dir, "training_state.pt")) torch.save(state, os.path.join(checkpoint_dir, "training_state.pt"))
logger.info(f"Saved checkpoint to {checkpoint_dir}") logger.info(f"Saved checkpoint to {checkpoint_dir}")
# Enforce save_total_limit by removing oldest checkpoints # Enforce save_total_limit by removing oldest checkpoints
if save_total_limit is not None and save_total_limit > 0: if save_total_limit is not None and save_total_limit > 0:
checkpoints = sorted( checkpoints = sorted([d for d in os.listdir(output_dir) if d.startswith("checkpoint-")], key=lambda x: int(x.split("-")[1]))
[d for d in os.listdir(output_dir) if d.startswith("checkpoint-")],
key=lambda x: int(x.split("-")[1])
)
while len(checkpoints) > save_total_limit: while len(checkpoints) > save_total_limit:
oldest = checkpoints.pop(0) oldest = checkpoints.pop(0)
shutil.rmtree(os.path.join(output_dir, oldest)) shutil.rmtree(os.path.join(output_dir, oldest))
@ -149,10 +144,10 @@ def load_checkpoint(
"""Load model, optimizer, scheduler, and training state from checkpoint.""" """Load model, optimizer, scheduler, and training state from checkpoint."""
model = model_class.from_pretrained(checkpoint_dir, **init_kwargs) model = model_class.from_pretrained(checkpoint_dir, **init_kwargs)
model.to(device) model.to(device)
optimizer.load_state_dict(torch.load(os.path.join(checkpoint_dir, "optimizer.pt"), map_location=device)) optimizer.load_state_dict(torch.load(os.path.join(checkpoint_dir, "optimizer.pt"), map_location=device))
lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint_dir, "scheduler.pt"), map_location=device)) lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint_dir, "scheduler.pt"), map_location=device))
state = torch.load(os.path.join(checkpoint_dir, "training_state.pt"), map_location=device) state = torch.load(os.path.join(checkpoint_dir, "training_state.pt"), map_location=device)
logger.info(f"Resumed from checkpoint: {checkpoint_dir} at epoch {state['epoch']:.2f}, step {state['global_step']}, samples seen {state['samples_seen']}") logger.info(f"Resumed from checkpoint: {checkpoint_dir} at epoch {state['epoch']:.2f}, step {state['global_step']}, samples seen {state['samples_seen']}")
return model, state return model, state
@ -166,11 +161,11 @@ def evaluate_model(
"""Evaluate on all eval datasets and return average loss per dataset.""" """Evaluate on all eval datasets and return average loss per dataset."""
model.eval() model.eval()
eval_metrics = {} eval_metrics = {}
for dataset_name, dataloader in eval_dataloaders.items(): for dataset_name, dataloader in eval_dataloaders.items():
total_loss = 0.0 total_loss = 0.0
num_batches = 0 num_batches = 0
with torch.no_grad(): with torch.no_grad():
for batch in dataloader: for batch in dataloader:
# Skip if batch is None (all samples were filtered out) # Skip if batch is None (all samples were filtered out)
@ -181,16 +176,16 @@ def evaluate_model(
outputs = model(**batch) outputs = model(**batch)
total_loss += outputs.loss.item() total_loss += outputs.loss.item()
num_batches += 1 num_batches += 1
avg_loss = total_loss / num_batches if num_batches > 0 else 0.0 avg_loss = total_loss / num_batches if num_batches > 0 else 0.0
eval_metrics[f"eval_{dataset_name}_loss"] = avg_loss eval_metrics[f"eval_{dataset_name}_loss"] = avg_loss
logger.info(f"Eval {dataset_name} loss: {avg_loss:.4f}") logger.info(f"Eval {dataset_name} loss: {avg_loss:.4f}")
# Compute overall eval loss as average across datasets (or customize as needed) # Compute overall eval loss as average across datasets (or customize as needed)
if eval_metrics: if eval_metrics:
overall_loss = sum(eval_metrics.values()) / len(eval_metrics) overall_loss = sum(eval_metrics.values()) / len(eval_metrics)
eval_metrics["eval_loss"] = overall_loss eval_metrics["eval_loss"] = overall_loss
return eval_metrics return eval_metrics
@ -215,11 +210,11 @@ def main():
if config.project_name: if config.project_name:
os.environ["WANDB_PROJECT"] = config.project_name os.environ["WANDB_PROJECT"] = config.project_name
logger.info(f"Setting WANDB_PROJECT to: {config.project_name}") logger.info(f"Setting WANDB_PROJECT to: {config.project_name}")
# Initialize wandb if reporting to it # Initialize wandb if reporting to it
if "wandb" in config.training.report_to: if "wandb" in config.training.report_to:
wandb.init(project=config.project_name, name=config.run_name, config=config.to_dict()) wandb.init(project=config.project_name, name=config.run_name, config=config.to_dict())
# Load processor for tokenization # Load processor for tokenization
logger.info(f"Loading processor: {config.model.name}") logger.info(f"Loading processor: {config.model.name}")
processor = AutoProcessor.from_pretrained( processor = AutoProcessor.from_pretrained(
@ -284,7 +279,6 @@ def main():
if len(dataset) > 0: if len(dataset) > 0:
eval_datasets[dataset_name] = dataset eval_datasets[dataset_name] = dataset
# Log total evaluation samples across all datasets # Log total evaluation samples across all datasets
total_eval_samples = sum(len(dataset) for dataset in eval_datasets.values()) total_eval_samples = sum(len(dataset) for dataset in eval_datasets.values())
logger.info(f"Total evaluation samples across {len(eval_datasets)} datasets: {total_eval_samples}") logger.info(f"Total evaluation samples across {len(eval_datasets)} datasets: {total_eval_samples}")
@ -310,14 +304,15 @@ def main():
# Set seeds # Set seeds
torch.manual_seed(config.training.seed) torch.manual_seed(config.training.seed)
# Set up data loader seed worker function # Set up data loader seed worker function
def seed_worker(worker_id): def seed_worker(worker_id):
worker_seed = torch.initial_seed() % 2**32 worker_seed = torch.initial_seed() % 2**32
np.random.seed(worker_seed) np.random.seed(worker_seed)
import random import random
random.seed(worker_seed) random.seed(worker_seed)
# Create generator for data loader # Create generator for data loader
generator = None generator = None
if config.training.data_seed is not None: if config.training.data_seed is not None:
@ -327,7 +322,7 @@ def main():
# Device setup # Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device) model.to(device)
# Apply torch compile if enabled # Apply torch compile if enabled
if config.training.torch_compile: if config.training.torch_compile:
logger.info(f"Compiling model with torch.compile (backend={config.training.torch_compile_backend}, mode={config.training.torch_compile_mode})") logger.info(f"Compiling model with torch.compile (backend={config.training.torch_compile_backend}, mode={config.training.torch_compile_mode})")
@ -365,29 +360,29 @@ def main():
embed_params = [p for n, p in model.named_parameters() if "embed" in n] embed_params = [p for n, p in model.named_parameters() if "embed" in n]
scalar_params = [p for p in model.parameters() if p.ndim < 2] scalar_params = [p for p in model.parameters() if p.ndim < 2]
head_params = [p for n, p in model.named_parameters() if "lm_head" in n] head_params = [p for n, p in model.named_parameters() if "lm_head" in n]
# Create Adam groups with different learning rates # Create Adam groups with different learning rates
adam_groups = [ adam_groups = [
dict(params=head_params, lr=float(config.training.learning_rate) * config.training.muon_lr_multiplier_head, use_muon=False), dict(params=head_params, lr=float(config.training.learning_rate) * config.training.muon_lr_multiplier_head, use_muon=False),
dict(params=embed_params, lr=float(config.training.learning_rate) * config.training.muon_lr_multiplier_embed, use_muon=False), dict(params=embed_params, lr=float(config.training.learning_rate) * config.training.muon_lr_multiplier_embed, use_muon=False),
dict(params=scalar_params, lr=float(config.training.learning_rate) * config.training.muon_lr_multiplier_scalar, use_muon=False) dict(params=scalar_params, lr=float(config.training.learning_rate) * config.training.muon_lr_multiplier_scalar, use_muon=False),
] ]
# Add Adam hyperparameters to groups # Add Adam hyperparameters to groups
for g in adam_groups: for g in adam_groups:
g["betas"] = (config.training.adam_beta1, config.training.adam_beta2) g["betas"] = (config.training.adam_beta1, config.training.adam_beta2)
g["eps"] = float(config.training.adam_epsilon) g["eps"] = float(config.training.adam_epsilon)
g["weight_decay"] = config.training.weight_decay g["weight_decay"] = config.training.weight_decay
# Create Muon group # Create Muon group
muon_group = dict( muon_group = dict(
params=hidden_matrix_params, params=hidden_matrix_params,
lr=float(config.training.learning_rate), lr=float(config.training.learning_rate),
momentum=config.training.muon_momentum, momentum=config.training.muon_momentum,
weight_decay=config.training.weight_decay, weight_decay=config.training.weight_decay,
use_muon=True use_muon=True,
) )
# Combine all groups # Combine all groups
param_groups = [*adam_groups, muon_group] param_groups = [*adam_groups, muon_group]
optimizer = SingleDeviceMuonWithAuxAdam(param_groups) optimizer = SingleDeviceMuonWithAuxAdam(param_groups)
@ -416,7 +411,7 @@ def main():
global_step = 0 global_step = 0
samples_seen = 0 samples_seen = 0
best_metric = float("inf") if not config.training.greater_is_better else -float("inf") best_metric = float("inf") if not config.training.greater_is_better else -float("inf")
if found_resumable_checkpoint: if found_resumable_checkpoint:
model, state = load_checkpoint(model_class, model_init_kwargs, optimizer, lr_scheduler, found_resumable_checkpoint, device) model, state = load_checkpoint(model_class, model_init_kwargs, optimizer, lr_scheduler, found_resumable_checkpoint, device)
global_step = state["global_step"] global_step = state["global_step"]
@ -457,7 +452,7 @@ def main():
current_epoch = samples_seen / len(train_dataset) current_epoch = samples_seen / len(train_dataset)
logger.info(f"Starting training from epoch {current_epoch:.2f} (step {global_step}, samples {samples_seen}) to {config.training.num_train_epochs} epochs") logger.info(f"Starting training from epoch {current_epoch:.2f} (step {global_step}, samples {samples_seen}) to {config.training.num_train_epochs} epochs")
logger.info(f"Total training steps: {max_train_steps}, Total samples to process: {max_train_samples}") logger.info(f"Total training steps: {max_train_steps}, Total samples to process: {max_train_samples}")
if samples_seen >= max_train_samples: if samples_seen >= max_train_samples:
logger.info("Training already completed based on samples seen!") logger.info("Training already completed based on samples seen!")
logger.info("Skipping to final model save.") logger.info("Skipping to final model save.")
@ -465,7 +460,7 @@ def main():
model.train() model.train()
accumulated_loss = 0.0 accumulated_loss = 0.0
num_losses_accumulated = 0 num_losses_accumulated = 0
# Create epoch iterator and skip samples if resuming # Create epoch iterator and skip samples if resuming
epoch_iterator = iter(train_dataloader) epoch_iterator = iter(train_dataloader)
if samples_seen > 0: if samples_seen > 0:
@ -479,10 +474,10 @@ def main():
# We've reached the end of the epoch while skipping, create new iterator # We've reached the end of the epoch while skipping, create new iterator
epoch_iterator = iter(train_dataloader) epoch_iterator = iter(train_dataloader)
break break
# Create progress bar # Create progress bar
pbar = tqdm(total=max_train_samples - samples_seen, desc=f"Training from step {global_step}", unit="samples") pbar = tqdm(total=max_train_samples - samples_seen, desc=f"Training from step {global_step}", unit="samples")
while samples_seen < max_train_samples and global_step < max_train_steps: while samples_seen < max_train_samples and global_step < max_train_steps:
try: try:
batch = next(epoch_iterator) batch = next(epoch_iterator)
@ -492,48 +487,43 @@ def main():
logger.info(f"Completed epoch {current_epoch:.2f}") logger.info(f"Completed epoch {current_epoch:.2f}")
epoch_iterator = iter(train_dataloader) epoch_iterator = iter(train_dataloader)
batch = next(epoch_iterator) batch = next(epoch_iterator)
# Skip if batch is None (all samples were filtered out) # Skip if batch is None (all samples were filtered out)
if batch is None: if batch is None:
continue continue
batch = {k: v.to(device) for k, v in batch.items()} batch = {k: v.to(device) for k, v in batch.items()}
with autocast(device_type="cuda", enabled=True, dtype=torch.bfloat16): with autocast(device_type="cuda", enabled=True, dtype=torch.bfloat16):
outputs = model(**batch) outputs = model(**batch)
loss = outputs.loss / config.training.gradient_accumulation_steps loss = outputs.loss / config.training.gradient_accumulation_steps
loss.backward() loss.backward()
accumulated_loss += outputs.loss.item() # Use undivided loss for logging accumulated_loss += outputs.loss.item() # Use undivided loss for logging
num_losses_accumulated += 1 num_losses_accumulated += 1
samples_seen += config.training.per_device_train_batch_size samples_seen += config.training.per_device_train_batch_size
# Update progress bar # Update progress bar
pbar.update(config.training.per_device_train_batch_size) pbar.update(config.training.per_device_train_batch_size)
# Check if we should do a gradient update # Check if we should do a gradient update
if samples_seen % samples_per_step == 0 or samples_seen >= max_train_samples: if samples_seen % samples_per_step == 0 or samples_seen >= max_train_samples:
# Clip gradients # Clip gradients
torch.nn.utils.clip_grad_norm_(model.parameters(), config.training.max_grad_norm) torch.nn.utils.clip_grad_norm_(model.parameters(), config.training.max_grad_norm)
# Step optimizer and scheduler # Step optimizer and scheduler
optimizer.step() optimizer.step()
lr_scheduler.step() lr_scheduler.step()
optimizer.zero_grad() optimizer.zero_grad()
global_step += 1 global_step += 1
current_epoch = samples_seen / len(train_dataset) current_epoch = samples_seen / len(train_dataset)
# Update progress bar with current stats # Update progress bar with current stats
current_lr = lr_scheduler.get_last_lr()[0] current_lr = lr_scheduler.get_last_lr()[0]
avg_loss = accumulated_loss / num_losses_accumulated if num_losses_accumulated > 0 else 0 avg_loss = accumulated_loss / num_losses_accumulated if num_losses_accumulated > 0 else 0
pbar.set_postfix({ pbar.set_postfix({"loss": f"{avg_loss:.4f}", "lr": f"{current_lr:.2e}", "epoch": f"{current_epoch:.2f}", "step": global_step})
'loss': f'{avg_loss:.4f}',
'lr': f'{current_lr:.2e}',
'epoch': f'{current_epoch:.2f}',
'step': global_step
})
# Logging # Logging
if config.training.logging_steps > 0 and global_step % config.training.logging_steps == 0: if config.training.logging_steps > 0 and global_step % config.training.logging_steps == 0:
avg_train_loss = accumulated_loss / num_losses_accumulated if num_losses_accumulated > 0 else 0 avg_train_loss = accumulated_loss / num_losses_accumulated if num_losses_accumulated > 0 else 0
@ -546,52 +536,49 @@ def main():
logger.info(f"Step {global_step}: epoch={current_epoch:.3f}, loss={avg_train_loss:.4f}, lr={lr_scheduler.get_last_lr()[0]:.2e}") logger.info(f"Step {global_step}: epoch={current_epoch:.3f}, loss={avg_train_loss:.4f}, lr={lr_scheduler.get_last_lr()[0]:.2e}")
if "wandb" in config.training.report_to: if "wandb" in config.training.report_to:
wandb.log(logs, step=global_step) wandb.log(logs, step=global_step)
accumulated_loss = 0.0 accumulated_loss = 0.0
num_losses_accumulated = 0 num_losses_accumulated = 0
# Evaluation # Evaluation
if config.training.eval_steps > 0 and global_step % config.training.eval_steps == 0 and global_step > 0: if config.training.eval_steps > 0 and global_step % config.training.eval_steps == 0 and global_step > 0:
metrics = evaluate_model(model, eval_dataloaders, device) metrics = evaluate_model(model, eval_dataloaders, device)
logger.info(f"Evaluation at step {global_step}: {metrics}") logger.info(f"Evaluation at step {global_step}: {metrics}")
if "wandb" in config.training.report_to: if "wandb" in config.training.report_to:
wandb.log(metrics, step=global_step) wandb.log(metrics, step=global_step)
# Update best metric # Update best metric
current_metric = metrics.get(config.training.metric_for_best_model, None) current_metric = metrics.get(config.training.metric_for_best_model, None)
if current_metric is not None: if current_metric is not None:
if (config.training.greater_is_better and current_metric > best_metric) or \ if (config.training.greater_is_better and current_metric > best_metric) or (
(not config.training.greater_is_better and current_metric < best_metric): not config.training.greater_is_better and current_metric < best_metric
):
best_metric = current_metric best_metric = current_metric
# Return to training mode # Return to training mode
model.train() model.train()
# Saving # Saving
if config.training.save_steps > 0 and global_step % config.training.save_steps == 0: if config.training.save_steps > 0 and global_step % config.training.save_steps == 0:
save_checkpoint( save_checkpoint(
model, optimizer, lr_scheduler, current_epoch, global_step, samples_seen, best_metric, model, optimizer, lr_scheduler, current_epoch, global_step, samples_seen, best_metric, full_output_dir, config.training.save_total_limit
full_output_dir, config.training.save_total_limit
) )
# Check if we've reached our training limit # Check if we've reached our training limit
if samples_seen >= max_train_samples or global_step >= max_train_steps: if samples_seen >= max_train_samples or global_step >= max_train_steps:
break break
# Close progress bar # Close progress bar
pbar.close() pbar.close()
# Save the final checkpoint with step number # Save the final checkpoint with step number
logger.info(f"Saving final checkpoint at step {global_step}...") logger.info(f"Saving final checkpoint at step {global_step}...")
save_checkpoint( save_checkpoint(model, optimizer, lr_scheduler, current_epoch, global_step, samples_seen, best_metric, full_output_dir, config.training.save_total_limit)
model, optimizer, lr_scheduler, current_epoch, global_step, samples_seen, best_metric,
full_output_dir, config.training.save_total_limit
)
# Log final training state # Log final training state
final_epoch = samples_seen / len(train_dataset) final_epoch = samples_seen / len(train_dataset)
logger.info(f"Training completed at epoch {final_epoch:.3f}, step {global_step}, samples {samples_seen}") logger.info(f"Training completed at epoch {final_epoch:.3f}, step {global_step}, samples {samples_seen}")
# Final evaluation # Final evaluation
final_metrics = evaluate_model(model, eval_dataloaders, device) final_metrics = evaluate_model(model, eval_dataloaders, device)
logger.info(f"Final evaluation metrics: {final_metrics}") logger.info(f"Final evaluation metrics: {final_metrics}")
@ -601,4 +588,4 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -171,7 +171,6 @@ class WorkQueue:
logger.info(f"Initialized queue with {self.size:,} work items") logger.info(f"Initialized queue with {self.size:,} work items")
return self.size return self.size
async def get_work(self, worker_lock_timeout_secs: int = 1800) -> Optional[WorkItem]: async def get_work(self, worker_lock_timeout_secs: int = 1800) -> Optional[WorkItem]:
""" """
Get the next available work item that isn't completed or locked. Get the next available work item that isn't completed or locked.
@ -179,7 +178,6 @@ class WorkQueue:
REFRESH_COMPLETED_HASH_CACHE_MAX_ATTEMPTS = 3 REFRESH_COMPLETED_HASH_CACHE_MAX_ATTEMPTS = 3
refresh_completed_hash_attempt = 0 refresh_completed_hash_attempt = 0
while True: while True:
try: try:
work_item = self._queue.get_nowait() work_item = self._queue.get_nowait()
@ -221,7 +219,7 @@ class WorkQueue:
""" """
# Create done flag in done_flags_dir # Create done flag in done_flags_dir
await self.backend.create_done_flag(work_item.hash) await self.backend.create_done_flag(work_item.hash)
# Remove the worker lock # Remove the worker lock
await self.backend.delete_worker_lock(work_item.hash) await self.backend.delete_worker_lock(work_item.hash)
self._queue.task_done() self._queue.task_done()
@ -281,11 +279,7 @@ class LocalBackend(Backend):
def _list_completed() -> Set[str]: def _list_completed() -> Set[str]:
if not os.path.isdir(self._done_flags_dir): if not os.path.isdir(self._done_flags_dir):
return set() return set()
return { return {f[len("done_") : -len(".flag")] for f in os.listdir(self._done_flags_dir) if f.startswith("done_") and f.endswith(".flag")}
f[len("done_") : -len(".flag")]
for f in os.listdir(self._done_flags_dir)
if f.startswith("done_") and f.endswith(".flag")
}
return await asyncio.to_thread(_list_completed) return await asyncio.to_thread(_list_completed)
@ -299,6 +293,7 @@ class LocalBackend(Backend):
async def _get_object_mtime(self, path: str) -> Optional[datetime.datetime]: async def _get_object_mtime(self, path: str) -> Optional[datetime.datetime]:
"""Internal method to get object mtime.""" """Internal method to get object mtime."""
def _get_mtime() -> Optional[datetime.datetime]: def _get_mtime() -> Optional[datetime.datetime]:
if not os.path.exists(path): if not os.path.exists(path):
return None return None
@ -310,17 +305,17 @@ class LocalBackend(Backend):
"""Check if a worker lock is taken and not stale.""" """Check if a worker lock is taken and not stale."""
lock_path = self._get_worker_lock_path(work_hash) lock_path = self._get_worker_lock_path(work_hash)
lock_mtime = await self._get_object_mtime(lock_path) lock_mtime = await self._get_object_mtime(lock_path)
if not lock_mtime: if not lock_mtime:
return False return False
now = datetime.datetime.now(datetime.timezone.utc) now = datetime.datetime.now(datetime.timezone.utc)
return (now - lock_mtime).total_seconds() <= worker_lock_timeout_secs return (now - lock_mtime).total_seconds() <= worker_lock_timeout_secs
async def create_worker_lock(self, work_hash: str) -> None: async def create_worker_lock(self, work_hash: str) -> None:
"""Create a worker lock for a work hash.""" """Create a worker lock for a work hash."""
lock_path = self._get_worker_lock_path(work_hash) lock_path = self._get_worker_lock_path(work_hash)
def _create() -> None: def _create() -> None:
with open(lock_path, "wb"): with open(lock_path, "wb"):
pass pass
@ -330,7 +325,7 @@ class LocalBackend(Backend):
async def delete_worker_lock(self, work_hash: str) -> None: async def delete_worker_lock(self, work_hash: str) -> None:
"""Delete the worker lock for a work hash if it exists.""" """Delete the worker lock for a work hash if it exists."""
lock_path = self._get_worker_lock_path(work_hash) lock_path = self._get_worker_lock_path(work_hash)
def _delete() -> None: def _delete() -> None:
if os.path.exists(lock_path): if os.path.exists(lock_path):
os.remove(lock_path) os.remove(lock_path)
@ -345,7 +340,7 @@ class LocalBackend(Backend):
async def create_done_flag(self, work_hash: str) -> None: async def create_done_flag(self, work_hash: str) -> None:
"""Create a done flag for a work hash.""" """Create a done flag for a work hash."""
done_flag_path = self._get_done_flag_path(work_hash) done_flag_path = self._get_done_flag_path(work_hash)
def _create() -> None: def _create() -> None:
with open(done_flag_path, "wb"): with open(done_flag_path, "wb"):
pass pass
@ -406,10 +401,10 @@ class S3Backend(Backend):
"""Check if a worker lock is taken and not stale.""" """Check if a worker lock is taken and not stale."""
lock_path = self._get_worker_lock_path(work_hash) lock_path = self._get_worker_lock_path(work_hash)
lock_mtime = await self._get_object_mtime(lock_path) lock_mtime = await self._get_object_mtime(lock_path)
if not lock_mtime: if not lock_mtime:
return False return False
now = datetime.datetime.now(datetime.timezone.utc) now = datetime.datetime.now(datetime.timezone.utc)
return (now - lock_mtime).total_seconds() <= worker_lock_timeout_secs return (now - lock_mtime).total_seconds() <= worker_lock_timeout_secs
@ -434,4 +429,4 @@ class S3Backend(Backend):
"""Create a done flag for a work hash.""" """Create a done flag for a work hash."""
done_flag_path = self._get_done_flag_path(work_hash) done_flag_path = self._get_done_flag_path(work_hash)
bucket, key = parse_s3_path(done_flag_path) bucket, key = parse_s3_path(done_flag_path)
await asyncio.to_thread(self.s3_client.put_object, Bucket=bucket, Key=key, Body=b"") await asyncio.to_thread(self.s3_client.put_object, Bucket=bucket, Key=key, Body=b"")

View File

@ -6,7 +6,7 @@ from unittest.mock import Mock, patch
from botocore.exceptions import ClientError from botocore.exceptions import ClientError
# Import the classes we're testing # Import the classes we're testing
from olmocr.work_queue import WorkQueue, S3Backend, WorkItem from olmocr.work_queue import S3Backend, WorkItem, WorkQueue
class TestS3WorkQueue(unittest.TestCase): class TestS3WorkQueue(unittest.TestCase):
@ -214,7 +214,7 @@ class TestS3WorkQueue(unittest.TestCase):
self.assertEqual(len(put_calls), 1) self.assertEqual(len(put_calls), 1)
done_flag_key = put_calls[0][1]["Key"] done_flag_key = put_calls[0][1]["Key"]
self.assertTrue(done_flag_key.endswith(f"done_{work_item.hash}.flag")) self.assertTrue(done_flag_key.endswith(f"done_{work_item.hash}.flag"))
# Verify lock file was deleted # Verify lock file was deleted
self.s3_client.delete_object.assert_called_once() self.s3_client.delete_object.assert_called_once()
key = self.s3_client.delete_object.call_args[1]["Key"] key = self.s3_client.delete_object.call_args[1]["Key"]