mirror of
				https://github.com/allenai/olmocr.git
				synced 2025-10-31 10:04:26 +00:00 
			
		
		
		
	Lint fixes
This commit is contained in:
		
							parent
							
								
									05330150ad
								
							
						
					
					
						commit
						93411a80a0
					
				| @ -49,7 +49,7 @@ from olmocr.s3_utils import ( | ||||
| ) | ||||
| from olmocr.train.dataloader import FrontMatterParser | ||||
| from olmocr.version import VERSION | ||||
| from olmocr.work_queue import WorkQueue, LocalBackend, S3Backend | ||||
| from olmocr.work_queue import LocalBackend, S3Backend, WorkQueue | ||||
| 
 | ||||
| # Initialize logger | ||||
| logger = logging.getLogger(__name__) | ||||
|  | ||||
| @ -394,9 +394,7 @@ class Config: | ||||
|                 steps.append(FrontMatterParser(front_matter_class=front_matter_class)) | ||||
| 
 | ||||
|             elif step_name == "PDFRenderer": | ||||
|                 steps.append( | ||||
|                     PDFRenderer(target_longest_image_dim=step_config.get("target_longest_image_dim", 1024)) | ||||
|                 ) | ||||
|                 steps.append(PDFRenderer(target_longest_image_dim=step_config.get("target_longest_image_dim", 1024))) | ||||
| 
 | ||||
|             elif step_name == "StaticLengthDocumentAnchoring": | ||||
|                 steps.append(StaticLengthDocumentAnchoring(target_anchor_text_len=step_config.get("target_anchor_text_len", 6000))) | ||||
| @ -417,9 +415,7 @@ class Config: | ||||
|                 steps.append(FrontMatterOutputFormat()) | ||||
| 
 | ||||
|             elif step_name == "InstructUserMessages": | ||||
|                 steps.append(InstructUserMessages( | ||||
|                     prompt_first=step_config.get("prompt_first", False) | ||||
|                 )) | ||||
|                 steps.append(InstructUserMessages(prompt_first=step_config.get("prompt_first", False))) | ||||
| 
 | ||||
|             elif step_name == "LatexBracketNormalizer": | ||||
|                 steps.append(LatexBracketNormalizer()) | ||||
| @ -462,18 +458,10 @@ class Config: | ||||
|                 steps.append(FilterOutRotatedDocuments()) | ||||
| 
 | ||||
|             elif step_name == "RotationAugmentation": | ||||
|                 steps.append( | ||||
|                     RotationAugmentation( | ||||
|                         probability=step_config.get("probability", 0.5) | ||||
|                     ) | ||||
|                 ) | ||||
|                 steps.append(RotationAugmentation(probability=step_config.get("probability", 0.5))) | ||||
| 
 | ||||
|             elif step_name == "AugraphyBasicAugmentations": | ||||
|                 steps.append( | ||||
|                     AugraphyBasicAugmentations( | ||||
|                         probability=step_config.get("probability", 0.5) | ||||
|                     ) | ||||
|                 ) | ||||
|                 steps.append(AugraphyBasicAugmentations(probability=step_config.get("probability", 0.5))) | ||||
| 
 | ||||
