Buildsilver script suppors reservoir sampling so it can sample 100M+ paths now efficiently

This commit is contained in:
Jake Poznanski 2024-09-30 18:41:18 +00:00
parent 8ec9e35f22
commit b4e9d6a2b8

View File

@ -6,7 +6,6 @@ import base64
import argparse import argparse
import boto3 import boto3
import json import json
from openai import OpenAI
from pypdf import PdfReader from pypdf import PdfReader
from tqdm import tqdm from tqdm import tqdm
from typing import Generator from typing import Generator
@ -31,8 +30,6 @@ def _build_prompt(base_text: str) -> str:
f"RAW_TEXT_START\n{base_text}\nRAW_TEXT_END" f"RAW_TEXT_START\n{base_text}\nRAW_TEXT_END"
) )
# Initialize OpenAI client
openai_client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
pdf_filter = PdfFilter() pdf_filter = PdfFilter()
def build_page_query(local_pdf_path: str, pretty_pdf_path: str, page: int) -> dict: def build_page_query(local_pdf_path: str, pretty_pdf_path: str, page: int) -> dict:
@ -145,30 +142,68 @@ def main():
parser.add_argument("--first_n_pages", type=int, default=0, help="Always sample the first N pages of each PDF.") parser.add_argument("--first_n_pages", type=int, default=0, help="Always sample the first N pages of each PDF.")
parser.add_argument("--max_sample_pages", type=int, default=15, help="Max number of pages to sample per PDF.") parser.add_argument("--max_sample_pages", type=int, default=15, help="Max number of pages to sample per PDF.")
parser.add_argument("--output", type=str, default="openai_batch_data", help="Output destination") parser.add_argument("--output", type=str, default="openai_batch_data", help="Output destination")
parser.add_argument("--reservoir_size", type=int, default=None,
help="Size of the reservoir for sampling paths. Defaults to 10x num_sample_docs.")
args = parser.parse_args() args = parser.parse_args()
# Load PDF paths from glob or path_list # Set default reservoir_size if not provided
if args.reservoir_size is None:
args.reservoir_size = 10 * args.num_sample_docs
# Initialize reservoir sampling variables
pdf_paths = [] pdf_paths = []
n = 0 # Total number of items seen
# Load PDF paths from glob or path_list using reservoir sampling
if args.glob_path: if args.glob_path:
if args.glob_path.startswith("s3://"): if args.glob_path.startswith("s3://"):
# Handle S3 globbing using boto3 # Handle S3 globbing using boto3 with pagination
parsed = urlparse(args.glob_path) parsed = urlparse(args.glob_path)
s3 = boto3.client('s3') s3 = boto3.client('s3')
bucket_name = parsed.netloc bucket_name = parsed.netloc
prefix = os.path.dirname(parsed.path.lstrip('/')) + "/" prefix = os.path.dirname(parsed.path.lstrip('/')) + "/"
response = s3.list_objects_v2(Bucket=bucket_name, Prefix=prefix) paginator = s3.get_paginator('list_objects_v2')
for obj in response.get('Contents', []): page_iterator = paginator.paginate(Bucket=bucket_name, Prefix=prefix)
if obj['Key'].endswith('.pdf'):
pdf_paths.append(f"s3://{bucket_name}/{obj['Key']}") for page in page_iterator:
for obj in page.get('Contents', []):
if obj['Key'].endswith('.pdf'):
n += 1
path = f"s3://{bucket_name}/{obj['Key']}"
if len(pdf_paths) < args.reservoir_size:
pdf_paths.append(path)
else:
s = random.randint(1, n)
if s <= args.reservoir_size:
pdf_paths[s - 1] = path
else: else:
# Handle local globbing # Handle local globbing using glob.iglob()
pdf_paths = glob.glob(args.glob_path) for path in glob.iglob(args.glob_path, recursive=True):
n += 1
if len(pdf_paths) < args.reservoir_size:
pdf_paths.append(path)
else:
s = random.randint(1, n)
if s <= args.reservoir_size:
pdf_paths[s - 1] = path
elif args.path_list: elif args.path_list:
with open(args.path_list, 'r') as f: with open(args.path_list, 'r') as f:
pdf_paths = [line.strip() for line in f] for line in f:
n += 1
path = line.strip()
if len(pdf_paths) < args.reservoir_size:
pdf_paths.append(path)
else:
s = random.randint(1, n)
if s <= args.reservoir_size:
pdf_paths[s - 1] = path
# Shuffle the reservoir
random.shuffle(pdf_paths) random.shuffle(pdf_paths)
print(f"Loaded and shuffled {len(pdf_paths)} paths to use.")
# Rest of the code remains the same
cur_file_num = 0 cur_file_num = 0
output_dir = args.output output_dir = args.output
max_file_size = 99 * 1024 * 1024 # 99MB in bytes max_file_size = 99 * 1024 * 1024 # 99MB in bytes
@ -184,7 +219,7 @@ def main():
# Counter to track PDFs that produce at least one output # Counter to track PDFs that produce at least one output
pdfs_with_output = 0 pdfs_with_output = 0
# Using ThreadPoolExecutor to process files concurrently # Using ThreadPoolExecutor to process files concurrently
with ThreadPoolExecutor(max_workers=60) as executor: with ThreadPoolExecutor(max_workers=60) as executor:
futures = [] futures = []