New work queue code is cleaner

This commit is contained in:
Jake Poznanski 2025-08-13 20:20:27 +00:00
parent 9a8fa335ae
commit 05330150ad
4 changed files with 376 additions and 567 deletions

View File

@ -49,7 +49,7 @@ from olmocr.s3_utils import (
)
from olmocr.train.dataloader import FrontMatterParser
from olmocr.version import VERSION
from olmocr.work_queue import LocalWorkQueue, S3WorkQueue, WorkQueue
from olmocr.work_queue import WorkQueue, LocalBackend, S3Backend
# Initialize logger
logger = logging.getLogger(__name__)
@ -1104,9 +1104,9 @@ async def main():
# Create work queue
if args.workspace.startswith("s3://"):
work_queue = S3WorkQueue(workspace_s3, args.workspace)
work_queue = WorkQueue(S3Backend(workspace_s3, args.workspace))
else:
work_queue = LocalWorkQueue(args.workspace)
work_queue = WorkQueue(LocalBackend(args.workspace))
if args.pdfs:
logger.info("Got --pdfs argument, going to add to the work queue")

View File

@ -7,9 +7,11 @@ import io
import logging
import os
import random
from asyncio import Queue
from asyncio import Queue, QueueEmpty
from dataclasses import dataclass
from typing import Any, List, Optional
from typing import Any, Dict, List, Optional, Set
import zstandard
from olmocr.s3_utils import (
download_zstd_csv,
@ -20,31 +22,76 @@ from olmocr.s3_utils import (
logger = logging.getLogger(__name__)
# Shared directory names for both local and S3 backends
WORKER_LOCKS_DIR = "worker_locks"
DONE_FLAGS_DIR = "done_flags"
@dataclass
class WorkItem:
"""Represents a single work item in the queue"""
"""Represents a single work item in the queue."""
hash: str
work_paths: List[str]
class WorkQueue(abc.ABC):
class Backend(abc.ABC):
"""Abstract backend for storage operations."""
@abc.abstractmethod
async def load_index_lines(self) -> List[str]:
"""Load raw index lines from storage."""
pass
@abc.abstractmethod
async def save_index_lines(self, lines: List[str]) -> None:
"""Save raw index lines to storage."""
pass
@abc.abstractmethod
async def get_completed_hashes(self) -> Set[str]:
"""Get set of completed work hashes."""
pass
@abc.abstractmethod
async def is_completed(self, work_hash: str) -> bool:
"""Check if a work item has been completed."""
pass
@abc.abstractmethod
async def is_worker_lock_taken(self, work_hash: str, worker_lock_timeout_secs: int = 1800) -> bool:
"""Check if a worker lock is taken and not stale."""
pass
@abc.abstractmethod
async def create_worker_lock(self, work_hash: str) -> None:
"""Create a worker lock for a work hash."""
pass
@abc.abstractmethod
async def delete_worker_lock(self, work_hash: str) -> None:
"""Delete the worker lock for a work hash if it exists."""
pass
@abc.abstractmethod
async def create_done_flag(self, work_hash: str) -> None:
"""Create a done flag for a work hash."""
pass
class WorkQueue:
"""
Base class defining the interface for a work queue.
Manages a work queue with pluggable storage backends (e.g., local or S3).
"""
def __init__(self, backend: Backend):
self.backend = backend
self._queue: Queue[WorkItem] = Queue()
self._completed_hash_cache = set()
@staticmethod
def _encode_csv_row(row: List[str]) -> str:
"""
Encodes a row of data for CSV storage with proper escaping.
Args:
row: List of strings to encode
Returns:
CSV-encoded string with proper escaping of commas and quotes
"""
"""Encodes a row of data for CSV storage with proper escaping."""
output = io.StringIO()
writer = csv.writer(output, quoting=csv.QUOTE_MINIMAL)
writer.writerow(row)
@ -52,586 +99,339 @@ class WorkQueue(abc.ABC):
@staticmethod
def _decode_csv_row(line: str) -> List[str]:
"""
Decodes a CSV row with proper unescaping.
Args:
line: CSV-encoded string
Returns:
List of unescaped string values
"""
"""Decodes a CSV row with proper unescaping."""
return next(csv.reader([line]))
@abc.abstractmethod
async def populate_queue(self, work_paths: List[str], items_per_group: int) -> None:
"""
Add new items to the work queue. The specifics will vary depending on
whether this is a local or S3-backed queue.
Args:
work_paths: Each individual path that we will process over
items_per_group: Number of items to group together in a single work item
"""
pass
@abc.abstractmethod
async def initialize_queue(self) -> int:
"""
Load the work queue from the relevant store (local or remote)
and initialize it for processing.
For example, this might remove already completed work items and randomize
the order before adding them to an internal queue.
"""
pass
@abc.abstractmethod
async def is_completed(self, work_hash: str) -> bool:
"""
Check if a work item has been completed.
Args:
work_hash: Hash of the work item to check
Returns:
True if the work is completed, False otherwise
"""
pass
@abc.abstractmethod
async def get_work(self, worker_lock_timeout_secs: int = 1800) -> Optional[WorkItem]:
"""
Get the next available work item that isn't completed or locked.
Args:
worker_lock_timeout_secs: Number of seconds before considering
a worker lock stale (default 30 mins)
Returns:
WorkItem if work is available, None if queue is empty
"""
pass
@abc.abstractmethod
async def mark_done(self, work_item: WorkItem) -> None:
"""
Mark a work item as done by removing its lock file
or performing any other cleanup.
Args:
work_item: The WorkItem to mark as done
"""
pass
@property
@abc.abstractmethod
def size(self) -> int:
"""Get current size of work queue"""
pass
@staticmethod
def _compute_workgroup_hash(work_paths: List[str]) -> str:
"""
Compute a deterministic hash for a group of paths.
Args:
work_paths: List of paths (local or S3)
Returns:
SHA1 hash of the sorted paths
"""
"""Compute a deterministic hash for a group of paths."""
sha1 = hashlib.sha1()
for path in sorted(work_paths):
sha1.update(path.encode("utf-8"))
return sha1.hexdigest()
# --------------------------------------------------------------------------------------
# Local Helpers for reading/writing the index CSV (compressed with zstd) to disk
# --------------------------------------------------------------------------------------
try:
import zstandard
except ImportError:
zstandard = None
def download_zstd_csv_local(local_path: str) -> List[str]:
"""
Download a zstd-compressed CSV from a local path.
If the file doesn't exist, returns an empty list.
"""
if not os.path.exists(local_path):
return []
if not zstandard:
raise RuntimeError("zstandard package is required for local zstd CSV operations.")
with open(local_path, "rb") as f:
dctx = zstandard.ZstdDecompressor()
data = dctx.decompress(f.read())
lines = data.decode("utf-8").splitlines()
return lines
def upload_zstd_csv_local(local_path: str, lines: List[str]) -> None:
"""
Upload a zstd-compressed CSV to a local path.
"""
if not zstandard:
raise RuntimeError("zstandard package is required for local zstd CSV operations.")
data = "\n".join(lines).encode("utf-8")
cctx = zstandard.ZstdCompressor()
compressed_data = cctx.compress(data)
# Ensure parent directories exist
os.makedirs(os.path.dirname(local_path), exist_ok=True)
with open(local_path, "wb") as f:
f.write(compressed_data)
# --------------------------------------------------------------------------------------
# LocalWorkQueue Implementation
# --------------------------------------------------------------------------------------
class LocalWorkQueue(WorkQueue):
"""
A local in-memory and on-disk WorkQueue implementation, which uses
a local workspace directory to store the queue index, lock files,
and completed results for persistent resumption across process restarts.
"""
def __init__(self, workspace_path: str):
"""
Initialize the local work queue.
Args:
workspace_path: Local directory path where the queue index,
results, and locks are stored.
"""
self.workspace_path = os.path.abspath(workspace_path)
os.makedirs(self.workspace_path, exist_ok=True)
# Local index file (compressed)
self._index_path = os.path.join(self.workspace_path, "work_index_list.csv.zstd")
# Output directory for completed tasks
self._results_dir = os.path.join(self.workspace_path, "results")
os.makedirs(self._results_dir, exist_ok=True)
# Directory for lock files
self._locks_dir = os.path.join(self.workspace_path, "worker_locks")
os.makedirs(self._locks_dir, exist_ok=True)
# Internal queue
self._queue: Queue[Any] = Queue()
async def populate_queue(self, work_paths: List[str], items_per_group: int) -> None:
"""
Add new items to the work queue (local version).
Args:
work_paths: Each individual path (local in this context)
that we will process over
items_per_group: Number of items to group together in a single work item
"""
# Treat them as local paths, but keep variable name for consistency
all_paths = set(work_paths)
logger.info(f"Found {len(all_paths):,} total paths")
# Load existing work groups from local index
existing_lines = await asyncio.to_thread(download_zstd_csv_local, self._index_path)
existing_groups = {}
for line in existing_lines:
def _parse_index_lines(self, lines: List[str]) -> Dict[str, List[str]]:
"""Parse index lines into a dict of hash to paths."""
result = {}
for line in lines:
if line.strip():
parts = self._decode_csv_row(line.strip())
if parts: # Ensure we have at least one part
group_hash = parts[0]
group_paths = parts[1:]
existing_groups[group_hash] = group_paths
parts = self._decode_csv_row(line)
if parts:
result[parts[0]] = parts[1:]
return result
existing_path_set = {p for paths in existing_groups.values() for p in paths}
new_paths = all_paths - existing_path_set
logger.info(f"{len(new_paths):,} new paths to add to the workspace")
if not new_paths:
return
# Create new work groups
new_groups = []
current_group = []
for path in sorted(new_paths):
current_group.append(path)
if len(current_group) == items_per_group:
group_hash = self._compute_workgroup_hash(current_group)
new_groups.append((group_hash, current_group))
current_group = []
if current_group:
group_hash = self._compute_workgroup_hash(current_group)
new_groups.append((group_hash, current_group))
logger.info(f"Created {len(new_groups):,} new work groups")
# Combine and save updated work groups
combined_groups = existing_groups.copy()
for group_hash, group_paths in new_groups:
combined_groups[group_hash] = group_paths
# Use proper CSV encoding with escaping for paths that may contain commas
combined_lines = [self._encode_csv_row([group_hash] + group_paths) for group_hash, group_paths in combined_groups.items()]
if new_groups:
# Write the combined data back to disk in zstd CSV format
await asyncio.to_thread(upload_zstd_csv_local, self._index_path, combined_lines)
async def initialize_queue(self) -> int:
"""
Load the work queue from the local index file and initialize it for processing.
Removes already completed work items and randomizes the order.
"""
# 1) Read the index
work_queue_lines = await asyncio.to_thread(download_zstd_csv_local, self._index_path)
work_queue = {}
for line in work_queue_lines:
if line.strip():
parts = self._decode_csv_row(line.strip())
if parts: # Ensure we have at least one part
work_queue[parts[0]] = parts[1:]
# 2) Determine which items are completed by scanning local results/*.jsonl
if not os.path.isdir(self._results_dir):
os.makedirs(self._results_dir, exist_ok=True)
done_work_items = [f for f in os.listdir(self._results_dir) if f.startswith("output_") and f.endswith(".jsonl")]
done_work_hashes = {fn[len("output_") : -len(".jsonl")] for fn in done_work_items}
# 3) Filter out completed items
remaining_work_hashes = set(work_queue) - done_work_hashes
remaining_items = [WorkItem(hash=hash_, work_paths=work_queue[hash_]) for hash_ in remaining_work_hashes]
random.shuffle(remaining_items)
# 4) Initialize our in-memory queue
self._queue = asyncio.Queue()
for item in remaining_items:
await self._queue.put(item)
logger.info(f"Initialized local queue with {self._queue.qsize()} work items")
return self._queue.qsize()
async def is_completed(self, work_hash: str) -> bool:
"""
Check if a work item has been completed locally by seeing if
output_{work_hash}.jsonl is present in the results directory.
Args:
work_hash: Hash of the work item to check
"""
output_file = os.path.join(self._results_dir, f"output_{work_hash}.jsonl")
return os.path.exists(output_file)
async def get_work(self, worker_lock_timeout_secs: int = 1800) -> Optional[WorkItem]:
"""
Get the next available work item that isn't completed or locked.
Args:
worker_lock_timeout_secs: Number of seconds before considering
a worker lock stale (default 30 mins)
Returns:
WorkItem if work is available, None if queue is empty
"""
while True:
try:
work_item = self._queue.get_nowait()
except asyncio.QueueEmpty:
return None
# Check if work is already completed
if await self.is_completed(work_item.hash):
logger.debug(f"Work item {work_item.hash} already completed, skipping")
self._queue.task_done()
continue
# Check for worker lock
lock_file = os.path.join(self._locks_dir, f"output_{work_item.hash}.jsonl")
if os.path.exists(lock_file):
# Check modification time
mtime = datetime.datetime.fromtimestamp(os.path.getmtime(lock_file), datetime.timezone.utc)
if (datetime.datetime.now(datetime.timezone.utc) - mtime).total_seconds() > worker_lock_timeout_secs:
# Lock is stale, we can take this work
logger.debug(f"Found stale lock for {work_item.hash}, taking work item")
else:
# Lock is active, skip this work
logger.debug(f"Work item {work_item.hash} is locked by another worker, skipping")
self._queue.task_done()
continue
# Create our lock file (touch an empty file)
try:
with open(lock_file, "wb") as f:
f.write(b"")
except Exception as e:
logger.warning(f"Failed to create lock file for {work_item.hash}: {e}")
self._queue.task_done()
continue
return work_item
async def mark_done(self, work_item: WorkItem) -> None:
"""
Mark a work item as done by removing its lock file.
Args:
work_item: The WorkItem to mark as done
"""
lock_file = os.path.join(self._locks_dir, f"output_{work_item.hash}.jsonl")
if os.path.exists(lock_file):
try:
os.remove(lock_file)
except Exception as e:
logger.warning(f"Failed to delete lock file for {work_item.hash}: {e}")
self._queue.task_done()
@property
def size(self) -> int:
"""Get current size of local work queue"""
return self._queue.qsize()
# --------------------------------------------------------------------------------------
# S3WorkQueue Implementation
# --------------------------------------------------------------------------------------
class S3WorkQueue(WorkQueue):
"""
Manages a work queue stored in S3 that coordinates work across multiple workers.
The queue maintains a list of work items, where each work item is a group of s3 paths
that should be processed together.
Each work item gets a hash, and completed work items will have their results
stored in s3://workspace_path/results/output_[hash].jsonl
This is the ground source of truth about which work items are done.
When a worker takes an item off the queue, it will write an empty s3 file to
s3://workspace_path/worker_locks/output_[hash].jsonl
The queue gets randomized on each worker, so workers pull random work items to operate on.
As you pull an item, we will check to see if it has been completed. If yes,
then it will immediately fetch the next item. If a lock file was created within a configurable
timeout (30 mins by default), then that work item is also skipped.
The lock will will be deleted once the worker is done with that item.
"""
def __init__(self, s3_client, workspace_path: str):
"""
Initialize the work queue.
Args:
s3_client: Boto3 S3 client to use for operations
workspace_path: S3 path where work queue and results are stored
"""
self.s3_client = s3_client
self.workspace_path = workspace_path.rstrip("/")
self._index_path = os.path.join(self.workspace_path, "work_index_list.csv.zstd")
self._output_glob = os.path.join(self.workspace_path, "results", "*.jsonl")
self._queue: Queue[Any] = Queue()
def _make_index_lines(self, groups: Dict[str, List[str]]) -> List[str]:
"""Create encoded lines from groups dict."""
return [self._encode_csv_row([group_hash] + group_paths) for group_hash, group_paths in groups.items()]
async def populate_queue(self, work_paths: List[str], items_per_group: int) -> None:
"""
Add new items to the work queue.
Args:
work_paths: Each individual s3 path that we will process over
items_per_group: Number of items to group together in a single work item
"""
all_paths = set(work_paths)
logger.info(f"Found {len(all_paths):,} total paths")
# Load existing work groups
existing_lines = await asyncio.to_thread(download_zstd_csv, self.s3_client, self._index_path)
existing_groups = {}
for line in existing_lines:
if line.strip():
parts = self._decode_csv_row(line.strip())
if parts: # Ensure we have at least one part
group_hash = parts[0]
group_paths = parts[1:]
existing_groups[group_hash] = group_paths
lines = await self.backend.load_index_lines()
existing_groups = self._parse_index_lines(lines)
existing_path_set = {path for paths in existing_groups.values() for path in paths}
# Find new paths to process
new_paths = all_paths - existing_path_set
logger.info(f"{len(new_paths):,} new paths to add to the workspace")
existing_path_set = {p for paths in existing_groups.values() for p in paths}
new_paths = sorted(all_paths - existing_path_set)
if not new_paths:
return
# Create new work groups
new_groups = []
current_group = []
for path in sorted(new_paths):
current_group.append(path)
if len(current_group) == items_per_group:
group_hash = self._compute_workgroup_hash(current_group)
new_groups.append((group_hash, current_group))
current_group = []
if current_group:
group_hash = self._compute_workgroup_hash(current_group)
new_groups.append((group_hash, current_group))
for i in range(0, len(new_paths), items_per_group):
group = new_paths[i : i + items_per_group]
group_hash = self._compute_workgroup_hash(group)
new_groups.append((group_hash, group))
logger.info(f"Created {len(new_groups):,} new work groups")
combined_groups = {**existing_groups, **dict(new_groups)}
combined_lines = self._make_index_lines(combined_groups)
# Combine and save updated work groups
combined_groups = existing_groups.copy()
for group_hash, group_paths in new_groups:
combined_groups[group_hash] = group_paths
# Use proper CSV encoding with escaping for paths that may contain commas
combined_lines = [self._encode_csv_row([group_hash] + group_paths) for group_hash, group_paths in combined_groups.items()]
if new_groups:
await asyncio.to_thread(upload_zstd_csv, self.s3_client, self._index_path, combined_lines)
await self.backend.save_index_lines(combined_lines)
async def initialize_queue(self) -> int:
"""
Load the work queue from S3 and initialize it for processing.
Load the work queue and initialize it for processing.
Removes already completed work items and randomizes the order.
"""
# Load work items and completed items in parallel
download_task = asyncio.to_thread(download_zstd_csv, self.s3_client, self._index_path)
expand_task = asyncio.to_thread(expand_s3_glob, self.s3_client, self._output_glob)
lines = await self.backend.load_index_lines()
work_queue = self._parse_index_lines(lines)
work_queue_lines, done_work_items = await asyncio.gather(download_task, expand_task)
done_hashes = await self.backend.get_completed_hashes()
# Process work queue lines
work_queue = {}
for line in work_queue_lines:
if line.strip():
parts = self._decode_csv_row(line.strip())
if parts: # Ensure we have at least one part
work_queue[parts[0]] = parts[1:]
# Get set of completed work hashes
done_work_hashes = {
os.path.basename(item)[len("output_") : -len(".jsonl")]
for item in done_work_items
if os.path.basename(item).startswith("output_") and os.path.basename(item).endswith(".jsonl")
}
# Find remaining work and shuffle
remaining_work_hashes = set(work_queue) - done_work_hashes
remaining_items = [WorkItem(hash=hash_, work_paths=work_queue[hash_]) for hash_ in remaining_work_hashes]
remaining_hashes = set(work_queue) - done_hashes
remaining_items = [WorkItem(hash=h, work_paths=work_queue[h]) for h in remaining_hashes]
random.shuffle(remaining_items)
# Initialize queue
self._queue = asyncio.Queue()
self._queue = Queue()
for item in remaining_items:
await self._queue.put(item)
logger.info(f"Initialized queue with {self._queue.qsize()} work items")
logger.info(f"Initialized queue with {self.size:,} work items")
return self.size
return self._queue.qsize()
async def is_completed(self, work_hash: str) -> bool:
"""
Check if a work item has been completed.
Args:
work_hash: Hash of the work item to check
Returns:
True if the work is completed, False otherwise
"""
output_s3_path = os.path.join(self.workspace_path, "results", f"output_{work_hash}.jsonl")
bucket, key = parse_s3_path(output_s3_path)
try:
await asyncio.to_thread(self.s3_client.head_object, Bucket=bucket, Key=key)
return True
except self.s3_client.exceptions.ClientError:
return False
async def get_work(self, worker_lock_timeout_secs: int = 1800) -> Optional[WorkItem]:
"""
Get the next available work item that isn't completed or locked.
Args:
worker_lock_timeout_secs: Number of seconds before considering a worker lock stale (default 30 mins)
Returns:
WorkItem if work is available, None if queue is empty
"""
REFRESH_COMPLETED_HASH_CACHE_MAX_ATTEMPTS = 3
refresh_completed_hash_attempt = 0
while True:
try:
work_item = self._queue.get_nowait()
except asyncio.QueueEmpty:
except QueueEmpty:
return None
# Check if work is already completed
if await self.is_completed(work_item.hash):
if work_item.hash in self._completed_hash_cache or await self.backend.is_completed(work_item.hash):
logger.debug(f"Work item {work_item.hash} already completed, skipping")
self._queue.task_done()
refresh_completed_hash_attempt += 1
if refresh_completed_hash_attempt >= REFRESH_COMPLETED_HASH_CACHE_MAX_ATTEMPTS:
logger.info(f"More than {REFRESH_COMPLETED_HASH_CACHE_MAX_ATTEMPTS} queue items already done, refreshing local completed cache fully")
self._completed_hash_cache = await self.backend.get_completed_hashes()
refresh_completed_hash_attempt = 0
continue
# Check for worker lock
lock_path = os.path.join(self.workspace_path, "worker_locks", f"output_{work_item.hash}.jsonl")
bucket, key = parse_s3_path(lock_path)
try:
response = await asyncio.to_thread(self.s3_client.head_object, Bucket=bucket, Key=key)
# Check if lock is stale
last_modified = response["LastModified"]
if (datetime.datetime.now(datetime.timezone.utc) - last_modified).total_seconds() > worker_lock_timeout_secs:
# Lock is stale, we can take this work
logger.debug(f"Found stale lock for {work_item.hash}, taking work item")
else:
# Lock is active, skip this work
logger.debug(f"Work item {work_item.hash} is locked by another worker, skipping")
self._queue.task_done()
continue
except self.s3_client.exceptions.ClientError:
# No lock exists, we can take this work
pass
# Create our lock file
try:
await asyncio.to_thread(self.s3_client.put_object, Bucket=bucket, Key=key, Body=b"")
except Exception as e:
logger.warning(f"Failed to create lock file for {work_item.hash}: {e}")
if await self.backend.is_worker_lock_taken(work_item.hash, worker_lock_timeout_secs):
logger.debug(f"Work item {work_item.hash} is locked by another worker, skipping")
self._queue.task_done()
continue
# Create lock (overwrites if stale)
try:
await self.backend.create_worker_lock(work_item.hash)
except Exception as e:
logger.warning(f"Failed to create lock for {work_item.hash}: {e}")
self._queue.task_done()
continue
refresh_completed_hash_attempt = 0
return work_item
async def mark_done(self, work_item: WorkItem) -> None:
"""
Mark a work item as done by removing its lock file.
Args:
work_item: The WorkItem to mark as done
Mark a work item as done by removing its lock file and creating a done flag.
"""
lock_path = os.path.join(self.workspace_path, "worker_locks", f"output_{work_item.hash}.jsonl")
bucket, key = parse_s3_path(lock_path)
try:
await asyncio.to_thread(self.s3_client.delete_object, Bucket=bucket, Key=key)
except Exception as e:
logger.warning(f"Failed to delete lock file for {work_item.hash}: {e}")
# Create done flag in done_flags_dir
await self.backend.create_done_flag(work_item.hash)
# Remove the worker lock
await self.backend.delete_worker_lock(work_item.hash)
self._queue.task_done()
@property
def size(self) -> int:
"""Get current size of work queue"""
"""Get current size of work queue."""
return self._queue.qsize()
class LocalBackend(Backend):
"""Local file system backend."""
def __init__(self, workspace_path: str):
self.workspace_path = os.path.abspath(workspace_path)
self._index_path = os.path.join(self.workspace_path, "work_index_list.csv.zstd")
self._done_flags_dir = os.path.join(self.workspace_path, DONE_FLAGS_DIR)
self._locks_dir = os.path.join(self.workspace_path, WORKER_LOCKS_DIR)
os.makedirs(self.workspace_path, exist_ok=True)
os.makedirs(self._done_flags_dir, exist_ok=True)
os.makedirs(self._locks_dir, exist_ok=True)
def _download_zstd_csv_local(self, local_path: str) -> List[str]:
"""
Read a zstd-compressed CSV from a local path.
If the file doesn't exist, returns an empty list.
"""
if not os.path.exists(local_path):
return []
with open(local_path, "rb") as f:
dctx = zstandard.ZstdDecompressor()
data = dctx.decompress(f.read())
lines = data.decode("utf-8").splitlines()
return lines
def _upload_zstd_csv_local(self, local_path: str, lines: List[str]) -> None:
"""
Write a zstd-compressed CSV to a local path.
"""
data = "\n".join(lines).encode("utf-8")
cctx = zstandard.ZstdCompressor()
compressed_data = cctx.compress(data)
os.makedirs(os.path.dirname(local_path), exist_ok=True)
with open(local_path, "wb") as f:
f.write(compressed_data)
async def load_index_lines(self) -> List[str]:
return await asyncio.to_thread(self._download_zstd_csv_local, self._index_path)
async def save_index_lines(self, lines: List[str]) -> None:
await asyncio.to_thread(self._upload_zstd_csv_local, self._index_path, lines)
async def get_completed_hashes(self) -> Set[str]:
def _list_completed() -> Set[str]:
if not os.path.isdir(self._done_flags_dir):
return set()
return {
f[len("done_") : -len(".flag")]
for f in os.listdir(self._done_flags_dir)
if f.startswith("done_") and f.endswith(".flag")
}
return await asyncio.to_thread(_list_completed)
def _get_worker_lock_path(self, work_hash: str) -> str:
"""Internal method to get worker lock path."""
return os.path.join(self._locks_dir, f"worker_{work_hash}.lock")
def _get_done_flag_path(self, work_hash: str) -> str:
"""Internal method to get done flag path."""
return os.path.join(self._done_flags_dir, f"done_{work_hash}.flag")
async def _get_object_mtime(self, path: str) -> Optional[datetime.datetime]:
"""Internal method to get object mtime."""
def _get_mtime() -> Optional[datetime.datetime]:
if not os.path.exists(path):
return None
return datetime.datetime.fromtimestamp(os.path.getmtime(path), datetime.timezone.utc)
return await asyncio.to_thread(_get_mtime)
async def is_worker_lock_taken(self, work_hash: str, worker_lock_timeout_secs: int = 1800) -> bool:
"""Check if a worker lock is taken and not stale."""
lock_path = self._get_worker_lock_path(work_hash)
lock_mtime = await self._get_object_mtime(lock_path)
if not lock_mtime:
return False
now = datetime.datetime.now(datetime.timezone.utc)
return (now - lock_mtime).total_seconds() <= worker_lock_timeout_secs
async def create_worker_lock(self, work_hash: str) -> None:
"""Create a worker lock for a work hash."""
lock_path = self._get_worker_lock_path(work_hash)
def _create() -> None:
with open(lock_path, "wb"):
pass
await asyncio.to_thread(_create)
async def delete_worker_lock(self, work_hash: str) -> None:
"""Delete the worker lock for a work hash if it exists."""
lock_path = self._get_worker_lock_path(work_hash)
def _delete() -> None:
if os.path.exists(lock_path):
os.remove(lock_path)
await asyncio.to_thread(_delete)
async def is_completed(self, work_hash: str) -> bool:
"""Check if a work item has been completed."""
done_flag_path = self._get_done_flag_path(work_hash)
return await self._get_object_mtime(done_flag_path) is not None
async def create_done_flag(self, work_hash: str) -> None:
"""Create a done flag for a work hash."""
done_flag_path = self._get_done_flag_path(work_hash)
def _create() -> None:
with open(done_flag_path, "wb"):
pass
await asyncio.to_thread(_create)
class S3Backend(Backend):
"""S3 backend."""
def __init__(self, s3_client: Any, workspace_path: str):
self.s3_client = s3_client
self.workspace_path = workspace_path.rstrip("/")
self._index_path = os.path.join(self.workspace_path, "work_index_list.csv.zstd")
self._output_glob = os.path.join(self.workspace_path, DONE_FLAGS_DIR, "*.flag")
async def load_index_lines(self) -> List[str]:
return await asyncio.to_thread(download_zstd_csv, self.s3_client, self._index_path)
async def save_index_lines(self, lines: List[str]) -> None:
await asyncio.to_thread(upload_zstd_csv, self.s3_client, self._index_path, lines)
async def get_completed_hashes(self) -> Set[str]:
def _list_completed() -> Set[str]:
done_work_items = expand_s3_glob(self.s3_client, self._output_glob)
return {
os.path.basename(item)[len("done_") : -len(".flag")]
for item in done_work_items
if os.path.basename(item).startswith("done_") and os.path.basename(item).endswith(".flag")
}
return await asyncio.to_thread(_list_completed)
def _get_worker_lock_path(self, work_hash: str) -> str:
"""Internal method to get worker lock path."""
return os.path.join(self.workspace_path, WORKER_LOCKS_DIR, f"worker_{work_hash}.lock")
def _get_done_flag_path(self, work_hash: str) -> str:
"""Internal method to get done flag path."""
return os.path.join(self.workspace_path, DONE_FLAGS_DIR, f"done_{work_hash}.flag")
async def _get_object_mtime(self, path: str) -> Optional[datetime.datetime]:
"""Internal method to get object mtime."""
bucket, key = parse_s3_path(path)
def _head_object() -> Optional[datetime.datetime]:
try:
response = self.s3_client.head_object(Bucket=bucket, Key=key)
return response["LastModified"]
except self.s3_client.exceptions.ClientError as e:
if e.response["Error"]["Code"] == "404":
return None
raise
return await asyncio.to_thread(_head_object)
async def is_worker_lock_taken(self, work_hash: str, worker_lock_timeout_secs: int = 1800) -> bool:
"""Check if a worker lock is taken and not stale."""
lock_path = self._get_worker_lock_path(work_hash)
lock_mtime = await self._get_object_mtime(lock_path)
if not lock_mtime:
return False
now = datetime.datetime.now(datetime.timezone.utc)
return (now - lock_mtime).total_seconds() <= worker_lock_timeout_secs
async def create_worker_lock(self, work_hash: str) -> None:
"""Create a worker lock for a work hash."""
lock_path = self._get_worker_lock_path(work_hash)
bucket, key = parse_s3_path(lock_path)
await asyncio.to_thread(self.s3_client.put_object, Bucket=bucket, Key=key, Body=b"")
async def delete_worker_lock(self, work_hash: str) -> None:
"""Delete the worker lock for a work hash if it exists."""
lock_path = self._get_worker_lock_path(work_hash)
bucket, key = parse_s3_path(lock_path)
await asyncio.to_thread(self.s3_client.delete_object, Bucket=bucket, Key=key)
async def is_completed(self, work_hash: str) -> bool:
"""Check if a work item has been completed."""
done_flag_path = self._get_done_flag_path(work_hash)
return await self._get_object_mtime(done_flag_path) is not None
async def create_done_flag(self, work_hash: str) -> None:
"""Create a done flag for a work hash."""
done_flag_path = self._get_done_flag_path(work_hash)
bucket, key = parse_s3_path(done_flag_path)
await asyncio.to_thread(self.s3_client.put_object, Bucket=bucket, Key=key, Body=b"")

View File

@ -64,7 +64,7 @@ class TestImageRotation:
# Extract the image from the result
messages = result["messages"]
content = messages[0]["content"]
image_url = content[0]["image_url"]["url"]
image_url = content[1]["image_url"]["url"]
image_base64 = image_url.split(",")[1]
result_img = base64_to_image(image_base64)
@ -88,7 +88,7 @@ class TestImageRotation:
# Extract the image from the result
messages = result["messages"]
content = messages[0]["content"]
image_url = content[0]["image_url"]["url"]
image_url = content[1]["image_url"]["url"]
image_base64 = image_url.split(",")[1]
result_img = base64_to_image(image_base64)
@ -113,7 +113,7 @@ class TestImageRotation:
# Extract the image from the result
messages = result["messages"]
content = messages[0]["content"]
image_url = content[0]["image_url"]["url"]
image_url = content[1]["image_url"]["url"]
image_base64 = image_url.split(",")[1]
result_img = base64_to_image(image_base64)
@ -138,7 +138,7 @@ class TestImageRotation:
# Extract the image from the result
messages = result["messages"]
content = messages[0]["content"]
image_url = content[0]["image_url"]["url"]
image_url = content[1]["image_url"]["url"]
image_base64 = image_url.split(",")[1]
result_img = base64_to_image(image_base64)
@ -178,7 +178,7 @@ class TestImageRotation:
# Extract the image from the result
messages = result["messages"]
content = messages[0]["content"]
image_url = content[0]["image_url"]["url"]
image_url = content[1]["image_url"]["url"]
image_base64 = image_url.split(",")[1]
result_img = base64_to_image(image_base64)

View File

@ -1,14 +1,12 @@
import asyncio
import datetime
import hashlib
import unittest
from typing import Dict, List
from unittest.mock import Mock, call, patch
from unittest.mock import Mock, patch
from botocore.exceptions import ClientError
# Import the classes we're testing
from olmocr.work_queue import S3WorkQueue, WorkItem
from olmocr.work_queue import WorkQueue, S3Backend, WorkItem
class TestS3WorkQueue(unittest.TestCase):
@ -16,7 +14,8 @@ class TestS3WorkQueue(unittest.TestCase):
"""Set up test fixtures before each test method."""
self.s3_client = Mock()
self.s3_client.exceptions.ClientError = ClientError
self.work_queue = S3WorkQueue(self.s3_client, "s3://test-bucket/workspace")
self.backend = S3Backend(self.s3_client, "s3://test-bucket/workspace")
self.work_queue = WorkQueue(self.backend)
self.sample_paths = [
"s3://test-bucket/data/file1.pdf",
"s3://test-bucket/data/file2.pdf",
@ -35,18 +34,18 @@ class TestS3WorkQueue(unittest.TestCase):
]
# Hash should be the same regardless of order
hash1 = S3WorkQueue._compute_workgroup_hash(paths)
hash2 = S3WorkQueue._compute_workgroup_hash(reversed(paths))
hash1 = WorkQueue._compute_workgroup_hash(paths)
hash2 = WorkQueue._compute_workgroup_hash(reversed(paths))
self.assertEqual(hash1, hash2)
def test_init(self):
"""Test initialization of S3WorkQueue"""
"""Test initialization of S3Backend"""
client = Mock()
queue = S3WorkQueue(client, "s3://test-bucket/workspace/")
backend = S3Backend(client, "s3://test-bucket/workspace/")
self.assertEqual(queue.workspace_path, "s3://test-bucket/workspace")
self.assertEqual(queue._index_path, "s3://test-bucket/workspace/work_index_list.csv.zstd")
self.assertEqual(queue._output_glob, "s3://test-bucket/workspace/results/*.jsonl")
self.assertEqual(backend.workspace_path, "s3://test-bucket/workspace")
self.assertEqual(backend._index_path, "s3://test-bucket/workspace/work_index_list.csv.zstd")
self.assertEqual(backend._output_glob, "s3://test-bucket/workspace/done_flags/*.flag")
def asyncSetUp(self):
"""Set up async test fixtures"""
@ -87,7 +86,7 @@ class TestS3WorkQueue(unittest.TestCase):
# Verify format of uploaded lines
for line in lines:
parts = line.split(",")
parts = WorkQueue._decode_csv_row(line)
self.assertGreaterEqual(len(parts), 2) # Hash + at least one path
self.assertEqual(len(parts[0]), 40) # SHA1 hash length
@ -98,8 +97,8 @@ class TestS3WorkQueue(unittest.TestCase):
new_paths = ["s3://test-bucket/data/new1.pdf"]
# Create existing index content
existing_hash = S3WorkQueue._compute_workgroup_hash(existing_paths)
existing_line = f"{existing_hash},{existing_paths[0]}"
existing_hash = WorkQueue._compute_workgroup_hash(existing_paths)
existing_line = WorkQueue._encode_csv_row([existing_hash] + existing_paths)
with patch("olmocr.work_queue.download_zstd_csv", return_value=[existing_line]):
with patch("olmocr.work_queue.upload_zstd_csv") as mock_upload:
@ -115,17 +114,17 @@ class TestS3WorkQueue(unittest.TestCase):
"""Test queue initialization"""
# Mock work items and completed items
work_paths = ["s3://test/file1.pdf", "s3://test/file2.pdf"]
work_hash = S3WorkQueue._compute_workgroup_hash(work_paths)
work_line = f"{work_hash},{work_paths[0]},{work_paths[1]}"
work_hash = WorkQueue._compute_workgroup_hash(work_paths)
work_line = WorkQueue._encode_csv_row([work_hash] + work_paths)
completed_items = [f"s3://test-bucket/workspace/results/output_{work_hash}.jsonl"]
completed_items = [f"s3://test-bucket/workspace/done_flags/done_{work_hash}.flag"]
with patch("olmocr.work_queue.download_zstd_csv", return_value=[work_line]):
with patch("olmocr.work_queue.expand_s3_glob", return_value=completed_items):
await self.work_queue.initialize_queue()
count = await self.work_queue.initialize_queue()
# Queue should be empty since all work is completed
self.assertTrue(self.work_queue._queue.empty())
self.assertEqual(count, 0)
@async_test
async def test_is_completed(self):
@ -134,11 +133,11 @@ class TestS3WorkQueue(unittest.TestCase):
# Test completed work
self.s3_client.head_object.return_value = {"LastModified": datetime.datetime.now(datetime.timezone.utc)}
self.assertTrue(await self.work_queue.is_completed(work_hash))
self.assertTrue(await self.backend.is_completed(work_hash))
# Test incomplete work
self.s3_client.head_object.side_effect = ClientError({"Error": {"Code": "404", "Message": "Not Found"}}, "HeadObject")
self.assertFalse(await self.work_queue.is_completed(work_hash))
self.assertFalse(await self.backend.is_completed(work_hash))
@async_test
async def test_get_work(self):
@ -154,8 +153,8 @@ class TestS3WorkQueue(unittest.TestCase):
# Verify lock file was created
self.s3_client.put_object.assert_called_once()
bucket, key = self.s3_client.put_object.call_args[1]["Bucket"], self.s3_client.put_object.call_args[1]["Key"]
self.assertTrue(key.endswith(f"output_{work_item.hash}.jsonl"))
key = self.s3_client.put_object.call_args[1]["Key"]
self.assertTrue(key.endswith(f"worker_{work_item.hash}.lock"))
@async_test
async def test_get_work_completed(self):
@ -209,10 +208,17 @@ class TestS3WorkQueue(unittest.TestCase):
await self.work_queue.mark_done(work_item)
# Verify done flag was created and lock file was deleted
# Check put_object was called for done flag
put_calls = self.s3_client.put_object.call_args_list
self.assertEqual(len(put_calls), 1)
done_flag_key = put_calls[0][1]["Key"]
self.assertTrue(done_flag_key.endswith(f"done_{work_item.hash}.flag"))
# Verify lock file was deleted
self.s3_client.delete_object.assert_called_once()
bucket, key = self.s3_client.delete_object.call_args[1]["Bucket"], self.s3_client.delete_object.call_args[1]["Key"]
self.assertTrue(key.endswith(f"output_{work_item.hash}.jsonl"))
key = self.s3_client.delete_object.call_args[1]["Key"]
self.assertTrue(key.endswith(f"worker_{work_item.hash}.lock"))
@async_test
async def test_paths_with_commas(self):
@ -235,8 +241,11 @@ class TestS3WorkQueue(unittest.TestCase):
# Initialize a fresh queue from these lines
await self.work_queue.initialize_queue()
# Mock ClientError for head_object (file doesn't exist)
self.s3_client.head_object.side_effect = ClientError({"Error": {"Code": "404", "Message": "Not Found"}}, "HeadObject")
# Mock ClientError for head_object (file doesn't exist) - need to handle multiple calls
self.s3_client.head_object.side_effect = [
ClientError({"Error": {"Code": "404", "Message": "Not Found"}}, "HeadObject"), # done flag check
ClientError({"Error": {"Code": "404", "Message": "Not Found"}}, "HeadObject"), # worker lock check
]
# Get a work item
work_item = await self.work_queue.get_work()