mirror of
https://github.com/allenai/olmocr.git
synced 2025-11-07 22:18:51 +00:00
Buildsilver script suppors reservoir sampling so it can sample 100M+ paths now efficiently
This commit is contained in:
parent
8ec9e35f22
commit
b4e9d6a2b8
@ -6,7 +6,6 @@ import base64
|
|||||||
import argparse
|
import argparse
|
||||||
import boto3
|
import boto3
|
||||||
import json
|
import json
|
||||||
from openai import OpenAI
|
|
||||||
from pypdf import PdfReader
|
from pypdf import PdfReader
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from typing import Generator
|
from typing import Generator
|
||||||
@ -31,8 +30,6 @@ def _build_prompt(base_text: str) -> str:
|
|||||||
f"RAW_TEXT_START\n{base_text}\nRAW_TEXT_END"
|
f"RAW_TEXT_START\n{base_text}\nRAW_TEXT_END"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize OpenAI client
|
|
||||||
openai_client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
|
||||||
pdf_filter = PdfFilter()
|
pdf_filter = PdfFilter()
|
||||||
|
|
||||||
def build_page_query(local_pdf_path: str, pretty_pdf_path: str, page: int) -> dict:
|
def build_page_query(local_pdf_path: str, pretty_pdf_path: str, page: int) -> dict:
|
||||||
@ -145,30 +142,68 @@ def main():
|
|||||||
parser.add_argument("--first_n_pages", type=int, default=0, help="Always sample the first N pages of each PDF.")
|
parser.add_argument("--first_n_pages", type=int, default=0, help="Always sample the first N pages of each PDF.")
|
||||||
parser.add_argument("--max_sample_pages", type=int, default=15, help="Max number of pages to sample per PDF.")
|
parser.add_argument("--max_sample_pages", type=int, default=15, help="Max number of pages to sample per PDF.")
|
||||||
parser.add_argument("--output", type=str, default="openai_batch_data", help="Output destination")
|
parser.add_argument("--output", type=str, default="openai_batch_data", help="Output destination")
|
||||||
|
parser.add_argument("--reservoir_size", type=int, default=None,
|
||||||
|
help="Size of the reservoir for sampling paths. Defaults to 10x num_sample_docs.")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Load PDF paths from glob or path_list
|
# Set default reservoir_size if not provided
|
||||||
|
if args.reservoir_size is None:
|
||||||
|
args.reservoir_size = 10 * args.num_sample_docs
|
||||||
|
|
||||||
|
# Initialize reservoir sampling variables
|
||||||
pdf_paths = []
|
pdf_paths = []
|
||||||
|
n = 0 # Total number of items seen
|
||||||
|
|
||||||
|
# Load PDF paths from glob or path_list using reservoir sampling
|
||||||
if args.glob_path:
|
if args.glob_path:
|
||||||
if args.glob_path.startswith("s3://"):
|
if args.glob_path.startswith("s3://"):
|
||||||
# Handle S3 globbing using boto3
|
# Handle S3 globbing using boto3 with pagination
|
||||||
parsed = urlparse(args.glob_path)
|
parsed = urlparse(args.glob_path)
|
||||||
s3 = boto3.client('s3')
|
s3 = boto3.client('s3')
|
||||||
bucket_name = parsed.netloc
|
bucket_name = parsed.netloc
|
||||||
prefix = os.path.dirname(parsed.path.lstrip('/')) + "/"
|
prefix = os.path.dirname(parsed.path.lstrip('/')) + "/"
|
||||||
response = s3.list_objects_v2(Bucket=bucket_name, Prefix=prefix)
|
paginator = s3.get_paginator('list_objects_v2')
|
||||||
for obj in response.get('Contents', []):
|
page_iterator = paginator.paginate(Bucket=bucket_name, Prefix=prefix)
|
||||||
if obj['Key'].endswith('.pdf'):
|
|
||||||
pdf_paths.append(f"s3://{bucket_name}/{obj['Key']}")
|
for page in page_iterator:
|
||||||
|
for obj in page.get('Contents', []):
|
||||||
|
if obj['Key'].endswith('.pdf'):
|
||||||
|
n += 1
|
||||||
|
path = f"s3://{bucket_name}/{obj['Key']}"
|
||||||
|
if len(pdf_paths) < args.reservoir_size:
|
||||||
|
pdf_paths.append(path)
|
||||||
|
else:
|
||||||
|
s = random.randint(1, n)
|
||||||
|
if s <= args.reservoir_size:
|
||||||
|
pdf_paths[s - 1] = path
|
||||||
else:
|
else:
|
||||||
# Handle local globbing
|
# Handle local globbing using glob.iglob()
|
||||||
pdf_paths = glob.glob(args.glob_path)
|
for path in glob.iglob(args.glob_path, recursive=True):
|
||||||
|
n += 1
|
||||||
|
if len(pdf_paths) < args.reservoir_size:
|
||||||
|
pdf_paths.append(path)
|
||||||
|
else:
|
||||||
|
s = random.randint(1, n)
|
||||||
|
if s <= args.reservoir_size:
|
||||||
|
pdf_paths[s - 1] = path
|
||||||
elif args.path_list:
|
elif args.path_list:
|
||||||
with open(args.path_list, 'r') as f:
|
with open(args.path_list, 'r') as f:
|
||||||
pdf_paths = [line.strip() for line in f]
|
for line in f:
|
||||||
|
n += 1
|
||||||
|
path = line.strip()
|
||||||
|
if len(pdf_paths) < args.reservoir_size:
|
||||||
|
pdf_paths.append(path)
|
||||||
|
else:
|
||||||
|
s = random.randint(1, n)
|
||||||
|
if s <= args.reservoir_size:
|
||||||
|
pdf_paths[s - 1] = path
|
||||||
|
|
||||||
|
# Shuffle the reservoir
|
||||||
random.shuffle(pdf_paths)
|
random.shuffle(pdf_paths)
|
||||||
|
|
||||||
|
print(f"Loaded and shuffled {len(pdf_paths)} paths to use.")
|
||||||
|
|
||||||
|
# Rest of the code remains the same
|
||||||
cur_file_num = 0
|
cur_file_num = 0
|
||||||
output_dir = args.output
|
output_dir = args.output
|
||||||
max_file_size = 99 * 1024 * 1024 # 99MB in bytes
|
max_file_size = 99 * 1024 * 1024 # 99MB in bytes
|
||||||
@ -184,7 +219,7 @@ def main():
|
|||||||
# Counter to track PDFs that produce at least one output
|
# Counter to track PDFs that produce at least one output
|
||||||
pdfs_with_output = 0
|
pdfs_with_output = 0
|
||||||
|
|
||||||
# Using ThreadPoolExecutor to process files concurrently
|
# Using ThreadPoolExecutor to process files concurrently
|
||||||
with ThreadPoolExecutor(max_workers=60) as executor:
|
with ThreadPoolExecutor(max_workers=60) as executor:
|
||||||
futures = []
|
futures = []
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user