mirror of
https://github.com/allenai/olmocr.git
synced 2025-09-26 08:54:01 +00:00
Adjusting workflow so I can do s2 pdfs
This commit is contained in:
parent
6d61ae4aa8
commit
4669eb7134
@ -10,6 +10,7 @@ import datetime
|
|||||||
import posixpath
|
import posixpath
|
||||||
import threading
|
import threading
|
||||||
import logging
|
import logging
|
||||||
|
import boto3.session
|
||||||
import urllib3.exceptions
|
import urllib3.exceptions
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@ -25,7 +26,8 @@ from pdelfin.prompts import build_finetuning_prompt
|
|||||||
from pdelfin.prompts.anchor import get_anchor_text
|
from pdelfin.prompts.anchor import get_anchor_text
|
||||||
|
|
||||||
# Global s3 client for the whole script, feel free to adjust params if you need it
|
# Global s3 client for the whole script, feel free to adjust params if you need it
|
||||||
s3 = boto3.client('s3')
|
workspace_s3 = boto3.client('s3')
|
||||||
|
pdf_s3 = boto3.client('s3')
|
||||||
|
|
||||||
# Quiet logs from pypdf and smart open
|
# Quiet logs from pypdf and smart open
|
||||||
logging.getLogger("pypdf").setLevel(logging.ERROR)
|
logging.getLogger("pypdf").setLevel(logging.ERROR)
|
||||||
@ -338,13 +340,13 @@ def parse_s3_path(s3_path):
|
|||||||
bucket, _, prefix = path.partition('/')
|
bucket, _, prefix = path.partition('/')
|
||||||
return bucket, prefix
|
return bucket, prefix
|
||||||
|
|
||||||
def expand_s3_glob(s3_glob: str) -> Dict[str, str]:
|
def expand_s3_glob(s3_client, s3_glob: str) -> Dict[str, str]:
|
||||||
parsed = urlparse(s3_glob)
|
parsed = urlparse(s3_glob)
|
||||||
bucket_name = parsed.netloc
|
bucket_name = parsed.netloc
|
||||||
prefix = os.path.dirname(parsed.path.lstrip('/')).rstrip('/') + "/"
|
prefix = os.path.dirname(parsed.path.lstrip('/')).rstrip('/') + "/"
|
||||||
pattern = os.path.basename(parsed.path)
|
pattern = os.path.basename(parsed.path)
|
||||||
|
|
||||||
paginator = s3.get_paginator('list_objects_v2')
|
paginator = s3_client.get_paginator('list_objects_v2')
|
||||||
page_iterator = paginator.paginate(Bucket=bucket_name, Prefix=prefix)
|
page_iterator = paginator.paginate(Bucket=bucket_name, Prefix=prefix)
|
||||||
|
|
||||||
matched_files = {}
|
matched_files = {}
|
||||||
@ -374,7 +376,7 @@ def build_page_query(local_pdf_path: str, pretty_pdf_path: str, page: int) -> di
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def get_s3_bytes(s3_path: str, start_index: Optional[int] = None, end_index: Optional[int] = None) -> bytes:
|
def get_s3_bytes(s3_client, s3_path: str, start_index: Optional[int] = None, end_index: Optional[int] = None) -> bytes:
|
||||||
bucket, key = parse_s3_path(s3_path)
|
bucket, key = parse_s3_path(s3_path)
|
||||||
|
|
||||||
# Build the range header if start_index and/or end_index are specified
|
# Build the range header if start_index and/or end_index are specified
|
||||||
@ -393,9 +395,9 @@ def get_s3_bytes(s3_path: str, start_index: Optional[int] = None, end_index: Opt
|
|||||||
range_header = {'Range': range_value}
|
range_header = {'Range': range_value}
|
||||||
|
|
||||||
if range_header:
|
if range_header:
|
||||||
obj = s3.get_object(Bucket=bucket, Key=key, Range=range_header['Range'])
|
obj = s3_client.get_object(Bucket=bucket, Key=key, Range=range_header['Range'])
|
||||||
else:
|
else:
|
||||||
obj = s3.get_object(Bucket=bucket, Key=key)
|
obj = s3_client.get_object(Bucket=bucket, Key=key)
|
||||||
|
|
||||||
return obj['Body'].read()
|
return obj['Body'].read()
|
||||||
|
|
||||||
@ -406,7 +408,7 @@ def parse_custom_id(custom_id: str) -> Tuple[str, int]:
|
|||||||
return s3_path, page_num
|
return s3_path, page_num
|
||||||
|
|
||||||
def process_jsonl_content(inference_s3_path: str) -> List[DatabaseManager.BatchInferenceRecord]:
|
def process_jsonl_content(inference_s3_path: str) -> List[DatabaseManager.BatchInferenceRecord]:
|
||||||
content_bytes = get_s3_bytes(inference_s3_path)
|
content_bytes = get_s3_bytes(workspace_s3, inference_s3_path)
|
||||||
|
|
||||||
start_index = 0
|
start_index = 0
|
||||||
index_entries = []
|
index_entries = []
|
||||||
@ -462,7 +464,7 @@ def process_jsonl_content(inference_s3_path: str) -> List[DatabaseManager.BatchI
|
|||||||
def get_pdf_num_pages(s3_path: str) -> Optional[int]:
|
def get_pdf_num_pages(s3_path: str) -> Optional[int]:
|
||||||
try:
|
try:
|
||||||
with tempfile.NamedTemporaryFile("wb+", suffix=".pdf") as tf:
|
with tempfile.NamedTemporaryFile("wb+", suffix=".pdf") as tf:
|
||||||
tf.write(get_s3_bytes(s3_path))
|
tf.write(get_s3_bytes(pdf_s3, s3_path))
|
||||||
tf.flush()
|
tf.flush()
|
||||||
|
|
||||||
reader = PdfReader(tf.name)
|
reader = PdfReader(tf.name)
|
||||||
@ -484,7 +486,7 @@ def build_pdf_queries(s3_workspace: str, pdf: DatabaseManager.PDFRecord, cur_rou
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
with tempfile.NamedTemporaryFile("wb+", suffix=".pdf") as tf:
|
with tempfile.NamedTemporaryFile("wb+", suffix=".pdf") as tf:
|
||||||
tf.write(get_s3_bytes(pdf.s3_path))
|
tf.write(get_s3_bytes(pdf_s3, pdf.s3_path))
|
||||||
tf.flush()
|
tf.flush()
|
||||||
|
|
||||||
for target_page_num in range(1, pdf.num_pages + 1):
|
for target_page_num in range(1, pdf.num_pages + 1):
|
||||||
@ -522,7 +524,7 @@ def build_dolma_doc(s3_workspace: str, pdf: DatabaseManager.PDFRecord) -> Option
|
|||||||
|
|
||||||
target_page = target_pages[0]
|
target_page = target_pages[0]
|
||||||
|
|
||||||
target_row = get_s3_bytes(target_page.inference_s3_path,
|
target_row = get_s3_bytes(workspace_s3, target_page.inference_s3_path,
|
||||||
start_index=target_page.start_index,
|
start_index=target_page.start_index,
|
||||||
end_index=target_page.start_index+target_page.length - 1)
|
end_index=target_page.start_index+target_page.length - 1)
|
||||||
|
|
||||||
@ -565,7 +567,7 @@ def mark_pdfs_done(s3_workspace: str, dolma_doc_lines: list[str]):
|
|||||||
def get_current_round(s3_workspace: str) -> int:
|
def get_current_round(s3_workspace: str) -> int:
|
||||||
bucket, prefix = parse_s3_path(s3_workspace)
|
bucket, prefix = parse_s3_path(s3_workspace)
|
||||||
inference_inputs_prefix = posixpath.join(prefix, 'inference_inputs/')
|
inference_inputs_prefix = posixpath.join(prefix, 'inference_inputs/')
|
||||||
paginator = s3.get_paginator('list_objects_v2')
|
paginator = workspace_s3.get_paginator('list_objects_v2')
|
||||||
page_iterator = paginator.paginate(Bucket=bucket, Prefix=inference_inputs_prefix, Delimiter='/')
|
page_iterator = paginator.paginate(Bucket=bucket, Prefix=inference_inputs_prefix, Delimiter='/')
|
||||||
|
|
||||||
round_numbers = []
|
round_numbers = []
|
||||||
@ -590,10 +592,20 @@ def get_current_round(s3_workspace: str) -> int:
|
|||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser(description='Manager for running millions of PDFs through a batch inference pipeline')
|
parser = argparse.ArgumentParser(description='Manager for running millions of PDFs through a batch inference pipeline')
|
||||||
parser.add_argument('workspace', help='The S3 path where work will be done e.g., s3://bucket/prefix/)')
|
parser.add_argument('workspace', help='The S3 path where work will be done e.g., s3://bucket/prefix/)')
|
||||||
parser.add_argument('--add_pdfs', help='Glob path to add PDFs (s3) to the workspace', default=None)
|
parser.add_argument('--add_pdfs', help='Path to add pdfs stored in s3 to the workspace, can be a glob path s3://bucket/prefix/*.pdf or path to file containing list of pdf paths', default=None)
|
||||||
|
parser.add_argument('--workspace_profile', help='S3 configuration profile for accessing the workspace', default=None)
|
||||||
|
parser.add_argument('--pdf_profile', help='S3 configuration profile for accessing the raw pdf documents', default=None)
|
||||||
parser.add_argument('--max_size_mb', type=int, default=250, help='Max file size in MB')
|
parser.add_argument('--max_size_mb', type=int, default=250, help='Max file size in MB')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.workspace_profile:
|
||||||
|
workspace_session = boto3.Session(profile_name=args.workspace_profile)
|
||||||
|
workspace_s3 = workspace_session.resource("s3")
|
||||||
|
|
||||||
|
if args.pdf_profile:
|
||||||
|
pdf_session = boto3.Session(profile_name=args.pdf_profile)
|
||||||
|
pdf_s3 = pdf_session.resource("s3")
|
||||||
|
|
||||||
db = DatabaseManager(args.workspace)
|
db = DatabaseManager(args.workspace)
|
||||||
print(f"Loaded db at {db.db_path}")
|
print(f"Loaded db at {db.db_path}")
|
||||||
|
|
||||||
@ -605,12 +617,16 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
# If you have new PDFs, step one is to add them to the list
|
# If you have new PDFs, step one is to add them to the list
|
||||||
if args.add_pdfs:
|
if args.add_pdfs:
|
||||||
assert args.add_pdfs.startswith("s3://"), "PDFs must live on s3"
|
if args.add_pdfs.startswith("s3://"):
|
||||||
|
|
||||||
print(f"Querying all PDFs at {args.add_pdfs}")
|
print(f"Querying all PDFs at {args.add_pdfs}")
|
||||||
|
|
||||||
all_pdfs = expand_s3_glob(args.add_pdfs)
|
all_pdfs = expand_s3_glob(pdf_s3, args.add_pdfs)
|
||||||
print(f"Found {len(all_pdfs):,} total pdf paths")
|
print(f"Found {len(all_pdfs):,} total pdf paths")
|
||||||
|
elif os.path.exists(args.add_pdfs):
|
||||||
|
with open(args.add_pdfs, "r") as f:
|
||||||
|
all_pdfs = [line for line in f.readlines() if len(line.strip()) > 0]
|
||||||
|
else:
|
||||||
|
raise ValueError("add_pdfs argument needs to be either an s3 glob search path, or a local file contains pdf paths (one per line)")
|
||||||
|
|
||||||
all_pdfs = [pdf for pdf in all_pdfs if not db.pdf_exists(pdf)]
|
all_pdfs = [pdf for pdf in all_pdfs if not db.pdf_exists(pdf)]
|
||||||
print(f"Need to import {len(all_pdfs):,} total new pdf paths")
|
print(f"Need to import {len(all_pdfs):,} total new pdf paths")
|
||||||
@ -626,7 +642,7 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
# Now build an index of all the pages that were processed within the workspace so far
|
# Now build an index of all the pages that were processed within the workspace so far
|
||||||
print("Indexing all batch inference sent to this workspace")
|
print("Indexing all batch inference sent to this workspace")
|
||||||
inference_output_paths = expand_s3_glob(f"{args.workspace}/inference_outputs/*.jsonl")
|
inference_output_paths = expand_s3_glob(workspace_s3, f"{args.workspace}/inference_outputs/*.jsonl")
|
||||||
|
|
||||||
inference_output_paths = [
|
inference_output_paths = [
|
||||||
(s3_path, etag) for s3_path, etag in inference_output_paths.items()
|
(s3_path, etag) for s3_path, etag in inference_output_paths.items()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user