olmocr/pdelfin/assemblepipeline.py

286 lines
10 KiB
Python
Raw Normal View History

import os
import hashlib
2024-10-10 22:10:26 +00:00
import boto3
import sqlite3
2024-10-10 22:10:26 +00:00
import json
import argparse
2024-10-11 21:50:09 +00:00
import glob
import tempfile
import posixpath
from pypdf import PdfReader
from tqdm import tqdm
2024-10-11 21:50:09 +00:00
from typing import Optional
from urllib.parse import urlparse
from concurrent.futures import ProcessPoolExecutor, as_completed
2024-10-11 21:50:09 +00:00
# Global s3 client for the whole script, feel free to adjust params if you need it
s3 = boto3.client('s3')
2024-10-11 16:24:29 +00:00
class DatabaseManager:
2024-10-11 20:22:58 +00:00
def __init__(self, s3_workspace: str):
cache_key = hashlib.sha256(s3_workspace.strip().lower().encode('utf-8')).hexdigest()
home_cache_dir = os.path.join(os.path.expanduser('~'), '.cache', 'pdelfin', cache_key)
os.makedirs(home_cache_dir, exist_ok=True)
self.db_path = os.path.join(home_cache_dir, 'index.db')
2024-10-11 16:24:29 +00:00
self.conn = sqlite3.connect(self.db_path)
self.cursor = self.conn.cursor()
self._initialize_tables()
def _initialize_tables(self):
self.cursor.execute("""
CREATE TABLE IF NOT EXISTS index_table (
custom_id TEXT,
s3_path TEXT,
start_index BIGINT,
end_index BIGINT
)
""")
2024-10-11 20:22:58 +00:00
self.cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_custom_id ON index_table(custom_id)
""")
self.cursor.execute("""
CREATE TABLE IF NOT EXISTS pdfs (
s3_path TEXT PRIMARY KEY,
num_pages INTEGER,
status TEXT DEFAULT 'pending'
)
""")
2024-10-11 16:24:29 +00:00
self.cursor.execute("""
CREATE TABLE IF NOT EXISTS processed_files (
s3_path TEXT PRIMARY KEY,
etag TEXT
)
""")
2024-10-11 20:22:58 +00:00
# Generic metadata such as current round
self.cursor.execute("""
CREATE TABLE IF NOT EXISTS metadata (
key TEXT PRIMARY KEY,
value TEXT
)
""")
self.cursor.execute("SELECT value FROM metadata WHERE key='round'")
if self.cursor.fetchone() is None:
self.cursor.execute("INSERT INTO metadata (key, value) VALUES ('round', '0')")
2024-10-11 16:24:29 +00:00
self.conn.commit()
2024-10-11 20:22:58 +00:00
def get_current_round(self):
self.cursor.execute("SELECT value FROM metadata WHERE key='round'")
result = self.cursor.fetchone()
return int(result[0])
2024-10-11 16:24:29 +00:00
def is_file_processed(self, s3_path, etag):
self.cursor.execute("SELECT etag FROM processed_files WHERE s3_path = ?", (s3_path,))
result = self.cursor.fetchone()
return result is not None and result[0] == etag
def add_index_entries(self, index_entries):
if index_entries:
self.cursor.executemany("""
INSERT INTO index_table (custom_id, s3_path, start_index, end_index)
VALUES (?, ?, ?, ?)
""", index_entries)
self.conn.commit()
def update_processed_file(self, s3_path, etag):
self.cursor.execute("""
INSERT INTO processed_files (s3_path, etag)
VALUES (?, ?)
ON CONFLICT(s3_path) DO UPDATE SET etag=excluded.etag
""", (s3_path, etag))
self.conn.commit()
2024-10-11 21:50:09 +00:00
def pdf_exists(self, s3_path: str) -> bool:
self.cursor.execute("SELECT 1 FROM pdfs WHERE s3_path = ?", (s3_path,))
return self.cursor.fetchone() is not None
def add_pdf(self, s3_path: str, num_pages: int, status: str = 'pending') -> None:
try:
self.cursor.execute("""
INSERT INTO pdfs (s3_path, num_pages, status)
VALUES (?, ?, ?)
""", (s3_path, num_pages, status))
self.conn.commit()
except sqlite3.IntegrityError:
print(f"PDF with s3_path '{s3_path}' already exists.")
def get_pdf_status(self, s3_path: str) -> Optional[str]:
self.cursor.execute("SELECT status FROM pdfs WHERE s3_path = ?", (s3_path,))
result = self.cursor.fetchone()
return result[0] if result else None
2024-10-11 16:24:29 +00:00
def close(self):
self.conn.close()
2024-10-10 22:10:26 +00:00
def build_index(s3_path):
2024-10-11 20:22:58 +00:00
db_manager = DatabaseManager(s3_path)
2024-10-10 22:10:26 +00:00
bucket, prefix = parse_s3_path(s3_path)
2024-10-10 22:10:26 +00:00
# List all .json and .jsonl files under s3_path with their ETags
2024-10-11 21:50:09 +00:00
files = expand_s3_glob(s3_path)
2024-10-10 22:13:43 +00:00
if not files:
print("No .json or .jsonl files found in the specified S3 path.")
2024-10-11 16:24:29 +00:00
db_manager.close()
2024-10-10 22:10:26 +00:00
return
# Prepare a list of files that need processing
2024-10-11 16:24:29 +00:00
files_to_process = [
(key, etag) for key, etag in files.items()
if not db_manager.is_file_processed(key, etag)
]
if not files_to_process:
print("All files are up to date. No processing needed.")
2024-10-11 16:24:29 +00:00
db_manager.close()
return
# Use ProcessPoolExecutor to process files with tqdm progress bar
with ProcessPoolExecutor() as executor:
2024-10-11 16:24:29 +00:00
futures = [
executor.submit(process_file, bucket, key, etag)
for key, etag in files_to_process
]
for future in tqdm(as_completed(futures), total=len(futures), desc="Processing files"):
s3_path, key, etag, index_entries = future.result()
if index_entries:
2024-10-11 16:24:29 +00:00
db_manager.add_index_entries(index_entries)
# Update the processed_files table
2024-10-11 16:24:29 +00:00
db_manager.update_processed_file(key, etag)
2024-10-11 16:24:29 +00:00
db_manager.close()
def parse_s3_path(s3_path):
2024-10-10 22:10:26 +00:00
if not s3_path.startswith('s3://'):
raise ValueError('s3_path must start with s3://')
path = s3_path[5:]
bucket, _, prefix = path.partition('/')
return bucket, prefix
2024-10-11 21:50:09 +00:00
def expand_s3_glob(s3_glob: str) -> dict[str, str]:
parsed = urlparse(s3_glob)
bucket_name = parsed.netloc
prefix = os.path.dirname(parsed.path.lstrip('/')).rstrip('/') + "/"
pattern = os.path.basename(parsed.path)
2024-10-10 22:10:26 +00:00
paginator = s3.get_paginator('list_objects_v2')
2024-10-11 21:50:09 +00:00
page_iterator = paginator.paginate(Bucket=bucket_name, Prefix=prefix)
matched_files = {}
2024-10-10 22:10:26 +00:00
for page in page_iterator:
2024-10-11 21:50:09 +00:00
for obj in page.get('Contents', []):
2024-10-10 22:10:26 +00:00
key = obj['Key']
2024-10-11 21:50:09 +00:00
if glob.fnmatch.fnmatch(key, posixpath.join(prefix, pattern)):
matched_files[f"s3://{bucket_name}/{key}"] = obj['ETag'].strip('"')
return matched_files
2024-10-10 22:10:26 +00:00
def process_file(bucket, key, etag):
s3 = boto3.client('s3') # Initialize s3 client in the worker process
2024-10-10 22:13:43 +00:00
s3_path = f's3://{bucket}/{key}'
2024-10-09 22:19:16 +00:00
try:
# Get the object
obj = s3.get_object(Bucket=bucket, Key=key)
# Read the content as bytes
content = obj['Body'].read()
# Process the file as JSONL
index_entries = process_jsonl_content(content, s3_path)
# Return the necessary data to the main process
return s3_path, key, etag, index_entries
2024-10-10 22:10:26 +00:00
except Exception as e:
2024-10-10 22:13:43 +00:00
print(f"Error processing file {s3_path}: {e}")
return s3_path, key, etag, []
2024-10-10 22:10:26 +00:00
def process_jsonl_content(content, s3_path):
2024-10-10 22:10:26 +00:00
start_index = 0
index_entries = []
2024-10-10 22:10:26 +00:00
lines = content.splitlines(keepends=True)
for line in lines:
line_length = len(line)
end_index = start_index + line_length
2024-10-09 22:19:16 +00:00
try:
2024-10-10 22:10:26 +00:00
data = json.loads(line)
custom_id = data.get('custom_id')
if custom_id:
index_entries.append((custom_id, s3_path, start_index, end_index))
2024-10-10 22:10:26 +00:00
except json.JSONDecodeError:
pass # Handle JSON decode errors if necessary
start_index = end_index
return index_entries
2024-10-09 22:19:16 +00:00
2024-10-11 21:50:09 +00:00
def get_s3_bytes(s3_path: str, start_index: Optional[int] = None, end_index: Optional[int] = None) -> bytes:
bucket, key = parse_s3_path(s3_path)
# Build the range header if start_index and/or end_index are specified
range_header = None
if start_index is not None or end_index is not None:
range_value = f"bytes={start_index or 0}-"
if end_index is not None:
range_value += str(end_index)
range_header = {'Range': range_value}
if range_header:
obj = s3.get_object(Bucket=bucket, Key=key, Range=range_header['Range'])
else:
obj = s3.get_object(Bucket=bucket, Key=key)
return obj['Body'].read()
def get_pdf_num_pages(s3_path: str) -> Optional[int]:
try:
with tempfile.NamedTemporaryFile("wb+", suffix=".pdf") as tf:
tf.write(get_s3_bytes(s3_path))
tf.flush()
reader = PdfReader(tf.name)
return reader.get_num_pages()
except Exception as ex:
print(f"Warning, could not add {s3_path} due to {ex}")
return None
2024-10-10 22:10:26 +00:00
if __name__ == '__main__':
2024-10-11 20:22:58 +00:00
parser = argparse.ArgumentParser(description='Manager for running millions of PDFs through a batch inference pipeline')
parser.add_argument('workspace', help='The S3 path where work will be done e.g., s3://bucket/prefix/)')
2024-10-11 21:50:09 +00:00
parser.add_argument('--pdfs', help='Glob path to PDFs (local or s3)', default=None)
2024-10-11 20:22:58 +00:00
parser.add_argument('--file_size_limit', type=int, default=250, help='Max file size in MB')
args = parser.parse_args()
2024-10-11 20:22:58 +00:00
db = DatabaseManager(args.workspace)
print(f"Loaded db at {db.db_path}")
2024-10-11 21:50:09 +00:00
print(f"Current round is {db.get_current_round()}\n")
# One shared executor to rule them all
executor = ProcessPoolExecutor()
# If you have new PDFs, add them to the list
if args.pdfs:
assert args.pdfs.startswith("s3://"), "PDFs must live on s3"
print(f"Querying all PDFs at {args.pdfs}")
all_pdfs = expand_s3_glob(args.pdfs)
print(f"Found {len(all_pdfs)} total pdf paths")
all_pdfs = [pdf for pdf in all_pdfs if not db.pdf_exists(pdf)]
print(f"Need to import {len(all_pdfs)} total new pdf paths")
future_to_path = {executor.submit(get_pdf_num_pages, s3_path): s3_path for s3_path in all_pdfs}
for future in tqdm(as_completed(future_to_path), total=len(future_to_path)):
s3_path = future_to_path[future]
if future.result() and not db.pdf_exists(s3_path):
db.add_pdf(s3_path, future.result(), "pending")
print("\n")
2024-10-11 20:22:58 +00:00
2024-10-11 21:50:09 +00:00
# Now build an index of all the pages that were processed within the workspace so far
build_index(f"{args.workspace}/*.jsonl")
2024-10-11 20:22:58 +00:00
2024-10-11 21:50:09 +00:00
# Now, for each pending book, find all pages which still need to be processed
# and add them to the next round's batch inference jobs