diff --git a/olmocr/birrpipeline.py b/olmocr/birrpipeline.py deleted file mode 100644 index 73885fe..0000000 --- a/olmocr/birrpipeline.py +++ /dev/null @@ -1,862 +0,0 @@ -import os -import hashlib -import boto3 -import sqlite3 -import orjson -import argparse -import base64 -import tempfile -import datetime -import posixpath -import threading -import logging -import psutil -import boto3.session -import urllib3.exceptions - -from dataclasses import dataclass -from pypdf import PdfReader -from io import BytesIO -from PIL import Image -from tqdm import tqdm -from functools import partial -from typing import Optional, List, Tuple, Dict, Callable, Any -from urllib.parse import urlparse -import concurrent.futures -from concurrent.futures import ProcessPoolExecutor, as_completed - -from olmocr.data.renderpdf import render_pdf_to_base64png -from olmocr.prompts import build_finetuning_prompt, PageResponse -from olmocr.prompts.anchor import get_anchor_text -from olmocr.s3_utils import parse_custom_id, expand_s3_glob, get_s3_bytes, parse_s3_path -from olmocr.check import check_poppler_version - -# Initialize logger -logger = logging.getLogger(__name__) -logger.setLevel(logging.INFO) - -# File handler for DEBUG level and above with line-by-line flushing -class FlushFileHandler(logging.FileHandler): - def emit(self, record): - super().emit(record) - self.flush() # Explicitly flush after every log entry - -file_handler = FlushFileHandler('birrpipeline-debug.log', mode='a') -file_handler.setLevel(logging.DEBUG) -file_handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')) - -# Add handlers to the logger -logger.addHandler(file_handler) - -# Global s3 client for the whole script, feel free to adjust params if you need it -workspace_s3 = boto3.client('s3') -pdf_s3 = boto3.client('s3') - -# Quiet logs from pypdf -logging.getLogger("pypdf").setLevel(logging.ERROR) - - -class DatabaseManager: - @dataclass(frozen=True) - class BatchInferenceRecord: - inference_s3_path: str - pdf_s3_path: str - page_num: int # 1 indexed! - round: int - start_index: int - length: int - finish_reason: str - error: Optional[str] - - def is_usable(self): - return self.error is None and self.finish_reason == "stop" - - @dataclass(frozen=True) - class PDFRecord: - s3_path: str - num_pages: int - status: str - - def __init__(self, s3_workspace: str, skip_init: bool=False): - cache_key = hashlib.sha256(s3_workspace.strip().lower().encode('utf-8')).hexdigest() - home_cache_dir = os.path.join(os.path.expanduser('~'), '.cache', 'olmocr', cache_key) - os.makedirs(home_cache_dir, exist_ok=True) - self.db_path = os.path.join(home_cache_dir, 'index.db') - - self.conn = sqlite3.connect(self.db_path) - # Enable WAL mode so you can read and write concurrently - self.cursor = self.conn.cursor() - self.cursor.execute("PRAGMA journal_mode=WAL;") - - if not skip_init: - self._initialize_tables() - - def _initialize_tables(self): - self.cursor.execute(""" - CREATE TABLE IF NOT EXISTS page_results ( - inference_s3_path TEXT, - pdf_s3_path TEXT, - page_num INTEGER, - round INTEGER, - start_index BIGINT, - length BIGINT, - finish_reason TEXT, - error TEXT - ) - """) - self.cursor.execute(""" - CREATE INDEX IF NOT EXISTS idx_path ON page_results(pdf_s3_path) - """) - self.cursor.execute(""" - CREATE INDEX IF NOT EXISTS idx_inf_path ON page_results(inference_s3_path) - """) - self.cursor.execute(""" - CREATE TABLE IF NOT EXISTS pdfs ( - s3_path TEXT PRIMARY KEY, - num_pages INTEGER, - status TEXT DEFAULT 'pending' - ) - """) - self.cursor.execute(""" - CREATE TABLE IF NOT EXISTS processed_files ( - s3_path TEXT PRIMARY KEY, - etag TEXT - ) - """) - # Generic metadata such as current round - self.cursor.execute(""" - CREATE TABLE IF NOT EXISTS metadata ( - key TEXT PRIMARY KEY, - value TEXT - ) - """) - - self.conn.commit() - - def get_metadata(self, key: str) -> Optional[str]: - self.cursor.execute("SELECT value FROM metadata WHERE key=?", (key,)) - result = self.cursor.fetchone() - return result[0] if result else None - - def set_metadata(self, key: str, value: str) -> None: - self.cursor.execute(""" - INSERT INTO metadata (key, value) - VALUES (?, ?) - ON CONFLICT(key) DO UPDATE SET value=excluded.value - """, (key, value)) - self.conn.commit() - - def is_file_processed(self, s3_path, etag): - self.cursor.execute("SELECT etag FROM processed_files WHERE s3_path = ?", (s3_path,)) - result = self.cursor.fetchone() - return result is not None and result[0] == etag - - def update_processed_file(self, s3_path, etag): - self.cursor.execute(""" - INSERT INTO processed_files (s3_path, etag) - VALUES (?, ?) - ON CONFLICT(s3_path) DO UPDATE SET etag=excluded.etag - """, (s3_path, etag)) - self.conn.commit() - - def clear_index(self): - self.cursor.execute(""" - DELETE FROM processed_files; - """) - self.cursor.execute(""" - DELETE FROM page_results; - """) - self.conn.commit() - - def add_index_entries(self, index_entries: List['BatchInferenceRecord']): - if index_entries: - self.cursor.executemany(""" - INSERT INTO page_results (inference_s3_path, pdf_s3_path, page_num, round, start_index, length, finish_reason, error) - VALUES (?, ?, ?, ?, ?, ?, ?, ?) - """, [(entry.inference_s3_path, entry.pdf_s3_path, entry.page_num, entry.round, entry.start_index, entry.length, entry.finish_reason, entry.error) for entry in index_entries]) - self.conn.commit() - - def get_index_entries(self, pdf_s3_path: str) -> List['BatchInferenceRecord']: - self.cursor.execute(""" - SELECT inference_s3_path, pdf_s3_path, page_num, round, start_index, length, finish_reason, error - FROM page_results - WHERE pdf_s3_path = ? - ORDER BY inference_s3_path DESC, start_index ASC, page_num ASC - """, (pdf_s3_path,)) - - rows = self.cursor.fetchall() - - return [ - self.BatchInferenceRecord( - inference_s3_path=row[0], - pdf_s3_path=row[1], - page_num=row[2], - round=row[3], - start_index=row[4], - length=row[5], - finish_reason=row[6], - error=row[7] - ) - for row in rows - ] - - def delete_index_entries_by_inference_s3_path(self, inference_s3_path: str): - self.cursor.execute("DELETE FROM page_results WHERE inference_s3_path = ?", (inference_s3_path,)) - self.conn.commit() - - def get_last_indexed_round(self) -> int: - self.cursor.execute(""" - SELECT MAX(round) - FROM page_results - """) - - result = self.cursor.fetchone() - return -1 if result[0] is None else result[0] - - def pdf_exists(self, s3_path: str) -> bool: - self.cursor.execute("SELECT 1 FROM pdfs WHERE s3_path = ?", (s3_path,)) - return self.cursor.fetchone() is not None - - def add_pdf(self, s3_path: str, num_pages: int, status: str = 'pending') -> None: - try: - self.cursor.execute(""" - INSERT INTO pdfs (s3_path, num_pages, status) - VALUES (?, ?, ?) - """, (s3_path, num_pages, status)) - self.conn.commit() - except sqlite3.IntegrityError: - logger.warning(f"PDF with s3_path '{s3_path}' already exists.") - - def update_pdf_statuses(self, status_updates: Dict[str, str]) -> None: - """ - Update the status of multiple PDFs in the database. - - :param status_updates: A dictionary where each key is an s3_path (str) and - each value is the new status (str) for that PDF. - """ - self.cursor.executemany(""" - UPDATE pdfs - SET status = ? - WHERE s3_path = ? - """, [(new_status, s3_path) for s3_path, new_status in status_updates.items()]) - self.conn.commit() - - def get_pdf(self, s3_path: str) -> Optional['PDFRecord']: - self.cursor.execute(""" - SELECT s3_path, num_pages, status - FROM pdfs - WHERE s3_path = ? - """, (s3_path,)) - - row = self.cursor.fetchone() - - if row: - return self.PDFRecord( - s3_path=row[0], - num_pages=row[1], - status=row[2] - ) - return None - - def get_pdfs_by_status(self, status: str) -> List['PDFRecord']: - self.cursor.execute(""" - SELECT s3_path, num_pages, status - FROM pdfs - WHERE status == ? - ORDER BY s3_path DESC, num_pages DESC - """, (status, )) - - rows = self.cursor.fetchall() - - return [ - self.PDFRecord( - s3_path=row[0], - num_pages=row[1], - status=row[2] - ) - for row in rows - ] - - def close(self): - self.conn.close() - - -class BatchWriter: - def __init__( - self, - output_prefix: str, - max_size_mb: int = 250, - after_flush: Optional[Callable[[List[Any]], Any]] = None, - ): - self.output_prefix = output_prefix - self.max_size = max_size_mb * 1024 * 1024 # Convert MB to bytes - self.batch_objects = [] - self.batch_size = 0 - self.after_flush = after_flush - self.threads = [] - self.temp_file = None # The temporary file object - self.temp_file_path = None # Path to the temporary file - - parsed = urlparse(output_prefix) - self.is_s3 = parsed.scheme in ("s3", "s3a", "s3n") - - if not self.is_s3: - os.makedirs(output_prefix, exist_ok=True) - - def write_line(self, obj: Optional[Any]): - if obj is None: - return - - line_bytes = orjson.dumps(obj) - line_size = len(line_bytes) + 1 # +1 for newline - - if self.batch_size + line_size > self.max_size: - self._write_batch() - - if self.batch_size == 0: - # Open a new temporary file - self.temp_file = tempfile.NamedTemporaryFile(mode="wb+", delete=False) - self.temp_file_path = self.temp_file.name - - self.temp_file.write(line_bytes + b"\n") - self.batch_objects.append(obj) - self.batch_size += line_size - - def _write_batch(self): - if self.batch_size == 0: - return - - # Close the temp file - self.temp_file.flush() - self.temp_file.close() - - # Start a new thread to upload the temp file - thread = threading.Thread( - target=self._write_batch_to_file, args=(self.temp_file_path, self.batch_objects) - ) - thread.start() - self.threads.append(thread) - - # Reset batch_objects and batch_size - self.batch_objects = [] - self.batch_size = 0 - self.temp_file = None - self.temp_file_path = None - - def _write_batch_to_file(self, temp_file_path: str, batch_objects: List[Any]): - # Compute hash based on file content - hash_str = self._compute_hash(temp_file_path) - output_path = self._get_output_path(hash_str) - - if self.is_s3: - bucket, key = parse_s3_path(output_path) - - # Use the s3 client directly - try: - workspace_s3.upload_file(temp_file_path, bucket, key) - except Exception as e: - logger.error(f"Failed to upload {temp_file_path} to {output_path}: {e}", exc_info=True) - else: - # Move the temp file to the output path - os.rename(temp_file_path, output_path) - - # After writing, call the after_flush callback if it is set - if self.after_flush: - self.after_flush(batch_objects) - - os.remove(temp_file_path) - - def _compute_hash(self, temp_file_path: str) -> str: - """Compute a 20-character SHA1 hash of the file content.""" - sha1 = hashlib.sha1() - with open(temp_file_path, "rb") as f: - while True: - data = f.read(1024*1024) - if not data: - break - sha1.update(data) - return sha1.hexdigest()[:20] - - def _get_output_path(self, hash_str: str) -> str: - """Generate the full output path with hash in the filename.""" - parsed = urlparse(self.output_prefix) - if self.is_s3: - bucket = parsed.netloc - key = parsed.path.lstrip("/") - if key and not key.endswith("/"): - key += "/" - full_key = posixpath.join(key, f"output_{hash_str}.jsonl") - return f"s3://{bucket}/{full_key}" - else: - filename = f"output_{hash_str}.jsonl" - return os.path.join(self.output_prefix, filename) - - def close(self): - self._write_batch() - # Wait for all threads to finish - for thread in self.threads: - thread.join() - - -def build_page_query(local_pdf_path: str, pretty_pdf_path: str, page: int, target_longest_image_dim: int, target_anchor_text_len: int, image_rotation: int=0) -> dict: - assert image_rotation in [0, 90, 180, 270], "Invalid image rotation provided in build_page_query" - image_base64 = render_pdf_to_base64png(local_pdf_path, page, target_longest_image_dim=target_longest_image_dim) - - if image_rotation != 0: - image_bytes = base64.b64decode(image_base64) - with Image.open(BytesIO(image_bytes)) as img: - rotated_img = img.rotate(-image_rotation, expand=True) - - # Save the rotated image to a bytes buffer - buffered = BytesIO() - rotated_img.save(buffered, format="PNG") - - # Encode the rotated image back to base64 - image_base64 = base64.b64encode(buffered.getvalue()).decode('utf-8') - - - anchor_text = get_anchor_text(local_pdf_path, page, pdf_engine="pdfreport", target_length=target_anchor_text_len) - - return { - "custom_id": f"{pretty_pdf_path}-{page}", - "chat_messages": [ - { - "role": "user", - "content": [ - {"type": "text", "text": build_finetuning_prompt(anchor_text)}, - {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_base64}"}} - ], - } - ], - } - - -def process_jsonl_content(inference_s3_path: str) -> List[DatabaseManager.BatchInferenceRecord]: - content_bytes = get_s3_bytes(workspace_s3, inference_s3_path) - - start_index = 0 - index_entries = [] - lines = content_bytes.splitlines(keepends=True) # Split content into lines as bytes - for line in lines: - line_length = len(line) # Length in bytes - - try: - # Parse the line directly as JSON - data = orjson.loads(line) - pdf_s3_path, page_num = parse_custom_id(data["custom_id"]) - - if data.get("completion_error", None) is not None: - index_entries.append(DatabaseManager.BatchInferenceRecord( - inference_s3_path=inference_s3_path, - pdf_s3_path=pdf_s3_path, - page_num=page_num, - round=data["round"], - start_index=start_index, # Byte offset in the original file - length=line_length, # Length in bytes - finish_reason="completion_error", - error=data.get("completion_error", None) - )) - else: - # Try to parse the actual model response JSON - assert "outputs" in data and len(data["outputs"]) > 0, "No outputs from model detected" - - try: - model_response_json = orjson.loads(data["outputs"][0]["text"]) - page_response = PageResponse(**model_response_json) - - last_error = data.get("completion_error", None) - - if not page_response.is_rotation_valid: - last_error = "rotation_invalid" - - index_entries.append(DatabaseManager.BatchInferenceRecord( - inference_s3_path=inference_s3_path, - pdf_s3_path=pdf_s3_path, - page_num=page_num, - round=data["round"], - start_index=start_index, # Byte offset in the original file - length=line_length, # Length in bytes - finish_reason=data["outputs"][0]["finish_reason"], - error=last_error, - )) - except Exception as e: - error_type = type(e).__name__ - index_entries.append(DatabaseManager.BatchInferenceRecord( - inference_s3_path=inference_s3_path, - pdf_s3_path=pdf_s3_path, - page_num=page_num, - round=data["round"], - start_index=start_index, # Byte offset in the original file - length=line_length, # Length in bytes - finish_reason=data["outputs"][0]["finish_reason"], - error=error_type, - )) - - except Exception as e: - logger.exception(f"Error processing line in {inference_s3_path}: {e}") - # Optionally, you might want to add an index entry indicating an error here - - start_index += line_length # Increment by the number of bytes - - return index_entries - - -def get_pdf_num_pages(s3_path: str) -> Optional[int]: - logger.debug(f"Startng to get_pdf_num_pages for {s3_path}") - try: - with tempfile.NamedTemporaryFile("wb+", suffix=".pdf") as tf: - tf.write(get_s3_bytes(pdf_s3, s3_path)) - tf.flush() - - reader = PdfReader(tf.name) - logger.debug(f"Built reader for {s3_path}") - return reader.get_num_pages() - except Exception as ex: - logger.warning(f"Warning, could not add {s3_path} due to {ex}") - - return None - - -def _get_page_data(page_index_entries: List[DatabaseManager.BatchInferenceRecord]) -> List[PageResponse]: - usable_page_data = [get_s3_bytes(workspace_s3, page.inference_s3_path, - start_index=page.start_index, - end_index=page.start_index + page.length - 1) for page in page_index_entries] - - usable_page_final_results = [] - for page_data in usable_page_data: - data = orjson.loads(page_data) - model_response_json = orjson.loads(data["outputs"][0]["text"]) - page_response = PageResponse(**model_response_json) - usable_page_final_results.append(page_response) - - return usable_page_final_results - - -def build_pdf_queries(s3_workspace: str, pdf: DatabaseManager.PDFRecord, cur_round: int, target_longest_image_dim: int, target_anchor_text_len: int) -> list[dict]: - db = DatabaseManager(s3_workspace, skip_init=True) - - existing_pages = db.get_index_entries(pdf.s3_path) - new_queries = [] - - # Shortcut out of downloading the actual PDF - if set(page.page_num for page in existing_pages if page.is_usable()) == set(range(1, pdf.num_pages + 1)): - db.close() - return [] - - try: - with tempfile.NamedTemporaryFile("wb+", suffix=".pdf") as tf: - tf.write(get_s3_bytes(pdf_s3, pdf.s3_path)) - tf.flush() - - for target_page_num in range(1, pdf.num_pages + 1): - # Is there an existing page that has no error - if any(page.is_usable() and page.page_num == target_page_num for page in existing_pages): - continue - - has_errored_previously = sum(page.page_num == target_page_num for page in existing_pages) - - if has_errored_previously: - # Retry the page at least one more time regularly - new_queries.append({**build_page_query(tf.name, pdf.s3_path, target_page_num, target_longest_image_dim, target_anchor_text_len), "round": cur_round}) - - # If the rotation was previously invalid, then apply a rotation - rotated_page_data = _get_page_data([page for page in existing_pages if page.page_num == target_page_num and page.error == "rotation_invalid"]) - rotation_corrections = set(page_data.rotation_correction for page_data in rotated_page_data) - for correction in rotation_corrections: - logger.debug(f"Adding {correction}-degree rotation query for {pdf.s3_path}-{target_page_num}") - new_queries.append({**build_page_query(tf.name, pdf.s3_path, target_page_num, target_longest_image_dim, target_anchor_text_len, image_rotation=correction), "round": cur_round}) - - # TODO: Try to provide a smaller prompt hint if that was the error - else: - new_queries.append({**build_page_query(tf.name, pdf.s3_path, target_page_num, target_longest_image_dim, target_anchor_text_len), "round": cur_round}) - except Exception as ex: - logger.warning(f"Warning, could not get batch inferences lines for {pdf.s3_path} due to {ex}") - - db.close() - return new_queries - -def build_dolma_doc(s3_workspace: str, pdf: DatabaseManager.PDFRecord) -> Optional[dict]: - db = DatabaseManager(s3_workspace, skip_init=True) - existing_pages = db.get_index_entries(pdf.s3_path) - document_text = "" - last_page_start_index = 0 - pdf_page_spans = [] - - # Error out quickly if this document cannot be assembled - for target_page_num in range(1, pdf.num_pages + 1): - usable_pages = [page for page in existing_pages if page.is_usable() and page.page_num == target_page_num] - - if len(usable_pages) == 0: - db.close() - return None - - for target_page_num in range(1, pdf.num_pages + 1): - usable_pages = [page for page in existing_pages if page.is_usable() and page.page_num == target_page_num] - usable_page_final_results = _get_page_data(usable_pages) - - # Sort the pages: - # 1. Prefer pages with `is_rotation_valid` set to True. - # 2. Within those, sort by the length of the `natural_text` in descending order. - usable_page_final_results.sort( - key=lambda page: (not page.is_rotation_valid, -len(page.natural_text or "")) - ) - - target_page_final_result = usable_page_final_results[0] - - if target_page_final_result.natural_text is not None: - document_text += target_page_final_result.natural_text + "\n" - - pdf_page_spans.append([last_page_start_index, len(document_text), target_page_num]) - last_page_start_index = len(document_text) - - metadata = { - "Source-File": pdf.s3_path, - "pdf-total-pages": pdf.num_pages, - } - id_ = hashlib.sha1(document_text.encode()).hexdigest() - - dolma_doc = { - "id": id_, - "text": document_text, - "source": "olmocr", - "added": datetime.datetime.now().strftime("%Y-%m-%d"), - "created": datetime.datetime.now().strftime("%Y-%m-%d"), - "metadata": metadata, - "attributes": { - "pdf_page_numbers": pdf_page_spans - } - } - - db.close() - return dolma_doc - -def mark_pdfs_done(s3_workspace: str, dolma_docs: list[dict]): - db = DatabaseManager(s3_workspace, skip_init=True) - db.update_pdf_statuses({doc["metadata"]["Source-File"]: "completed" for doc in dolma_docs}) - db.close() - -def get_current_round(s3_workspace: str) -> int: - path = s3_workspace[5:] - bucket, _, prefix = path.partition('/') - - inference_inputs_prefix = posixpath.join(prefix, 'inference_inputs/') - paginator = workspace_s3.get_paginator('list_objects_v2') - page_iterator = paginator.paginate(Bucket=bucket, Prefix=inference_inputs_prefix, Delimiter='/') - - round_numbers = [] - for page in page_iterator: - for common_prefix in page.get('CommonPrefixes', []): - round_prefix = common_prefix.get('Prefix') - # Extract 'round_X' from the prefix - round_dir = posixpath.basename(posixpath.dirname(round_prefix)) - if round_dir.startswith('round_'): - try: - round_num = int(round_dir[len('round_'):]) - round_numbers.append(round_num) - except ValueError: - pass - if round_numbers: - current_round = max(round_numbers) + 1 - else: - current_round = 0 - return current_round - - -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='Manager for running millions of PDFs through a batch inference pipeline') - parser.add_argument('workspace', help='The S3 path where work will be done e.g., s3://bucket/prefix/)') - parser.add_argument('--add_pdfs', help='Path to add pdfs stored in s3 to the workspace, can be a glob path s3://bucket/prefix/*.pdf or path to file containing list of pdf paths', default=None) - parser.add_argument('--target_longest_image_dim', type=int, help='Dimension on longest side to use for rendering the pdf pages', default=1024) - parser.add_argument('--target_anchor_text_len', type=int, help='Maximum amount of anchor text to use (characters)', default=6000) - parser.add_argument('--workspace_profile', help='S3 configuration profile for accessing the workspace', default=None) - parser.add_argument('--pdf_profile', help='S3 configuration profile for accessing the raw pdf documents', default=None) - parser.add_argument('--max_size_mb', type=int, default=250, help='Max file size in MB') - parser.add_argument('--workers', type=int, help='Number of workers to run in the processpool') - parser.add_argument('--reindex', action='store_true', default=False, help='Reindex all of the page_results') - parser.add_argument('--skip_build_queries', action='store_true', default=False, help='Skip generation of new pdf page queries for batch inferencing') - args = parser.parse_args() - - if args.workspace_profile: - workspace_session = boto3.Session(profile_name=args.workspace_profile) - workspace_s3 = workspace_session.client("s3") - - if args.pdf_profile: - pdf_session = boto3.Session(profile_name=args.pdf_profile) - pdf_s3 = pdf_session.client("s3") - - db = DatabaseManager(args.workspace) - logger.info(f"Loaded db at {db.db_path}") - - if args.reindex: - db.clear_index() - logger.info("Cleared existing index.") - - current_round = get_current_round(args.workspace) - logger.info(f"Current round is {current_round}") - - check_poppler_version() - - # One shared executor to rule them all - executor = ProcessPoolExecutor(max_workers=args.workers) - - # If you have new PDFs, step one is to add them to the list - if args.add_pdfs: - if args.add_pdfs.startswith("s3://"): - logger.info(f"Querying all PDFs at {args.add_pdfs}") - - all_pdfs = expand_s3_glob(pdf_s3, args.add_pdfs) - logger.info(f"Found {len(all_pdfs):,} total pdf paths") - elif os.path.exists(args.add_pdfs): - with open(args.add_pdfs, "r") as f: - all_pdfs = [line.strip() for line in f.readlines() if len(line.strip()) > 0] - else: - raise ValueError("add_pdfs argument needs to be either an s3 glob search path, or a local file contains pdf paths (one per line)") - - all_pdfs = [pdf for pdf in all_pdfs if not db.pdf_exists(pdf)] - logger.info(f"Need to import {len(all_pdfs):,} total new pdf paths") - - future_to_path = {executor.submit(get_pdf_num_pages, s3_path): s3_path for s3_path in all_pdfs} - for future in tqdm(as_completed(future_to_path), total=len(future_to_path), desc="Adding PDFs"): - s3_path = future_to_path[future] - num_pages = future.result() - logger.debug(f"Got {num_pages} pages back for {s3_path}") - if num_pages and not db.pdf_exists(s3_path): - db.add_pdf(s3_path, num_pages, "pending") - - logger.info("Completed adding new PDFs.") - - # Now build an index of all the pages that were processed within the workspace so far - logger.info("Indexing all batch inference sent to this workspace") - inference_output_paths = expand_s3_glob(workspace_s3, f"{args.workspace}/inference_outputs/*.jsonl") - - inference_output_paths = { - s3_path: etag for s3_path, etag in inference_output_paths.items() - if not db.is_file_processed(s3_path, etag) - } - - logger.info(f"Found {len(inference_output_paths):,} new batch inference results to index") - future_to_path = {executor.submit(process_jsonl_content, s3_path): (s3_path, etag) for s3_path, etag in inference_output_paths.items()} - - for future in tqdm(as_completed(future_to_path), total=len(future_to_path), desc="Indexing Inference Results"): - s3_path, etag = future_to_path.pop(future) - try: - inference_records = future.result() - - db.delete_index_entries_by_inference_s3_path(s3_path) - db.add_index_entries(inference_records) - db.update_processed_file(s3_path, etag=etag) - except urllib3.exceptions.SSLError: - logger.warning(f"Cannot load inference file {s3_path} due to SSL error, will retry another time") - except Exception as e: - logger.exception(f"Failed to index inference file {s3_path}: {e}") - - # Now query each pdf, if you have all of the pages needed (all pages present, error is null and finish_reason is stop), then you assemble it into a dolma document and output it - # If you don't have every page, or if you have pages with errors, then you output a new batch of inference items to use - if db.get_last_indexed_round() < current_round - 1: - logger.warning(f"WARNING: No new batch inference results found, you need to run batch inference on {args.workspace}/inference_inputs/round_{current_round - 1}") - potentially_done_pdfs = db.get_pdfs_by_status("pending") - elif args.skip_build_queries: - logger.info(f"Skipping generating new batch inference files") - potentially_done_pdfs = db.get_pdfs_by_status("pending") - else: - logger.info("Creating batch inference files for new PDFs") - pdf_list = list(db.get_pdfs_by_status("pending")) - pdf_iter = iter(pdf_list) - pending_futures = {} - potentially_done_pdfs = [] - lines_written = 0 - new_inference_writer = BatchWriter(f"{args.workspace}/inference_inputs/round_{current_round}", args.max_size_mb) - total_pdfs = len(pdf_list) - max_pending = 300 - - with tqdm(total=total_pdfs, desc="Building Batch Queries") as pbar: - # Submit initial batch of futures - for _ in range(min(max_pending, total_pdfs)): - pdf = next(pdf_iter) - future = executor.submit( - build_pdf_queries, args.workspace, pdf, current_round, args.target_longest_image_dim,args.target_anchor_text_len, - ) - pending_futures[future] = pdf - - while pending_futures: - # Wait for the next future to complete - done, _ = concurrent.futures.wait( - pending_futures.keys(), - return_when=concurrent.futures.FIRST_COMPLETED, - ) - - for future in done: - pdf = pending_futures.pop(future) - inference_lines = future.result() - - if len(inference_lines) == 0: - potentially_done_pdfs.append(pdf) - - for line in inference_lines: - lines_written += 1 - - if line is not None: - new_inference_writer.write_line(line) - - pbar.update(1) - - # Submit a new future if there are more PDFs - try: - pdf = next(pdf_iter) - future = executor.submit( - build_pdf_queries, args.workspace, pdf, current_round, args.target_longest_image_dim,args.target_anchor_text_len, - ) - pending_futures[future] = pdf - except StopIteration: - pass # No more PDFs to process - - new_inference_writer.close() - - if lines_written > 0: - logger.info(f"Added {lines_written:,} new batch inference requests") - - # Now, finally, assemble any potentially done docs into dolma documents - logger.info(f"Assembling potentially finished PDFs into Dolma documents at {args.workspace}/output") - future_to_path = {executor.submit(build_dolma_doc, args.workspace, pdf): pdf for pdf in potentially_done_pdfs} - new_output_writer = BatchWriter(f"{args.workspace}/output", args.max_size_mb, after_flush=partial(mark_pdfs_done, args.workspace)) - - for future in tqdm(as_completed(future_to_path), total=len(future_to_path), desc="Assembling Dolma Docs"): - pdf = future_to_path.pop(future) - dolma_doc = future.result() - - if dolma_doc is not None: - new_output_writer.write_line(dolma_doc) - - new_output_writer.close() - - logger.info("Final statistics:") - - # Output the number of PDFs in each status "pending" and "completed" - pending_pdfs = db.get_pdfs_by_status("pending") - completed_pdfs = db.get_pdfs_by_status("completed") - - logger.info(f"Pending PDFs: {len(pending_pdfs):,} ({sum(doc.num_pages for doc in pending_pdfs):,} pages)") - logger.info(f"Completed PDFs: {len(completed_pdfs):,} ({sum(doc.num_pages for doc in completed_pdfs):,} pages)") - - # For each round, outputs a report of how many pages were processed, how many had errors, and a breakdown by (error, finish_reason) - total_rounds = db.get_last_indexed_round() + 1 - for round_num in range(total_rounds): - db.cursor.execute(""" - SELECT COUNT(*), error, finish_reason - FROM page_results - WHERE round = ? - GROUP BY error, finish_reason - """, (round_num,)) - - results = db.cursor.fetchall() - - total_pages = sum(count for count, _, _ in results) - logger.info(f"Inference Round {round_num} - {total_pages:,} pages processed:") - - for count, error, finish_reason in results: - error_str = error if error is not None else "None" - logger.info(f" (error: {error_str}, finish_reason: {finish_reason}) -> {count:,} pages") - - logger.info("Work finished, waiting for all workers to finish cleaning up") - executor.shutdown(wait=True) - db.close() diff --git a/olmocr/cappedpool.py b/olmocr/cappedpool.py deleted file mode 100644 index f6ffc27..0000000 --- a/olmocr/cappedpool.py +++ /dev/null @@ -1,116 +0,0 @@ -import concurrent.futures -import threading -import queue - -class CappedFuture(concurrent.futures.Future): - def __init__(self, semaphore): - super().__init__() - self._semaphore = semaphore - self._result_retrieved = False - self._underlying_future = None - self._condition = threading.Condition() - - def set_underlying_future(self, underlying_future): - with self._condition: - self._underlying_future = underlying_future - # Transfer the result when the underlying future completes - underlying_future.add_done_callback(self._transfer_result) - - def _transfer_result(self, underlying_future): - if underlying_future.cancelled(): - self.set_cancelled() - elif underlying_future.exception() is not None: - self.set_exception(underlying_future.exception()) - else: - try: - result = underlying_future.result() - self.set_result(result) - except Exception as e: - self.set_exception(e) - - def result(self, timeout=None): - res = super().result(timeout) - self._release_semaphore() - return res - - def exception(self, timeout=None): - exc = super().exception(timeout) - self._release_semaphore() - return exc - - def _release_semaphore(self): - if not self._result_retrieved: - self._result_retrieved = True - self._semaphore.release() - - def cancel(self): - with self._condition: - if self._underlying_future is not None: - cancelled = self._underlying_future.cancel() - if cancelled: - super().cancel() - return cancelled - else: - # Task has not been submitted yet; cancel directly - return super().cancel() - - def cancelled(self): - return super().cancelled() - - def running(self): - with self._condition: - if self._underlying_future is not None: - return self._underlying_future.running() - else: - return False - - def done(self): - return super().done() - -class CappedProcessPoolExecutor(concurrent.futures.Executor): - def __init__(self, max_unprocessed=100, max_workers=None): - self._max_unprocessed = max_unprocessed - self._semaphore = threading.BoundedSemaphore(max_unprocessed) - self._task_queue = queue.Queue() - self._shutdown = threading.Event() - self._shutdown_lock = threading.Lock() - self._executor = concurrent.futures.ProcessPoolExecutor(max_workers=max_workers) - self._worker_thread = threading.Thread(target=self._worker) - self._worker_thread.daemon = True - self._worker_thread.start() - - def submit(self, fn, *args, **kwargs): - if self._shutdown.is_set(): - raise RuntimeError('Cannot submit new tasks after shutdown') - # Create a CappedFuture to return to the user - user_future = CappedFuture(self._semaphore) - # Put the task in the queue - self._task_queue.put((user_future, fn, args, kwargs)) - return user_future - - def _worker(self): - while True: - if self._shutdown.is_set() and self._task_queue.empty(): - break - try: - user_future, fn, args, kwargs = self._task_queue.get(timeout=0.1) - except queue.Empty: - continue - self._semaphore.acquire() - if user_future.cancelled(): - self._semaphore.release() - continue - # Submit the task to the underlying executor - try: - underlying_future = self._executor.submit(fn, *args, **kwargs) - user_future.set_underlying_future(underlying_future) - except Exception as e: - user_future.set_exception(e) - self._semaphore.release() - continue - - def shutdown(self, wait=True): - with self._shutdown_lock: - self._shutdown.set() - self._worker_thread.join() - self._executor.shutdown(wait=wait) diff --git a/olmocr/beakerpipeline.py b/olmocr/pipeline.py similarity index 100% rename from olmocr/beakerpipeline.py rename to olmocr/pipeline.py diff --git a/olmocr/s3_queue.py b/olmocr/s3_queue.py index c4e3a26..a1295e9 100644 --- a/olmocr/s3_queue.py +++ b/olmocr/s3_queue.py @@ -15,7 +15,7 @@ from olmocr.s3_utils import ( upload_zstd_csv, parse_s3_path ) -from pypdf import PdfReader + logger = logging.getLogger(__name__) diff --git a/tests/test_cappedpool.py b/tests/test_cappedpool.py deleted file mode 100644 index 3f7912d..0000000 --- a/tests/test_cappedpool.py +++ /dev/null @@ -1,99 +0,0 @@ -import unittest -import time -import concurrent.futures -from concurrent.futures import TimeoutError - -# Assuming the CappedProcessPoolExecutor code is in a module named 'capped_executor' -from olmocr.cappedpool import CappedProcessPoolExecutor - -# Define functions at the top level to ensure they are picklable by multiprocessing - -def square(x): - return x * x - -def raise_exception(): - raise ValueError("Test exception") - -def sleep_and_return(x, sleep_time): - time.sleep(sleep_time) - return x - -def task(counter, max_counter, counter_lock): - with counter_lock: - counter.value += 1 - print(f"Task incrementing counter to {counter.value}") - if counter.value > max_counter.value: - max_counter.value = counter.value - time.sleep(0.5) - with counter_lock: - counter.value -= 1 - return True - -class TestCappedProcessPoolExecutor(unittest.TestCase): - - def test_basic_functionality(self): - """Test that tasks are executed and results are correct.""" - with CappedProcessPoolExecutor(max_unprocessed=10, max_workers=4) as executor: - futures = [executor.submit(square, i) for i in range(10)] - results = [f.result() for f in futures] - expected = [i * i for i in range(10)] - self.assertEqual(results, expected) - - def test_exception_handling(self): - """Test that exceptions in tasks are properly raised.""" - with CappedProcessPoolExecutor(max_unprocessed=10, max_workers=4) as executor: - future = executor.submit(raise_exception) - with self.assertRaises(ValueError): - future.result() - - def test_cancellation(self): - """Test that tasks can be cancelled before execution.""" - with CappedProcessPoolExecutor(max_unprocessed=10, max_workers=4) as executor: - future = executor.submit(time.sleep, 5) - # Try to cancel immediately - cancelled = future.cancel() - self.assertTrue(cancelled) - self.assertTrue(future.cancelled()) - # Attempt to get result; should raise CancelledError - with self.assertRaises(concurrent.futures.CancelledError): - future.result() - - def test_shutdown(self): - """Test that the executor shuts down properly and does not accept new tasks.""" - executor = CappedProcessPoolExecutor(max_unprocessed=10, max_workers=4) - future = executor.submit(time.sleep, 1) - executor.shutdown(wait=True) - with self.assertRaises(RuntimeError): - executor.submit(time.sleep, 1) - - def test_capping_behavior(self): - """Test that the number of concurrent tasks does not exceed max_unprocessed.""" - max_unprocessed = 3 - with CappedProcessPoolExecutor(max_unprocessed=max_unprocessed, max_workers=10) as executor: - from multiprocessing import Manager - - manager = Manager() - counter = manager.Value('i', 0) - max_counter = manager.Value('i', 0) - counter_lock = manager.Lock() - - futures = [executor.submit(task, counter, max_counter, counter_lock) for _ in range(10)] - - for index, f in enumerate(futures): - print(f"Future {index} returned {f.result()}") - - time.sleep(1) - - print(max_counter.value) - self.assertLessEqual(max_counter.value, max_unprocessed) - - def test_submit_after_shutdown(self): - """Test that submitting tasks after shutdown raises an error.""" - executor = CappedProcessPoolExecutor(max_unprocessed=10, max_workers=4) - executor.shutdown(wait=True) - with self.assertRaises(RuntimeError): - executor.submit(square, 2) - - -if __name__ == '__main__': - unittest.main() diff --git a/tests/test_sglang.py b/tests/test_sglang.py index 5ca5b85..6b3b3c1 100644 --- a/tests/test_sglang.py +++ b/tests/test_sglang.py @@ -17,7 +17,7 @@ from io import BytesIO from PIL import Image from transformers import AutoProcessor, AutoTokenizer, Qwen2VLForConditionalGeneration from pathlib import Path -from olmocr.beakerpipeline import sglang_server_task, sglang_server_ready, build_page_query, SGLANG_SERVER_PORT, render_pdf_to_base64png, get_anchor_text, download_directory +from olmocr.pipeline import sglang_server_task, sglang_server_ready, build_page_query, SGLANG_SERVER_PORT, render_pdf_to_base64png, get_anchor_text, download_directory from olmocr.prompts import PageResponse from httpx import AsyncClient import torch.nn.functional as F