mirror of
https://github.com/allenai/olmocr.git
synced 2025-10-29 09:01:35 +00:00
Lint fixes
This commit is contained in:
parent
05330150ad
commit
93411a80a0
@ -49,7 +49,7 @@ from olmocr.s3_utils import (
|
|||||||
)
|
)
|
||||||
from olmocr.train.dataloader import FrontMatterParser
|
from olmocr.train.dataloader import FrontMatterParser
|
||||||
from olmocr.version import VERSION
|
from olmocr.version import VERSION
|
||||||
from olmocr.work_queue import WorkQueue, LocalBackend, S3Backend
|
from olmocr.work_queue import LocalBackend, S3Backend, WorkQueue
|
||||||
|
|
||||||
# Initialize logger
|
# Initialize logger
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|||||||
@ -394,9 +394,7 @@ class Config:
|
|||||||
steps.append(FrontMatterParser(front_matter_class=front_matter_class))
|
steps.append(FrontMatterParser(front_matter_class=front_matter_class))
|
||||||
|
|
||||||
elif step_name == "PDFRenderer":
|
elif step_name == "PDFRenderer":
|
||||||
steps.append(
|
steps.append(PDFRenderer(target_longest_image_dim=step_config.get("target_longest_image_dim", 1024)))
|
||||||
PDFRenderer(target_longest_image_dim=step_config.get("target_longest_image_dim", 1024))
|
|
||||||
)
|
|
||||||
|
|
||||||
elif step_name == "StaticLengthDocumentAnchoring":
|
elif step_name == "StaticLengthDocumentAnchoring":
|
||||||
steps.append(StaticLengthDocumentAnchoring(target_anchor_text_len=step_config.get("target_anchor_text_len", 6000)))
|
steps.append(StaticLengthDocumentAnchoring(target_anchor_text_len=step_config.get("target_anchor_text_len", 6000)))
|
||||||
@ -417,9 +415,7 @@ class Config:
|
|||||||
steps.append(FrontMatterOutputFormat())
|
steps.append(FrontMatterOutputFormat())
|
||||||
|
|
||||||
elif step_name == "InstructUserMessages":
|
elif step_name == "InstructUserMessages":
|
||||||
steps.append(InstructUserMessages(
|
steps.append(InstructUserMessages(prompt_first=step_config.get("prompt_first", False)))
|
||||||
prompt_first=step_config.get("prompt_first", False)
|
|
||||||
))
|
|
||||||
|
|
||||||
elif step_name == "LatexBracketNormalizer":
|
elif step_name == "LatexBracketNormalizer":
|
||||||
steps.append(LatexBracketNormalizer())
|
steps.append(LatexBracketNormalizer())
|
||||||
@ -462,18 +458,10 @@ class Config:
|
|||||||
steps.append(FilterOutRotatedDocuments())
|
steps.append(FilterOutRotatedDocuments())
|
||||||
|
|
||||||
elif step_name == "RotationAugmentation":
|
elif step_name == "RotationAugmentation":
|
||||||
steps.append(
|
steps.append(RotationAugmentation(probability=step_config.get("probability", 0.5)))
|
||||||
RotationAugmentation(
|
|
||||||
probability=step_config.get("probability", 0.5)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
elif step_name == "AugraphyBasicAugmentations":
|
elif step_name == "AugraphyBasicAugmentations":
|
||||||
steps.append(
|
steps.append(AugraphyBasicAugmentations(probability=step_config.get("probability", 0.5)))
|
||||||
AugraphyBasicAugmentations(
|
|
||||||
probability=step_config.get("probability", 0.5)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown pipeline step: {step_name}")
|
raise ValueError(f"Unknown pipeline step: {step_name}")
|
||||||
|
|||||||
@ -5,13 +5,11 @@ import re
|
|||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||||
from dataclasses import dataclass, fields
|
from dataclasses import dataclass, fields
|
||||||
from functools import reduce
|
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from os import PathLike
|
from os import PathLike
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
Callable,
|
|
||||||
Dict,
|
Dict,
|
||||||
List,
|
List,
|
||||||
Optional,
|
Optional,
|
||||||
@ -144,7 +142,7 @@ class BaseMarkdownPDFDataset(Dataset):
|
|||||||
pbar.update(1)
|
pbar.update(1)
|
||||||
|
|
||||||
# Sort samples by markdown path for consistent ordering across runs
|
# Sort samples by markdown path for consistent ordering across runs
|
||||||
self.samples.sort(key=lambda x: x['markdown_path'])
|
self.samples.sort(key=lambda x: x["markdown_path"])
|
||||||
|
|
||||||
logger.info(f"Found {valid_count} valid markdown-PDF pairs")
|
logger.info(f"Found {valid_count} valid markdown-PDF pairs")
|
||||||
|
|
||||||
@ -551,15 +549,14 @@ class AugraphyBasicAugmentations(PipelineStep):
|
|||||||
# Apply a basic augraphy pipeline
|
# Apply a basic augraphy pipeline
|
||||||
from augraphy import (
|
from augraphy import (
|
||||||
AugraphyPipeline,
|
AugraphyPipeline,
|
||||||
|
Brightness,
|
||||||
InkBleed,
|
InkBleed,
|
||||||
|
InkMottling,
|
||||||
|
InkShifter,
|
||||||
|
Jpeg,
|
||||||
LowInkPeriodicLines,
|
LowInkPeriodicLines,
|
||||||
LowInkRandomLines,
|
LowInkRandomLines,
|
||||||
OneOf,
|
OneOf,
|
||||||
|
|
||||||
Jpeg,
|
|
||||||
InkMottling,
|
|
||||||
InkShifter,
|
|
||||||
Brightness,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Apply geometric transformations first, maintaing scale
|
# Apply geometric transformations first, maintaing scale
|
||||||
@ -590,25 +587,20 @@ class AugraphyBasicAugmentations(PipelineStep):
|
|||||||
(width, height),
|
(width, height),
|
||||||
flags=cv2.INTER_LINEAR,
|
flags=cv2.INTER_LINEAR,
|
||||||
borderMode=cv2.BORDER_CONSTANT,
|
borderMode=cv2.BORDER_CONSTANT,
|
||||||
borderValue=(255, 255, 255) # White background for documents
|
borderValue=(255, 255, 255), # White background for documents
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
ink_phase = [
|
ink_phase = [
|
||||||
OneOf([InkBleed(p=1), LowInkRandomLines(p=1), LowInkPeriodicLines(p=1), InkMottling(p=1), InkShifter(p=1, text_shift_scale_range=(10, 15))], p=0.2),
|
OneOf([InkBleed(p=1), LowInkRandomLines(p=1), LowInkPeriodicLines(p=1), InkMottling(p=1), InkShifter(p=1, text_shift_scale_range=(10, 15))], p=0.2),
|
||||||
]
|
]
|
||||||
|
|
||||||
paper_phase = [
|
paper_phase = [OneOf([Brightness(p=0.2), Jpeg(p=1)])]
|
||||||
OneOf([Brightness(p=0.2), Jpeg(p=1)])
|
|
||||||
]
|
|
||||||
|
|
||||||
post_phase = [
|
post_phase = [
|
||||||
# Empty on purpose or else augmentations are too strong
|
# Empty on purpose or else augmentations are too strong
|
||||||
]
|
]
|
||||||
|
|
||||||
augmentation_pipeline = AugraphyPipeline(
|
augmentation_pipeline = AugraphyPipeline(ink_phase=ink_phase, paper_phase=paper_phase, post_phase=post_phase)
|
||||||
ink_phase=ink_phase, paper_phase=paper_phase, post_phase=post_phase
|
|
||||||
)
|
|
||||||
|
|
||||||
# Apply augmentations
|
# Apply augmentations
|
||||||
augmented_image_bgr = augmentation_pipeline(image_bgr)
|
augmented_image_bgr = augmentation_pipeline(image_bgr)
|
||||||
@ -621,12 +613,11 @@ class AugraphyBasicAugmentations(PipelineStep):
|
|||||||
sample["image"] = augmented_image_pil
|
sample["image"] = augmented_image_pil
|
||||||
|
|
||||||
# Double-check PIL image size matches original
|
# Double-check PIL image size matches original
|
||||||
assert augmented_image_pil.size == image.size, (
|
assert augmented_image_pil.size == image.size, f"PIL image size changed during augmentation: {image.size} -> {augmented_image_pil.size}"
|
||||||
f"PIL image size changed during augmentation: {image.size} -> {augmented_image_pil.size}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return sample
|
return sample
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True, slots=True)
|
@dataclass(frozen=True, slots=True)
|
||||||
class InstructUserMessages(PipelineStep):
|
class InstructUserMessages(PipelineStep):
|
||||||
"""Creates instruction-following messages format for training."""
|
"""Creates instruction-following messages format for training."""
|
||||||
|
|||||||
@ -39,7 +39,7 @@ def muon_update(grad, momentum, beta=0.95, ns_steps=5, nesterov=True):
|
|||||||
if update.ndim == 4: # for the case of conv filters
|
if update.ndim == 4: # for the case of conv filters
|
||||||
update = update.view(len(update), -1)
|
update = update.view(len(update), -1)
|
||||||
update = zeropower_via_newtonschulz5(update, steps=ns_steps)
|
update = zeropower_via_newtonschulz5(update, steps=ns_steps)
|
||||||
update *= max(1, grad.size(-2) / grad.size(-1))**0.5
|
update *= max(1, grad.size(-2) / grad.size(-1)) ** 0.5
|
||||||
return update
|
return update
|
||||||
|
|
||||||
|
|
||||||
@ -64,6 +64,7 @@ class Muon(torch.optim.Optimizer):
|
|||||||
weight_decay: The AdamW-style weight decay.
|
weight_decay: The AdamW-style weight decay.
|
||||||
momentum: The momentum. A value of 0.95 here is usually fine.
|
momentum: The momentum. A value of 0.95 here is usually fine.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, params, lr=0.02, weight_decay=0, momentum=0.95):
|
def __init__(self, params, lr=0.02, weight_decay=0, momentum=0.95):
|
||||||
defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum)
|
defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum)
|
||||||
assert isinstance(params, list) and len(params) >= 1 and isinstance(params[0], torch.nn.Parameter)
|
assert isinstance(params, list) and len(params) >= 1 and isinstance(params[0], torch.nn.Parameter)
|
||||||
@ -81,7 +82,7 @@ class Muon(torch.optim.Optimizer):
|
|||||||
for group in self.param_groups:
|
for group in self.param_groups:
|
||||||
params = group["params"]
|
params = group["params"]
|
||||||
params_pad = params + [torch.empty_like(params[-1])] * (dist.get_world_size() - len(params) % dist.get_world_size())
|
params_pad = params + [torch.empty_like(params[-1])] * (dist.get_world_size() - len(params) % dist.get_world_size())
|
||||||
for base_i in range(len(params))[::dist.get_world_size()]:
|
for base_i in range(len(params))[:: dist.get_world_size()]:
|
||||||
if base_i + dist.get_rank() < len(params):
|
if base_i + dist.get_rank() < len(params):
|
||||||
p = params[base_i + dist.get_rank()]
|
p = params[base_i + dist.get_rank()]
|
||||||
if p.grad is None:
|
if p.grad is None:
|
||||||
@ -93,7 +94,7 @@ class Muon(torch.optim.Optimizer):
|
|||||||
update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"])
|
update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"])
|
||||||
p.mul_(1 - group["lr"] * group["weight_decay"])
|
p.mul_(1 - group["lr"] * group["weight_decay"])
|
||||||
p.add_(update.reshape(p.shape), alpha=-group["lr"])
|
p.add_(update.reshape(p.shape), alpha=-group["lr"])
|
||||||
dist.all_gather(params_pad[base_i:base_i + dist.get_world_size()], params_pad[base_i + dist.get_rank()])
|
dist.all_gather(params_pad[base_i : base_i + dist.get_world_size()], params_pad[base_i + dist.get_rank()])
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
@ -102,6 +103,7 @@ class SingleDeviceMuon(torch.optim.Optimizer):
|
|||||||
"""
|
"""
|
||||||
Muon variant for usage in non-distributed settings.
|
Muon variant for usage in non-distributed settings.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, params, lr=0.02, weight_decay=0, momentum=0.95):
|
def __init__(self, params, lr=0.02, weight_decay=0, momentum=0.95):
|
||||||
defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum)
|
defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum)
|
||||||
super().__init__(params, defaults)
|
super().__init__(params, defaults)
|
||||||
@ -132,8 +134,8 @@ class SingleDeviceMuon(torch.optim.Optimizer):
|
|||||||
def adam_update(grad, buf1, buf2, step, betas, eps):
|
def adam_update(grad, buf1, buf2, step, betas, eps):
|
||||||
buf1.lerp_(grad, 1 - betas[0])
|
buf1.lerp_(grad, 1 - betas[0])
|
||||||
buf2.lerp_(grad.square(), 1 - betas[1])
|
buf2.lerp_(grad.square(), 1 - betas[1])
|
||||||
buf1c = buf1 / (1 - betas[0]**step)
|
buf1c = buf1 / (1 - betas[0] ** step)
|
||||||
buf2c = buf2 / (1 - betas[1]**step)
|
buf2c = buf2 / (1 - betas[1] ** step)
|
||||||
return buf1c / (buf2c.sqrt() + eps)
|
return buf1c / (buf2c.sqrt() + eps)
|
||||||
|
|
||||||
|
|
||||||
@ -164,6 +166,7 @@ class MuonWithAuxAdam(torch.optim.Optimizer):
|
|||||||
optimizer = MuonWithAuxAdam(param_groups)
|
optimizer = MuonWithAuxAdam(param_groups)
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, param_groups):
|
def __init__(self, param_groups):
|
||||||
for group in param_groups:
|
for group in param_groups:
|
||||||
assert "use_muon" in group
|
assert "use_muon" in group
|
||||||
@ -195,7 +198,7 @@ class MuonWithAuxAdam(torch.optim.Optimizer):
|
|||||||
if group["use_muon"]:
|
if group["use_muon"]:
|
||||||
params = group["params"]
|
params = group["params"]
|
||||||
params_pad = params + [torch.empty_like(params[-1])] * (dist.get_world_size() - len(params) % dist.get_world_size())
|
params_pad = params + [torch.empty_like(params[-1])] * (dist.get_world_size() - len(params) % dist.get_world_size())
|
||||||
for base_i in range(len(params))[::dist.get_world_size()]:
|
for base_i in range(len(params))[:: dist.get_world_size()]:
|
||||||
if base_i + dist.get_rank() < len(params):
|
if base_i + dist.get_rank() < len(params):
|
||||||
p = params[base_i + dist.get_rank()]
|
p = params[base_i + dist.get_rank()]
|
||||||
if p.grad is None:
|
if p.grad is None:
|
||||||
@ -207,7 +210,7 @@ class MuonWithAuxAdam(torch.optim.Optimizer):
|
|||||||
update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"])
|
update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"])
|
||||||
p.mul_(1 - group["lr"] * group["weight_decay"])
|
p.mul_(1 - group["lr"] * group["weight_decay"])
|
||||||
p.add_(update.reshape(p.shape), alpha=-group["lr"])
|
p.add_(update.reshape(p.shape), alpha=-group["lr"])
|
||||||
dist.all_gather(params_pad[base_i:base_i + dist.get_world_size()], params_pad[base_i + dist.get_rank()])
|
dist.all_gather(params_pad[base_i : base_i + dist.get_world_size()], params_pad[base_i + dist.get_rank()])
|
||||||
else:
|
else:
|
||||||
for p in group["params"]:
|
for p in group["params"]:
|
||||||
if p.grad is None:
|
if p.grad is None:
|
||||||
@ -219,8 +222,7 @@ class MuonWithAuxAdam(torch.optim.Optimizer):
|
|||||||
state["exp_avg_sq"] = torch.zeros_like(p)
|
state["exp_avg_sq"] = torch.zeros_like(p)
|
||||||
state["step"] = 0
|
state["step"] = 0
|
||||||
state["step"] += 1
|
state["step"] += 1
|
||||||
update = adam_update(p.grad, state["exp_avg"], state["exp_avg_sq"],
|
update = adam_update(p.grad, state["exp_avg"], state["exp_avg_sq"], state["step"], group["betas"], group["eps"])
|
||||||
state["step"], group["betas"], group["eps"])
|
|
||||||
p.mul_(1 - group["lr"] * group["weight_decay"])
|
p.mul_(1 - group["lr"] * group["weight_decay"])
|
||||||
p.add_(update, alpha=-group["lr"])
|
p.add_(update, alpha=-group["lr"])
|
||||||
|
|
||||||
@ -231,6 +233,7 @@ class SingleDeviceMuonWithAuxAdam(torch.optim.Optimizer):
|
|||||||
"""
|
"""
|
||||||
Non-distributed variant of MuonWithAuxAdam.
|
Non-distributed variant of MuonWithAuxAdam.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, param_groups):
|
def __init__(self, param_groups):
|
||||||
for group in param_groups:
|
for group in param_groups:
|
||||||
assert "use_muon" in group
|
assert "use_muon" in group
|
||||||
@ -280,8 +283,7 @@ class SingleDeviceMuonWithAuxAdam(torch.optim.Optimizer):
|
|||||||
state["exp_avg_sq"] = torch.zeros_like(p)
|
state["exp_avg_sq"] = torch.zeros_like(p)
|
||||||
state["step"] = 0
|
state["step"] = 0
|
||||||
state["step"] += 1
|
state["step"] += 1
|
||||||
update = adam_update(p.grad, state["exp_avg"], state["exp_avg_sq"],
|
update = adam_update(p.grad, state["exp_avg"], state["exp_avg_sq"], state["step"], group["betas"], group["eps"])
|
||||||
state["step"], group["betas"], group["eps"])
|
|
||||||
p.mul_(1 - group["lr"] * group["weight_decay"])
|
p.mul_(1 - group["lr"] * group["weight_decay"])
|
||||||
p.add_(update, alpha=-group["lr"])
|
p.add_(update, alpha=-group["lr"])
|
||||||
|
|
||||||
|
|||||||
@ -4,26 +4,25 @@ Simple script to test OlmOCR dataset loading with YAML configuration.
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
import math
|
import math
|
||||||
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import ConcatDataset, DataLoader
|
|
||||||
from torch.optim import AdamW
|
|
||||||
from torch.amp import autocast
|
|
||||||
import wandb
|
import wandb
|
||||||
|
from torch.amp import autocast
|
||||||
|
from torch.optim import AdamW
|
||||||
|
from torch.utils.data import ConcatDataset, DataLoader
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoProcessor,
|
AutoProcessor,
|
||||||
get_scheduler,
|
|
||||||
Qwen2_5_VLForConditionalGeneration,
|
Qwen2_5_VLForConditionalGeneration,
|
||||||
Qwen2VLForConditionalGeneration,
|
Qwen2VLForConditionalGeneration,
|
||||||
|
get_scheduler,
|
||||||
)
|
)
|
||||||
|
|
||||||
from typing import Optional, Dict, Any
|
|
||||||
from olmocr.train.config import Config
|
from olmocr.train.config import Config
|
||||||
from olmocr.train.dataloader import BaseMarkdownPDFDataset
|
from olmocr.train.dataloader import BaseMarkdownPDFDataset
|
||||||
from olmocr.train.muon import SingleDeviceMuonWithAuxAdam
|
from olmocr.train.muon import SingleDeviceMuonWithAuxAdam
|
||||||
@ -37,7 +36,6 @@ logging.basicConfig(
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class QwenDataCollator:
|
class QwenDataCollator:
|
||||||
"""Data collator for vision-language models that handles numpy arrays."""
|
"""Data collator for vision-language models that handles numpy arrays."""
|
||||||
|
|
||||||
@ -128,10 +126,7 @@ def save_checkpoint(
|
|||||||
|
|
||||||
# Enforce save_total_limit by removing oldest checkpoints
|
# Enforce save_total_limit by removing oldest checkpoints
|
||||||
if save_total_limit is not None and save_total_limit > 0:
|
if save_total_limit is not None and save_total_limit > 0:
|
||||||
checkpoints = sorted(
|
checkpoints = sorted([d for d in os.listdir(output_dir) if d.startswith("checkpoint-")], key=lambda x: int(x.split("-")[1]))
|
||||||
[d for d in os.listdir(output_dir) if d.startswith("checkpoint-")],
|
|
||||||
key=lambda x: int(x.split("-")[1])
|
|
||||||
)
|
|
||||||
while len(checkpoints) > save_total_limit:
|
while len(checkpoints) > save_total_limit:
|
||||||
oldest = checkpoints.pop(0)
|
oldest = checkpoints.pop(0)
|
||||||
shutil.rmtree(os.path.join(output_dir, oldest))
|
shutil.rmtree(os.path.join(output_dir, oldest))
|
||||||
@ -284,7 +279,6 @@ def main():
|
|||||||
if len(dataset) > 0:
|
if len(dataset) > 0:
|
||||||
eval_datasets[dataset_name] = dataset
|
eval_datasets[dataset_name] = dataset
|
||||||
|
|
||||||
|
|
||||||
# Log total evaluation samples across all datasets
|
# Log total evaluation samples across all datasets
|
||||||
total_eval_samples = sum(len(dataset) for dataset in eval_datasets.values())
|
total_eval_samples = sum(len(dataset) for dataset in eval_datasets.values())
|
||||||
logger.info(f"Total evaluation samples across {len(eval_datasets)} datasets: {total_eval_samples}")
|
logger.info(f"Total evaluation samples across {len(eval_datasets)} datasets: {total_eval_samples}")
|
||||||
@ -316,6 +310,7 @@ def main():
|
|||||||
worker_seed = torch.initial_seed() % 2**32
|
worker_seed = torch.initial_seed() % 2**32
|
||||||
np.random.seed(worker_seed)
|
np.random.seed(worker_seed)
|
||||||
import random
|
import random
|
||||||
|
|
||||||
random.seed(worker_seed)
|
random.seed(worker_seed)
|
||||||
|
|
||||||
# Create generator for data loader
|
# Create generator for data loader
|
||||||
@ -370,7 +365,7 @@ def main():
|
|||||||
adam_groups = [
|
adam_groups = [
|
||||||
dict(params=head_params, lr=float(config.training.learning_rate) * config.training.muon_lr_multiplier_head, use_muon=False),
|
dict(params=head_params, lr=float(config.training.learning_rate) * config.training.muon_lr_multiplier_head, use_muon=False),
|
||||||
dict(params=embed_params, lr=float(config.training.learning_rate) * config.training.muon_lr_multiplier_embed, use_muon=False),
|
dict(params=embed_params, lr=float(config.training.learning_rate) * config.training.muon_lr_multiplier_embed, use_muon=False),
|
||||||
dict(params=scalar_params, lr=float(config.training.learning_rate) * config.training.muon_lr_multiplier_scalar, use_muon=False)
|
dict(params=scalar_params, lr=float(config.training.learning_rate) * config.training.muon_lr_multiplier_scalar, use_muon=False),
|
||||||
]
|
]
|
||||||
|
|
||||||
# Add Adam hyperparameters to groups
|
# Add Adam hyperparameters to groups
|
||||||
@ -385,7 +380,7 @@ def main():
|
|||||||
lr=float(config.training.learning_rate),
|
lr=float(config.training.learning_rate),
|
||||||
momentum=config.training.muon_momentum,
|
momentum=config.training.muon_momentum,
|
||||||
weight_decay=config.training.weight_decay,
|
weight_decay=config.training.weight_decay,
|
||||||
use_muon=True
|
use_muon=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Combine all groups
|
# Combine all groups
|
||||||
@ -527,12 +522,7 @@ def main():
|
|||||||
# Update progress bar with current stats
|
# Update progress bar with current stats
|
||||||
current_lr = lr_scheduler.get_last_lr()[0]
|
current_lr = lr_scheduler.get_last_lr()[0]
|
||||||
avg_loss = accumulated_loss / num_losses_accumulated if num_losses_accumulated > 0 else 0
|
avg_loss = accumulated_loss / num_losses_accumulated if num_losses_accumulated > 0 else 0
|
||||||
pbar.set_postfix({
|
pbar.set_postfix({"loss": f"{avg_loss:.4f}", "lr": f"{current_lr:.2e}", "epoch": f"{current_epoch:.2f}", "step": global_step})
|
||||||
'loss': f'{avg_loss:.4f}',
|
|
||||||
'lr': f'{current_lr:.2e}',
|
|
||||||
'epoch': f'{current_epoch:.2f}',
|
|
||||||
'step': global_step
|
|
||||||
})
|
|
||||||
|
|
||||||
# Logging
|
# Logging
|
||||||
if config.training.logging_steps > 0 and global_step % config.training.logging_steps == 0:
|
if config.training.logging_steps > 0 and global_step % config.training.logging_steps == 0:
|
||||||
@ -560,8 +550,9 @@ def main():
|
|||||||
# Update best metric
|
# Update best metric
|
||||||
current_metric = metrics.get(config.training.metric_for_best_model, None)
|
current_metric = metrics.get(config.training.metric_for_best_model, None)
|
||||||
if current_metric is not None:
|
if current_metric is not None:
|
||||||
if (config.training.greater_is_better and current_metric > best_metric) or \
|
if (config.training.greater_is_better and current_metric > best_metric) or (
|
||||||
(not config.training.greater_is_better and current_metric < best_metric):
|
not config.training.greater_is_better and current_metric < best_metric
|
||||||
|
):
|
||||||
best_metric = current_metric
|
best_metric = current_metric
|
||||||
|
|
||||||
# Return to training mode
|
# Return to training mode
|
||||||
@ -570,8 +561,7 @@ def main():
|
|||||||
# Saving
|
# Saving
|
||||||
if config.training.save_steps > 0 and global_step % config.training.save_steps == 0:
|
if config.training.save_steps > 0 and global_step % config.training.save_steps == 0:
|
||||||
save_checkpoint(
|
save_checkpoint(
|
||||||
model, optimizer, lr_scheduler, current_epoch, global_step, samples_seen, best_metric,
|
model, optimizer, lr_scheduler, current_epoch, global_step, samples_seen, best_metric, full_output_dir, config.training.save_total_limit
|
||||||
full_output_dir, config.training.save_total_limit
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if we've reached our training limit
|
# Check if we've reached our training limit
|
||||||
@ -583,10 +573,7 @@ def main():
|
|||||||
|
|
||||||
# Save the final checkpoint with step number
|
# Save the final checkpoint with step number
|
||||||
logger.info(f"Saving final checkpoint at step {global_step}...")
|
logger.info(f"Saving final checkpoint at step {global_step}...")
|
||||||
save_checkpoint(
|
save_checkpoint(model, optimizer, lr_scheduler, current_epoch, global_step, samples_seen, best_metric, full_output_dir, config.training.save_total_limit)
|
||||||
model, optimizer, lr_scheduler, current_epoch, global_step, samples_seen, best_metric,
|
|
||||||
full_output_dir, config.training.save_total_limit
|
|
||||||
)
|
|
||||||
|
|
||||||
# Log final training state
|
# Log final training state
|
||||||
final_epoch = samples_seen / len(train_dataset)
|
final_epoch = samples_seen / len(train_dataset)
|
||||||
|
|||||||
@ -171,7 +171,6 @@ class WorkQueue:
|
|||||||
logger.info(f"Initialized queue with {self.size:,} work items")
|
logger.info(f"Initialized queue with {self.size:,} work items")
|
||||||
return self.size
|
return self.size
|
||||||
|
|
||||||
|
|
||||||
async def get_work(self, worker_lock_timeout_secs: int = 1800) -> Optional[WorkItem]:
|
async def get_work(self, worker_lock_timeout_secs: int = 1800) -> Optional[WorkItem]:
|
||||||
"""
|
"""
|
||||||
Get the next available work item that isn't completed or locked.
|
Get the next available work item that isn't completed or locked.
|
||||||
@ -179,7 +178,6 @@ class WorkQueue:
|
|||||||
REFRESH_COMPLETED_HASH_CACHE_MAX_ATTEMPTS = 3
|
REFRESH_COMPLETED_HASH_CACHE_MAX_ATTEMPTS = 3
|
||||||
refresh_completed_hash_attempt = 0
|
refresh_completed_hash_attempt = 0
|
||||||
|
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
work_item = self._queue.get_nowait()
|
work_item = self._queue.get_nowait()
|
||||||
@ -281,11 +279,7 @@ class LocalBackend(Backend):
|
|||||||
def _list_completed() -> Set[str]:
|
def _list_completed() -> Set[str]:
|
||||||
if not os.path.isdir(self._done_flags_dir):
|
if not os.path.isdir(self._done_flags_dir):
|
||||||
return set()
|
return set()
|
||||||
return {
|
return {f[len("done_") : -len(".flag")] for f in os.listdir(self._done_flags_dir) if f.startswith("done_") and f.endswith(".flag")}
|
||||||
f[len("done_") : -len(".flag")]
|
|
||||||
for f in os.listdir(self._done_flags_dir)
|
|
||||||
if f.startswith("done_") and f.endswith(".flag")
|
|
||||||
}
|
|
||||||
|
|
||||||
return await asyncio.to_thread(_list_completed)
|
return await asyncio.to_thread(_list_completed)
|
||||||
|
|
||||||
@ -299,6 +293,7 @@ class LocalBackend(Backend):
|
|||||||
|
|
||||||
async def _get_object_mtime(self, path: str) -> Optional[datetime.datetime]:
|
async def _get_object_mtime(self, path: str) -> Optional[datetime.datetime]:
|
||||||
"""Internal method to get object mtime."""
|
"""Internal method to get object mtime."""
|
||||||
|
|
||||||
def _get_mtime() -> Optional[datetime.datetime]:
|
def _get_mtime() -> Optional[datetime.datetime]:
|
||||||
if not os.path.exists(path):
|
if not os.path.exists(path):
|
||||||
return None
|
return None
|
||||||
|
|||||||
@ -6,7 +6,7 @@ from unittest.mock import Mock, patch
|
|||||||
from botocore.exceptions import ClientError
|
from botocore.exceptions import ClientError
|
||||||
|
|
||||||
# Import the classes we're testing
|
# Import the classes we're testing
|
||||||
from olmocr.work_queue import WorkQueue, S3Backend, WorkItem
|
from olmocr.work_queue import S3Backend, WorkItem, WorkQueue
|
||||||
|
|
||||||
|
|
||||||
class TestS3WorkQueue(unittest.TestCase):
|
class TestS3WorkQueue(unittest.TestCase):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user