mirror of
https://github.com/allenai/olmocr.git
synced 2025-11-14 17:38:12 +00:00
runpipeline
This commit is contained in:
parent
a90feda42f
commit
8e5809da71
@ -6,9 +6,10 @@ import base64
|
|||||||
import argparse
|
import argparse
|
||||||
import boto3
|
import boto3
|
||||||
import json
|
import json
|
||||||
|
import hashlib
|
||||||
from pypdf import PdfReader
|
from pypdf import PdfReader
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from typing import Generator
|
from typing import Generator, List
|
||||||
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, as_completed
|
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, as_completed
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
@ -18,6 +19,8 @@ from pdelfin.prompts.anchor import get_anchor_text
|
|||||||
from pdelfin.filter import PdfFilter
|
from pdelfin.filter import PdfFilter
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import smart_open
|
||||||
|
import posixpath # Import posixpath for S3 path handling
|
||||||
|
|
||||||
logging.getLogger("pypdf").setLevel(logging.ERROR)
|
logging.getLogger("pypdf").setLevel(logging.ERROR)
|
||||||
|
|
||||||
@ -51,7 +54,7 @@ def fetch_s3_file(s3_url: str, local_path: str) -> str:
|
|||||||
s3.download_file(bucket_name, key, local_path)
|
s3.download_file(bucket_name, key, local_path)
|
||||||
return local_path
|
return local_path
|
||||||
|
|
||||||
def process_pdf(pdf_path: str, no_filter: bool) -> Generator[dict, None, None]:
|
def process_pdf(pdf_path: str, no_filter: bool) -> List[dict]:
|
||||||
if pdf_path.startswith("s3://"):
|
if pdf_path.startswith("s3://"):
|
||||||
local_pdf_path = os.path.join("/tmp", os.path.basename(pdf_path))
|
local_pdf_path = os.path.join("/tmp", os.path.basename(pdf_path))
|
||||||
fetch_s3_file(pdf_path, local_pdf_path)
|
fetch_s3_file(pdf_path, local_pdf_path)
|
||||||
@ -95,11 +98,34 @@ def expand_s3_glob(s3_glob: str) -> list:
|
|||||||
for page in page_iterator:
|
for page in page_iterator:
|
||||||
for obj in page.get('Contents', []):
|
for obj in page.get('Contents', []):
|
||||||
key = obj['Key']
|
key = obj['Key']
|
||||||
if key.endswith('.pdf') and glob.fnmatch.fnmatch(key, prefix + pattern):
|
if key.endswith('.pdf') and glob.fnmatch.fnmatch(key, posixpath.join(prefix, pattern)):
|
||||||
matched_files.append(f"s3://{bucket_name}/{key}")
|
matched_files.append(f"s3://{bucket_name}/{key}")
|
||||||
|
|
||||||
return matched_files
|
return matched_files
|
||||||
|
|
||||||
|
def compute_hash(content: str) -> str:
|
||||||
|
"""Compute a 20-character SHA1 hash of the given content."""
|
||||||
|
sha1 = hashlib.sha1()
|
||||||
|
sha1.update(content.encode('utf-8'))
|
||||||
|
return sha1.hexdigest()[:20]
|
||||||
|
|
||||||
|
def get_smart_open_write_path(output_path: str, hash_str: str) -> str:
|
||||||
|
"""Generate the full output path with hash in the filename."""
|
||||||
|
parsed = urlparse(output_path)
|
||||||
|
if parsed.scheme in ('s3', 's3a', 's3n'):
|
||||||
|
bucket = parsed.netloc
|
||||||
|
key = parsed.path.lstrip('/')
|
||||||
|
# Ensure the key is treated as a directory by appending a slash if not present
|
||||||
|
if key and not key.endswith('/'):
|
||||||
|
key += '/'
|
||||||
|
# Use posixpath to correctly join S3 paths
|
||||||
|
full_key = posixpath.join(key, f"output_{hash_str}.jsonl")
|
||||||
|
return f"s3://{bucket}/{full_key}"
|
||||||
|
else:
|
||||||
|
dir_path = output_path
|
||||||
|
filename = f"output_{hash_str}.jsonl"
|
||||||
|
return os.path.join(dir_path, filename)
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="Given a bunch of PDFs, prepares a mise/birr workflow to run them through a conversion mechanism"
|
description="Given a bunch of PDFs, prepares a mise/birr workflow to run them through a conversion mechanism"
|
||||||
@ -132,7 +158,7 @@ def main():
|
|||||||
"--output",
|
"--output",
|
||||||
type=str,
|
type=str,
|
||||||
default="mise_batch_data",
|
default="mise_batch_data",
|
||||||
help="Output destination"
|
help="Output destination (can be a local path or an S3 URI)"
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@ -167,22 +193,31 @@ def main():
|
|||||||
|
|
||||||
print(f"Loaded and shuffled {len(pdf_paths)} paths to use.")
|
print(f"Loaded and shuffled {len(pdf_paths)} paths to use.")
|
||||||
|
|
||||||
# Rest of the code remains the same
|
# Prepare for output
|
||||||
cur_file_num = 0
|
|
||||||
output_dir = args.output
|
output_dir = args.output
|
||||||
max_file_size = args.max_size_mb * 1024 * 1024
|
max_file_size = args.max_size_mb * 1024 * 1024 # Convert MB to bytes
|
||||||
cur_file_size = 0
|
|
||||||
cur_file_path = os.path.join(output_dir, f"output_{cur_file_num}.jsonl")
|
|
||||||
|
|
||||||
# Ensure output directory exists
|
# Determine if output is S3
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
parsed_output = urlparse(output_dir)
|
||||||
|
is_s3 = parsed_output.scheme in ('s3', 's3a', 's3n')
|
||||||
|
|
||||||
# Open the first file for writing
|
# Initialize variables for batching
|
||||||
cur_file = open(cur_file_path, 'w')
|
batch = []
|
||||||
|
batch_size = 0
|
||||||
# Counter to track PDFs that produce at least one output
|
|
||||||
pdfs_with_output = 0
|
pdfs_with_output = 0
|
||||||
|
|
||||||
|
# Function to write a batch
|
||||||
|
def write_batch(batch: List[dict]):
|
||||||
|
nonlocal output_dir
|
||||||
|
if not batch:
|
||||||
|
return
|
||||||
|
batch_content = "\n".join(json.dumps(entry) for entry in batch) + "\n"
|
||||||
|
hash_str = compute_hash(batch_content)
|
||||||
|
output_path_with_hash = get_smart_open_write_path(output_dir, hash_str)
|
||||||
|
with smart_open.open(output_path_with_hash, 'w') as f_out:
|
||||||
|
f_out.write(batch_content)
|
||||||
|
print(f"Wrote batch to {output_path_with_hash}")
|
||||||
|
|
||||||
# Using ProcessPoolExecutor to process files concurrently
|
# Using ProcessPoolExecutor to process files concurrently
|
||||||
with ProcessPoolExecutor() as executor:
|
with ProcessPoolExecutor() as executor:
|
||||||
futures = []
|
futures = []
|
||||||
@ -200,28 +235,26 @@ def main():
|
|||||||
|
|
||||||
for request_obj in request_results:
|
for request_obj in request_results:
|
||||||
request_json = json.dumps(request_obj)
|
request_json = json.dumps(request_obj)
|
||||||
request_size = len(request_json.encode('utf-8')) # Calculate size in bytes
|
request_size = len(request_json.encode('utf-8')) + 1 # +1 for newline
|
||||||
|
|
||||||
# Check if the current request can fit in the current file
|
# Check if adding this entry would exceed the max size
|
||||||
if cur_file_size + request_size > max_file_size:
|
if batch_size + request_size > max_file_size:
|
||||||
# Close the current file and create a new one
|
# Write the current batch
|
||||||
cur_file.close()
|
write_batch(batch)
|
||||||
cur_file_num += 1
|
# Reset the batch
|
||||||
cur_file_path = os.path.join(output_dir, f"output_{cur_file_num}.jsonl")
|
batch = []
|
||||||
cur_file = open(cur_file_path, 'w')
|
batch_size = 0
|
||||||
cur_file_size = 0 # Reset file size
|
|
||||||
|
|
||||||
# Write the JSON entry to the file
|
# Add the entry to the batch
|
||||||
cur_file.write(request_json)
|
batch.append(request_obj)
|
||||||
cur_file.write("\n")
|
batch_size += request_size
|
||||||
cur_file_size += request_size
|
|
||||||
|
|
||||||
pb.update(1)
|
pb.update(1)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error processing a PDF: {str(e)}")
|
print(f"Error processing a PDF: {str(e)}")
|
||||||
|
|
||||||
# Close the last open file
|
# Write any remaining batch
|
||||||
cur_file.close()
|
write_batch(batch)
|
||||||
|
|
||||||
# Print the number of PDFs that resulted in at least one output
|
# Print the number of PDFs that resulted in at least one output
|
||||||
print(f"Number of sampled PDFs that produced at least one output: {pdfs_with_output}")
|
print(f"Number of sampled PDFs that produced at least one output: {pdfs_with_output}")
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user