From 96ae2dd49bad104dc2ac70f7b9c99e88512e5368 Mon Sep 17 00:00:00 2001 From: Jake Poznanski Date: Mon, 27 Jan 2025 20:45:28 +0000 Subject: [PATCH] Refactoring --- .gitignore | 1 + olmocr/pipeline.py | 11 +- olmocr/s3_queue.py | 302 ----------------- olmocr/work_queue.py | 640 ++++++++++++++++++++++++++++++++++++ tests/test_birrpipeline.py | 278 ---------------- tests/test_s3_work_queue.py | 28 +- 6 files changed, 662 insertions(+), 598 deletions(-) delete mode 100644 olmocr/s3_queue.py create mode 100644 olmocr/work_queue.py delete mode 100644 tests/test_birrpipeline.py diff --git a/.gitignore b/.gitignore index e0ffe86..3eec6d3 100644 --- a/.gitignore +++ b/.gitignore @@ -11,6 +11,7 @@ sample200_vllm/* sample200_sglang/* pdelfin_testset/* /*.html +scoreelo.csv debug.log birrpipeline-debug.log beakerpipeline-debug.log diff --git a/olmocr/pipeline.py b/olmocr/pipeline.py index ece4473..10a2741 100644 --- a/olmocr/pipeline.py +++ b/olmocr/pipeline.py @@ -31,7 +31,7 @@ from typing import Optional, Tuple, List, Dict, Set from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed from concurrent.futures.process import BrokenProcessPool -from olmocr.s3_queue import S3WorkQueue, WorkItem +from olmocr.work_queue import S3WorkQueue, LocalWorkQueue from olmocr.s3_utils import expand_s3_glob, get_s3_bytes, get_s3_bytes_with_backoff, parse_s3_path, download_zstd_csv, upload_zstd_csv, download_directory from olmocr.data.renderpdf import render_pdf_to_base64png from olmocr.filter.filter import PdfFilter, Language @@ -420,7 +420,7 @@ async def worker(args, work_queue: S3WorkQueue, semaphore, worker_id): try: async with asyncio.TaskGroup() as tg: - dolma_tasks = [tg.create_task(process_pdf(args, worker_id, pdf)) for pdf in work_item.s3_work_paths] + dolma_tasks = [tg.create_task(process_pdf(args, worker_id, pdf)) for pdf in work_item.work_paths] logger.info(f"Created all tasks for {work_item.hash}") logger.info(f"Finished TaskGroup for worker on {work_item.hash}") @@ -834,7 +834,7 @@ def print_stats(args): async def main(): parser = argparse.ArgumentParser(description='Manager for running millions of PDFs through a batch inference pipeline') - parser.add_argument('workspace', help='The S3 path where work will be done e.g., s3://bucket/prefix/') + parser.add_argument('workspace', help='The filesystem path where work will be stored, can be a local folder, or an s3 path if coordinating work with many workers, s3://bucket/prefix/ ') parser.add_argument('--pdfs', help='Path to add pdfs stored in s3 to the workspace, can be a glob path s3://bucket/prefix/*.pdf or path to file containing list of pdf paths', default=None) parser.add_argument('--workspace_profile', help='S3 configuration profile for accessing the workspace', default=None) parser.add_argument('--pdf_profile', help='S3 configuration profile for accessing the raw pdf documents', default=None) @@ -891,7 +891,10 @@ async def main(): check_poppler_version() # Create work queue - work_queue = S3WorkQueue(workspace_s3, args.workspace) + if args.workspace.startswith("s3://"): + work_queue = S3WorkQueue(workspace_s3, args.workspace) + else: + work_queue = LocalWorkQueue(args.workspace) if args.pdfs: logger.info("Got --pdfs argument, going to add to the work queue") diff --git a/olmocr/s3_queue.py b/olmocr/s3_queue.py deleted file mode 100644 index a1295e9..0000000 --- a/olmocr/s3_queue.py +++ /dev/null @@ -1,302 +0,0 @@ -import os -import random -import logging -import hashlib -import tempfile -import datetime -from typing import Optional, Tuple, List, Dict, Set -from dataclasses import dataclass -import asyncio -from functools import partial - -from olmocr.s3_utils import ( - expand_s3_glob, - download_zstd_csv, - upload_zstd_csv, - parse_s3_path -) - - -logger = logging.getLogger(__name__) - -@dataclass -class WorkItem: - """Represents a single work item in the queue""" - hash: str - s3_work_paths: List[str] - -class S3WorkQueue: - """ - Manages a work queue stored in S3 that coordinates work across multiple workers. - The queue maintains a list of work items, where each work item is a group of s3 paths - that should be processed together. - - Each work item gets a hash, and completed work items will have their results - stored in s3://workspace_path/results/output_[hash].jsonl - - This is the ground source of truth about which work items are done. - - When a worker takes an item off the queue, it will write an empty s3 file to - s3://workspace_path/worker_locks/output_[hash].jsonl - - The queue gets randomized on each worker, so workers pull random work items to operate on. - As you pull an item, we will check to see if it has been completed. If yes, - then it will immediately fetch the next item. If a lock file was created within a configurable - timeout (30 mins by default), then that work item is also skipped. - - The lock will will be deleted once the worker is done with that item. - """ - def __init__(self, s3_client, workspace_path: str): - """ - Initialize the work queue. - - Args: - s3_client: Boto3 S3 client to use for operations - workspace_path: S3 path where work queue and results are stored - """ - self.s3_client = s3_client - self.workspace_path = workspace_path.rstrip('/') - - self._index_path = os.path.join(self.workspace_path, "work_index_list.csv.zstd") - self._output_glob = os.path.join(self.workspace_path, "results", "*.jsonl") - self._queue = asyncio.Queue() - - @staticmethod - def _compute_workgroup_hash(s3_work_paths: List[str]) -> str: - """ - Compute a deterministic hash for a group of paths. - - Args: - s3_work_paths: List of S3 paths - - Returns: - SHA1 hash of the sorted paths - """ - sha1 = hashlib.sha1() - for path in sorted(s3_work_paths): - sha1.update(path.encode('utf-8')) - return sha1.hexdigest() - - async def populate_queue(self, s3_work_paths: list[str], items_per_group: int) -> None: - """ - Add new items to the work queue. - - Args: - s3_work_paths: Each individual s3 path that we will process over - items_per_group: Number of items to group together in a single work item - """ - all_paths = set(s3_work_paths) - logger.info(f"Found {len(all_paths):,} total paths") - - # Load existing work groups - existing_lines = await asyncio.to_thread(download_zstd_csv, self.s3_client, self._index_path) - existing_groups = {} - for line in existing_lines: - if line.strip(): - parts = line.strip().split(",") - group_hash = parts[0] - group_paths = parts[1:] - existing_groups[group_hash] = group_paths - - existing_path_set = {path for paths in existing_groups.values() for path in paths} - - # Find new paths to process - new_paths = all_paths - existing_path_set - logger.info(f"{len(new_paths):,} new paths to add to the workspace") - - if not new_paths: - return - - # Create new work groups - new_groups = [] - current_group = [] - for path in sorted(new_paths): - current_group.append(path) - if len(current_group) == items_per_group: - group_hash = self._compute_workgroup_hash(current_group) - new_groups.append((group_hash, current_group)) - current_group = [] - if current_group: - group_hash = self._compute_workgroup_hash(current_group) - new_groups.append((group_hash, current_group)) - - logger.info(f"Created {len(new_groups):,} new work groups") - - # Combine and save updated work groups - combined_groups = existing_groups.copy() - for group_hash, group_paths in new_groups: - combined_groups[group_hash] = group_paths - - combined_lines = [ - ",".join([group_hash] + group_paths) - for group_hash, group_paths in combined_groups.items() - ] - - if new_groups: - await asyncio.to_thread( - upload_zstd_csv, - self.s3_client, - self._index_path, - combined_lines - ) - - async def initialize_queue(self) -> None: - """ - Load the work queue from S3 and initialize it for processing. - Removes already completed work items and randomizes the order. - """ - # Load work items and completed items in parallel - download_task = asyncio.to_thread( - download_zstd_csv, - self.s3_client, - self._index_path - ) - expand_task = asyncio.to_thread( - expand_s3_glob, - self.s3_client, - self._output_glob - ) - - work_queue_lines, done_work_items = await asyncio.gather(download_task, expand_task) - - # Process work queue lines - work_queue = { - parts[0]: parts[1:] - for line in work_queue_lines - if (parts := line.strip().split(",")) and line.strip() - } - - # Get set of completed work hashes - done_work_hashes = { - os.path.basename(item)[len('output_'):-len('.jsonl')] - for item in done_work_items - if os.path.basename(item).startswith('output_') - and os.path.basename(item).endswith('.jsonl') - } - - # Find remaining work and shuffle - remaining_work_hashes = set(work_queue) - done_work_hashes - remaining_items = [ - WorkItem(hash=hash_, s3_work_paths=work_queue[hash_]) - for hash_ in remaining_work_hashes - ] - random.shuffle(remaining_items) - - # Initialize queue - self._queue = asyncio.Queue() - for item in remaining_items: - await self._queue.put(item) - - logger.info(f"Initialized queue with {self._queue.qsize()} work items") - - async def is_completed(self, work_hash: str) -> bool: - """ - Check if a work item has been completed. - - Args: - work_hash: Hash of the work item to check - - Returns: - True if the work is completed, False otherwise - """ - output_s3_path = os.path.join(self.workspace_path, "results", f"output_{work_hash}.jsonl") - bucket, key = parse_s3_path(output_s3_path) - - try: - await asyncio.to_thread( - self.s3_client.head_object, - Bucket=bucket, - Key=key - ) - return True - except self.s3_client.exceptions.ClientError: - return False - - async def get_work(self, worker_lock_timeout_secs: int = 1800) -> Optional[WorkItem]: - """ - Get the next available work item that isn't completed or locked. - - Args: - worker_lock_timeout_secs: Number of seconds before considering a worker lock stale (default 30 mins) - - Returns: - WorkItem if work is available, None if queue is empty - """ - while True: - try: - work_item = self._queue.get_nowait() - except asyncio.QueueEmpty: - return None - - # Check if work is already completed - if await self.is_completed(work_item.hash): - logger.debug(f"Work item {work_item.hash} already completed, skipping") - self._queue.task_done() - continue - - # Check for worker lock - lock_path = os.path.join(self.workspace_path, "worker_locks", f"output_{work_item.hash}.jsonl") - bucket, key = parse_s3_path(lock_path) - - try: - response = await asyncio.to_thread( - self.s3_client.head_object, - Bucket=bucket, - Key=key - ) - - # Check if lock is stale - last_modified = response['LastModified'] - if (datetime.datetime.now(datetime.timezone.utc) - last_modified).total_seconds() > worker_lock_timeout_secs: - # Lock is stale, we can take this work - logger.debug(f"Found stale lock for {work_item.hash}, taking work item") - else: - # Lock is active, skip this work - logger.debug(f"Work item {work_item.hash} is locked by another worker, skipping") - self._queue.task_done() - continue - - except self.s3_client.exceptions.ClientError: - # No lock exists, we can take this work - pass - - # Create our lock file - try: - await asyncio.to_thread( - self.s3_client.put_object, - Bucket=bucket, - Key=key, - Body=b'' - ) - except Exception as e: - logger.warning(f"Failed to create lock file for {work_item.hash}: {e}") - self._queue.task_done() - continue - - return work_item - - async def mark_done(self, work_item: WorkItem) -> None: - """ - Mark a work item as done by removing its lock file. - - Args: - work_item: The WorkItem to mark as done - """ - lock_path = os.path.join(self.workspace_path, "worker_locks", f"output_{work_item.hash}.jsonl") - bucket, key = parse_s3_path(lock_path) - - try: - await asyncio.to_thread( - self.s3_client.delete_object, - Bucket=bucket, - Key=key - ) - except Exception as e: - logger.warning(f"Failed to delete lock file for {work_item.hash}: {e}") - - self._queue.task_done() - - @property - def size(self) -> int: - """Get current size of work queue""" - return self._queue.qsize() \ No newline at end of file diff --git a/olmocr/work_queue.py b/olmocr/work_queue.py new file mode 100644 index 0000000..e0d00d2 --- /dev/null +++ b/olmocr/work_queue.py @@ -0,0 +1,640 @@ +import os +import random +import logging +import hashlib +import tempfile +import datetime +import asyncio +import abc +from typing import Optional, List, Dict, Set +from dataclasses import dataclass + +from functools import partial + +logger = logging.getLogger(__name__) + +@dataclass +class WorkItem: + """Represents a single work item in the queue""" + hash: str + work_paths: List[str] + + +class WorkQueue(abc.ABC): + """ + Base class defining the interface for a work queue. + """ + + @abc.abstractmethod + async def populate_queue(self, work_paths: List[str], items_per_group: int) -> None: + """ + Add new items to the work queue. The specifics will vary depending on + whether this is a local or S3-backed queue. + + Args: + work_paths: Each individual path that we will process over + items_per_group: Number of items to group together in a single work item + """ + pass + + @abc.abstractmethod + async def initialize_queue(self) -> None: + """ + Load the work queue from the relevant store (local or remote) + and initialize it for processing. + + For example, this might remove already completed work items and randomize + the order before adding them to an internal queue. + """ + pass + + @abc.abstractmethod + async def is_completed(self, work_hash: str) -> bool: + """ + Check if a work item has been completed. + + Args: + work_hash: Hash of the work item to check + + Returns: + True if the work is completed, False otherwise + """ + pass + + @abc.abstractmethod + async def get_work(self, worker_lock_timeout_secs: int = 1800) -> Optional[WorkItem]: + """ + Get the next available work item that isn't completed or locked. + + Args: + worker_lock_timeout_secs: Number of seconds before considering + a worker lock stale (default 30 mins) + + Returns: + WorkItem if work is available, None if queue is empty + """ + pass + + @abc.abstractmethod + async def mark_done(self, work_item: WorkItem) -> None: + """ + Mark a work item as done by removing its lock file + or performing any other cleanup. + + Args: + work_item: The WorkItem to mark as done + """ + pass + + @property + @abc.abstractmethod + def size(self) -> int: + """Get current size of work queue""" + pass + + @staticmethod + def _compute_workgroup_hash(s3_work_paths: List[str]) -> str: + """ + Compute a deterministic hash for a group of paths. + + Args: + s3_work_paths: List of paths (local or S3) + + Returns: + SHA1 hash of the sorted paths + """ + sha1 = hashlib.sha1() + for path in sorted(s3_work_paths): + sha1.update(path.encode('utf-8')) + return sha1.hexdigest() + + +# -------------------------------------------------------------------------------------- +# Local Helpers for reading/writing the index CSV (compressed with zstd) to disk +# -------------------------------------------------------------------------------------- + +try: + import zstandard +except ImportError: + zstandard = None + +def download_zstd_csv_local(local_path: str) -> List[str]: + """ + Download a zstd-compressed CSV from a local path. + If the file doesn't exist, returns an empty list. + """ + if not os.path.exists(local_path): + return [] + + if not zstandard: + raise RuntimeError("zstandard package is required for local zstd CSV operations.") + + with open(local_path, 'rb') as f: + dctx = zstandard.ZstdDecompressor() + data = dctx.decompress(f.read()) + lines = data.decode('utf-8').splitlines() + return lines + +def upload_zstd_csv_local(local_path: str, lines: List[str]) -> None: + """ + Upload a zstd-compressed CSV to a local path. + """ + if not zstandard: + raise RuntimeError("zstandard package is required for local zstd CSV operations.") + + data = "\n".join(lines).encode('utf-8') + cctx = zstandard.ZstdCompressor() + compressed_data = cctx.compress(data) + + # Ensure parent directories exist + os.makedirs(os.path.dirname(local_path), exist_ok=True) + + with open(local_path, 'wb') as f: + f.write(compressed_data) + + +# -------------------------------------------------------------------------------------- +# LocalWorkQueue Implementation +# -------------------------------------------------------------------------------------- + +class LocalWorkQueue(WorkQueue): + """ + A local in-memory and on-disk WorkQueue implementation, which uses + a local workspace directory to store the queue index, lock files, + and completed results for persistent resumption across process restarts. + """ + + def __init__(self, workspace_path: str): + """ + Initialize the local work queue. + + Args: + workspace_path: Local directory path where the queue index, + results, and locks are stored. + """ + self.workspace_path = os.path.abspath(workspace_path) + os.makedirs(self.workspace_path, exist_ok=True) + + # Local index file (compressed) + self._index_path = os.path.join(self.workspace_path, "work_index_list.csv.zstd") + + # Output directory for completed tasks + self._results_dir = os.path.join(self.workspace_path, "results") + os.makedirs(self._results_dir, exist_ok=True) + + # Directory for lock files + self._locks_dir = os.path.join(self.workspace_path, "worker_locks") + os.makedirs(self._locks_dir, exist_ok=True) + + # Internal queue + self._queue = asyncio.Queue() + + async def populate_queue(self, s3_work_paths: List[str], items_per_group: int) -> None: + """ + Add new items to the work queue (local version). + + Args: + s3_work_paths: Each individual path (local in this context) + that we will process over + items_per_group: Number of items to group together in a single work item + """ + # Treat them as local paths, but keep variable name for consistency + all_paths = set(s3_work_paths) + logger.info(f"Found {len(all_paths):,} total paths") + + # Load existing work groups from local index + existing_lines = await asyncio.to_thread(download_zstd_csv_local, self._index_path) + existing_groups = {} + for line in existing_lines: + if line.strip(): + parts = line.strip().split(",") + group_hash = parts[0] + group_paths = parts[1:] + existing_groups[group_hash] = group_paths + + existing_path_set = {p for paths in existing_groups.values() for p in paths} + new_paths = all_paths - existing_path_set + logger.info(f"{len(new_paths):,} new paths to add to the workspace") + + if not new_paths: + return + + # Create new work groups + new_groups = [] + current_group = [] + for path in sorted(new_paths): + current_group.append(path) + if len(current_group) == items_per_group: + group_hash = self._compute_workgroup_hash(current_group) + new_groups.append((group_hash, current_group)) + current_group = [] + if current_group: + group_hash = self._compute_workgroup_hash(current_group) + new_groups.append((group_hash, current_group)) + + logger.info(f"Created {len(new_groups):,} new work groups") + + # Combine and save updated work groups + combined_groups = existing_groups.copy() + for group_hash, group_paths in new_groups: + combined_groups[group_hash] = group_paths + + combined_lines = [ + ",".join([group_hash] + group_paths) + for group_hash, group_paths in combined_groups.items() + ] + + if new_groups: + # Write the combined data back to disk in zstd CSV format + await asyncio.to_thread(upload_zstd_csv_local, self._index_path, combined_lines) + + async def initialize_queue(self) -> None: + """ + Load the work queue from the local index file and initialize it for processing. + Removes already completed work items and randomizes the order. + """ + # 1) Read the index + work_queue_lines = await asyncio.to_thread(download_zstd_csv_local, self._index_path) + work_queue = { + parts[0]: parts[1:] + for line in work_queue_lines + if (parts := line.strip().split(",")) and line.strip() + } + + # 2) Determine which items are completed by scanning local results/*.jsonl + if not os.path.isdir(self._results_dir): + os.makedirs(self._results_dir, exist_ok=True) + done_work_items = [ + f for f in os.listdir(self._results_dir) + if f.startswith("output_") and f.endswith(".jsonl") + ] + done_work_hashes = { + fn[len('output_'):-len('.jsonl')] + for fn in done_work_items + } + + # 3) Filter out completed items + remaining_work_hashes = set(work_queue) - done_work_hashes + remaining_items = [ + WorkItem(hash=hash_, s3_work_paths=work_queue[hash_]) + for hash_ in remaining_work_hashes + ] + random.shuffle(remaining_items) + + # 4) Initialize our in-memory queue + self._queue = asyncio.Queue() + for item in remaining_items: + await self._queue.put(item) + + logger.info(f"Initialized local queue with {self._queue.qsize()} work items") + + async def is_completed(self, work_hash: str) -> bool: + """ + Check if a work item has been completed locally by seeing if + output_{work_hash}.jsonl is present in the results directory. + + Args: + work_hash: Hash of the work item to check + """ + output_file = os.path.join(self._results_dir, f"output_{work_hash}.jsonl") + return os.path.exists(output_file) + + async def get_work(self, worker_lock_timeout_secs: int = 1800) -> Optional[WorkItem]: + """ + Get the next available work item that isn't completed or locked. + + Args: + worker_lock_timeout_secs: Number of seconds before considering + a worker lock stale (default 30 mins) + + Returns: + WorkItem if work is available, None if queue is empty + """ + while True: + try: + work_item = self._queue.get_nowait() + except asyncio.QueueEmpty: + return None + + # Check if work is already completed + if await self.is_completed(work_item.hash): + logger.debug(f"Work item {work_item.hash} already completed, skipping") + self._queue.task_done() + continue + + # Check for worker lock + lock_file = os.path.join(self._locks_dir, f"output_{work_item.hash}.jsonl") + if os.path.exists(lock_file): + # Check modification time + mtime = datetime.datetime.fromtimestamp(os.path.getmtime(lock_file), datetime.timezone.utc) + if (datetime.datetime.now(datetime.timezone.utc) - mtime).total_seconds() > worker_lock_timeout_secs: + # Lock is stale, we can take this work + logger.debug(f"Found stale lock for {work_item.hash}, taking work item") + else: + # Lock is active, skip this work + logger.debug(f"Work item {work_item.hash} is locked by another worker, skipping") + self._queue.task_done() + continue + + # Create our lock file (touch an empty file) + try: + with open(lock_file, "wb") as f: + f.write(b"") + except Exception as e: + logger.warning(f"Failed to create lock file for {work_item.hash}: {e}") + self._queue.task_done() + continue + + return work_item + + async def mark_done(self, work_item: WorkItem) -> None: + """ + Mark a work item as done by removing its lock file. + + Args: + work_item: The WorkItem to mark as done + """ + lock_file = os.path.join(self._locks_dir, f"output_{work_item.hash}.jsonl") + if os.path.exists(lock_file): + try: + os.remove(lock_file) + except Exception as e: + logger.warning(f"Failed to delete lock file for {work_item.hash}: {e}") + self._queue.task_done() + + @property + def size(self) -> int: + """Get current size of local work queue""" + return self._queue.qsize() + + +# -------------------------------------------------------------------------------------- +# S3WorkQueue Implementation (Preserves Original Comments) +# -------------------------------------------------------------------------------------- + +from olmocr.s3_utils import ( + expand_s3_glob, + download_zstd_csv, + upload_zstd_csv, + parse_s3_path +) + +class S3WorkQueue(WorkQueue): + """ + Manages a work queue stored in S3 that coordinates work across multiple workers. + The queue maintains a list of work items, where each work item is a group of s3 paths + that should be processed together. + + Each work item gets a hash, and completed work items will have their results + stored in s3://workspace_path/results/output_[hash].jsonl + + This is the ground source of truth about which work items are done. + + When a worker takes an item off the queue, it will write an empty s3 file to + s3://workspace_path/worker_locks/output_[hash].jsonl + + The queue gets randomized on each worker, so workers pull random work items to operate on. + As you pull an item, we will check to see if it has been completed. If yes, + then it will immediately fetch the next item. If a lock file was created within a configurable + timeout (30 mins by default), then that work item is also skipped. + + The lock will will be deleted once the worker is done with that item. + """ + def __init__(self, s3_client, workspace_path: str): + """ + Initialize the work queue. + + Args: + s3_client: Boto3 S3 client to use for operations + workspace_path: S3 path where work queue and results are stored + """ + self.s3_client = s3_client + self.workspace_path = workspace_path.rstrip('/') + + self._index_path = os.path.join(self.workspace_path, "work_index_list.csv.zstd") + self._output_glob = os.path.join(self.workspace_path, "results", "*.jsonl") + self._queue = asyncio.Queue() + + async def populate_queue(self, s3_work_paths: List[str], items_per_group: int) -> None: + """ + Add new items to the work queue. + + Args: + s3_work_paths: Each individual s3 path that we will process over + items_per_group: Number of items to group together in a single work item + """ + all_paths = set(s3_work_paths) + logger.info(f"Found {len(all_paths):,} total paths") + + # Load existing work groups + existing_lines = await asyncio.to_thread(download_zstd_csv, self.s3_client, self._index_path) + existing_groups = {} + for line in existing_lines: + if line.strip(): + parts = line.strip().split(",") + group_hash = parts[0] + group_paths = parts[1:] + existing_groups[group_hash] = group_paths + + existing_path_set = {path for paths in existing_groups.values() for path in paths} + + # Find new paths to process + new_paths = all_paths - existing_path_set + logger.info(f"{len(new_paths):,} new paths to add to the workspace") + + if not new_paths: + return + + # Create new work groups + new_groups = [] + current_group = [] + for path in sorted(new_paths): + current_group.append(path) + if len(current_group) == items_per_group: + group_hash = self._compute_workgroup_hash(current_group) + new_groups.append((group_hash, current_group)) + current_group = [] + if current_group: + group_hash = self._compute_workgroup_hash(current_group) + new_groups.append((group_hash, current_group)) + + logger.info(f"Created {len(new_groups):,} new work groups") + + # Combine and save updated work groups + combined_groups = existing_groups.copy() + for group_hash, group_paths in new_groups: + combined_groups[group_hash] = group_paths + + combined_lines = [ + ",".join([group_hash] + group_paths) + for group_hash, group_paths in combined_groups.items() + ] + + if new_groups: + await asyncio.to_thread( + upload_zstd_csv, + self.s3_client, + self._index_path, + combined_lines + ) + + async def initialize_queue(self) -> None: + """ + Load the work queue from S3 and initialize it for processing. + Removes already completed work items and randomizes the order. + """ + # Load work items and completed items in parallel + download_task = asyncio.to_thread( + download_zstd_csv, + self.s3_client, + self._index_path + ) + expand_task = asyncio.to_thread( + expand_s3_glob, + self.s3_client, + self._output_glob + ) + + work_queue_lines, done_work_items = await asyncio.gather(download_task, expand_task) + + # Process work queue lines + work_queue = { + parts[0]: parts[1:] + for line in work_queue_lines + if (parts := line.strip().split(",")) and line.strip() + } + + # Get set of completed work hashes + done_work_hashes = { + os.path.basename(item)[len('output_'):-len('.jsonl')] + for item in done_work_items + if os.path.basename(item).startswith('output_') + and os.path.basename(item).endswith('.jsonl') + } + + # Find remaining work and shuffle + remaining_work_hashes = set(work_queue) - done_work_hashes + remaining_items = [ + WorkItem(hash=hash_, s3_work_paths=work_queue[hash_]) + for hash_ in remaining_work_hashes + ] + random.shuffle(remaining_items) + + # Initialize queue + self._queue = asyncio.Queue() + for item in remaining_items: + await self._queue.put(item) + + logger.info(f"Initialized queue with {self._queue.qsize()} work items") + + async def is_completed(self, work_hash: str) -> bool: + """ + Check if a work item has been completed. + + Args: + work_hash: Hash of the work item to check + + Returns: + True if the work is completed, False otherwise + """ + output_s3_path = os.path.join(self.workspace_path, "results", f"output_{work_hash}.jsonl") + bucket, key = parse_s3_path(output_s3_path) + + try: + await asyncio.to_thread( + self.s3_client.head_object, + Bucket=bucket, + Key=key + ) + return True + except self.s3_client.exceptions.ClientError: + return False + + async def get_work(self, worker_lock_timeout_secs: int = 1800) -> Optional[WorkItem]: + """ + Get the next available work item that isn't completed or locked. + + Args: + worker_lock_timeout_secs: Number of seconds before considering a worker lock stale (default 30 mins) + + Returns: + WorkItem if work is available, None if queue is empty + """ + while True: + try: + work_item = self._queue.get_nowait() + except asyncio.QueueEmpty: + return None + + # Check if work is already completed + if await self.is_completed(work_item.hash): + logger.debug(f"Work item {work_item.hash} already completed, skipping") + self._queue.task_done() + continue + + # Check for worker lock + lock_path = os.path.join(self.workspace_path, "worker_locks", f"output_{work_item.hash}.jsonl") + bucket, key = parse_s3_path(lock_path) + + try: + response = await asyncio.to_thread( + self.s3_client.head_object, + Bucket=bucket, + Key=key + ) + + # Check if lock is stale + last_modified = response['LastModified'] + if (datetime.datetime.now(datetime.timezone.utc) - last_modified).total_seconds() > worker_lock_timeout_secs: + # Lock is stale, we can take this work + logger.debug(f"Found stale lock for {work_item.hash}, taking work item") + else: + # Lock is active, skip this work + logger.debug(f"Work item {work_item.hash} is locked by another worker, skipping") + self._queue.task_done() + continue + + except self.s3_client.exceptions.ClientError: + # No lock exists, we can take this work + pass + + # Create our lock file + try: + await asyncio.to_thread( + self.s3_client.put_object, + Bucket=bucket, + Key=key, + Body=b'' + ) + except Exception as e: + logger.warning(f"Failed to create lock file for {work_item.hash}: {e}") + self._queue.task_done() + continue + + return work_item + + async def mark_done(self, work_item: WorkItem) -> None: + """ + Mark a work item as done by removing its lock file. + + Args: + work_item: The WorkItem to mark as done + """ + lock_path = os.path.join(self.workspace_path, "worker_locks", f"output_{work_item.hash}.jsonl") + bucket, key = parse_s3_path(lock_path) + + try: + await asyncio.to_thread( + self.s3_client.delete_object, + Bucket=bucket, + Key=key + ) + except Exception as e: + logger.warning(f"Failed to delete lock file for {work_item.hash}: {e}") + + self._queue.task_done() + + @property + def size(self) -> int: + """Get current size of work queue""" + return self._queue.qsize() diff --git a/tests/test_birrpipeline.py b/tests/test_birrpipeline.py deleted file mode 100644 index 03c56f9..0000000 --- a/tests/test_birrpipeline.py +++ /dev/null @@ -1,278 +0,0 @@ -import unittest -from unittest.mock import MagicMock, patch -import hashlib -import json -import os -import base64 -from PIL import Image - -# Adjust the import path to match where your code resides -from olmocr.birrpipeline import build_dolma_doc, DatabaseManager, build_finetuning_prompt, build_page_query - -class TestBuildDolmaDoc(unittest.TestCase): - @patch('olmocr.birrpipeline.DatabaseManager') - @patch('olmocr.birrpipeline.get_s3_bytes') - def test_build_dolma_doc_with_multiple_page_entries(self, mock_get_s3_bytes, mock_DatabaseManager): - # Mock DatabaseManager instance - mock_db_instance = MagicMock() - mock_DatabaseManager.return_value = mock_db_instance - - # Define the PDF record - pdf_s3_path = 's3://bucket/pdf/test.pdf' - pdf = DatabaseManager.PDFRecord(s3_path=pdf_s3_path, num_pages=1, status='pending') - - # Create multiple BatchInferenceRecord entries for page_num=1 - entry_a = DatabaseManager.BatchInferenceRecord( - inference_s3_path='s3://bucket/inference/output1.jsonl', - pdf_s3_path=pdf_s3_path, - page_num=1, - round=0, - start_index=0, - length=100, - finish_reason='stop', - error=None - ) - - entry_b = DatabaseManager.BatchInferenceRecord( - inference_s3_path='s3://bucket/inference/output2.jsonl', - pdf_s3_path=pdf_s3_path, - page_num=1, - round=0, - start_index=0, - length=100, - finish_reason='stop', - error=None - ) - - entry_c = DatabaseManager.BatchInferenceRecord( - inference_s3_path='s3://bucket/inference/output3.jsonl', - pdf_s3_path=pdf_s3_path, - page_num=1, - round=0, - start_index=0, - length=100, - finish_reason='stop', - error=None - ) - - entry_d = DatabaseManager.BatchInferenceRecord( - inference_s3_path='s3://bucket/inference/output4.jsonl', - pdf_s3_path=pdf_s3_path, - page_num=1, - round=0, - start_index=0, - length=100, - finish_reason='stop', - error=None - ) - - # Set up mock_db_instance.get_index_entries to return all entries - mock_db_instance.get_index_entries.return_value = [entry_a, entry_b, entry_c, entry_d] - - # Define get_s3_bytes side effect function - def get_s3_bytes_side_effect(s3_client, s3_path, start_index=None, end_index=None): - if s3_path == 's3://bucket/inference/output1.jsonl': - inner_data = { - "primary_language": "en", - "is_rotation_valid": True, - "rotation_correction": 0, - "is_table": False, - "is_diagram": False, - "natural_text": "Short Text" - } - data = { - "custom_id": f"{pdf_s3_path}-1", - "outputs": [{"text": json.dumps(inner_data)}], - "round": 0 - } - elif s3_path == 's3://bucket/inference/output2.jsonl': - inner_data = { - "primary_language": "en", - "is_rotation_valid": False, - "rotation_correction": 90, - "is_table": True, - "is_diagram": False, - "natural_text": "Very Long Text Here that is longer" - } - data = { - "custom_id": f"{pdf_s3_path}-1", - "outputs": [{"text": json.dumps(inner_data)}], - "round": 0 - } - elif s3_path == 's3://bucket/inference/output3.jsonl': - inner_data = { - "primary_language": "en", - "is_rotation_valid": True, - "rotation_correction": 0, - "is_table": False, - "is_diagram": True, - "natural_text": "Medium Length Text" - } - data = { - "custom_id": f"{pdf_s3_path}-1", - "outputs": [{"text": json.dumps(inner_data)}], - "round": 0 - } - elif s3_path == 's3://bucket/inference/output4.jsonl': - inner_data = { - "primary_language": "en", - "is_rotation_valid": True, - "rotation_correction": 0, - "is_table": False, - "is_diagram": False, - "natural_text": "The Longest Correct Text" - } - data = { - "custom_id": f"{pdf_s3_path}-1", - "outputs": [{"text": json.dumps(inner_data)}], - "round": 0 - } - else: - data = {} - - line = json.dumps(data) + '\n' - content_bytes = line.encode('utf-8') - return content_bytes - - mock_get_s3_bytes.side_effect = get_s3_bytes_side_effect - - # Call build_dolma_doc - s3_workspace = 's3://bucket/workspace' - dolma_doc = build_dolma_doc(s3_workspace, pdf) - - # Check that the resulting dolma_doc has the expected document_text - expected_text = 'The Longest Correct Text\n' - - self.assertIsNotNone(dolma_doc) - self.assertEqual(dolma_doc['text'], expected_text) - - # Additional assertions to ensure that the correct page was selected - self.assertEqual(dolma_doc['metadata']['Source-File'], pdf_s3_path) - self.assertEqual(dolma_doc['metadata']['pdf-total-pages'], 1) - self.assertEqual(len(dolma_doc['attributes']['pdf_page_numbers']), 1) - self.assertEqual(dolma_doc['attributes']['pdf_page_numbers'][0][2], 1) - - # Ensure that the document ID is correctly computed - expected_id = hashlib.sha1(expected_text.encode()).hexdigest() - self.assertEqual(dolma_doc['id'], expected_id) - - -class TestBuildPageQuery(unittest.TestCase): - def testNotParsing(self): - file = os.path.join( - os.path.dirname(__file__), - "gnarly_pdfs", - "not_parsing.pdf" - ) - - for page in range(1,9): - query = build_page_query(file, "not_parsing.pdf", page, 1024, 6000) - print(query) - - def testNotParsing2(self): - file = os.path.join( - os.path.dirname(__file__), - "gnarly_pdfs", - "not_parsing2.pdf" - ) - - for page in range(1,10): - query = build_page_query(file, "not_parsing2.pdf", page, 1024, 6000) - print(query) - - def testNotParsingHugeMemoryUsage(self): - file = os.path.join( - os.path.dirname(__file__), - "gnarly_pdfs", - "failing_pdf_pg9.pdf" - ) - - print("Starting to parse bad pdf") - - query = build_page_query(file, "failing_pdf_pg9.pdf", 9, 1024, 6000) - - print(query) - - - def testRotation(self): - # First, generate and save the non-rotated image - query = build_page_query(os.path.join( - os.path.dirname(__file__), - "gnarly_pdfs", - "edgar.pdf" - ), "edgar.pdf", 1, 1024, 6000, 0) - - # Extract the base64 image from the query - image_content = query["chat_messages"][0]["content"][1] - self.assertEqual(image_content["type"], "image_url") - image_url = image_content["image_url"]["url"] - - # Extract base64 string from the data URL - prefix = "data:image/png;base64," - self.assertTrue(image_url.startswith(prefix)) - image_base64 = image_url[len(prefix):] - - # Decode the base64 string - image_data = base64.b64decode(image_base64) - - # Define the output file path for the non-rotated image - output_image_path = os.path.join(os.path.dirname(__file__), "test_renders", "output_image.png") - - # Save the non-rotated image to a file - with open(output_image_path, "wb") as image_file: - image_file.write(image_data) - - # Now, generate and save the rotated image (90 degrees clockwise) - query_rotated = build_page_query(os.path.join( - os.path.dirname(__file__), - "gnarly_pdfs", - "edgar.pdf" - ), "edgar.pdf", 1, 1024, 6000, 90) - - # Extract the base64 image from the rotated query - image_content_rotated = query_rotated["chat_messages"][0]["content"][1] - self.assertEqual(image_content_rotated["type"], "image_url") - image_url_rotated = image_content_rotated["image_url"]["url"] - - # Extract base64 string from the data URL for the rotated image - self.assertTrue(image_url_rotated.startswith(prefix)) - image_base64_rotated = image_url_rotated[len(prefix):] - - # Decode the base64 string for the rotated image - image_data_rotated = base64.b64decode(image_base64_rotated) - - # Define the output file path for the rotated image - output_image_rotated_path = os.path.join(os.path.dirname(__file__), "test_renders", "output_image_rotated90.png") - - # Save the rotated image to a file - with open(output_image_rotated_path, "wb") as image_file_rotated: - image_file_rotated.write(image_data_rotated) - - # Verification Step: Ensure the rotated image is 90 degrees clockwise rotated - - # Open both images using PIL - with Image.open(output_image_path) as original_image: - with Image.open(output_image_rotated_path) as rotated_image: - - # Compare pixel by pixel - original_pixels = original_image.load() - rotated_pixels = rotated_image.load() - width, height = original_image.size - - self.assertEqual(width, rotated_image.size[1]) - self.assertEqual(height, rotated_image.size[0]) - - for x in range(width): - for y in range(height): - - self.assertEqual( - original_pixels[x, y], rotated_pixels[height - 1 - y, x], - f"Pixel mismatch at ({x}, {y})" - ) - - print("Rotation verification passed: The rotated image is correctly rotated 90 degrees clockwise.") - - -# Run the test -if __name__ == '__main__': - unittest.main() diff --git a/tests/test_s3_work_queue.py b/tests/test_s3_work_queue.py index ecfd949..6b602a3 100644 --- a/tests/test_s3_work_queue.py +++ b/tests/test_s3_work_queue.py @@ -7,7 +7,7 @@ import hashlib from typing import List, Dict # Import the classes we're testing -from olmocr.s3_queue import S3WorkQueue, WorkItem +from olmocr.work_queue import S3WorkQueue, WorkItem class TestS3WorkQueue(unittest.TestCase): def setUp(self): @@ -70,8 +70,8 @@ class TestS3WorkQueue(unittest.TestCase): async def test_populate_queue_new_items(self): """Test populating queue with new items""" # Mock empty existing index - with patch('olmocr.s3_queue.download_zstd_csv', return_value=[]): - with patch('olmocr.s3_queue.upload_zstd_csv') as mock_upload: + with patch('olmocr.work_queue.download_zstd_csv', return_value=[]): + with patch('olmocr.work_queue.upload_zstd_csv') as mock_upload: await self.work_queue.populate_queue(self.sample_paths, items_per_group=2) # Verify upload was called with correct data @@ -97,8 +97,8 @@ class TestS3WorkQueue(unittest.TestCase): existing_hash = S3WorkQueue._compute_workgroup_hash(existing_paths) existing_line = f"{existing_hash},{existing_paths[0]}" - with patch('olmocr.s3_queue.download_zstd_csv', return_value=[existing_line]): - with patch('olmocr.s3_queue.upload_zstd_csv') as mock_upload: + with patch('olmocr.work_queue.download_zstd_csv', return_value=[existing_line]): + with patch('olmocr.work_queue.upload_zstd_csv') as mock_upload: await self.work_queue.populate_queue(existing_paths + new_paths, items_per_group=1) # Verify upload called with both existing and new items @@ -116,8 +116,8 @@ class TestS3WorkQueue(unittest.TestCase): completed_items = [f"s3://test-bucket/workspace/results/output_{work_hash}.jsonl"] - with patch('olmocr.s3_queue.download_zstd_csv', return_value=[work_line]): - with patch('olmocr.s3_queue.expand_s3_glob', return_value=completed_items): + with patch('olmocr.work_queue.download_zstd_csv', return_value=[work_line]): + with patch('olmocr.work_queue.expand_s3_glob', return_value=completed_items): await self.work_queue.initialize_queue() # Queue should be empty since all work is completed @@ -143,7 +143,7 @@ class TestS3WorkQueue(unittest.TestCase): async def test_get_work(self): """Test getting work items""" # Setup test data - work_item = WorkItem(hash="testhash123", s3_work_paths=["s3://test/file1.pdf"]) + work_item = WorkItem(hash="testhash123", work_paths=["s3://test/file1.pdf"]) await self.work_queue._queue.put(work_item) # Test getting available work @@ -162,7 +162,7 @@ class TestS3WorkQueue(unittest.TestCase): @async_test async def test_get_work_completed(self): """Test getting work that's already completed""" - work_item = WorkItem(hash="testhash123", s3_work_paths=["s3://test/file1.pdf"]) + work_item = WorkItem(hash="testhash123", work_paths=["s3://test/file1.pdf"]) await self.work_queue._queue.put(work_item) # Simulate completed work @@ -174,7 +174,7 @@ class TestS3WorkQueue(unittest.TestCase): @async_test async def test_get_work_locked(self): """Test getting work that's locked by another worker""" - work_item = WorkItem(hash="testhash123", s3_work_paths=["s3://test/file1.pdf"]) + work_item = WorkItem(hash="testhash123", work_paths=["s3://test/file1.pdf"]) await self.work_queue._queue.put(work_item) # Simulate active lock @@ -190,7 +190,7 @@ class TestS3WorkQueue(unittest.TestCase): @async_test async def test_get_work_stale_lock(self): """Test getting work with a stale lock""" - work_item = WorkItem(hash="testhash123", s3_work_paths=["s3://test/file1.pdf"]) + work_item = WorkItem(hash="testhash123", work_paths=["s3://test/file1.pdf"]) await self.work_queue._queue.put(work_item) # Simulate stale lock @@ -206,7 +206,7 @@ class TestS3WorkQueue(unittest.TestCase): @async_test async def test_mark_done(self): """Test marking work as done""" - work_item = WorkItem(hash="testhash123", s3_work_paths=["s3://test/file1.pdf"]) + work_item = WorkItem(hash="testhash123", work_paths=["s3://test/file1.pdf"]) await self.work_queue._queue.put(work_item) await self.work_queue.mark_done(work_item) @@ -223,10 +223,10 @@ class TestS3WorkQueue(unittest.TestCase): self.loop = asyncio.new_event_loop() asyncio.set_event_loop(self.loop) - self.loop.run_until_complete(self.work_queue._queue.put(WorkItem(hash="test1", s3_work_paths=["path1"]))) + self.loop.run_until_complete(self.work_queue._queue.put(WorkItem(hash="test1", work_paths=["path1"]))) self.assertEqual(self.work_queue.size, 1) - self.loop.run_until_complete(self.work_queue._queue.put(WorkItem(hash="test2", s3_work_paths=["path2"]))) + self.loop.run_until_complete(self.work_queue._queue.put(WorkItem(hash="test2", work_paths=["path2"]))) self.assertEqual(self.work_queue.size, 2) self.loop.close()