Adjusting workflow so I can do s2 pdfs

This commit is contained in:
Jake Poznanski 2024-10-15 16:22:55 +00:00
parent 6d61ae4aa8
commit 4669eb7134

View File

@ -10,6 +10,7 @@ import datetime
import posixpath import posixpath
import threading import threading
import logging import logging
import boto3.session
import urllib3.exceptions import urllib3.exceptions
from dataclasses import dataclass from dataclasses import dataclass
@ -25,7 +26,8 @@ from pdelfin.prompts import build_finetuning_prompt
from pdelfin.prompts.anchor import get_anchor_text from pdelfin.prompts.anchor import get_anchor_text
# Global s3 client for the whole script, feel free to adjust params if you need it # Global s3 client for the whole script, feel free to adjust params if you need it
s3 = boto3.client('s3') workspace_s3 = boto3.client('s3')
pdf_s3 = boto3.client('s3')
# Quiet logs from pypdf and smart open # Quiet logs from pypdf and smart open
logging.getLogger("pypdf").setLevel(logging.ERROR) logging.getLogger("pypdf").setLevel(logging.ERROR)
@ -338,13 +340,13 @@ def parse_s3_path(s3_path):
bucket, _, prefix = path.partition('/') bucket, _, prefix = path.partition('/')
return bucket, prefix return bucket, prefix
def expand_s3_glob(s3_glob: str) -> Dict[str, str]: def expand_s3_glob(s3_client, s3_glob: str) -> Dict[str, str]:
parsed = urlparse(s3_glob) parsed = urlparse(s3_glob)
bucket_name = parsed.netloc bucket_name = parsed.netloc
prefix = os.path.dirname(parsed.path.lstrip('/')).rstrip('/') + "/" prefix = os.path.dirname(parsed.path.lstrip('/')).rstrip('/') + "/"
pattern = os.path.basename(parsed.path) pattern = os.path.basename(parsed.path)
paginator = s3.get_paginator('list_objects_v2') paginator = s3_client.get_paginator('list_objects_v2')
page_iterator = paginator.paginate(Bucket=bucket_name, Prefix=prefix) page_iterator = paginator.paginate(Bucket=bucket_name, Prefix=prefix)
matched_files = {} matched_files = {}
@ -374,7 +376,7 @@ def build_page_query(local_pdf_path: str, pretty_pdf_path: str, page: int) -> di
} }
def get_s3_bytes(s3_path: str, start_index: Optional[int] = None, end_index: Optional[int] = None) -> bytes: def get_s3_bytes(s3_client, s3_path: str, start_index: Optional[int] = None, end_index: Optional[int] = None) -> bytes:
bucket, key = parse_s3_path(s3_path) bucket, key = parse_s3_path(s3_path)
# Build the range header if start_index and/or end_index are specified # Build the range header if start_index and/or end_index are specified
@ -393,9 +395,9 @@ def get_s3_bytes(s3_path: str, start_index: Optional[int] = None, end_index: Opt
range_header = {'Range': range_value} range_header = {'Range': range_value}
if range_header: if range_header:
obj = s3.get_object(Bucket=bucket, Key=key, Range=range_header['Range']) obj = s3_client.get_object(Bucket=bucket, Key=key, Range=range_header['Range'])
else: else:
obj = s3.get_object(Bucket=bucket, Key=key) obj = s3_client.get_object(Bucket=bucket, Key=key)
return obj['Body'].read() return obj['Body'].read()
@ -406,7 +408,7 @@ def parse_custom_id(custom_id: str) -> Tuple[str, int]:
return s3_path, page_num return s3_path, page_num
def process_jsonl_content(inference_s3_path: str) -> List[DatabaseManager.BatchInferenceRecord]: def process_jsonl_content(inference_s3_path: str) -> List[DatabaseManager.BatchInferenceRecord]:
content_bytes = get_s3_bytes(inference_s3_path) content_bytes = get_s3_bytes(workspace_s3, inference_s3_path)
start_index = 0 start_index = 0
index_entries = [] index_entries = []
@ -462,7 +464,7 @@ def process_jsonl_content(inference_s3_path: str) -> List[DatabaseManager.BatchI
def get_pdf_num_pages(s3_path: str) -> Optional[int]: def get_pdf_num_pages(s3_path: str) -> Optional[int]:
try: try:
with tempfile.NamedTemporaryFile("wb+", suffix=".pdf") as tf: with tempfile.NamedTemporaryFile("wb+", suffix=".pdf") as tf:
tf.write(get_s3_bytes(s3_path)) tf.write(get_s3_bytes(pdf_s3, s3_path))
tf.flush() tf.flush()
reader = PdfReader(tf.name) reader = PdfReader(tf.name)
@ -484,7 +486,7 @@ def build_pdf_queries(s3_workspace: str, pdf: DatabaseManager.PDFRecord, cur_rou
try: try:
with tempfile.NamedTemporaryFile("wb+", suffix=".pdf") as tf: with tempfile.NamedTemporaryFile("wb+", suffix=".pdf") as tf:
tf.write(get_s3_bytes(pdf.s3_path)) tf.write(get_s3_bytes(pdf_s3, pdf.s3_path))
tf.flush() tf.flush()
for target_page_num in range(1, pdf.num_pages + 1): for target_page_num in range(1, pdf.num_pages + 1):
@ -522,9 +524,9 @@ def build_dolma_doc(s3_workspace: str, pdf: DatabaseManager.PDFRecord) -> Option
target_page = target_pages[0] target_page = target_pages[0]
target_row = get_s3_bytes(target_page.inference_s3_path, target_row = get_s3_bytes(workspace_s3, target_page.inference_s3_path,
start_index=target_page.start_index, start_index=target_page.start_index,
end_index=target_page.start_index+target_page.length - 1) end_index=target_page.start_index+target_page.length - 1)
target_data = json.loads(target_row.decode("utf-8")) target_data = json.loads(target_row.decode("utf-8"))
@ -565,7 +567,7 @@ def mark_pdfs_done(s3_workspace: str, dolma_doc_lines: list[str]):
def get_current_round(s3_workspace: str) -> int: def get_current_round(s3_workspace: str) -> int:
bucket, prefix = parse_s3_path(s3_workspace) bucket, prefix = parse_s3_path(s3_workspace)
inference_inputs_prefix = posixpath.join(prefix, 'inference_inputs/') inference_inputs_prefix = posixpath.join(prefix, 'inference_inputs/')
paginator = s3.get_paginator('list_objects_v2') paginator = workspace_s3.get_paginator('list_objects_v2')
page_iterator = paginator.paginate(Bucket=bucket, Prefix=inference_inputs_prefix, Delimiter='/') page_iterator = paginator.paginate(Bucket=bucket, Prefix=inference_inputs_prefix, Delimiter='/')
round_numbers = [] round_numbers = []
@ -590,10 +592,20 @@ def get_current_round(s3_workspace: str) -> int:
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Manager for running millions of PDFs through a batch inference pipeline') parser = argparse.ArgumentParser(description='Manager for running millions of PDFs through a batch inference pipeline')
parser.add_argument('workspace', help='The S3 path where work will be done e.g., s3://bucket/prefix/)') parser.add_argument('workspace', help='The S3 path where work will be done e.g., s3://bucket/prefix/)')
parser.add_argument('--add_pdfs', help='Glob path to add PDFs (s3) to the workspace', default=None) parser.add_argument('--add_pdfs', help='Path to add pdfs stored in s3 to the workspace, can be a glob path s3://bucket/prefix/*.pdf or path to file containing list of pdf paths', default=None)
parser.add_argument('--workspace_profile', help='S3 configuration profile for accessing the workspace', default=None)
parser.add_argument('--pdf_profile', help='S3 configuration profile for accessing the raw pdf documents', default=None)
parser.add_argument('--max_size_mb', type=int, default=250, help='Max file size in MB') parser.add_argument('--max_size_mb', type=int, default=250, help='Max file size in MB')
args = parser.parse_args() args = parser.parse_args()
if args.workspace_profile:
workspace_session = boto3.Session(profile_name=args.workspace_profile)
workspace_s3 = workspace_session.resource("s3")
if args.pdf_profile:
pdf_session = boto3.Session(profile_name=args.pdf_profile)
pdf_s3 = pdf_session.resource("s3")
db = DatabaseManager(args.workspace) db = DatabaseManager(args.workspace)
print(f"Loaded db at {db.db_path}") print(f"Loaded db at {db.db_path}")
@ -605,12 +617,16 @@ if __name__ == '__main__':
# If you have new PDFs, step one is to add them to the list # If you have new PDFs, step one is to add them to the list
if args.add_pdfs: if args.add_pdfs:
assert args.add_pdfs.startswith("s3://"), "PDFs must live on s3" if args.add_pdfs.startswith("s3://"):
print(f"Querying all PDFs at {args.add_pdfs}")
print(f"Querying all PDFs at {args.add_pdfs}") all_pdfs = expand_s3_glob(pdf_s3, args.add_pdfs)
print(f"Found {len(all_pdfs):,} total pdf paths")
all_pdfs = expand_s3_glob(args.add_pdfs) elif os.path.exists(args.add_pdfs):
print(f"Found {len(all_pdfs):,} total pdf paths") with open(args.add_pdfs, "r") as f:
all_pdfs = [line for line in f.readlines() if len(line.strip()) > 0]
else:
raise ValueError("add_pdfs argument needs to be either an s3 glob search path, or a local file contains pdf paths (one per line)")
all_pdfs = [pdf for pdf in all_pdfs if not db.pdf_exists(pdf)] all_pdfs = [pdf for pdf in all_pdfs if not db.pdf_exists(pdf)]
print(f"Need to import {len(all_pdfs):,} total new pdf paths") print(f"Need to import {len(all_pdfs):,} total new pdf paths")
@ -626,7 +642,7 @@ if __name__ == '__main__':
# Now build an index of all the pages that were processed within the workspace so far # Now build an index of all the pages that were processed within the workspace so far
print("Indexing all batch inference sent to this workspace") print("Indexing all batch inference sent to this workspace")
inference_output_paths = expand_s3_glob(f"{args.workspace}/inference_outputs/*.jsonl") inference_output_paths = expand_s3_glob(workspace_s3, f"{args.workspace}/inference_outputs/*.jsonl")
inference_output_paths = [ inference_output_paths = [
(s3_path, etag) for s3_path, etag in inference_output_paths.items() (s3_path, etag) for s3_path, etag in inference_output_paths.items()