import os import hashlib import boto3 import sqlite3 import json import argparse import glob import tempfile import datetime import posixpath import threading import logging import boto3.session import urllib3.exceptions from dataclasses import dataclass from pypdf import PdfReader from tqdm import tqdm from functools import partial from typing import Optional, List, Tuple, Dict, Callable, Any from urllib.parse import urlparse from concurrent.futures import ProcessPoolExecutor, as_completed from pdelfin.data.renderpdf import render_pdf_to_base64png from pdelfin.prompts import build_finetuning_prompt from pdelfin.prompts.anchor import get_anchor_text from pdelfin.s3_utils import parse_custom_id, expand_s3_glob, get_s3_bytes, put_s3_bytes # Global s3 client for the whole script, feel free to adjust params if you need it workspace_s3 = boto3.client('s3') pdf_s3 = boto3.client('s3') # Quiet logs from pypdf and smart open logging.getLogger("pypdf").setLevel(logging.ERROR) logging.getLogger("smart_open").setLevel(logging.ERROR) class DatabaseManager: @dataclass(frozen=True) class BatchInferenceRecord: inference_s3_path: str pdf_s3_path: str page_num: int # 1 indexed! round: int start_index: int length: int finish_reason: str error: Optional[str] def is_usable(self): return self.error is None and self.finish_reason == "stop" @dataclass(frozen=True) class PDFRecord: s3_path: str num_pages: int status: str def __init__(self, s3_workspace: str): cache_key = hashlib.sha256(s3_workspace.strip().lower().encode('utf-8')).hexdigest() home_cache_dir = os.path.join(os.path.expanduser('~'), '.cache', 'pdelfin', cache_key) os.makedirs(home_cache_dir, exist_ok=True) self.db_path = os.path.join(home_cache_dir, 'index.db') self.conn = sqlite3.connect(self.db_path) self.cursor = self.conn.cursor() self._initialize_tables() def _initialize_tables(self): self.cursor.execute(""" CREATE TABLE IF NOT EXISTS page_results ( inference_s3_path TEXT, pdf_s3_path TEXT, page_num INTEGER, round INTEGER, start_index BIGINT, length BIGINT, finish_reason TEXT, error TEXT ) """) self.cursor.execute(""" CREATE INDEX IF NOT EXISTS idx_path ON page_results(pdf_s3_path) """) self.cursor.execute(""" CREATE TABLE IF NOT EXISTS pdfs ( s3_path TEXT PRIMARY KEY, num_pages INTEGER, status TEXT DEFAULT 'pending' ) """) self.cursor.execute(""" CREATE TABLE IF NOT EXISTS processed_files ( s3_path TEXT PRIMARY KEY, etag TEXT ) """) # Generic metadata such as current round self.cursor.execute(""" CREATE TABLE IF NOT EXISTS metadata ( key TEXT PRIMARY KEY, value TEXT ) """) self.conn.commit() def get_metadata(self, key: str) -> Optional[str]: self.cursor.execute("SELECT value FROM metadata WHERE key=?", (key,)) result = self.cursor.fetchone() return result[0] if result else None def set_metadata(self, key: str, value: str) -> None: self.cursor.execute(""" INSERT INTO metadata (key, value) VALUES (?, ?) ON CONFLICT(key) DO UPDATE SET value=excluded.value """, (key, value)) self.conn.commit() def is_file_processed(self, s3_path, etag): self.cursor.execute("SELECT etag FROM processed_files WHERE s3_path = ?", (s3_path,)) result = self.cursor.fetchone() return result is not None and result[0] == etag def update_processed_file(self, s3_path, etag): self.cursor.execute(""" INSERT INTO processed_files (s3_path, etag) VALUES (?, ?) ON CONFLICT(s3_path) DO UPDATE SET etag=excluded.etag """, (s3_path, etag)) self.conn.commit() def add_index_entries(self, index_entries: List[BatchInferenceRecord]): if index_entries: self.cursor.executemany(""" INSERT INTO page_results (inference_s3_path, pdf_s3_path, page_num, round, start_index, length, finish_reason, error) VALUES (?, ?, ?, ?, ?, ?, ?, ?) """, [(entry.inference_s3_path, entry.pdf_s3_path, entry.page_num, entry.round, entry.start_index, entry.length, entry.finish_reason, entry.error) for entry in index_entries]) self.conn.commit() def get_index_entries(self, pdf_s3_path: str) -> List[BatchInferenceRecord]: self.cursor.execute(""" SELECT inference_s3_path, pdf_s3_path, page_num, round, start_index, length, finish_reason, error FROM page_results WHERE pdf_s3_path = ? ORDER BY inference_s3_path DESC, start_index ASC, page_num ASC """, (pdf_s3_path,)) rows = self.cursor.fetchall() return [ self.BatchInferenceRecord( inference_s3_path=row[0], pdf_s3_path=row[1], page_num=row[2], round=row[3], start_index=row[4], length=row[5], finish_reason=row[6], error=row[7] ) for row in rows ] def get_last_indexed_round(self) -> int: self.cursor.execute(""" SELECT MAX(round) FROM page_results """) result = self.cursor.fetchone() return -1 if result[0] is None else result[0] def pdf_exists(self, s3_path: str) -> bool: self.cursor.execute("SELECT 1 FROM pdfs WHERE s3_path = ?", (s3_path,)) return self.cursor.fetchone() is not None def add_pdf(self, s3_path: str, num_pages: int, status: str = 'pending') -> None: try: self.cursor.execute(""" INSERT INTO pdfs (s3_path, num_pages, status) VALUES (?, ?, ?) """, (s3_path, num_pages, status)) self.conn.commit() except sqlite3.IntegrityError: print(f"PDF with s3_path '{s3_path}' already exists.") def update_pdf_status(self, s3_path: str, new_status: str) -> None: self.cursor.execute(""" UPDATE pdfs SET status = ? WHERE s3_path = ? """, (new_status, s3_path)) self.conn.commit() def get_pdf(self, s3_path: str) -> Optional[PDFRecord]: self.cursor.execute(""" SELECT s3_path, num_pages, status FROM pdfs WHERE s3_path = ? """, (s3_path,)) row = self.cursor.fetchone() if row: return self.PDFRecord( s3_path=row[0], num_pages=row[1], status=row[2] ) return None def get_pdfs_by_status(self, status: str) -> List[PDFRecord]: self.cursor.execute(""" SELECT s3_path, num_pages, status FROM pdfs WHERE status == ? ORDER BY s3_path DESC, num_pages DESC """, (status, )) rows = self.cursor.fetchall() return [ self.PDFRecord( s3_path=row[0], num_pages=row[1], status=row[2] ) for row in rows ] def close(self): self.conn.close() # Writes batches of lines out to a set of files, keeping each file below some maximum size class BatchWriter: def __init__(self, output_prefix: str, max_size_mb: int = 250, after_flush: Optional[Callable[[List[str]], Any]] = None): self.output_prefix = output_prefix self.max_size = max_size_mb * 1024 * 1024 # Convert MB to bytes self.batch = [] self.batch_size = 0 self.after_flush = after_flush self.threads = [] parsed = urlparse(output_prefix) self.is_s3 = parsed.scheme in ('s3', 's3a', 's3n') if not self.is_s3: os.makedirs(output_prefix, exist_ok=True) def _compute_hash(self, content: str) -> str: """Compute a 20-character SHA1 hash of the given content.""" sha1 = hashlib.sha1() sha1.update(content.encode('utf-8')) return sha1.hexdigest()[:20] def _get_output_path(self, hash_str: str) -> str: """Generate the full output path with hash in the filename.""" parsed = urlparse(self.output_prefix) if self.is_s3: bucket = parsed.netloc key = parsed.path.lstrip('/') if key and not key.endswith('/'): key += '/' full_key = posixpath.join(key, f"output_{hash_str}.jsonl") return f"s3://{bucket}/{full_key}" else: filename = f"output_{hash_str}.jsonl" return os.path.join(self.output_prefix, filename) def write_line(self, line: Optional[str]): if line is None or not line.strip(): return line_size = len(line.encode('utf-8')) + 1 # +1 for newline if self.batch_size + line_size > self.max_size: self._write_batch() self.batch.append(line) self.batch_size += line_size def _write_batch(self): if not self.batch: return batch_lines = self.batch.copy() batch_content = "\n".join(batch_lines) + "\n" hash_str = self._compute_hash(batch_content) output_path = self._get_output_path(hash_str) # Start a new thread to write the batch thread = threading.Thread( target=self._write_batch_to_file, args=(batch_content, output_path, batch_lines) ) thread.start() self.threads.append(thread) # Clear the batch and batch_size self.batch = [] self.batch_size = 0 def _write_batch_to_file(self, batch_content: str, output_path: str, batch_lines: List[str]): if self.is_s3: put_s3_bytes(workspace_s3, output_path, batch_content.encode("utf-8")) else: with open(output_path, 'w', encoding='utf-8') as f_out: f_out.write(batch_content) # After writing, call the after_flush callback if it is set if self.after_flush: self.after_flush(batch_lines) def close(self): self._write_batch() # Wait for all threads to finish for thread in self.threads: thread.join() def build_page_query(local_pdf_path: str, pretty_pdf_path: str, page: int) -> dict: image_base64 = render_pdf_to_base64png(local_pdf_path, page, 1024) anchor_text = get_anchor_text(local_pdf_path, page, pdf_engine="pdfreport") return { "custom_id": f"{pretty_pdf_path}-{page}", "chat_messages": [ { "role": "user", "content": [ {"type": "text", "text": build_finetuning_prompt(anchor_text)}, {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_base64}"}} ], } ], } def process_jsonl_content(inference_s3_path: str) -> List[DatabaseManager.BatchInferenceRecord]: content_bytes = get_s3_bytes(workspace_s3, inference_s3_path) start_index = 0 index_entries = [] lines = content_bytes.splitlines(keepends=True) # Split content into lines as bytes for line in lines: line_length = len(line) # Length in bytes try: # Decode the line for JSON processing line_str = line.decode('utf-8') data = json.loads(line_str) pdf_s3_path, page_num = parse_custom_id(data["custom_id"]) if data.get("completion_error", None) is not None: index_entries.append(DatabaseManager.BatchInferenceRecord( inference_s3_path=inference_s3_path, pdf_s3_path=pdf_s3_path, page_num=page_num, round=data["round"], start_index=start_index, # Byte offset in the original file length=line_length, # Length in bytes finish_reason="completion_error", error=data.get("completion_error", None) )) else: # Try to parse the actual model response JSON assert "outputs" in data and len(data["outputs"]) > 0, "No outputs from model detected" try: model_response_json = json.loads(data["outputs"][0]["text"]) last_error = data.get("completion_error", None) if not model_response_json["is_rotation_valid"]: last_error = "rotation_invalid" index_entries.append(DatabaseManager.BatchInferenceRecord( inference_s3_path=inference_s3_path, pdf_s3_path=pdf_s3_path, page_num=page_num, round=data["round"], start_index=start_index, # Byte offset in the original file length=line_length, # Length in bytes finish_reason=data["outputs"][0]["finish_reason"], error=last_error, )) except json.JSONDecodeError: index_entries.append(DatabaseManager.BatchInferenceRecord( inference_s3_path=inference_s3_path, pdf_s3_path=pdf_s3_path, page_num=page_num, round=data["round"], start_index=start_index, # Byte offset in the original file length=line_length, # Length in bytes finish_reason=data["outputs"][0]["finish_reason"], error="Could not parse model JSON output", )) except json.JSONDecodeError: print(f"Error with JSON Decoding of inference in {inference_s3_path}") # TODO Maybe this needs to add an index error that this json is bad except Exception as e: print(f"Error processing line: {e}") start_index += line_length # Increment by the number of bytes return index_entries def get_pdf_num_pages(s3_path: str) -> Optional[int]: try: with tempfile.NamedTemporaryFile("wb+", suffix=".pdf") as tf: tf.write(get_s3_bytes(pdf_s3, s3_path)) tf.flush() reader = PdfReader(tf.name) return reader.get_num_pages() except Exception as ex: print(f"Warning, could not add {s3_path} due to {ex}") return None def build_pdf_queries(s3_workspace: str, pdf: DatabaseManager.PDFRecord, cur_round: int) -> list[dict]: db = DatabaseManager(s3_workspace) existing_pages = db.get_index_entries(pdf.s3_path) new_queries = [] # Shortcut out of downloading the actual PDF if set(page.page_num for page in existing_pages if page.is_usable()) == set(range(1, pdf.num_pages + 1)): return [] try: with tempfile.NamedTemporaryFile("wb+", suffix=".pdf") as tf: tf.write(get_s3_bytes(pdf_s3, pdf.s3_path)) tf.flush() for target_page_num in range(1, pdf.num_pages + 1): # Is there an existing page that has no error if any(page.is_usable() and page.page_num == target_page_num for page in existing_pages): continue has_errored_previously = sum(page.page_num == target_page_num for page in existing_pages) if has_errored_previously: # Retry the page at least one more time regularly new_queries.append({**build_page_query(tf.name, pdf.s3_path, target_page_num), "round": cur_round}) # TODO: If the rotation was previously invalid, then apply a rotation # TODO: Try to provide a smaller prompt hint else: new_queries.append({**build_page_query(tf.name, pdf.s3_path, target_page_num), "round": cur_round}) except Exception as ex: print(f"Warning, could not get batch inferences lines for {pdf.s3_path} due to {ex}") return new_queries def build_dolma_doc(s3_workspace: str, pdf: DatabaseManager.PDFRecord) -> Optional[dict]: db = DatabaseManager(s3_workspace) existing_pages = db.get_index_entries(pdf.s3_path) document_text = "" last_page_start_index = 0 pdf_page_spans = [] for target_page_num in range(1, pdf.num_pages + 1): usable_pages = [page for page in existing_pages if page.is_usable() and page.page_num == target_page_num] if len(usable_pages) == 0: return None usable_page_data = [get_s3_bytes(workspace_s3, page.inference_s3_path, start_index=page.start_index, end_index=page.start_index + page.length - 1) for page in usable_pages] usable_page_final_results = [json.loads(json.loads(page_data.decode("utf-8"))["outputs"][0]["text"]) for page_data in usable_page_data] # Sort the pages: # 1. Prefer pages with `is_rotation_valid` set to True. # 2. Within those, sort by the length of the `natural_text` in descending order. usable_page_final_results.sort( key=lambda page: (not page["is_rotation_valid"], -len(page["natural_text"] if page["natural_text"] else "")) ) target_page_final_result = usable_page_final_results[0] if target_page_final_result["natural_text"] is not None: document_text += target_page_final_result["natural_text"] + "\n" pdf_page_spans.append([last_page_start_index, len(document_text), target_page_num]) last_page_start_index = len(document_text) metadata = { "Source-File": pdf.s3_path, "pdf-total-pages": pdf.num_pages, } id_ = hashlib.sha1(document_text.encode()).hexdigest() dolma_doc = { "id": id_, "text": document_text, "source": "pdelfin", "added": datetime.datetime.now().strftime("%Y-%m-%d"), "created": datetime.datetime.now().strftime("%Y-%m-%d"), "metadata": metadata, "attributes": { "pdf_page_numbers": pdf_page_spans } } return dolma_doc def mark_pdfs_done(s3_workspace: str, dolma_doc_lines: list[str]): db = DatabaseManager(s3_workspace) for line in dolma_doc_lines: db.update_pdf_status(json.loads(line)["metadata"]["Source-File"], "completed") def get_current_round(s3_workspace: str) -> int: path = s3_workspace[5:] bucket, _, prefix = path.partition('/') inference_inputs_prefix = posixpath.join(prefix, 'inference_inputs/') paginator = workspace_s3.get_paginator('list_objects_v2') page_iterator = paginator.paginate(Bucket=bucket, Prefix=inference_inputs_prefix, Delimiter='/') round_numbers = [] for page in page_iterator: for common_prefix in page.get('CommonPrefixes', []): round_prefix = common_prefix.get('Prefix') # Extract 'round_X' from the prefix round_dir = posixpath.basename(posixpath.dirname(round_prefix)) if round_dir.startswith('round_'): try: round_num = int(round_dir[len('round_'):]) round_numbers.append(round_num) except ValueError: pass if round_numbers: current_round = max(round_numbers) + 1 else: current_round = 0 return current_round if __name__ == '__main__': parser = argparse.ArgumentParser(description='Manager for running millions of PDFs through a batch inference pipeline') parser.add_argument('workspace', help='The S3 path where work will be done e.g., s3://bucket/prefix/)') parser.add_argument('--add_pdfs', help='Path to add pdfs stored in s3 to the workspace, can be a glob path s3://bucket/prefix/*.pdf or path to file containing list of pdf paths', default=None) parser.add_argument('--prefilter_lang', help='If set, tries to detect the language of the pdf and only accepts it if it matches (ex. ENGLISH)') parser.add_argument('--prefilter_spam', help='If set, tries to detect spammy pdfs and not include them') parser.add_argument('--workspace_profile', help='S3 configuration profile for accessing the workspace', default=None) parser.add_argument('--pdf_profile', help='S3 configuration profile for accessing the raw pdf documents', default=None) parser.add_argument('--max_size_mb', type=int, default=250, help='Max file size in MB') args = parser.parse_args() if args.workspace_profile: workspace_session = boto3.Session(profile_name=args.workspace_profile) workspace_s3 = workspace_session.client("s3") if args.pdf_profile: pdf_session = boto3.Session(profile_name=args.pdf_profile) pdf_s3 = pdf_session.client("s3") db = DatabaseManager(args.workspace) print(f"Loaded db at {db.db_path}") current_round = get_current_round(args.workspace) print(f"Current round is {current_round}\n") # One shared executor to rule them all executor = ProcessPoolExecutor() # If you have new PDFs, step one is to add them to the list if args.add_pdfs: if args.add_pdfs.startswith("s3://"): print(f"Querying all PDFs at {args.add_pdfs}") all_pdfs = expand_s3_glob(pdf_s3, args.add_pdfs) print(f"Found {len(all_pdfs):,} total pdf paths") elif os.path.exists(args.add_pdfs): with open(args.add_pdfs, "r") as f: all_pdfs = [line.strip() for line in f.readlines() if len(line.strip()) > 0] else: raise ValueError("add_pdfs argument needs to be either an s3 glob search path, or a local file contains pdf paths (one per line)") all_pdfs = [pdf for pdf in all_pdfs if not db.pdf_exists(pdf)] print(f"Need to import {len(all_pdfs):,} total new pdf paths") future_to_path = {executor.submit(get_pdf_num_pages, s3_path): s3_path for s3_path in all_pdfs} for future in tqdm(as_completed(future_to_path), total=len(future_to_path)): s3_path = future_to_path[future] num_pages = future.result() if num_pages and not db.pdf_exists(s3_path): db.add_pdf(s3_path, num_pages, "pending") print("\n") # Now build an index of all the pages that were processed within the workspace so far print("Indexing all batch inference sent to this workspace") inference_output_paths = expand_s3_glob(workspace_s3, f"{args.workspace}/inference_outputs/*.jsonl") inference_output_paths = [ (s3_path, etag) for s3_path, etag in inference_output_paths.items() if not db.is_file_processed(s3_path, etag) ] print(f"Found {len(inference_output_paths):,} new batch inference results to index") future_to_path = {executor.submit(process_jsonl_content, s3_path): (s3_path, etag) for s3_path, etag in inference_output_paths} for future in tqdm(as_completed(future_to_path), total=len(future_to_path)): s3_path, etag = future_to_path[future] try: inference_records = future.result() db.add_index_entries(inference_records) db.update_processed_file(s3_path, etag=etag) except urllib3.exceptions.SSLError: print(f"Cannot load inference file {s3_path} due to SSL error, will retry another time") # Now query each pdf, if you have all of the pages needed (all pages present, error is null and finish_reason is stop), then you assemble it into a dolma document and output it # If you don't have every page, or if you have pages with errors, then you output a new batch of inference items to use if db.get_last_indexed_round() < current_round - 1: print(f"WARNING: No new batch inference results found, you need to run batch inference on {args.workspace}/inference_inputs/round_{current_round - 1}") potentially_done_pdfs = db.get_pdfs_by_status("pending") else: print(f"\nCreating batch inference files for new PDFs") future_to_path = {executor.submit(build_pdf_queries, args.workspace, pdf, current_round): pdf for pdf in db.get_pdfs_by_status("pending")} potentially_done_pdfs = [] lines_written = 0 new_inference_writer = BatchWriter(f"{args.workspace}/inference_inputs/round_{current_round}", args.max_size_mb) for future in tqdm(as_completed(future_to_path), total=len(future_to_path)): pdf = future_to_path[future] inference_lines = future.result() if len(inference_lines) == 0: potentially_done_pdfs.append(pdf) for line in inference_lines: lines_written += 1 if line is not None: new_inference_writer.write_line(json.dumps(line)) new_inference_writer.close() if lines_written > 0: print(f"Added {lines_written:,} new batch inference requests") # Now, finally, assemble any potentially done docs into dolma documents print(f"\nAssembling potentially finished PDFs into Dolma documents at {args.workspace}/output") future_to_path = {executor.submit(build_dolma_doc, args.workspace, pdf): pdf for pdf in potentially_done_pdfs} new_output_writer = BatchWriter(f"{args.workspace}/output", args.max_size_mb, after_flush=partial(mark_pdfs_done, args.workspace)) for future in tqdm(as_completed(future_to_path), total=len(future_to_path)): pdf = future_to_path[future] dolma_doc = future.result() if dolma_doc is not None: new_output_writer.write_line(json.dumps(dolma_doc)) new_output_writer.close() print("\nFinal statistics:") # Output the number of PDFs in each status "pending" and "completed" pending_pdfs = db.get_pdfs_by_status("pending") completed_pdfs = db.get_pdfs_by_status("completed") print(f"Pending PDFs: {len(pending_pdfs):,} ({sum(doc.num_pages for doc in pending_pdfs):,} pages)") print(f"Completed PDFs: {len(completed_pdfs):,} ({sum(doc.num_pages for doc in completed_pdfs):,} pages)") # For each round, outputs a report of how many pages were processed, how many had errors, and a breakdown by (error, finish_reason) total_rounds = db.get_last_indexed_round() + 1 for round_num in range(total_rounds): db.cursor.execute(""" SELECT COUNT(*), error, finish_reason FROM page_results WHERE round = ? GROUP BY error, finish_reason """, (round_num,)) results = db.cursor.fetchall() total_pages = sum(count for count, _, _ in results) print(f"\nInference Round {round_num} - {total_pages:,} pages processed:") for count, error, finish_reason in results: error_str = error if error is not None else "None" print(f" (error: {error_str}, finish_reason: {finish_reason}) -> {count:,} pages") print("\nWork finished, waiting for all workers to finish cleaning up") executor.shutdown(wait=True) db.close()