From 4669eb71344c6c7b078359d58c202570fe995d15 Mon Sep 17 00:00:00 2001 From: Jake Poznanski Date: Tue, 15 Oct 2024 16:22:55 +0000 Subject: [PATCH] Adjusting workflow so I can do s2 pdfs --- pdelfin/birrpipeline.py | 58 ++++++++++++++++++++++++++--------------- 1 file changed, 37 insertions(+), 21 deletions(-) diff --git a/pdelfin/birrpipeline.py b/pdelfin/birrpipeline.py index 5ea818b..5b578f7 100644 --- a/pdelfin/birrpipeline.py +++ b/pdelfin/birrpipeline.py @@ -10,6 +10,7 @@ import datetime import posixpath import threading import logging +import boto3.session import urllib3.exceptions from dataclasses import dataclass @@ -25,7 +26,8 @@ from pdelfin.prompts import build_finetuning_prompt from pdelfin.prompts.anchor import get_anchor_text # Global s3 client for the whole script, feel free to adjust params if you need it -s3 = boto3.client('s3') +workspace_s3 = boto3.client('s3') +pdf_s3 = boto3.client('s3') # Quiet logs from pypdf and smart open logging.getLogger("pypdf").setLevel(logging.ERROR) @@ -338,13 +340,13 @@ def parse_s3_path(s3_path): bucket, _, prefix = path.partition('/') return bucket, prefix -def expand_s3_glob(s3_glob: str) -> Dict[str, str]: +def expand_s3_glob(s3_client, s3_glob: str) -> Dict[str, str]: parsed = urlparse(s3_glob) bucket_name = parsed.netloc prefix = os.path.dirname(parsed.path.lstrip('/')).rstrip('/') + "/" pattern = os.path.basename(parsed.path) - paginator = s3.get_paginator('list_objects_v2') + paginator = s3_client.get_paginator('list_objects_v2') page_iterator = paginator.paginate(Bucket=bucket_name, Prefix=prefix) matched_files = {} @@ -374,7 +376,7 @@ def build_page_query(local_pdf_path: str, pretty_pdf_path: str, page: int) -> di } -def get_s3_bytes(s3_path: str, start_index: Optional[int] = None, end_index: Optional[int] = None) -> bytes: +def get_s3_bytes(s3_client, s3_path: str, start_index: Optional[int] = None, end_index: Optional[int] = None) -> bytes: bucket, key = parse_s3_path(s3_path) # Build the range header if start_index and/or end_index are specified @@ -393,9 +395,9 @@ def get_s3_bytes(s3_path: str, start_index: Optional[int] = None, end_index: Opt range_header = {'Range': range_value} if range_header: - obj = s3.get_object(Bucket=bucket, Key=key, Range=range_header['Range']) + obj = s3_client.get_object(Bucket=bucket, Key=key, Range=range_header['Range']) else: - obj = s3.get_object(Bucket=bucket, Key=key) + obj = s3_client.get_object(Bucket=bucket, Key=key) return obj['Body'].read() @@ -406,7 +408,7 @@ def parse_custom_id(custom_id: str) -> Tuple[str, int]: return s3_path, page_num def process_jsonl_content(inference_s3_path: str) -> List[DatabaseManager.BatchInferenceRecord]: - content_bytes = get_s3_bytes(inference_s3_path) + content_bytes = get_s3_bytes(workspace_s3, inference_s3_path) start_index = 0 index_entries = [] @@ -462,7 +464,7 @@ def process_jsonl_content(inference_s3_path: str) -> List[DatabaseManager.BatchI def get_pdf_num_pages(s3_path: str) -> Optional[int]: try: with tempfile.NamedTemporaryFile("wb+", suffix=".pdf") as tf: - tf.write(get_s3_bytes(s3_path)) + tf.write(get_s3_bytes(pdf_s3, s3_path)) tf.flush() reader = PdfReader(tf.name) @@ -484,7 +486,7 @@ def build_pdf_queries(s3_workspace: str, pdf: DatabaseManager.PDFRecord, cur_rou try: with tempfile.NamedTemporaryFile("wb+", suffix=".pdf") as tf: - tf.write(get_s3_bytes(pdf.s3_path)) + tf.write(get_s3_bytes(pdf_s3, pdf.s3_path)) tf.flush() for target_page_num in range(1, pdf.num_pages + 1): @@ -522,9 +524,9 @@ def build_dolma_doc(s3_workspace: str, pdf: DatabaseManager.PDFRecord) -> Option target_page = target_pages[0] - target_row = get_s3_bytes(target_page.inference_s3_path, - start_index=target_page.start_index, - end_index=target_page.start_index+target_page.length - 1) + target_row = get_s3_bytes(workspace_s3, target_page.inference_s3_path, + start_index=target_page.start_index, + end_index=target_page.start_index+target_page.length - 1) target_data = json.loads(target_row.decode("utf-8")) @@ -565,7 +567,7 @@ def mark_pdfs_done(s3_workspace: str, dolma_doc_lines: list[str]): def get_current_round(s3_workspace: str) -> int: bucket, prefix = parse_s3_path(s3_workspace) inference_inputs_prefix = posixpath.join(prefix, 'inference_inputs/') - paginator = s3.get_paginator('list_objects_v2') + paginator = workspace_s3.get_paginator('list_objects_v2') page_iterator = paginator.paginate(Bucket=bucket, Prefix=inference_inputs_prefix, Delimiter='/') round_numbers = [] @@ -590,10 +592,20 @@ def get_current_round(s3_workspace: str) -> int: if __name__ == '__main__': parser = argparse.ArgumentParser(description='Manager for running millions of PDFs through a batch inference pipeline') parser.add_argument('workspace', help='The S3 path where work will be done e.g., s3://bucket/prefix/)') - parser.add_argument('--add_pdfs', help='Glob path to add PDFs (s3) to the workspace', default=None) + parser.add_argument('--add_pdfs', help='Path to add pdfs stored in s3 to the workspace, can be a glob path s3://bucket/prefix/*.pdf or path to file containing list of pdf paths', default=None) + parser.add_argument('--workspace_profile', help='S3 configuration profile for accessing the workspace', default=None) + parser.add_argument('--pdf_profile', help='S3 configuration profile for accessing the raw pdf documents', default=None) parser.add_argument('--max_size_mb', type=int, default=250, help='Max file size in MB') args = parser.parse_args() + if args.workspace_profile: + workspace_session = boto3.Session(profile_name=args.workspace_profile) + workspace_s3 = workspace_session.resource("s3") + + if args.pdf_profile: + pdf_session = boto3.Session(profile_name=args.pdf_profile) + pdf_s3 = pdf_session.resource("s3") + db = DatabaseManager(args.workspace) print(f"Loaded db at {db.db_path}") @@ -605,12 +617,16 @@ if __name__ == '__main__': # If you have new PDFs, step one is to add them to the list if args.add_pdfs: - assert args.add_pdfs.startswith("s3://"), "PDFs must live on s3" - - print(f"Querying all PDFs at {args.add_pdfs}") - - all_pdfs = expand_s3_glob(args.add_pdfs) - print(f"Found {len(all_pdfs):,} total pdf paths") + if args.add_pdfs.startswith("s3://"): + print(f"Querying all PDFs at {args.add_pdfs}") + + all_pdfs = expand_s3_glob(pdf_s3, args.add_pdfs) + print(f"Found {len(all_pdfs):,} total pdf paths") + elif os.path.exists(args.add_pdfs): + with open(args.add_pdfs, "r") as f: + all_pdfs = [line for line in f.readlines() if len(line.strip()) > 0] + else: + raise ValueError("add_pdfs argument needs to be either an s3 glob search path, or a local file contains pdf paths (one per line)") all_pdfs = [pdf for pdf in all_pdfs if not db.pdf_exists(pdf)] print(f"Need to import {len(all_pdfs):,} total new pdf paths") @@ -626,7 +642,7 @@ if __name__ == '__main__': # Now build an index of all the pages that were processed within the workspace so far print("Indexing all batch inference sent to this workspace") - inference_output_paths = expand_s3_glob(f"{args.workspace}/inference_outputs/*.jsonl") + inference_output_paths = expand_s3_glob(workspace_s3, f"{args.workspace}/inference_outputs/*.jsonl") inference_output_paths = [ (s3_path, etag) for s3_path, etag in inference_output_paths.items()