mirror of
https://github.com/allenai/olmocr.git
synced 2025-12-25 14:15:16 +00:00
New work queue code is cleaner
This commit is contained in:
parent
9a8fa335ae
commit
05330150ad
@ -49,7 +49,7 @@ from olmocr.s3_utils import (
|
||||
)
|
||||
from olmocr.train.dataloader import FrontMatterParser
|
||||
from olmocr.version import VERSION
|
||||
from olmocr.work_queue import LocalWorkQueue, S3WorkQueue, WorkQueue
|
||||
from olmocr.work_queue import WorkQueue, LocalBackend, S3Backend
|
||||
|
||||
# Initialize logger
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -1104,9 +1104,9 @@ async def main():
|
||||
|
||||
# Create work queue
|
||||
if args.workspace.startswith("s3://"):
|
||||
work_queue = S3WorkQueue(workspace_s3, args.workspace)
|
||||
work_queue = WorkQueue(S3Backend(workspace_s3, args.workspace))
|
||||
else:
|
||||
work_queue = LocalWorkQueue(args.workspace)
|
||||
work_queue = WorkQueue(LocalBackend(args.workspace))
|
||||
|
||||
if args.pdfs:
|
||||
logger.info("Got --pdfs argument, going to add to the work queue")
|
||||
|
||||
@ -7,9 +7,11 @@ import io
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
from asyncio import Queue
|
||||
from asyncio import Queue, QueueEmpty
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, List, Optional
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
|
||||
import zstandard
|
||||
|
||||
from olmocr.s3_utils import (
|
||||
download_zstd_csv,
|
||||
@ -20,31 +22,76 @@ from olmocr.s3_utils import (
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Shared directory names for both local and S3 backends
|
||||
WORKER_LOCKS_DIR = "worker_locks"
|
||||
DONE_FLAGS_DIR = "done_flags"
|
||||
|
||||
|
||||
@dataclass
|
||||
class WorkItem:
|
||||
"""Represents a single work item in the queue"""
|
||||
"""Represents a single work item in the queue."""
|
||||
|
||||
hash: str
|
||||
work_paths: List[str]
|
||||
|
||||
|
||||
class WorkQueue(abc.ABC):
|
||||
class Backend(abc.ABC):
|
||||
"""Abstract backend for storage operations."""
|
||||
|
||||
@abc.abstractmethod
|
||||
async def load_index_lines(self) -> List[str]:
|
||||
"""Load raw index lines from storage."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def save_index_lines(self, lines: List[str]) -> None:
|
||||
"""Save raw index lines to storage."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get_completed_hashes(self) -> Set[str]:
|
||||
"""Get set of completed work hashes."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def is_completed(self, work_hash: str) -> bool:
|
||||
"""Check if a work item has been completed."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def is_worker_lock_taken(self, work_hash: str, worker_lock_timeout_secs: int = 1800) -> bool:
|
||||
"""Check if a worker lock is taken and not stale."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def create_worker_lock(self, work_hash: str) -> None:
|
||||
"""Create a worker lock for a work hash."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def delete_worker_lock(self, work_hash: str) -> None:
|
||||
"""Delete the worker lock for a work hash if it exists."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def create_done_flag(self, work_hash: str) -> None:
|
||||
"""Create a done flag for a work hash."""
|
||||
pass
|
||||
|
||||
|
||||
class WorkQueue:
|
||||
"""
|
||||
Base class defining the interface for a work queue.
|
||||
Manages a work queue with pluggable storage backends (e.g., local or S3).
|
||||
"""
|
||||
|
||||
def __init__(self, backend: Backend):
|
||||
self.backend = backend
|
||||
self._queue: Queue[WorkItem] = Queue()
|
||||
self._completed_hash_cache = set()
|
||||
|
||||
@staticmethod
|
||||
def _encode_csv_row(row: List[str]) -> str:
|
||||
"""
|
||||
Encodes a row of data for CSV storage with proper escaping.
|
||||
|
||||
Args:
|
||||
row: List of strings to encode
|
||||
|
||||
Returns:
|
||||
CSV-encoded string with proper escaping of commas and quotes
|
||||
"""
|
||||
"""Encodes a row of data for CSV storage with proper escaping."""
|
||||
output = io.StringIO()
|
||||
writer = csv.writer(output, quoting=csv.QUOTE_MINIMAL)
|
||||
writer.writerow(row)
|
||||
@ -52,586 +99,339 @@ class WorkQueue(abc.ABC):
|
||||
|
||||
@staticmethod
|
||||
def _decode_csv_row(line: str) -> List[str]:
|
||||
"""
|
||||
Decodes a CSV row with proper unescaping.
|
||||
|
||||
Args:
|
||||
line: CSV-encoded string
|
||||
|
||||
Returns:
|
||||
List of unescaped string values
|
||||
"""
|
||||
"""Decodes a CSV row with proper unescaping."""
|
||||
return next(csv.reader([line]))
|
||||
|
||||
@abc.abstractmethod
|
||||
async def populate_queue(self, work_paths: List[str], items_per_group: int) -> None:
|
||||
"""
|
||||
Add new items to the work queue. The specifics will vary depending on
|
||||
whether this is a local or S3-backed queue.
|
||||
|
||||
Args:
|
||||
work_paths: Each individual path that we will process over
|
||||
items_per_group: Number of items to group together in a single work item
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def initialize_queue(self) -> int:
|
||||
"""
|
||||
Load the work queue from the relevant store (local or remote)
|
||||
and initialize it for processing.
|
||||
|
||||
For example, this might remove already completed work items and randomize
|
||||
the order before adding them to an internal queue.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def is_completed(self, work_hash: str) -> bool:
|
||||
"""
|
||||
Check if a work item has been completed.
|
||||
|
||||
Args:
|
||||
work_hash: Hash of the work item to check
|
||||
|
||||
Returns:
|
||||
True if the work is completed, False otherwise
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get_work(self, worker_lock_timeout_secs: int = 1800) -> Optional[WorkItem]:
|
||||
"""
|
||||
Get the next available work item that isn't completed or locked.
|
||||
|
||||
Args:
|
||||
worker_lock_timeout_secs: Number of seconds before considering
|
||||
a worker lock stale (default 30 mins)
|
||||
|
||||
Returns:
|
||||
WorkItem if work is available, None if queue is empty
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def mark_done(self, work_item: WorkItem) -> None:
|
||||
"""
|
||||
Mark a work item as done by removing its lock file
|
||||
or performing any other cleanup.
|
||||
|
||||
Args:
|
||||
work_item: The WorkItem to mark as done
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def size(self) -> int:
|
||||
"""Get current size of work queue"""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _compute_workgroup_hash(work_paths: List[str]) -> str:
|
||||
"""
|
||||
Compute a deterministic hash for a group of paths.
|
||||
|
||||
Args:
|
||||
work_paths: List of paths (local or S3)
|
||||
|
||||
Returns:
|
||||
SHA1 hash of the sorted paths
|
||||
"""
|
||||
"""Compute a deterministic hash for a group of paths."""
|
||||
sha1 = hashlib.sha1()
|
||||
for path in sorted(work_paths):
|
||||
sha1.update(path.encode("utf-8"))
|
||||
return sha1.hexdigest()
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------------------
|
||||
# Local Helpers for reading/writing the index CSV (compressed with zstd) to disk
|
||||
# --------------------------------------------------------------------------------------
|
||||
|
||||
try:
|
||||
import zstandard
|
||||
except ImportError:
|
||||
zstandard = None
|
||||
|
||||
|
||||
def download_zstd_csv_local(local_path: str) -> List[str]:
|
||||
"""
|
||||
Download a zstd-compressed CSV from a local path.
|
||||
If the file doesn't exist, returns an empty list.
|
||||
"""
|
||||
if not os.path.exists(local_path):
|
||||
return []
|
||||
|
||||
if not zstandard:
|
||||
raise RuntimeError("zstandard package is required for local zstd CSV operations.")
|
||||
|
||||
with open(local_path, "rb") as f:
|
||||
dctx = zstandard.ZstdDecompressor()
|
||||
data = dctx.decompress(f.read())
|
||||
lines = data.decode("utf-8").splitlines()
|
||||
return lines
|
||||
|
||||
|
||||
def upload_zstd_csv_local(local_path: str, lines: List[str]) -> None:
|
||||
"""
|
||||
Upload a zstd-compressed CSV to a local path.
|
||||
"""
|
||||
if not zstandard:
|
||||
raise RuntimeError("zstandard package is required for local zstd CSV operations.")
|
||||
|
||||
data = "\n".join(lines).encode("utf-8")
|
||||
cctx = zstandard.ZstdCompressor()
|
||||
compressed_data = cctx.compress(data)
|
||||
|
||||
# Ensure parent directories exist
|
||||
os.makedirs(os.path.dirname(local_path), exist_ok=True)
|
||||
|
||||
with open(local_path, "wb") as f:
|
||||
f.write(compressed_data)
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------------------
|
||||
# LocalWorkQueue Implementation
|
||||
# --------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
class LocalWorkQueue(WorkQueue):
|
||||
"""
|
||||
A local in-memory and on-disk WorkQueue implementation, which uses
|
||||
a local workspace directory to store the queue index, lock files,
|
||||
and completed results for persistent resumption across process restarts.
|
||||
"""
|
||||
|
||||
def __init__(self, workspace_path: str):
|
||||
"""
|
||||
Initialize the local work queue.
|
||||
|
||||
Args:
|
||||
workspace_path: Local directory path where the queue index,
|
||||
results, and locks are stored.
|
||||
"""
|
||||
self.workspace_path = os.path.abspath(workspace_path)
|
||||
os.makedirs(self.workspace_path, exist_ok=True)
|
||||
|
||||
# Local index file (compressed)
|
||||
self._index_path = os.path.join(self.workspace_path, "work_index_list.csv.zstd")
|
||||
|
||||
# Output directory for completed tasks
|
||||
self._results_dir = os.path.join(self.workspace_path, "results")
|
||||
os.makedirs(self._results_dir, exist_ok=True)
|
||||
|
||||
# Directory for lock files
|
||||
self._locks_dir = os.path.join(self.workspace_path, "worker_locks")
|
||||
os.makedirs(self._locks_dir, exist_ok=True)
|
||||
|
||||
# Internal queue
|
||||
self._queue: Queue[Any] = Queue()
|
||||
|
||||
async def populate_queue(self, work_paths: List[str], items_per_group: int) -> None:
|
||||
"""
|
||||
Add new items to the work queue (local version).
|
||||
|
||||
Args:
|
||||
work_paths: Each individual path (local in this context)
|
||||
that we will process over
|
||||
items_per_group: Number of items to group together in a single work item
|
||||
"""
|
||||
# Treat them as local paths, but keep variable name for consistency
|
||||
all_paths = set(work_paths)
|
||||
logger.info(f"Found {len(all_paths):,} total paths")
|
||||
|
||||
# Load existing work groups from local index
|
||||
existing_lines = await asyncio.to_thread(download_zstd_csv_local, self._index_path)
|
||||
existing_groups = {}
|
||||
for line in existing_lines:
|
||||
def _parse_index_lines(self, lines: List[str]) -> Dict[str, List[str]]:
|
||||
"""Parse index lines into a dict of hash to paths."""
|
||||
result = {}
|
||||
for line in lines:
|
||||
if line.strip():
|
||||
parts = self._decode_csv_row(line.strip())
|
||||
if parts: # Ensure we have at least one part
|
||||
group_hash = parts[0]
|
||||
group_paths = parts[1:]
|
||||
existing_groups[group_hash] = group_paths
|
||||
parts = self._decode_csv_row(line)
|
||||
if parts:
|
||||
result[parts[0]] = parts[1:]
|
||||
return result
|
||||
|
||||
existing_path_set = {p for paths in existing_groups.values() for p in paths}
|
||||
new_paths = all_paths - existing_path_set
|
||||
logger.info(f"{len(new_paths):,} new paths to add to the workspace")
|
||||
|
||||
if not new_paths:
|
||||
return
|
||||
|
||||
# Create new work groups
|
||||
new_groups = []
|
||||
current_group = []
|
||||
for path in sorted(new_paths):
|
||||
current_group.append(path)
|
||||
if len(current_group) == items_per_group:
|
||||
group_hash = self._compute_workgroup_hash(current_group)
|
||||
new_groups.append((group_hash, current_group))
|
||||
current_group = []
|
||||
if current_group:
|
||||
group_hash = self._compute_workgroup_hash(current_group)
|
||||
new_groups.append((group_hash, current_group))
|
||||
|
||||
logger.info(f"Created {len(new_groups):,} new work groups")
|
||||
|
||||
# Combine and save updated work groups
|
||||
combined_groups = existing_groups.copy()
|
||||
for group_hash, group_paths in new_groups:
|
||||
combined_groups[group_hash] = group_paths
|
||||
|
||||
# Use proper CSV encoding with escaping for paths that may contain commas
|
||||
combined_lines = [self._encode_csv_row([group_hash] + group_paths) for group_hash, group_paths in combined_groups.items()]
|
||||
|
||||
if new_groups:
|
||||
# Write the combined data back to disk in zstd CSV format
|
||||
await asyncio.to_thread(upload_zstd_csv_local, self._index_path, combined_lines)
|
||||
|
||||
async def initialize_queue(self) -> int:
|
||||
"""
|
||||
Load the work queue from the local index file and initialize it for processing.
|
||||
Removes already completed work items and randomizes the order.
|
||||
"""
|
||||
# 1) Read the index
|
||||
work_queue_lines = await asyncio.to_thread(download_zstd_csv_local, self._index_path)
|
||||
work_queue = {}
|
||||
for line in work_queue_lines:
|
||||
if line.strip():
|
||||
parts = self._decode_csv_row(line.strip())
|
||||
if parts: # Ensure we have at least one part
|
||||
work_queue[parts[0]] = parts[1:]
|
||||
|
||||
# 2) Determine which items are completed by scanning local results/*.jsonl
|
||||
if not os.path.isdir(self._results_dir):
|
||||
os.makedirs(self._results_dir, exist_ok=True)
|
||||
done_work_items = [f for f in os.listdir(self._results_dir) if f.startswith("output_") and f.endswith(".jsonl")]
|
||||
done_work_hashes = {fn[len("output_") : -len(".jsonl")] for fn in done_work_items}
|
||||
|
||||
# 3) Filter out completed items
|
||||
remaining_work_hashes = set(work_queue) - done_work_hashes
|
||||
remaining_items = [WorkItem(hash=hash_, work_paths=work_queue[hash_]) for hash_ in remaining_work_hashes]
|
||||
random.shuffle(remaining_items)
|
||||
|
||||
# 4) Initialize our in-memory queue
|
||||
self._queue = asyncio.Queue()
|
||||
for item in remaining_items:
|
||||
await self._queue.put(item)
|
||||
|
||||
logger.info(f"Initialized local queue with {self._queue.qsize()} work items")
|
||||
|
||||
return self._queue.qsize()
|
||||
|
||||
async def is_completed(self, work_hash: str) -> bool:
|
||||
"""
|
||||
Check if a work item has been completed locally by seeing if
|
||||
output_{work_hash}.jsonl is present in the results directory.
|
||||
|
||||
Args:
|
||||
work_hash: Hash of the work item to check
|
||||
"""
|
||||
output_file = os.path.join(self._results_dir, f"output_{work_hash}.jsonl")
|
||||
return os.path.exists(output_file)
|
||||
|
||||
async def get_work(self, worker_lock_timeout_secs: int = 1800) -> Optional[WorkItem]:
|
||||
"""
|
||||
Get the next available work item that isn't completed or locked.
|
||||
|
||||
Args:
|
||||
worker_lock_timeout_secs: Number of seconds before considering
|
||||
a worker lock stale (default 30 mins)
|
||||
|
||||
Returns:
|
||||
WorkItem if work is available, None if queue is empty
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
work_item = self._queue.get_nowait()
|
||||
except asyncio.QueueEmpty:
|
||||
return None
|
||||
|
||||
# Check if work is already completed
|
||||
if await self.is_completed(work_item.hash):
|
||||
logger.debug(f"Work item {work_item.hash} already completed, skipping")
|
||||
self._queue.task_done()
|
||||
continue
|
||||
|
||||
# Check for worker lock
|
||||
lock_file = os.path.join(self._locks_dir, f"output_{work_item.hash}.jsonl")
|
||||
if os.path.exists(lock_file):
|
||||
# Check modification time
|
||||
mtime = datetime.datetime.fromtimestamp(os.path.getmtime(lock_file), datetime.timezone.utc)
|
||||
if (datetime.datetime.now(datetime.timezone.utc) - mtime).total_seconds() > worker_lock_timeout_secs:
|
||||
# Lock is stale, we can take this work
|
||||
logger.debug(f"Found stale lock for {work_item.hash}, taking work item")
|
||||
else:
|
||||
# Lock is active, skip this work
|
||||
logger.debug(f"Work item {work_item.hash} is locked by another worker, skipping")
|
||||
self._queue.task_done()
|
||||
continue
|
||||
|
||||
# Create our lock file (touch an empty file)
|
||||
try:
|
||||
with open(lock_file, "wb") as f:
|
||||
f.write(b"")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to create lock file for {work_item.hash}: {e}")
|
||||
self._queue.task_done()
|
||||
continue
|
||||
|
||||
return work_item
|
||||
|
||||
async def mark_done(self, work_item: WorkItem) -> None:
|
||||
"""
|
||||
Mark a work item as done by removing its lock file.
|
||||
|
||||
Args:
|
||||
work_item: The WorkItem to mark as done
|
||||
"""
|
||||
lock_file = os.path.join(self._locks_dir, f"output_{work_item.hash}.jsonl")
|
||||
if os.path.exists(lock_file):
|
||||
try:
|
||||
os.remove(lock_file)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete lock file for {work_item.hash}: {e}")
|
||||
self._queue.task_done()
|
||||
|
||||
@property
|
||||
def size(self) -> int:
|
||||
"""Get current size of local work queue"""
|
||||
return self._queue.qsize()
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------------------
|
||||
# S3WorkQueue Implementation
|
||||
# --------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
class S3WorkQueue(WorkQueue):
|
||||
"""
|
||||
Manages a work queue stored in S3 that coordinates work across multiple workers.
|
||||
The queue maintains a list of work items, where each work item is a group of s3 paths
|
||||
that should be processed together.
|
||||
|
||||
Each work item gets a hash, and completed work items will have their results
|
||||
stored in s3://workspace_path/results/output_[hash].jsonl
|
||||
|
||||
This is the ground source of truth about which work items are done.
|
||||
|
||||
When a worker takes an item off the queue, it will write an empty s3 file to
|
||||
s3://workspace_path/worker_locks/output_[hash].jsonl
|
||||
|
||||
The queue gets randomized on each worker, so workers pull random work items to operate on.
|
||||
As you pull an item, we will check to see if it has been completed. If yes,
|
||||
then it will immediately fetch the next item. If a lock file was created within a configurable
|
||||
timeout (30 mins by default), then that work item is also skipped.
|
||||
|
||||
The lock will will be deleted once the worker is done with that item.
|
||||
"""
|
||||
|
||||
def __init__(self, s3_client, workspace_path: str):
|
||||
"""
|
||||
Initialize the work queue.
|
||||
|
||||
Args:
|
||||
s3_client: Boto3 S3 client to use for operations
|
||||
workspace_path: S3 path where work queue and results are stored
|
||||
"""
|
||||
self.s3_client = s3_client
|
||||
self.workspace_path = workspace_path.rstrip("/")
|
||||
|
||||
self._index_path = os.path.join(self.workspace_path, "work_index_list.csv.zstd")
|
||||
self._output_glob = os.path.join(self.workspace_path, "results", "*.jsonl")
|
||||
self._queue: Queue[Any] = Queue()
|
||||
def _make_index_lines(self, groups: Dict[str, List[str]]) -> List[str]:
|
||||
"""Create encoded lines from groups dict."""
|
||||
return [self._encode_csv_row([group_hash] + group_paths) for group_hash, group_paths in groups.items()]
|
||||
|
||||
async def populate_queue(self, work_paths: List[str], items_per_group: int) -> None:
|
||||
"""
|
||||
Add new items to the work queue.
|
||||
|
||||
Args:
|
||||
work_paths: Each individual s3 path that we will process over
|
||||
items_per_group: Number of items to group together in a single work item
|
||||
"""
|
||||
all_paths = set(work_paths)
|
||||
logger.info(f"Found {len(all_paths):,} total paths")
|
||||
|
||||
# Load existing work groups
|
||||
existing_lines = await asyncio.to_thread(download_zstd_csv, self.s3_client, self._index_path)
|
||||
existing_groups = {}
|
||||
for line in existing_lines:
|
||||
if line.strip():
|
||||
parts = self._decode_csv_row(line.strip())
|
||||
if parts: # Ensure we have at least one part
|
||||
group_hash = parts[0]
|
||||
group_paths = parts[1:]
|
||||
existing_groups[group_hash] = group_paths
|
||||
lines = await self.backend.load_index_lines()
|
||||
existing_groups = self._parse_index_lines(lines)
|
||||
|
||||
existing_path_set = {path for paths in existing_groups.values() for path in paths}
|
||||
|
||||
# Find new paths to process
|
||||
new_paths = all_paths - existing_path_set
|
||||
logger.info(f"{len(new_paths):,} new paths to add to the workspace")
|
||||
existing_path_set = {p for paths in existing_groups.values() for p in paths}
|
||||
new_paths = sorted(all_paths - existing_path_set)
|
||||
|
||||
if not new_paths:
|
||||
return
|
||||
|
||||
# Create new work groups
|
||||
new_groups = []
|
||||
current_group = []
|
||||
for path in sorted(new_paths):
|
||||
current_group.append(path)
|
||||
if len(current_group) == items_per_group:
|
||||
group_hash = self._compute_workgroup_hash(current_group)
|
||||
new_groups.append((group_hash, current_group))
|
||||
current_group = []
|
||||
if current_group:
|
||||
group_hash = self._compute_workgroup_hash(current_group)
|
||||
new_groups.append((group_hash, current_group))
|
||||
for i in range(0, len(new_paths), items_per_group):
|
||||
group = new_paths[i : i + items_per_group]
|
||||
group_hash = self._compute_workgroup_hash(group)
|
||||
new_groups.append((group_hash, group))
|
||||
|
||||
logger.info(f"Created {len(new_groups):,} new work groups")
|
||||
combined_groups = {**existing_groups, **dict(new_groups)}
|
||||
combined_lines = self._make_index_lines(combined_groups)
|
||||
|
||||
# Combine and save updated work groups
|
||||
combined_groups = existing_groups.copy()
|
||||
for group_hash, group_paths in new_groups:
|
||||
combined_groups[group_hash] = group_paths
|
||||
|
||||
# Use proper CSV encoding with escaping for paths that may contain commas
|
||||
combined_lines = [self._encode_csv_row([group_hash] + group_paths) for group_hash, group_paths in combined_groups.items()]
|
||||
|
||||
if new_groups:
|
||||
await asyncio.to_thread(upload_zstd_csv, self.s3_client, self._index_path, combined_lines)
|
||||
await self.backend.save_index_lines(combined_lines)
|
||||
|
||||
async def initialize_queue(self) -> int:
|
||||
"""
|
||||
Load the work queue from S3 and initialize it for processing.
|
||||
Load the work queue and initialize it for processing.
|
||||
Removes already completed work items and randomizes the order.
|
||||
"""
|
||||
# Load work items and completed items in parallel
|
||||
download_task = asyncio.to_thread(download_zstd_csv, self.s3_client, self._index_path)
|
||||
expand_task = asyncio.to_thread(expand_s3_glob, self.s3_client, self._output_glob)
|
||||
lines = await self.backend.load_index_lines()
|
||||
work_queue = self._parse_index_lines(lines)
|
||||
|
||||
work_queue_lines, done_work_items = await asyncio.gather(download_task, expand_task)
|
||||
done_hashes = await self.backend.get_completed_hashes()
|
||||
|
||||
# Process work queue lines
|
||||
work_queue = {}
|
||||
for line in work_queue_lines:
|
||||
if line.strip():
|
||||
parts = self._decode_csv_row(line.strip())
|
||||
if parts: # Ensure we have at least one part
|
||||
work_queue[parts[0]] = parts[1:]
|
||||
|
||||
# Get set of completed work hashes
|
||||
done_work_hashes = {
|
||||
os.path.basename(item)[len("output_") : -len(".jsonl")]
|
||||
for item in done_work_items
|
||||
if os.path.basename(item).startswith("output_") and os.path.basename(item).endswith(".jsonl")
|
||||
}
|
||||
|
||||
# Find remaining work and shuffle
|
||||
remaining_work_hashes = set(work_queue) - done_work_hashes
|
||||
remaining_items = [WorkItem(hash=hash_, work_paths=work_queue[hash_]) for hash_ in remaining_work_hashes]
|
||||
remaining_hashes = set(work_queue) - done_hashes
|
||||
remaining_items = [WorkItem(hash=h, work_paths=work_queue[h]) for h in remaining_hashes]
|
||||
random.shuffle(remaining_items)
|
||||
|
||||
# Initialize queue
|
||||
self._queue = asyncio.Queue()
|
||||
self._queue = Queue()
|
||||
for item in remaining_items:
|
||||
await self._queue.put(item)
|
||||
|
||||
logger.info(f"Initialized queue with {self._queue.qsize()} work items")
|
||||
logger.info(f"Initialized queue with {self.size:,} work items")
|
||||
return self.size
|
||||
|
||||
return self._queue.qsize()
|
||||
|
||||
async def is_completed(self, work_hash: str) -> bool:
|
||||
"""
|
||||
Check if a work item has been completed.
|
||||
|
||||
Args:
|
||||
work_hash: Hash of the work item to check
|
||||
|
||||
Returns:
|
||||
True if the work is completed, False otherwise
|
||||
"""
|
||||
output_s3_path = os.path.join(self.workspace_path, "results", f"output_{work_hash}.jsonl")
|
||||
bucket, key = parse_s3_path(output_s3_path)
|
||||
|
||||
try:
|
||||
await asyncio.to_thread(self.s3_client.head_object, Bucket=bucket, Key=key)
|
||||
return True
|
||||
except self.s3_client.exceptions.ClientError:
|
||||
return False
|
||||
|
||||
async def get_work(self, worker_lock_timeout_secs: int = 1800) -> Optional[WorkItem]:
|
||||
"""
|
||||
Get the next available work item that isn't completed or locked.
|
||||
|
||||
Args:
|
||||
worker_lock_timeout_secs: Number of seconds before considering a worker lock stale (default 30 mins)
|
||||
|
||||
Returns:
|
||||
WorkItem if work is available, None if queue is empty
|
||||
"""
|
||||
REFRESH_COMPLETED_HASH_CACHE_MAX_ATTEMPTS = 3
|
||||
refresh_completed_hash_attempt = 0
|
||||
|
||||
|
||||
while True:
|
||||
try:
|
||||
work_item = self._queue.get_nowait()
|
||||
except asyncio.QueueEmpty:
|
||||
except QueueEmpty:
|
||||
return None
|
||||
|
||||
# Check if work is already completed
|
||||
if await self.is_completed(work_item.hash):
|
||||
if work_item.hash in self._completed_hash_cache or await self.backend.is_completed(work_item.hash):
|
||||
logger.debug(f"Work item {work_item.hash} already completed, skipping")
|
||||
self._queue.task_done()
|
||||
|
||||
refresh_completed_hash_attempt += 1
|
||||
|
||||
if refresh_completed_hash_attempt >= REFRESH_COMPLETED_HASH_CACHE_MAX_ATTEMPTS:
|
||||
logger.info(f"More than {REFRESH_COMPLETED_HASH_CACHE_MAX_ATTEMPTS} queue items already done, refreshing local completed cache fully")
|
||||
self._completed_hash_cache = await self.backend.get_completed_hashes()
|
||||
refresh_completed_hash_attempt = 0
|
||||
|
||||
continue
|
||||
|
||||
# Check for worker lock
|
||||
lock_path = os.path.join(self.workspace_path, "worker_locks", f"output_{work_item.hash}.jsonl")
|
||||
bucket, key = parse_s3_path(lock_path)
|
||||
|
||||
try:
|
||||
response = await asyncio.to_thread(self.s3_client.head_object, Bucket=bucket, Key=key)
|
||||
|
||||
# Check if lock is stale
|
||||
last_modified = response["LastModified"]
|
||||
if (datetime.datetime.now(datetime.timezone.utc) - last_modified).total_seconds() > worker_lock_timeout_secs:
|
||||
# Lock is stale, we can take this work
|
||||
logger.debug(f"Found stale lock for {work_item.hash}, taking work item")
|
||||
else:
|
||||
# Lock is active, skip this work
|
||||
logger.debug(f"Work item {work_item.hash} is locked by another worker, skipping")
|
||||
self._queue.task_done()
|
||||
continue
|
||||
|
||||
except self.s3_client.exceptions.ClientError:
|
||||
# No lock exists, we can take this work
|
||||
pass
|
||||
|
||||
# Create our lock file
|
||||
try:
|
||||
await asyncio.to_thread(self.s3_client.put_object, Bucket=bucket, Key=key, Body=b"")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to create lock file for {work_item.hash}: {e}")
|
||||
if await self.backend.is_worker_lock_taken(work_item.hash, worker_lock_timeout_secs):
|
||||
logger.debug(f"Work item {work_item.hash} is locked by another worker, skipping")
|
||||
self._queue.task_done()
|
||||
continue
|
||||
|
||||
# Create lock (overwrites if stale)
|
||||
try:
|
||||
await self.backend.create_worker_lock(work_item.hash)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to create lock for {work_item.hash}: {e}")
|
||||
self._queue.task_done()
|
||||
continue
|
||||
|
||||
refresh_completed_hash_attempt = 0
|
||||
return work_item
|
||||
|
||||
async def mark_done(self, work_item: WorkItem) -> None:
|
||||
"""
|
||||
Mark a work item as done by removing its lock file.
|
||||
|
||||
Args:
|
||||
work_item: The WorkItem to mark as done
|
||||
Mark a work item as done by removing its lock file and creating a done flag.
|
||||
"""
|
||||
lock_path = os.path.join(self.workspace_path, "worker_locks", f"output_{work_item.hash}.jsonl")
|
||||
bucket, key = parse_s3_path(lock_path)
|
||||
|
||||
try:
|
||||
await asyncio.to_thread(self.s3_client.delete_object, Bucket=bucket, Key=key)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete lock file for {work_item.hash}: {e}")
|
||||
|
||||
# Create done flag in done_flags_dir
|
||||
await self.backend.create_done_flag(work_item.hash)
|
||||
|
||||
# Remove the worker lock
|
||||
await self.backend.delete_worker_lock(work_item.hash)
|
||||
self._queue.task_done()
|
||||
|
||||
@property
|
||||
def size(self) -> int:
|
||||
"""Get current size of work queue"""
|
||||
"""Get current size of work queue."""
|
||||
return self._queue.qsize()
|
||||
|
||||
|
||||
class LocalBackend(Backend):
|
||||
"""Local file system backend."""
|
||||
|
||||
def __init__(self, workspace_path: str):
|
||||
self.workspace_path = os.path.abspath(workspace_path)
|
||||
self._index_path = os.path.join(self.workspace_path, "work_index_list.csv.zstd")
|
||||
self._done_flags_dir = os.path.join(self.workspace_path, DONE_FLAGS_DIR)
|
||||
self._locks_dir = os.path.join(self.workspace_path, WORKER_LOCKS_DIR)
|
||||
|
||||
os.makedirs(self.workspace_path, exist_ok=True)
|
||||
os.makedirs(self._done_flags_dir, exist_ok=True)
|
||||
os.makedirs(self._locks_dir, exist_ok=True)
|
||||
|
||||
def _download_zstd_csv_local(self, local_path: str) -> List[str]:
|
||||
"""
|
||||
Read a zstd-compressed CSV from a local path.
|
||||
If the file doesn't exist, returns an empty list.
|
||||
"""
|
||||
if not os.path.exists(local_path):
|
||||
return []
|
||||
|
||||
with open(local_path, "rb") as f:
|
||||
dctx = zstandard.ZstdDecompressor()
|
||||
data = dctx.decompress(f.read())
|
||||
lines = data.decode("utf-8").splitlines()
|
||||
return lines
|
||||
|
||||
def _upload_zstd_csv_local(self, local_path: str, lines: List[str]) -> None:
|
||||
"""
|
||||
Write a zstd-compressed CSV to a local path.
|
||||
"""
|
||||
data = "\n".join(lines).encode("utf-8")
|
||||
cctx = zstandard.ZstdCompressor()
|
||||
compressed_data = cctx.compress(data)
|
||||
|
||||
os.makedirs(os.path.dirname(local_path), exist_ok=True)
|
||||
with open(local_path, "wb") as f:
|
||||
f.write(compressed_data)
|
||||
|
||||
async def load_index_lines(self) -> List[str]:
|
||||
return await asyncio.to_thread(self._download_zstd_csv_local, self._index_path)
|
||||
|
||||
async def save_index_lines(self, lines: List[str]) -> None:
|
||||
await asyncio.to_thread(self._upload_zstd_csv_local, self._index_path, lines)
|
||||
|
||||
async def get_completed_hashes(self) -> Set[str]:
|
||||
def _list_completed() -> Set[str]:
|
||||
if not os.path.isdir(self._done_flags_dir):
|
||||
return set()
|
||||
return {
|
||||
f[len("done_") : -len(".flag")]
|
||||
for f in os.listdir(self._done_flags_dir)
|
||||
if f.startswith("done_") and f.endswith(".flag")
|
||||
}
|
||||
|
||||
return await asyncio.to_thread(_list_completed)
|
||||
|
||||
def _get_worker_lock_path(self, work_hash: str) -> str:
|
||||
"""Internal method to get worker lock path."""
|
||||
return os.path.join(self._locks_dir, f"worker_{work_hash}.lock")
|
||||
|
||||
def _get_done_flag_path(self, work_hash: str) -> str:
|
||||
"""Internal method to get done flag path."""
|
||||
return os.path.join(self._done_flags_dir, f"done_{work_hash}.flag")
|
||||
|
||||
async def _get_object_mtime(self, path: str) -> Optional[datetime.datetime]:
|
||||
"""Internal method to get object mtime."""
|
||||
def _get_mtime() -> Optional[datetime.datetime]:
|
||||
if not os.path.exists(path):
|
||||
return None
|
||||
return datetime.datetime.fromtimestamp(os.path.getmtime(path), datetime.timezone.utc)
|
||||
|
||||
return await asyncio.to_thread(_get_mtime)
|
||||
|
||||
async def is_worker_lock_taken(self, work_hash: str, worker_lock_timeout_secs: int = 1800) -> bool:
|
||||
"""Check if a worker lock is taken and not stale."""
|
||||
lock_path = self._get_worker_lock_path(work_hash)
|
||||
lock_mtime = await self._get_object_mtime(lock_path)
|
||||
|
||||
if not lock_mtime:
|
||||
return False
|
||||
|
||||
now = datetime.datetime.now(datetime.timezone.utc)
|
||||
return (now - lock_mtime).total_seconds() <= worker_lock_timeout_secs
|
||||
|
||||
async def create_worker_lock(self, work_hash: str) -> None:
|
||||
"""Create a worker lock for a work hash."""
|
||||
lock_path = self._get_worker_lock_path(work_hash)
|
||||
|
||||
def _create() -> None:
|
||||
with open(lock_path, "wb"):
|
||||
pass
|
||||
|
||||
await asyncio.to_thread(_create)
|
||||
|
||||
async def delete_worker_lock(self, work_hash: str) -> None:
|
||||
"""Delete the worker lock for a work hash if it exists."""
|
||||
lock_path = self._get_worker_lock_path(work_hash)
|
||||
|
||||
def _delete() -> None:
|
||||
if os.path.exists(lock_path):
|
||||
os.remove(lock_path)
|
||||
|
||||
await asyncio.to_thread(_delete)
|
||||
|
||||
async def is_completed(self, work_hash: str) -> bool:
|
||||
"""Check if a work item has been completed."""
|
||||
done_flag_path = self._get_done_flag_path(work_hash)
|
||||
return await self._get_object_mtime(done_flag_path) is not None
|
||||
|
||||
async def create_done_flag(self, work_hash: str) -> None:
|
||||
"""Create a done flag for a work hash."""
|
||||
done_flag_path = self._get_done_flag_path(work_hash)
|
||||
|
||||
def _create() -> None:
|
||||
with open(done_flag_path, "wb"):
|
||||
pass
|
||||
|
||||
await asyncio.to_thread(_create)
|
||||
|
||||
|
||||
class S3Backend(Backend):
|
||||
"""S3 backend."""
|
||||
|
||||
def __init__(self, s3_client: Any, workspace_path: str):
|
||||
self.s3_client = s3_client
|
||||
self.workspace_path = workspace_path.rstrip("/")
|
||||
self._index_path = os.path.join(self.workspace_path, "work_index_list.csv.zstd")
|
||||
self._output_glob = os.path.join(self.workspace_path, DONE_FLAGS_DIR, "*.flag")
|
||||
|
||||
async def load_index_lines(self) -> List[str]:
|
||||
return await asyncio.to_thread(download_zstd_csv, self.s3_client, self._index_path)
|
||||
|
||||
async def save_index_lines(self, lines: List[str]) -> None:
|
||||
await asyncio.to_thread(upload_zstd_csv, self.s3_client, self._index_path, lines)
|
||||
|
||||
async def get_completed_hashes(self) -> Set[str]:
|
||||
def _list_completed() -> Set[str]:
|
||||
done_work_items = expand_s3_glob(self.s3_client, self._output_glob)
|
||||
return {
|
||||
os.path.basename(item)[len("done_") : -len(".flag")]
|
||||
for item in done_work_items
|
||||
if os.path.basename(item).startswith("done_") and os.path.basename(item).endswith(".flag")
|
||||
}
|
||||
|
||||
return await asyncio.to_thread(_list_completed)
|
||||
|
||||
def _get_worker_lock_path(self, work_hash: str) -> str:
|
||||
"""Internal method to get worker lock path."""
|
||||
return os.path.join(self.workspace_path, WORKER_LOCKS_DIR, f"worker_{work_hash}.lock")
|
||||
|
||||
def _get_done_flag_path(self, work_hash: str) -> str:
|
||||
"""Internal method to get done flag path."""
|
||||
return os.path.join(self.workspace_path, DONE_FLAGS_DIR, f"done_{work_hash}.flag")
|
||||
|
||||
async def _get_object_mtime(self, path: str) -> Optional[datetime.datetime]:
|
||||
"""Internal method to get object mtime."""
|
||||
bucket, key = parse_s3_path(path)
|
||||
|
||||
def _head_object() -> Optional[datetime.datetime]:
|
||||
try:
|
||||
response = self.s3_client.head_object(Bucket=bucket, Key=key)
|
||||
return response["LastModified"]
|
||||
except self.s3_client.exceptions.ClientError as e:
|
||||
if e.response["Error"]["Code"] == "404":
|
||||
return None
|
||||
raise
|
||||
|
||||
return await asyncio.to_thread(_head_object)
|
||||
|
||||
async def is_worker_lock_taken(self, work_hash: str, worker_lock_timeout_secs: int = 1800) -> bool:
|
||||
"""Check if a worker lock is taken and not stale."""
|
||||
lock_path = self._get_worker_lock_path(work_hash)
|
||||
lock_mtime = await self._get_object_mtime(lock_path)
|
||||
|
||||
if not lock_mtime:
|
||||
return False
|
||||
|
||||
now = datetime.datetime.now(datetime.timezone.utc)
|
||||
return (now - lock_mtime).total_seconds() <= worker_lock_timeout_secs
|
||||
|
||||
async def create_worker_lock(self, work_hash: str) -> None:
|
||||
"""Create a worker lock for a work hash."""
|
||||
lock_path = self._get_worker_lock_path(work_hash)
|
||||
bucket, key = parse_s3_path(lock_path)
|
||||
await asyncio.to_thread(self.s3_client.put_object, Bucket=bucket, Key=key, Body=b"")
|
||||
|
||||
async def delete_worker_lock(self, work_hash: str) -> None:
|
||||
"""Delete the worker lock for a work hash if it exists."""
|
||||
lock_path = self._get_worker_lock_path(work_hash)
|
||||
bucket, key = parse_s3_path(lock_path)
|
||||
await asyncio.to_thread(self.s3_client.delete_object, Bucket=bucket, Key=key)
|
||||
|
||||
async def is_completed(self, work_hash: str) -> bool:
|
||||
"""Check if a work item has been completed."""
|
||||
done_flag_path = self._get_done_flag_path(work_hash)
|
||||
return await self._get_object_mtime(done_flag_path) is not None
|
||||
|
||||
async def create_done_flag(self, work_hash: str) -> None:
|
||||
"""Create a done flag for a work hash."""
|
||||
done_flag_path = self._get_done_flag_path(work_hash)
|
||||
bucket, key = parse_s3_path(done_flag_path)
|
||||
await asyncio.to_thread(self.s3_client.put_object, Bucket=bucket, Key=key, Body=b"")
|
||||
@ -64,7 +64,7 @@ class TestImageRotation:
|
||||
# Extract the image from the result
|
||||
messages = result["messages"]
|
||||
content = messages[0]["content"]
|
||||
image_url = content[0]["image_url"]["url"]
|
||||
image_url = content[1]["image_url"]["url"]
|
||||
image_base64 = image_url.split(",")[1]
|
||||
result_img = base64_to_image(image_base64)
|
||||
|
||||
@ -88,7 +88,7 @@ class TestImageRotation:
|
||||
# Extract the image from the result
|
||||
messages = result["messages"]
|
||||
content = messages[0]["content"]
|
||||
image_url = content[0]["image_url"]["url"]
|
||||
image_url = content[1]["image_url"]["url"]
|
||||
image_base64 = image_url.split(",")[1]
|
||||
result_img = base64_to_image(image_base64)
|
||||
|
||||
@ -113,7 +113,7 @@ class TestImageRotation:
|
||||
# Extract the image from the result
|
||||
messages = result["messages"]
|
||||
content = messages[0]["content"]
|
||||
image_url = content[0]["image_url"]["url"]
|
||||
image_url = content[1]["image_url"]["url"]
|
||||
image_base64 = image_url.split(",")[1]
|
||||
result_img = base64_to_image(image_base64)
|
||||
|
||||
@ -138,7 +138,7 @@ class TestImageRotation:
|
||||
# Extract the image from the result
|
||||
messages = result["messages"]
|
||||
content = messages[0]["content"]
|
||||
image_url = content[0]["image_url"]["url"]
|
||||
image_url = content[1]["image_url"]["url"]
|
||||
image_base64 = image_url.split(",")[1]
|
||||
result_img = base64_to_image(image_base64)
|
||||
|
||||
@ -178,7 +178,7 @@ class TestImageRotation:
|
||||
# Extract the image from the result
|
||||
messages = result["messages"]
|
||||
content = messages[0]["content"]
|
||||
image_url = content[0]["image_url"]["url"]
|
||||
image_url = content[1]["image_url"]["url"]
|
||||
image_base64 = image_url.split(",")[1]
|
||||
result_img = base64_to_image(image_base64)
|
||||
|
||||
|
||||
@ -1,14 +1,12 @@
|
||||
import asyncio
|
||||
import datetime
|
||||
import hashlib
|
||||
import unittest
|
||||
from typing import Dict, List
|
||||
from unittest.mock import Mock, call, patch
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from botocore.exceptions import ClientError
|
||||
|
||||
# Import the classes we're testing
|
||||
from olmocr.work_queue import S3WorkQueue, WorkItem
|
||||
from olmocr.work_queue import WorkQueue, S3Backend, WorkItem
|
||||
|
||||
|
||||
class TestS3WorkQueue(unittest.TestCase):
|
||||
@ -16,7 +14,8 @@ class TestS3WorkQueue(unittest.TestCase):
|
||||
"""Set up test fixtures before each test method."""
|
||||
self.s3_client = Mock()
|
||||
self.s3_client.exceptions.ClientError = ClientError
|
||||
self.work_queue = S3WorkQueue(self.s3_client, "s3://test-bucket/workspace")
|
||||
self.backend = S3Backend(self.s3_client, "s3://test-bucket/workspace")
|
||||
self.work_queue = WorkQueue(self.backend)
|
||||
self.sample_paths = [
|
||||
"s3://test-bucket/data/file1.pdf",
|
||||
"s3://test-bucket/data/file2.pdf",
|
||||
@ -35,18 +34,18 @@ class TestS3WorkQueue(unittest.TestCase):
|
||||
]
|
||||
|
||||
# Hash should be the same regardless of order
|
||||
hash1 = S3WorkQueue._compute_workgroup_hash(paths)
|
||||
hash2 = S3WorkQueue._compute_workgroup_hash(reversed(paths))
|
||||
hash1 = WorkQueue._compute_workgroup_hash(paths)
|
||||
hash2 = WorkQueue._compute_workgroup_hash(reversed(paths))
|
||||
self.assertEqual(hash1, hash2)
|
||||
|
||||
def test_init(self):
|
||||
"""Test initialization of S3WorkQueue"""
|
||||
"""Test initialization of S3Backend"""
|
||||
client = Mock()
|
||||
queue = S3WorkQueue(client, "s3://test-bucket/workspace/")
|
||||
backend = S3Backend(client, "s3://test-bucket/workspace/")
|
||||
|
||||
self.assertEqual(queue.workspace_path, "s3://test-bucket/workspace")
|
||||
self.assertEqual(queue._index_path, "s3://test-bucket/workspace/work_index_list.csv.zstd")
|
||||
self.assertEqual(queue._output_glob, "s3://test-bucket/workspace/results/*.jsonl")
|
||||
self.assertEqual(backend.workspace_path, "s3://test-bucket/workspace")
|
||||
self.assertEqual(backend._index_path, "s3://test-bucket/workspace/work_index_list.csv.zstd")
|
||||
self.assertEqual(backend._output_glob, "s3://test-bucket/workspace/done_flags/*.flag")
|
||||
|
||||
def asyncSetUp(self):
|
||||
"""Set up async test fixtures"""
|
||||
@ -87,7 +86,7 @@ class TestS3WorkQueue(unittest.TestCase):
|
||||
|
||||
# Verify format of uploaded lines
|
||||
for line in lines:
|
||||
parts = line.split(",")
|
||||
parts = WorkQueue._decode_csv_row(line)
|
||||
self.assertGreaterEqual(len(parts), 2) # Hash + at least one path
|
||||
self.assertEqual(len(parts[0]), 40) # SHA1 hash length
|
||||
|
||||
@ -98,8 +97,8 @@ class TestS3WorkQueue(unittest.TestCase):
|
||||
new_paths = ["s3://test-bucket/data/new1.pdf"]
|
||||
|
||||
# Create existing index content
|
||||
existing_hash = S3WorkQueue._compute_workgroup_hash(existing_paths)
|
||||
existing_line = f"{existing_hash},{existing_paths[0]}"
|
||||
existing_hash = WorkQueue._compute_workgroup_hash(existing_paths)
|
||||
existing_line = WorkQueue._encode_csv_row([existing_hash] + existing_paths)
|
||||
|
||||
with patch("olmocr.work_queue.download_zstd_csv", return_value=[existing_line]):
|
||||
with patch("olmocr.work_queue.upload_zstd_csv") as mock_upload:
|
||||
@ -115,17 +114,17 @@ class TestS3WorkQueue(unittest.TestCase):
|
||||
"""Test queue initialization"""
|
||||
# Mock work items and completed items
|
||||
work_paths = ["s3://test/file1.pdf", "s3://test/file2.pdf"]
|
||||
work_hash = S3WorkQueue._compute_workgroup_hash(work_paths)
|
||||
work_line = f"{work_hash},{work_paths[0]},{work_paths[1]}"
|
||||
work_hash = WorkQueue._compute_workgroup_hash(work_paths)
|
||||
work_line = WorkQueue._encode_csv_row([work_hash] + work_paths)
|
||||
|
||||
completed_items = [f"s3://test-bucket/workspace/results/output_{work_hash}.jsonl"]
|
||||
completed_items = [f"s3://test-bucket/workspace/done_flags/done_{work_hash}.flag"]
|
||||
|
||||
with patch("olmocr.work_queue.download_zstd_csv", return_value=[work_line]):
|
||||
with patch("olmocr.work_queue.expand_s3_glob", return_value=completed_items):
|
||||
await self.work_queue.initialize_queue()
|
||||
count = await self.work_queue.initialize_queue()
|
||||
|
||||
# Queue should be empty since all work is completed
|
||||
self.assertTrue(self.work_queue._queue.empty())
|
||||
self.assertEqual(count, 0)
|
||||
|
||||
@async_test
|
||||
async def test_is_completed(self):
|
||||
@ -134,11 +133,11 @@ class TestS3WorkQueue(unittest.TestCase):
|
||||
|
||||
# Test completed work
|
||||
self.s3_client.head_object.return_value = {"LastModified": datetime.datetime.now(datetime.timezone.utc)}
|
||||
self.assertTrue(await self.work_queue.is_completed(work_hash))
|
||||
self.assertTrue(await self.backend.is_completed(work_hash))
|
||||
|
||||
# Test incomplete work
|
||||
self.s3_client.head_object.side_effect = ClientError({"Error": {"Code": "404", "Message": "Not Found"}}, "HeadObject")
|
||||
self.assertFalse(await self.work_queue.is_completed(work_hash))
|
||||
self.assertFalse(await self.backend.is_completed(work_hash))
|
||||
|
||||
@async_test
|
||||
async def test_get_work(self):
|
||||
@ -154,8 +153,8 @@ class TestS3WorkQueue(unittest.TestCase):
|
||||
|
||||
# Verify lock file was created
|
||||
self.s3_client.put_object.assert_called_once()
|
||||
bucket, key = self.s3_client.put_object.call_args[1]["Bucket"], self.s3_client.put_object.call_args[1]["Key"]
|
||||
self.assertTrue(key.endswith(f"output_{work_item.hash}.jsonl"))
|
||||
key = self.s3_client.put_object.call_args[1]["Key"]
|
||||
self.assertTrue(key.endswith(f"worker_{work_item.hash}.lock"))
|
||||
|
||||
@async_test
|
||||
async def test_get_work_completed(self):
|
||||
@ -209,10 +208,17 @@ class TestS3WorkQueue(unittest.TestCase):
|
||||
|
||||
await self.work_queue.mark_done(work_item)
|
||||
|
||||
# Verify done flag was created and lock file was deleted
|
||||
# Check put_object was called for done flag
|
||||
put_calls = self.s3_client.put_object.call_args_list
|
||||
self.assertEqual(len(put_calls), 1)
|
||||
done_flag_key = put_calls[0][1]["Key"]
|
||||
self.assertTrue(done_flag_key.endswith(f"done_{work_item.hash}.flag"))
|
||||
|
||||
# Verify lock file was deleted
|
||||
self.s3_client.delete_object.assert_called_once()
|
||||
bucket, key = self.s3_client.delete_object.call_args[1]["Bucket"], self.s3_client.delete_object.call_args[1]["Key"]
|
||||
self.assertTrue(key.endswith(f"output_{work_item.hash}.jsonl"))
|
||||
key = self.s3_client.delete_object.call_args[1]["Key"]
|
||||
self.assertTrue(key.endswith(f"worker_{work_item.hash}.lock"))
|
||||
|
||||
@async_test
|
||||
async def test_paths_with_commas(self):
|
||||
@ -235,8 +241,11 @@ class TestS3WorkQueue(unittest.TestCase):
|
||||
# Initialize a fresh queue from these lines
|
||||
await self.work_queue.initialize_queue()
|
||||
|
||||
# Mock ClientError for head_object (file doesn't exist)
|
||||
self.s3_client.head_object.side_effect = ClientError({"Error": {"Code": "404", "Message": "Not Found"}}, "HeadObject")
|
||||
# Mock ClientError for head_object (file doesn't exist) - need to handle multiple calls
|
||||
self.s3_client.head_object.side_effect = [
|
||||
ClientError({"Error": {"Code": "404", "Message": "Not Found"}}, "HeadObject"), # done flag check
|
||||
ClientError({"Error": {"Code": "404", "Message": "Not Found"}}, "HeadObject"), # worker lock check
|
||||
]
|
||||
|
||||
# Get a work item
|
||||
work_item = await self.work_queue.get_work()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user