Lint fixes

This commit is contained in:
Jake Poznanski 2025-08-13 20:21:04 +00:00
parent 05330150ad
commit 93411a80a0
8 changed files with 157 additions and 194 deletions

View File

@ -49,7 +49,7 @@ from olmocr.s3_utils import (
) )
from olmocr.train.dataloader import FrontMatterParser from olmocr.train.dataloader import FrontMatterParser
from olmocr.version import VERSION from olmocr.version import VERSION
from olmocr.work_queue import WorkQueue, LocalBackend, S3Backend from olmocr.work_queue import LocalBackend, S3Backend, WorkQueue
# Initialize logger # Initialize logger
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -394,9 +394,7 @@ class Config:
steps.append(FrontMatterParser(front_matter_class=front_matter_class)) steps.append(FrontMatterParser(front_matter_class=front_matter_class))
elif step_name == "PDFRenderer": elif step_name == "PDFRenderer":
steps.append( steps.append(PDFRenderer(target_longest_image_dim=step_config.get("target_longest_image_dim", 1024)))
PDFRenderer(target_longest_image_dim=step_config.get("target_longest_image_dim", 1024))
)
elif step_name == "StaticLengthDocumentAnchoring": elif step_name == "StaticLengthDocumentAnchoring":
steps.append(StaticLengthDocumentAnchoring(target_anchor_text_len=step_config.get("target_anchor_text_len", 6000))) steps.append(StaticLengthDocumentAnchoring(target_anchor_text_len=step_config.get("target_anchor_text_len", 6000)))
@ -417,9 +415,7 @@ class Config:
steps.append(FrontMatterOutputFormat()) steps.append(FrontMatterOutputFormat())
elif step_name == "InstructUserMessages": elif step_name == "InstructUserMessages":
steps.append(InstructUserMessages( steps.append(InstructUserMessages(prompt_first=step_config.get("prompt_first", False)))
prompt_first=step_config.get("prompt_first", False)
))
elif step_name == "LatexBracketNormalizer": elif step_name == "LatexBracketNormalizer":
steps.append(LatexBracketNormalizer()) steps.append(LatexBracketNormalizer())
@ -462,18 +458,10 @@ class Config:
steps.append(FilterOutRotatedDocuments()) steps.append(FilterOutRotatedDocuments())
elif step_name == "RotationAugmentation": elif step_name == "RotationAugmentation":
steps.append( steps.append(RotationAugmentation(probability=step_config.get("probability", 0.5)))
RotationAugmentation(
probability=step_config.get("probability", 0.5)
)
)
elif step_name == "AugraphyBasicAugmentations": elif step_name == "AugraphyBasicAugmentations":
steps.append( steps.append(AugraphyBasicAugmentations(probability=step_config.get("probability", 0.5)))
AugraphyBasicAugmentations(
probability=step_config.get("probability", 0.5)
)
)
else: else:
raise ValueError(f"Unknown pipeline step: {step_name}") raise ValueError(f"Unknown pipeline step: {step_name}")

View File

