diff --git a/olmocr/train/convertjsontoparquet.py b/olmocr/train/convertjsontoparquet.py index 4f4154f..a8b4670 100644 --- a/olmocr/train/convertjsontoparquet.py +++ b/olmocr/train/convertjsontoparquet.py @@ -12,13 +12,17 @@ import glob import json import re import sqlite3 +import tempfile +import os +import shutil from dataclasses import dataclass -from typing import Optional, List, Tuple +from typing import Optional, List, Tuple, Dict, Set import concurrent.futures +import boto3 from tqdm import tqdm import pandas as pd - +from pypdf import PdfReader, PdfWriter def parse_pdf_hash(pretty_pdf_path: str) -> Optional[str]: """ @@ -27,13 +31,13 @@ def parse_pdf_hash(pretty_pdf_path: str) -> Optional[str]: s3://ai2-s2-pdfs/de80/a57e6c57b45796d2e020173227f7eae44232.pdf-1 it will return "de80a57e6c57b45796d2e020173227f7eae44232". """ - pattern = r"s3://ai2-s2-pdfs/([a-f0-9]{4})/([a-f0-9]+)\.pdf" + # Allow an optional "-" at the end. + pattern = r"s3://ai2-s2-pdfs/([a-f0-9]{4})/([a-f0-9]+)\.pdf(?:-\d+)?$" match = re.match(pattern, pretty_pdf_path) if match: return match.group(1) + match.group(2) return None - def get_uri_from_db(db_path: str, pdf_hash: str) -> Optional[str]: """ Looks up the URL for the given pdf_hash in the sqlite database. @@ -44,8 +48,7 @@ def get_uri_from_db(db_path: str, pdf_hash: str) -> Optional[str]: cursor.execute("SELECT uri FROM pdf_mapping WHERE pdf_hash = ?", (pdf_hash,)) result = cursor.fetchone() conn.close() - return result[0] if result else None - + return result[0].strip() if result and result[0] else None @dataclass(frozen=True) class NormalizedEntry: @@ -70,7 +73,6 @@ class NormalizedEntry: def goldkey(self): return f"{self.s3_path}-{self.pagenum}" - def normalize_json_entry(data: dict) -> NormalizedEntry: """ Normalizes a JSON entry from any of the supported formats. @@ -107,14 +109,84 @@ def normalize_json_entry(data: dict) -> NormalizedEntry: else: raise ValueError("Unsupported JSON format") +def parse_s3_url(s3_url: str) -> Tuple[str, str]: + """ + Parses an S3 URL of the form s3://bucket/key and returns (bucket, key). + """ + if not s3_url.startswith("s3://"): + raise ValueError(f"Invalid S3 URL: {s3_url}") + s3_path = s3_url[5:] + bucket, key = s3_path.split("/", 1) + return bucket, key -def process_file(file_path: str, db_path: str) -> Tuple[List[dict], int]: +def download_pdf_to_cache(s3_url: str, cache_dir: str) -> Optional[str]: + """ + Downloads the PDF from the given S3 URL into the specified cache directory. + The destination filename is based on the parsed PDF hash. + Returns the path to the downloaded PDF. + """ + try: + bucket, key = parse_s3_url(s3_url) + s3_client = boto3.client('s3') + pdf_hash = parse_pdf_hash(s3_url) + if not pdf_hash: + # Fallback: use a sanitized version of the s3_url + pdf_hash = re.sub(r'\W+', '_', s3_url) + dest_path = os.path.join(cache_dir, f"{pdf_hash}.pdf") + # Avoid re-downloading if already exists + if not os.path.exists(dest_path): + s3_client.download_file(bucket, key, dest_path) + return dest_path + except Exception as e: + print(f"Error downloading {s3_url}: {e}") + return None + +def process_pdf_page(s3_url: str, page_number: int, combined_id: str, output_pdf_dir: str, pdf_cache: Dict[str, str]) -> Optional[str]: + """ + Extracts the specified page (1-indexed) from the cached PDF corresponding to s3_url. + Writes a new single-page PDF to the output_pdf_dir using the combined_id as the filename. + Returns the relative path to the new PDF (e.g., "pdfs/.pdf"). + """ + try: + local_cached_pdf = pdf_cache.get(s3_url) + if not local_cached_pdf or not os.path.exists(local_cached_pdf): + print(f"Cached PDF not found for {s3_url}") + return None + reader = PdfReader(local_cached_pdf) + # pypdf uses 0-indexed page numbers + page_index = page_number - 1 + if page_index < 0 or page_index >= len(reader.pages): + print(f"Page number {page_number} out of range for PDF {s3_url}") + return None + writer = PdfWriter() + writer.add_page(reader.pages[page_index]) + output_filename = f"{combined_id}.pdf" + output_path = os.path.join(output_pdf_dir, output_filename) + with open(output_path, "wb") as f_out: + writer.write(f_out) + # Return the relative path (assuming pdfs/ folder is relative to the parquet file location) + return os.path.join("pdfs", output_filename) + except Exception as e: + print(f"Error processing PDF page for {s3_url} page {page_number}: {e}") + return None + +def process_file(file_path: str, db_path: str, output_pdf_dir: str, pdf_cache: Dict[str, str]) -> Tuple[List[dict], int]: """ Process a single file and return a tuple: - (list of valid rows, number of rows skipped due to missing URL). + (list of valid rows, number of rows skipped due to missing URL or PDF extraction/filtering). + For each JSON entry, the function: + - Normalizes the JSON. + - Skips entries whose response contains the word "resume" (any case) along with either an email address or a phone number. + - Extracts the PDF hash and builds the combined id. + - Looks up the corresponding URL from the sqlite database. + - Extracts the specified page from the cached PDF and writes it to output_pdf_dir. + - Outputs a row with "id", "url", "page_number", "response". """ rows = [] missing_count = 0 + email_regex = r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}\b" + phone_regex = r"\b(?:\+?\d{1,3}[-.\s]?)?(?:\(?\d{3}\)?[-.\s]?)?\d{3}[-.\s]?\d{4}\b" + try: with open(file_path, "r", encoding="utf-8") as f: for line_num, line in enumerate(f, start=1): @@ -133,7 +205,14 @@ def process_file(file_path: str, db_path: str) -> Tuple[List[dict], int]: print(f"Error normalizing entry at {file_path}:{line_num} - {e}") continue - # Extract the pdf hash from the s3_path. + # Apply filter: skip if response contains "resume" (any case) and an email or phone number. + response_text = normalized.text if normalized.text else "" + if (re.search(r"resume", response_text, re.IGNORECASE) and + (re.search(email_regex, response_text) or re.search(phone_regex, response_text))): + print(f"Skipping entry due to resume and contact info in response at {file_path}:{line_num}") + continue + + # Extract the PDF hash from the s3_path. pdf_hash = parse_pdf_hash(normalized.s3_path) if pdf_hash is None: print(f"Could not parse pdf hash from {normalized.s3_path} at {file_path}:{line_num}") @@ -144,14 +223,18 @@ def process_file(file_path: str, db_path: str) -> Tuple[List[dict], int]: # Look up the corresponding URL from the sqlite database. url = get_uri_from_db(db_path, pdf_hash) - if url is not None: - url = url.strip() - # Skip rows with missing URLs (None or empty after strip) if not url: print(f"Missing URL for pdf hash {pdf_hash} at {file_path}:{line_num}") missing_count += 1 continue + # Process PDF: extract the specified page from the cached PDF. + local_pdf_path = process_pdf_page(normalized.s3_path, normalized.pagenum, combined_id, output_pdf_dir, pdf_cache) + if local_pdf_path is None: + print(f"Skipping entry because PDF processing failed for {normalized.s3_path} page {normalized.pagenum} at {file_path}:{line_num}") + missing_count += 1 + continue + row = { "id": combined_id, "url": url, @@ -163,6 +246,27 @@ def process_file(file_path: str, db_path: str) -> Tuple[List[dict], int]: print(f"Error processing file {file_path}: {e}") return rows, missing_count +def scan_file_for_s3_urls(file_path: str) -> Set[str]: + """ + Scans a single file and returns a set of unique S3 URLs found in the JSON entries. + """ + urls = set() + try: + with open(file_path, "r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line: + continue + try: + data = json.loads(line) + normalized = normalize_json_entry(data) + urls.add(normalized.s3_path) + except Exception: + # Skip entries that cannot be normalized + continue + except Exception as e: + print(f"Error reading file {file_path}: {e}") + return urls def main(): parser = argparse.ArgumentParser( @@ -182,19 +286,60 @@ def main(): files = glob.glob(args.input_dataset) print(f"Found {len(files)} files matching pattern: {args.input_dataset}") + # Determine output directory and create 'pdfs' subfolder. + output_abs_path = os.path.abspath(args.output) + output_dir = os.path.dirname(output_abs_path) + pdfs_dir = os.path.join(output_dir, "pdfs") + os.makedirs(pdfs_dir, exist_ok=True) + + # Create a temporary directory for caching PDFs. + pdf_cache_dir = tempfile.mkdtemp(prefix="pdf_cache_") + print(f"Caching PDFs to temporary directory: {pdf_cache_dir}") + + # --------------------------------------------------------------------- + # Step 1: Scan input files to collect all unique S3 URLs using a ProcessPoolExecutor. + unique_s3_urls: Set[str] = set() + print("Scanning input files to collect unique PDF URLs...") + with concurrent.futures.ProcessPoolExecutor() as executor: + results = list(tqdm(executor.map(scan_file_for_s3_urls, files), total=len(files), desc="Scanning files")) + for url_set in results: + unique_s3_urls |= url_set + + print(f"Found {len(unique_s3_urls)} unique PDF URLs.") + + # --------------------------------------------------------------------- + # Step 2: Download all unique PDFs to the cache directory. + pdf_cache: Dict[str, str] = {} + print("Caching PDFs from S3...") + with concurrent.futures.ThreadPoolExecutor() as executor: + future_to_url = { + executor.submit(download_pdf_to_cache, s3_url, pdf_cache_dir): s3_url + for s3_url in unique_s3_urls + } + for future in tqdm(concurrent.futures.as_completed(future_to_url), + total=len(future_to_url), desc="Downloading PDFs"): + s3_url = future_to_url[future] + try: + local_path = future.result() + if local_path: + pdf_cache[s3_url] = local_path + else: + print(f"Failed to cache PDF for {s3_url}") + except Exception as e: + print(f"Error caching PDF for {s3_url}: {e}") + + # --------------------------------------------------------------------- + # Step 3: Process input files using the precached PDFs. all_rows = [] total_missing = 0 - # Process files in parallel using ProcessPoolExecutor. + print("Processing files...") with concurrent.futures.ProcessPoolExecutor() as executor: futures = { - executor.submit(process_file, file_path, args.db_path): file_path + executor.submit(process_file, file_path, args.db_path, pdfs_dir, pdf_cache): file_path for file_path in files } - for future in tqdm( - concurrent.futures.as_completed(futures), - total=len(futures), - desc="Processing files", - ): + for future in tqdm(concurrent.futures.as_completed(futures), + total=len(futures), desc="Processing files"): file_path = futures[future] try: rows, missing_count = future.result() @@ -212,10 +357,16 @@ def main(): valid_count = len(df) total_processed = valid_count + total_missing print(f"Successfully wrote {valid_count} rows to {args.output}") - print(f"Missing URL rows skipped: {total_missing} out of {total_processed} processed rows") + print(f"Rows skipped due to missing URL/PDF or filtering: {total_missing} out of {total_processed} processed rows") else: print("No valid rows to write. Exiting.") + # Optionally clean up the PDF cache directory. + try: + shutil.rmtree(pdf_cache_dir) + print(f"Cleaned up PDF cache directory: {pdf_cache_dir}") + except Exception as e: + print(f"Error cleaning up PDF cache directory: {e}") if __name__ == "__main__": main()