olmocr/pdelfin/data/convertsilver_birr.py

296 lines
10 KiB
Python
Raw Normal View History

import argparse
import json
2024-09-30 19:54:30 +00:00
import re
from pathlib import Path
from concurrent.futures import ProcessPoolExecutor, as_completed
import sys
import logging
import smart_open
from cached_path import cached_path
2024-09-30 19:54:30 +00:00
from pdelfin.prompts import build_finetuning_prompt
from pdelfin.prompts.anchor import get_anchor_text
2024-09-30 19:54:30 +00:00
# Import Plotly for plotting
import plotly.express as px
2024-09-30 19:54:30 +00:00
def setup_logging():
"""Configure logging for the script."""
logging.basicConfig(
level=logging.INFO,
format='[%(asctime)s] %(levelname)s: %(message)s',
handlers=[
logging.StreamHandler(sys.stdout)
]
)
2024-09-30 19:54:30 +00:00
def is_s3_path(path):
"""Check if the given path is an S3 path."""
return str(path).startswith('s3://')
def transform_json_object(obj):
"""
Transform a single JSON object by extracting and renaming specific fields.
Args:
obj (dict): Original JSON object.
Returns:
dict: Transformed JSON object.
"""
try:
transformed = {
"custom_id": obj["custom_id"],
"chat_messages": obj["body"]["messages"],
"temperature": obj["body"]["temperature"],
"max_tokens": obj["body"]["max_tokens"]
}
return transformed
except KeyError as e:
logging.error(f"Missing key {e} in object: {obj.get('custom_id', 'unknown')}")
return None
2024-09-30 19:54:30 +00:00
def process_file(input_file: str, output_file: str, rewrite_prompt_str: bool):
"""
Process a single JSONL file: read, transform, and write to output.
Args:
input_file (str): Path or URL to the input JSONL file.
output_file (str): Path or URL to the output JSONL file.
rewrite_prompt_str (bool): Flag to rewrite the prompt string.
"""
processed_count = 0
error_count = 0
prompt_lengths = []
try:
with smart_open.open(input_file, 'r', encoding='utf-8') as infile, \
2024-10-01 20:19:03 +00:00
smart_open.open(output_file, 'w', encoding='utf-8') as outfile:
for line_number, line in enumerate(infile, 1):
line = line.strip()
if not line:
continue # Skip empty lines
try:
obj = json.loads(line)
except json.JSONDecodeError as e:
logging.error(f"JSON decode error in file {input_file} at line {line_number}: {e}")
error_count += 1
continue
transformed = transform_json_object(obj)
2024-09-30 19:54:30 +00:00
if transformed is not None and rewrite_prompt_str:
pattern = r"RAW_TEXT_START\s*\n(.*?)\nRAW_TEXT_END"
# Use re.DOTALL to ensure that the dot matches newline characters
match = re.search(pattern, transformed["chat_messages"][0]["content"][0]["text"], re.DOTALL)
if match:
raw_page_text = match.group(1).strip()
# Ok, now we want to try to see if it's better if we recalculate the anchor text
goldkey = obj["custom_id"]
s3_path = goldkey[:goldkey.rindex("-")]
page = int(goldkey[goldkey.rindex("-") + 1:])
# Save the pdf to a temporary cache folder
local_pdf_path = cached_path(s3_path, quiet=True)
raw_page_text = get_anchor_text(local_pdf_path, page, pdf_engine="pdfreport")
2024-09-30 19:54:30 +00:00
transformed["chat_messages"][0]["content"][0]["text"] = build_finetuning_prompt(raw_page_text)
if transformed is not None:
prompt_text = transformed["chat_messages"][0]["content"][0]["text"]
prompt_length = len(prompt_text)
if prompt_length > 6000:
print(transformed["custom_id"], "length ", prompt_length)
prompt_lengths.append(prompt_length)
outfile.write(json.dumps(transformed) + '\n')
processed_count += 1
else:
error_count += 1
logging.info(f"Processed '{input_file}': {processed_count} records transformed, {error_count} errors.")
return prompt_lengths
except Exception as e:
2024-10-01 20:19:03 +00:00
logging.exception(e)
logging.error(f"Failed to process file {input_file}: {e}")
return []
def construct_output_file_path(input_file_path, input_dir, output_dir):
"""
Given an input file path, input directory, and output directory,
construct the corresponding output file path.
Args:
input_file_path (str): Path to the input file.
input_dir (str): Path to the input directory.
output_dir (str): Path to the output directory.
Returns:
str: Path to the output file.
"""
input_file = Path(input_file_path)
2024-09-30 22:41:51 +00:00
if is_s3_path(input_dir):
# For S3 paths, manually construct the relative path based on the input S3 path
input_prefix = input_dir.split('s3://')[1]
input_prefix = input_prefix.rstrip('*') # Remove any glob patterns like *.jsonl
# Remove the 's3://' part from input_file_path and extract the relative part
input_file_key = input_file_path.split('s3://')[1]
relative_path = input_file_key[len(input_prefix):].lstrip('/')
# Construct the output S3 path by appending the relative part to the output S3 directory
output_file_path = output_dir.rstrip('/') + '/' + relative_path
else:
2024-09-30 22:41:51 +00:00
# For local paths, use the existing relative path logic
input_dir_path = Path(input_dir)
relative_path = input_file.relative_to(input_dir_path)
output_file_path = str(Path(output_dir) / relative_path)
2024-09-30 22:41:51 +00:00
return output_file_path
def list_input_files(input_dir):
"""
2024-09-30 22:41:51 +00:00
List all JSONL files in the input directory. If input_dir is an S3 path, handle
globbing manually by listing objects and filtering based on patterns.
Args:
2024-09-30 22:41:51 +00:00
input_dir (str): Path to the input directory or S3 URL.
Returns:
list: List of input file paths.
"""
if is_s3_path(input_dir):
# Use smart_open's s3 functionality to list files
import boto3
2024-09-30 22:41:51 +00:00
import fnmatch
# Parse bucket and prefix
bucket_name = input_dir.split('s3://')[1].split('/')[0]
2024-09-30 22:41:51 +00:00
path_and_pattern = '/'.join(input_dir.split('s3://')[1].split('/')[1:])
# Separate the prefix and pattern
if '/' in path_and_pattern:
prefix = path_and_pattern.rsplit('/', 1)[0] + '/'
pattern = path_and_pattern.rsplit('/', 1)[1]
else:
prefix = ''
pattern = path_and_pattern
# Set up S3 resource and bucket
s3 = boto3.resource('s3')
bucket = s3.Bucket(bucket_name)
2024-09-30 22:41:51 +00:00
# Get all objects and filter them manually based on the pattern
files = []
for obj in bucket.objects.filter(Prefix=prefix):
2024-09-30 22:41:51 +00:00
if fnmatch.fnmatch(obj.key, f'{prefix}{pattern}'):
files.append(f's3://{bucket_name}/{obj.key}')
2024-09-30 22:41:51 +00:00
return files
else:
2024-09-30 22:41:51 +00:00
# Local path handling (with glob pattern)
input_dir_path = Path(input_dir)
return [str(p) for p in input_dir_path.glob('*.jsonl')]
def main():
setup_logging()
parser = argparse.ArgumentParser(
description="Transform JSONL files by extracting and renaming specific fields."
)
2024-09-30 19:54:30 +00:00
parser.add_argument(
'--rewrite_finetuning_prompt',
action='store_true',
default=False,
help="Rewrites the input prompt from standard OPENAI instruction format into our finetuned format"
2024-09-30 19:54:30 +00:00
)
parser.add_argument(
'input_dir',
type=str,
help='Path to the input directory containing JSONL files. Can be a local path or S3 URL.'
)
parser.add_argument(
'output_dir',
type=str,
help='Path to the output directory where transformed JSONL files will be saved. Can be a local path or S3 URL.'
)
parser.add_argument(
'--jobs', '-j',
type=int,
default=20,
help='Number of parallel jobs to run (default: 20).'
)
args = parser.parse_args()
input_dir = args.input_dir.rstrip('/')
output_dir = args.output_dir.rstrip('/')
max_jobs = args.jobs
# List input files
input_files = list_input_files(input_dir)
if not input_files:
logging.warning(f"No JSONL files found in '{input_dir}'. Exiting.")
sys.exit(0)
logging.info(f"Found {len(input_files)} JSONL files to process.")
# Prepare tasks for parallel processing
tasks = []
for input_file in input_files:
output_file = construct_output_file_path(input_file, input_dir, output_dir)
tasks.append((input_file, output_file))
# Process files in parallel
all_prompt_lengths = []
with ProcessPoolExecutor(max_workers=max_jobs) as executor:
future_to_file = {
executor.submit(process_file, input_file, output_file, args.rewrite_finetuning_prompt): input_file
for input_file, output_file in tasks
}
for future in as_completed(future_to_file):
input_file = future_to_file[future]
try:
prompt_lengths = future.result()
all_prompt_lengths.extend(prompt_lengths)
except Exception as exc:
logging.error(f"File {input_file} generated an exception: {exc}")
logging.info("All files have been processed.")
# Plot histogram of prompt lengths
if all_prompt_lengths:
fig = px.histogram(all_prompt_lengths, nbins=50, title="Histogram of Prompt Lengths")
fig.update_xaxes(title="Prompt Length")
fig.update_yaxes(title="Frequency")
try:
fig.write_image("prompt_lengths_histogram.png")
logging.info("Histogram of prompt lengths has been saved to 'prompt_lengths_histogram.png'.")
except Exception as e:
logging.error(f"Failed to save the histogram image: {e}")
logging.error("Please make sure that the 'kaleido' package is installed (pip install -U kaleido).")
fig.write_html("prompt_lengths_histogram.html")
logging.info("Histogram of prompt lengths has been saved to 'prompt_lengths_histogram.html'.")
else:
logging.warning("No prompt lengths were collected; histogram will not be generated.")
if __name__ == "__main__":
main()