@ -5,13 +5,11 @@ import re
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from concurrent.futures import ProcessPoolExecutor, as_completed from concurrent.futures import ProcessPoolExecutor, as_completed
from dataclasses import dataclass, fields from dataclasses import dataclass, fields
from functools import reduce
from io import BytesIO from io import BytesIO
from os import PathLike from os import PathLike
from pathlib import Path from pathlib import Path
from typing import ( from typing import (
Any, Any,
Callable,
Dict, Dict,
List, List,
Optional, Optional,
@ -144,7 +142,7 @@ class BaseMarkdownPDFDataset(Dataset):
pbar.update(1) pbar.update(1)
# Sort samples by markdown path for consistent ordering across runs # Sort samples by markdown path for consistent ordering across runs
self.samples.sort(key=lambda x: x['markdown_path']) self.samples.sort(key=lambda x: x["markdown_path"])
logger.info(f"Found {valid_count} valid markdown-PDF pairs") logger.info(f"Found {valid_count} valid markdown-PDF pairs")
@ -551,15 +549,14 @@ class AugraphyBasicAugmentations(PipelineStep):
# Apply a basic augraphy pipeline # Apply a basic augraphy pipeline
from augraphy import ( from augraphy import (
AugraphyPipeline, AugraphyPipeline,
Brightness,
InkBleed, InkBleed,
InkMottling,
InkShifter,
Jpeg,
LowInkPeriodicLines, LowInkPeriodicLines,
LowInkRandomLines, LowInkRandomLines,
OneOf, OneOf,
Jpeg,
InkMottling,
InkShifter,
Brightness,
) )
# Apply geometric transformations first, maintaing scale # Apply geometric transformations first, maintaing scale
@ -590,25 +587,20 @@ class AugraphyBasicAugmentations(PipelineStep):
(width, height), (width, height),
flags=cv2.INTER_LINEAR, flags=cv2.INTER_LINEAR,
borderMode=cv2.BORDER_CONSTANT, borderMode=cv2.BORDER_CONSTANT,
borderValue=(255, 255, 255) # White background for documents borderValue=(255, 255, 255), # White background for documents
) )
ink_phase = [ ink_phase = [
OneOf([InkBleed(p=1), LowInkRandomLines(p=1), LowInkPeriodicLines(p=1), InkMottling(p=1), InkShifter(p=1, text_shift_scale_range=(10, 15))], p=0.2), OneOf([InkBleed(p=1), LowInkRandomLines(p=1), LowInkPeriodicLines(p=1), InkMottling(p=1), InkShifter(p=1, text_shift_scale_range=(10, 15))], p=0.2),
] ]
paper_phase = [ paper_phase = [OneOf([Brightness(p=0.2), Jpeg(p=1)])]
OneOf([Brightness(p=0.2), Jpeg(p=1)])
]
post_phase = [ post_phase = [
# Empty on purpose or else augmentations are too strong # Empty on purpose or else augmentations are too strong
] ]
augmentation_pipeline = AugraphyPipeline( augmentation_pipeline = AugraphyPipeline(ink_phase=ink_phase, paper_phase=paper_phase, post_phase=post_phase)
ink_phase=ink_phase, paper_phase=paper_phase, post_phase=post_phase
)
# Apply augmentations # Apply augmentations
augmented_image_bgr = augmentation_pipeline(image_bgr) augmented_image_bgr = augmentation_pipeline(image_bgr)
@ -621,12 +613,11 @@ class AugraphyBasicAugmentations(PipelineStep):
sample["image"] = augmented_image_pil sample["image"] = augmented_image_pil
# Double-check PIL image size matches original # Double-check PIL image size matches original
assert augmented_image_pil.size == image.size, ( assert augmented_image_pil.size == image.size, f"PIL image size changed during augmentation: {image.size} -> {augmented_image_pil.size}"
f"PIL image size changed during augmentation: {image.size} -> {augmented_image_pil.size}"
)
return sample return sample
@dataclass(frozen=True, slots=True) @dataclass(frozen=True, slots=True)
class InstructUserMessages(PipelineStep): class InstructUserMessages(PipelineStep):
"""Creates instruction-following messages format for training.""" """Creates instruction-following messages format for training."""

View File

