mirror of
https://github.com/allenai/olmocr.git
synced 2025-12-02 10:10:44 +00:00
Refactoring
This commit is contained in:
parent
c6062677aa
commit
96ae2dd49b
1
.gitignore
vendored
1
.gitignore
vendored
@ -11,6 +11,7 @@ sample200_vllm/*
|
||||
sample200_sglang/*
|
||||
pdelfin_testset/*
|
||||
/*.html
|
||||
scoreelo.csv
|
||||
debug.log
|
||||
birrpipeline-debug.log
|
||||
beakerpipeline-debug.log
|
||||
|
||||
@ -31,7 +31,7 @@ from typing import Optional, Tuple, List, Dict, Set
|
||||
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
|
||||
from concurrent.futures.process import BrokenProcessPool
|
||||
|
||||
from olmocr.s3_queue import S3WorkQueue, WorkItem
|
||||
from olmocr.work_queue import S3WorkQueue, LocalWorkQueue
|
||||
from olmocr.s3_utils import expand_s3_glob, get_s3_bytes, get_s3_bytes_with_backoff, parse_s3_path, download_zstd_csv, upload_zstd_csv, download_directory
|
||||
from olmocr.data.renderpdf import render_pdf_to_base64png
|
||||
from olmocr.filter.filter import PdfFilter, Language
|
||||
@ -420,7 +420,7 @@ async def worker(args, work_queue: S3WorkQueue, semaphore, worker_id):
|
||||
|
||||
try:
|
||||
async with asyncio.TaskGroup() as tg:
|
||||
dolma_tasks = [tg.create_task(process_pdf(args, worker_id, pdf)) for pdf in work_item.s3_work_paths]
|
||||
dolma_tasks = [tg.create_task(process_pdf(args, worker_id, pdf)) for pdf in work_item.work_paths]
|
||||
logger.info(f"Created all tasks for {work_item.hash}")
|
||||
|
||||
logger.info(f"Finished TaskGroup for worker on {work_item.hash}")
|
||||
@ -834,7 +834,7 @@ def print_stats(args):
|
||||
|
||||
async def main():
|
||||
parser = argparse.ArgumentParser(description='Manager for running millions of PDFs through a batch inference pipeline')
|
||||
parser.add_argument('workspace', help='The S3 path where work will be done e.g., s3://bucket/prefix/')
|
||||
parser.add_argument('workspace', help='The filesystem path where work will be stored, can be a local folder, or an s3 path if coordinating work with many workers, s3://bucket/prefix/ ')
|
||||
parser.add_argument('--pdfs', help='Path to add pdfs stored in s3 to the workspace, can be a glob path s3://bucket/prefix/*.pdf or path to file containing list of pdf paths', default=None)
|
||||
parser.add_argument('--workspace_profile', help='S3 configuration profile for accessing the workspace', default=None)
|
||||
parser.add_argument('--pdf_profile', help='S3 configuration profile for accessing the raw pdf documents', default=None)
|
||||
@ -891,7 +891,10 @@ async def main():
|
||||
check_poppler_version()
|
||||
|
||||
# Create work queue
|
||||
work_queue = S3WorkQueue(workspace_s3, args.workspace)
|
||||
if args.workspace.startswith("s3://"):
|
||||
work_queue = S3WorkQueue(workspace_s3, args.workspace)
|
||||
else:
|
||||
work_queue = LocalWorkQueue(args.workspace)
|
||||
|
||||
if args.pdfs:
|
||||
logger.info("Got --pdfs argument, going to add to the work queue")
|
||||
|
||||
@ -1,302 +0,0 @@
|
||||
import os
|
||||
import random
|
||||
import logging
|
||||
import hashlib
|
||||
import tempfile
|
||||
import datetime
|
||||
from typing import Optional, Tuple, List, Dict, Set
|
||||
from dataclasses import dataclass
|
||||
import asyncio
|
||||
from functools import partial
|
||||
|
||||
from olmocr.s3_utils import (
|
||||
expand_s3_glob,
|
||||
download_zstd_csv,
|
||||
upload_zstd_csv,
|
||||
parse_s3_path
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@dataclass
|
||||
class WorkItem:
|
||||
"""Represents a single work item in the queue"""
|
||||
hash: str
|
||||
s3_work_paths: List[str]
|
||||
|
||||
class S3WorkQueue:
|
||||
"""
|
||||
Manages a work queue stored in S3 that coordinates work across multiple workers.
|
||||
The queue maintains a list of work items, where each work item is a group of s3 paths
|
||||
that should be processed together.
|
||||
|
||||
Each work item gets a hash, and completed work items will have their results
|
||||
stored in s3://workspace_path/results/output_[hash].jsonl
|
||||
|
||||
This is the ground source of truth about which work items are done.
|
||||
|
||||
When a worker takes an item off the queue, it will write an empty s3 file to
|
||||
s3://workspace_path/worker_locks/output_[hash].jsonl
|
||||
|
||||
The queue gets randomized on each worker, so workers pull random work items to operate on.
|
||||
As you pull an item, we will check to see if it has been completed. If yes,
|
||||
then it will immediately fetch the next item. If a lock file was created within a configurable
|
||||
timeout (30 mins by default), then that work item is also skipped.
|
||||
|
||||
The lock will will be deleted once the worker is done with that item.
|
||||
"""
|
||||
def __init__(self, s3_client, workspace_path: str):
|
||||
"""
|
||||
Initialize the work queue.
|
||||
|
||||
Args:
|
||||
s3_client: Boto3 S3 client to use for operations
|
||||
workspace_path: S3 path where work queue and results are stored
|
||||
"""
|
||||
self.s3_client = s3_client
|
||||
self.workspace_path = workspace_path.rstrip('/')
|
||||
|
||||
self._index_path = os.path.join(self.workspace_path, "work_index_list.csv.zstd")
|
||||
self._output_glob = os.path.join(self.workspace_path, "results", "*.jsonl")
|
||||
self._queue = asyncio.Queue()
|
||||
|
||||
@staticmethod
|
||||
def _compute_workgroup_hash(s3_work_paths: List[str]) -> str:
|
||||
"""
|
||||
Compute a deterministic hash for a group of paths.
|
||||
|
||||
Args:
|
||||
s3_work_paths: List of S3 paths
|
||||
|
||||
Returns:
|
||||
SHA1 hash of the sorted paths
|
||||
"""
|
||||
sha1 = hashlib.sha1()
|
||||
for path in sorted(s3_work_paths):
|
||||
sha1.update(path.encode('utf-8'))
|
||||
return sha1.hexdigest()
|
||||
|
||||
async def populate_queue(self, s3_work_paths: list[str], items_per_group: int) -> None:
|
||||
"""
|
||||
Add new items to the work queue.
|
||||
|
||||
Args:
|
||||
s3_work_paths: Each individual s3 path that we will process over
|
||||
items_per_group: Number of items to group together in a single work item
|
||||
"""
|
||||
all_paths = set(s3_work_paths)
|
||||
logger.info(f"Found {len(all_paths):,} total paths")
|
||||
|
||||
# Load existing work groups
|
||||
existing_lines = await asyncio.to_thread(download_zstd_csv, self.s3_client, self._index_path)
|
||||
existing_groups = {}
|
||||
for line in existing_lines:
|
||||
if line.strip():
|
||||
parts = line.strip().split(",")
|
||||
group_hash = parts[0]
|
||||
group_paths = parts[1:]
|
||||
existing_groups[group_hash] = group_paths
|
||||
|
||||
existing_path_set = {path for paths in existing_groups.values() for path in paths}
|
||||
|
||||
# Find new paths to process
|
||||
new_paths = all_paths - existing_path_set
|
||||
logger.info(f"{len(new_paths):,} new paths to add to the workspace")
|
||||
|
||||
if not new_paths:
|
||||
return
|
||||
|
||||
# Create new work groups
|
||||
new_groups = []
|
||||
current_group = []
|
||||
for path in sorted(new_paths):
|
||||
current_group.append(path)
|
||||
if len(current_group) == items_per_group:
|
||||
group_hash = self._compute_workgroup_hash(current_group)
|
||||
new_groups.append((group_hash, current_group))
|
||||
current_group = []
|
||||
if current_group:
|
||||
group_hash = self._compute_workgroup_hash(current_group)
|
||||
new_groups.append((group_hash, current_group))
|
||||
|
||||
logger.info(f"Created {len(new_groups):,} new work groups")
|
||||
|
||||
# Combine and save updated work groups
|
||||
combined_groups = existing_groups.copy()
|
||||
for group_hash, group_paths in new_groups:
|
||||
combined_groups[group_hash] = group_paths
|
||||
|
||||
combined_lines = [
|
||||
",".join([group_hash] + group_paths)
|
||||
for group_hash, group_paths in combined_groups.items()
|
||||
]
|
||||
|
||||
if new_groups:
|
||||
await asyncio.to_thread(
|
||||
upload_zstd_csv,
|
||||
self.s3_client,
|
||||
self._index_path,
|
||||
combined_lines
|
||||
)
|
||||
|
||||
async def initialize_queue(self) -> None:
|
||||
"""
|
||||
Load the work queue from S3 and initialize it for processing.
|
||||
Removes already completed work items and randomizes the order.
|
||||
"""
|
||||
# Load work items and completed items in parallel
|
||||
download_task = asyncio.to_thread(
|
||||
download_zstd_csv,
|
||||
self.s3_client,
|
||||
self._index_path
|
||||
)
|
||||
expand_task = asyncio.to_thread(
|
||||
expand_s3_glob,
|
||||
self.s3_client,
|
||||
self._output_glob
|
||||
)
|
||||
|
||||
work_queue_lines, done_work_items = await asyncio.gather(download_task, expand_task)
|
||||
|
||||
# Process work queue lines
|
||||
work_queue = {
|
||||
parts[0]: parts[1:]
|
||||
for line in work_queue_lines
|
||||
if (parts := line.strip().split(",")) and line.strip()
|
||||
}
|
||||
|
||||
# Get set of completed work hashes
|
||||
done_work_hashes = {
|
||||
os.path.basename(item)[len('output_'):-len('.jsonl')]
|
||||
for item in done_work_items
|
||||
if os.path.basename(item).startswith('output_')
|
||||
and os.path.basename(item).endswith('.jsonl')
|
||||
}
|
||||
|
||||
# Find remaining work and shuffle
|
||||
remaining_work_hashes = set(work_queue) - done_work_hashes
|
||||
remaining_items = [
|
||||
WorkItem(hash=hash_, s3_work_paths=work_queue[hash_])
|
||||
for hash_ in remaining_work_hashes
|
||||
]
|
||||
random.shuffle(remaining_items)
|
||||
|
||||
# Initialize queue
|
||||
self._queue = asyncio.Queue()
|
||||
for item in remaining_items:
|
||||
await self._queue.put(item)
|
||||
|
||||
logger.info(f"Initialized queue with {self._queue.qsize()} work items")
|
||||
|
||||
async def is_completed(self, work_hash: str) -> bool:
|
||||
"""
|
||||
Check if a work item has been completed.
|
||||
|
||||
Args:
|
||||
work_hash: Hash of the work item to check
|
||||
|
||||
Returns:
|
||||
True if the work is completed, False otherwise
|
||||
"""
|
||||
output_s3_path = os.path.join(self.workspace_path, "results", f"output_{work_hash}.jsonl")
|
||||
bucket, key = parse_s3_path(output_s3_path)
|
||||
|
||||
try:
|
||||
await asyncio.to_thread(
|
||||
self.s3_client.head_object,
|
||||
Bucket=bucket,
|
||||
Key=key
|
||||
)
|
||||
return True
|
||||
except self.s3_client.exceptions.ClientError:
|
||||
return False
|
||||
|
||||
async def get_work(self, worker_lock_timeout_secs: int = 1800) -> Optional[WorkItem]:
|
||||
"""
|
||||
Get the next available work item that isn't completed or locked.
|
||||
|
||||
Args:
|
||||
worker_lock_timeout_secs: Number of seconds before considering a worker lock stale (default 30 mins)
|
||||
|
||||
Returns:
|
||||
WorkItem if work is available, None if queue is empty
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
work_item = self._queue.get_nowait()
|
||||
except asyncio.QueueEmpty:
|
||||
return None
|
||||
|
||||
# Check if work is already completed
|
||||
if await self.is_completed(work_item.hash):
|
||||
logger.debug(f"Work item {work_item.hash} already completed, skipping")
|
||||
self._queue.task_done()
|
||||
continue
|
||||
|
||||
# Check for worker lock
|
||||
lock_path = os.path.join(self.workspace_path, "worker_locks", f"output_{work_item.hash}.jsonl")
|
||||
bucket, key = parse_s3_path(lock_path)
|
||||
|
||||
try:
|
||||
response = await asyncio.to_thread(
|
||||
self.s3_client.head_object,
|
||||
Bucket=bucket,
|
||||
Key=key
|
||||
)
|
||||
|
||||
# Check if lock is stale
|
||||
last_modified = response['LastModified']
|
||||
if (datetime.datetime.now(datetime.timezone.utc) - last_modified).total_seconds() > worker_lock_timeout_secs:
|
||||
# Lock is stale, we can take this work
|
||||
logger.debug(f"Found stale lock for {work_item.hash}, taking work item")
|
||||
else:
|
||||
# Lock is active, skip this work
|
||||
logger.debug(f"Work item {work_item.hash} is locked by another worker, skipping")
|
||||
self._queue.task_done()
|
||||
continue
|
||||
|
||||
except self.s3_client.exceptions.ClientError:
|
||||
# No lock exists, we can take this work
|
||||
pass
|
||||
|
||||
# Create our lock file
|
||||
try:
|
||||
await asyncio.to_thread(
|
||||
self.s3_client.put_object,
|
||||
Bucket=bucket,
|
||||
Key=key,
|
||||
Body=b''
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to create lock file for {work_item.hash}: {e}")
|
||||
self._queue.task_done()
|
||||
continue
|
||||
|
||||
return work_item
|
||||
|
||||
async def mark_done(self, work_item: WorkItem) -> None:
|
||||
"""
|
||||
Mark a work item as done by removing its lock file.
|
||||
|
||||
Args:
|
||||
work_item: The WorkItem to mark as done
|
||||
"""
|
||||
lock_path = os.path.join(self.workspace_path, "worker_locks", f"output_{work_item.hash}.jsonl")
|
||||
bucket, key = parse_s3_path(lock_path)
|
||||
|
||||
try:
|
||||
await asyncio.to_thread(
|
||||
self.s3_client.delete_object,
|
||||
Bucket=bucket,
|
||||
Key=key
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete lock file for {work_item.hash}: {e}")
|
||||
|
||||
self._queue.task_done()
|
||||
|
||||
@property
|
||||
def size(self) -> int:
|
||||
"""Get current size of work queue"""
|
||||
return self._queue.qsize()
|
||||
640
olmocr/work_queue.py
Normal file
640
olmocr/work_queue.py
Normal file
@ -0,0 +1,640 @@
|
||||
import os
|
||||
import random
|
||||
import logging
|
||||
import hashlib
|
||||
import tempfile
|
||||
import datetime
|
||||
import asyncio
|
||||
import abc
|
||||
from typing import Optional, List, Dict, Set
|
||||
from dataclasses import dataclass
|
||||
|
||||
from functools import partial
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@dataclass
|
||||
class WorkItem:
|
||||
"""Represents a single work item in the queue"""
|
||||
hash: str
|
||||
work_paths: List[str]
|
||||
|
||||
|
||||
class WorkQueue(abc.ABC):
|
||||
"""
|
||||
Base class defining the interface for a work queue.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
async def populate_queue(self, work_paths: List[str], items_per_group: int) -> None:
|
||||
"""
|
||||
Add new items to the work queue. The specifics will vary depending on
|
||||
whether this is a local or S3-backed queue.
|
||||
|
||||
Args:
|
||||
work_paths: Each individual path that we will process over
|
||||
items_per_group: Number of items to group together in a single work item
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def initialize_queue(self) -> None:
|
||||
"""
|
||||
Load the work queue from the relevant store (local or remote)
|
||||
and initialize it for processing.
|
||||
|
||||
For example, this might remove already completed work items and randomize
|
||||
the order before adding them to an internal queue.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def is_completed(self, work_hash: str) -> bool:
|
||||
"""
|
||||
Check if a work item has been completed.
|
||||
|
||||
Args:
|
||||
work_hash: Hash of the work item to check
|
||||
|
||||
Returns:
|
||||
True if the work is completed, False otherwise
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get_work(self, worker_lock_timeout_secs: int = 1800) -> Optional[WorkItem]:
|
||||
"""
|
||||
Get the next available work item that isn't completed or locked.
|
||||
|
||||
Args:
|
||||
worker_lock_timeout_secs: Number of seconds before considering
|
||||
a worker lock stale (default 30 mins)
|
||||
|
||||
Returns:
|
||||
WorkItem if work is available, None if queue is empty
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def mark_done(self, work_item: WorkItem) -> None:
|
||||
"""
|
||||
Mark a work item as done by removing its lock file
|
||||
or performing any other cleanup.
|
||||
|
||||
Args:
|
||||
work_item: The WorkItem to mark as done
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def size(self) -> int:
|
||||
"""Get current size of work queue"""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _compute_workgroup_hash(s3_work_paths: List[str]) -> str:
|
||||
"""
|
||||
Compute a deterministic hash for a group of paths.
|
||||
|
||||
Args:
|
||||
s3_work_paths: List of paths (local or S3)
|
||||
|
||||
Returns:
|
||||
SHA1 hash of the sorted paths
|
||||
"""
|
||||
sha1 = hashlib.sha1()
|
||||
for path in sorted(s3_work_paths):
|
||||
sha1.update(path.encode('utf-8'))
|
||||
return sha1.hexdigest()
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------------------
|
||||
# Local Helpers for reading/writing the index CSV (compressed with zstd) to disk
|
||||
# --------------------------------------------------------------------------------------
|
||||
|
||||
try:
|
||||
import zstandard
|
||||
except ImportError:
|
||||
zstandard = None
|
||||
|
||||
def download_zstd_csv_local(local_path: str) -> List[str]:
|
||||
"""
|
||||
Download a zstd-compressed CSV from a local path.
|
||||
If the file doesn't exist, returns an empty list.
|
||||
"""
|
||||
if not os.path.exists(local_path):
|
||||
return []
|
||||
|
||||
if not zstandard:
|
||||
raise RuntimeError("zstandard package is required for local zstd CSV operations.")
|
||||
|
||||
with open(local_path, 'rb') as f:
|
||||
dctx = zstandard.ZstdDecompressor()
|
||||
data = dctx.decompress(f.read())
|
||||
lines = data.decode('utf-8').splitlines()
|
||||
return lines
|
||||
|
||||
def upload_zstd_csv_local(local_path: str, lines: List[str]) -> None:
|
||||
"""
|
||||
Upload a zstd-compressed CSV to a local path.
|
||||
"""
|
||||
if not zstandard:
|
||||
raise RuntimeError("zstandard package is required for local zstd CSV operations.")
|
||||
|
||||
data = "\n".join(lines).encode('utf-8')
|
||||
cctx = zstandard.ZstdCompressor()
|
||||
compressed_data = cctx.compress(data)
|
||||
|
||||
# Ensure parent directories exist
|
||||
os.makedirs(os.path.dirname(local_path), exist_ok=True)
|
||||
|
||||
with open(local_path, 'wb') as f:
|
||||
f.write(compressed_data)
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------------------
|
||||
# LocalWorkQueue Implementation
|
||||
# --------------------------------------------------------------------------------------
|
||||
|
||||
class LocalWorkQueue(WorkQueue):
|
||||
"""
|
||||
A local in-memory and on-disk WorkQueue implementation, which uses
|
||||
a local workspace directory to store the queue index, lock files,
|
||||
and completed results for persistent resumption across process restarts.
|
||||
"""
|
||||
|
||||
def __init__(self, workspace_path: str):
|
||||
"""
|
||||
Initialize the local work queue.
|
||||
|
||||
Args:
|
||||
workspace_path: Local directory path where the queue index,
|
||||
results, and locks are stored.
|
||||
"""
|
||||
self.workspace_path = os.path.abspath(workspace_path)
|
||||
os.makedirs(self.workspace_path, exist_ok=True)
|
||||
|
||||
# Local index file (compressed)
|
||||
self._index_path = os.path.join(self.workspace_path, "work_index_list.csv.zstd")
|
||||
|
||||
# Output directory for completed tasks
|
||||
self._results_dir = os.path.join(self.workspace_path, "results")
|
||||
os.makedirs(self._results_dir, exist_ok=True)
|
||||
|
||||
# Directory for lock files
|
||||
self._locks_dir = os.path.join(self.workspace_path, "worker_locks")
|
||||
os.makedirs(self._locks_dir, exist_ok=True)
|
||||
|
||||
# Internal queue
|
||||
self._queue = asyncio.Queue()
|
||||
|
||||
async def populate_queue(self, s3_work_paths: List[str], items_per_group: int) -> None:
|
||||
"""
|
||||
Add new items to the work queue (local version).
|
||||
|
||||
Args:
|
||||
s3_work_paths: Each individual path (local in this context)
|
||||
that we will process over
|
||||
items_per_group: Number of items to group together in a single work item
|
||||
"""
|
||||
# Treat them as local paths, but keep variable name for consistency
|
||||
all_paths = set(s3_work_paths)
|
||||
logger.info(f"Found {len(all_paths):,} total paths")
|
||||
|
||||
# Load existing work groups from local index
|
||||
existing_lines = await asyncio.to_thread(download_zstd_csv_local, self._index_path)
|
||||
existing_groups = {}
|
||||
for line in existing_lines:
|
||||
if line.strip():
|
||||
parts = line.strip().split(",")
|
||||
group_hash = parts[0]
|
||||
group_paths = parts[1:]
|
||||
existing_groups[group_hash] = group_paths
|
||||
|
||||
existing_path_set = {p for paths in existing_groups.values() for p in paths}
|
||||
new_paths = all_paths - existing_path_set
|
||||
logger.info(f"{len(new_paths):,} new paths to add to the workspace")
|
||||
|
||||
if not new_paths:
|
||||
return
|
||||
|
||||
# Create new work groups
|
||||
new_groups = []
|
||||
current_group = []
|
||||
for path in sorted(new_paths):
|
||||
current_group.append(path)
|
||||
if len(current_group) == items_per_group:
|
||||
group_hash = self._compute_workgroup_hash(current_group)
|
||||
new_groups.append((group_hash, current_group))
|
||||
current_group = []
|
||||
if current_group:
|
||||
group_hash = self._compute_workgroup_hash(current_group)
|
||||
new_groups.append((group_hash, current_group))
|
||||
|
||||
logger.info(f"Created {len(new_groups):,} new work groups")
|
||||
|
||||
# Combine and save updated work groups
|
||||
combined_groups = existing_groups.copy()
|
||||
for group_hash, group_paths in new_groups:
|
||||
combined_groups[group_hash] = group_paths
|
||||
|
||||
combined_lines = [
|
||||
",".join([group_hash] + group_paths)
|
||||
for group_hash, group_paths in combined_groups.items()
|
||||
]
|
||||
|
||||
if new_groups:
|
||||
# Write the combined data back to disk in zstd CSV format
|
||||
await asyncio.to_thread(upload_zstd_csv_local, self._index_path, combined_lines)
|
||||
|
||||
async def initialize_queue(self) -> None:
|
||||
"""
|
||||
Load the work queue from the local index file and initialize it for processing.
|
||||
Removes already completed work items and randomizes the order.
|
||||
"""
|
||||
# 1) Read the index
|
||||
work_queue_lines = await asyncio.to_thread(download_zstd_csv_local, self._index_path)
|
||||
work_queue = {
|
||||
parts[0]: parts[1:]
|
||||
for line in work_queue_lines
|
||||
if (parts := line.strip().split(",")) and line.strip()
|
||||
}
|
||||
|
||||
# 2) Determine which items are completed by scanning local results/*.jsonl
|
||||
if not os.path.isdir(self._results_dir):
|
||||
os.makedirs(self._results_dir, exist_ok=True)
|
||||
done_work_items = [
|
||||
f for f in os.listdir(self._results_dir)
|
||||
if f.startswith("output_") and f.endswith(".jsonl")
|
||||
]
|
||||
done_work_hashes = {
|
||||
fn[len('output_'):-len('.jsonl')]
|
||||
for fn in done_work_items
|
||||
}
|
||||
|
||||
# 3) Filter out completed items
|
||||
remaining_work_hashes = set(work_queue) - done_work_hashes
|
||||
remaining_items = [
|
||||
WorkItem(hash=hash_, s3_work_paths=work_queue[hash_])
|
||||
for hash_ in remaining_work_hashes
|
||||
]
|
||||
random.shuffle(remaining_items)
|
||||
|
||||
# 4) Initialize our in-memory queue
|
||||
self._queue = asyncio.Queue()
|
||||
for item in remaining_items:
|
||||
await self._queue.put(item)
|
||||
|
||||
logger.info(f"Initialized local queue with {self._queue.qsize()} work items")
|
||||
|
||||
async def is_completed(self, work_hash: str) -> bool:
|
||||
"""
|
||||
Check if a work item has been completed locally by seeing if
|
||||
output_{work_hash}.jsonl is present in the results directory.
|
||||
|
||||
Args:
|
||||
work_hash: Hash of the work item to check
|
||||
"""
|
||||
output_file = os.path.join(self._results_dir, f"output_{work_hash}.jsonl")
|
||||
return os.path.exists(output_file)
|
||||
|
||||
async def get_work(self, worker_lock_timeout_secs: int = 1800) -> Optional[WorkItem]:
|
||||
"""
|
||||
Get the next available work item that isn't completed or locked.
|
||||
|
||||
Args:
|
||||
worker_lock_timeout_secs: Number of seconds before considering
|
||||
a worker lock stale (default 30 mins)
|
||||
|
||||
Returns:
|
||||
WorkItem if work is available, None if queue is empty
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
work_item = self._queue.get_nowait()
|
||||
except asyncio.QueueEmpty:
|
||||
return None
|
||||
|
||||
# Check if work is already completed
|
||||
if await self.is_completed(work_item.hash):
|
||||
logger.debug(f"Work item {work_item.hash} already completed, skipping")
|
||||
self._queue.task_done()
|
||||
continue
|
||||
|
||||
# Check for worker lock
|
||||
lock_file = os.path.join(self._locks_dir, f"output_{work_item.hash}.jsonl")
|
||||
if os.path.exists(lock_file):
|
||||
# Check modification time
|
||||
mtime = datetime.datetime.fromtimestamp(os.path.getmtime(lock_file), datetime.timezone.utc)
|
||||
if (datetime.datetime.now(datetime.timezone.utc) - mtime).total_seconds() > worker_lock_timeout_secs:
|
||||
# Lock is stale, we can take this work
|
||||
logger.debug(f"Found stale lock for {work_item.hash}, taking work item")
|
||||
else:
|
||||
# Lock is active, skip this work
|
||||
logger.debug(f"Work item {work_item.hash} is locked by another worker, skipping")
|
||||
self._queue.task_done()
|
||||
continue
|
||||
|
||||
# Create our lock file (touch an empty file)
|
||||
try:
|
||||
with open(lock_file, "wb") as f:
|
||||
f.write(b"")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to create lock file for {work_item.hash}: {e}")
|
||||
self._queue.task_done()
|
||||
continue
|
||||
|
||||
return work_item
|
||||
|
||||
async def mark_done(self, work_item: WorkItem) -> None:
|
||||
"""
|
||||
Mark a work item as done by removing its lock file.
|
||||
|
||||
Args:
|
||||
work_item: The WorkItem to mark as done
|
||||
"""
|
||||
lock_file = os.path.join(self._locks_dir, f"output_{work_item.hash}.jsonl")
|
||||
if os.path.exists(lock_file):
|
||||
try:
|
||||
os.remove(lock_file)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete lock file for {work_item.hash}: {e}")
|
||||
self._queue.task_done()
|
||||
|
||||
@property
|
||||
def size(self) -> int:
|
||||
"""Get current size of local work queue"""
|
||||
return self._queue.qsize()
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------------------
|
||||
# S3WorkQueue Implementation (Preserves Original Comments)
|
||||
# --------------------------------------------------------------------------------------
|
||||
|
||||
from olmocr.s3_utils import (
|
||||
expand_s3_glob,
|
||||
download_zstd_csv,
|
||||
upload_zstd_csv,
|
||||
parse_s3_path
|
||||
)
|
||||
|
||||
class S3WorkQueue(WorkQueue):
|
||||
"""
|
||||
Manages a work queue stored in S3 that coordinates work across multiple workers.
|
||||
The queue maintains a list of work items, where each work item is a group of s3 paths
|
||||
that should be processed together.
|
||||
|
||||
Each work item gets a hash, and completed work items will have their results
|
||||
stored in s3://workspace_path/results/output_[hash].jsonl
|
||||
|
||||
This is the ground source of truth about which work items are done.
|
||||
|
||||
When a worker takes an item off the queue, it will write an empty s3 file to
|
||||
s3://workspace_path/worker_locks/output_[hash].jsonl
|
||||
|
||||
The queue gets randomized on each worker, so workers pull random work items to operate on.
|
||||
As you pull an item, we will check to see if it has been completed. If yes,
|
||||
then it will immediately fetch the next item. If a lock file was created within a configurable
|
||||
timeout (30 mins by default), then that work item is also skipped.
|
||||
|
||||
The lock will will be deleted once the worker is done with that item.
|
||||
"""
|
||||
def __init__(self, s3_client, workspace_path: str):
|
||||
"""
|
||||
Initialize the work queue.
|
||||
|
||||
Args:
|
||||
s3_client: Boto3 S3 client to use for operations
|
||||
workspace_path: S3 path where work queue and results are stored
|
||||
"""
|
||||
self.s3_client = s3_client
|
||||
self.workspace_path = workspace_path.rstrip('/')
|
||||
|
||||
self._index_path = os.path.join(self.workspace_path, "work_index_list.csv.zstd")
|
||||
self._output_glob = os.path.join(self.workspace_path, "results", "*.jsonl")
|
||||
self._queue = asyncio.Queue()
|
||||
|
||||
async def populate_queue(self, s3_work_paths: List[str], items_per_group: int) -> None:
|
||||
"""
|
||||
Add new items to the work queue.
|
||||
|
||||
Args:
|
||||
s3_work_paths: Each individual s3 path that we will process over
|
||||
items_per_group: Number of items to group together in a single work item
|
||||
"""
|
||||
all_paths = set(s3_work_paths)
|
||||
logger.info(f"Found {len(all_paths):,} total paths")
|
||||
|
||||
# Load existing work groups
|
||||
existing_lines = await asyncio.to_thread(download_zstd_csv, self.s3_client, self._index_path)
|
||||
existing_groups = {}
|
||||
for line in existing_lines:
|
||||
if line.strip():
|
||||
parts = line.strip().split(",")
|
||||
group_hash = parts[0]
|
||||
group_paths = parts[1:]
|
||||
existing_groups[group_hash] = group_paths
|
||||
|
||||
existing_path_set = {path for paths in existing_groups.values() for path in paths}
|
||||
|
||||
# Find new paths to process
|
||||
new_paths = all_paths - existing_path_set
|
||||
logger.info(f"{len(new_paths):,} new paths to add to the workspace")
|
||||
|
||||
if not new_paths:
|
||||
return
|
||||
|
||||
# Create new work groups
|
||||
new_groups = []
|
||||
current_group = []
|
||||
for path in sorted(new_paths):
|
||||
current_group.append(path)
|
||||
if len(current_group) == items_per_group:
|
||||
group_hash = self._compute_workgroup_hash(current_group)
|
||||
new_groups.append((group_hash, current_group))
|
||||
current_group = []
|
||||
if current_group:
|
||||
group_hash = self._compute_workgroup_hash(current_group)
|
||||
new_groups.append((group_hash, current_group))
|
||||
|
||||
logger.info(f"Created {len(new_groups):,} new work groups")
|
||||
|
||||
# Combine and save updated work groups
|
||||
combined_groups = existing_groups.copy()
|
||||
for group_hash, group_paths in new_groups:
|
||||
combined_groups[group_hash] = group_paths
|
||||
|
||||
combined_lines = [
|
||||
",".join([group_hash] + group_paths)
|
||||
for group_hash, group_paths in combined_groups.items()
|
||||
]
|
||||
|
||||
if new_groups:
|
||||
await asyncio.to_thread(
|
||||
upload_zstd_csv,
|
||||
self.s3_client,
|
||||
self._index_path,
|
||||
combined_lines
|
||||
)
|
||||
|
||||
async def initialize_queue(self) -> None:
|
||||
"""
|
||||
Load the work queue from S3 and initialize it for processing.
|
||||
Removes already completed work items and randomizes the order.
|
||||
"""
|
||||
# Load work items and completed items in parallel
|
||||
download_task = asyncio.to_thread(
|
||||
download_zstd_csv,
|
||||
self.s3_client,
|
||||
self._index_path
|
||||
)
|
||||
expand_task = asyncio.to_thread(
|
||||
expand_s3_glob,
|
||||
self.s3_client,
|
||||
self._output_glob
|
||||
)
|
||||
|
||||
work_queue_lines, done_work_items = await asyncio.gather(download_task, expand_task)
|
||||
|
||||
# Process work queue lines
|
||||
work_queue = {
|
||||
parts[0]: parts[1:]
|
||||
for line in work_queue_lines
|
||||
if (parts := line.strip().split(",")) and line.strip()
|
||||
}
|
||||
|
||||
# Get set of completed work hashes
|
||||
done_work_hashes = {
|
||||
os.path.basename(item)[len('output_'):-len('.jsonl')]
|
||||
for item in done_work_items
|
||||
if os.path.basename(item).startswith('output_')
|
||||
and os.path.basename(item).endswith('.jsonl')
|
||||
}
|
||||
|
||||
# Find remaining work and shuffle
|
||||
remaining_work_hashes = set(work_queue) - done_work_hashes
|
||||
remaining_items = [
|
||||
WorkItem(hash=hash_, s3_work_paths=work_queue[hash_])
|
||||
for hash_ in remaining_work_hashes
|
||||
]
|
||||
random.shuffle(remaining_items)
|
||||
|
||||
# Initialize queue
|
||||
self._queue = asyncio.Queue()
|
||||
for item in remaining_items:
|
||||
await self._queue.put(item)
|
||||
|
||||
logger.info(f"Initialized queue with {self._queue.qsize()} work items")
|
||||
|
||||
async def is_completed(self, work_hash: str) -> bool:
|
||||
"""
|
||||
Check if a work item has been completed.
|
||||
|
||||
Args:
|
||||
work_hash: Hash of the work item to check
|
||||
|
||||
Returns:
|
||||
True if the work is completed, False otherwise
|
||||
"""
|
||||
output_s3_path = os.path.join(self.workspace_path, "results", f"output_{work_hash}.jsonl")
|
||||
bucket, key = parse_s3_path(output_s3_path)
|
||||
|
||||
try:
|
||||
await asyncio.to_thread(
|
||||
self.s3_client.head_object,
|
||||
Bucket=bucket,
|
||||
Key=key
|
||||
)
|
||||
return True
|
||||
except self.s3_client.exceptions.ClientError:
|
||||
return False
|
||||
|
||||
async def get_work(self, worker_lock_timeout_secs: int = 1800) -> Optional[WorkItem]:
|
||||
"""
|
||||
Get the next available work item that isn't completed or locked.
|
||||
|
||||
Args:
|
||||
worker_lock_timeout_secs: Number of seconds before considering a worker lock stale (default 30 mins)
|
||||
|
||||
Returns:
|
||||
WorkItem if work is available, None if queue is empty
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
work_item = self._queue.get_nowait()
|
||||
except asyncio.QueueEmpty:
|
||||
return None
|
||||
|
||||
# Check if work is already completed
|
||||
if await self.is_completed(work_item.hash):
|
||||
logger.debug(f"Work item {work_item.hash} already completed, skipping")
|
||||
self._queue.task_done()
|
||||
continue
|
||||
|
||||
# Check for worker lock
|
||||
lock_path = os.path.join(self.workspace_path, "worker_locks", f"output_{work_item.hash}.jsonl")
|
||||
bucket, key = parse_s3_path(lock_path)
|
||||
|
||||
try:
|
||||
response = await asyncio.to_thread(
|
||||
self.s3_client.head_object,
|
||||
Bucket=bucket,
|
||||
Key=key
|
||||
)
|
||||
|
||||
# Check if lock is stale
|
||||
last_modified = response['LastModified']
|
||||
if (datetime.datetime.now(datetime.timezone.utc) - last_modified).total_seconds() > worker_lock_timeout_secs:
|
||||
# Lock is stale, we can take this work
|
||||
logger.debug(f"Found stale lock for {work_item.hash}, taking work item")
|
||||
else:
|
||||
# Lock is active, skip this work
|
||||
logger.debug(f"Work item {work_item.hash} is locked by another worker, skipping")
|
||||
self._queue.task_done()
|
||||
continue
|
||||
|
||||
except self.s3_client.exceptions.ClientError:
|
||||
# No lock exists, we can take this work
|
||||
pass
|
||||
|
||||
# Create our lock file
|
||||
try:
|
||||
await asyncio.to_thread(
|
||||
self.s3_client.put_object,
|
||||
Bucket=bucket,
|
||||
Key=key,
|
||||
Body=b''
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to create lock file for {work_item.hash}: {e}")
|
||||
self._queue.task_done()
|
||||
continue
|
||||
|
||||
return work_item
|
||||
|
||||
async def mark_done(self, work_item: WorkItem) -> None:
|
||||
"""
|
||||
Mark a work item as done by removing its lock file.
|
||||
|
||||
Args:
|
||||
work_item: The WorkItem to mark as done
|
||||
"""
|
||||
lock_path = os.path.join(self.workspace_path, "worker_locks", f"output_{work_item.hash}.jsonl")
|
||||
bucket, key = parse_s3_path(lock_path)
|
||||
|
||||
try:
|
||||
await asyncio.to_thread(
|
||||
self.s3_client.delete_object,
|
||||
Bucket=bucket,
|
||||
Key=key
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to delete lock file for {work_item.hash}: {e}")
|
||||
|
||||
self._queue.task_done()
|
||||
|
||||
@property
|
||||
def size(self) -> int:
|
||||
"""Get current size of work queue"""
|
||||
return self._queue.qsize()
|
||||
@ -1,278 +0,0 @@
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import base64
|
||||
from PIL import Image
|
||||
|
||||
# Adjust the import path to match where your code resides
|
||||
from olmocr.birrpipeline import build_dolma_doc, DatabaseManager, build_finetuning_prompt, build_page_query
|
||||
|
||||
class TestBuildDolmaDoc(unittest.TestCase):
|
||||
@patch('olmocr.birrpipeline.DatabaseManager')
|
||||
@patch('olmocr.birrpipeline.get_s3_bytes')
|
||||
def test_build_dolma_doc_with_multiple_page_entries(self, mock_get_s3_bytes, mock_DatabaseManager):
|
||||
# Mock DatabaseManager instance
|
||||
mock_db_instance = MagicMock()
|
||||
mock_DatabaseManager.return_value = mock_db_instance
|
||||
|
||||
# Define the PDF record
|
||||
pdf_s3_path = 's3://bucket/pdf/test.pdf'
|
||||
pdf = DatabaseManager.PDFRecord(s3_path=pdf_s3_path, num_pages=1, status='pending')
|
||||
|
||||
# Create multiple BatchInferenceRecord entries for page_num=1
|
||||
entry_a = DatabaseManager.BatchInferenceRecord(
|
||||
inference_s3_path='s3://bucket/inference/output1.jsonl',
|
||||
pdf_s3_path=pdf_s3_path,
|
||||
page_num=1,
|
||||
round=0,
|
||||
start_index=0,
|
||||
length=100,
|
||||
finish_reason='stop',
|
||||
error=None
|
||||
)
|
||||
|
||||
entry_b = DatabaseManager.BatchInferenceRecord(
|
||||
inference_s3_path='s3://bucket/inference/output2.jsonl',
|
||||
pdf_s3_path=pdf_s3_path,
|
||||
page_num=1,
|
||||
round=0,
|
||||
start_index=0,
|
||||
length=100,
|
||||
finish_reason='stop',
|
||||
error=None
|
||||
)
|
||||
|
||||
entry_c = DatabaseManager.BatchInferenceRecord(
|
||||
inference_s3_path='s3://bucket/inference/output3.jsonl',
|
||||
pdf_s3_path=pdf_s3_path,
|
||||
page_num=1,
|
||||
round=0,
|
||||
start_index=0,
|
||||
length=100,
|
||||
finish_reason='stop',
|
||||
error=None
|
||||
)
|
||||
|
||||
entry_d = DatabaseManager.BatchInferenceRecord(
|
||||
inference_s3_path='s3://bucket/inference/output4.jsonl',
|
||||
pdf_s3_path=pdf_s3_path,
|
||||
page_num=1,
|
||||
round=0,
|
||||
start_index=0,
|
||||
length=100,
|
||||
finish_reason='stop',
|
||||
error=None
|
||||
)
|
||||
|
||||
# Set up mock_db_instance.get_index_entries to return all entries
|
||||
mock_db_instance.get_index_entries.return_value = [entry_a, entry_b, entry_c, entry_d]
|
||||
|
||||
# Define get_s3_bytes side effect function
|
||||
def get_s3_bytes_side_effect(s3_client, s3_path, start_index=None, end_index=None):
|
||||
if s3_path == 's3://bucket/inference/output1.jsonl':
|
||||
inner_data = {
|
||||
"primary_language": "en",
|
||||
"is_rotation_valid": True,
|
||||
"rotation_correction": 0,
|
||||
"is_table": False,
|
||||
"is_diagram": False,
|
||||
"natural_text": "Short Text"
|
||||
}
|
||||
data = {
|
||||
"custom_id": f"{pdf_s3_path}-1",
|
||||
"outputs": [{"text": json.dumps(inner_data)}],
|
||||
"round": 0
|
||||
}
|
||||
elif s3_path == 's3://bucket/inference/output2.jsonl':
|
||||
inner_data = {
|
||||
"primary_language": "en",
|
||||
"is_rotation_valid": False,
|
||||
"rotation_correction": 90,
|
||||
"is_table": True,
|
||||
"is_diagram": False,
|
||||
"natural_text": "Very Long Text Here that is longer"
|
||||
}
|
||||
data = {
|
||||
"custom_id": f"{pdf_s3_path}-1",
|
||||
"outputs": [{"text": json.dumps(inner_data)}],
|
||||
"round": 0
|
||||
}
|
||||
elif s3_path == 's3://bucket/inference/output3.jsonl':
|
||||
inner_data = {
|
||||
"primary_language": "en",
|
||||
"is_rotation_valid": True,
|
||||
"rotation_correction": 0,
|
||||
"is_table": False,
|
||||
"is_diagram": True,
|
||||
"natural_text": "Medium Length Text"
|
||||
}
|
||||
data = {
|
||||
"custom_id": f"{pdf_s3_path}-1",
|
||||
"outputs": [{"text": json.dumps(inner_data)}],
|
||||
"round": 0
|
||||
}
|
||||
elif s3_path == 's3://bucket/inference/output4.jsonl':
|
||||
inner_data = {
|
||||
"primary_language": "en",
|
||||
"is_rotation_valid": True,
|
||||
"rotation_correction": 0,
|
||||
"is_table": False,
|
||||
"is_diagram": False,
|
||||
"natural_text": "The Longest Correct Text"
|
||||
}
|
||||
data = {
|
||||
"custom_id": f"{pdf_s3_path}-1",
|
||||
"outputs": [{"text": json.dumps(inner_data)}],
|
||||
"round": 0
|
||||
}
|
||||
else:
|
||||
data = {}
|
||||
|
||||
line = json.dumps(data) + '\n'
|
||||
content_bytes = line.encode('utf-8')
|
||||
return content_bytes
|
||||
|
||||
mock_get_s3_bytes.side_effect = get_s3_bytes_side_effect
|
||||
|
||||
# Call build_dolma_doc
|
||||
s3_workspace = 's3://bucket/workspace'
|
||||
dolma_doc = build_dolma_doc(s3_workspace, pdf)
|
||||
|
||||
# Check that the resulting dolma_doc has the expected document_text
|
||||
expected_text = 'The Longest Correct Text\n'
|
||||
|
||||
self.assertIsNotNone(dolma_doc)
|
||||
self.assertEqual(dolma_doc['text'], expected_text)
|
||||
|
||||
# Additional assertions to ensure that the correct page was selected
|
||||
self.assertEqual(dolma_doc['metadata']['Source-File'], pdf_s3_path)
|
||||
self.assertEqual(dolma_doc['metadata']['pdf-total-pages'], 1)
|
||||
self.assertEqual(len(dolma_doc['attributes']['pdf_page_numbers']), 1)
|
||||
self.assertEqual(dolma_doc['attributes']['pdf_page_numbers'][0][2], 1)
|
||||
|
||||
# Ensure that the document ID is correctly computed
|
||||
expected_id = hashlib.sha1(expected_text.encode()).hexdigest()
|
||||
self.assertEqual(dolma_doc['id'], expected_id)
|
||||
|
||||
|
||||
class TestBuildPageQuery(unittest.TestCase):
|
||||
def testNotParsing(self):
|
||||
file = os.path.join(
|
||||
os.path.dirname(__file__),
|
||||
"gnarly_pdfs",
|
||||
"not_parsing.pdf"
|
||||
)
|
||||
|
||||
for page in range(1,9):
|
||||
query = build_page_query(file, "not_parsing.pdf", page, 1024, 6000)
|
||||
print(query)
|
||||
|
||||
def testNotParsing2(self):
|
||||
file = os.path.join(
|
||||
os.path.dirname(__file__),
|
||||
"gnarly_pdfs",
|
||||
"not_parsing2.pdf"
|
||||
)
|
||||
|
||||
for page in range(1,10):
|
||||
query = build_page_query(file, "not_parsing2.pdf", page, 1024, 6000)
|
||||
print(query)
|
||||
|
||||
def testNotParsingHugeMemoryUsage(self):
|
||||
file = os.path.join(
|
||||
os.path.dirname(__file__),
|
||||
"gnarly_pdfs",
|
||||
"failing_pdf_pg9.pdf"
|
||||
)
|
||||
|
||||
print("Starting to parse bad pdf")
|
||||
|
||||
query = build_page_query(file, "failing_pdf_pg9.pdf", 9, 1024, 6000)
|
||||
|
||||
print(query)
|
||||
|
||||
|
||||
def testRotation(self):
|
||||
# First, generate and save the non-rotated image
|
||||
query = build_page_query(os.path.join(
|
||||
os.path.dirname(__file__),
|
||||
"gnarly_pdfs",
|
||||
"edgar.pdf"
|
||||
), "edgar.pdf", 1, 1024, 6000, 0)
|
||||
|
||||
# Extract the base64 image from the query
|
||||
image_content = query["chat_messages"][0]["content"][1]
|
||||
self.assertEqual(image_content["type"], "image_url")
|
||||
image_url = image_content["image_url"]["url"]
|
||||
|
||||
# Extract base64 string from the data URL
|
||||
prefix = "data:image/png;base64,"
|
||||
self.assertTrue(image_url.startswith(prefix))
|
||||
image_base64 = image_url[len(prefix):]
|
||||
|
||||
# Decode the base64 string
|
||||
image_data = base64.b64decode(image_base64)
|
||||
|
||||
# Define the output file path for the non-rotated image
|
||||
output_image_path = os.path.join(os.path.dirname(__file__), "test_renders", "output_image.png")
|
||||
|
||||
# Save the non-rotated image to a file
|
||||
with open(output_image_path, "wb") as image_file:
|
||||
image_file.write(image_data)
|
||||
|
||||
# Now, generate and save the rotated image (90 degrees clockwise)
|
||||
query_rotated = build_page_query(os.path.join(
|
||||
os.path.dirname(__file__),
|
||||
"gnarly_pdfs",
|
||||
"edgar.pdf"
|
||||
), "edgar.pdf", 1, 1024, 6000, 90)
|
||||
|
||||
# Extract the base64 image from the rotated query
|
||||
image_content_rotated = query_rotated["chat_messages"][0]["content"][1]
|
||||
self.assertEqual(image_content_rotated["type"], "image_url")
|
||||
image_url_rotated = image_content_rotated["image_url"]["url"]
|
||||
|
||||
# Extract base64 string from the data URL for the rotated image
|
||||
self.assertTrue(image_url_rotated.startswith(prefix))
|
||||
image_base64_rotated = image_url_rotated[len(prefix):]
|
||||
|
||||
# Decode the base64 string for the rotated image
|
||||
image_data_rotated = base64.b64decode(image_base64_rotated)
|
||||
|
||||
# Define the output file path for the rotated image
|
||||
output_image_rotated_path = os.path.join(os.path.dirname(__file__), "test_renders", "output_image_rotated90.png")
|
||||
|
||||
# Save the rotated image to a file
|
||||
with open(output_image_rotated_path, "wb") as image_file_rotated:
|
||||
image_file_rotated.write(image_data_rotated)
|
||||
|
||||
# Verification Step: Ensure the rotated image is 90 degrees clockwise rotated
|
||||
|
||||
# Open both images using PIL
|
||||
with Image.open(output_image_path) as original_image:
|
||||
with Image.open(output_image_rotated_path) as rotated_image:
|
||||
|
||||
# Compare pixel by pixel
|
||||
original_pixels = original_image.load()
|
||||
rotated_pixels = rotated_image.load()
|
||||
width, height = original_image.size
|
||||
|
||||
self.assertEqual(width, rotated_image.size[1])
|
||||
self.assertEqual(height, rotated_image.size[0])
|
||||
|
||||
for x in range(width):
|
||||
for y in range(height):
|
||||
|
||||
self.assertEqual(
|
||||
original_pixels[x, y], rotated_pixels[height - 1 - y, x],
|
||||
f"Pixel mismatch at ({x}, {y})"
|
||||
)
|
||||
|
||||
print("Rotation verification passed: The rotated image is correctly rotated 90 degrees clockwise.")
|
||||
|
||||
|
||||
# Run the test
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@ -7,7 +7,7 @@ import hashlib
|
||||
from typing import List, Dict
|
||||
|
||||
# Import the classes we're testing
|
||||
from olmocr.s3_queue import S3WorkQueue, WorkItem
|
||||
from olmocr.work_queue import S3WorkQueue, WorkItem
|
||||
|
||||
class TestS3WorkQueue(unittest.TestCase):
|
||||
def setUp(self):
|
||||
@ -70,8 +70,8 @@ class TestS3WorkQueue(unittest.TestCase):
|
||||
async def test_populate_queue_new_items(self):
|
||||
"""Test populating queue with new items"""
|
||||
# Mock empty existing index
|
||||
with patch('olmocr.s3_queue.download_zstd_csv', return_value=[]):
|
||||
with patch('olmocr.s3_queue.upload_zstd_csv') as mock_upload:
|
||||
with patch('olmocr.work_queue.download_zstd_csv', return_value=[]):
|
||||
with patch('olmocr.work_queue.upload_zstd_csv') as mock_upload:
|
||||
await self.work_queue.populate_queue(self.sample_paths, items_per_group=2)
|
||||
|
||||
# Verify upload was called with correct data
|
||||
@ -97,8 +97,8 @@ class TestS3WorkQueue(unittest.TestCase):
|
||||
existing_hash = S3WorkQueue._compute_workgroup_hash(existing_paths)
|
||||
existing_line = f"{existing_hash},{existing_paths[0]}"
|
||||
|
||||
with patch('olmocr.s3_queue.download_zstd_csv', return_value=[existing_line]):
|
||||
with patch('olmocr.s3_queue.upload_zstd_csv') as mock_upload:
|
||||
with patch('olmocr.work_queue.download_zstd_csv', return_value=[existing_line]):
|
||||
with patch('olmocr.work_queue.upload_zstd_csv') as mock_upload:
|
||||
await self.work_queue.populate_queue(existing_paths + new_paths, items_per_group=1)
|
||||
|
||||
# Verify upload called with both existing and new items
|
||||
@ -116,8 +116,8 @@ class TestS3WorkQueue(unittest.TestCase):
|
||||
|
||||
completed_items = [f"s3://test-bucket/workspace/results/output_{work_hash}.jsonl"]
|
||||
|
||||
with patch('olmocr.s3_queue.download_zstd_csv', return_value=[work_line]):
|
||||
with patch('olmocr.s3_queue.expand_s3_glob', return_value=completed_items):
|
||||
with patch('olmocr.work_queue.download_zstd_csv', return_value=[work_line]):
|
||||
with patch('olmocr.work_queue.expand_s3_glob', return_value=completed_items):
|
||||
await self.work_queue.initialize_queue()
|
||||
|
||||
# Queue should be empty since all work is completed
|
||||
@ -143,7 +143,7 @@ class TestS3WorkQueue(unittest.TestCase):
|
||||
async def test_get_work(self):
|
||||
"""Test getting work items"""
|
||||
# Setup test data
|
||||
work_item = WorkItem(hash="testhash123", s3_work_paths=["s3://test/file1.pdf"])
|
||||
work_item = WorkItem(hash="testhash123", work_paths=["s3://test/file1.pdf"])
|
||||
await self.work_queue._queue.put(work_item)
|
||||
|
||||
# Test getting available work
|
||||
@ -162,7 +162,7 @@ class TestS3WorkQueue(unittest.TestCase):
|
||||
@async_test
|
||||
async def test_get_work_completed(self):
|
||||
"""Test getting work that's already completed"""
|
||||
work_item = WorkItem(hash="testhash123", s3_work_paths=["s3://test/file1.pdf"])
|
||||
work_item = WorkItem(hash="testhash123", work_paths=["s3://test/file1.pdf"])
|
||||
await self.work_queue._queue.put(work_item)
|
||||
|
||||
# Simulate completed work
|
||||
@ -174,7 +174,7 @@ class TestS3WorkQueue(unittest.TestCase):
|
||||
@async_test
|
||||
async def test_get_work_locked(self):
|
||||
"""Test getting work that's locked by another worker"""
|
||||
work_item = WorkItem(hash="testhash123", s3_work_paths=["s3://test/file1.pdf"])
|
||||
work_item = WorkItem(hash="testhash123", work_paths=["s3://test/file1.pdf"])
|
||||
await self.work_queue._queue.put(work_item)
|
||||
|
||||
# Simulate active lock
|
||||
@ -190,7 +190,7 @@ class TestS3WorkQueue(unittest.TestCase):
|
||||
@async_test
|
||||
async def test_get_work_stale_lock(self):
|
||||
"""Test getting work with a stale lock"""
|
||||
work_item = WorkItem(hash="testhash123", s3_work_paths=["s3://test/file1.pdf"])
|
||||
work_item = WorkItem(hash="testhash123", work_paths=["s3://test/file1.pdf"])
|
||||
await self.work_queue._queue.put(work_item)
|
||||
|
||||
# Simulate stale lock
|
||||
@ -206,7 +206,7 @@ class TestS3WorkQueue(unittest.TestCase):
|
||||
@async_test
|
||||
async def test_mark_done(self):
|
||||
"""Test marking work as done"""
|
||||
work_item = WorkItem(hash="testhash123", s3_work_paths=["s3://test/file1.pdf"])
|
||||
work_item = WorkItem(hash="testhash123", work_paths=["s3://test/file1.pdf"])
|
||||
await self.work_queue._queue.put(work_item)
|
||||
|
||||
await self.work_queue.mark_done(work_item)
|
||||
@ -223,10 +223,10 @@ class TestS3WorkQueue(unittest.TestCase):
|
||||
self.loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(self.loop)
|
||||
|
||||
self.loop.run_until_complete(self.work_queue._queue.put(WorkItem(hash="test1", s3_work_paths=["path1"])))
|
||||
self.loop.run_until_complete(self.work_queue._queue.put(WorkItem(hash="test1", work_paths=["path1"])))
|
||||
self.assertEqual(self.work_queue.size, 1)
|
||||
|
||||
self.loop.run_until_complete(self.work_queue._queue.put(WorkItem(hash="test2", s3_work_paths=["path2"])))
|
||||
self.loop.run_until_complete(self.work_queue._queue.put(WorkItem(hash="test2", work_paths=["path2"])))
|
||||
self.assertEqual(self.work_queue.size, 2)
|
||||
|
||||
self.loop.close()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user