|             else: | ||||
|                 raise ValueError(f"Unknown pipeline step: {step_name}") | ||||
|  | ||||
| @ -5,13 +5,11 @@ import re | ||||
| from abc import ABC, abstractmethod | ||||
| from concurrent.futures import ProcessPoolExecutor, as_completed | ||||
| from dataclasses import dataclass, fields | ||||
| from functools import reduce | ||||
| from io import BytesIO | ||||
| from os import PathLike | ||||
| from pathlib import Path | ||||
| from typing import ( | ||||
|     Any, | ||||
|     Callable, | ||||
|     Dict, | ||||
|     List, | ||||
|     Optional, | ||||
| @ -144,7 +142,7 @@ class BaseMarkdownPDFDataset(Dataset): | ||||
|                     pbar.update(1) | ||||
| 
 | ||||
|         # Sort samples by markdown path for consistent ordering across runs | ||||
|         self.samples.sort(key=lambda x: x['markdown_path']) | ||||
|         self.samples.sort(key=lambda x: x["markdown_path"]) | ||||
| 
 | ||||
|         logger.info(f"Found {valid_count} valid markdown-PDF pairs") | ||||
| 
 | ||||
| @ -551,15 +549,14 @@ class AugraphyBasicAugmentations(PipelineStep): | ||||
|         # Apply a basic augraphy pipeline | ||||
|         from augraphy import ( | ||||
|             AugraphyPipeline, | ||||
|             Brightness, | ||||
|             InkBleed, | ||||
|             InkMottling, | ||||
|             InkShifter, | ||||
|             Jpeg, | ||||
|             LowInkPeriodicLines, | ||||
|             LowInkRandomLines, | ||||
|             OneOf, | ||||
|              | ||||
|             Jpeg, | ||||
|             InkMottling, | ||||
|             InkShifter, | ||||
|             Brightness, | ||||
|         ) | ||||
| 
 | ||||
|         # Apply geometric transformations first, maintaing scale | ||||
| @ -590,25 +587,20 @@ class AugraphyBasicAugmentations(PipelineStep): | ||||
|                 (width, height), | ||||
|                 flags=cv2.INTER_LINEAR, | ||||
|                 borderMode=cv2.BORDER_CONSTANT, | ||||
|                 borderValue=(255, 255, 255)  # White background for documents | ||||
|                 borderValue=(255, 255, 255),  # White background for documents | ||||
|             ) | ||||
| 
 | ||||
| 
 | ||||
|         ink_phase = [ | ||||
|             OneOf([InkBleed(p=1), LowInkRandomLines(p=1), LowInkPeriodicLines(p=1), InkMottling(p=1), InkShifter(p=1, text_shift_scale_range=(10, 15))], p=0.2), | ||||
|         ] | ||||
| 
 | ||||
|         paper_phase = [ | ||||
|             OneOf([Brightness(p=0.2), Jpeg(p=1)]) | ||||
|         ] | ||||
|         paper_phase = [OneOf([Brightness(p=0.2), Jpeg(p=1)])] | ||||
| 
 | ||||
|         post_phase = [ | ||||
|             # Empty on purpose or else augmentations are too strong | ||||
|         ] | ||||
| 
 | ||||
|         augmentation_pipeline = AugraphyPipeline( | ||||
|             ink_phase=ink_phase, paper_phase=paper_phase, post_phase=post_phase | ||||
|         ) | ||||
|         augmentation_pipeline = AugraphyPipeline(ink_phase=ink_phase, paper_phase=paper_phase, post_phase=post_phase) | ||||
| 
 | ||||
|         # Apply augmentations | ||||
|         augmented_image_bgr = augmentation_pipeline(image_bgr) | ||||
| @ -621,12 +613,11 @@ class AugraphyBasicAugmentations(PipelineStep): | ||||
|         sample["image"] = augmented_image_pil | ||||
| 
 | ||||
|         # Double-check PIL image size matches original | ||||
|         assert augmented_image_pil.size == image.size, ( | ||||
|             f"PIL image size changed during augmentation: {image.size} -> {augmented_image_pil.size}" | ||||
|         ) | ||||
|         assert augmented_image_pil.size == image.size, f"PIL image size changed during augmentation: {image.size} -> {augmented_image_pil.size}" | ||||
| 
 | ||||
|         return sample | ||||
| 
 | ||||
| 
 | ||||
| @dataclass(frozen=True, slots=True) | ||||
| class InstructUserMessages(PipelineStep): | ||||
|     """Creates instruction-following messages format for training.""" | ||||
|  | ||||
| @ -64,6 +64,7 @@ class Muon(torch.optim.Optimizer): | ||||
|         weight_decay: The AdamW-style weight decay. | ||||
|         momentum: The momentum. A value of 0.95 here is usually fine. | ||||
|     """ | ||||
| 
 | ||||
|     def __init__(self, params, lr=0.02, weight_decay=0, momentum=0.95): | ||||
|         defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) | ||||
|         assert isinstance(params, list) and len(params) >= 1 and isinstance(params[0], torch.nn.Parameter) | ||||
| @ -102,6 +103,7 @@ class SingleDeviceMuon(torch.optim.Optimizer): | ||||
|     """ | ||||
|     Muon variant for usage in non-distributed settings. | ||||
|     """ | ||||
| 
 | ||||
|     def __init__(self, params, lr=0.02, weight_decay=0, momentum=0.95): | ||||
|         defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) | ||||
|         super().__init__(params, defaults) | ||||
| @ -164,6 +166,7 @@ class MuonWithAuxAdam(torch.optim.Optimizer): | ||||
|     optimizer = MuonWithAuxAdam(param_groups) | ||||
|     ``` | ||||
|     """ | ||||
| 
 | ||||
