2024-10-09 22:14:28 +00:00
import os
import hashlib
2024-10-10 22:10:26 +00:00
import boto3
2024-10-10 22:30:09 +00:00
import sqlite3
2024-10-10 22:10:26 +00:00
import json
import argparse
2024-10-11 21:50:09 +00:00
import glob
import tempfile
2024-10-14 17:23:09 +00:00
import datetime
2024-10-11 21:50:09 +00:00
import posixpath
2024-10-14 17:09:11 +00:00
import smart_open
2024-10-11 21:50:09 +00:00
2024-10-11 22:37:32 +00:00
from dataclasses import dataclass
2024-10-11 21:50:09 +00:00
from pypdf import PdfReader
2024-10-09 22:14:28 +00:00
from tqdm import tqdm
2024-10-11 22:41:09 +00:00
from typing import Optional , List , Tuple , Dict
2024-10-11 21:50:09 +00:00
from urllib . parse import urlparse
2024-10-10 22:30:09 +00:00
from concurrent . futures import ProcessPoolExecutor , as_completed
2024-10-09 22:14:28 +00:00
2024-10-11 22:57:49 +00:00
from pdelfin . data . renderpdf import render_pdf_to_base64png
from pdelfin . prompts import build_finetuning_prompt
from pdelfin . prompts . anchor import get_anchor_text
2024-10-11 21:50:09 +00:00
# Global s3 client for the whole script, feel free to adjust params if you need it
s3 = boto3 . client ( ' s3 ' )
2024-10-11 16:24:29 +00:00
class DatabaseManager :
2024-10-14 16:30:49 +00:00
@dataclass ( frozen = True )
class BatchInferenceRecord :
2024-10-14 18:19:17 +00:00
inference_s3_path : str
pdf_s3_path : str
2024-10-14 16:30:49 +00:00
page_num : int # 1 indexed!
start_index : int
length : int
finish_reason : str
error : Optional [ str ]
def is_usable ( self ) :
return self . error is None and self . finish_reason == " stop "
@dataclass ( frozen = True )
class PDFRecord :
s3_path : str
num_pages : int
status : str
2024-10-11 20:22:58 +00:00
def __init__ ( self , s3_workspace : str ) :
cache_key = hashlib . sha256 ( s3_workspace . strip ( ) . lower ( ) . encode ( ' utf-8 ' ) ) . hexdigest ( )
home_cache_dir = os . path . join ( os . path . expanduser ( ' ~ ' ) , ' .cache ' , ' pdelfin ' , cache_key )
os . makedirs ( home_cache_dir , exist_ok = True )
self . db_path = os . path . join ( home_cache_dir , ' index.db ' )
2024-10-11 16:24:29 +00:00
self . conn = sqlite3 . connect ( self . db_path )
self . cursor = self . conn . cursor ( )
self . _initialize_tables ( )
def _initialize_tables ( self ) :
self . cursor . execute ( """
2024-10-11 22:37:32 +00:00
CREATE TABLE IF NOT EXISTS page_results (
2024-10-14 18:19:17 +00:00
inference_s3_path TEXT ,
pdf_s3_path TEXT ,
2024-10-11 22:37:32 +00:00
page_num INTEGER ,
2024-10-11 16:24:29 +00:00
start_index BIGINT ,
2024-10-11 22:37:32 +00:00
length BIGINT ,
2024-10-11 22:41:09 +00:00
finish_reason TEXT ,
error TEXT
2024-10-11 16:24:29 +00:00
)
""" )
2024-10-11 20:22:58 +00:00
self . cursor . execute ( """
2024-10-11 22:41:09 +00:00
CREATE INDEX IF NOT EXISTS idx_path ON page_results ( s3_path )
2024-10-11 20:22:58 +00:00
""" )
self . cursor . execute ( """
CREATE TABLE IF NOT EXISTS pdfs (
s3_path TEXT PRIMARY KEY ,
num_pages INTEGER ,
status TEXT DEFAULT ' pending '
)
""" )
2024-10-11 16:24:29 +00:00
self . cursor . execute ( """
CREATE TABLE IF NOT EXISTS processed_files (
s3_path TEXT PRIMARY KEY ,
etag TEXT
)
""" )
2024-10-11 20:22:58 +00:00
# Generic metadata such as current round
self . cursor . execute ( """
CREATE TABLE IF NOT EXISTS metadata (
key TEXT PRIMARY KEY ,
value TEXT
)
""" )
2024-10-11 22:37:32 +00:00
2024-10-11 16:24:29 +00:00
self . conn . commit ( )
2024-10-11 22:41:09 +00:00
def get_metadata ( self , key : str ) - > Optional [ str ] :
2024-10-11 22:37:32 +00:00
self . cursor . execute ( " SELECT value FROM metadata WHERE key=? " , ( key , ) )
2024-10-11 20:22:58 +00:00
result = self . cursor . fetchone ( )
2024-10-11 22:41:09 +00:00
return result [ 0 ] if result else None
2024-10-14 17:23:09 +00:00
def set_metadata ( self , key : str , value : str ) - > None :
self . cursor . execute ( """
INSERT INTO metadata ( key , value )
VALUES ( ? , ? )
ON CONFLICT ( key ) DO UPDATE SET value = excluded . value
""" , (key, value))
self . conn . commit ( )
2024-10-11 22:37:32 +00:00
def get_current_round ( self ) :
2024-10-11 22:41:09 +00:00
round_value = self . get_metadata ( " round " )
return int ( round_value ) if round_value else 0
2024-10-11 20:22:58 +00:00
2024-10-11 16:24:29 +00:00
def is_file_processed ( self , s3_path , etag ) :
self . cursor . execute ( " SELECT etag FROM processed_files WHERE s3_path = ? " , ( s3_path , ) )
result = self . cursor . fetchone ( )
return result is not None and result [ 0 ] == etag
2024-10-14 16:30:49 +00:00
def add_index_entries ( self , index_entries : List [ BatchInferenceRecord ] ) :
2024-10-11 16:24:29 +00:00
if index_entries :
self . cursor . executemany ( """
2024-10-14 18:19:17 +00:00
INSERT INTO page_results ( inference_s3_path , pdf_s3_path , page_num , start_index , length , finish_reason , error )
2024-10-11 22:41:09 +00:00
VALUES ( ? , ? , ? , ? , ? , ? )
2024-10-14 18:19:17 +00:00
""" , [(entry.inference_s3_path, entry.pdf_s3_path, entry.page_num, entry.start_index, entry.length, entry.finish_reason, entry.error) for entry in index_entries])
2024-10-11 16:24:29 +00:00
self . conn . commit ( )
2024-10-14 16:30:49 +00:00
def get_index_entries ( self , s3_path : str ) - > List [ BatchInferenceRecord ] :
self . cursor . execute ( """
2024-10-14 18:19:17 +00:00
SELECT inference_s3_path , pdf_s3_path , page_num , start_index , length , finish_reason , error
2024-10-14 16:30:49 +00:00
FROM page_results
WHERE s3_path = ?
2024-10-14 18:19:17 +00:00
ORDER BY inference_s3_path DESC start_index ASC page_num ASC
2024-10-14 16:30:49 +00:00
""" , (s3_path,))
rows = self . cursor . fetchall ( )
return [
self . BatchInferenceRecord (
2024-10-14 18:19:17 +00:00
inference_s3_path = row [ 0 ] ,
pdf_s3_path = row [ 1 ] ,
page_num = row [ 2 ] ,
start_index = row [ 3 ] ,
length = row [ 4 ] ,
finish_reason = row [ 5 ] ,
error = row [ 6 ]
2024-10-14 16:30:49 +00:00
)
for row in rows
]
2024-10-11 16:24:29 +00:00
def update_processed_file ( self , s3_path , etag ) :
self . cursor . execute ( """
INSERT INTO processed_files ( s3_path , etag )
VALUES ( ? , ? )
ON CONFLICT ( s3_path ) DO UPDATE SET etag = excluded . etag
""" , (s3_path, etag))
self . conn . commit ( )
2024-10-11 21:50:09 +00:00
def pdf_exists ( self , s3_path : str ) - > bool :
self . cursor . execute ( " SELECT 1 FROM pdfs WHERE s3_path = ? " , ( s3_path , ) )
return self . cursor . fetchone ( ) is not None
def add_pdf ( self , s3_path : str , num_pages : int , status : str = ' pending ' ) - > None :
try :
self . cursor . execute ( """
INSERT INTO pdfs ( s3_path , num_pages , status )
VALUES ( ? , ? , ? )
""" , (s3_path, num_pages, status))
self . conn . commit ( )
except sqlite3 . IntegrityError :
print ( f " PDF with s3_path ' { s3_path } ' already exists. " )
2024-10-14 16:30:49 +00:00
def get_pdf ( self , s3_path : str ) - > Optional [ PDFRecord ] :
self . cursor . execute ( """
SELECT s3_path , num_pages , status
FROM pdfs
WHERE s3_path = ?
""" , (s3_path,))
row = self . cursor . fetchone ( )
if row :
return self . PDFRecord (
s3_path = row [ 0 ] ,
num_pages = row [ 1 ] ,
status = row [ 2 ]
)
return None
def get_pdfs_by_status ( self , status : str ) - > List [ PDFRecord ] :
self . cursor . execute ( """
SELECT s3_path , num_pages , status
FROM pdfs
WHERE status == ?
2024-10-14 18:19:17 +00:00
ORDER BY s3_path DESC
2024-10-14 16:30:49 +00:00
""" , (status, ))
rows = self . cursor . fetchall ( )
return [
self . PDFRecord (
s3_path = row [ 0 ] ,
num_pages = row [ 1 ] ,
status = row [ 2 ]
)
for row in rows
]
2024-10-11 21:50:09 +00:00
2024-10-11 16:24:29 +00:00
def close ( self ) :
self . conn . close ( )
2024-10-11 22:57:49 +00:00
2024-10-14 17:09:11 +00:00
# Writes batches of lines out to a set of files, keeping each file below some maximum size
class BatchWriter :
def __init__ ( self , output_prefix : str , max_size_mb : int = 250 ) :
self . output_prefix = output_prefix
self . max_size = max_size_mb * 1024 * 1024 # Convert MB to bytes
self . batch = [ ]
self . batch_size = 0
parsed = urlparse ( output_prefix )
self . is_s3 = parsed . scheme in ( ' s3 ' , ' s3a ' , ' s3n ' )
if not self . is_s3 :
os . makedirs ( output_prefix , exist_ok = True )
def _compute_hash ( self , content : str ) - > str :
""" Compute a 20-character SHA1 hash of the given content. """
sha1 = hashlib . sha1 ( )
sha1 . update ( content . encode ( ' utf-8 ' ) )
return sha1 . hexdigest ( ) [ : 20 ]
def _get_output_path ( self , hash_str : str ) - > str :
""" Generate the full output path with hash in the filename. """
parsed = urlparse ( self . output_prefix )
if self . is_s3 :
bucket = parsed . netloc
key = parsed . path . lstrip ( ' / ' )
if key and not key . endswith ( ' / ' ) :
key + = ' / '
full_key = posixpath . join ( key , f " output_ { hash_str } .jsonl " )
return f " s3:// { bucket } / { full_key } "
else :
filename = f " output_ { hash_str } .jsonl "
return os . path . join ( self . output_prefix , filename )
def write_line ( self , line : str ) :
line_size = len ( line . encode ( ' utf-8 ' ) ) + 1 # +1 for newline
if self . batch_size + line_size > self . max_size :
self . _write_batch ( )
self . batch . append ( line )
self . batch_size + = line_size
def _write_batch ( self ) :
if not self . batch :
return
batch_content = " \n " . join ( self . batch ) + " \n "
hash_str = self . _compute_hash ( batch_content )
output_path = self . _get_output_path ( hash_str )
2024-10-14 16:30:49 +00:00
2024-10-14 17:09:11 +00:00
with smart_open . open ( output_path , ' w ' ) as f_out :
f_out . write ( batch_content )
print ( f " Wrote batch to { output_path } " )
self . batch = [ ]
self . batch_size = 0
def close ( self ) :
self . _write_batch ( )
2024-10-11 22:57:49 +00:00
2024-10-09 22:14:28 +00:00
def parse_s3_path ( s3_path ) :
2024-10-10 22:10:26 +00:00
if not s3_path . startswith ( ' s3:// ' ) :
raise ValueError ( ' s3_path must start with s3:// ' )
path = s3_path [ 5 : ]
bucket , _ , prefix = path . partition ( ' / ' )
return bucket , prefix
2024-10-11 22:41:09 +00:00
def expand_s3_glob ( s3_glob : str ) - > Dict [ str , str ] :
2024-10-11 21:50:09 +00:00
parsed = urlparse ( s3_glob )
bucket_name = parsed . netloc
prefix = os . path . dirname ( parsed . path . lstrip ( ' / ' ) ) . rstrip ( ' / ' ) + " / "
pattern = os . path . basename ( parsed . path )
2024-10-10 22:10:26 +00:00
paginator = s3 . get_paginator ( ' list_objects_v2 ' )
2024-10-11 21:50:09 +00:00
page_iterator = paginator . paginate ( Bucket = bucket_name , Prefix = prefix )
matched_files = { }
2024-10-10 22:10:26 +00:00
for page in page_iterator :
2024-10-11 21:50:09 +00:00
for obj in page . get ( ' Contents ' , [ ] ) :
2024-10-10 22:10:26 +00:00
key = obj [ ' Key ' ]
2024-10-11 21:50:09 +00:00
if glob . fnmatch . fnmatch ( key , posixpath . join ( prefix , pattern ) ) :
matched_files [ f " s3:// { bucket_name } / { key } " ] = obj [ ' ETag ' ] . strip ( ' " ' )
return matched_files
2024-10-10 22:10:26 +00:00
2024-10-11 22:57:49 +00:00
def build_page_query ( local_pdf_path : str , pretty_pdf_path : str , page : int ) - > dict :
image_base64 = render_pdf_to_base64png ( local_pdf_path , page , 1024 )
anchor_text = get_anchor_text ( local_pdf_path , page , pdf_engine = " pdfreport " )
return {
" custom_id " : f " { pretty_pdf_path } - { page } " ,
" chat_messages " : [
{
" role " : " user " ,
" content " : [
{ " type " : " text " , " text " : build_finetuning_prompt ( anchor_text ) } ,
{ " type " : " image_url " , " image_url " : { " url " : f " data:image/png;base64, { image_base64 } " } }
] ,
}
] ,
" temperature " : 0.8 ,
" max_tokens " : 6000 ,
}
def get_s3_bytes ( s3_path : str , start_index : Optional [ int ] = None , end_index : Optional [ int ] = None ) - > bytes :
bucket , key = parse_s3_path ( s3_path )
# Build the range header if start_index and/or end_index are specified
range_header = None
if start_index is not None or end_index is not None :
range_value = f " bytes= { start_index or 0 } - "
if end_index is not None :
range_value + = str ( end_index )
range_header = { ' Range ' : range_value }
if range_header :
obj = s3 . get_object ( Bucket = bucket , Key = key , Range = range_header [ ' Range ' ] )
else :
obj = s3 . get_object ( Bucket = bucket , Key = key )
return obj [ ' Body ' ] . read ( )
2024-10-11 22:41:09 +00:00
def parse_custom_id ( custom_id : str ) - > Tuple [ str , int ] :
2024-10-11 22:37:32 +00:00
s3_path = custom_id [ : custom_id . rindex ( " - " ) ]
page_num = int ( custom_id [ custom_id . rindex ( " - " ) + 1 : ] )
return s3_path , page_num
2024-10-14 18:19:17 +00:00
def process_jsonl_content ( inference_s3_path : str ) - > List [ DatabaseManager . BatchInferenceRecord ] :
content = get_s3_bytes ( inference_s3_path ) . decode ( " utf-8 " )
2024-10-11 22:37:32 +00:00
2024-10-10 22:10:26 +00:00
start_index = 0
2024-10-10 22:30:09 +00:00
index_entries = [ ]
2024-10-10 22:10:26 +00:00
lines = content . splitlines ( keepends = True )
for line in lines :
line_length = len ( line )
2024-10-11 22:37:32 +00:00
2024-10-09 22:19:16 +00:00
try :
2024-10-10 22:10:26 +00:00
data = json . loads ( line )
2024-10-14 18:19:17 +00:00
pdf_s3_path , page_num = parse_custom_id ( data [ " custom_id " ] )
2024-10-11 22:37:32 +00:00
assert " outputs " in data and len ( data [ " outputs " ] ) > 0 , " No outputs from model detected "
2024-10-14 16:30:49 +00:00
index_entries . append ( DatabaseManager . BatchInferenceRecord (
2024-10-14 18:19:17 +00:00
inference_s3_path = inference_s3_path ,
pdf_s3_path = pdf_s3_path ,
2024-10-11 22:41:09 +00:00
page_num = page_num ,
start_index = start_index ,
length = line_length ,
finish_reason = data [ " outputs " ] [ 0 ] [ " finish_reason " ] ,
error = data . get ( " completion_error " , None )
) )
2024-10-10 22:10:26 +00:00
except json . JSONDecodeError :
pass # Handle JSON decode errors if necessary
2024-10-11 22:41:09 +00:00
except Exception as e :
print ( f " Error processing line: { e } " )
2024-10-11 22:37:32 +00:00
2024-10-11 22:41:09 +00:00
start_index + = line_length
2024-10-11 22:37:32 +00:00
2024-10-10 22:30:09 +00:00
return index_entries
2024-10-09 22:19:16 +00:00
2024-10-11 21:50:09 +00:00
def get_pdf_num_pages ( s3_path : str ) - > Optional [ int ] :
try :
with tempfile . NamedTemporaryFile ( " wb+ " , suffix = " .pdf " ) as tf :
tf . write ( get_s3_bytes ( s3_path ) )
tf . flush ( )
reader = PdfReader ( tf . name )
return reader . get_num_pages ( )
except Exception as ex :
print ( f " Warning, could not add { s3_path } due to { ex } " )
2024-10-11 22:41:09 +00:00
2024-10-11 21:50:09 +00:00
return None
2024-10-14 16:30:49 +00:00
def build_pdf_queries ( s3_workspace : str , pdf : DatabaseManager . PDFRecord ) - > list [ dict ] :
db = DatabaseManager ( s3_workspace )
existing_pages = db . get_index_entries ( pdf . s3_path )
new_queries = [ ]
2024-10-14 17:09:11 +00:00
# Shortcut out of downloading the actual PDF
if set ( page . page_num for page in existing_pages if page . is_usable ( ) ) == set ( range ( 1 , pdf . num_pages + 1 ) ) :
return [ ]
2024-10-14 16:30:49 +00:00
2024-10-11 22:57:49 +00:00
try :
with tempfile . NamedTemporaryFile ( " wb+ " , suffix = " .pdf " ) as tf :
2024-10-14 16:30:49 +00:00
tf . write ( get_s3_bytes ( pdf . s3_path ) )
2024-10-11 22:57:49 +00:00
tf . flush ( )
2024-10-14 17:23:09 +00:00
for target_page_num in range ( 1 , pdf . num_pages + 1 ) :
2024-10-14 16:30:49 +00:00
# Is there an existing page that has no error
2024-10-14 17:23:09 +00:00
if any ( page . is_usable ( ) and page . page_num == target_page_num for page in existing_pages ) :
2024-10-14 16:30:49 +00:00
continue
# TODO: Later, you may want to retry with different sampling parameters or do something else
2024-10-14 17:23:09 +00:00
new_queries . append ( build_page_query ( tf . name , pdf . s3_path , target_page_num ) )
2024-10-11 22:57:49 +00:00
except Exception as ex :
2024-10-14 16:30:49 +00:00
print ( f " Warning, could not get batch inferences lines for { pdf . s3_path } due to { ex } " )
2024-10-11 22:57:49 +00:00
2024-10-14 16:30:49 +00:00
return new_queries
2024-10-11 22:57:49 +00:00
2024-10-14 17:23:09 +00:00
def build_dolma_doc ( s3_workspace : str , pdf : DatabaseManager . PDFRecord ) - > dict :
db = DatabaseManager ( s3_workspace )
existing_pages = db . get_index_entries ( pdf . s3_path )
document_text = " "
2024-10-14 18:19:17 +00:00
last_page_start_index = 0
pdf_page_spans = [ ]
2024-10-14 17:23:09 +00:00
for target_page_num in range ( 1 , pdf . num_pages + 1 ) :
target_page = next ( page for page in existing_pages if page . is_usable ( ) and page . page_num == target_page_num )
2024-10-14 18:19:17 +00:00
target_row = get_s3_bytes ( target_page . pdf_s3_path ,
start_index = target_page . start_index ,
end_index = target_page . start_index + target_page . length )
2024-10-14 17:23:09 +00:00
target_data = json . loads ( target_row . decode ( " utf-8 " ) )
2024-10-14 18:19:17 +00:00
document_text + = target_data [ " natural_text " ] + " \n "
pdf_page_spans . append ( [ last_page_start_index , len ( document_text ) , target_page_num ] )
last_page_start_index = len ( document_text )
2024-10-14 17:23:09 +00:00
metadata = {
" Source-File " : pdf . s3_path ,
" pdf-total-pages " : pdf . num_pages ,
}
id_ = hashlib . sha1 ( document_text . encode ( ) ) . hexdigest ( )
dolma_doc = {
" id " : id_ ,
" text " : document_text ,
2024-10-14 18:19:17 +00:00
" source " : " pdelfin " ,
2024-10-14 17:23:09 +00:00
" added " : datetime . datetime . now ( ) . strftime ( " % Y- % m- %d " ) ,
" created " : datetime . datetime . now ( ) . strftime ( " % Y- % m- %d " ) ,
" metadata " : metadata ,
2024-10-14 18:19:17 +00:00
" attributes " : {
" pdf_page_numbers " : pdf_page_spans
}
2024-10-14 17:23:09 +00:00
}
return dolma_doc
2024-10-10 22:10:26 +00:00
if __name__ == ' __main__ ' :
2024-10-11 20:22:58 +00:00
parser = argparse . ArgumentParser ( description = ' Manager for running millions of PDFs through a batch inference pipeline ' )
parser . add_argument ( ' workspace ' , help = ' The S3 path where work will be done e.g., s3://bucket/prefix/) ' )
2024-10-11 22:37:32 +00:00
parser . add_argument ( ' --add_pdfs ' , help = ' Glob path to add PDFs (s3) to the workspace ' , default = None )
2024-10-14 17:09:11 +00:00
parser . add_argument ( ' --max_size_mb ' , type = int , default = 250 , help = ' Max file size in MB ' )
2024-10-09 22:14:28 +00:00
args = parser . parse_args ( )
2024-10-11 20:22:58 +00:00
db = DatabaseManager ( args . workspace )
print ( f " Loaded db at { db . db_path } " )
2024-10-11 21:50:09 +00:00
print ( f " Current round is { db . get_current_round ( ) } \n " )
# One shared executor to rule them all
executor = ProcessPoolExecutor ( )
2024-10-14 16:30:49 +00:00
# If you have new PDFs, step one is to add them to the list
2024-10-11 22:37:32 +00:00
if args . add_pdfs :
assert args . add_pdfs . startswith ( " s3:// " ) , " PDFs must live on s3 "
2024-10-11 21:50:09 +00:00
2024-10-11 22:37:32 +00:00
print ( f " Querying all PDFs at { args . add_pdfs } " )
2024-10-11 21:50:09 +00:00
2024-10-11 22:37:32 +00:00
all_pdfs = expand_s3_glob ( args . add_pdfs )
2024-10-11 21:50:09 +00:00
print ( f " Found { len ( all_pdfs ) } total pdf paths " )
all_pdfs = [ pdf for pdf in all_pdfs if not db . pdf_exists ( pdf ) ]
print ( f " Need to import { len ( all_pdfs ) } total new pdf paths " )
future_to_path = { executor . submit ( get_pdf_num_pages , s3_path ) : s3_path for s3_path in all_pdfs }
for future in tqdm ( as_completed ( future_to_path ) , total = len ( future_to_path ) ) :
s3_path = future_to_path [ future ]
2024-10-11 22:41:09 +00:00
num_pages = future . result ( )
if num_pages and not db . pdf_exists ( s3_path ) :
db . add_pdf ( s3_path , num_pages , " pending " )
2024-10-11 21:50:09 +00:00
print ( " \n " )
# Now build an index of all the pages that were processed within the workspace so far
2024-10-11 22:57:49 +00:00
print ( " Indexing all batch inference sent to this workspace " )
2024-10-11 22:37:32 +00:00
inference_output_paths = expand_s3_glob ( f " { args . workspace } /inference_outputs/*.jsonl " )
inference_output_paths = [
2024-10-11 22:57:49 +00:00
( s3_path , etag ) for s3_path , etag in inference_output_paths . items ( )
if not db . is_file_processed ( s3_path , etag )
2024-10-11 22:37:32 +00:00
]
2024-10-11 22:57:49 +00:00
print ( f " Found { len ( inference_output_paths ) } new batch inference results to index " )
2024-10-11 22:41:09 +00:00
future_to_path = { executor . submit ( process_jsonl_content , s3_path ) : ( s3_path , etag ) for s3_path , etag in inference_output_paths }
2024-10-11 22:37:32 +00:00
for future in tqdm ( as_completed ( future_to_path ) , total = len ( future_to_path ) ) :
2024-10-11 22:41:09 +00:00
s3_path , etag = future_to_path [ future ]
2024-10-14 17:09:11 +00:00
inference_records = future . result ( )
2024-10-11 22:37:32 +00:00
2024-10-14 17:09:11 +00:00
db . add_index_entries ( inference_records )
2024-10-11 22:41:09 +00:00
db . update_processed_file ( s3_path , etag = etag )
2024-10-11 22:57:49 +00:00
# Now query each pdf, if you have all of the pages needed (all pages present, error is null and finish_reason is stop), then you assemble it into a dolma document and output it
2024-10-14 16:30:49 +00:00
# If you don't have every page, or if you have pages with errors, then you output a new batch of inference items to use
future_to_path = { executor . submit ( build_pdf_queries , args . workspace , pdf ) : pdf for pdf in db . get_pdfs_by_status ( " pending " ) }
2024-10-14 17:09:11 +00:00
potentially_done_pdfs = [ ]
2024-10-14 17:23:09 +00:00
lines_written = 0
2024-10-14 18:19:17 +00:00
new_inference_writer = BatchWriter ( f " { args . workspace } /inference/round_ { db . get_current_round ( ) } " , args . max_size_mb )
2024-10-14 16:30:49 +00:00
for future in tqdm ( as_completed ( future_to_path ) , total = len ( future_to_path ) ) :
pdf = future_to_path [ future ]
inference_lines = future . result ( )
2024-10-14 17:09:11 +00:00
if len ( inference_lines ) == 0 :
potentially_done_pdfs . append ( pdf )
for line in inference_lines :
2024-10-14 17:23:09 +00:00
lines_written + = 1
2024-10-14 17:09:11 +00:00
new_inference_writer . write_line ( json . dumps ( line ) )
2024-10-14 16:30:49 +00:00
2024-10-14 17:09:11 +00:00
new_inference_writer . close ( )
2024-10-14 16:30:49 +00:00
2024-10-14 17:23:09 +00:00
if lines_written > 0 :
db . set_metadata ( " round " , str ( db . get_current_round ( ) + 1 ) )
2024-10-14 17:09:11 +00:00
# Now, finally, assemble any potentially done docs into dolma documents
2024-10-14 17:23:09 +00:00
future_to_path = { executor . submit ( build_dolma_doc , args . workspace , pdf ) : pdf for pdf in potentially_done_pdfs }
new_output_writer = BatchWriter ( f " { args . workspace } /output " , args . max_size_mb )
2024-10-14 16:30:49 +00:00
2024-10-14 17:23:09 +00:00
for future in tqdm ( as_completed ( future_to_path ) , total = len ( future_to_path ) ) :
pdf = future_to_path [ future ]
dolma_doc = future . result ( )
2024-10-14 16:30:49 +00:00
2024-10-14 17:23:09 +00:00
new_output_writer . write_line ( json . dumps ( dolma_doc ) )
new_output_writer . close ( )
2024-10-14 16:30:49 +00:00
# TODO
# 2. Have a way to apply basic spam + language filter if you can during add pdfs step
# 3. For retrying, make it so you retry several times with different sampling parameters