runpipeline

This commit is contained in:
Jake Poznanski 2024-10-09 20:29:59 +00:00
parent a90feda42f
commit 8e5809da71

View File

@ -6,9 +6,10 @@ import base64
import argparse import argparse
import boto3 import boto3
import json import json
import hashlib
from pypdf import PdfReader from pypdf import PdfReader
from tqdm import tqdm from tqdm import tqdm
from typing import Generator from typing import Generator, List
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, as_completed from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, as_completed
from urllib.parse import urlparse from urllib.parse import urlparse
@ -18,6 +19,8 @@ from pdelfin.prompts.anchor import get_anchor_text
from pdelfin.filter import PdfFilter from pdelfin.filter import PdfFilter
import logging import logging
import smart_open
import posixpath # Import posixpath for S3 path handling
logging.getLogger("pypdf").setLevel(logging.ERROR) logging.getLogger("pypdf").setLevel(logging.ERROR)
@ -51,7 +54,7 @@ def fetch_s3_file(s3_url: str, local_path: str) -> str:
s3.download_file(bucket_name, key, local_path) s3.download_file(bucket_name, key, local_path)
return local_path return local_path
def process_pdf(pdf_path: str, no_filter: bool) -> Generator[dict, None, None]: def process_pdf(pdf_path: str, no_filter: bool) -> List[dict]:
if pdf_path.startswith("s3://"): if pdf_path.startswith("s3://"):
local_pdf_path = os.path.join("/tmp", os.path.basename(pdf_path)) local_pdf_path = os.path.join("/tmp", os.path.basename(pdf_path))
fetch_s3_file(pdf_path, local_pdf_path) fetch_s3_file(pdf_path, local_pdf_path)
@ -95,11 +98,34 @@ def expand_s3_glob(s3_glob: str) -> list:
for page in page_iterator: for page in page_iterator:
for obj in page.get('Contents', []): for obj in page.get('Contents', []):
key = obj['Key'] key = obj['Key']
if key.endswith('.pdf') and glob.fnmatch.fnmatch(key, prefix + pattern): if key.endswith('.pdf') and glob.fnmatch.fnmatch(key, posixpath.join(prefix, pattern)):
matched_files.append(f"s3://{bucket_name}/{key}") matched_files.append(f"s3://{bucket_name}/{key}")
return matched_files return matched_files
def compute_hash(content: str) -> str:
"""Compute a 20-character SHA1 hash of the given content."""
sha1 = hashlib.sha1()
sha1.update(content.encode('utf-8'))
return sha1.hexdigest()[:20]
def get_smart_open_write_path(output_path: str, hash_str: str) -> str:
"""Generate the full output path with hash in the filename."""
parsed = urlparse(output_path)
if parsed.scheme in ('s3', 's3a', 's3n'):
bucket = parsed.netloc
key = parsed.path.lstrip('/')
# Ensure the key is treated as a directory by appending a slash if not present
if key and not key.endswith('/'):
key += '/'
# Use posixpath to correctly join S3 paths
full_key = posixpath.join(key, f"output_{hash_str}.jsonl")
return f"s3://{bucket}/{full_key}"
else:
dir_path = output_path
filename = f"output_{hash_str}.jsonl"
return os.path.join(dir_path, filename)
def main(): def main():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Given a bunch of PDFs, prepares a mise/birr workflow to run them through a conversion mechanism" description="Given a bunch of PDFs, prepares a mise/birr workflow to run them through a conversion mechanism"
@ -132,7 +158,7 @@ def main():
"--output", "--output",
type=str, type=str,
default="mise_batch_data", default="mise_batch_data",
help="Output destination" help="Output destination (can be a local path or an S3 URI)"
) )
args = parser.parse_args() args = parser.parse_args()
@ -167,22 +193,31 @@ def main():
print(f"Loaded and shuffled {len(pdf_paths)} paths to use.") print(f"Loaded and shuffled {len(pdf_paths)} paths to use.")
# Rest of the code remains the same # Prepare for output
cur_file_num = 0
output_dir = args.output output_dir = args.output
max_file_size = args.max_size_mb * 1024 * 1024 max_file_size = args.max_size_mb * 1024 * 1024 # Convert MB to bytes
cur_file_size = 0
cur_file_path = os.path.join(output_dir, f"output_{cur_file_num}.jsonl")
# Ensure output directory exists # Determine if output is S3
os.makedirs(output_dir, exist_ok=True) parsed_output = urlparse(output_dir)
is_s3 = parsed_output.scheme in ('s3', 's3a', 's3n')
# Open the first file for writing # Initialize variables for batching
cur_file = open(cur_file_path, 'w') batch = []
batch_size = 0
# Counter to track PDFs that produce at least one output
pdfs_with_output = 0 pdfs_with_output = 0
# Function to write a batch
def write_batch(batch: List[dict]):
nonlocal output_dir
if not batch:
return
batch_content = "\n".join(json.dumps(entry) for entry in batch) + "\n"
hash_str = compute_hash(batch_content)
output_path_with_hash = get_smart_open_write_path(output_dir, hash_str)
with smart_open.open(output_path_with_hash, 'w') as f_out:
f_out.write(batch_content)
print(f"Wrote batch to {output_path_with_hash}")
# Using ProcessPoolExecutor to process files concurrently # Using ProcessPoolExecutor to process files concurrently
with ProcessPoolExecutor() as executor: with ProcessPoolExecutor() as executor:
futures = [] futures = []
@ -200,28 +235,26 @@ def main():
for request_obj in request_results: for request_obj in request_results:
request_json = json.dumps(request_obj) request_json = json.dumps(request_obj)
request_size = len(request_json.encode('utf-8')) # Calculate size in bytes request_size = len(request_json.encode('utf-8')) + 1 # +1 for newline
# Check if the current request can fit in the current file # Check if adding this entry would exceed the max size
if cur_file_size + request_size > max_file_size: if batch_size + request_size > max_file_size:
# Close the current file and create a new one # Write the current batch
cur_file.close() write_batch(batch)
cur_file_num += 1 # Reset the batch
cur_file_path = os.path.join(output_dir, f"output_{cur_file_num}.jsonl") batch = []
cur_file = open(cur_file_path, 'w') batch_size = 0
cur_file_size = 0 # Reset file size
# Write the JSON entry to the file # Add the entry to the batch
cur_file.write(request_json) batch.append(request_obj)
cur_file.write("\n") batch_size += request_size
cur_file_size += request_size
pb.update(1) pb.update(1)
except Exception as e: except Exception as e:
print(f"Error processing a PDF: {str(e)}") print(f"Error processing a PDF: {str(e)}")
# Close the last open file # Write any remaining batch
cur_file.close() write_batch(batch)
# Print the number of PDFs that resulted in at least one output # Print the number of PDFs that resulted in at least one output
print(f"Number of sampled PDFs that produced at least one output: {pdfs_with_output}") print(f"Number of sampled PDFs that produced at least one output: {pdfs_with_output}")