@ -39,7 +39,7 @@ def muon_update(grad, momentum, beta=0.95, ns_steps=5, nesterov=True):
if update.ndim == 4: # for the case of conv filters if update.ndim == 4: # for the case of conv filters
update = update.view(len(update), -1) update = update.view(len(update), -1)
update = zeropower_via_newtonschulz5(update, steps=ns_steps) update = zeropower_via_newtonschulz5(update, steps=ns_steps)
update *= max(1, grad.size(-2) / grad.size(-1))**0.5 update *= max(1, grad.size(-2) / grad.size(-1)) ** 0.5
return update return update
@ -64,6 +64,7 @@ class Muon(torch.optim.Optimizer):
weight_decay: The AdamW-style weight decay. weight_decay: The AdamW-style weight decay.
momentum: The momentum. A value of 0.95 here is usually fine. momentum: The momentum. A value of 0.95 here is usually fine.
""" """
def __init__(self, params, lr=0.02, weight_decay=0, momentum=0.95): def __init__(self, params, lr=0.02, weight_decay=0, momentum=0.95):
defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum)
assert isinstance(params, list) and len(params) >= 1 and isinstance(params[0], torch.nn.Parameter) assert isinstance(params, list) and len(params) >= 1 and isinstance(params[0], torch.nn.Parameter)
@ -81,7 +82,7 @@ class Muon(torch.optim.Optimizer):
for group in self.param_groups: for group in self.param_groups:
params = group["params"] params = group["params"]
params_pad = params + [torch.empty_like(params[-1])] * (dist.get_world_size() - len(params) % dist.get_world_size()) params_pad = params + [torch.empty_like(params[-1])] * (dist.get_world_size() - len(params) % dist.get_world_size())
for base_i in range(len(params))[::dist.get_world_size()]: for base_i in range(len(params))[:: dist.get_world_size()]:
if base_i + dist.get_rank() < len(params): if base_i + dist.get_rank() < len(params):
p = params[base_i + dist.get_rank()] p = params[base_i + dist.get_rank()]
if p.grad is None: if p.grad is None:
@ -93,7 +94,7 @@ class Muon(torch.optim.Optimizer):
update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"]) update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"])
p.mul_(1 - group["lr"] * group["weight_decay"]) p.mul_(1 - group["lr"] * group["weight_decay"])
p.add_(update.reshape(p.shape), alpha=-group["lr"]) p.add_(update.reshape(p.shape), alpha=-group["lr"])
dist.all_gather(params_pad[base_i:base_i + dist.get_world_size()], params_pad[base_i + dist.get_rank()]) dist.all_gather(params_pad[base_i : base_i + dist.get_world_size()], params_pad[base_i + dist.get_rank()])
return loss return loss
@ -102,6 +103,7 @@ class SingleDeviceMuon(torch.optim.Optimizer):
""" """
Muon variant for usage in non-distributed settings. Muon variant for usage in non-distributed settings.
""" """
def __init__(self, params, lr=0.02, weight_decay=0, momentum=0.95): def __init__(self, params, lr=0.02, weight_decay=0, momentum=0.95):
defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum) defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum)
super().__init__(params, defaults) super().__init__(params, defaults)
@ -132,8 +134,8 @@ class SingleDeviceMuon(torch.optim.Optimizer):
def adam_update(grad, buf1, buf2, step, betas, eps): def adam_update(grad, buf1, buf2, step, betas, eps):
buf1.lerp_(grad, 1 - betas[0]) buf1.lerp_(grad, 1 - betas[0])
buf2.lerp_(grad.square(), 1 - betas[1]) buf2.lerp_(grad.square(), 1 - betas[1])
buf1c = buf1 / (1 - betas[0]**step) buf1c = buf1 / (1 - betas[0] ** step)
buf2c = buf2 / (1 - betas[1]**step) buf2c = buf2 / (1 - betas[1] ** step)
return buf1c / (buf2c.sqrt() + eps) return buf1c / (buf2c.sqrt() + eps)
@ -164,6 +166,7 @@ class MuonWithAuxAdam(torch.optim.Optimizer):
optimizer = MuonWithAuxAdam(param_groups) optimizer = MuonWithAuxAdam(param_groups)
``` ```
""" """
def __init__(self, param_groups): def __init__(self, param_groups):
for group in param_groups: for group in param_groups:
assert "use_muon" in group assert "use_muon" in group
@ -195,7 +198,7 @@ class MuonWithAuxAdam(torch.optim.Optimizer):
if group["use_muon"]: if group["use_muon"]:
params = group["params"] params = group["params"]
params_pad = params + [torch.empty_like(params[-1])] * (dist.get_world_size() - len(params) % dist.get_world_size()) params_pad = params + [torch.empty_like(params[-1])] * (dist.get_world_size() - len(params) % dist.get_world_size())
for base_i in range(len(params))[::dist.get_world_size()]: for base_i in range(len(params))[:: dist.get_world_size()]:
if base_i + dist.get_rank() < len(params): if base_i + dist.get_rank() < len(params):
p = params[base_i + dist.get_rank()] p = params[base_i + dist.get_rank()]
if p.grad is None: if p.grad is None:
@ -207,7 +210,7 @@ class MuonWithAuxAdam(torch.optim.Optimizer):
update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"]) update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"])
p.mul_(1 - group["lr"] * group["weight_decay"]) p.mul_(1 - group["lr"] * group["weight_decay"])
p.add_(update.reshape(p.shape), alpha=-group["lr"]) p.add_(update.reshape(p.shape), alpha=-group["lr"])
dist.all_gather(params_pad[base_i:base_i + dist.get_world_size()], params_pad[base_i + dist.get_rank()]) dist.all_gather(params_pad[base_i : base_i + dist.get_world_size()], params_pad[base_i + dist.get_rank()])
else: else:
for p in group["params"]: for p in group["params"]:
if p.grad is None: if p.grad is None:
@ -219,8 +222,7 @@ class MuonWithAuxAdam(torch.optim.Optimizer):
state["exp_avg_sq"] = torch.zeros_like(p) state["exp_avg_sq"] = torch.zeros_like(p)
state["step"] = 0 state["step"] = 0
state["step"] += 1 state["step"] += 1
update = adam_update(p.grad, state["exp_avg"], state["exp_avg_sq"], update = adam_update(p.grad, state["exp_avg"], state["exp_avg_sq"], state["step"], group["betas"], group["eps"])
state["step"], group["betas"], group["eps"])
p.mul_(1 - group["lr"] * group["weight_decay"]) p.mul_(1 - group["lr"] * group["weight_decay"])
p.add_(update, alpha=-group["lr"]) p.add_(update, alpha=-group["lr"])
@ -231,6 +233,7 @@ class SingleDeviceMuonWithAuxAdam(torch.optim.Optimizer):
""" """
Non-distributed variant of MuonWithAuxAdam. Non-distributed variant of MuonWithAuxAdam.
""" """
def __init__(self, param_groups): def __init__(self, param_groups):
for group in param_groups: for group in param_groups:
assert "use_muon" in group assert "use_muon" in group
@ -280,8 +283,7 @@ class SingleDeviceMuonWithAuxAdam(torch.optim.Optimizer):
state["exp_avg_sq"] = torch.zeros_like(p) state["exp_avg_sq"] = torch.zeros_like(p)
state["step"] = 0 state["step"] = 0
state["step"] += 1 state["step"] += 1
update = adam_update(p.grad, state["exp_avg"], state["exp_avg_sq"], update = adam_update(p.grad, state["exp_avg"], state["exp_avg_sq"], state["step"], group["betas"], group["eps"])
state["step"], group["betas"], group["eps"])
p.mul_(1 - group["lr"] * group["weight_decay"]) p.mul_(1 - group["lr"] * group["weight_decay"])
p.add_(update, alpha=-group["lr"]) p.add_(update, alpha=-group["lr"])

