Adjusting workflow so I can do s2 pdfs

This commit is contained in:
Jake Poznanski 2024-10-15 16:22:55 +00:00
parent 6d61ae4aa8
commit 4669eb7134

View File

@ -10,6 +10,7 @@ import datetime
import posixpath
import threading
import logging
import boto3.session
import urllib3.exceptions
from dataclasses import dataclass
@ -25,7 +26,8 @@ from pdelfin.prompts import build_finetuning_prompt
from pdelfin.prompts.anchor import get_anchor_text
# Global s3 client for the whole script, feel free to adjust params if you need it
s3 = boto3.client('s3')
workspace_s3 = boto3.client('s3')
pdf_s3 = boto3.client('s3')
# Quiet logs from pypdf and smart open
logging.getLogger("pypdf").setLevel(logging.ERROR)
@ -338,13 +340,13 @@ def parse_s3_path(s3_path):
bucket, _, prefix = path.partition('/')
return bucket, prefix
def expand_s3_glob(s3_glob: str) -> Dict[str, str]:
def expand_s3_glob(s3_client, s3_glob: str) -> Dict[str, str]:
parsed = urlparse(s3_glob)
bucket_name = parsed.netloc
prefix = os.path.dirname(parsed.path.lstrip('/')).rstrip('/') + "/"
pattern = os.path.basename(parsed.path)
paginator = s3.get_paginator('list_objects_v2')
paginator = s3_client.get_paginator('list_objects_v2')
page_iterator = paginator.paginate(Bucket=bucket_name, Prefix=prefix)
matched_files = {}
@ -374,7 +376,7 @@ def build_page_query(local_pdf_path: str, pretty_pdf_path: str, page: int) -> di
}
def get_s3_bytes(s3_path: str, start_index: Optional[int] = None, end_index: Optional[int] = None) -> bytes:
def get_s3_bytes(s3_client, s3_path: str, start_index: Optional[int] = None, end_index: Optional[int] = None) -> bytes:
bucket, key = parse_s3_path(s3_path)
# Build the range header if start_index and/or end_index are specified
@ -393,9 +395,9 @@ def get_s3_bytes(s3_path: str, start_index: Optional[int] = None, end_index: Opt
range_header = {'Range': range_value}
if range_header:
obj = s3.get_object(Bucket=bucket, Key=key, Range=range_header['Range'])
obj = s3_client.get_object(Bucket=bucket, Key=key, Range=range_header['Range'])
else:
obj = s3.get_object(Bucket=bucket, Key=key)
obj = s3_client.get_object(Bucket=bucket, Key=key)
return obj['Body'].read()
@ -406,7 +408,7 @@ def parse_custom_id(custom_id: str) -> Tuple[str, int]:
return s3_path, page_num
def process_jsonl_content(inference_s3_path: str) -> List[DatabaseManager.BatchInferenceRecord]:
content_bytes = get_s3_bytes(inference_s3_path)
content_bytes = get_s3_bytes(workspace_s3, inference_s3_path)
start_index = 0
index_entries = []
@ -462,7 +464,7 @@ def process_jsonl_content(inference_s3_path: str) -> List[DatabaseManager.BatchI
def get_pdf_num_pages(s3_path: str) -> Optional[int]:
try:
with tempfile.NamedTemporaryFile("wb+", suffix=".pdf") as tf:
tf.write(get_s3_bytes(s3_path))
tf.write(get_s3_bytes(pdf_s3, s3_path))
tf.flush()
reader = PdfReader(tf.name)
@ -484,7 +486,7 @@ def build_pdf_queries(s3_workspace: str, pdf: DatabaseManager.PDFRecord, cur_rou
try:
with tempfile.NamedTemporaryFile("wb+", suffix=".pdf") as tf:
tf.write(get_s3_bytes(pdf.s3_path))
tf.write(get_s3_bytes(pdf_s3, pdf.s3_path))
tf.flush()
for target_page_num in range(1, pdf.num_pages + 1):
@ -522,9 +524,9 @@ def build_dolma_doc(s3_workspace: str, pdf: DatabaseManager.PDFRecord) -> Option
target_page = target_pages[0]
target_row = get_s3_bytes(target_page.inference_s3_path,
start_index=target_page.start_index,
end_index=target_page.start_index+target_page.length - 1)
target_row = get_s3_bytes(workspace_s3, target_page.inference_s3_path,
start_index=target_page.start_index,
end_index=target_page.start_index+target_page.length - 1)
target_data = json.loads(target_row.decode("utf-8"))
@ -565,7 +567,7 @@ def mark_pdfs_done(s3_workspace: str, dolma_doc_lines: list[str]):
def get_current_round(s3_workspace: str) -> int:
bucket, prefix = parse_s3_path(s3_workspace)
inference_inputs_prefix = posixpath.join(prefix, 'inference_inputs/')
paginator = s3.get_paginator('list_objects_v2')
paginator = workspace_s3.get_paginator('list_objects_v2')
page_iterator = paginator.paginate(Bucket=bucket, Prefix=inference_inputs_prefix, Delimiter='/')
round_numbers = []
@ -590,10 +592,20 @@ def get_current_round(s3_workspace: str) -> int:
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Manager for running millions of PDFs through a batch inference pipeline')
parser.add_argument('workspace', help='The S3 path where work will be done e.g., s3://bucket/prefix/)')
parser.add_argument('--add_pdfs', help='Glob path to add PDFs (s3) to the workspace', default=None)
parser.add_argument('--add_pdfs', help='Path to add pdfs stored in s3 to the workspace, can be a glob path s3://bucket/prefix/*.pdf or path to file containing list of pdf paths', default=None)
parser.add_argument('--workspace_profile', help='S3 configuration profile for accessing the workspace', default=None)
parser.add_argument('--pdf_profile', help='S3 configuration profile for accessing the raw pdf documents', default=None)
parser.add_argument('--max_size_mb', type=int, default=250, help='Max file size in MB')
args = parser.parse_args()
if args.workspace_profile:
workspace_session = boto3.Session(profile_name=args.workspace_profile)
workspace_s3 = workspace_session.resource("s3")
if args.pdf_profile:
pdf_session = boto3.Session(profile_name=args.pdf_profile)
pdf_s3 = pdf_session.resource("s3")
db = DatabaseManager(args.workspace)
print(f"Loaded db at {db.db_path}")
@ -605,12 +617,16 @@ if __name__ == '__main__':
# If you have new PDFs, step one is to add them to the list
if args.add_pdfs:
assert args.add_pdfs.startswith("s3://"), "PDFs must live on s3"
print(f"Querying all PDFs at {args.add_pdfs}")
all_pdfs = expand_s3_glob(args.add_pdfs)
print(f"Found {len(all_pdfs):,} total pdf paths")
if args.add_pdfs.startswith("s3://"):
print(f"Querying all PDFs at {args.add_pdfs}")
all_pdfs = expand_s3_glob(pdf_s3, args.add_pdfs)
print(f"Found {len(all_pdfs):,} total pdf paths")
elif os.path.exists(args.add_pdfs):
with open(args.add_pdfs, "r") as f:
all_pdfs = [line for line in f.readlines() if len(line.strip()) > 0]
else:
raise ValueError("add_pdfs argument needs to be either an s3 glob search path, or a local file contains pdf paths (one per line)")
all_pdfs = [pdf for pdf in all_pdfs if not db.pdf_exists(pdf)]
print(f"Need to import {len(all_pdfs):,} total new pdf paths")
@ -626,7 +642,7 @@ if __name__ == '__main__':
# Now build an index of all the pages that were processed within the workspace so far
print("Indexing all batch inference sent to this workspace")
inference_output_paths = expand_s3_glob(f"{args.workspace}/inference_outputs/*.jsonl")
inference_output_paths = expand_s3_glob(workspace_s3, f"{args.workspace}/inference_outputs/*.jsonl")
inference_output_paths = [
(s3_path, etag) for s3_path, etag in inference_output_paths.items()