2024-10-09 22:14:28 +00:00
import os
import hashlib
2024-10-10 22:10:26 +00:00
import boto3
2024-10-10 22:30:09 +00:00
import sqlite3
2024-10-10 22:10:26 +00:00
import json
import argparse
2024-10-11 21:50:09 +00:00
import glob
import tempfile
2024-10-14 17:23:09 +00:00
import datetime
2024-10-11 21:50:09 +00:00
import posixpath
2024-10-14 20:28:54 +00:00
import threading
2024-10-14 18:42:50 +00:00
import logging
2024-10-15 16:22:55 +00:00
import boto3 . session
2024-10-15 16:02:08 +00:00
import urllib3 . exceptions
2024-10-11 21:50:09 +00:00
2024-10-11 22:37:32 +00:00
from dataclasses import dataclass
2024-10-11 21:50:09 +00:00
from pypdf import PdfReader
2024-10-09 22:14:28 +00:00
from tqdm import tqdm
2024-10-14 20:28:54 +00:00
from functools import partial
from typing import Optional , List , Tuple , Dict , Callable , Any
2024-10-11 21:50:09 +00:00
from urllib . parse import urlparse
2024-10-10 22:30:09 +00:00
from concurrent . futures import ProcessPoolExecutor , as_completed
2024-10-09 22:14:28 +00:00
2024-10-11 22:57:49 +00:00
from pdelfin . data . renderpdf import render_pdf_to_base64png
from pdelfin . prompts import build_finetuning_prompt
from pdelfin . prompts . anchor import get_anchor_text
2024-10-16 16:18:27 +00:00
from pdelfin . s3_utils import parse_s3_path , expand_s3_glob , get_s3_bytes , put_s3_bytes
2024-10-11 22:57:49 +00:00
2024-10-11 21:50:09 +00:00
# Global s3 client for the whole script, feel free to adjust params if you need it
2024-10-15 16:22:55 +00:00
workspace_s3 = boto3 . client ( ' s3 ' )
pdf_s3 = boto3 . client ( ' s3 ' )
2024-10-11 21:50:09 +00:00
2024-10-14 18:42:50 +00:00
# Quiet logs from pypdf and smart open
logging . getLogger ( " pypdf " ) . setLevel ( logging . ERROR )
logging . getLogger ( " smart_open " ) . setLevel ( logging . ERROR )
2024-10-11 16:24:29 +00:00
class DatabaseManager :
2024-10-14 16:30:49 +00:00
@dataclass ( frozen = True )
class BatchInferenceRecord :
2024-10-14 18:19:17 +00:00
inference_s3_path : str
pdf_s3_path : str
2024-10-14 16:30:49 +00:00
page_num : int # 1 indexed!
2024-10-14 18:42:50 +00:00
round : int
2024-10-14 16:30:49 +00:00
start_index : int
length : int
finish_reason : str
error : Optional [ str ]
def is_usable ( self ) :
return self . error is None and self . finish_reason == " stop "
@dataclass ( frozen = True )
class PDFRecord :
s3_path : str
num_pages : int
status : str
2024-10-11 20:22:58 +00:00
def __init__ ( self , s3_workspace : str ) :
cache_key = hashlib . sha256 ( s3_workspace . strip ( ) . lower ( ) . encode ( ' utf-8 ' ) ) . hexdigest ( )
home_cache_dir = os . path . join ( os . path . expanduser ( ' ~ ' ) , ' .cache ' , ' pdelfin ' , cache_key )
os . makedirs ( home_cache_dir , exist_ok = True )
self . db_path = os . path . join ( home_cache_dir , ' index.db ' )
2024-10-11 16:24:29 +00:00
self . conn = sqlite3 . connect ( self . db_path )
self . cursor = self . conn . cursor ( )
self . _initialize_tables ( )
def _initialize_tables ( self ) :
self . cursor . execute ( """
2024-10-11 22:37:32 +00:00
CREATE TABLE IF NOT EXISTS page_results (
2024-10-14 18:19:17 +00:00
inference_s3_path TEXT ,
pdf_s3_path TEXT ,
2024-10-11 22:37:32 +00:00
page_num INTEGER ,
2024-10-14 18:42:50 +00:00
round INTEGER ,
2024-10-11 16:24:29 +00:00
start_index BIGINT ,
2024-10-11 22:37:32 +00:00
length BIGINT ,
2024-10-11 22:41:09 +00:00
finish_reason TEXT ,
error TEXT
2024-10-11 16:24:29 +00:00
)
""" )
2024-10-11 20:22:58 +00:00
self . cursor . execute ( """
2024-10-14 18:23:22 +00:00
CREATE INDEX IF NOT EXISTS idx_path ON page_results ( pdf_s3_path )
2024-10-11 20:22:58 +00:00
""" )
self . cursor . execute ( """
CREATE TABLE IF NOT EXISTS pdfs (
s3_path TEXT PRIMARY KEY ,
num_pages INTEGER ,
status TEXT DEFAULT ' pending '
)
""" )
2024-10-11 16:24:29 +00:00
self . cursor . execute ( """
CREATE TABLE IF NOT EXISTS processed_files (
s3_path TEXT PRIMARY KEY ,
etag TEXT
)
""" )
2024-10-11 20:22:58 +00:00
# Generic metadata such as current round
self . cursor . execute ( """
CREATE TABLE IF NOT EXISTS metadata (
key TEXT PRIMARY KEY ,
value TEXT
)
""" )
2024-10-11 22:37:32 +00:00
2024-10-11 16:24:29 +00:00
self . conn . commit ( )
2024-10-11 22:41:09 +00:00
def get_metadata ( self , key : str ) - > Optional [ str ] :
2024-10-11 22:37:32 +00:00
self . cursor . execute ( " SELECT value FROM metadata WHERE key=? " , ( key , ) )
2024-10-11 20:22:58 +00:00
result = self . cursor . fetchone ( )
2024-10-11 22:41:09 +00:00
return result [ 0 ] if result else None
2024-10-14 17:23:09 +00:00
def set_metadata ( self , key : str , value : str ) - > None :
self . cursor . execute ( """
INSERT INTO metadata ( key , value )
VALUES ( ? , ? )
ON CONFLICT ( key ) DO UPDATE SET value = excluded . value
""" , (key, value))
self . conn . commit ( )
2024-10-11 22:37:32 +00:00
2024-10-11 16:24:29 +00:00
def is_file_processed ( self , s3_path , etag ) :
self . cursor . execute ( " SELECT etag FROM processed_files WHERE s3_path = ? " , ( s3_path , ) )
result = self . cursor . fetchone ( )
return result is not None and result [ 0 ] == etag
2024-10-14 20:28:54 +00:00
def update_processed_file ( self , s3_path , etag ) :
self . cursor . execute ( """
INSERT INTO processed_files ( s3_path , etag )
VALUES ( ? , ? )
ON CONFLICT ( s3_path ) DO UPDATE SET etag = excluded . etag
""" , (s3_path, etag))
self . conn . commit ( )
2024-10-11 16:24:29 +00:00
2024-10-14 16:30:49 +00:00
def add_index_entries ( self , index_entries : List [ BatchInferenceRecord ] ) :
2024-10-11 16:24:29 +00:00
if index_entries :
self . cursor . executemany ( """
2024-10-14 18:42:50 +00:00
INSERT INTO page_results ( inference_s3_path , pdf_s3_path , page_num , round , start_index , length , finish_reason , error )
2024-10-14 20:06:07 +00:00
VALUES ( ? , ? , ? , ? , ? , ? , ? , ? )
2024-10-14 18:42:50 +00:00
""" , [(entry.inference_s3_path, entry.pdf_s3_path, entry.page_num, entry.round, entry.start_index, entry.length, entry.finish_reason, entry.error) for entry in index_entries])
2024-10-11 16:24:29 +00:00
self . conn . commit ( )
2024-10-14 18:23:22 +00:00
def get_index_entries ( self , pdf_s3_path : str ) - > List [ BatchInferenceRecord ] :
2024-10-14 16:30:49 +00:00
self . cursor . execute ( """
2024-10-14 18:42:50 +00:00
SELECT inference_s3_path , pdf_s3_path , page_num , round , start_index , length , finish_reason , error
2024-10-14 16:30:49 +00:00
FROM page_results
2024-10-14 18:23:22 +00:00
WHERE pdf_s3_path = ?
ORDER BY inference_s3_path DESC , start_index ASC , page_num ASC
""" , (pdf_s3_path,))
2024-10-14 16:30:49 +00:00
rows = self . cursor . fetchall ( )
return [
self . BatchInferenceRecord (
2024-10-14 18:19:17 +00:00
inference_s3_path = row [ 0 ] ,
pdf_s3_path = row [ 1 ] ,
page_num = row [ 2 ] ,
2024-10-14 18:42:50 +00:00
round = row [ 3 ] ,
start_index = row [ 4 ] ,
length = row [ 5 ] ,
finish_reason = row [ 6 ] ,
error = row [ 7 ]
2024-10-14 16:30:49 +00:00
)
for row in rows
]
2024-10-14 18:42:50 +00:00
def get_last_indexed_round ( self ) - > int :
self . cursor . execute ( """
SELECT MAX ( round )
FROM page_results
""" )
result = self . cursor . fetchone ( )
return - 1 if result [ 0 ] is None else result [ 0 ]
2024-10-11 21:50:09 +00:00
def pdf_exists ( self , s3_path : str ) - > bool :
self . cursor . execute ( " SELECT 1 FROM pdfs WHERE s3_path = ? " , ( s3_path , ) )
return self . cursor . fetchone ( ) is not None
def add_pdf ( self , s3_path : str , num_pages : int , status : str = ' pending ' ) - > None :
try :
self . cursor . execute ( """
INSERT INTO pdfs ( s3_path , num_pages , status )
VALUES ( ? , ? , ? )
""" , (s3_path, num_pages, status))
self . conn . commit ( )
except sqlite3 . IntegrityError :
print ( f " PDF with s3_path ' { s3_path } ' already exists. " )
2024-10-14 20:28:54 +00:00
def update_pdf_status ( self , s3_path : str , new_status : str ) - > None :
self . cursor . execute ( """
UPDATE pdfs
SET status = ?
WHERE s3_path = ?
""" , (new_status, s3_path))
self . conn . commit ( )
2024-10-14 16:30:49 +00:00
def get_pdf ( self , s3_path : str ) - > Optional [ PDFRecord ] :
self . cursor . execute ( """
SELECT s3_path , num_pages , status
FROM pdfs
WHERE s3_path = ?
""" , (s3_path,))
row = self . cursor . fetchone ( )
if row :
return self . PDFRecord (
s3_path = row [ 0 ] ,
num_pages = row [ 1 ] ,
status = row [ 2 ]
)
return None
def get_pdfs_by_status ( self , status : str ) - > List [ PDFRecord ] :
self . cursor . execute ( """
SELECT s3_path , num_pages , status
FROM pdfs
WHERE status == ?
2024-10-15 16:02:08 +00:00
ORDER BY s3_path DESC , num_pages DESC
2024-10-14 16:30:49 +00:00
""" , (status, ))
rows = self . cursor . fetchall ( )
return [
self . PDFRecord (
s3_path = row [ 0 ] ,
num_pages = row [ 1 ] ,
status = row [ 2 ]
)
for row in rows
]
2024-10-11 21:50:09 +00:00
2024-10-11 16:24:29 +00:00
def close ( self ) :
self . conn . close ( )
2024-10-11 22:57:49 +00:00
2024-10-14 17:09:11 +00:00
# Writes batches of lines out to a set of files, keeping each file below some maximum size
class BatchWriter :
2024-10-14 20:28:54 +00:00
def __init__ ( self , output_prefix : str , max_size_mb : int = 250 , after_flush : Optional [ Callable [ [ List [ str ] ] , Any ] ] = None ) :
2024-10-14 17:09:11 +00:00
self . output_prefix = output_prefix
self . max_size = max_size_mb * 1024 * 1024 # Convert MB to bytes
self . batch = [ ]
self . batch_size = 0
2024-10-14 20:28:54 +00:00
self . after_flush = after_flush
self . threads = [ ]
2024-10-14 17:09:11 +00:00
parsed = urlparse ( output_prefix )
self . is_s3 = parsed . scheme in ( ' s3 ' , ' s3a ' , ' s3n ' )
if not self . is_s3 :
os . makedirs ( output_prefix , exist_ok = True )
def _compute_hash ( self , content : str ) - > str :
""" Compute a 20-character SHA1 hash of the given content. """
sha1 = hashlib . sha1 ( )
sha1 . update ( content . encode ( ' utf-8 ' ) )
return sha1 . hexdigest ( ) [ : 20 ]
def _get_output_path ( self , hash_str : str ) - > str :
""" Generate the full output path with hash in the filename. """
parsed = urlparse ( self . output_prefix )
if self . is_s3 :
bucket = parsed . netloc
key = parsed . path . lstrip ( ' / ' )
if key and not key . endswith ( ' / ' ) :
key + = ' / '
full_key = posixpath . join ( key , f " output_ { hash_str } .jsonl " )
return f " s3:// { bucket } / { full_key } "
else :
filename = f " output_ { hash_str } .jsonl "
return os . path . join ( self . output_prefix , filename )
2024-10-14 20:06:07 +00:00
def write_line ( self , line : Optional [ str ] ) :
2024-10-14 20:28:54 +00:00
if line is None or not line . strip ( ) :
2024-10-14 20:06:07 +00:00
return
2024-10-14 20:28:54 +00:00
2024-10-14 17:09:11 +00:00
line_size = len ( line . encode ( ' utf-8 ' ) ) + 1 # +1 for newline
if self . batch_size + line_size > self . max_size :
self . _write_batch ( )
2024-10-14 20:28:54 +00:00
2024-10-14 17:09:11 +00:00
self . batch . append ( line )
self . batch_size + = line_size
def _write_batch ( self ) :
if not self . batch :
return
2024-10-14 20:28:54 +00:00
batch_lines = self . batch . copy ( )
batch_content = " \n " . join ( batch_lines ) + " \n "
2024-10-14 17:09:11 +00:00
hash_str = self . _compute_hash ( batch_content )
output_path = self . _get_output_path ( hash_str )
2024-10-14 16:30:49 +00:00
2024-10-14 20:28:54 +00:00
# Start a new thread to write the batch
thread = threading . Thread (
target = self . _write_batch_to_file ,
args = ( batch_content , output_path , batch_lines )
)
thread . start ( )
self . threads . append ( thread )
2024-10-14 17:09:11 +00:00
2024-10-14 20:28:54 +00:00
# Clear the batch and batch_size
2024-10-14 17:09:11 +00:00
self . batch = [ ]
self . batch_size = 0
2024-10-14 20:28:54 +00:00
def _write_batch_to_file ( self , batch_content : str , output_path : str , batch_lines : List [ str ] ) :
if self . is_s3 :
2024-10-16 16:18:27 +00:00
put_s3_bytes ( workspace_s3 , output_path , batch_content . encode ( " utf-8 " ) )
2024-10-14 20:28:54 +00:00
else :
with open ( output_path , ' w ' , encoding = ' utf-8 ' ) as f_out :
f_out . write ( batch_content )
# After writing, call the after_flush callback if it is set
if self . after_flush :
self . after_flush ( batch_lines )
2024-10-14 17:09:11 +00:00
def close ( self ) :
self . _write_batch ( )
2024-10-14 20:28:54 +00:00
# Wait for all threads to finish
for thread in self . threads :
thread . join ( )
2024-10-11 22:57:49 +00:00
def build_page_query ( local_pdf_path : str , pretty_pdf_path : str , page : int ) - > dict :
image_base64 = render_pdf_to_base64png ( local_pdf_path , page , 1024 )
anchor_text = get_anchor_text ( local_pdf_path , page , pdf_engine = " pdfreport " )
return {
" custom_id " : f " { pretty_pdf_path } - { page } " ,
" chat_messages " : [
{
" role " : " user " ,
" content " : [
{ " type " : " text " , " text " : build_finetuning_prompt ( anchor_text ) } ,
{ " type " : " image_url " , " image_url " : { " url " : f " data:image/png;base64, { image_base64 } " } }
] ,
}
] ,
}
2024-10-11 22:41:09 +00:00
def parse_custom_id ( custom_id : str ) - > Tuple [ str , int ] :
2024-10-11 22:37:32 +00:00
s3_path = custom_id [ : custom_id . rindex ( " - " ) ]
page_num = int ( custom_id [ custom_id . rindex ( " - " ) + 1 : ] )
return s3_path , page_num
2024-10-14 20:06:07 +00:00
def process_jsonl_content ( inference_s3_path : str ) - > List [ DatabaseManager . BatchInferenceRecord ] :
2024-10-15 16:22:55 +00:00
content_bytes = get_s3_bytes ( workspace_s3 , inference_s3_path )
2024-10-11 22:37:32 +00:00
2024-10-10 22:10:26 +00:00
start_index = 0
2024-10-10 22:30:09 +00:00
index_entries = [ ]
2024-10-14 20:06:07 +00:00
lines = content_bytes . splitlines ( keepends = True ) # Split content into lines as bytes
2024-10-10 22:10:26 +00:00
for line in lines :
2024-10-14 20:06:07 +00:00
line_length = len ( line ) # Length in bytes
2024-10-11 22:37:32 +00:00
2024-10-09 22:19:16 +00:00
try :
2024-10-14 20:06:07 +00:00
# Decode the line for JSON processing
line_str = line . decode ( ' utf-8 ' )
data = json . loads ( line_str )
2024-10-14 18:19:17 +00:00
pdf_s3_path , page_num = parse_custom_id ( data [ " custom_id " ] )
2024-10-11 22:37:32 +00:00
2024-10-15 22:43:31 +00:00
if data . get ( " completion_error " , None ) is not None :
2024-10-14 20:31:37 +00:00
index_entries . append ( DatabaseManager . BatchInferenceRecord (
inference_s3_path = inference_s3_path ,
pdf_s3_path = pdf_s3_path ,
page_num = page_num ,
round = data [ " round " ] ,
start_index = start_index , # Byte offset in the original file
length = line_length , # Length in bytes
2024-10-15 22:43:31 +00:00
finish_reason = " completion_error " ,
2024-10-14 20:31:37 +00:00
error = data . get ( " completion_error " , None )
) )
2024-10-15 22:43:31 +00:00
else :
# Try to parse the actual model response JSON
assert " outputs " in data and len ( data [ " outputs " ] ) > 0 , " No outputs from model detected "
try :
model_response_json = json . loads ( data [ " outputs " ] [ 0 ] [ " text " ] )
index_entries . append ( DatabaseManager . BatchInferenceRecord (
inference_s3_path = inference_s3_path ,
pdf_s3_path = pdf_s3_path ,
page_num = page_num ,
round = data [ " round " ] ,
start_index = start_index , # Byte offset in the original file
length = line_length , # Length in bytes
finish_reason = data [ " outputs " ] [ 0 ] [ " finish_reason " ] ,
error = data . get ( " completion_error " , None )
) )
except json . JSONDecodeError :
index_entries . append ( DatabaseManager . BatchInferenceRecord (
inference_s3_path = inference_s3_path ,
pdf_s3_path = pdf_s3_path ,
page_num = page_num ,
round = data [ " round " ] ,
start_index = start_index , # Byte offset in the original file
length = line_length , # Length in bytes
finish_reason = data [ " outputs " ] [ 0 ] [ " finish_reason " ] ,
error = " Could not parse model JSON output " ,
) )
2024-10-14 20:31:37 +00:00
2024-10-10 22:10:26 +00:00
except json . JSONDecodeError :
2024-10-15 16:02:08 +00:00
print ( f " Error with JSON Decoding of inference in { inference_s3_path } " )
# TODO Maybe this needs to add an index error that this json is bad
2024-10-11 22:41:09 +00:00
except Exception as e :
print ( f " Error processing line: { e } " )
2024-10-11 22:37:32 +00:00
2024-10-14 20:06:07 +00:00
start_index + = line_length # Increment by the number of bytes
2024-10-11 22:37:32 +00:00
2024-10-10 22:30:09 +00:00
return index_entries
2024-10-09 22:19:16 +00:00
2024-10-14 20:06:07 +00:00
2024-10-11 21:50:09 +00:00
def get_pdf_num_pages ( s3_path : str ) - > Optional [ int ] :
try :
with tempfile . NamedTemporaryFile ( " wb+ " , suffix = " .pdf " ) as tf :
2024-10-15 16:22:55 +00:00
tf . write ( get_s3_bytes ( pdf_s3 , s3_path ) )
2024-10-11 21:50:09 +00:00
tf . flush ( )
reader = PdfReader ( tf . name )
return reader . get_num_pages ( )
except Exception as ex :
print ( f " Warning, could not add { s3_path } due to { ex } " )
2024-10-11 22:41:09 +00:00
2024-10-11 21:50:09 +00:00
return None
2024-10-15 16:02:08 +00:00
def build_pdf_queries ( s3_workspace : str , pdf : DatabaseManager . PDFRecord , cur_round : int ) - > list [ dict ] :
2024-10-14 16:30:49 +00:00
db = DatabaseManager ( s3_workspace )
existing_pages = db . get_index_entries ( pdf . s3_path )
new_queries = [ ]
2024-10-14 17:09:11 +00:00
# Shortcut out of downloading the actual PDF
if set ( page . page_num for page in existing_pages if page . is_usable ( ) ) == set ( range ( 1 , pdf . num_pages + 1 ) ) :
return [ ]
2024-10-14 18:42:50 +00:00
2024-10-11 22:57:49 +00:00
try :
with tempfile . NamedTemporaryFile ( " wb+ " , suffix = " .pdf " ) as tf :
2024-10-15 16:22:55 +00:00
tf . write ( get_s3_bytes ( pdf_s3 , pdf . s3_path ) )
2024-10-11 22:57:49 +00:00
tf . flush ( )
2024-10-14 17:23:09 +00:00
for target_page_num in range ( 1 , pdf . num_pages + 1 ) :
2024-10-14 16:30:49 +00:00
# Is there an existing page that has no error
2024-10-14 17:23:09 +00:00
if any ( page . is_usable ( ) and page . page_num == target_page_num for page in existing_pages ) :
2024-10-14 16:30:49 +00:00
continue
2024-10-14 21:37:14 +00:00
has_errored_previously = sum ( page . page_num == target_page_num for page in existing_pages )
if has_errored_previously :
2024-10-15 16:02:08 +00:00
# Retry the page up to 3 times
for _ in range ( 3 ) :
new_queries . append ( { * * build_page_query ( tf . name , pdf . s3_path , target_page_num ) , " round " : cur_round } )
2024-10-14 21:37:14 +00:00
2024-10-15 16:02:08 +00:00
# Optionally, you can implement more complex retry logic here
2024-10-14 21:37:14 +00:00
else :
new_queries . append ( { * * build_page_query ( tf . name , pdf . s3_path , target_page_num ) , " round " : cur_round } )
2024-10-11 22:57:49 +00:00
except Exception as ex :
2024-10-14 16:30:49 +00:00
print ( f " Warning, could not get batch inferences lines for { pdf . s3_path } due to { ex } " )
2024-10-11 22:57:49 +00:00
2024-10-14 16:30:49 +00:00
return new_queries
2024-10-11 22:57:49 +00:00
2024-10-14 20:06:07 +00:00
def build_dolma_doc ( s3_workspace : str , pdf : DatabaseManager . PDFRecord ) - > Optional [ dict ] :
2024-10-14 17:23:09 +00:00
db = DatabaseManager ( s3_workspace )
existing_pages = db . get_index_entries ( pdf . s3_path )
document_text = " "
2024-10-14 18:19:17 +00:00
last_page_start_index = 0
pdf_page_spans = [ ]
2024-10-14 17:23:09 +00:00
for target_page_num in range ( 1 , pdf . num_pages + 1 ) :
2024-10-14 20:06:07 +00:00
target_pages = [ page for page in existing_pages if page . is_usable ( ) and page . page_num == target_page_num ]
2024-10-14 17:23:09 +00:00
2024-10-14 20:06:07 +00:00
if len ( target_pages ) == 0 :
return None
target_page = target_pages [ 0 ]
2024-10-15 16:22:55 +00:00
target_row = get_s3_bytes ( workspace_s3 , target_page . inference_s3_path ,
start_index = target_page . start_index ,
end_index = target_page . start_index + target_page . length - 1 )
2024-10-14 17:23:09 +00:00
target_data = json . loads ( target_row . decode ( " utf-8 " ) )
2024-10-14 20:06:07 +00:00
target_output = json . loads ( target_data [ " outputs " ] [ 0 ] [ " text " ] )
if target_output [ " natural_text " ] is not None :
document_text + = target_output [ " natural_text " ] + " \n "
2024-10-14 18:19:17 +00:00
pdf_page_spans . append ( [ last_page_start_index , len ( document_text ) , target_page_num ] )
last_page_start_index = len ( document_text )
2024-10-14 17:23:09 +00:00
metadata = {
" Source-File " : pdf . s3_path ,
" pdf-total-pages " : pdf . num_pages ,
}
id_ = hashlib . sha1 ( document_text . encode ( ) ) . hexdigest ( )
dolma_doc = {
" id " : id_ ,
" text " : document_text ,
2024-10-14 18:19:17 +00:00
" source " : " pdelfin " ,
2024-10-14 17:23:09 +00:00
" added " : datetime . datetime . now ( ) . strftime ( " % Y- % m- %d " ) ,
" created " : datetime . datetime . now ( ) . strftime ( " % Y- % m- %d " ) ,
" metadata " : metadata ,
2024-10-14 18:19:17 +00:00
" attributes " : {
" pdf_page_numbers " : pdf_page_spans
}
2024-10-14 17:23:09 +00:00
}
return dolma_doc
2024-10-14 20:28:54 +00:00
def mark_pdfs_done ( s3_workspace : str , dolma_doc_lines : list [ str ] ) :
db = DatabaseManager ( s3_workspace )
for line in dolma_doc_lines :
db . update_pdf_status ( json . loads ( line ) [ " metadata " ] [ " Source-File " ] , " completed " )
2024-10-15 16:02:08 +00:00
def get_current_round ( s3_workspace : str ) - > int :
2024-10-16 16:18:27 +00:00
path = s3_workspace [ 5 : ]
bucket , _ , prefix = path . partition ( ' / ' )
2024-10-15 16:02:08 +00:00
inference_inputs_prefix = posixpath . join ( prefix , ' inference_inputs/ ' )
2024-10-15 16:22:55 +00:00
paginator = workspace_s3 . get_paginator ( ' list_objects_v2 ' )
2024-10-15 16:02:08 +00:00
page_iterator = paginator . paginate ( Bucket = bucket , Prefix = inference_inputs_prefix , Delimiter = ' / ' )
round_numbers = [ ]
for page in page_iterator :
for common_prefix in page . get ( ' CommonPrefixes ' , [ ] ) :
round_prefix = common_prefix . get ( ' Prefix ' )
# Extract 'round_X' from the prefix
round_dir = posixpath . basename ( posixpath . dirname ( round_prefix ) )
if round_dir . startswith ( ' round_ ' ) :
try :
round_num = int ( round_dir [ len ( ' round_ ' ) : ] )
round_numbers . append ( round_num )
except ValueError :
pass
if round_numbers :
current_round = max ( round_numbers ) + 1
else :
current_round = 0
return current_round
2024-10-14 17:23:09 +00:00
2024-10-10 22:10:26 +00:00
if __name__ == ' __main__ ' :
2024-10-11 20:22:58 +00:00
parser = argparse . ArgumentParser ( description = ' Manager for running millions of PDFs through a batch inference pipeline ' )
parser . add_argument ( ' workspace ' , help = ' The S3 path where work will be done e.g., s3://bucket/prefix/) ' )
2024-10-15 16:22:55 +00:00
parser . add_argument ( ' --add_pdfs ' , help = ' Path to add pdfs stored in s3 to the workspace, can be a glob path s3://bucket/prefix/*.pdf or path to file containing list of pdf paths ' , default = None )
parser . add_argument ( ' --workspace_profile ' , help = ' S3 configuration profile for accessing the workspace ' , default = None )
parser . add_argument ( ' --pdf_profile ' , help = ' S3 configuration profile for accessing the raw pdf documents ' , default = None )
2024-10-14 17:09:11 +00:00
parser . add_argument ( ' --max_size_mb ' , type = int , default = 250 , help = ' Max file size in MB ' )
2024-10-09 22:14:28 +00:00
args = parser . parse_args ( )
2024-10-15 16:22:55 +00:00
if args . workspace_profile :
workspace_session = boto3 . Session ( profile_name = args . workspace_profile )
2024-10-15 16:36:54 +00:00
workspace_s3 = workspace_session . client ( " s3 " )
2024-10-15 16:22:55 +00:00
if args . pdf_profile :
pdf_session = boto3 . Session ( profile_name = args . pdf_profile )
2024-10-15 16:36:54 +00:00
pdf_s3 = pdf_session . client ( " s3 " )
2024-10-15 16:22:55 +00:00
2024-10-11 20:22:58 +00:00
db = DatabaseManager ( args . workspace )
print ( f " Loaded db at { db . db_path } " )
2024-10-15 16:02:08 +00:00
current_round = get_current_round ( args . workspace )
print ( f " Current round is { current_round } \n " )
2024-10-11 21:50:09 +00:00
# One shared executor to rule them all
2024-10-14 20:48:33 +00:00
executor = ProcessPoolExecutor ( )
2024-10-11 21:50:09 +00:00
2024-10-14 16:30:49 +00:00
# If you have new PDFs, step one is to add them to the list
2024-10-11 22:37:32 +00:00
if args . add_pdfs :
2024-10-15 16:22:55 +00:00
if args . add_pdfs . startswith ( " s3:// " ) :
print ( f " Querying all PDFs at { args . add_pdfs } " )
all_pdfs = expand_s3_glob ( pdf_s3 , args . add_pdfs )
print ( f " Found { len ( all_pdfs ) : , } total pdf paths " )
elif os . path . exists ( args . add_pdfs ) :
with open ( args . add_pdfs , " r " ) as f :
2024-10-15 16:36:54 +00:00
all_pdfs = [ line . strip ( ) for line in f . readlines ( ) if len ( line . strip ( ) ) > 0 ]
2024-10-15 16:22:55 +00:00
else :
raise ValueError ( " add_pdfs argument needs to be either an s3 glob search path, or a local file contains pdf paths (one per line) " )
2024-10-11 21:50:09 +00:00
all_pdfs = [ pdf for pdf in all_pdfs if not db . pdf_exists ( pdf ) ]
2024-10-14 20:06:07 +00:00
print ( f " Need to import { len ( all_pdfs ) : , } total new pdf paths " )
2024-10-11 21:50:09 +00:00
future_to_path = { executor . submit ( get_pdf_num_pages , s3_path ) : s3_path for s3_path in all_pdfs }
for future in tqdm ( as_completed ( future_to_path ) , total = len ( future_to_path ) ) :
s3_path = future_to_path [ future ]
2024-10-11 22:41:09 +00:00
num_pages = future . result ( )
if num_pages and not db . pdf_exists ( s3_path ) :
db . add_pdf ( s3_path , num_pages , " pending " )
2024-10-11 21:50:09 +00:00
print ( " \n " )
# Now build an index of all the pages that were processed within the workspace so far
2024-10-11 22:57:49 +00:00
print ( " Indexing all batch inference sent to this workspace " )
2024-10-15 16:22:55 +00:00
inference_output_paths = expand_s3_glob ( workspace_s3 , f " { args . workspace } /inference_outputs/*.jsonl " )
2024-10-11 22:37:32 +00:00
inference_output_paths = [
2024-10-11 22:57:49 +00:00
( s3_path , etag ) for s3_path , etag in inference_output_paths . items ( )
if not db . is_file_processed ( s3_path , etag )
2024-10-11 22:37:32 +00:00
]
2024-10-14 20:06:07 +00:00
print ( f " Found { len ( inference_output_paths ) : , } new batch inference results to index " )
2024-10-11 22:41:09 +00:00
future_to_path = { executor . submit ( process_jsonl_content , s3_path ) : ( s3_path , etag ) for s3_path , etag in inference_output_paths }
2024-10-11 22:37:32 +00:00
for future in tqdm ( as_completed ( future_to_path ) , total = len ( future_to_path ) ) :
2024-10-11 22:41:09 +00:00
s3_path , etag = future_to_path [ future ]
2024-10-11 22:37:32 +00:00
2024-10-15 16:02:08 +00:00
try :
inference_records = future . result ( )
db . add_index_entries ( inference_records )
db . update_processed_file ( s3_path , etag = etag )
except urllib3 . exceptions . SSLError :
print ( f " Cannot load inference file { s3_path } due to SSL error, will retry another time " )
2024-10-11 22:57:49 +00:00
# Now query each pdf, if you have all of the pages needed (all pages present, error is null and finish_reason is stop), then you assemble it into a dolma document and output it
2024-10-14 16:30:49 +00:00
# If you don't have every page, or if you have pages with errors, then you output a new batch of inference items to use
2024-10-15 16:02:08 +00:00
if db . get_last_indexed_round ( ) < current_round - 1 :
print ( f " WARNING: No new batch inference results found, you need to run batch inference on { args . workspace } /inference_inputs/round_ { current_round - 1 } " )
2024-10-14 20:06:07 +00:00
potentially_done_pdfs = db . get_pdfs_by_status ( " pending " )
2024-10-14 18:42:50 +00:00
else :
print ( f " \n Creating batch inference files for new PDFs " )
2024-10-15 16:02:08 +00:00
future_to_path = { executor . submit ( build_pdf_queries , args . workspace , pdf , current_round ) : pdf for pdf in db . get_pdfs_by_status ( " pending " ) }
2024-10-14 18:42:50 +00:00
potentially_done_pdfs = [ ]
lines_written = 0
2024-10-15 16:02:08 +00:00
new_inference_writer = BatchWriter ( f " { args . workspace } /inference_inputs/round_ { current_round } " , args . max_size_mb )
2024-10-14 16:30:49 +00:00
2024-10-14 18:42:50 +00:00
for future in tqdm ( as_completed ( future_to_path ) , total = len ( future_to_path ) ) :
pdf = future_to_path [ future ]
inference_lines = future . result ( )
2024-10-14 16:30:49 +00:00
2024-10-14 18:42:50 +00:00
if len ( inference_lines ) == 0 :
potentially_done_pdfs . append ( pdf )
2024-10-14 17:09:11 +00:00
2024-10-14 18:42:50 +00:00
for line in inference_lines :
lines_written + = 1
2024-10-14 20:28:54 +00:00
if line is not None :
new_inference_writer . write_line ( json . dumps ( line ) )
2024-10-14 16:30:49 +00:00
2024-10-14 18:42:50 +00:00
new_inference_writer . close ( )
2024-10-14 16:30:49 +00:00
2024-10-14 18:42:50 +00:00
if lines_written > 0 :
2024-10-14 20:48:33 +00:00
print ( f " Added { lines_written : , } new batch inference requests " )
2024-10-14 17:23:09 +00:00
2024-10-14 17:09:11 +00:00
# Now, finally, assemble any potentially done docs into dolma documents
2024-10-15 16:54:19 +00:00
print ( f " \n Assembling potentially finished PDFs into Dolma documents at { args . workspace } /output " )
2024-10-14 17:23:09 +00:00
future_to_path = { executor . submit ( build_dolma_doc , args . workspace , pdf ) : pdf for pdf in potentially_done_pdfs }
2024-10-14 20:28:54 +00:00
new_output_writer = BatchWriter ( f " { args . workspace } /output " , args . max_size_mb , after_flush = partial ( mark_pdfs_done , args . workspace ) )
2024-10-14 16:30:49 +00:00
2024-10-14 17:23:09 +00:00
for future in tqdm ( as_completed ( future_to_path ) , total = len ( future_to_path ) ) :
pdf = future_to_path [ future ]
dolma_doc = future . result ( )
2024-10-14 20:28:54 +00:00
if dolma_doc is not None :
new_output_writer . write_line ( json . dumps ( dolma_doc ) )
2024-10-14 17:23:09 +00:00
new_output_writer . close ( )
2024-10-14 18:42:50 +00:00
2024-10-14 20:28:54 +00:00
print ( " \n Final statistics: " )
2024-10-14 20:48:33 +00:00
# Output the number of PDFs in each status "pending" and "completed"
pending_pdfs = db . get_pdfs_by_status ( " pending " )
completed_pdfs = db . get_pdfs_by_status ( " completed " )
2024-10-15 22:26:31 +00:00
print ( f " Pending PDFs: { len ( pending_pdfs ) : , } ( { sum ( doc . num_pages for doc in pending_pdfs ) : , } pages) " )
print ( f " Completed PDFs: { len ( completed_pdfs ) : , } ( { sum ( doc . num_pages for doc in completed_pdfs ) : , } pages) " )
2024-10-14 20:48:33 +00:00
# For each round, outputs a report of how many pages were processed, how many had errors, and a breakdown by (error, finish_reason)
total_rounds = db . get_last_indexed_round ( ) + 1
for round_num in range ( total_rounds ) :
db . cursor . execute ( """
SELECT COUNT ( * ) , error , finish_reason
FROM page_results
WHERE round = ?
GROUP BY error , finish_reason
""" , (round_num,))
results = db . cursor . fetchall ( )
total_pages = sum ( count for count , _ , _ in results )
2024-10-14 21:37:14 +00:00
print ( f " \n Inference Round { round_num } - { total_pages : , } pages processed: " )
2024-10-14 20:48:33 +00:00
for count , error , finish_reason in results :
error_str = error if error is not None else " None "
print ( f " (error: { error_str } , finish_reason: { finish_reason } ) -> { count : , } pages " )
2024-10-14 20:31:37 +00:00
2024-10-14 18:42:50 +00:00
print ( " \n Work finished, waiting for all workers to finish cleaning up " )
executor . shutdown ( wait = True )
2024-10-14 20:28:54 +00:00
db . close ( )