|     def __init__(self, param_groups): | ||||
|         for group in param_groups: | ||||
|             assert "use_muon" in group | ||||
| @ -219,8 +222,7 @@ class MuonWithAuxAdam(torch.optim.Optimizer): | ||||
|                         state["exp_avg_sq"] = torch.zeros_like(p) | ||||
|                         state["step"] = 0 | ||||
|                     state["step"] += 1 | ||||
|                     update = adam_update(p.grad, state["exp_avg"], state["exp_avg_sq"], | ||||
|                                          state["step"], group["betas"], group["eps"]) | ||||
|                     update = adam_update(p.grad, state["exp_avg"], state["exp_avg_sq"], state["step"], group["betas"], group["eps"]) | ||||
|                     p.mul_(1 - group["lr"] * group["weight_decay"]) | ||||
|                     p.add_(update, alpha=-group["lr"]) | ||||
| 
 | ||||
| @ -231,6 +233,7 @@ class SingleDeviceMuonWithAuxAdam(torch.optim.Optimizer): | ||||
|     """ | ||||
|     Non-distributed variant of MuonWithAuxAdam. | ||||
|     """ | ||||
| 
 | ||||
|     def __init__(self, param_groups): | ||||
|         for group in param_groups: | ||||
|             assert "use_muon" in group | ||||
| @ -280,8 +283,7 @@ class SingleDeviceMuonWithAuxAdam(torch.optim.Optimizer): | ||||
|                         state["exp_avg_sq"] = torch.zeros_like(p) | ||||
|                         state["step"] = 0 | ||||
|                     state["step"] += 1 | ||||
|                     update = adam_update(p.grad, state["exp_avg"], state["exp_avg_sq"], | ||||
|                                          state["step"], group["betas"], group["eps"]) | ||||
|                     update = adam_update(p.grad, state["exp_avg"], state["exp_avg_sq"], state["step"], group["betas"], group["eps"]) | ||||
|                     p.mul_(1 - group["lr"] * group["weight_decay"]) | ||||
|                     p.add_(update, alpha=-group["lr"]) | ||||
| 
 | ||||
|  | ||||
| @ -4,26 +4,25 @@ Simple script to test OlmOCR dataset loading with YAML configuration. | ||||
| 
 | ||||
| import argparse | ||||
| import logging | ||||
| import os | ||||
| import math | ||||
| import os | ||||
| import shutil | ||||
| from typing import Any, Dict, Optional | ||||
| 
 | ||||
| import numpy as np | ||||
| import torch | ||||
| from torch.utils.data import ConcatDataset, DataLoader | ||||
| from torch.optim import AdamW | ||||
| from torch.amp import autocast | ||||
| import wandb | ||||
| from torch.amp import autocast | ||||
| from torch.optim import AdamW | ||||
| from torch.utils.data import ConcatDataset, DataLoader | ||||
| from tqdm import tqdm | ||||
| 
 | ||||
| from transformers import ( | ||||
|     AutoProcessor, | ||||
|     get_scheduler, | ||||
|     Qwen2_5_VLForConditionalGeneration, | ||||
|     Qwen2VLForConditionalGeneration, | ||||
|     get_scheduler, | ||||
| ) | ||||
| 
 | ||||
