Buildsilver script suppors reservoir sampling so it can sample 100M+ paths now efficiently

This commit is contained in:
Jake Poznanski 2024-09-30 18:41:18 +00:00
parent 8ec9e35f22
commit b4e9d6a2b8

View File

@ -6,7 +6,6 @@ import base64
import argparse
import boto3
import json
from openai import OpenAI
from pypdf import PdfReader
from tqdm import tqdm
from typing import Generator
@ -31,8 +30,6 @@ def _build_prompt(base_text: str) -> str:
f"RAW_TEXT_START\n{base_text}\nRAW_TEXT_END"
)
# Initialize OpenAI client
openai_client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
pdf_filter = PdfFilter()
def build_page_query(local_pdf_path: str, pretty_pdf_path: str, page: int) -> dict:
@ -145,30 +142,68 @@ def main():
parser.add_argument("--first_n_pages", type=int, default=0, help="Always sample the first N pages of each PDF.")
parser.add_argument("--max_sample_pages", type=int, default=15, help="Max number of pages to sample per PDF.")
parser.add_argument("--output", type=str, default="openai_batch_data", help="Output destination")
parser.add_argument("--reservoir_size", type=int, default=None,
help="Size of the reservoir for sampling paths. Defaults to 10x num_sample_docs.")
args = parser.parse_args()
# Load PDF paths from glob or path_list
# Set default reservoir_size if not provided
if args.reservoir_size is None:
args.reservoir_size = 10 * args.num_sample_docs
# Initialize reservoir sampling variables
pdf_paths = []
n = 0 # Total number of items seen
# Load PDF paths from glob or path_list using reservoir sampling
if args.glob_path:
if args.glob_path.startswith("s3://"):
# Handle S3 globbing using boto3
# Handle S3 globbing using boto3 with pagination
parsed = urlparse(args.glob_path)
s3 = boto3.client('s3')
bucket_name = parsed.netloc
prefix = os.path.dirname(parsed.path.lstrip('/')) + "/"
response = s3.list_objects_v2(Bucket=bucket_name, Prefix=prefix)
for obj in response.get('Contents', []):
if obj['Key'].endswith('.pdf'):
pdf_paths.append(f"s3://{bucket_name}/{obj['Key']}")
paginator = s3.get_paginator('list_objects_v2')
page_iterator = paginator.paginate(Bucket=bucket_name, Prefix=prefix)
for page in page_iterator:
for obj in page.get('Contents', []):
if obj['Key'].endswith('.pdf'):
n += 1
path = f"s3://{bucket_name}/{obj['Key']}"
if len(pdf_paths) < args.reservoir_size:
pdf_paths.append(path)
else:
s = random.randint(1, n)
if s <= args.reservoir_size:
pdf_paths[s - 1] = path
else:
# Handle local globbing
pdf_paths = glob.glob(args.glob_path)
# Handle local globbing using glob.iglob()
for path in glob.iglob(args.glob_path, recursive=True):
n += 1
if len(pdf_paths) < args.reservoir_size:
pdf_paths.append(path)
else:
s = random.randint(1, n)
if s <= args.reservoir_size:
pdf_paths[s - 1] = path
elif args.path_list:
with open(args.path_list, 'r') as f:
pdf_paths = [line.strip() for line in f]
for line in f:
n += 1
path = line.strip()
if len(pdf_paths) < args.reservoir_size:
pdf_paths.append(path)
else:
s = random.randint(1, n)
if s <= args.reservoir_size:
pdf_paths[s - 1] = path
# Shuffle the reservoir
random.shuffle(pdf_paths)
print(f"Loaded and shuffled {len(pdf_paths)} paths to use.")
# Rest of the code remains the same
cur_file_num = 0
output_dir = args.output
max_file_size = 99 * 1024 * 1024 # 99MB in bytes
@ -184,7 +219,7 @@ def main():
# Counter to track PDFs that produce at least one output
pdfs_with_output = 0
# Using ThreadPoolExecutor to process files concurrently
# Using ThreadPoolExecutor to process files concurrently
with ThreadPoolExecutor(max_workers=60) as executor:
futures = []