mirror of
https://github.com/allenai/olmocr.git
synced 2025-09-25 16:30:28 +00:00
Adjusting workflow so I can do s2 pdfs
This commit is contained in:
parent
6d61ae4aa8
commit
4669eb7134
@ -10,6 +10,7 @@ import datetime
|
||||
import posixpath
|
||||
import threading
|
||||
import logging
|
||||
import boto3.session
|
||||
import urllib3.exceptions
|
||||
|
||||
from dataclasses import dataclass
|
||||
@ -25,7 +26,8 @@ from pdelfin.prompts import build_finetuning_prompt
|
||||
from pdelfin.prompts.anchor import get_anchor_text
|
||||
|
||||
# Global s3 client for the whole script, feel free to adjust params if you need it
|
||||
s3 = boto3.client('s3')
|
||||
workspace_s3 = boto3.client('s3')
|
||||
pdf_s3 = boto3.client('s3')
|
||||
|
||||
# Quiet logs from pypdf and smart open
|
||||
logging.getLogger("pypdf").setLevel(logging.ERROR)
|
||||
@ -338,13 +340,13 @@ def parse_s3_path(s3_path):
|
||||
bucket, _, prefix = path.partition('/')
|
||||
return bucket, prefix
|
||||
|
||||
def expand_s3_glob(s3_glob: str) -> Dict[str, str]:
|
||||
def expand_s3_glob(s3_client, s3_glob: str) -> Dict[str, str]:
|
||||
parsed = urlparse(s3_glob)
|
||||
bucket_name = parsed.netloc
|
||||
prefix = os.path.dirname(parsed.path.lstrip('/')).rstrip('/') + "/"
|
||||
pattern = os.path.basename(parsed.path)
|
||||
|
||||
paginator = s3.get_paginator('list_objects_v2')
|
||||
paginator = s3_client.get_paginator('list_objects_v2')
|
||||
page_iterator = paginator.paginate(Bucket=bucket_name, Prefix=prefix)
|
||||
|
||||
matched_files = {}
|
||||
@ -374,7 +376,7 @@ def build_page_query(local_pdf_path: str, pretty_pdf_path: str, page: int) -> di
|
||||
}
|
||||
|
||||
|
||||
def get_s3_bytes(s3_path: str, start_index: Optional[int] = None, end_index: Optional[int] = None) -> bytes:
|
||||
def get_s3_bytes(s3_client, s3_path: str, start_index: Optional[int] = None, end_index: Optional[int] = None) -> bytes:
|
||||
bucket, key = parse_s3_path(s3_path)
|
||||
|
||||
# Build the range header if start_index and/or end_index are specified
|
||||
@ -393,9 +395,9 @@ def get_s3_bytes(s3_path: str, start_index: Optional[int] = None, end_index: Opt
|
||||
range_header = {'Range': range_value}
|
||||
|
||||
if range_header:
|
||||
obj = s3.get_object(Bucket=bucket, Key=key, Range=range_header['Range'])
|
||||
obj = s3_client.get_object(Bucket=bucket, Key=key, Range=range_header['Range'])
|
||||
else:
|
||||
obj = s3.get_object(Bucket=bucket, Key=key)
|
||||
obj = s3_client.get_object(Bucket=bucket, Key=key)
|
||||
|
||||
return obj['Body'].read()
|
||||
|
||||
@ -406,7 +408,7 @@ def parse_custom_id(custom_id: str) -> Tuple[str, int]:
|
||||
return s3_path, page_num
|
||||
|
||||
def process_jsonl_content(inference_s3_path: str) -> List[DatabaseManager.BatchInferenceRecord]:
|
||||
content_bytes = get_s3_bytes(inference_s3_path)
|
||||
content_bytes = get_s3_bytes(workspace_s3, inference_s3_path)
|
||||
|
||||
start_index = 0
|
||||
index_entries = []
|
||||
@ -462,7 +464,7 @@ def process_jsonl_content(inference_s3_path: str) -> List[DatabaseManager.BatchI
|
||||
def get_pdf_num_pages(s3_path: str) -> Optional[int]:
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile("wb+", suffix=".pdf") as tf:
|
||||
tf.write(get_s3_bytes(s3_path))
|
||||
tf.write(get_s3_bytes(pdf_s3, s3_path))
|
||||
tf.flush()
|
||||
|
||||
reader = PdfReader(tf.name)
|
||||
@ -484,7 +486,7 @@ def build_pdf_queries(s3_workspace: str, pdf: DatabaseManager.PDFRecord, cur_rou
|
||||
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile("wb+", suffix=".pdf") as tf:
|
||||
tf.write(get_s3_bytes(pdf.s3_path))
|
||||
tf.write(get_s3_bytes(pdf_s3, pdf.s3_path))
|
||||
tf.flush()
|
||||
|
||||
for target_page_num in range(1, pdf.num_pages + 1):
|
||||
@ -522,9 +524,9 @@ def build_dolma_doc(s3_workspace: str, pdf: DatabaseManager.PDFRecord) -> Option
|
||||
|
||||
target_page = target_pages[0]
|
||||
|
||||
target_row = get_s3_bytes(target_page.inference_s3_path,
|
||||
start_index=target_page.start_index,
|
||||
end_index=target_page.start_index+target_page.length - 1)
|
||||
target_row = get_s3_bytes(workspace_s3, target_page.inference_s3_path,
|
||||
start_index=target_page.start_index,
|
||||
end_index=target_page.start_index+target_page.length - 1)
|
||||
|
||||
target_data = json.loads(target_row.decode("utf-8"))
|
||||
|
||||
@ -565,7 +567,7 @@ def mark_pdfs_done(s3_workspace: str, dolma_doc_lines: list[str]):
|
||||
def get_current_round(s3_workspace: str) -> int:
|
||||
bucket, prefix = parse_s3_path(s3_workspace)
|
||||
inference_inputs_prefix = posixpath.join(prefix, 'inference_inputs/')
|
||||
paginator = s3.get_paginator('list_objects_v2')
|
||||
paginator = workspace_s3.get_paginator('list_objects_v2')
|
||||
page_iterator = paginator.paginate(Bucket=bucket, Prefix=inference_inputs_prefix, Delimiter='/')
|
||||
|
||||
round_numbers = []
|
||||
@ -590,10 +592,20 @@ def get_current_round(s3_workspace: str) -> int:
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='Manager for running millions of PDFs through a batch inference pipeline')
|
||||
parser.add_argument('workspace', help='The S3 path where work will be done e.g., s3://bucket/prefix/)')
|
||||
parser.add_argument('--add_pdfs', help='Glob path to add PDFs (s3) to the workspace', default=None)
|
||||
parser.add_argument('--add_pdfs', help='Path to add pdfs stored in s3 to the workspace, can be a glob path s3://bucket/prefix/*.pdf or path to file containing list of pdf paths', default=None)
|
||||
parser.add_argument('--workspace_profile', help='S3 configuration profile for accessing the workspace', default=None)
|
||||
parser.add_argument('--pdf_profile', help='S3 configuration profile for accessing the raw pdf documents', default=None)
|
||||
parser.add_argument('--max_size_mb', type=int, default=250, help='Max file size in MB')
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.workspace_profile:
|
||||
workspace_session = boto3.Session(profile_name=args.workspace_profile)
|
||||
workspace_s3 = workspace_session.resource("s3")
|
||||
|
||||
if args.pdf_profile:
|
||||
pdf_session = boto3.Session(profile_name=args.pdf_profile)
|
||||
pdf_s3 = pdf_session.resource("s3")
|
||||
|
||||
db = DatabaseManager(args.workspace)
|
||||
print(f"Loaded db at {db.db_path}")
|
||||
|
||||
@ -605,12 +617,16 @@ if __name__ == '__main__':
|
||||
|
||||
# If you have new PDFs, step one is to add them to the list
|
||||
if args.add_pdfs:
|
||||
assert args.add_pdfs.startswith("s3://"), "PDFs must live on s3"
|
||||
|
||||
print(f"Querying all PDFs at {args.add_pdfs}")
|
||||
|
||||
all_pdfs = expand_s3_glob(args.add_pdfs)
|
||||
print(f"Found {len(all_pdfs):,} total pdf paths")
|
||||
if args.add_pdfs.startswith("s3://"):
|
||||
print(f"Querying all PDFs at {args.add_pdfs}")
|
||||
|
||||
all_pdfs = expand_s3_glob(pdf_s3, args.add_pdfs)
|
||||
print(f"Found {len(all_pdfs):,} total pdf paths")
|
||||
elif os.path.exists(args.add_pdfs):
|
||||
with open(args.add_pdfs, "r") as f:
|
||||
all_pdfs = [line for line in f.readlines() if len(line.strip()) > 0]
|
||||
else:
|
||||
raise ValueError("add_pdfs argument needs to be either an s3 glob search path, or a local file contains pdf paths (one per line)")
|
||||
|
||||
all_pdfs = [pdf for pdf in all_pdfs if not db.pdf_exists(pdf)]
|
||||
print(f"Need to import {len(all_pdfs):,} total new pdf paths")
|
||||
@ -626,7 +642,7 @@ if __name__ == '__main__':
|
||||
|
||||
# Now build an index of all the pages that were processed within the workspace so far
|
||||
print("Indexing all batch inference sent to this workspace")
|
||||
inference_output_paths = expand_s3_glob(f"{args.workspace}/inference_outputs/*.jsonl")
|
||||
inference_output_paths = expand_s3_glob(workspace_s3, f"{args.workspace}/inference_outputs/*.jsonl")
|
||||
|
||||
inference_output_paths = [
|
||||
(s3_path, etag) for s3_path, etag in inference_output_paths.items()
|
||||
|
Loading…
x
Reference in New Issue
Block a user