mirror of
https://github.com/allenai/olmocr.git
synced 2025-11-01 10:33:57 +00:00
runpipeline
This commit is contained in:
parent
a90feda42f
commit
8e5809da71
@ -6,9 +6,10 @@ import base64
|
||||
import argparse
|
||||
import boto3
|
||||
import json
|
||||
import hashlib
|
||||
from pypdf import PdfReader
|
||||
from tqdm import tqdm
|
||||
from typing import Generator
|
||||
from typing import Generator, List
|
||||
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, as_completed
|
||||
from urllib.parse import urlparse
|
||||
|
||||
@ -18,6 +19,8 @@ from pdelfin.prompts.anchor import get_anchor_text
|
||||
from pdelfin.filter import PdfFilter
|
||||
|
||||
import logging
|
||||
import smart_open
|
||||
import posixpath # Import posixpath for S3 path handling
|
||||
|
||||
logging.getLogger("pypdf").setLevel(logging.ERROR)
|
||||
|
||||
@ -51,7 +54,7 @@ def fetch_s3_file(s3_url: str, local_path: str) -> str:
|
||||
s3.download_file(bucket_name, key, local_path)
|
||||
return local_path
|
||||
|
||||
def process_pdf(pdf_path: str, no_filter: bool) -> Generator[dict, None, None]:
|
||||
def process_pdf(pdf_path: str, no_filter: bool) -> List[dict]:
|
||||
if pdf_path.startswith("s3://"):
|
||||
local_pdf_path = os.path.join("/tmp", os.path.basename(pdf_path))
|
||||
fetch_s3_file(pdf_path, local_pdf_path)
|
||||
@ -95,11 +98,34 @@ def expand_s3_glob(s3_glob: str) -> list:
|
||||
for page in page_iterator:
|
||||
for obj in page.get('Contents', []):
|
||||
key = obj['Key']
|
||||
if key.endswith('.pdf') and glob.fnmatch.fnmatch(key, prefix + pattern):
|
||||
if key.endswith('.pdf') and glob.fnmatch.fnmatch(key, posixpath.join(prefix, pattern)):
|
||||
matched_files.append(f"s3://{bucket_name}/{key}")
|
||||
|
||||
return matched_files
|
||||
|
||||
def compute_hash(content: str) -> str:
|
||||
"""Compute a 20-character SHA1 hash of the given content."""
|
||||
sha1 = hashlib.sha1()
|
||||
sha1.update(content.encode('utf-8'))
|
||||
return sha1.hexdigest()[:20]
|
||||
|
||||
def get_smart_open_write_path(output_path: str, hash_str: str) -> str:
|
||||
"""Generate the full output path with hash in the filename."""
|
||||
parsed = urlparse(output_path)
|
||||
if parsed.scheme in ('s3', 's3a', 's3n'):
|
||||
bucket = parsed.netloc
|
||||
key = parsed.path.lstrip('/')
|
||||
# Ensure the key is treated as a directory by appending a slash if not present
|
||||
if key and not key.endswith('/'):
|
||||
key += '/'
|
||||
# Use posixpath to correctly join S3 paths
|
||||
full_key = posixpath.join(key, f"output_{hash_str}.jsonl")
|
||||
return f"s3://{bucket}/{full_key}"
|
||||
else:
|
||||
dir_path = output_path
|
||||
filename = f"output_{hash_str}.jsonl"
|
||||
return os.path.join(dir_path, filename)
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Given a bunch of PDFs, prepares a mise/birr workflow to run them through a conversion mechanism"
|
||||
@ -132,7 +158,7 @@ def main():
|
||||
"--output",
|
||||
type=str,
|
||||
default="mise_batch_data",
|
||||
help="Output destination"
|
||||
help="Output destination (can be a local path or an S3 URI)"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -167,22 +193,31 @@ def main():
|
||||
|
||||
print(f"Loaded and shuffled {len(pdf_paths)} paths to use.")
|
||||
|
||||
# Rest of the code remains the same
|
||||
cur_file_num = 0
|
||||
# Prepare for output
|
||||
output_dir = args.output
|
||||
max_file_size = args.max_size_mb * 1024 * 1024
|
||||
cur_file_size = 0
|
||||
cur_file_path = os.path.join(output_dir, f"output_{cur_file_num}.jsonl")
|
||||
max_file_size = args.max_size_mb * 1024 * 1024 # Convert MB to bytes
|
||||
|
||||
# Ensure output directory exists
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
# Determine if output is S3
|
||||
parsed_output = urlparse(output_dir)
|
||||
is_s3 = parsed_output.scheme in ('s3', 's3a', 's3n')
|
||||
|
||||
# Open the first file for writing
|
||||
cur_file = open(cur_file_path, 'w')
|
||||
|
||||
# Counter to track PDFs that produce at least one output
|
||||
# Initialize variables for batching
|
||||
batch = []
|
||||
batch_size = 0
|
||||
pdfs_with_output = 0
|
||||
|
||||
# Function to write a batch
|
||||
def write_batch(batch: List[dict]):
|
||||
nonlocal output_dir
|
||||
if not batch:
|
||||
return
|
||||
batch_content = "\n".join(json.dumps(entry) for entry in batch) + "\n"
|
||||
hash_str = compute_hash(batch_content)
|
||||
output_path_with_hash = get_smart_open_write_path(output_dir, hash_str)
|
||||
with smart_open.open(output_path_with_hash, 'w') as f_out:
|
||||
f_out.write(batch_content)
|
||||
print(f"Wrote batch to {output_path_with_hash}")
|
||||
|
||||
# Using ProcessPoolExecutor to process files concurrently
|
||||
with ProcessPoolExecutor() as executor:
|
||||
futures = []
|
||||
@ -200,28 +235,26 @@ def main():
|
||||
|
||||
for request_obj in request_results:
|
||||
request_json = json.dumps(request_obj)
|
||||
request_size = len(request_json.encode('utf-8')) # Calculate size in bytes
|
||||
request_size = len(request_json.encode('utf-8')) + 1 # +1 for newline
|
||||
|
||||
# Check if the current request can fit in the current file
|
||||
if cur_file_size + request_size > max_file_size:
|
||||
# Close the current file and create a new one
|
||||
cur_file.close()
|
||||
cur_file_num += 1
|
||||
cur_file_path = os.path.join(output_dir, f"output_{cur_file_num}.jsonl")
|
||||
cur_file = open(cur_file_path, 'w')
|
||||
cur_file_size = 0 # Reset file size
|
||||
# Check if adding this entry would exceed the max size
|
||||
if batch_size + request_size > max_file_size:
|
||||
# Write the current batch
|
||||
write_batch(batch)
|
||||
# Reset the batch
|
||||
batch = []
|
||||
batch_size = 0
|
||||
|
||||
# Write the JSON entry to the file
|
||||
cur_file.write(request_json)
|
||||
cur_file.write("\n")
|
||||
cur_file_size += request_size
|
||||
# Add the entry to the batch
|
||||
batch.append(request_obj)
|
||||
batch_size += request_size
|
||||
|
||||
pb.update(1)
|
||||
except Exception as e:
|
||||
print(f"Error processing a PDF: {str(e)}")
|
||||
|
||||
# Close the last open file
|
||||
cur_file.close()
|
||||
# Write any remaining batch
|
||||
write_batch(batch)
|
||||
|
||||
# Print the number of PDFs that resulted in at least one output
|
||||
print(f"Number of sampled PDFs that produced at least one output: {pdfs_with_output}")
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user