mirror of
https://github.com/allenai/olmocr.git
synced 2025-11-03 11:35:29 +00:00
Buildsilver script suppors reservoir sampling so it can sample 100M+ paths now efficiently
This commit is contained in:
parent
8ec9e35f22
commit
b4e9d6a2b8
@ -6,7 +6,6 @@ import base64
|
||||
import argparse
|
||||
import boto3
|
||||
import json
|
||||
from openai import OpenAI
|
||||
from pypdf import PdfReader
|
||||
from tqdm import tqdm
|
||||
from typing import Generator
|
||||
@ -31,8 +30,6 @@ def _build_prompt(base_text: str) -> str:
|
||||
f"RAW_TEXT_START\n{base_text}\nRAW_TEXT_END"
|
||||
)
|
||||
|
||||
# Initialize OpenAI client
|
||||
openai_client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
||||
pdf_filter = PdfFilter()
|
||||
|
||||
def build_page_query(local_pdf_path: str, pretty_pdf_path: str, page: int) -> dict:
|
||||
@ -145,30 +142,68 @@ def main():
|
||||
parser.add_argument("--first_n_pages", type=int, default=0, help="Always sample the first N pages of each PDF.")
|
||||
parser.add_argument("--max_sample_pages", type=int, default=15, help="Max number of pages to sample per PDF.")
|
||||
parser.add_argument("--output", type=str, default="openai_batch_data", help="Output destination")
|
||||
parser.add_argument("--reservoir_size", type=int, default=None,
|
||||
help="Size of the reservoir for sampling paths. Defaults to 10x num_sample_docs.")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load PDF paths from glob or path_list
|
||||
# Set default reservoir_size if not provided
|
||||
if args.reservoir_size is None:
|
||||
args.reservoir_size = 10 * args.num_sample_docs
|
||||
|
||||
# Initialize reservoir sampling variables
|
||||
pdf_paths = []
|
||||
n = 0 # Total number of items seen
|
||||
|
||||
# Load PDF paths from glob or path_list using reservoir sampling
|
||||
if args.glob_path:
|
||||
if args.glob_path.startswith("s3://"):
|
||||
# Handle S3 globbing using boto3
|
||||
# Handle S3 globbing using boto3 with pagination
|
||||
parsed = urlparse(args.glob_path)
|
||||
s3 = boto3.client('s3')
|
||||
bucket_name = parsed.netloc
|
||||
prefix = os.path.dirname(parsed.path.lstrip('/')) + "/"
|
||||
response = s3.list_objects_v2(Bucket=bucket_name, Prefix=prefix)
|
||||
for obj in response.get('Contents', []):
|
||||
if obj['Key'].endswith('.pdf'):
|
||||
pdf_paths.append(f"s3://{bucket_name}/{obj['Key']}")
|
||||
paginator = s3.get_paginator('list_objects_v2')
|
||||
page_iterator = paginator.paginate(Bucket=bucket_name, Prefix=prefix)
|
||||
|
||||
for page in page_iterator:
|
||||
for obj in page.get('Contents', []):
|
||||
if obj['Key'].endswith('.pdf'):
|
||||
n += 1
|
||||
path = f"s3://{bucket_name}/{obj['Key']}"
|
||||
if len(pdf_paths) < args.reservoir_size:
|
||||
pdf_paths.append(path)
|
||||
else:
|
||||
s = random.randint(1, n)
|
||||
if s <= args.reservoir_size:
|
||||
pdf_paths[s - 1] = path
|
||||
else:
|
||||
# Handle local globbing
|
||||
pdf_paths = glob.glob(args.glob_path)
|
||||
# Handle local globbing using glob.iglob()
|
||||
for path in glob.iglob(args.glob_path, recursive=True):
|
||||
n += 1
|
||||
if len(pdf_paths) < args.reservoir_size:
|
||||
pdf_paths.append(path)
|
||||
else:
|
||||
s = random.randint(1, n)
|
||||
if s <= args.reservoir_size:
|
||||
pdf_paths[s - 1] = path
|
||||
elif args.path_list:
|
||||
with open(args.path_list, 'r') as f:
|
||||
pdf_paths = [line.strip() for line in f]
|
||||
for line in f:
|
||||
n += 1
|
||||
path = line.strip()
|
||||
if len(pdf_paths) < args.reservoir_size:
|
||||
pdf_paths.append(path)
|
||||
else:
|
||||
s = random.randint(1, n)
|
||||
if s <= args.reservoir_size:
|
||||
pdf_paths[s - 1] = path
|
||||
|
||||
# Shuffle the reservoir
|
||||
random.shuffle(pdf_paths)
|
||||
|
||||
print(f"Loaded and shuffled {len(pdf_paths)} paths to use.")
|
||||
|
||||
# Rest of the code remains the same
|
||||
cur_file_num = 0
|
||||
output_dir = args.output
|
||||
max_file_size = 99 * 1024 * 1024 # 99MB in bytes
|
||||
@ -184,7 +219,7 @@ def main():
|
||||
# Counter to track PDFs that produce at least one output
|
||||
pdfs_with_output = 0
|
||||
|
||||
# Using ThreadPoolExecutor to process files concurrently
|
||||
# Using ThreadPoolExecutor to process files concurrently
|
||||
with ThreadPoolExecutor(max_workers=60) as executor:
|
||||
futures = []
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user