| from typing import Optional, Dict, Any | ||||
| from olmocr.train.config import Config | ||||
| from olmocr.train.dataloader import BaseMarkdownPDFDataset | ||||
| from olmocr.train.muon import SingleDeviceMuonWithAuxAdam | ||||
| @ -37,7 +36,6 @@ logging.basicConfig( | ||||
| logger = logging.getLogger(__name__) | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| class QwenDataCollator: | ||||
|     """Data collator for vision-language models that handles numpy arrays.""" | ||||
| 
 | ||||
| @ -128,10 +126,7 @@ def save_checkpoint( | ||||
| 
 | ||||
|     # Enforce save_total_limit by removing oldest checkpoints | ||||
|     if save_total_limit is not None and save_total_limit > 0: | ||||
|         checkpoints = sorted( | ||||
|             [d for d in os.listdir(output_dir) if d.startswith("checkpoint-")], | ||||
|             key=lambda x: int(x.split("-")[1]) | ||||
|         ) | ||||
|         checkpoints = sorted([d for d in os.listdir(output_dir) if d.startswith("checkpoint-")], key=lambda x: int(x.split("-")[1])) | ||||
|         while len(checkpoints) > save_total_limit: | ||||
|             oldest = checkpoints.pop(0) | ||||
|             shutil.rmtree(os.path.join(output_dir, oldest)) | ||||
| @ -284,7 +279,6 @@ def main(): | ||||
|         if len(dataset) > 0: | ||||
|             eval_datasets[dataset_name] = dataset | ||||
| 
 | ||||
| 
 | ||||
|     # Log total evaluation samples across all datasets | ||||
|     total_eval_samples = sum(len(dataset) for dataset in eval_datasets.values()) | ||||
|     logger.info(f"Total evaluation samples across {len(eval_datasets)} datasets: {total_eval_samples}") | ||||
| @ -316,6 +310,7 @@ def main(): | ||||
|         worker_seed = torch.initial_seed() % 2**32 | ||||
|         np.random.seed(worker_seed) | ||||
|         import random | ||||
| 
 | ||||
|         random.seed(worker_seed) | ||||
| 
 | ||||
|     # Create generator for data loader | ||||
| @ -370,7 +365,7 @@ def main(): | ||||
|         adam_groups = [ | ||||
|             dict(params=head_params, lr=float(config.training.learning_rate) * config.training.muon_lr_multiplier_head, use_muon=False), | ||||
|             dict(params=embed_params, lr=float(config.training.learning_rate) * config.training.muon_lr_multiplier_embed, use_muon=False), | ||||
|             dict(params=scalar_params, lr=float(config.training.learning_rate) * config.training.muon_lr_multiplier_scalar, use_muon=False) | ||||
|             dict(params=scalar_params, lr=float(config.training.learning_rate) * config.training.muon_lr_multiplier_scalar, use_muon=False), | ||||
|         ] | ||||
| 
 | ||||
|         # Add Adam hyperparameters to groups | ||||
| @ -385,7 +380,7 @@ def main(): | ||||
|             lr=float(config.training.learning_rate), | ||||
|             momentum=config.training.muon_momentum, | ||||
|             weight_decay=config.training.weight_decay, | ||||
|             use_muon=True | ||||
|             use_muon=True, | ||||
|         ) | ||||
| 
 | ||||
|         # Combine all groups | ||||
| @ -527,12 +522,7 @@ def main(): | ||||
|                 # Update progress bar with current stats | ||||
|                 current_lr = lr_scheduler.get_last_lr()[0] | ||||
|                 avg_loss = accumulated_loss / num_losses_accumulated if num_losses_accumulated > 0 else 0 | ||||
|                 pbar.set_postfix({ | ||||
|                     'loss': f'{avg_loss:.4f}', | ||||
|                     'lr': f'{current_lr:.2e}', | ||||
|                     'epoch': f'{current_epoch:.2f}', | ||||
|                     'step': global_step | ||||
|                 }) | ||||
|                 pbar.set_postfix({"loss": f"{avg_loss:.4f}", "lr": f"{current_lr:.2e}", "epoch": f"{current_epoch:.2f}", "step": global_step}) | ||||
| 
 | ||||
|                 # Logging | ||||
|                 if config.training.logging_steps > 0 and global_step % config.training.logging_steps == 0: | ||||
| @ -560,8 +550,9 @@ def main(): | ||||
|                     # Update best metric | ||||
|                     current_metric = metrics.get(config.training.metric_for_best_model, None) | ||||
|                     if current_metric is not None: | ||||
|                         if (config.training.greater_is_better and current_metric > best_metric) or \ | ||||
|                            (not config.training.greater_is_better and current_metric < best_metric): | ||||
|                         if (config.training.greater_is_better and current_metric > best_metric) or ( | ||||
|                             not config.training.greater_is_better and current_metric < best_metric | ||||
|                         ): | ||||
|                             best_metric = current_metric | ||||
| 
 | ||||
|                     # Return to training mode | ||||
| @ -570,8 +561,7 @@ def main(): | ||||
|                 # Saving | ||||
|                 if config.training.save_steps > 0 and global_step % config.training.save_steps == 0: | ||||
|                     save_checkpoint( | ||||
|                         model, optimizer, lr_scheduler, current_epoch, global_step, samples_seen, best_metric, | ||||
|                         full_output_dir, config.training.save_total_limit | ||||
|                         model, optimizer, lr_scheduler, current_epoch, global_step, samples_seen, best_metric, full_output_dir, config.training.save_total_limit | ||||
|                     ) | ||||
| 
 | ||||
|             # Check if we've reached our training limit | ||||
| @ -583,10 +573,7 @@ def main(): | ||||
| 
 | ||||
|     # Save the final checkpoint with step number | ||||
|     logger.info(f"Saving final checkpoint at step {global_step}...") | ||||
|     save_checkpoint( | ||||
|         model, optimizer, lr_scheduler, current_epoch, global_step, samples_seen, best_metric, | ||||
|         full_output_dir, config.training.save_total_limit | ||||
|     ) | ||||
|     save_checkpoint(model, optimizer, lr_scheduler, current_epoch, global_step, samples_seen, best_metric, full_output_dir, config.training.save_total_limit) | ||||
| 
 | ||||
|     # Log final training state | ||||
|     final_epoch = samples_seen / len(train_dataset) | ||||
|  | ||||
| @ -171,7 +171,6 @@ class WorkQueue: | ||||
|         logger.info(f"Initialized queue with {self.size:,} work items") | ||||
|         return self.size | ||||
| 
 | ||||
| 
 | ||||
|     async def get_work(self, worker_lock_timeout_secs: int = 1800) -> Optional[WorkItem]: | ||||
|         """ | ||||
|         Get the next available work item that isn't completed or locked. | ||||
| @ -179,7 +178,6 @@ class WorkQueue: | ||||
|         REFRESH_COMPLETED_HASH_CACHE_MAX_ATTEMPTS = 3 | ||||
|         refresh_completed_hash_attempt = 0 | ||||
| 
 | ||||
| 
 | ||||
|         while True: | ||||
|             try: | ||||
|                 work_item = self._queue.get_nowait() | ||||
| @ -281,11 +279,7 @@ class LocalBackend(Backend): | ||||
|         def _list_completed() -> Set[str]: | ||||
|             if not os.path.isdir(self._done_flags_dir): | ||||
|                 return set() | ||||
|             return { | ||||
|                 f[len("done_") : -len(".flag")] | ||||
|                 for f in os.listdir(self._done_flags_dir) | ||||
|                 if f.startswith("done_") and f.endswith(".flag") | ||||
|             } | ||||
|             return {f[len("done_") : -len(".flag")] for f in os.listdir(self._done_flags_dir) if f.startswith("done_") and f.endswith(".flag")} | ||||
| 
 | ||||
|         return await asyncio.to_thread(_list_completed) | ||||
| 
 | ||||
| @ -299,6 +293,7 @@ class LocalBackend(Backend): | ||||
| 
 | ||||
|     async def _get_object_mtime(self, path: str) -> Optional[datetime.datetime]: | ||||
|         """Internal method to get object mtime.""" | ||||
| 
 | ||||
|         def _get_mtime() -> Optional[datetime.datetime]: | ||||
|             if not os.path.exists(path): | ||||
|                 return None | ||||
|  | ||||
| @ -6,7 +6,7 @@ from unittest.mock import Mock, patch | ||||
| from botocore.exceptions import ClientError | ||||
| 
 | ||||
| # Import the classes we're testing | ||||
| from olmocr.work_queue import WorkQueue, S3Backend, WorkItem | ||||
| from olmocr.work_queue import S3Backend, WorkItem, WorkQueue | ||||
| 
 | ||||
| 
 | ||||
| class TestS3WorkQueue(unittest.TestCase): | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Jake Poznanski
						Jake Poznanski