2024-10-09 22:14:28 +00:00
import os
import hashlib
2024-10-10 22:10:26 +00:00
import boto3
2024-10-10 22:30:09 +00:00
import sqlite3
2024-10-25 22:10:54 +00:00
import orjson
2024-10-10 22:10:26 +00:00
import argparse
2024-10-23 16:28:46 +00:00
import uuid
2024-10-11 21:50:09 +00:00
import tempfile
2024-10-14 17:23:09 +00:00
import datetime
2024-10-11 21:50:09 +00:00
import posixpath
2024-10-14 20:28:54 +00:00
import threading
2024-10-14 18:42:50 +00:00
import logging
2024-10-15 16:22:55 +00:00
import boto3 . session
2024-10-15 16:02:08 +00:00
import urllib3 . exceptions
2024-10-11 21:50:09 +00:00
2024-10-11 22:37:32 +00:00
from dataclasses import dataclass
2024-10-11 21:50:09 +00:00
from pypdf import PdfReader
2024-10-09 22:14:28 +00:00
from tqdm import tqdm
2024-10-14 20:28:54 +00:00
from functools import partial
from typing import Optional , List , Tuple , Dict , Callable , Any
2024-10-11 21:50:09 +00:00
from urllib . parse import urlparse
2024-10-23 16:28:46 +00:00
import concurrent . futures
2024-10-10 22:30:09 +00:00
from concurrent . futures import ProcessPoolExecutor , as_completed
2024-10-09 22:14:28 +00:00
2024-10-11 22:57:49 +00:00
from pdelfin . data . renderpdf import render_pdf_to_base64png
2024-10-25 22:10:54 +00:00
from pdelfin . prompts import build_finetuning_prompt , PageResponse
2024-10-11 22:57:49 +00:00
from pdelfin . prompts . anchor import get_anchor_text
2024-10-23 21:51:54 +00:00
from pdelfin . s3_utils import parse_custom_id , expand_s3_glob , get_s3_bytes , parse_s3_path
2024-10-16 16:18:27 +00:00
2024-10-11 22:57:49 +00:00
2024-10-11 21:50:09 +00:00
# Global s3 client for the whole script, feel free to adjust params if you need it
2024-10-15 16:22:55 +00:00
workspace_s3 = boto3 . client ( ' s3 ' )
pdf_s3 = boto3 . client ( ' s3 ' )
2024-10-11 21:50:09 +00:00
2024-10-14 18:42:50 +00:00
# Quiet logs from pypdf and smart open
logging . getLogger ( " pypdf " ) . setLevel ( logging . ERROR )
logging . getLogger ( " smart_open " ) . setLevel ( logging . ERROR )
2024-10-25 22:10:54 +00:00
2024-10-11 16:24:29 +00:00
class DatabaseManager :
2024-10-14 16:30:49 +00:00
@dataclass ( frozen = True )
class BatchInferenceRecord :
2024-10-14 18:19:17 +00:00
inference_s3_path : str
pdf_s3_path : str
2024-10-14 16:30:49 +00:00
page_num : int # 1 indexed!
2024-10-14 18:42:50 +00:00
round : int
2024-10-14 16:30:49 +00:00
start_index : int
length : int
finish_reason : str
error : Optional [ str ]
def is_usable ( self ) :
return self . error is None and self . finish_reason == " stop "
@dataclass ( frozen = True )
class PDFRecord :
s3_path : str
num_pages : int
status : str
2024-10-25 22:10:54 +00:00
def __init__ ( self , s3_workspace : str , skip_init : bool = False ) :
2024-10-11 20:22:58 +00:00
cache_key = hashlib . sha256 ( s3_workspace . strip ( ) . lower ( ) . encode ( ' utf-8 ' ) ) . hexdigest ( )
home_cache_dir = os . path . join ( os . path . expanduser ( ' ~ ' ) , ' .cache ' , ' pdelfin ' , cache_key )
os . makedirs ( home_cache_dir , exist_ok = True )
self . db_path = os . path . join ( home_cache_dir , ' index.db ' )
2024-10-11 16:24:29 +00:00
self . conn = sqlite3 . connect ( self . db_path )
self . cursor = self . conn . cursor ( )
2024-10-25 22:10:54 +00:00
if not skip_init :
self . _initialize_tables ( )
2024-10-11 16:24:29 +00:00
def _initialize_tables ( self ) :
self . cursor . execute ( """
2024-10-11 22:37:32 +00:00
CREATE TABLE IF NOT EXISTS page_results (
2024-10-14 18:19:17 +00:00
inference_s3_path TEXT ,
pdf_s3_path TEXT ,
2024-10-11 22:37:32 +00:00
page_num INTEGER ,
2024-10-14 18:42:50 +00:00
round INTEGER ,
2024-10-11 16:24:29 +00:00
start_index BIGINT ,
2024-10-11 22:37:32 +00:00
length BIGINT ,
2024-10-11 22:41:09 +00:00
finish_reason TEXT ,
error TEXT
2024-10-11 16:24:29 +00:00
)
""" )
2024-10-11 20:22:58 +00:00
self . cursor . execute ( """
2024-10-14 18:23:22 +00:00
CREATE INDEX IF NOT EXISTS idx_path ON page_results ( pdf_s3_path )
2024-10-11 20:22:58 +00:00
""" )
2024-10-22 16:03:06 +00:00
self . cursor . execute ( """
CREATE INDEX IF NOT EXISTS idx_inf_path ON page_results ( inference_s3_path )
""" )
2024-10-11 20:22:58 +00:00
self . cursor . execute ( """
CREATE TABLE IF NOT EXISTS pdfs (
s3_path TEXT PRIMARY KEY ,
num_pages INTEGER ,
status TEXT DEFAULT ' pending '
)
""" )
2024-10-11 16:24:29 +00:00
self . cursor . execute ( """
CREATE TABLE IF NOT EXISTS processed_files (
s3_path TEXT PRIMARY KEY ,
etag TEXT
)
""" )
2024-10-11 20:22:58 +00:00
# Generic metadata such as current round
self . cursor . execute ( """
CREATE TABLE IF NOT EXISTS metadata (
key TEXT PRIMARY KEY ,
value TEXT
)
""" )
2024-10-11 22:37:32 +00:00
2024-10-11 16:24:29 +00:00
self . conn . commit ( )
2024-10-11 22:41:09 +00:00
def get_metadata ( self , key : str ) - > Optional [ str ] :
2024-10-11 22:37:32 +00:00
self . cursor . execute ( " SELECT value FROM metadata WHERE key=? " , ( key , ) )
2024-10-11 20:22:58 +00:00
result = self . cursor . fetchone ( )
2024-10-11 22:41:09 +00:00
return result [ 0 ] if result else None
2024-10-14 17:23:09 +00:00
def set_metadata ( self , key : str , value : str ) - > None :
self . cursor . execute ( """
INSERT INTO metadata ( key , value )
VALUES ( ? , ? )
ON CONFLICT ( key ) DO UPDATE SET value = excluded . value
""" , (key, value))
self . conn . commit ( )
2024-10-11 22:37:32 +00:00
2024-10-11 16:24:29 +00:00
def is_file_processed ( self , s3_path , etag ) :
self . cursor . execute ( " SELECT etag FROM processed_files WHERE s3_path = ? " , ( s3_path , ) )
result = self . cursor . fetchone ( )
return result is not None and result [ 0 ] == etag
2024-10-14 20:28:54 +00:00
def update_processed_file ( self , s3_path , etag ) :
self . cursor . execute ( """
INSERT INTO processed_files ( s3_path , etag )
VALUES ( ? , ? )
ON CONFLICT ( s3_path ) DO UPDATE SET etag = excluded . etag
""" , (s3_path, etag))
self . conn . commit ( )
2024-10-11 16:24:29 +00:00
2024-10-25 20:32:51 +00:00
def clear_index ( self ) :
self . cursor . execute ( """
2024-10-25 22:10:54 +00:00
DELETE FROM processed_files ;
""" )
self . cursor . execute ( """
DELETE FROM page_results ;
""" )
2024-10-25 20:32:51 +00:00
self . conn . commit ( )
2024-10-14 16:30:49 +00:00
def add_index_entries ( self , index_entries : List [ BatchInferenceRecord ] ) :
2024-10-11 16:24:29 +00:00
if index_entries :
self . cursor . executemany ( """
2024-10-14 18:42:50 +00:00
INSERT INTO page_results ( inference_s3_path , pdf_s3_path , page_num , round , start_index , length , finish_reason , error )
2024-10-14 20:06:07 +00:00
VALUES ( ? , ? , ? , ? , ? , ? , ? , ? )
2024-10-14 18:42:50 +00:00
""" , [(entry.inference_s3_path, entry.pdf_s3_path, entry.page_num, entry.round, entry.start_index, entry.length, entry.finish_reason, entry.error) for entry in index_entries])
2024-10-11 16:24:29 +00:00
self . conn . commit ( )
2024-10-14 18:23:22 +00:00
def get_index_entries ( self , pdf_s3_path : str ) - > List [ BatchInferenceRecord ] :
2024-10-14 16:30:49 +00:00
self . cursor . execute ( """
2024-10-14 18:42:50 +00:00
SELECT inference_s3_path , pdf_s3_path , page_num , round , start_index , length , finish_reason , error
2024-10-14 16:30:49 +00:00
FROM page_results
2024-10-14 18:23:22 +00:00
WHERE pdf_s3_path = ?
ORDER BY inference_s3_path DESC , start_index ASC , page_num ASC
""" , (pdf_s3_path,))
2024-10-14 16:30:49 +00:00
rows = self . cursor . fetchall ( )
return [
self . BatchInferenceRecord (
2024-10-14 18:19:17 +00:00
inference_s3_path = row [ 0 ] ,
pdf_s3_path = row [ 1 ] ,
page_num = row [ 2 ] ,
2024-10-14 18:42:50 +00:00
round = row [ 3 ] ,
start_index = row [ 4 ] ,
length = row [ 5 ] ,
finish_reason = row [ 6 ] ,
error = row [ 7 ]
2024-10-14 16:30:49 +00:00
)
for row in rows
]
2024-10-22 15:47:11 +00:00
def delete_index_entries_by_inference_s3_path ( self , inference_s3_path : str ) :
self . cursor . execute ( " DELETE FROM page_results WHERE inference_s3_path = ? " , ( inference_s3_path , ) )
self . conn . commit ( )
2024-10-14 16:30:49 +00:00
2024-10-14 18:42:50 +00:00
def get_last_indexed_round ( self ) - > int :
self . cursor . execute ( """
SELECT MAX ( round )
FROM page_results
""" )
result = self . cursor . fetchone ( )
return - 1 if result [ 0 ] is None else result [ 0 ]
2024-10-11 21:50:09 +00:00
def pdf_exists ( self , s3_path : str ) - > bool :
self . cursor . execute ( " SELECT 1 FROM pdfs WHERE s3_path = ? " , ( s3_path , ) )
return self . cursor . fetchone ( ) is not None
def add_pdf ( self , s3_path : str , num_pages : int , status : str = ' pending ' ) - > None :
try :
self . cursor . execute ( """
INSERT INTO pdfs ( s3_path , num_pages , status )
VALUES ( ? , ? , ? )
""" , (s3_path, num_pages, status))
self . conn . commit ( )
except sqlite3 . IntegrityError :
print ( f " PDF with s3_path ' { s3_path } ' already exists. " )
2024-10-14 20:28:54 +00:00
def update_pdf_status ( self , s3_path : str , new_status : str ) - > None :
self . cursor . execute ( """
UPDATE pdfs
SET status = ?
WHERE s3_path = ?
""" , (new_status, s3_path))
self . conn . commit ( )
2024-10-14 16:30:49 +00:00
def get_pdf ( self , s3_path : str ) - > Optional [ PDFRecord ] :
self . cursor . execute ( """
SELECT s3_path , num_pages , status
FROM pdfs
WHERE s3_path = ?
""" , (s3_path,))
row = self . cursor . fetchone ( )
if row :
return self . PDFRecord (
s3_path = row [ 0 ] ,
num_pages = row [ 1 ] ,
status = row [ 2 ]
)
return None
def get_pdfs_by_status ( self , status : str ) - > List [ PDFRecord ] :
self . cursor . execute ( """
SELECT s3_path , num_pages , status
FROM pdfs
WHERE status == ?
2024-10-15 16:02:08 +00:00
ORDER BY s3_path DESC , num_pages DESC
2024-10-14 16:30:49 +00:00
""" , (status, ))
rows = self . cursor . fetchall ( )
return [
self . PDFRecord (
s3_path = row [ 0 ] ,
num_pages = row [ 1 ] ,
status = row [ 2 ]
)
for row in rows
]
2024-10-11 21:50:09 +00:00
2024-10-11 16:24:29 +00:00
def close ( self ) :
self . conn . close ( )
2024-10-11 22:57:49 +00:00
2024-10-14 17:09:11 +00:00
class BatchWriter :
2024-10-23 16:28:46 +00:00
def __init__ (
self ,
output_prefix : str ,
max_size_mb : int = 250 ,
after_flush : Optional [ Callable [ [ List [ Any ] ] , Any ] ] = None ,
) :
2024-10-14 17:09:11 +00:00
self . output_prefix = output_prefix
self . max_size = max_size_mb * 1024 * 1024 # Convert MB to bytes
2024-10-23 16:28:46 +00:00
self . batch_objects = [ ]
2024-10-14 17:09:11 +00:00
self . batch_size = 0
2024-10-14 20:28:54 +00:00
self . after_flush = after_flush
self . threads = [ ]
2024-10-23 16:28:46 +00:00
self . temp_file = None # The temporary file object
self . temp_file_path = None # Path to the temporary file
2024-10-14 17:09:11 +00:00
parsed = urlparse ( output_prefix )
2024-10-23 16:28:46 +00:00
self . is_s3 = parsed . scheme in ( " s3 " , " s3a " , " s3n " )
2024-10-14 17:09:11 +00:00
if not self . is_s3 :
os . makedirs ( output_prefix , exist_ok = True )
2024-10-23 16:28:46 +00:00
def write_line ( self , obj : Optional [ Any ] ) :
if obj is None :
2024-10-14 20:06:07 +00:00
return
2024-10-14 20:28:54 +00:00
2024-10-25 22:10:54 +00:00
line_bytes = orjson . dumps ( obj )
2024-10-23 16:28:46 +00:00
line_size = len ( line_bytes ) + 1 # +1 for newline
2024-10-14 17:09:11 +00:00
if self . batch_size + line_size > self . max_size :
self . _write_batch ( )
2024-10-14 20:28:54 +00:00
2024-10-23 16:28:46 +00:00
if self . batch_size == 0 :
# Open a new temporary file
self . temp_file = tempfile . NamedTemporaryFile ( mode = " wb+ " , delete = False )
self . temp_file_path = self . temp_file . name
self . temp_file . write ( line_bytes + b " \n " )
self . batch_objects . append ( obj )
2024-10-14 17:09:11 +00:00
self . batch_size + = line_size
def _write_batch ( self ) :
2024-10-23 16:28:46 +00:00
if self . batch_size == 0 :
2024-10-14 17:09:11 +00:00
return
2024-10-23 16:28:46 +00:00
# Close the temp file
self . temp_file . flush ( )
self . temp_file . close ( )
2024-10-14 16:30:49 +00:00
2024-10-23 16:28:46 +00:00
# Start a new thread to upload the temp file
2024-10-14 20:28:54 +00:00
thread = threading . Thread (
2024-10-23 16:28:46 +00:00
target = self . _write_batch_to_file , args = ( self . temp_file_path , self . batch_objects )
2024-10-14 20:28:54 +00:00
)
thread . start ( )
self . threads . append ( thread )
2024-10-14 17:09:11 +00:00
2024-10-23 16:28:46 +00:00
# Reset batch_objects and batch_size
self . batch_objects = [ ]
2024-10-14 17:09:11 +00:00
self . batch_size = 0
2024-10-23 16:28:46 +00:00
self . temp_file = None
self . temp_file_path = None
def _write_batch_to_file ( self , temp_file_path : str , batch_objects : List [ Any ] ) :
# Compute hash based on file content
hash_str = self . _compute_hash ( temp_file_path )
output_path = self . _get_output_path ( hash_str )
2024-10-14 17:09:11 +00:00
2024-10-14 20:28:54 +00:00
if self . is_s3 :
2024-10-23 21:51:54 +00:00
bucket , key = parse_s3_path ( output_path )
2024-10-23 16:28:46 +00:00
# Use the s3 client directly
try :
workspace_s3 . upload_file ( temp_file_path , bucket , key )
except Exception as e :
print ( f " Failed to upload { temp_file_path } to { output_path } : { e } " )
2024-10-14 20:28:54 +00:00
else :
2024-10-23 16:28:46 +00:00
# Move the temp file to the output path
os . rename ( temp_file_path , output_path )
2024-10-14 20:28:54 +00:00
# After writing, call the after_flush callback if it is set
if self . after_flush :
2024-10-23 16:28:46 +00:00
self . after_flush ( batch_objects )
os . remove ( temp_file_path )
def _compute_hash ( self , temp_file_path : str ) - > str :
""" Compute a 20-character SHA1 hash of the file content. """
sha1 = hashlib . sha1 ( )
with open ( temp_file_path , " rb " ) as f :
while True :
data = f . read ( 1024 * 1024 )
if not data :
break
sha1 . update ( data )
return sha1 . hexdigest ( ) [ : 20 ]
def _get_output_path ( self , hash_str : str ) - > str :
""" Generate the full output path with hash in the filename. """
parsed = urlparse ( self . output_prefix )
if self . is_s3 :
bucket = parsed . netloc
key = parsed . path . lstrip ( " / " )
if key and not key . endswith ( " / " ) :
key + = " / "
full_key = posixpath . join ( key , f " output_ { hash_str } .jsonl " )
return f " s3:// { bucket } / { full_key } "
else :
filename = f " output_ { hash_str } .jsonl "
return os . path . join ( self . output_prefix , filename )
2024-10-14 20:28:54 +00:00
2024-10-14 17:09:11 +00:00
def close ( self ) :
self . _write_batch ( )
2024-10-14 20:28:54 +00:00
# Wait for all threads to finish
for thread in self . threads :
thread . join ( )
2024-10-11 22:57:49 +00:00
2024-10-18 21:47:30 +00:00
def build_page_query ( local_pdf_path : str , pretty_pdf_path : str , page : int , target_longest_image_dim : int , target_anchor_text_len : int ) - > dict :
image_base64 = render_pdf_to_base64png ( local_pdf_path , page , target_longest_image_dim = target_longest_image_dim )
anchor_text = get_anchor_text ( local_pdf_path , page , pdf_engine = " pdfreport " , target_length = target_anchor_text_len )
2024-10-11 22:57:49 +00:00
return {
" custom_id " : f " { pretty_pdf_path } - { page } " ,
" chat_messages " : [
{
" role " : " user " ,
" content " : [
{ " type " : " text " , " text " : build_finetuning_prompt ( anchor_text ) } ,
{ " type " : " image_url " , " image_url " : { " url " : f " data:image/png;base64, { image_base64 } " } }
] ,
}
] ,
}
2024-10-14 20:06:07 +00:00
def process_jsonl_content ( inference_s3_path : str ) - > List [ DatabaseManager . BatchInferenceRecord ] :
2024-10-15 16:22:55 +00:00
content_bytes = get_s3_bytes ( workspace_s3 , inference_s3_path )
2024-10-11 22:37:32 +00:00
2024-10-10 22:10:26 +00:00
start_index = 0
2024-10-10 22:30:09 +00:00
index_entries = [ ]
2024-10-14 20:06:07 +00:00
lines = content_bytes . splitlines ( keepends = True ) # Split content into lines as bytes
2024-10-10 22:10:26 +00:00
for line in lines :
2024-10-14 20:06:07 +00:00
line_length = len ( line ) # Length in bytes
2024-10-11 22:37:32 +00:00
2024-10-09 22:19:16 +00:00
try :
2024-10-25 22:10:54 +00:00
# Parse the line directly as JSON
data = orjson . loads ( line )
2024-10-14 18:19:17 +00:00
pdf_s3_path , page_num = parse_custom_id ( data [ " custom_id " ] )
2024-10-11 22:37:32 +00:00
2024-10-15 22:43:31 +00:00
if data . get ( " completion_error " , None ) is not None :
2024-10-14 20:31:37 +00:00
index_entries . append ( DatabaseManager . BatchInferenceRecord (
inference_s3_path = inference_s3_path ,
pdf_s3_path = pdf_s3_path ,
page_num = page_num ,
round = data [ " round " ] ,
start_index = start_index , # Byte offset in the original file
length = line_length , # Length in bytes
2024-10-15 22:43:31 +00:00
finish_reason = " completion_error " ,
2024-10-14 20:31:37 +00:00
error = data . get ( " completion_error " , None )
) )
2024-10-15 22:43:31 +00:00
else :
# Try to parse the actual model response JSON
assert " outputs " in data and len ( data [ " outputs " ] ) > 0 , " No outputs from model detected "
try :
2024-10-25 22:10:54 +00:00
model_response_json = orjson . loads ( data [ " outputs " ] [ 0 ] [ " text " ] )
page_response = PageResponse ( * * model_response_json )
2024-10-15 22:43:31 +00:00
2024-10-17 20:18:06 +00:00
last_error = data . get ( " completion_error " , None )
2024-10-25 22:10:54 +00:00
if not page_response . is_rotation_valid :
2024-10-17 22:41:44 +00:00
last_error = " rotation_invalid "
2024-10-15 22:43:31 +00:00
index_entries . append ( DatabaseManager . BatchInferenceRecord (
inference_s3_path = inference_s3_path ,
pdf_s3_path = pdf_s3_path ,
page_num = page_num ,
round = data [ " round " ] ,
start_index = start_index , # Byte offset in the original file
length = line_length , # Length in bytes
finish_reason = data [ " outputs " ] [ 0 ] [ " finish_reason " ] ,
2024-10-17 20:18:06 +00:00
error = last_error ,
2024-10-15 22:43:31 +00:00
) )
2024-10-25 22:10:54 +00:00
except Exception as e :
error_type = type ( e ) . __name__
2024-10-15 22:43:31 +00:00
index_entries . append ( DatabaseManager . BatchInferenceRecord (
inference_s3_path = inference_s3_path ,
pdf_s3_path = pdf_s3_path ,
page_num = page_num ,
round = data [ " round " ] ,
start_index = start_index , # Byte offset in the original file
length = line_length , # Length in bytes
finish_reason = data [ " outputs " ] [ 0 ] [ " finish_reason " ] ,
2024-10-25 22:10:54 +00:00
error = error_type ,
2024-10-15 22:43:31 +00:00
) )
2024-10-14 20:31:37 +00:00
2024-10-11 22:41:09 +00:00
except Exception as e :
2024-10-25 22:10:54 +00:00
print ( f " Error processing line in { inference_s3_path } : { e } " )
# Optionally, you might want to add an index entry indicating an error here
2024-10-11 22:37:32 +00:00
2024-10-14 20:06:07 +00:00
start_index + = line_length # Increment by the number of bytes
2024-10-11 22:37:32 +00:00
2024-10-10 22:30:09 +00:00
return index_entries
2024-10-09 22:19:16 +00:00
2024-10-14 20:06:07 +00:00
2024-10-11 21:50:09 +00:00
def get_pdf_num_pages ( s3_path : str ) - > Optional [ int ] :
try :
with tempfile . NamedTemporaryFile ( " wb+ " , suffix = " .pdf " ) as tf :
2024-10-15 16:22:55 +00:00
tf . write ( get_s3_bytes ( pdf_s3 , s3_path ) )
2024-10-11 21:50:09 +00:00
tf . flush ( )
reader = PdfReader ( tf . name )
return reader . get_num_pages ( )
except Exception as ex :
print ( f " Warning, could not add { s3_path } due to { ex } " )
2024-10-11 22:41:09 +00:00
2024-10-11 21:50:09 +00:00
return None
2024-10-18 21:47:30 +00:00
def build_pdf_queries ( s3_workspace : str , pdf : DatabaseManager . PDFRecord , cur_round : int , target_longest_image_dim : int , target_anchor_text_len : int ) - > list [ dict ] :
2024-10-25 22:10:54 +00:00
db = DatabaseManager ( s3_workspace , skip_init = True )
2024-10-14 16:30:49 +00:00
existing_pages = db . get_index_entries ( pdf . s3_path )
new_queries = [ ]
2024-10-14 17:09:11 +00:00
# Shortcut out of downloading the actual PDF
if set ( page . page_num for page in existing_pages if page . is_usable ( ) ) == set ( range ( 1 , pdf . num_pages + 1 ) ) :
return [ ]
2024-10-14 18:42:50 +00:00
2024-10-11 22:57:49 +00:00
try :
with tempfile . NamedTemporaryFile ( " wb+ " , suffix = " .pdf " ) as tf :
2024-10-15 16:22:55 +00:00
tf . write ( get_s3_bytes ( pdf_s3 , pdf . s3_path ) )
2024-10-11 22:57:49 +00:00
tf . flush ( )
2024-10-14 17:23:09 +00:00
for target_page_num in range ( 1 , pdf . num_pages + 1 ) :
2024-10-14 16:30:49 +00:00
# Is there an existing page that has no error
2024-10-14 17:23:09 +00:00
if any ( page . is_usable ( ) and page . page_num == target_page_num for page in existing_pages ) :
2024-10-14 16:30:49 +00:00
continue
2024-10-14 21:37:14 +00:00
has_errored_previously = sum ( page . page_num == target_page_num for page in existing_pages )
if has_errored_previously :
2024-10-17 20:18:06 +00:00
# Retry the page at least one more time regularly
2024-10-18 21:47:30 +00:00
new_queries . append ( { * * build_page_query ( tf . name , pdf . s3_path , target_page_num , target_longest_image_dim , target_anchor_text_len ) , " round " : cur_round } )
2024-10-17 20:18:06 +00:00
# TODO: If the rotation was previously invalid, then apply a rotation
2024-10-14 21:37:14 +00:00
2024-10-17 20:18:06 +00:00
# TODO: Try to provide a smaller prompt hint
2024-10-14 21:37:14 +00:00
else :
2024-10-18 21:47:30 +00:00
new_queries . append ( { * * build_page_query ( tf . name , pdf . s3_path , target_page_num , target_longest_image_dim , target_anchor_text_len ) , " round " : cur_round } )
2024-10-11 22:57:49 +00:00
except Exception as ex :
2024-10-14 16:30:49 +00:00
print ( f " Warning, could not get batch inferences lines for { pdf . s3_path } due to { ex } " )
2024-10-11 22:57:49 +00:00
2024-10-14 16:30:49 +00:00
return new_queries
2024-10-11 22:57:49 +00:00
2024-10-14 20:06:07 +00:00
def build_dolma_doc ( s3_workspace : str , pdf : DatabaseManager . PDFRecord ) - > Optional [ dict ] :
2024-10-25 22:10:54 +00:00
db = DatabaseManager ( s3_workspace , skip_init = True )
2024-10-14 17:23:09 +00:00
existing_pages = db . get_index_entries ( pdf . s3_path )
document_text = " "
2024-10-14 18:19:17 +00:00
last_page_start_index = 0
pdf_page_spans = [ ]
2024-10-14 17:23:09 +00:00
for target_page_num in range ( 1 , pdf . num_pages + 1 ) :
2024-10-17 20:18:06 +00:00
usable_pages = [ page for page in existing_pages if page . is_usable ( ) and page . page_num == target_page_num ]
2024-10-14 17:23:09 +00:00
2024-10-17 20:18:06 +00:00
if len ( usable_pages ) == 0 :
2024-10-14 20:06:07 +00:00
return None
2024-10-17 20:18:06 +00:00
usable_page_data = [ get_s3_bytes ( workspace_s3 , page . inference_s3_path ,
start_index = page . start_index ,
end_index = page . start_index + page . length - 1 ) for page in usable_pages ]
2024-10-14 20:06:07 +00:00
2024-10-25 22:10:54 +00:00
usable_page_final_results = [ ]
for page_data in usable_page_data :
data = orjson . loads ( page_data )
model_response_json = orjson . loads ( data [ " outputs " ] [ 0 ] [ " text " ] )
page_response = PageResponse ( * * model_response_json )
usable_page_final_results . append ( page_response )
2024-10-14 17:23:09 +00:00
2024-10-17 20:18:06 +00:00
# Sort the pages:
# 1. Prefer pages with `is_rotation_valid` set to True.
# 2. Within those, sort by the length of the `natural_text` in descending order.
usable_page_final_results . sort (
2024-10-25 22:10:54 +00:00
key = lambda page : ( not page . is_rotation_valid , - len ( page . natural_text or " " ) )
2024-10-17 20:18:06 +00:00
)
2024-10-14 20:06:07 +00:00
2024-10-17 20:18:06 +00:00
target_page_final_result = usable_page_final_results [ 0 ]
2024-10-25 22:10:54 +00:00
if target_page_final_result . natural_text is not None :
document_text + = target_page_final_result . natural_text + " \n "
2024-10-14 18:19:17 +00:00
pdf_page_spans . append ( [ last_page_start_index , len ( document_text ) , target_page_num ] )
last_page_start_index = len ( document_text )
2024-10-14 17:23:09 +00:00
metadata = {
" Source-File " : pdf . s3_path ,
" pdf-total-pages " : pdf . num_pages ,
}
id_ = hashlib . sha1 ( document_text . encode ( ) ) . hexdigest ( )
dolma_doc = {
" id " : id_ ,
" text " : document_text ,
2024-10-14 18:19:17 +00:00
" source " : " pdelfin " ,
2024-10-14 17:23:09 +00:00
" added " : datetime . datetime . now ( ) . strftime ( " % Y- % m- %d " ) ,
" created " : datetime . datetime . now ( ) . strftime ( " % Y- % m- %d " ) ,
" metadata " : metadata ,
2024-10-14 18:19:17 +00:00
" attributes " : {
" pdf_page_numbers " : pdf_page_spans
}
2024-10-14 17:23:09 +00:00
}
return dolma_doc
2024-10-23 16:28:46 +00:00
def mark_pdfs_done ( s3_workspace : str , dolma_docs : list [ dict ] ) :
2024-10-25 22:10:54 +00:00
db = DatabaseManager ( s3_workspace , skip_init = True )
2024-10-14 20:28:54 +00:00
2024-10-23 16:28:46 +00:00
for doc in dolma_docs :
db . update_pdf_status ( doc [ " metadata " ] [ " Source-File " ] , " completed " )
2024-10-14 20:28:54 +00:00
2024-10-15 16:02:08 +00:00
def get_current_round ( s3_workspace : str ) - > int :
2024-10-16 16:18:27 +00:00
path = s3_workspace [ 5 : ]
bucket , _ , prefix = path . partition ( ' / ' )
2024-10-15 16:02:08 +00:00
inference_inputs_prefix = posixpath . join ( prefix , ' inference_inputs/ ' )
2024-10-15 16:22:55 +00:00
paginator = workspace_s3 . get_paginator ( ' list_objects_v2 ' )
2024-10-15 16:02:08 +00:00
page_iterator = paginator . paginate ( Bucket = bucket , Prefix = inference_inputs_prefix , Delimiter = ' / ' )
round_numbers = [ ]
for page in page_iterator :
for common_prefix in page . get ( ' CommonPrefixes ' , [ ] ) :
round_prefix = common_prefix . get ( ' Prefix ' )
# Extract 'round_X' from the prefix
round_dir = posixpath . basename ( posixpath . dirname ( round_prefix ) )
if round_dir . startswith ( ' round_ ' ) :
try :
round_num = int ( round_dir [ len ( ' round_ ' ) : ] )
round_numbers . append ( round_num )
except ValueError :
pass
if round_numbers :
current_round = max ( round_numbers ) + 1
else :
current_round = 0
return current_round
2024-10-14 17:23:09 +00:00
2024-10-10 22:10:26 +00:00
if __name__ == ' __main__ ' :
2024-10-11 20:22:58 +00:00
parser = argparse . ArgumentParser ( description = ' Manager for running millions of PDFs through a batch inference pipeline ' )
parser . add_argument ( ' workspace ' , help = ' The S3 path where work will be done e.g., s3://bucket/prefix/) ' )
2024-10-15 16:22:55 +00:00
parser . add_argument ( ' --add_pdfs ' , help = ' Path to add pdfs stored in s3 to the workspace, can be a glob path s3://bucket/prefix/*.pdf or path to file containing list of pdf paths ' , default = None )
2024-10-18 22:39:25 +00:00
parser . add_argument ( ' --target_longest_image_dim ' , type = int , help = ' Dimension on longest side to use for rendering the pdf pages ' , default = 1024 )
parser . add_argument ( ' --target_anchor_text_len ' , type = int , help = ' Maximum amount of anchor text to use (characters) ' , default = 6000 )
2024-10-15 16:22:55 +00:00
parser . add_argument ( ' --workspace_profile ' , help = ' S3 configuration profile for accessing the workspace ' , default = None )
parser . add_argument ( ' --pdf_profile ' , help = ' S3 configuration profile for accessing the raw pdf documents ' , default = None )
2024-10-14 17:09:11 +00:00
parser . add_argument ( ' --max_size_mb ' , type = int , default = 250 , help = ' Max file size in MB ' )
2024-10-25 22:10:54 +00:00
parser . add_argument ( ' --reindex ' , action = ' store_true ' , default = False , help = ' Reindex all of the page_results ' )
2024-10-27 21:17:48 +00:00
parser . add_argument ( ' --skip_build_queries ' , action = ' store_true ' , default = False , help = ' Skip generation of new pdf page queries for batch inferencing ' )
2024-10-09 22:14:28 +00:00
args = parser . parse_args ( )
2024-10-15 16:22:55 +00:00
if args . workspace_profile :
workspace_session = boto3 . Session ( profile_name = args . workspace_profile )
2024-10-15 16:36:54 +00:00
workspace_s3 = workspace_session . client ( " s3 " )
2024-10-15 16:22:55 +00:00
if args . pdf_profile :
pdf_session = boto3 . Session ( profile_name = args . pdf_profile )
2024-10-15 16:36:54 +00:00
pdf_s3 = pdf_session . client ( " s3 " )
2024-10-15 16:22:55 +00:00
2024-10-11 20:22:58 +00:00
db = DatabaseManager ( args . workspace )
print ( f " Loaded db at { db . db_path } " )
2024-10-15 16:02:08 +00:00
2024-10-25 20:32:51 +00:00
if args . reindex :
db . clear_index ( )
2024-10-15 16:02:08 +00:00
current_round = get_current_round ( args . workspace )
print ( f " Current round is { current_round } \n " )
2024-10-11 21:50:09 +00:00
# One shared executor to rule them all
2024-10-14 20:48:33 +00:00
executor = ProcessPoolExecutor ( )
2024-10-11 21:50:09 +00:00
2024-10-14 16:30:49 +00:00
# If you have new PDFs, step one is to add them to the list
2024-10-11 22:37:32 +00:00
if args . add_pdfs :
2024-10-15 16:22:55 +00:00
if args . add_pdfs . startswith ( " s3:// " ) :
print ( f " Querying all PDFs at { args . add_pdfs } " )
all_pdfs = expand_s3_glob ( pdf_s3 , args . add_pdfs )
print ( f " Found { len ( all_pdfs ) : , } total pdf paths " )
elif os . path . exists ( args . add_pdfs ) :
with open ( args . add_pdfs , " r " ) as f :
2024-10-15 16:36:54 +00:00
all_pdfs = [ line . strip ( ) for line in f . readlines ( ) if len ( line . strip ( ) ) > 0 ]
2024-10-15 16:22:55 +00:00
else :
raise ValueError ( " add_pdfs argument needs to be either an s3 glob search path, or a local file contains pdf paths (one per line) " )
2024-10-11 21:50:09 +00:00
all_pdfs = [ pdf for pdf in all_pdfs if not db . pdf_exists ( pdf ) ]
2024-10-14 20:06:07 +00:00
print ( f " Need to import { len ( all_pdfs ) : , } total new pdf paths " )
2024-10-11 21:50:09 +00:00
future_to_path = { executor . submit ( get_pdf_num_pages , s3_path ) : s3_path for s3_path in all_pdfs }
for future in tqdm ( as_completed ( future_to_path ) , total = len ( future_to_path ) ) :
s3_path = future_to_path [ future ]
2024-10-11 22:41:09 +00:00
num_pages = future . result ( )
if num_pages and not db . pdf_exists ( s3_path ) :
db . add_pdf ( s3_path , num_pages , " pending " )
2024-10-11 21:50:09 +00:00
print ( " \n " )
# Now build an index of all the pages that were processed within the workspace so far
2024-10-11 22:57:49 +00:00
print ( " Indexing all batch inference sent to this workspace " )
2024-10-15 16:22:55 +00:00
inference_output_paths = expand_s3_glob ( workspace_s3 , f " { args . workspace } /inference_outputs/*.jsonl " )
2024-10-11 22:37:32 +00:00
2024-10-23 16:28:46 +00:00
inference_output_paths = {
s3_path : etag for s3_path , etag in inference_output_paths . items ( )
2024-10-11 22:57:49 +00:00
if not db . is_file_processed ( s3_path , etag )
2024-10-23 16:28:46 +00:00
}
2024-10-11 22:37:32 +00:00
2024-10-14 20:06:07 +00:00
print ( f " Found { len ( inference_output_paths ) : , } new batch inference results to index " )
2024-10-23 16:28:46 +00:00
future_to_path = { executor . submit ( process_jsonl_content , s3_path ) : ( s3_path , etag ) for s3_path , etag in inference_output_paths . items ( ) }
2024-10-11 22:37:32 +00:00
for future in tqdm ( as_completed ( future_to_path ) , total = len ( future_to_path ) ) :
2024-10-11 22:41:09 +00:00
s3_path , etag = future_to_path [ future ]
2024-10-11 22:37:32 +00:00
2024-10-15 16:02:08 +00:00
try :
inference_records = future . result ( )
2024-10-22 15:47:11 +00:00
db . delete_index_entries_by_inference_s3_path ( s3_path )
2024-10-15 16:02:08 +00:00
db . add_index_entries ( inference_records )
db . update_processed_file ( s3_path , etag = etag )
except urllib3 . exceptions . SSLError :
print ( f " Cannot load inference file { s3_path } due to SSL error, will retry another time " )
2024-10-11 22:57:49 +00:00
# Now query each pdf, if you have all of the pages needed (all pages present, error is null and finish_reason is stop), then you assemble it into a dolma document and output it
2024-10-14 16:30:49 +00:00
# If you don't have every page, or if you have pages with errors, then you output a new batch of inference items to use
2024-10-15 16:02:08 +00:00
if db . get_last_indexed_round ( ) < current_round - 1 :
print ( f " WARNING: No new batch inference results found, you need to run batch inference on { args . workspace } /inference_inputs/round_ { current_round - 1 } " )
2024-10-14 20:06:07 +00:00
potentially_done_pdfs = db . get_pdfs_by_status ( " pending " )
2024-10-27 21:17:48 +00:00
elif args . skip_build_queries :
print ( f " Skipping generating new batch inference files " )
potentially_done_pdfs = db . get_pdfs_by_status ( " pending " )
2024-10-14 18:42:50 +00:00
else :
print ( f " \n Creating batch inference files for new PDFs " )
2024-10-23 16:28:46 +00:00
pdf_list = list ( db . get_pdfs_by_status ( " pending " ) )
pdf_iter = iter ( pdf_list )
pending_futures = { }
2024-10-14 18:42:50 +00:00
potentially_done_pdfs = [ ]
lines_written = 0
2024-10-15 16:02:08 +00:00
new_inference_writer = BatchWriter ( f " { args . workspace } /inference_inputs/round_ { current_round } " , args . max_size_mb )
2024-10-23 16:28:46 +00:00
total_pdfs = len ( pdf_list )
2024-10-27 21:17:48 +00:00
max_pending = 300
2024-10-23 16:28:46 +00:00
with tqdm ( total = total_pdfs ) as pbar :
# Submit initial batch of futures
for _ in range ( min ( max_pending , total_pdfs ) ) :
pdf = next ( pdf_iter )
future = executor . submit (
2024-10-23 21:51:54 +00:00
build_pdf_queries , args . workspace , pdf , current_round , args . target_longest_image_dim , args . target_anchor_text_len ,
2024-10-23 16:28:46 +00:00
)
pending_futures [ future ] = pdf
while pending_futures :
# Wait for the next future to complete
done , _ = concurrent . futures . wait (
pending_futures . keys ( ) ,
return_when = concurrent . futures . FIRST_COMPLETED ,
)
for future in done :
pdf = pending_futures . pop ( future )
inference_lines = future . result ( )
if len ( inference_lines ) == 0 :
potentially_done_pdfs . append ( pdf )
for line in inference_lines :
lines_written + = 1
if line is not None :
new_inference_writer . write_line ( line )
pbar . update ( 1 )
# Submit a new future if there are more PDFs
try :
pdf = next ( pdf_iter )
2024-10-23 21:51:54 +00:00
future = executor . submit (
build_pdf_queries , args . workspace , pdf , current_round , args . target_longest_image_dim , args . target_anchor_text_len ,
2024-10-23 16:28:46 +00:00
)
2024-10-23 21:51:54 +00:00
pending_futures [ future ] = pdf
2024-10-23 16:28:46 +00:00
except StopIteration :
pass # No more PDFs to process
2024-10-14 16:30:49 +00:00
2024-10-14 18:42:50 +00:00
new_inference_writer . close ( )
2024-10-14 16:30:49 +00:00
2024-10-14 18:42:50 +00:00
if lines_written > 0 :
2024-10-14 20:48:33 +00:00
print ( f " Added { lines_written : , } new batch inference requests " )
2024-10-14 17:23:09 +00:00
2024-10-14 17:09:11 +00:00
# Now, finally, assemble any potentially done docs into dolma documents
2024-10-15 16:54:19 +00:00
print ( f " \n Assembling potentially finished PDFs into Dolma documents at { args . workspace } /output " )
2024-10-14 17:23:09 +00:00
future_to_path = { executor . submit ( build_dolma_doc , args . workspace , pdf ) : pdf for pdf in potentially_done_pdfs }
2024-10-14 20:28:54 +00:00
new_output_writer = BatchWriter ( f " { args . workspace } /output " , args . max_size_mb , after_flush = partial ( mark_pdfs_done , args . workspace ) )
2024-10-14 16:30:49 +00:00
2024-10-14 17:23:09 +00:00
for future in tqdm ( as_completed ( future_to_path ) , total = len ( future_to_path ) ) :
pdf = future_to_path [ future ]
dolma_doc = future . result ( )
2024-10-14 20:28:54 +00:00
if dolma_doc is not None :
2024-10-23 16:28:46 +00:00
new_output_writer . write_line ( dolma_doc )
2024-10-14 17:23:09 +00:00
new_output_writer . close ( )
2024-10-14 18:42:50 +00:00
2024-10-14 20:28:54 +00:00
print ( " \n Final statistics: " )
2024-10-14 20:48:33 +00:00
# Output the number of PDFs in each status "pending" and "completed"
pending_pdfs = db . get_pdfs_by_status ( " pending " )
completed_pdfs = db . get_pdfs_by_status ( " completed " )
2024-10-15 22:26:31 +00:00
print ( f " Pending PDFs: { len ( pending_pdfs ) : , } ( { sum ( doc . num_pages for doc in pending_pdfs ) : , } pages) " )
print ( f " Completed PDFs: { len ( completed_pdfs ) : , } ( { sum ( doc . num_pages for doc in completed_pdfs ) : , } pages) " )
2024-10-14 20:48:33 +00:00
# For each round, outputs a report of how many pages were processed, how many had errors, and a breakdown by (error, finish_reason)
total_rounds = db . get_last_indexed_round ( ) + 1
for round_num in range ( total_rounds ) :
db . cursor . execute ( """
SELECT COUNT ( * ) , error , finish_reason
FROM page_results
WHERE round = ?
GROUP BY error , finish_reason
""" , (round_num,))
results = db . cursor . fetchall ( )
total_pages = sum ( count for count , _ , _ in results )
2024-10-14 21:37:14 +00:00
print ( f " \n Inference Round { round_num } - { total_pages : , } pages processed: " )
2024-10-14 20:48:33 +00:00
for count , error , finish_reason in results :
error_str = error if error is not None else " None "
print ( f " (error: { error_str } , finish_reason: { finish_reason } ) -> { count : , } pages " )
2024-10-14 20:31:37 +00:00
2024-10-14 18:42:50 +00:00
print ( " \n Work finished, waiting for all workers to finish cleaning up " )
executor . shutdown ( wait = True )
2024-10-14 20:28:54 +00:00
db . close ( )