diff --git a/pdelfin/assemblepipeline.py b/pdelfin/birrpipeline.py similarity index 98% rename from pdelfin/assemblepipeline.py rename to pdelfin/birrpipeline.py index 930ae9a..ae5fb8c 100644 --- a/pdelfin/assemblepipeline.py +++ b/pdelfin/birrpipeline.py @@ -67,7 +67,7 @@ class DatabaseManager: ) """) self.cursor.execute(""" - CREATE INDEX IF NOT EXISTS idx_path ON page_results(s3_path) + CREATE INDEX IF NOT EXISTS idx_path ON page_results(pdf_s3_path) """) self.cursor.execute(""" CREATE TABLE IF NOT EXISTS pdfs ( @@ -122,13 +122,13 @@ class DatabaseManager: """, [(entry.inference_s3_path, entry.pdf_s3_path, entry.page_num, entry.start_index, entry.length, entry.finish_reason, entry.error) for entry in index_entries]) self.conn.commit() - def get_index_entries(self, s3_path: str) -> List[BatchInferenceRecord]: + def get_index_entries(self, pdf_s3_path: str) -> List[BatchInferenceRecord]: self.cursor.execute(""" SELECT inference_s3_path, pdf_s3_path, page_num, start_index, length, finish_reason, error FROM page_results - WHERE s3_path = ? - ORDER BY inference_s3_path DESC start_index ASC page_num ASC - """, (s3_path,)) + WHERE pdf_s3_path = ? + ORDER BY inference_s3_path DESC, start_index ASC, page_num ASC + """, (pdf_s3_path,)) rows = self.cursor.fetchall() diff --git a/pdelfin/runpipeline.py b/pdelfin/runpipeline.py deleted file mode 100644 index 82f9968..0000000 --- a/pdelfin/runpipeline.py +++ /dev/null @@ -1,275 +0,0 @@ -# The way this script works is it gets a list of pdfs to process -# and an output/scratch folder location either locally or in s3 to work with -# On the first run, with an empty output folder, it will queue up each page in each pdf to go into a VLM -# Then, the user queues up that task in birr, and it outputs to a new subfolder in the same location -# Then, you run your script again, and it will see that you have some valid output files -# If so, then it will check those output files, and if it has a complete document, it will build a dolma doc for it, and that's considered done -# For any remaining pages that got errored out, or failed due to stop_reason not being "stop" (ex. over length) -# Then, it will queue up another set of tasks, hopefully much smaller, to send into batch inference again -# This process will keep going, until you run it with the --fallback option, at which point it will -# just use a basic text extraction on any remaining pages, and assemble the rest of the dolma docs -# -# -# -import os -import glob -import random -import argparse -import boto3 -import json -import hashlib -from pypdf import PdfReader -from tqdm import tqdm -from typing import Generator, List -from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, as_completed -from urllib.parse import urlparse - -from pdelfin.data.renderpdf import render_pdf_to_base64png -from pdelfin.prompts import build_finetuning_prompt -from pdelfin.prompts.anchor import get_anchor_text -from pdelfin.filter import PdfFilter - -import logging -import smart_open -import posixpath # Import posixpath for S3 path handling - -logging.getLogger("pypdf").setLevel(logging.ERROR) - -pdf_filter = PdfFilter() - -def build_page_query(local_pdf_path: str, pretty_pdf_path: str, page: int) -> dict: - image_base64 = render_pdf_to_base64png(local_pdf_path, page, 1024) - anchor_text = get_anchor_text(local_pdf_path, page, pdf_engine="pdfreport") - - return { - "custom_id": f"{pretty_pdf_path}-{page}", - "chat_messages": [ - { - "role": "user", - "content": [ - {"type": "text", "text": build_finetuning_prompt(anchor_text)}, - {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_base64}"}} - ], - } - ], - "temperature": 0.1, - "max_tokens": 6000, - } - -def fetch_s3_file(s3_url: str, local_path: str) -> str: - parsed = urlparse(s3_url) - bucket_name = parsed.netloc - key = parsed.path.lstrip('/') - - s3 = boto3.client('s3') - s3.download_file(bucket_name, key, local_path) - return local_path - -def process_pdf(pdf_path: str, no_filter: bool) -> List[dict]: - if pdf_path.startswith("s3://"): - local_pdf_path = os.path.join("/tmp", os.path.basename(pdf_path)) - fetch_s3_file(pdf_path, local_pdf_path) - else: - local_pdf_path = pdf_path - - if (not no_filter) and pdf_filter.filter_out_pdf(local_pdf_path): - print(f"Skipping {local_pdf_path} due to common filter") - return [] - - pretty_pdf_path = pdf_path - - pdf = PdfReader(local_pdf_path) - num_pages = len(pdf.pages) - - sample_pages = list(range(1, num_pages + 1)) - result = [] - for page in sample_pages: - try: - query = build_page_query(local_pdf_path, pretty_pdf_path, page) - result.append(query) - except Exception as e: - print(f"Error processing page {page} of {pdf_path}: {e}") - - return result - -def is_glob_pattern(path: str) -> bool: - return any(char in path for char in ['*', '?', '[', ']']) - -def expand_s3_glob(s3_glob: str) -> list: - parsed = urlparse(s3_glob) - bucket_name = parsed.netloc - prefix = os.path.dirname(parsed.path.lstrip('/')).rstrip('/') + "/" - pattern = os.path.basename(parsed.path) - - s3 = boto3.client('s3') - paginator = s3.get_paginator('list_objects_v2') - page_iterator = paginator.paginate(Bucket=bucket_name, Prefix=prefix) - - matched_files = [] - for page in page_iterator: - for obj in page.get('Contents', []): - key = obj['Key'] - if key.endswith('.pdf') and glob.fnmatch.fnmatch(key, posixpath.join(prefix, pattern)): - matched_files.append(f"s3://{bucket_name}/{key}") - - return matched_files - -def compute_hash(content: str) -> str: - """Compute a 20-character SHA1 hash of the given content.""" - sha1 = hashlib.sha1() - sha1.update(content.encode('utf-8')) - return sha1.hexdigest()[:20] - -def get_smart_open_write_path(output_path: str, hash_str: str) -> str: - """Generate the full output path with hash in the filename.""" - parsed = urlparse(output_path) - if parsed.scheme in ('s3', 's3a', 's3n'): - bucket = parsed.netloc - key = parsed.path.lstrip('/') - # Ensure the key is treated as a directory by appending a slash if not present - if key and not key.endswith('/'): - key += '/' - # Use posixpath to correctly join S3 paths - full_key = posixpath.join(key, f"output_{hash_str}.jsonl") - return f"s3://{bucket}/{full_key}" - else: - dir_path = output_path - filename = f"output_{hash_str}.jsonl" - return os.path.join(dir_path, filename) - -def main(): - parser = argparse.ArgumentParser( - description="Given a bunch of PDFs, prepares a mise/birr workflow to run them through a conversion mechanism" - ) - parser.add_argument( - "pdf_paths", - nargs='*', - help=( - "List of PDF paths to process. If a single argument contains glob patterns (e.g., *.pdf or s3://bucket/pdfs/*.pdf), " - "it will be expanded accordingly." - ) - ) - parser.add_argument( - "--path_list", - type=str, - help="Path to a file containing paths to PDFs, one per line." - ) - parser.add_argument( - "--max_size_mb", - type=int, - default=250, - help="Max number of MBs of entries to put in each birr workitem" - ) - parser.add_argument( - "--no_filter", - action="store_true", - help="Disables the basic spam/language filtering so that ALL pdfs listed are used" - ) - parser.add_argument( - "--output", - type=str, - default="mise_batch_data", - help="Output destination (can be a local path or an S3 URI)" - ) - args = parser.parse_args() - - pdf_paths = [] - - # Load PDF paths from positional arguments or path_list - if args.pdf_paths: - for path in args.pdf_paths: - if is_glob_pattern(path): - glob_path = path - if glob_path.startswith("s3://"): - # Handle S3 globbing - expanded_paths = expand_s3_glob(glob_path) - pdf_paths.extend(expanded_paths) - else: - # Handle local filesystem globbing - expanded_paths = glob.glob(glob_path, recursive=True) - pdf_paths.extend(expanded_paths) - else: - pdf_paths.append(path) - - if args.path_list: - with open(args.path_list, 'r') as f: - for line in f: - path = line.strip() - if path: - pdf_paths.append(path) - - # Remove duplicates and shuffle - pdf_paths = list(set(pdf_paths)) - random.shuffle(pdf_paths) - - print(f"Loaded and shuffled {len(pdf_paths)} paths to use.") - - # Prepare for output - output_dir = args.output - max_file_size = args.max_size_mb * 1024 * 1024 # Convert MB to bytes - - # Determine if output is S3 - parsed_output = urlparse(output_dir) - is_s3 = parsed_output.scheme in ('s3', 's3a', 's3n') - - # Initialize variables for batching - batch = [] - batch_size = 0 - pdfs_with_output = 0 - - # Function to write a batch - def write_batch(batch: List[dict]): - nonlocal output_dir - if not batch: - return - batch_content = "\n".join(json.dumps(entry) for entry in batch) + "\n" - hash_str = compute_hash(batch_content) - output_path_with_hash = get_smart_open_write_path(output_dir, hash_str) - with smart_open.open(output_path_with_hash, 'w') as f_out: - f_out.write(batch_content) - print(f"Wrote batch to {output_path_with_hash}") - - # Using ProcessPoolExecutor to process files concurrently - with ProcessPoolExecutor() as executor: - futures = [] - - with tqdm(desc="Processing PDFs", leave=False, total=len(pdf_paths)) as pb: - for pdf_path in pdf_paths: - futures.append(executor.submit(process_pdf, pdf_path, args.no_filter)) - - for future in as_completed(futures): - try: - request_results = future.result() # Get the result from the process - - if request_results: - pdfs_with_output += 1 # Increment if there's at least one result - - for request_obj in request_results: - request_json = json.dumps(request_obj) - request_size = len(request_json.encode('utf-8')) + 1 # +1 for newline - - # Check if adding this entry would exceed the max size - if batch_size + request_size > max_file_size: - # Write the current batch - write_batch(batch) - # Reset the batch - batch = [] - batch_size = 0 - - # Add the entry to the batch - batch.append(request_obj) - batch_size += request_size - - pb.update(1) - except Exception as e: - print(f"Error processing a PDF: {str(e)}") - - # Write any remaining batch - write_batch(batch) - - # Print the number of PDFs that resulted in at least one output - print(f"Number of sampled PDFs that produced at least one output: {pdfs_with_output}") - print(f"Now you should run these prompts through mise/birr") - -if __name__ == "__main__": - main()