View File

@ -4,26 +4,25 @@ Simple script to test OlmOCR dataset loading with YAML configuration.
import argparse import argparse
import logging import logging
import os
import math import math
import os
import shutil import shutil
from typing import Any, Dict, Optional
import numpy as np import numpy as np
import torch import torch
from torch.utils.data import ConcatDataset, DataLoader
from torch.optim import AdamW
from torch.amp import autocast
import wandb import wandb
from torch.amp import autocast
from torch.optim import AdamW
from torch.utils.data import ConcatDataset, DataLoader
from tqdm import tqdm from tqdm import tqdm
from transformers import ( from transformers import (
AutoProcessor, AutoProcessor,
get_scheduler,
Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration,
Qwen2VLForConditionalGeneration, Qwen2VLForConditionalGeneration,
get_scheduler,
) )
from typing import Optional, Dict, Any
from olmocr.train.config import Config from olmocr.train.config import Config
from olmocr.train.dataloader import BaseMarkdownPDFDataset from olmocr.train.dataloader import BaseMarkdownPDFDataset
from olmocr.train.muon import SingleDeviceMuonWithAuxAdam from olmocr.train.muon import SingleDeviceMuonWithAuxAdam
@ -37,7 +36,6 @@ logging.basicConfig(
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class QwenDataCollator: class QwenDataCollator:
"""Data collator for vision-language models that handles numpy arrays.""" """Data collator for vision-language models that handles numpy arrays."""
@ -128,10 +126,7 @@ def save_checkpoint(
# Enforce save_total_limit by removing oldest checkpoints # Enforce save_total_limit by removing oldest checkpoints
if save_total_limit is not None and save_total_limit > 0: if save_total_limit is not None and save_total_limit > 0:
checkpoints = sorted( checkpoints = sorted([d for d in os.listdir(output_dir) if d.startswith("checkpoint-")], key=lambda x: int(x.split("-")[1]))
[d for d in os.listdir(output_dir) if d.startswith("checkpoint-")],
key=lambda x: int(x.split("-")[1])
)
while len(checkpoints) > save_total_limit: while len(checkpoints) > save_total_limit:
oldest = checkpoints.pop(0) oldest = checkpoints.pop(0)
shutil.rmtree(os.path.join(output_dir, oldest)) shutil.rmtree(os.path.join(output_dir, oldest))
@ -284,7 +279,6 @@ def main():
if len(dataset) > 0: if len(dataset) > 0:
eval_datasets[dataset_name] = dataset eval_datasets[dataset_name] = dataset
# Log total evaluation samples across all datasets # Log total evaluation samples across all datasets
total_eval_samples = sum(len(dataset) for dataset in eval_datasets.values()) total_eval_samples = sum(len(dataset) for dataset in eval_datasets.values())
logger.info(f"Total evaluation samples across {len(eval_datasets)} datasets: {total_eval_samples}") logger.info(f"Total evaluation samples across {len(eval_datasets)} datasets: {total_eval_samples}")
@ -316,6 +310,7 @@ def main():
worker_seed = torch.initial_seed() % 2**32 worker_seed = torch.initial_seed() % 2**32
np.random.seed(worker_seed) np.random.seed(worker_seed)
import random import random
random.seed(worker_seed) random.seed(worker_seed)
# Create generator for data loader # Create generator for data loader
@ -370,7 +365,7 @@ def main():
adam_groups = [ adam_groups = [
dict(params=head_params, lr=float(config.training.learning_rate) * config.training.muon_lr_multiplier_head, use_muon=False), dict(params=head_params, lr=float(config.training.learning_rate) * config.training.muon_lr_multiplier_head, use_muon=False),
dict(params=embed_params, lr=float(config.training.learning_rate) * config.training.muon_lr_multiplier_embed, use_muon=False), dict(params=embed_params, lr=float(config.training.learning_rate) * config.training.muon_lr_multiplier_embed, use_muon=False),
dict(params=scalar_params, lr=float(config.training.learning_rate) * config.training.muon_lr_multiplier_scalar, use_muon=False) dict(params=scalar_params, lr=float(config.training.learning_rate) * config.training.muon_lr_multiplier_scalar, use_muon=False),
] ]
# Add Adam hyperparameters to groups # Add Adam hyperparameters to groups
@ -385,7 +380,7 @@ def main():
lr=float(config.training.learning_rate), lr=float(config.training.learning_rate),
momentum=config.training.muon_momentum, momentum=config.training.muon_momentum,
weight_decay=config.training.weight_decay, weight_decay=config.training.weight_decay,
use_muon=True use_muon=True,
) )
# Combine all groups # Combine all groups
@ -527,12 +522,7 @@ def main():
# Update progress bar with current stats # Update progress bar with current stats
current_lr = lr_scheduler.get_last_lr()[0] current_lr = lr_scheduler.get_last_lr()[0]
avg_loss = accumulated_loss / num_losses_accumulated if num_losses_accumulated > 0 else 0 avg_loss = accumulated_loss / num_losses_accumulated if num_losses_accumulated > 0 else 0
pbar.set_postfix({ pbar.set_postfix({"loss": f"{avg_loss:.4f}", "lr": f"{current_lr:.2e}", "epoch": f"{current_epoch:.2f}", "step": global_step})
'loss': f'{avg_loss:.4f}',
'lr': f'{current_lr:.2e}',
'epoch': f'{current_epoch:.2f}',
'step': global_step
})
# Logging # Logging
if config.training.logging_steps > 0 and global_step % config.training.logging_steps == 0: if config.training.logging_steps > 0 and global_step % config.training.logging_steps == 0:
@ -560,8 +550,9 @@ def main():
# Update best metric # Update best metric
current_metric = metrics.get(config.training.metric_for_best_model, None) current_metric = metrics.get(config.training.metric_for_best_model, None)
if current_metric is not None: if current_metric is not None:
if (config.training.greater_is_better and current_metric > best_metric) or \ if (config.training.greater_is_better and current_metric > best_metric) or (
(not config.training.greater_is_better and current_metric < best_metric): not config.training.greater_is_better and current_metric < best_metric
):
best_metric = current_metric best_metric = current_metric
# Return to training mode # Return to training mode
@ -570,8 +561,7 @@ def main():
# Saving # Saving
if config.training.save_steps > 0 and global_step % config.training.save_steps == 0: if config.training.save_steps > 0 and global_step % config.training.save_steps == 0:
save_checkpoint( save_checkpoint(
model, optimizer, lr_scheduler, current_epoch, global_step, samples_seen, best_metric, model, optimizer, lr_scheduler, current_epoch, global_step, samples_seen, best_metric, full_output_dir, config.training.save_total_limit
full_output_dir, config.training.save_total_limit
) )
# Check if we've reached our training limit # Check if we've reached our training limit
@ -583,10 +573,7 @@ def main():
# Save the final checkpoint with step number # Save the final checkpoint with step number
logger.info(f"Saving final checkpoint at step {global_step}...") logger.info(f"Saving final checkpoint at step {global_step}...")
save_checkpoint( save_checkpoint(model, optimizer, lr_scheduler, current_epoch, global_step, samples_seen, best_metric, full_output_dir, config.training.save_total_limit)
model, optimizer, lr_scheduler, current_epoch, global_step, samples_seen, best_metric,
full_output_dir, config.training.save_total_limit
)
# Log final training state # Log final training state
final_epoch = samples_seen / len(train_dataset) final_epoch = samples_seen / len(train_dataset)

