2024-10-09 22:14:28 +00:00
|
|
|
# The way this script works is it gets a list of pdfs to process
|
|
|
|
# and an output/scratch folder location either locally or in s3 to work with
|
|
|
|
# On the first run, with an empty output folder, it will queue up each page in each pdf to go into a VLM
|
|
|
|
# Then, the user queues up that task in birr, and it outputs to a new subfolder in the same location
|
|
|
|
# Then, you run your script again, and it will see that you have some valid output files
|
|
|
|
# If so, then it will check those output files, and if it has a complete document, it will build a dolma doc for it, and that's considered done
|
|
|
|
# For any remaining pages that got errored out, or failed due to stop_reason not being "stop" (ex. over length)
|
|
|
|
# Then, it will queue up another set of tasks, hopefully much smaller, to send into batch inference again
|
|
|
|
# This process will keep going, until you run it with the --fallback option, at which point it will
|
|
|
|
# just use a basic text extraction on any remaining pages, and assemble the rest of the dolma docs
|
|
|
|
#
|
|
|
|
#
|
|
|
|
#
|
2024-09-23 17:20:18 +00:00
|
|
|
import os
|
|
|
|
import glob
|
|
|
|
import random
|
|
|
|
import argparse
|
|
|
|
import boto3
|
|
|
|
import json
|
2024-10-09 20:29:59 +00:00
|
|
|
import hashlib
|
2024-09-23 17:20:18 +00:00
|
|
|
from pypdf import PdfReader
|
|
|
|
from tqdm import tqdm
|
2024-10-09 20:29:59 +00:00
|
|
|
from typing import Generator, List
|
2024-10-04 17:32:35 +00:00
|
|
|
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, as_completed
|
2024-09-23 17:20:18 +00:00
|
|
|
from urllib.parse import urlparse
|
|
|
|
|
2024-10-09 18:48:31 +00:00
|
|
|
from pdelfin.data.renderpdf import render_pdf_to_base64png
|
|
|
|
from pdelfin.prompts import build_finetuning_prompt
|
2024-10-02 17:29:44 +00:00
|
|
|
from pdelfin.prompts.anchor import get_anchor_text
|
2024-09-23 17:20:18 +00:00
|
|
|
from pdelfin.filter import PdfFilter
|
|
|
|
|
2024-10-09 20:20:06 +00:00
|
|
|
import logging
|
2024-10-09 20:29:59 +00:00
|
|
|
import smart_open
|
|
|
|
import posixpath # Import posixpath for S3 path handling
|
2024-10-09 20:20:06 +00:00
|
|
|
|
|
|
|
logging.getLogger("pypdf").setLevel(logging.ERROR)
|
|
|
|
|
2024-09-23 17:20:18 +00:00
|
|
|
pdf_filter = PdfFilter()
|
|
|
|
|
|
|
|
def build_page_query(local_pdf_path: str, pretty_pdf_path: str, page: int) -> dict:
|
2024-10-09 18:48:31 +00:00
|
|
|
image_base64 = render_pdf_to_base64png(local_pdf_path, page, 1024)
|
2024-10-02 17:29:44 +00:00
|
|
|
anchor_text = get_anchor_text(local_pdf_path, page, pdf_engine="pdfreport")
|
2024-09-23 17:20:18 +00:00
|
|
|
|
|
|
|
return {
|
|
|
|
"custom_id": f"{pretty_pdf_path}-{page}",
|
2024-10-09 18:48:31 +00:00
|
|
|
"chat_messages": [
|
|
|
|
{
|
|
|
|
"role": "user",
|
|
|
|
"content": [
|
|
|
|
{"type": "text", "text": build_finetuning_prompt(anchor_text)},
|
|
|
|
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_base64}"}}
|
|
|
|
],
|
|
|
|
}
|
|
|
|
],
|
|
|
|
"temperature": 0.1,
|
|
|
|
"max_tokens": 6000,
|
2024-09-23 17:20:18 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
def fetch_s3_file(s3_url: str, local_path: str) -> str:
|
|
|
|
parsed = urlparse(s3_url)
|
|
|
|
bucket_name = parsed.netloc
|
|
|
|
key = parsed.path.lstrip('/')
|
2024-10-09 19:55:45 +00:00
|
|
|
|
2024-09-23 17:20:18 +00:00
|
|
|
s3 = boto3.client('s3')
|
|
|
|
s3.download_file(bucket_name, key, local_path)
|
|
|
|
return local_path
|
|
|
|
|
2024-10-09 20:29:59 +00:00
|
|
|
def process_pdf(pdf_path: str, no_filter: bool) -> List[dict]:
|
2024-09-23 17:20:18 +00:00
|
|
|
if pdf_path.startswith("s3://"):
|
|
|
|
local_pdf_path = os.path.join("/tmp", os.path.basename(pdf_path))
|
|
|
|
fetch_s3_file(pdf_path, local_pdf_path)
|
|
|
|
else:
|
|
|
|
local_pdf_path = pdf_path
|
|
|
|
|
2024-10-02 15:46:12 +00:00
|
|
|
if (not no_filter) and pdf_filter.filter_out_pdf(local_pdf_path):
|
2024-09-23 17:20:18 +00:00
|
|
|
print(f"Skipping {local_pdf_path} due to common filter")
|
|
|
|
return []
|
2024-10-09 19:55:45 +00:00
|
|
|
|
2024-09-23 17:20:18 +00:00
|
|
|
pretty_pdf_path = pdf_path
|
|
|
|
|
|
|
|
pdf = PdfReader(local_pdf_path)
|
|
|
|
num_pages = len(pdf.pages)
|
2024-10-09 19:55:45 +00:00
|
|
|
|
2024-10-09 18:48:31 +00:00
|
|
|
sample_pages = list(range(1, num_pages + 1))
|
2024-09-23 17:20:18 +00:00
|
|
|
result = []
|
|
|
|
for page in sample_pages:
|
|
|
|
try:
|
|
|
|
query = build_page_query(local_pdf_path, pretty_pdf_path, page)
|
|
|
|
result.append(query)
|
|
|
|
except Exception as e:
|
|
|
|
print(f"Error processing page {page} of {pdf_path}: {e}")
|
|
|
|
|
|
|
|
return result
|
|
|
|
|
2024-10-09 19:55:45 +00:00
|
|
|
def is_glob_pattern(path: str) -> bool:
|
|
|
|
return any(char in path for char in ['*', '?', '[', ']'])
|
|
|
|
|
|
|
|
def expand_s3_glob(s3_glob: str) -> list:
|
|
|
|
parsed = urlparse(s3_glob)
|
|
|
|
bucket_name = parsed.netloc
|
|
|
|
prefix = os.path.dirname(parsed.path.lstrip('/')).rstrip('/') + "/"
|
|
|
|
pattern = os.path.basename(parsed.path)
|
|
|
|
|
|
|
|
s3 = boto3.client('s3')
|
|
|
|
paginator = s3.get_paginator('list_objects_v2')
|
|
|
|
page_iterator = paginator.paginate(Bucket=bucket_name, Prefix=prefix)
|
|
|
|
|
|
|
|
matched_files = []
|
|
|
|
for page in page_iterator:
|
|
|
|
for obj in page.get('Contents', []):
|
|
|
|
key = obj['Key']
|
2024-10-09 20:29:59 +00:00
|
|
|
if key.endswith('.pdf') and glob.fnmatch.fnmatch(key, posixpath.join(prefix, pattern)):
|
2024-10-09 19:55:45 +00:00
|
|
|
matched_files.append(f"s3://{bucket_name}/{key}")
|
|
|
|
|
|
|
|
return matched_files
|
|
|
|
|
2024-10-09 20:29:59 +00:00
|
|
|
def compute_hash(content: str) -> str:
|
|
|
|
"""Compute a 20-character SHA1 hash of the given content."""
|
|
|
|
sha1 = hashlib.sha1()
|
|
|
|
sha1.update(content.encode('utf-8'))
|
|
|
|
return sha1.hexdigest()[:20]
|
|
|
|
|
|
|
|
def get_smart_open_write_path(output_path: str, hash_str: str) -> str:
|
|
|
|
"""Generate the full output path with hash in the filename."""
|
|
|
|
parsed = urlparse(output_path)
|
|
|
|
if parsed.scheme in ('s3', 's3a', 's3n'):
|
|
|
|
bucket = parsed.netloc
|
|
|
|
key = parsed.path.lstrip('/')
|
|
|
|
# Ensure the key is treated as a directory by appending a slash if not present
|
|
|
|
if key and not key.endswith('/'):
|
|
|
|
key += '/'
|
|
|
|
# Use posixpath to correctly join S3 paths
|
|
|
|
full_key = posixpath.join(key, f"output_{hash_str}.jsonl")
|
|
|
|
return f"s3://{bucket}/{full_key}"
|
|
|
|
else:
|
|
|
|
dir_path = output_path
|
|
|
|
filename = f"output_{hash_str}.jsonl"
|
|
|
|
return os.path.join(dir_path, filename)
|
|
|
|
|
2024-09-23 17:20:18 +00:00
|
|
|
def main():
|
2024-10-09 19:55:45 +00:00
|
|
|
parser = argparse.ArgumentParser(
|
|
|
|
description="Given a bunch of PDFs, prepares a mise/birr workflow to run them through a conversion mechanism"
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"pdf_paths",
|
|
|
|
nargs='*',
|
|
|
|
help=(
|
|
|
|
"List of PDF paths to process. If a single argument contains glob patterns (e.g., *.pdf or s3://bucket/pdfs/*.pdf), "
|
|
|
|
"it will be expanded accordingly."
|
|
|
|
)
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--path_list",
|
|
|
|
type=str,
|
|
|
|
help="Path to a file containing paths to PDFs, one per line."
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--max_size_mb",
|
|
|
|
type=int,
|
|
|
|
default=250,
|
|
|
|
help="Max number of MBs of entries to put in each birr workitem"
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--no_filter",
|
|
|
|
action="store_true",
|
|
|
|
help="Disables the basic spam/language filtering so that ALL pdfs listed are used"
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--output",
|
|
|
|
type=str,
|
|
|
|
default="mise_batch_data",
|
2024-10-09 20:29:59 +00:00
|
|
|
help="Output destination (can be a local path or an S3 URI)"
|
2024-10-09 19:55:45 +00:00
|
|
|
)
|
2024-09-23 17:20:18 +00:00
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
pdf_paths = []
|
2024-09-30 18:41:18 +00:00
|
|
|
|
2024-10-09 19:55:45 +00:00
|
|
|
# Load PDF paths from positional arguments or path_list
|
|
|
|
if args.pdf_paths:
|
|
|
|
if len(args.pdf_paths) == 1 and is_glob_pattern(args.pdf_paths[0]):
|
|
|
|
glob_path = args.pdf_paths[0]
|
|
|
|
if glob_path.startswith("s3://"):
|
|
|
|
# Handle S3 globbing
|
|
|
|
expanded_paths = expand_s3_glob(glob_path)
|
|
|
|
pdf_paths.extend(expanded_paths)
|
|
|
|
else:
|
|
|
|
# Handle local filesystem globbing
|
|
|
|
expanded_paths = glob.glob(glob_path, recursive=True)
|
|
|
|
pdf_paths.extend(expanded_paths)
|
2024-09-23 17:20:18 +00:00
|
|
|
else:
|
2024-10-09 19:55:45 +00:00
|
|
|
# Treat positional arguments as list of PDF paths
|
|
|
|
pdf_paths.extend(args.pdf_paths)
|
|
|
|
|
|
|
|
if args.path_list:
|
2024-09-23 17:20:18 +00:00
|
|
|
with open(args.path_list, 'r') as f:
|
2024-09-30 18:41:18 +00:00
|
|
|
for line in f:
|
|
|
|
path = line.strip()
|
2024-10-09 19:55:45 +00:00
|
|
|
if path:
|
|
|
|
pdf_paths.append(path)
|
2024-10-09 18:48:31 +00:00
|
|
|
|
2024-10-09 19:55:45 +00:00
|
|
|
# Remove duplicates and shuffle
|
|
|
|
pdf_paths = list(set(pdf_paths))
|
|
|
|
random.shuffle(pdf_paths)
|
2024-09-23 17:20:18 +00:00
|
|
|
|
2024-09-30 18:41:18 +00:00
|
|
|
print(f"Loaded and shuffled {len(pdf_paths)} paths to use.")
|
|
|
|
|
2024-10-09 20:29:59 +00:00
|
|
|
# Prepare for output
|
2024-09-23 17:20:18 +00:00
|
|
|
output_dir = args.output
|
2024-10-09 20:29:59 +00:00
|
|
|
max_file_size = args.max_size_mb * 1024 * 1024 # Convert MB to bytes
|
2024-09-23 17:20:18 +00:00
|
|
|
|
2024-10-09 20:29:59 +00:00
|
|
|
# Determine if output is S3
|
|
|
|
parsed_output = urlparse(output_dir)
|
|
|
|
is_s3 = parsed_output.scheme in ('s3', 's3a', 's3n')
|
2024-09-23 17:20:18 +00:00
|
|
|
|
2024-10-09 20:29:59 +00:00
|
|
|
# Initialize variables for batching
|
|
|
|
batch = []
|
|
|
|
batch_size = 0
|
2024-09-23 17:20:18 +00:00
|
|
|
pdfs_with_output = 0
|
|
|
|
|
2024-10-09 20:29:59 +00:00
|
|
|
# Function to write a batch
|
|
|
|
def write_batch(batch: List[dict]):
|
|
|
|
nonlocal output_dir
|
|
|
|
if not batch:
|
|
|
|
return
|
|
|
|
batch_content = "\n".join(json.dumps(entry) for entry in batch) + "\n"
|
|
|
|
hash_str = compute_hash(batch_content)
|
|
|
|
output_path_with_hash = get_smart_open_write_path(output_dir, hash_str)
|
|
|
|
with smart_open.open(output_path_with_hash, 'w') as f_out:
|
|
|
|
f_out.write(batch_content)
|
|
|
|
print(f"Wrote batch to {output_path_with_hash}")
|
|
|
|
|
2024-10-09 19:55:45 +00:00
|
|
|
# Using ProcessPoolExecutor to process files concurrently
|
2024-10-05 04:04:45 +00:00
|
|
|
with ProcessPoolExecutor() as executor:
|
2024-09-23 17:20:18 +00:00
|
|
|
futures = []
|
|
|
|
|
2024-10-09 19:55:04 +00:00
|
|
|
with tqdm(desc="Processing PDFs", leave=False, total=len(pdf_paths)) as pb:
|
2024-09-23 17:20:18 +00:00
|
|
|
for pdf_path in pdf_paths:
|
2024-10-09 19:55:04 +00:00
|
|
|
futures.append(executor.submit(process_pdf, pdf_path, args.no_filter))
|
2024-09-23 17:20:18 +00:00
|
|
|
|
|
|
|
for future in as_completed(futures):
|
|
|
|
try:
|
2024-10-09 19:55:45 +00:00
|
|
|
request_results = future.result() # Get the result from the process
|
|
|
|
|
|
|
|
if request_results:
|
|
|
|
pdfs_with_output += 1 # Increment if there's at least one result
|
2024-09-23 17:20:18 +00:00
|
|
|
|
|
|
|
for request_obj in request_results:
|
|
|
|
request_json = json.dumps(request_obj)
|
2024-10-09 20:29:59 +00:00
|
|
|
request_size = len(request_json.encode('utf-8')) + 1 # +1 for newline
|
|
|
|
|
|
|
|
# Check if adding this entry would exceed the max size
|
|
|
|
if batch_size + request_size > max_file_size:
|
|
|
|
# Write the current batch
|
|
|
|
write_batch(batch)
|
|
|
|
# Reset the batch
|
|
|
|
batch = []
|
|
|
|
batch_size = 0
|
|
|
|
|
|
|
|
# Add the entry to the batch
|
|
|
|
batch.append(request_obj)
|
|
|
|
batch_size += request_size
|
2024-09-23 17:20:18 +00:00
|
|
|
|
2024-10-09 20:20:06 +00:00
|
|
|
pb.update(1)
|
2024-09-23 17:20:18 +00:00
|
|
|
except Exception as e:
|
2024-10-09 19:55:45 +00:00
|
|
|
print(f"Error processing a PDF: {str(e)}")
|
2024-09-23 17:20:18 +00:00
|
|
|
|
2024-10-09 20:29:59 +00:00
|
|
|
# Write any remaining batch
|
|
|
|
write_batch(batch)
|
2024-09-23 17:20:18 +00:00
|
|
|
|
2024-10-09 19:55:45 +00:00
|
|
|
# Print the number of PDFs that resulted in at least one output
|
2024-09-23 17:20:18 +00:00
|
|
|
print(f"Number of sampled PDFs that produced at least one output: {pdfs_with_output}")
|
2024-10-09 22:14:28 +00:00
|
|
|
print(f"Now you should run these prompts through mise/birr")
|
2024-09-23 17:20:18 +00:00
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
main()
|