2024-09-25 20:27:49 +00:00
import argparse
import json
2025-01-29 15:25:10 -08:00
import logging
import os
2024-09-30 19:54:30 +00:00
import re
2024-09-25 20:27:49 +00:00
import sys
2025-01-10 19:38:42 +00:00
import tempfile
2025-01-29 15:25:10 -08:00
from concurrent . futures import ProcessPoolExecutor , as_completed
from pathlib import Path
2024-09-25 20:27:49 +00:00
2025-01-10 19:38:42 +00:00
import boto3
2024-09-30 19:54:30 +00:00
2024-10-08 18:22:56 +00:00
# Import Plotly for plotting
import plotly . express as px
2025-01-29 15:25:10 -08:00
import smart_open
from olmocr . data . renderpdf import render_pdf_to_base64png
from olmocr . prompts import build_finetuning_prompt
from olmocr . prompts . anchor import get_anchor_text
2024-10-08 18:22:56 +00:00
2024-09-30 19:54:30 +00:00
2024-09-25 20:27:49 +00:00
def setup_logging ( ) :
""" Configure logging for the script. """
2025-01-29 15:30:39 -08:00
logging . basicConfig ( level = logging . INFO , format = " [ %(asctime)s ] %(levelname)s : %(message)s " , handlers = [ logging . StreamHandler ( sys . stdout ) ] )
2024-09-25 20:27:49 +00:00
2024-09-30 19:54:30 +00:00
2024-09-30 20:06:45 +00:00
def is_s3_path ( path ) :
""" Check if the given path is an S3 path. """
2025-01-29 15:30:39 -08:00
return str ( path ) . startswith ( " s3:// " )
2024-09-30 20:06:45 +00:00
2025-01-10 19:38:42 +00:00
def download_pdf_from_s3 ( s3_path : str , pdf_profile : str ) - > str :
"""
Downloads a PDF file from S3 to a temporary local file and returns the local file path .
Args :
s3_path ( str ) : S3 path in the format s3 : / / bucket / key
pdf_profile ( str ) : The name of the boto3 profile to use .
Returns :
str : Path to the downloaded PDF file in the local filesystem .
"""
# Parse the bucket and key from the s3_path
# s3_path format: s3://bucket_name/some/folder/file.pdf
2025-01-29 15:30:39 -08:00
path_without_scheme = s3_path . split ( " s3:// " , 1 ) [ 1 ]
bucket_name , key = path_without_scheme . split ( " / " , 1 )
2025-01-10 19:38:42 +00:00
# Create a session with the specified profile or default
session = boto3 . Session ( profile_name = pdf_profile ) if pdf_profile else boto3 . Session ( )
2025-01-29 15:30:39 -08:00
s3_client = session . client ( " s3 " )
2025-01-10 19:38:42 +00:00
# Create a temporary local file
2025-01-29 15:30:39 -08:00
tmp_file = tempfile . NamedTemporaryFile ( delete = False , suffix = " .pdf " )
2025-01-10 19:38:42 +00:00
tmp_file . close ( ) # We only want the path and not keep it locked
local_path = tmp_file . name
logging . info ( f " Downloading PDF from { s3_path } to { local_path } using profile { pdf_profile } " )
s3_client . download_file ( bucket_name , key , local_path )
return local_path
2024-09-25 20:27:49 +00:00
def transform_json_object ( obj ) :
"""
Transform a single JSON object by extracting and renaming specific fields .
Args :
obj ( dict ) : Original JSON object .
Returns :
2025-01-10 19:38:42 +00:00
dict or None : Transformed JSON object , or None if there ' s an error.
2024-09-25 20:27:49 +00:00
"""
try :
transformed = {
" custom_id " : obj [ " custom_id " ] ,
" chat_messages " : obj [ " body " ] [ " messages " ] ,
" temperature " : obj [ " body " ] [ " temperature " ] ,
2025-01-29 15:30:39 -08:00
" max_tokens " : obj [ " body " ] [ " max_tokens " ] ,
2024-09-25 20:27:49 +00:00
}
return transformed
except KeyError as e :
logging . error ( f " Missing key { e } in object: { obj . get ( ' custom_id ' , ' unknown ' ) } " )
return None
2024-09-30 19:54:30 +00:00
2025-01-10 19:38:42 +00:00
def process_file ( input_file : str , output_file : str , rewrite_prompt_str : bool , pdf_profile : str ) :
2024-09-25 20:27:49 +00:00
"""
Process a single JSONL file : read , transform , and write to output .
Args :
2024-09-30 20:06:45 +00:00
input_file ( str ) : Path or URL to the input JSONL file .
output_file ( str ) : Path or URL to the output JSONL file .
2024-10-08 18:22:56 +00:00
rewrite_prompt_str ( bool ) : Flag to rewrite the prompt string .
2025-01-10 19:38:42 +00:00
pdf_profile ( str ) : Boto3 profile to use when fetching PDFs from S3 .
2024-09-25 20:27:49 +00:00
"""
processed_count = 0
error_count = 0
2024-10-08 18:22:56 +00:00
prompt_lengths = [ ]
2024-09-25 20:27:49 +00:00
try :
2025-01-29 15:30:39 -08:00
with smart_open . open ( input_file , " r " , encoding = " utf-8 " ) as infile , smart_open . open ( output_file , " w " , encoding = " utf-8 " ) as outfile :
2024-09-25 20:27:49 +00:00
for line_number , line in enumerate ( infile , 1 ) :
line = line . strip ( )
if not line :
continue # Skip empty lines
try :
obj = json . loads ( line )
except json . JSONDecodeError as e :
logging . error ( f " JSON decode error in file { input_file } at line { line_number } : { e } " )
error_count + = 1
continue
transformed = transform_json_object ( obj )
2024-09-30 19:54:30 +00:00
if transformed is not None and rewrite_prompt_str :
2025-01-10 19:38:42 +00:00
# We look for RAW_TEXT_START ... RAW_TEXT_END in the existing content
2024-09-30 19:54:30 +00:00
pattern = r " RAW_TEXT_START \ s* \ n(.*?) \ nRAW_TEXT_END "
match = re . search ( pattern , transformed [ " chat_messages " ] [ 0 ] [ " content " ] [ 0 ] [ " text " ] , re . DOTALL )
if match :
2025-01-10 19:38:42 +00:00
# We found raw page text, but we'll attempt to regenerate it
2024-10-09 16:04:39 +00:00
goldkey = obj [ " custom_id " ]
2025-01-10 19:38:42 +00:00
# goldkey might look like: "s3://bucket/path/to/file.pdf-23"
# s3_path = everything up to the last dash
# page = everything after the dash
try :
2025-01-29 15:30:39 -08:00
s3_path = goldkey [ : goldkey . rindex ( " - " ) ]
page = int ( goldkey [ goldkey . rindex ( " - " ) + 1 : ] )
2025-01-10 19:38:42 +00:00
except ( ValueError , IndexError ) as e :
logging . error ( f " Could not parse the page number from custom_id { goldkey } : { e } " )
error_count + = 1
continue
# If the path is an S3 path, download to a local temp file; else assume local
if is_s3_path ( s3_path ) :
local_pdf_path = download_pdf_from_s3 ( s3_path , pdf_profile )
else :
local_pdf_path = s3_path
# Recalculate the anchor text
2025-01-29 15:30:39 -08:00
raw_page_text = get_anchor_text ( local_pdf_path , page , pdf_engine = " pdfreport " , target_length = 6000 )
2025-01-10 19:38:42 +00:00
image_base64 = render_pdf_to_base64png ( local_pdf_path , page , 1024 )
2024-10-09 16:04:39 +00:00
2024-09-30 19:54:30 +00:00
transformed [ " chat_messages " ] [ 0 ] [ " content " ] [ 0 ] [ " text " ] = build_finetuning_prompt ( raw_page_text )
2025-01-10 19:38:42 +00:00
transformed [ " chat_messages " ] [ 0 ] [ " content " ] [ 1 ] [ " image_url " ] [ " url " ] = f " data:image/png;base64, { image_base64 } "
2024-09-30 19:54:30 +00:00
2025-01-10 19:38:42 +00:00
# Clean up the temp PDF file if it was downloaded
if is_s3_path ( s3_path ) :
try :
os . remove ( local_pdf_path )
except OSError as remove_err :
logging . error ( f " Failed to remove temporary PDF file { local_pdf_path } : { remove_err } " )
2024-10-09 16:04:39 +00:00
2024-09-25 20:27:49 +00:00
if transformed is not None :
2024-10-08 18:22:56 +00:00
prompt_text = transformed [ " chat_messages " ] [ 0 ] [ " content " ] [ 0 ] [ " text " ]
prompt_length = len ( prompt_text )
2024-10-09 16:04:39 +00:00
if prompt_length > 6000 :
print ( transformed [ " custom_id " ] , " length " , prompt_length )
2024-10-08 18:22:56 +00:00
prompt_lengths . append ( prompt_length )
2025-01-29 15:30:39 -08:00
outfile . write ( json . dumps ( transformed ) + " \n " )
2024-09-25 20:27:49 +00:00
processed_count + = 1
else :
error_count + = 1
2024-09-30 20:06:45 +00:00
logging . info ( f " Processed ' { input_file } ' : { processed_count } records transformed, { error_count } errors. " )
2024-10-08 18:22:56 +00:00
return prompt_lengths
2024-09-25 20:27:49 +00:00
except Exception as e :
2024-10-01 20:19:03 +00:00
logging . exception ( e )
2024-09-25 20:27:49 +00:00
logging . error ( f " Failed to process file { input_file } : { e } " )
2024-10-08 18:22:56 +00:00
return [ ]
2024-09-25 20:27:49 +00:00
2024-09-30 20:06:45 +00:00
def construct_output_file_path ( input_file_path , input_dir , output_dir ) :
"""
Given an input file path , input directory , and output directory ,
construct the corresponding output file path .
Args :
input_file_path ( str ) : Path to the input file .
input_dir ( str ) : Path to the input directory .
output_dir ( str ) : Path to the output directory .
Returns :
str : Path to the output file .
"""
input_file = Path ( input_file_path )
2024-09-30 22:41:51 +00:00
if is_s3_path ( input_dir ) :
# For S3 paths, manually construct the relative path based on the input S3 path
2025-01-29 15:30:39 -08:00
input_prefix = input_dir . split ( " s3:// " ) [ 1 ]
input_prefix = input_prefix . rstrip ( " * " ) # Remove any glob patterns like *.jsonl
2024-09-30 22:41:51 +00:00
# Remove the 's3://' part from input_file_path and extract the relative part
2025-01-29 15:30:39 -08:00
input_file_key = input_file_path . split ( " s3:// " ) [ 1 ]
relative_path = input_file_key [ len ( input_prefix ) : ] . lstrip ( " / " )
2024-09-30 22:41:51 +00:00
# Construct the output S3 path by appending the relative part to the output S3 directory
2025-01-29 15:30:39 -08:00
output_file_path = output_dir . rstrip ( " / " ) + " / " + relative_path
2024-09-30 22:41:51 +00:00
2024-09-30 20:06:45 +00:00
else :
2024-09-30 22:41:51 +00:00
# For local paths, use the existing relative path logic
input_dir_path = Path ( input_dir )
relative_path = input_file . relative_to ( input_dir_path )
2024-09-30 20:06:45 +00:00
output_file_path = str ( Path ( output_dir ) / relative_path )
2024-09-30 22:41:51 +00:00
2024-09-30 20:06:45 +00:00
return output_file_path
def list_input_files ( input_dir ) :
"""
2024-09-30 22:41:51 +00:00
List all JSONL files in the input directory . If input_dir is an S3 path , handle
globbing manually by listing objects and filtering based on patterns .
2024-09-30 20:06:45 +00:00
Args :
2024-09-30 22:41:51 +00:00
input_dir ( str ) : Path to the input directory or S3 URL .
2024-09-30 20:06:45 +00:00
Returns :
list : List of input file paths .
"""
if is_s3_path ( input_dir ) :
2024-09-30 22:41:51 +00:00
import fnmatch
# Parse bucket and prefix
2025-01-29 15:30:39 -08:00
bucket_name = input_dir . split ( " s3:// " ) [ 1 ] . split ( " / " ) [ 0 ]
path_and_pattern = " / " . join ( input_dir . split ( " s3:// " ) [ 1 ] . split ( " / " ) [ 1 : ] )
2024-09-30 22:41:51 +00:00
# Separate the prefix and pattern
2025-01-29 15:30:39 -08:00
if " / " in path_and_pattern :
prefix = path_and_pattern . rsplit ( " / " , 1 ) [ 0 ] + " / "
pattern = path_and_pattern . rsplit ( " / " , 1 ) [ 1 ]
2024-09-30 22:41:51 +00:00
else :
2025-01-29 15:30:39 -08:00
prefix = " "
2024-09-30 22:41:51 +00:00
pattern = path_and_pattern
2025-01-10 19:38:42 +00:00
# Use a Boto3 session (no specific PDF profile needed here if only listing)
session = boto3 . Session ( )
2025-01-29 15:30:39 -08:00
s3 = session . resource ( " s3 " )
2024-09-30 20:06:45 +00:00
bucket = s3 . Bucket ( bucket_name )
2024-09-30 22:41:51 +00:00
2024-09-30 20:06:45 +00:00
files = [ ]
for obj in bucket . objects . filter ( Prefix = prefix ) :
2025-01-29 15:30:39 -08:00
if fnmatch . fnmatch ( obj . key , f " { prefix } { pattern } " ) :
files . append ( f " s3:// { bucket_name } / { obj . key } " )
2024-09-30 22:41:51 +00:00
2024-09-30 20:06:45 +00:00
return files
else :
input_dir_path = Path ( input_dir )
2025-01-29 15:30:39 -08:00
return [ str ( p ) for p in input_dir_path . glob ( " *.jsonl " ) ]
2024-09-30 20:06:45 +00:00
2024-09-25 20:27:49 +00:00
def main ( ) :
setup_logging ( )
2025-01-29 15:30:39 -08:00
parser = argparse . ArgumentParser ( description = " Transform JSONL files by extracting and renaming specific fields. " )
2024-09-30 19:54:30 +00:00
parser . add_argument (
2025-01-29 15:30:39 -08:00
" --rewrite_finetuning_prompt " ,
action = " store_true " ,
2025-01-10 19:38:42 +00:00
default = True ,
2025-01-29 15:30:39 -08:00
help = " Rewrite the input prompt from a standard OPENAI instruction format into a finetuned format. " ,
2025-01-10 19:38:42 +00:00
)
2025-01-29 15:30:39 -08:00
parser . add_argument ( " input_dir " , type = str , help = " Path to the input directory containing JSONL files. Can be a local path or S3 URL. " )
parser . add_argument ( " output_dir " , type = str , help = " Path to the output directory where transformed JSONL files will be saved. Can be a local path or S3 URL. " )
parser . add_argument ( " --jobs " , " -j " , type = int , default = 20 , help = " Number of parallel jobs to run (default: 20). " )
parser . add_argument ( " --pdf_profile " , type = str , default = None , help = " Boto3 profile to use for downloading PDFs from S3. Defaults to the default session. " )
2025-01-10 19:38:42 +00:00
2024-09-25 20:27:49 +00:00
args = parser . parse_args ( )
2025-01-29 15:30:39 -08:00
input_dir = args . input_dir . rstrip ( " / " )
output_dir = args . output_dir . rstrip ( " / " )
2024-09-25 20:27:49 +00:00
max_jobs = args . jobs
2024-09-30 20:06:45 +00:00
# List input files
input_files = list_input_files ( input_dir )
2024-09-25 20:27:49 +00:00
2024-09-30 20:06:45 +00:00
if not input_files :
2024-09-25 20:27:49 +00:00
logging . warning ( f " No JSONL files found in ' { input_dir } ' . Exiting. " )
sys . exit ( 0 )
2024-09-30 20:06:45 +00:00
logging . info ( f " Found { len ( input_files ) } JSONL files to process. " )
2024-09-25 20:27:49 +00:00
# Prepare tasks for parallel processing
tasks = [ ]
2024-09-30 20:06:45 +00:00
for input_file in input_files :
output_file = construct_output_file_path ( input_file , input_dir , output_dir )
tasks . append ( ( input_file , output_file ) )
# Process files in parallel
2024-10-08 18:22:56 +00:00
all_prompt_lengths = [ ]
2024-09-25 20:27:49 +00:00
with ProcessPoolExecutor ( max_workers = max_jobs ) as executor :
future_to_file = {
2025-01-29 15:30:39 -08:00
executor . submit ( process_file , input_file , output_file , args . rewrite_finetuning_prompt , args . pdf_profile ) : input_file
2024-09-30 20:06:45 +00:00
for input_file , output_file in tasks
2024-09-25 20:27:49 +00:00
}
for future in as_completed ( future_to_file ) :
2024-09-30 20:06:45 +00:00
input_file = future_to_file [ future ]
2024-09-25 20:27:49 +00:00
try :
2024-10-08 18:22:56 +00:00
prompt_lengths = future . result ( )
all_prompt_lengths . extend ( prompt_lengths )
2024-09-25 20:27:49 +00:00
except Exception as exc :
2024-09-30 20:06:45 +00:00
logging . error ( f " File { input_file } generated an exception: { exc } " )
2024-09-25 20:27:49 +00:00
logging . info ( " All files have been processed. " )
2024-10-08 18:22:56 +00:00
# Plot histogram of prompt lengths
if all_prompt_lengths :
fig = px . histogram ( all_prompt_lengths , nbins = 50 , title = " Histogram of Prompt Lengths " )
fig . update_xaxes ( title = " Prompt Length " )
fig . update_yaxes ( title = " Frequency " )
try :
fig . write_image ( " prompt_lengths_histogram.png " )
logging . info ( " Histogram of prompt lengths has been saved to ' prompt_lengths_histogram.png ' . " )
except Exception as e :
logging . error ( f " Failed to save the histogram image: { e } " )
logging . error ( " Please make sure that the ' kaleido ' package is installed (pip install -U kaleido). " )
fig . write_html ( " prompt_lengths_histogram.html " )
logging . info ( " Histogram of prompt lengths has been saved to ' prompt_lengths_histogram.html ' . " )
else :
logging . warning ( " No prompt lengths were collected; histogram will not be generated. " )
2024-09-30 20:06:45 +00:00
2024-09-25 20:27:49 +00:00
if __name__ == " __main__ " :
main ( )