mirror of
https://github.com/allenai/olmocr.git
synced 2025-11-02 11:04:25 +00:00
More refactoring
This commit is contained in:
parent
cd8e28e459
commit
af992bd603
@ -67,7 +67,7 @@ class DatabaseManager:
|
||||
)
|
||||
""")
|
||||
self.cursor.execute("""
|
||||
CREATE INDEX IF NOT EXISTS idx_path ON page_results(s3_path)
|
||||
CREATE INDEX IF NOT EXISTS idx_path ON page_results(pdf_s3_path)
|
||||
""")
|
||||
self.cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS pdfs (
|
||||
@ -122,13 +122,13 @@ class DatabaseManager:
|
||||
""", [(entry.inference_s3_path, entry.pdf_s3_path, entry.page_num, entry.start_index, entry.length, entry.finish_reason, entry.error) for entry in index_entries])
|
||||
self.conn.commit()
|
||||
|
||||
def get_index_entries(self, s3_path: str) -> List[BatchInferenceRecord]:
|
||||
def get_index_entries(self, pdf_s3_path: str) -> List[BatchInferenceRecord]:
|
||||
self.cursor.execute("""
|
||||
SELECT inference_s3_path, pdf_s3_path, page_num, start_index, length, finish_reason, error
|
||||
FROM page_results
|
||||
WHERE s3_path = ?
|
||||
ORDER BY inference_s3_path DESC start_index ASC page_num ASC
|
||||
""", (s3_path,))
|
||||
WHERE pdf_s3_path = ?
|
||||
ORDER BY inference_s3_path DESC, start_index ASC, page_num ASC
|
||||
""", (pdf_s3_path,))
|
||||
|
||||
rows = self.cursor.fetchall()
|
||||
|
||||
@ -1,275 +0,0 @@
|
||||
# The way this script works is it gets a list of pdfs to process
|
||||
# and an output/scratch folder location either locally or in s3 to work with
|
||||
# On the first run, with an empty output folder, it will queue up each page in each pdf to go into a VLM
|
||||
# Then, the user queues up that task in birr, and it outputs to a new subfolder in the same location
|
||||
# Then, you run your script again, and it will see that you have some valid output files
|
||||
# If so, then it will check those output files, and if it has a complete document, it will build a dolma doc for it, and that's considered done
|
||||
# For any remaining pages that got errored out, or failed due to stop_reason not being "stop" (ex. over length)
|
||||
# Then, it will queue up another set of tasks, hopefully much smaller, to send into batch inference again
|
||||
# This process will keep going, until you run it with the --fallback option, at which point it will
|
||||
# just use a basic text extraction on any remaining pages, and assemble the rest of the dolma docs
|
||||
#
|
||||
#
|
||||
#
|
||||
import os
|
||||
import glob
|
||||
import random
|
||||
import argparse
|
||||
import boto3
|
||||
import json
|
||||
import hashlib
|
||||
from pypdf import PdfReader
|
||||
from tqdm import tqdm
|
||||
from typing import Generator, List
|
||||
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, as_completed
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from pdelfin.data.renderpdf import render_pdf_to_base64png
|
||||
from pdelfin.prompts import build_finetuning_prompt
|
||||
from pdelfin.prompts.anchor import get_anchor_text
|
||||
from pdelfin.filter import PdfFilter
|
||||
|
||||
import logging
|
||||
import smart_open
|
||||
import posixpath # Import posixpath for S3 path handling
|
||||
|
||||
logging.getLogger("pypdf").setLevel(logging.ERROR)
|
||||
|
||||
pdf_filter = PdfFilter()
|
||||
|
||||
def build_page_query(local_pdf_path: str, pretty_pdf_path: str, page: int) -> dict:
|
||||
image_base64 = render_pdf_to_base64png(local_pdf_path, page, 1024)
|
||||
anchor_text = get_anchor_text(local_pdf_path, page, pdf_engine="pdfreport")
|
||||
|
||||
return {
|
||||
"custom_id": f"{pretty_pdf_path}-{page}",
|
||||
"chat_messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": build_finetuning_prompt(anchor_text)},
|
||||
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_base64}"}}
|
||||
],
|
||||
}
|
||||
],
|
||||
"temperature": 0.1,
|
||||
"max_tokens": 6000,
|
||||
}
|
||||
|
||||
def fetch_s3_file(s3_url: str, local_path: str) -> str:
|
||||
parsed = urlparse(s3_url)
|
||||
bucket_name = parsed.netloc
|
||||
key = parsed.path.lstrip('/')
|
||||
|
||||
s3 = boto3.client('s3')
|
||||
s3.download_file(bucket_name, key, local_path)
|
||||
return local_path
|
||||
|
||||
def process_pdf(pdf_path: str, no_filter: bool) -> List[dict]:
|
||||
if pdf_path.startswith("s3://"):
|
||||
local_pdf_path = os.path.join("/tmp", os.path.basename(pdf_path))
|
||||
fetch_s3_file(pdf_path, local_pdf_path)
|
||||
else:
|
||||
local_pdf_path = pdf_path
|
||||
|
||||
if (not no_filter) and pdf_filter.filter_out_pdf(local_pdf_path):
|
||||
print(f"Skipping {local_pdf_path} due to common filter")
|
||||
return []
|
||||
|
||||
pretty_pdf_path = pdf_path
|
||||
|
||||
pdf = PdfReader(local_pdf_path)
|
||||
num_pages = len(pdf.pages)
|
||||
|
||||
sample_pages = list(range(1, num_pages + 1))
|
||||
result = []
|
||||
for page in sample_pages:
|
||||
try:
|
||||
query = build_page_query(local_pdf_path, pretty_pdf_path, page)
|
||||
result.append(query)
|
||||
except Exception as e:
|
||||
print(f"Error processing page {page} of {pdf_path}: {e}")
|
||||
|
||||
return result
|
||||
|
||||
def is_glob_pattern(path: str) -> bool:
|
||||
return any(char in path for char in ['*', '?', '[', ']'])
|
||||
|
||||
def expand_s3_glob(s3_glob: str) -> list:
|
||||
parsed = urlparse(s3_glob)
|
||||
bucket_name = parsed.netloc
|
||||
prefix = os.path.dirname(parsed.path.lstrip('/')).rstrip('/') + "/"
|
||||
pattern = os.path.basename(parsed.path)
|
||||
|
||||
s3 = boto3.client('s3')
|
||||
paginator = s3.get_paginator('list_objects_v2')
|
||||
page_iterator = paginator.paginate(Bucket=bucket_name, Prefix=prefix)
|
||||
|
||||
matched_files = []
|
||||
for page in page_iterator:
|
||||
for obj in page.get('Contents', []):
|
||||
key = obj['Key']
|
||||
if key.endswith('.pdf') and glob.fnmatch.fnmatch(key, posixpath.join(prefix, pattern)):
|
||||
matched_files.append(f"s3://{bucket_name}/{key}")
|
||||
|
||||
return matched_files
|
||||
|
||||
def compute_hash(content: str) -> str:
|
||||
"""Compute a 20-character SHA1 hash of the given content."""
|
||||
sha1 = hashlib.sha1()
|
||||
sha1.update(content.encode('utf-8'))
|
||||
return sha1.hexdigest()[:20]
|
||||
|
||||
def get_smart_open_write_path(output_path: str, hash_str: str) -> str:
|
||||
"""Generate the full output path with hash in the filename."""
|
||||
parsed = urlparse(output_path)
|
||||
if parsed.scheme in ('s3', 's3a', 's3n'):
|
||||
bucket = parsed.netloc
|
||||
key = parsed.path.lstrip('/')
|
||||
# Ensure the key is treated as a directory by appending a slash if not present
|
||||
if key and not key.endswith('/'):
|
||||
key += '/'
|
||||
# Use posixpath to correctly join S3 paths
|
||||
full_key = posixpath.join(key, f"output_{hash_str}.jsonl")
|
||||
return f"s3://{bucket}/{full_key}"
|
||||
else:
|
||||
dir_path = output_path
|
||||
filename = f"output_{hash_str}.jsonl"
|
||||
return os.path.join(dir_path, filename)
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Given a bunch of PDFs, prepares a mise/birr workflow to run them through a conversion mechanism"
|
||||
)
|
||||
parser.add_argument(
|
||||
"pdf_paths",
|
||||
nargs='*',
|
||||
help=(
|
||||
"List of PDF paths to process. If a single argument contains glob patterns (e.g., *.pdf or s3://bucket/pdfs/*.pdf), "
|
||||
"it will be expanded accordingly."
|
||||
)
|
||||
)
|
||||
parser.add_argument(
|
||||
"--path_list",
|
||||
type=str,
|
||||
help="Path to a file containing paths to PDFs, one per line."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_size_mb",
|
||||
type=int,
|
||||
default=250,
|
||||
help="Max number of MBs of entries to put in each birr workitem"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no_filter",
|
||||
action="store_true",
|
||||
help="Disables the basic spam/language filtering so that ALL pdfs listed are used"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=str,
|
||||
default="mise_batch_data",
|
||||
help="Output destination (can be a local path or an S3 URI)"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
pdf_paths = []
|
||||
|
||||
# Load PDF paths from positional arguments or path_list
|
||||
if args.pdf_paths:
|
||||
for path in args.pdf_paths:
|
||||
if is_glob_pattern(path):
|
||||
glob_path = path
|
||||
if glob_path.startswith("s3://"):
|
||||
# Handle S3 globbing
|
||||
expanded_paths = expand_s3_glob(glob_path)
|
||||
pdf_paths.extend(expanded_paths)
|
||||
else:
|
||||
# Handle local filesystem globbing
|
||||
expanded_paths = glob.glob(glob_path, recursive=True)
|
||||
pdf_paths.extend(expanded_paths)
|
||||
else:
|
||||
pdf_paths.append(path)
|
||||
|
||||
if args.path_list:
|
||||
with open(args.path_list, 'r') as f:
|
||||
for line in f:
|
||||
path = line.strip()
|
||||
if path:
|
||||
pdf_paths.append(path)
|
||||
|
||||
# Remove duplicates and shuffle
|
||||
pdf_paths = list(set(pdf_paths))
|
||||
random.shuffle(pdf_paths)
|
||||
|
||||
print(f"Loaded and shuffled {len(pdf_paths)} paths to use.")
|
||||
|
||||
# Prepare for output
|
||||
output_dir = args.output
|
||||
max_file_size = args.max_size_mb * 1024 * 1024 # Convert MB to bytes
|
||||
|
||||
# Determine if output is S3
|
||||
parsed_output = urlparse(output_dir)
|
||||
is_s3 = parsed_output.scheme in ('s3', 's3a', 's3n')
|
||||
|
||||
# Initialize variables for batching
|
||||
batch = []
|
||||
batch_size = 0
|
||||
pdfs_with_output = 0
|
||||
|
||||
# Function to write a batch
|
||||
def write_batch(batch: List[dict]):
|
||||
nonlocal output_dir
|
||||
if not batch:
|
||||
return
|
||||
batch_content = "\n".join(json.dumps(entry) for entry in batch) + "\n"
|
||||
hash_str = compute_hash(batch_content)
|
||||
output_path_with_hash = get_smart_open_write_path(output_dir, hash_str)
|
||||
with smart_open.open(output_path_with_hash, 'w') as f_out:
|
||||
f_out.write(batch_content)
|
||||
print(f"Wrote batch to {output_path_with_hash}")
|
||||
|
||||
# Using ProcessPoolExecutor to process files concurrently
|
||||
with ProcessPoolExecutor() as executor:
|
||||
futures = []
|
||||
|
||||
with tqdm(desc="Processing PDFs", leave=False, total=len(pdf_paths)) as pb:
|
||||
for pdf_path in pdf_paths:
|
||||
futures.append(executor.submit(process_pdf, pdf_path, args.no_filter))
|
||||
|
||||
for future in as_completed(futures):
|
||||
try:
|
||||
request_results = future.result() # Get the result from the process
|
||||
|
||||
if request_results:
|
||||
pdfs_with_output += 1 # Increment if there's at least one result
|
||||
|
||||
for request_obj in request_results:
|
||||
request_json = json.dumps(request_obj)
|
||||
request_size = len(request_json.encode('utf-8')) + 1 # +1 for newline
|
||||
|
||||
# Check if adding this entry would exceed the max size
|
||||
if batch_size + request_size > max_file_size:
|
||||
# Write the current batch
|
||||
write_batch(batch)
|
||||
# Reset the batch
|
||||
batch = []
|
||||
batch_size = 0
|
||||
|
||||
# Add the entry to the batch
|
||||
batch.append(request_obj)
|
||||
batch_size += request_size
|
||||
|
||||
pb.update(1)
|
||||
except Exception as e:
|
||||
print(f"Error processing a PDF: {str(e)}")
|
||||
|
||||
# Write any remaining batch
|
||||
write_batch(batch)
|
||||
|
||||
# Print the number of PDFs that resulted in at least one output
|
||||
print(f"Number of sampled PDFs that produced at least one output: {pdfs_with_output}")
|
||||
print(f"Now you should run these prompts through mise/birr")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
x
Reference in New Issue
Block a user