View File

@ -171,7 +171,6 @@ class WorkQueue:
logger.info(f"Initialized queue with {self.size:,} work items") logger.info(f"Initialized queue with {self.size:,} work items")
return self.size return self.size
async def get_work(self, worker_lock_timeout_secs: int = 1800) -> Optional[WorkItem]: async def get_work(self, worker_lock_timeout_secs: int = 1800) -> Optional[WorkItem]:
""" """
Get the next available work item that isn't completed or locked. Get the next available work item that isn't completed or locked.
@ -179,7 +178,6 @@ class WorkQueue:
REFRESH_COMPLETED_HASH_CACHE_MAX_ATTEMPTS = 3 REFRESH_COMPLETED_HASH_CACHE_MAX_ATTEMPTS = 3
refresh_completed_hash_attempt = 0 refresh_completed_hash_attempt = 0
while True: while True:
try: try:
work_item = self._queue.get_nowait() work_item = self._queue.get_nowait()
@ -281,11 +279,7 @@ class LocalBackend(Backend):
def _list_completed() -> Set[str]: def _list_completed() -> Set[str]:
if not os.path.isdir(self._done_flags_dir): if not os.path.isdir(self._done_flags_dir):
return set() return set()
return { return {f[len("done_") : -len(".flag")] for f in os.listdir(self._done_flags_dir) if f.startswith("done_") and f.endswith(".flag")}
f[len("done_") : -len(".flag")]
for f in os.listdir(self._done_flags_dir)
if f.startswith("done_") and f.endswith(".flag")
}
return await asyncio.to_thread(_list_completed) return await asyncio.to_thread(_list_completed)
@ -299,6 +293,7 @@ class LocalBackend(Backend):
async def _get_object_mtime(self, path: str) -> Optional[datetime.datetime]: async def _get_object_mtime(self, path: str) -> Optional[datetime.datetime]:
"""Internal method to get object mtime.""" """Internal method to get object mtime."""
def _get_mtime() -> Optional[datetime.datetime]: def _get_mtime() -> Optional[datetime.datetime]:
if not os.path.exists(path): if not os.path.exists(path):
return None return None

View File

@ -6,7 +6,7 @@ from unittest.mock import Mock, patch
from botocore.exceptions import ClientError from botocore.exceptions import ClientError
# Import the classes we're testing # Import the classes we're testing
from olmocr.work_queue import WorkQueue, S3Backend, WorkItem from olmocr.work_queue import S3Backend, WorkItem, WorkQueue
class TestS3WorkQueue(unittest.TestCase): class TestS3WorkQueue(unittest.TestCase):