Refactoring

This commit is contained in:
Jake Poznanski 2025-01-27 20:45:28 +00:00
parent c6062677aa
commit 96ae2dd49b
6 changed files with 662 additions and 598 deletions

1
.gitignore vendored
View File

@ -11,6 +11,7 @@ sample200_vllm/*
sample200_sglang/*
pdelfin_testset/*
/*.html
scoreelo.csv
debug.log
birrpipeline-debug.log
beakerpipeline-debug.log

View File

@ -31,7 +31,7 @@ from typing import Optional, Tuple, List, Dict, Set
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
from concurrent.futures.process import BrokenProcessPool
from olmocr.s3_queue import S3WorkQueue, WorkItem
from olmocr.work_queue import S3WorkQueue, LocalWorkQueue
from olmocr.s3_utils import expand_s3_glob, get_s3_bytes, get_s3_bytes_with_backoff, parse_s3_path, download_zstd_csv, upload_zstd_csv, download_directory
from olmocr.data.renderpdf import render_pdf_to_base64png
from olmocr.filter.filter import PdfFilter, Language
@ -420,7 +420,7 @@ async def worker(args, work_queue: S3WorkQueue, semaphore, worker_id):
try:
async with asyncio.TaskGroup() as tg:
dolma_tasks = [tg.create_task(process_pdf(args, worker_id, pdf)) for pdf in work_item.s3_work_paths]
dolma_tasks = [tg.create_task(process_pdf(args, worker_id, pdf)) for pdf in work_item.work_paths]
logger.info(f"Created all tasks for {work_item.hash}")
logger.info(f"Finished TaskGroup for worker on {work_item.hash}")
@ -834,7 +834,7 @@ def print_stats(args):
async def main():
parser = argparse.ArgumentParser(description='Manager for running millions of PDFs through a batch inference pipeline')
parser.add_argument('workspace', help='The S3 path where work will be done e.g., s3://bucket/prefix/')
parser.add_argument('workspace', help='The filesystem path where work will be stored, can be a local folder, or an s3 path if coordinating work with many workers, s3://bucket/prefix/ ')
parser.add_argument('--pdfs', help='Path to add pdfs stored in s3 to the workspace, can be a glob path s3://bucket/prefix/*.pdf or path to file containing list of pdf paths', default=None)
parser.add_argument('--workspace_profile', help='S3 configuration profile for accessing the workspace', default=None)
parser.add_argument('--pdf_profile', help='S3 configuration profile for accessing the raw pdf documents', default=None)
@ -891,7 +891,10 @@ async def main():
check_poppler_version()
# Create work queue
work_queue = S3WorkQueue(workspace_s3, args.workspace)
if args.workspace.startswith("s3://"):
work_queue = S3WorkQueue(workspace_s3, args.workspace)
else:
work_queue = LocalWorkQueue(args.workspace)
if args.pdfs:
logger.info("Got --pdfs argument, going to add to the work queue")

View File

@ -1,302 +0,0 @@
import os
import random
import logging
import hashlib
import tempfile
import datetime
from typing import Optional, Tuple, List, Dict, Set
from dataclasses import dataclass
import asyncio
from functools import partial
from olmocr.s3_utils import (
expand_s3_glob,
download_zstd_csv,
upload_zstd_csv,
parse_s3_path
)
logger = logging.getLogger(__name__)
@dataclass
class WorkItem:
"""Represents a single work item in the queue"""
hash: str
s3_work_paths: List[str]
class S3WorkQueue:
"""
Manages a work queue stored in S3 that coordinates work across multiple workers.
The queue maintains a list of work items, where each work item is a group of s3 paths
that should be processed together.
Each work item gets a hash, and completed work items will have their results
stored in s3://workspace_path/results/output_[hash].jsonl
This is the ground source of truth about which work items are done.
When a worker takes an item off the queue, it will write an empty s3 file to
s3://workspace_path/worker_locks/output_[hash].jsonl
The queue gets randomized on each worker, so workers pull random work items to operate on.
As you pull an item, we will check to see if it has been completed. If yes,
then it will immediately fetch the next item. If a lock file was created within a configurable
timeout (30 mins by default), then that work item is also skipped.
The lock will will be deleted once the worker is done with that item.
"""
def __init__(self, s3_client, workspace_path: str):
"""
Initialize the work queue.
Args:
s3_client: Boto3 S3 client to use for operations
workspace_path: S3 path where work queue and results are stored
"""
self.s3_client = s3_client
self.workspace_path = workspace_path.rstrip('/')
self._index_path = os.path.join(self.workspace_path, "work_index_list.csv.zstd")
self._output_glob = os.path.join(self.workspace_path, "results", "*.jsonl")
self._queue = asyncio.Queue()
@staticmethod
def _compute_workgroup_hash(s3_work_paths: List[str]) -> str:
"""
Compute a deterministic hash for a group of paths.
Args:
s3_work_paths: List of S3 paths
Returns:
SHA1 hash of the sorted paths
"""
sha1 = hashlib.sha1()
for path in sorted(s3_work_paths):
sha1.update(path.encode('utf-8'))
return sha1.hexdigest()
async def populate_queue(self, s3_work_paths: list[str], items_per_group: int) -> None:
"""
Add new items to the work queue.
Args:
s3_work_paths: Each individual s3 path that we will process over
items_per_group: Number of items to group together in a single work item
"""
all_paths = set(s3_work_paths)
logger.info(f"Found {len(all_paths):,} total paths")
# Load existing work groups
existing_lines = await asyncio.to_thread(download_zstd_csv, self.s3_client, self._index_path)
existing_groups = {}
for line in existing_lines:
if line.strip():
parts = line.strip().split(",")
group_hash = parts[0]
group_paths = parts[1:]
existing_groups[group_hash] = group_paths
existing_path_set = {path for paths in existing_groups.values() for path in paths}
# Find new paths to process
new_paths = all_paths - existing_path_set
logger.info(f"{len(new_paths):,} new paths to add to the workspace")
if not new_paths:
return
# Create new work groups
new_groups = []
current_group = []
for path in sorted(new_paths):
current_group.append(path)
if len(current_group) == items_per_group:
group_hash = self._compute_workgroup_hash(current_group)
new_groups.append((group_hash, current_group))
current_group = []
if current_group:
group_hash = self._compute_workgroup_hash(current_group)
new_groups.append((group_hash, current_group))
logger.info(f"Created {len(new_groups):,} new work groups")
# Combine and save updated work groups
combined_groups = existing_groups.copy()
for group_hash, group_paths in new_groups:
combined_groups[group_hash] = group_paths
combined_lines = [
",".join([group_hash] + group_paths)
for group_hash, group_paths in combined_groups.items()
]
if new_groups:
await asyncio.to_thread(
upload_zstd_csv,
self.s3_client,
self._index_path,
combined_lines
)
async def initialize_queue(self) -> None:
"""
Load the work queue from S3 and initialize it for processing.
Removes already completed work items and randomizes the order.
"""
# Load work items and completed items in parallel
download_task = asyncio.to_thread(
download_zstd_csv,
self.s3_client,
self._index_path
)
expand_task = asyncio.to_thread(
expand_s3_glob,
self.s3_client,
self._output_glob
)
work_queue_lines, done_work_items = await asyncio.gather(download_task, expand_task)
# Process work queue lines
work_queue = {
parts[0]: parts[1:]
for line in work_queue_lines
if (parts := line.strip().split(",")) and line.strip()
}
# Get set of completed work hashes
done_work_hashes = {
os.path.basename(item)[len('output_'):-len('.jsonl')]
for item in done_work_items
if os.path.basename(item).startswith('output_')
and os.path.basename(item).endswith('.jsonl')
}
# Find remaining work and shuffle
remaining_work_hashes = set(work_queue) - done_work_hashes
remaining_items = [
WorkItem(hash=hash_, s3_work_paths=work_queue[hash_])
for hash_ in remaining_work_hashes
]
random.shuffle(remaining_items)
# Initialize queue
self._queue = asyncio.Queue()
for item in remaining_items:
await self._queue.put(item)
logger.info(f"Initialized queue with {self._queue.qsize()} work items")
async def is_completed(self, work_hash: str) -> bool:
"""
Check if a work item has been completed.
Args:
work_hash: Hash of the work item to check
Returns:
True if the work is completed, False otherwise
"""
output_s3_path = os.path.join(self.workspace_path, "results", f"output_{work_hash}.jsonl")
bucket, key = parse_s3_path(output_s3_path)
try:
await asyncio.to_thread(
self.s3_client.head_object,
Bucket=bucket,
Key=key
)
return True
except self.s3_client.exceptions.ClientError:
return False
async def get_work(self, worker_lock_timeout_secs: int = 1800) -> Optional[WorkItem]:
"""
Get the next available work item that isn't completed or locked.
Args:
worker_lock_timeout_secs: Number of seconds before considering a worker lock stale (default 30 mins)
Returns:
WorkItem if work is available, None if queue is empty
"""
while True:
try:
work_item = self._queue.get_nowait()
except asyncio.QueueEmpty:
return None
# Check if work is already completed
if await self.is_completed(work_item.hash):
logger.debug(f"Work item {work_item.hash} already completed, skipping")
self._queue.task_done()
continue
# Check for worker lock
lock_path = os.path.join(self.workspace_path, "worker_locks", f"output_{work_item.hash}.jsonl")
bucket, key = parse_s3_path(lock_path)
try:
response = await asyncio.to_thread(
self.s3_client.head_object,
Bucket=bucket,
Key=key
)
# Check if lock is stale
last_modified = response['LastModified']
if (datetime.datetime.now(datetime.timezone.utc) - last_modified).total_seconds() > worker_lock_timeout_secs:
# Lock is stale, we can take this work
logger.debug(f"Found stale lock for {work_item.hash}, taking work item")
else:
# Lock is active, skip this work
logger.debug(f"Work item {work_item.hash} is locked by another worker, skipping")
self._queue.task_done()
continue
except self.s3_client.exceptions.ClientError:
# No lock exists, we can take this work
pass
# Create our lock file
try:
await asyncio.to_thread(
self.s3_client.put_object,
Bucket=bucket,
Key=key,
Body=b''
)
except Exception as e:
logger.warning(f"Failed to create lock file for {work_item.hash}: {e}")
self._queue.task_done()
continue
return work_item
async def mark_done(self, work_item: WorkItem) -> None:
"""
Mark a work item as done by removing its lock file.
Args:
work_item: The WorkItem to mark as done
"""
lock_path = os.path.join(self.workspace_path, "worker_locks", f"output_{work_item.hash}.jsonl")
bucket, key = parse_s3_path(lock_path)
try:
await asyncio.to_thread(
self.s3_client.delete_object,
Bucket=bucket,
Key=key
)
except Exception as e:
logger.warning(f"Failed to delete lock file for {work_item.hash}: {e}")
self._queue.task_done()
@property
def size(self) -> int:
"""Get current size of work queue"""
return self._queue.qsize()

640
olmocr/work_queue.py Normal file
View File

@ -0,0 +1,640 @@
import os
import random
import logging
import hashlib
import tempfile
import datetime
import asyncio
import abc
from typing import Optional, List, Dict, Set
from dataclasses import dataclass
from functools import partial
logger = logging.getLogger(__name__)
@dataclass
class WorkItem:
"""Represents a single work item in the queue"""
hash: str
work_paths: List[str]
class WorkQueue(abc.ABC):
"""
Base class defining the interface for a work queue.
"""
@abc.abstractmethod
async def populate_queue(self, work_paths: List[str], items_per_group: int) -> None:
"""
Add new items to the work queue. The specifics will vary depending on
whether this is a local or S3-backed queue.
Args:
work_paths: Each individual path that we will process over
items_per_group: Number of items to group together in a single work item
"""
pass
@abc.abstractmethod
async def initialize_queue(self) -> None:
"""
Load the work queue from the relevant store (local or remote)
and initialize it for processing.
For example, this might remove already completed work items and randomize
the order before adding them to an internal queue.
"""
pass
@abc.abstractmethod
async def is_completed(self, work_hash: str) -> bool:
"""
Check if a work item has been completed.
Args:
work_hash: Hash of the work item to check
Returns:
True if the work is completed, False otherwise
"""
pass
@abc.abstractmethod
async def get_work(self, worker_lock_timeout_secs: int = 1800) -> Optional[WorkItem]:
"""
Get the next available work item that isn't completed or locked.
Args:
worker_lock_timeout_secs: Number of seconds before considering
a worker lock stale (default 30 mins)
Returns:
WorkItem if work is available, None if queue is empty
"""
pass
@abc.abstractmethod
async def mark_done(self, work_item: WorkItem) -> None:
"""
Mark a work item as done by removing its lock file
or performing any other cleanup.
Args:
work_item: The WorkItem to mark as done
"""
pass
@property
@abc.abstractmethod
def size(self) -> int:
"""Get current size of work queue"""
pass
@staticmethod
def _compute_workgroup_hash(s3_work_paths: List[str]) -> str:
"""
Compute a deterministic hash for a group of paths.
Args:
s3_work_paths: List of paths (local or S3)
Returns:
SHA1 hash of the sorted paths
"""
sha1 = hashlib.sha1()
for path in sorted(s3_work_paths):
sha1.update(path.encode('utf-8'))
return sha1.hexdigest()
# --------------------------------------------------------------------------------------
# Local Helpers for reading/writing the index CSV (compressed with zstd) to disk
# --------------------------------------------------------------------------------------
try:
import zstandard
except ImportError:
zstandard = None
def download_zstd_csv_local(local_path: str) -> List[str]:
"""
Download a zstd-compressed CSV from a local path.
If the file doesn't exist, returns an empty list.
"""
if not os.path.exists(local_path):
return []
if not zstandard:
raise RuntimeError("zstandard package is required for local zstd CSV operations.")
with open(local_path, 'rb') as f:
dctx = zstandard.ZstdDecompressor()
data = dctx.decompress(f.read())
lines = data.decode('utf-8').splitlines()
return lines
def upload_zstd_csv_local(local_path: str, lines: List[str]) -> None:
"""
Upload a zstd-compressed CSV to a local path.
"""
if not zstandard:
raise RuntimeError("zstandard package is required for local zstd CSV operations.")
data = "\n".join(lines).encode('utf-8')
cctx = zstandard.ZstdCompressor()
compressed_data = cctx.compress(data)
# Ensure parent directories exist
os.makedirs(os.path.dirname(local_path), exist_ok=True)
with open(local_path, 'wb') as f:
f.write(compressed_data)
# --------------------------------------------------------------------------------------
# LocalWorkQueue Implementation
# --------------------------------------------------------------------------------------
class LocalWorkQueue(WorkQueue):
"""
A local in-memory and on-disk WorkQueue implementation, which uses
a local workspace directory to store the queue index, lock files,
and completed results for persistent resumption across process restarts.
"""
def __init__(self, workspace_path: str):
"""
Initialize the local work queue.
Args:
workspace_path: Local directory path where the queue index,
results, and locks are stored.
"""
self.workspace_path = os.path.abspath(workspace_path)
os.makedirs(self.workspace_path, exist_ok=True)
# Local index file (compressed)
self._index_path = os.path.join(self.workspace_path, "work_index_list.csv.zstd")
# Output directory for completed tasks
self._results_dir = os.path.join(self.workspace_path, "results")
os.makedirs(self._results_dir, exist_ok=True)
# Directory for lock files
self._locks_dir = os.path.join(self.workspace_path, "worker_locks")
os.makedirs(self._locks_dir, exist_ok=True)
# Internal queue
self._queue = asyncio.Queue()
async def populate_queue(self, s3_work_paths: List[str], items_per_group: int) -> None:
"""
Add new items to the work queue (local version).
Args:
s3_work_paths: Each individual path (local in this context)
that we will process over
items_per_group: Number of items to group together in a single work item
"""
# Treat them as local paths, but keep variable name for consistency
all_paths = set(s3_work_paths)
logger.info(f"Found {len(all_paths):,} total paths")
# Load existing work groups from local index
existing_lines = await asyncio.to_thread(download_zstd_csv_local, self._index_path)
existing_groups = {}
for line in existing_lines:
if line.strip():
parts = line.strip().split(",")
group_hash = parts[0]
group_paths = parts[1:]
existing_groups[group_hash] = group_paths
existing_path_set = {p for paths in existing_groups.values() for p in paths}
new_paths = all_paths - existing_path_set
logger.info(f"{len(new_paths):,} new paths to add to the workspace")
if not new_paths:
return
# Create new work groups
new_groups = []
current_group = []
for path in sorted(new_paths):
current_group.append(path)
if len(current_group) == items_per_group:
group_hash = self._compute_workgroup_hash(current_group)
new_groups.append((group_hash, current_group))
current_group = []
if current_group:
group_hash = self._compute_workgroup_hash(current_group)
new_groups.append((group_hash, current_group))
logger.info(f"Created {len(new_groups):,} new work groups")
# Combine and save updated work groups
combined_groups = existing_groups.copy()
for group_hash, group_paths in new_groups:
combined_groups[group_hash] = group_paths
combined_lines = [
",".join([group_hash] + group_paths)
for group_hash, group_paths in combined_groups.items()
]
if new_groups:
# Write the combined data back to disk in zstd CSV format
await asyncio.to_thread(upload_zstd_csv_local, self._index_path, combined_lines)
async def initialize_queue(self) -> None:
"""
Load the work queue from the local index file and initialize it for processing.
Removes already completed work items and randomizes the order.
"""
# 1) Read the index
work_queue_lines = await asyncio.to_thread(download_zstd_csv_local, self._index_path)
work_queue = {
parts[0]: parts[1:]
for line in work_queue_lines
if (parts := line.strip().split(",")) and line.strip()
}
# 2) Determine which items are completed by scanning local results/*.jsonl
if not os.path.isdir(self._results_dir):
os.makedirs(self._results_dir, exist_ok=True)
done_work_items = [
f for f in os.listdir(self._results_dir)
if f.startswith("output_") and f.endswith(".jsonl")
]
done_work_hashes = {
fn[len('output_'):-len('.jsonl')]
for fn in done_work_items
}
# 3) Filter out completed items
remaining_work_hashes = set(work_queue) - done_work_hashes
remaining_items = [
WorkItem(hash=hash_, s3_work_paths=work_queue[hash_])
for hash_ in remaining_work_hashes
]
random.shuffle(remaining_items)
# 4) Initialize our in-memory queue
self._queue = asyncio.Queue()
for item in remaining_items:
await self._queue.put(item)
logger.info(f"Initialized local queue with {self._queue.qsize()} work items")
async def is_completed(self, work_hash: str) -> bool:
"""
Check if a work item has been completed locally by seeing if
output_{work_hash}.jsonl is present in the results directory.
Args:
work_hash: Hash of the work item to check
"""
output_file = os.path.join(self._results_dir, f"output_{work_hash}.jsonl")
return os.path.exists(output_file)
async def get_work(self, worker_lock_timeout_secs: int = 1800) -> Optional[WorkItem]:
"""
Get the next available work item that isn't completed or locked.
Args:
worker_lock_timeout_secs: Number of seconds before considering
a worker lock stale (default 30 mins)
Returns:
WorkItem if work is available, None if queue is empty
"""
while True:
try:
work_item = self._queue.get_nowait()
except asyncio.QueueEmpty:
return None
# Check if work is already completed
if await self.is_completed(work_item.hash):
logger.debug(f"Work item {work_item.hash} already completed, skipping")
self._queue.task_done()
continue
# Check for worker lock
lock_file = os.path.join(self._locks_dir, f"output_{work_item.hash}.jsonl")
if os.path.exists(lock_file):
# Check modification time
mtime = datetime.datetime.fromtimestamp(os.path.getmtime(lock_file), datetime.timezone.utc)
if (datetime.datetime.now(datetime.timezone.utc) - mtime).total_seconds() > worker_lock_timeout_secs:
# Lock is stale, we can take this work
logger.debug(f"Found stale lock for {work_item.hash}, taking work item")
else:
# Lock is active, skip this work
logger.debug(f"Work item {work_item.hash} is locked by another worker, skipping")
self._queue.task_done()
continue
# Create our lock file (touch an empty file)
try:
with open(lock_file, "wb") as f:
f.write(b"")
except Exception as e:
logger.warning(f"Failed to create lock file for {work_item.hash}: {e}")
self._queue.task_done()
continue
return work_item
async def mark_done(self, work_item: WorkItem) -> None:
"""
Mark a work item as done by removing its lock file.
Args:
work_item: The WorkItem to mark as done
"""
lock_file = os.path.join(self._locks_dir, f"output_{work_item.hash}.jsonl")
if os.path.exists(lock_file):
try:
os.remove(lock_file)
except Exception as e:
logger.warning(f"Failed to delete lock file for {work_item.hash}: {e}")
self._queue.task_done()
@property
def size(self) -> int:
"""Get current size of local work queue"""
return self._queue.qsize()
# --------------------------------------------------------------------------------------
# S3WorkQueue Implementation (Preserves Original Comments)
# --------------------------------------------------------------------------------------
from olmocr.s3_utils import (
expand_s3_glob,
download_zstd_csv,
upload_zstd_csv,
parse_s3_path
)
class S3WorkQueue(WorkQueue):
"""
Manages a work queue stored in S3 that coordinates work across multiple workers.
The queue maintains a list of work items, where each work item is a group of s3 paths
that should be processed together.
Each work item gets a hash, and completed work items will have their results
stored in s3://workspace_path/results/output_[hash].jsonl
This is the ground source of truth about which work items are done.
When a worker takes an item off the queue, it will write an empty s3 file to
s3://workspace_path/worker_locks/output_[hash].jsonl
The queue gets randomized on each worker, so workers pull random work items to operate on.
As you pull an item, we will check to see if it has been completed. If yes,
then it will immediately fetch the next item. If a lock file was created within a configurable
timeout (30 mins by default), then that work item is also skipped.
The lock will will be deleted once the worker is done with that item.
"""
def __init__(self, s3_client, workspace_path: str):
"""
Initialize the work queue.
Args:
s3_client: Boto3 S3 client to use for operations
workspace_path: S3 path where work queue and results are stored
"""
self.s3_client = s3_client
self.workspace_path = workspace_path.rstrip('/')
self._index_path = os.path.join(self.workspace_path, "work_index_list.csv.zstd")
self._output_glob = os.path.join(self.workspace_path, "results", "*.jsonl")
self._queue = asyncio.Queue()
async def populate_queue(self, s3_work_paths: List[str], items_per_group: int) -> None:
"""
Add new items to the work queue.
Args:
s3_work_paths: Each individual s3 path that we will process over
items_per_group: Number of items to group together in a single work item
"""
all_paths = set(s3_work_paths)
logger.info(f"Found {len(all_paths):,} total paths")
# Load existing work groups
existing_lines = await asyncio.to_thread(download_zstd_csv, self.s3_client, self._index_path)
existing_groups = {}
for line in existing_lines:
if line.strip():
parts = line.strip().split(",")
group_hash = parts[0]
group_paths = parts[1:]
existing_groups[group_hash] = group_paths
existing_path_set = {path for paths in existing_groups.values() for path in paths}
# Find new paths to process
new_paths = all_paths - existing_path_set
logger.info(f"{len(new_paths):,} new paths to add to the workspace")
if not new_paths:
return
# Create new work groups
new_groups = []
current_group = []
for path in sorted(new_paths):
current_group.append(path)
if len(current_group) == items_per_group:
group_hash = self._compute_workgroup_hash(current_group)
new_groups.append((group_hash, current_group))
current_group = []
if current_group:
group_hash = self._compute_workgroup_hash(current_group)
new_groups.append((group_hash, current_group))
logger.info(f"Created {len(new_groups):,} new work groups")
# Combine and save updated work groups
combined_groups = existing_groups.copy()
for group_hash, group_paths in new_groups:
combined_groups[group_hash] = group_paths
combined_lines = [
",".join([group_hash] + group_paths)
for group_hash, group_paths in combined_groups.items()
]
if new_groups:
await asyncio.to_thread(
upload_zstd_csv,
self.s3_client,
self._index_path,
combined_lines
)
async def initialize_queue(self) -> None:
"""
Load the work queue from S3 and initialize it for processing.
Removes already completed work items and randomizes the order.
"""
# Load work items and completed items in parallel
download_task = asyncio.to_thread(
download_zstd_csv,
self.s3_client,
self._index_path
)
expand_task = asyncio.to_thread(
expand_s3_glob,
self.s3_client,
self._output_glob
)
work_queue_lines, done_work_items = await asyncio.gather(download_task, expand_task)
# Process work queue lines
work_queue = {
parts[0]: parts[1:]
for line in work_queue_lines
if (parts := line.strip().split(",")) and line.strip()
}
# Get set of completed work hashes
done_work_hashes = {
os.path.basename(item)[len('output_'):-len('.jsonl')]
for item in done_work_items
if os.path.basename(item).startswith('output_')
and os.path.basename(item).endswith('.jsonl')
}
# Find remaining work and shuffle
remaining_work_hashes = set(work_queue) - done_work_hashes
remaining_items = [
WorkItem(hash=hash_, s3_work_paths=work_queue[hash_])
for hash_ in remaining_work_hashes
]
random.shuffle(remaining_items)
# Initialize queue
self._queue = asyncio.Queue()
for item in remaining_items:
await self._queue.put(item)
logger.info(f"Initialized queue with {self._queue.qsize()} work items")
async def is_completed(self, work_hash: str) -> bool:
"""
Check if a work item has been completed.
Args:
work_hash: Hash of the work item to check
Returns:
True if the work is completed, False otherwise
"""
output_s3_path = os.path.join(self.workspace_path, "results", f"output_{work_hash}.jsonl")
bucket, key = parse_s3_path(output_s3_path)
try:
await asyncio.to_thread(
self.s3_client.head_object,
Bucket=bucket,
Key=key
)
return True
except self.s3_client.exceptions.ClientError:
return False
async def get_work(self, worker_lock_timeout_secs: int = 1800) -> Optional[WorkItem]:
"""
Get the next available work item that isn't completed or locked.
Args:
worker_lock_timeout_secs: Number of seconds before considering a worker lock stale (default 30 mins)
Returns:
WorkItem if work is available, None if queue is empty
"""
while True:
try:
work_item = self._queue.get_nowait()
except asyncio.QueueEmpty:
return None
# Check if work is already completed
if await self.is_completed(work_item.hash):
logger.debug(f"Work item {work_item.hash} already completed, skipping")
self._queue.task_done()
continue
# Check for worker lock
lock_path = os.path.join(self.workspace_path, "worker_locks", f"output_{work_item.hash}.jsonl")
bucket, key = parse_s3_path(lock_path)
try:
response = await asyncio.to_thread(
self.s3_client.head_object,
Bucket=bucket,
Key=key
)
# Check if lock is stale
last_modified = response['LastModified']
if (datetime.datetime.now(datetime.timezone.utc) - last_modified).total_seconds() > worker_lock_timeout_secs:
# Lock is stale, we can take this work
logger.debug(f"Found stale lock for {work_item.hash}, taking work item")
else:
# Lock is active, skip this work
logger.debug(f"Work item {work_item.hash} is locked by another worker, skipping")
self._queue.task_done()
continue
except self.s3_client.exceptions.ClientError:
# No lock exists, we can take this work
pass
# Create our lock file
try:
await asyncio.to_thread(
self.s3_client.put_object,
Bucket=bucket,
Key=key,
Body=b''
)
except Exception as e:
logger.warning(f"Failed to create lock file for {work_item.hash}: {e}")
self._queue.task_done()
continue
return work_item
async def mark_done(self, work_item: WorkItem) -> None:
"""
Mark a work item as done by removing its lock file.
Args:
work_item: The WorkItem to mark as done
"""
lock_path = os.path.join(self.workspace_path, "worker_locks", f"output_{work_item.hash}.jsonl")
bucket, key = parse_s3_path(lock_path)
try:
await asyncio.to_thread(
self.s3_client.delete_object,
Bucket=bucket,
Key=key
)
except Exception as e:
logger.warning(f"Failed to delete lock file for {work_item.hash}: {e}")
self._queue.task_done()
@property
def size(self) -> int:
"""Get current size of work queue"""
return self._queue.qsize()

View File

@ -1,278 +0,0 @@
import unittest
from unittest.mock import MagicMock, patch
import hashlib
import json
import os
import base64
from PIL import Image
# Adjust the import path to match where your code resides
from olmocr.birrpipeline import build_dolma_doc, DatabaseManager, build_finetuning_prompt, build_page_query
class TestBuildDolmaDoc(unittest.TestCase):
@patch('olmocr.birrpipeline.DatabaseManager')
@patch('olmocr.birrpipeline.get_s3_bytes')
def test_build_dolma_doc_with_multiple_page_entries(self, mock_get_s3_bytes, mock_DatabaseManager):
# Mock DatabaseManager instance
mock_db_instance = MagicMock()
mock_DatabaseManager.return_value = mock_db_instance
# Define the PDF record
pdf_s3_path = 's3://bucket/pdf/test.pdf'
pdf = DatabaseManager.PDFRecord(s3_path=pdf_s3_path, num_pages=1, status='pending')
# Create multiple BatchInferenceRecord entries for page_num=1
entry_a = DatabaseManager.BatchInferenceRecord(
inference_s3_path='s3://bucket/inference/output1.jsonl',
pdf_s3_path=pdf_s3_path,
page_num=1,
round=0,
start_index=0,
length=100,
finish_reason='stop',
error=None
)
entry_b = DatabaseManager.BatchInferenceRecord(
inference_s3_path='s3://bucket/inference/output2.jsonl',
pdf_s3_path=pdf_s3_path,
page_num=1,
round=0,
start_index=0,
length=100,
finish_reason='stop',
error=None
)
entry_c = DatabaseManager.BatchInferenceRecord(
inference_s3_path='s3://bucket/inference/output3.jsonl',
pdf_s3_path=pdf_s3_path,
page_num=1,
round=0,
start_index=0,
length=100,
finish_reason='stop',
error=None
)
entry_d = DatabaseManager.BatchInferenceRecord(
inference_s3_path='s3://bucket/inference/output4.jsonl',
pdf_s3_path=pdf_s3_path,
page_num=1,
round=0,
start_index=0,
length=100,
finish_reason='stop',
error=None
)
# Set up mock_db_instance.get_index_entries to return all entries
mock_db_instance.get_index_entries.return_value = [entry_a, entry_b, entry_c, entry_d]
# Define get_s3_bytes side effect function
def get_s3_bytes_side_effect(s3_client, s3_path, start_index=None, end_index=None):
if s3_path == 's3://bucket/inference/output1.jsonl':
inner_data = {
"primary_language": "en",
"is_rotation_valid": True,
"rotation_correction": 0,
"is_table": False,
"is_diagram": False,
"natural_text": "Short Text"
}
data = {
"custom_id": f"{pdf_s3_path}-1",
"outputs": [{"text": json.dumps(inner_data)}],
"round": 0
}
elif s3_path == 's3://bucket/inference/output2.jsonl':
inner_data = {
"primary_language": "en",
"is_rotation_valid": False,
"rotation_correction": 90,
"is_table": True,
"is_diagram": False,
"natural_text": "Very Long Text Here that is longer"
}
data = {
"custom_id": f"{pdf_s3_path}-1",
"outputs": [{"text": json.dumps(inner_data)}],
"round": 0
}
elif s3_path == 's3://bucket/inference/output3.jsonl':
inner_data = {
"primary_language": "en",
"is_rotation_valid": True,
"rotation_correction": 0,
"is_table": False,
"is_diagram": True,
"natural_text": "Medium Length Text"
}
data = {
"custom_id": f"{pdf_s3_path}-1",
"outputs": [{"text": json.dumps(inner_data)}],
"round": 0
}
elif s3_path == 's3://bucket/inference/output4.jsonl':
inner_data = {
"primary_language": "en",
"is_rotation_valid": True,
"rotation_correction": 0,
"is_table": False,
"is_diagram": False,
"natural_text": "The Longest Correct Text"
}
data = {
"custom_id": f"{pdf_s3_path}-1",
"outputs": [{"text": json.dumps(inner_data)}],
"round": 0
}
else:
data = {}
line = json.dumps(data) + '\n'
content_bytes = line.encode('utf-8')
return content_bytes
mock_get_s3_bytes.side_effect = get_s3_bytes_side_effect
# Call build_dolma_doc
s3_workspace = 's3://bucket/workspace'
dolma_doc = build_dolma_doc(s3_workspace, pdf)
# Check that the resulting dolma_doc has the expected document_text
expected_text = 'The Longest Correct Text\n'
self.assertIsNotNone(dolma_doc)
self.assertEqual(dolma_doc['text'], expected_text)
# Additional assertions to ensure that the correct page was selected
self.assertEqual(dolma_doc['metadata']['Source-File'], pdf_s3_path)
self.assertEqual(dolma_doc['metadata']['pdf-total-pages'], 1)
self.assertEqual(len(dolma_doc['attributes']['pdf_page_numbers']), 1)
self.assertEqual(dolma_doc['attributes']['pdf_page_numbers'][0][2], 1)
# Ensure that the document ID is correctly computed
expected_id = hashlib.sha1(expected_text.encode()).hexdigest()
self.assertEqual(dolma_doc['id'], expected_id)
class TestBuildPageQuery(unittest.TestCase):
def testNotParsing(self):
file = os.path.join(
os.path.dirname(__file__),
"gnarly_pdfs",
"not_parsing.pdf"
)
for page in range(1,9):
query = build_page_query(file, "not_parsing.pdf", page, 1024, 6000)
print(query)
def testNotParsing2(self):
file = os.path.join(
os.path.dirname(__file__),
"gnarly_pdfs",
"not_parsing2.pdf"
)
for page in range(1,10):
query = build_page_query(file, "not_parsing2.pdf", page, 1024, 6000)
print(query)
def testNotParsingHugeMemoryUsage(self):
file = os.path.join(
os.path.dirname(__file__),
"gnarly_pdfs",
"failing_pdf_pg9.pdf"
)
print("Starting to parse bad pdf")
query = build_page_query(file, "failing_pdf_pg9.pdf", 9, 1024, 6000)
print(query)
def testRotation(self):
# First, generate and save the non-rotated image
query = build_page_query(os.path.join(
os.path.dirname(__file__),
"gnarly_pdfs",
"edgar.pdf"
), "edgar.pdf", 1, 1024, 6000, 0)
# Extract the base64 image from the query
image_content = query["chat_messages"][0]["content"][1]
self.assertEqual(image_content["type"], "image_url")
image_url = image_content["image_url"]["url"]
# Extract base64 string from the data URL
prefix = "data:image/png;base64,"
self.assertTrue(image_url.startswith(prefix))
image_base64 = image_url[len(prefix):]
# Decode the base64 string
image_data = base64.b64decode(image_base64)
# Define the output file path for the non-rotated image
output_image_path = os.path.join(os.path.dirname(__file__), "test_renders", "output_image.png")
# Save the non-rotated image to a file
with open(output_image_path, "wb") as image_file:
image_file.write(image_data)
# Now, generate and save the rotated image (90 degrees clockwise)
query_rotated = build_page_query(os.path.join(
os.path.dirname(__file__),
"gnarly_pdfs",
"edgar.pdf"
), "edgar.pdf", 1, 1024, 6000, 90)
# Extract the base64 image from the rotated query
image_content_rotated = query_rotated["chat_messages"][0]["content"][1]
self.assertEqual(image_content_rotated["type"], "image_url")
image_url_rotated = image_content_rotated["image_url"]["url"]
# Extract base64 string from the data URL for the rotated image
self.assertTrue(image_url_rotated.startswith(prefix))
image_base64_rotated = image_url_rotated[len(prefix):]
# Decode the base64 string for the rotated image
image_data_rotated = base64.b64decode(image_base64_rotated)
# Define the output file path for the rotated image
output_image_rotated_path = os.path.join(os.path.dirname(__file__), "test_renders", "output_image_rotated90.png")
# Save the rotated image to a file
with open(output_image_rotated_path, "wb") as image_file_rotated:
image_file_rotated.write(image_data_rotated)
# Verification Step: Ensure the rotated image is 90 degrees clockwise rotated
# Open both images using PIL
with Image.open(output_image_path) as original_image:
with Image.open(output_image_rotated_path) as rotated_image:
# Compare pixel by pixel
original_pixels = original_image.load()
rotated_pixels = rotated_image.load()
width, height = original_image.size
self.assertEqual(width, rotated_image.size[1])
self.assertEqual(height, rotated_image.size[0])
for x in range(width):
for y in range(height):
self.assertEqual(
original_pixels[x, y], rotated_pixels[height - 1 - y, x],
f"Pixel mismatch at ({x}, {y})"
)
print("Rotation verification passed: The rotated image is correctly rotated 90 degrees clockwise.")
# Run the test
if __name__ == '__main__':
unittest.main()

View File

@ -7,7 +7,7 @@ import hashlib
from typing import List, Dict
# Import the classes we're testing
from olmocr.s3_queue import S3WorkQueue, WorkItem
from olmocr.work_queue import S3WorkQueue, WorkItem
class TestS3WorkQueue(unittest.TestCase):
def setUp(self):
@ -70,8 +70,8 @@ class TestS3WorkQueue(unittest.TestCase):
async def test_populate_queue_new_items(self):
"""Test populating queue with new items"""
# Mock empty existing index
with patch('olmocr.s3_queue.download_zstd_csv', return_value=[]):
with patch('olmocr.s3_queue.upload_zstd_csv') as mock_upload:
with patch('olmocr.work_queue.download_zstd_csv', return_value=[]):
with patch('olmocr.work_queue.upload_zstd_csv') as mock_upload:
await self.work_queue.populate_queue(self.sample_paths, items_per_group=2)
# Verify upload was called with correct data
@ -97,8 +97,8 @@ class TestS3WorkQueue(unittest.TestCase):
existing_hash = S3WorkQueue._compute_workgroup_hash(existing_paths)
existing_line = f"{existing_hash},{existing_paths[0]}"
with patch('olmocr.s3_queue.download_zstd_csv', return_value=[existing_line]):
with patch('olmocr.s3_queue.upload_zstd_csv') as mock_upload:
with patch('olmocr.work_queue.download_zstd_csv', return_value=[existing_line]):
with patch('olmocr.work_queue.upload_zstd_csv') as mock_upload:
await self.work_queue.populate_queue(existing_paths + new_paths, items_per_group=1)
# Verify upload called with both existing and new items
@ -116,8 +116,8 @@ class TestS3WorkQueue(unittest.TestCase):
completed_items = [f"s3://test-bucket/workspace/results/output_{work_hash}.jsonl"]
with patch('olmocr.s3_queue.download_zstd_csv', return_value=[work_line]):
with patch('olmocr.s3_queue.expand_s3_glob', return_value=completed_items):
with patch('olmocr.work_queue.download_zstd_csv', return_value=[work_line]):
with patch('olmocr.work_queue.expand_s3_glob', return_value=completed_items):
await self.work_queue.initialize_queue()
# Queue should be empty since all work is completed
@ -143,7 +143,7 @@ class TestS3WorkQueue(unittest.TestCase):
async def test_get_work(self):
"""Test getting work items"""
# Setup test data
work_item = WorkItem(hash="testhash123", s3_work_paths=["s3://test/file1.pdf"])
work_item = WorkItem(hash="testhash123", work_paths=["s3://test/file1.pdf"])
await self.work_queue._queue.put(work_item)
# Test getting available work
@ -162,7 +162,7 @@ class TestS3WorkQueue(unittest.TestCase):
@async_test
async def test_get_work_completed(self):
"""Test getting work that's already completed"""
work_item = WorkItem(hash="testhash123", s3_work_paths=["s3://test/file1.pdf"])
work_item = WorkItem(hash="testhash123", work_paths=["s3://test/file1.pdf"])
await self.work_queue._queue.put(work_item)
# Simulate completed work
@ -174,7 +174,7 @@ class TestS3WorkQueue(unittest.TestCase):
@async_test
async def test_get_work_locked(self):
"""Test getting work that's locked by another worker"""
work_item = WorkItem(hash="testhash123", s3_work_paths=["s3://test/file1.pdf"])
work_item = WorkItem(hash="testhash123", work_paths=["s3://test/file1.pdf"])
await self.work_queue._queue.put(work_item)
# Simulate active lock
@ -190,7 +190,7 @@ class TestS3WorkQueue(unittest.TestCase):
@async_test
async def test_get_work_stale_lock(self):
"""Test getting work with a stale lock"""
work_item = WorkItem(hash="testhash123", s3_work_paths=["s3://test/file1.pdf"])
work_item = WorkItem(hash="testhash123", work_paths=["s3://test/file1.pdf"])
await self.work_queue._queue.put(work_item)
# Simulate stale lock
@ -206,7 +206,7 @@ class TestS3WorkQueue(unittest.TestCase):
@async_test
async def test_mark_done(self):
"""Test marking work as done"""
work_item = WorkItem(hash="testhash123", s3_work_paths=["s3://test/file1.pdf"])
work_item = WorkItem(hash="testhash123", work_paths=["s3://test/file1.pdf"])
await self.work_queue._queue.put(work_item)
await self.work_queue.mark_done(work_item)
@ -223,10 +223,10 @@ class TestS3WorkQueue(unittest.TestCase):
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
self.loop.run_until_complete(self.work_queue._queue.put(WorkItem(hash="test1", s3_work_paths=["path1"])))
self.loop.run_until_complete(self.work_queue._queue.put(WorkItem(hash="test1", work_paths=["path1"])))
self.assertEqual(self.work_queue.size, 1)
self.loop.run_until_complete(self.work_queue._queue.put(WorkItem(hash="test2", s3_work_paths=["path2"])))
self.loop.run_until_complete(self.work_queue._queue.put(WorkItem(hash="test2", work_paths=["path2"])))
self.assertEqual(self.work_queue.size, 2)
self.loop.close()