Better converter

This commit is contained in:
Jake Poznanski 2025-02-13 22:30:20 +00:00
parent e369569f99
commit 51cfdbd64f

View File

@ -12,13 +12,17 @@ import glob
import json
import re
import sqlite3
import tempfile
import os
import shutil
from dataclasses import dataclass
from typing import Optional, List, Tuple
from typing import Optional, List, Tuple, Dict, Set
import concurrent.futures
import boto3
from tqdm import tqdm
import pandas as pd
from pypdf import PdfReader, PdfWriter
def parse_pdf_hash(pretty_pdf_path: str) -> Optional[str]:
"""
@ -27,13 +31,13 @@ def parse_pdf_hash(pretty_pdf_path: str) -> Optional[str]:
s3://ai2-s2-pdfs/de80/a57e6c57b45796d2e020173227f7eae44232.pdf-1
it will return "de80a57e6c57b45796d2e020173227f7eae44232".
"""
pattern = r"s3://ai2-s2-pdfs/([a-f0-9]{4})/([a-f0-9]+)\.pdf"
# Allow an optional "-<number>" at the end.
pattern = r"s3://ai2-s2-pdfs/([a-f0-9]{4})/([a-f0-9]+)\.pdf(?:-\d+)?$"
match = re.match(pattern, pretty_pdf_path)
if match:
return match.group(1) + match.group(2)
return None
def get_uri_from_db(db_path: str, pdf_hash: str) -> Optional[str]:
"""
Looks up the URL for the given pdf_hash in the sqlite database.
@ -44,8 +48,7 @@ def get_uri_from_db(db_path: str, pdf_hash: str) -> Optional[str]:
cursor.execute("SELECT uri FROM pdf_mapping WHERE pdf_hash = ?", (pdf_hash,))
result = cursor.fetchone()
conn.close()
return result[0] if result else None
return result[0].strip() if result and result[0] else None
@dataclass(frozen=True)
class NormalizedEntry:
@ -70,7 +73,6 @@ class NormalizedEntry:
def goldkey(self):
return f"{self.s3_path}-{self.pagenum}"
def normalize_json_entry(data: dict) -> NormalizedEntry:
"""
Normalizes a JSON entry from any of the supported formats.
@ -107,14 +109,84 @@ def normalize_json_entry(data: dict) -> NormalizedEntry:
else:
raise ValueError("Unsupported JSON format")
def parse_s3_url(s3_url: str) -> Tuple[str, str]:
"""
Parses an S3 URL of the form s3://bucket/key and returns (bucket, key).
"""
if not s3_url.startswith("s3://"):
raise ValueError(f"Invalid S3 URL: {s3_url}")
s3_path = s3_url[5:]
bucket, key = s3_path.split("/", 1)
return bucket, key
def process_file(file_path: str, db_path: str) -> Tuple[List[dict], int]:
def download_pdf_to_cache(s3_url: str, cache_dir: str) -> Optional[str]:
"""
Downloads the PDF from the given S3 URL into the specified cache directory.
The destination filename is based on the parsed PDF hash.
Returns the path to the downloaded PDF.
"""
try:
bucket, key = parse_s3_url(s3_url)
s3_client = boto3.client('s3')
pdf_hash = parse_pdf_hash(s3_url)
if not pdf_hash:
# Fallback: use a sanitized version of the s3_url
pdf_hash = re.sub(r'\W+', '_', s3_url)
dest_path = os.path.join(cache_dir, f"{pdf_hash}.pdf")
# Avoid re-downloading if already exists
if not os.path.exists(dest_path):
s3_client.download_file(bucket, key, dest_path)
return dest_path
except Exception as e:
print(f"Error downloading {s3_url}: {e}")
return None
def process_pdf_page(s3_url: str, page_number: int, combined_id: str, output_pdf_dir: str, pdf_cache: Dict[str, str]) -> Optional[str]:
"""
Extracts the specified page (1-indexed) from the cached PDF corresponding to s3_url.
Writes a new single-page PDF to the output_pdf_dir using the combined_id as the filename.
Returns the relative path to the new PDF (e.g., "pdfs/<combined_id>.pdf").
"""
try:
local_cached_pdf = pdf_cache.get(s3_url)
if not local_cached_pdf or not os.path.exists(local_cached_pdf):
print(f"Cached PDF not found for {s3_url}")
return None
reader = PdfReader(local_cached_pdf)
# pypdf uses 0-indexed page numbers
page_index = page_number - 1
if page_index < 0 or page_index >= len(reader.pages):
print(f"Page number {page_number} out of range for PDF {s3_url}")
return None
writer = PdfWriter()
writer.add_page(reader.pages[page_index])
output_filename = f"{combined_id}.pdf"
output_path = os.path.join(output_pdf_dir, output_filename)
with open(output_path, "wb") as f_out:
writer.write(f_out)
# Return the relative path (assuming pdfs/ folder is relative to the parquet file location)
return os.path.join("pdfs", output_filename)
except Exception as e:
print(f"Error processing PDF page for {s3_url} page {page_number}: {e}")
return None
def process_file(file_path: str, db_path: str, output_pdf_dir: str, pdf_cache: Dict[str, str]) -> Tuple[List[dict], int]:
"""
Process a single file and return a tuple:
(list of valid rows, number of rows skipped due to missing URL).
(list of valid rows, number of rows skipped due to missing URL or PDF extraction/filtering).
For each JSON entry, the function:
- Normalizes the JSON.
- Skips entries whose response contains the word "resume" (any case) along with either an email address or a phone number.
- Extracts the PDF hash and builds the combined id.
- Looks up the corresponding URL from the sqlite database.
- Extracts the specified page from the cached PDF and writes it to output_pdf_dir.
- Outputs a row with "id", "url", "page_number", "response".
"""
rows = []
missing_count = 0
email_regex = r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}\b"
phone_regex = r"\b(?:\+?\d{1,3}[-.\s]?)?(?:\(?\d{3}\)?[-.\s]?)?\d{3}[-.\s]?\d{4}\b"
try:
with open(file_path, "r", encoding="utf-8") as f:
for line_num, line in enumerate(f, start=1):
@ -133,7 +205,14 @@ def process_file(file_path: str, db_path: str) -> Tuple[List[dict], int]:
print(f"Error normalizing entry at {file_path}:{line_num} - {e}")
continue
# Extract the pdf hash from the s3_path.
# Apply filter: skip if response contains "resume" (any case) and an email or phone number.
response_text = normalized.text if normalized.text else ""
if (re.search(r"resume", response_text, re.IGNORECASE) and
(re.search(email_regex, response_text) or re.search(phone_regex, response_text))):
print(f"Skipping entry due to resume and contact info in response at {file_path}:{line_num}")
continue
# Extract the PDF hash from the s3_path.
pdf_hash = parse_pdf_hash(normalized.s3_path)
if pdf_hash is None:
print(f"Could not parse pdf hash from {normalized.s3_path} at {file_path}:{line_num}")
@ -144,14 +223,18 @@ def process_file(file_path: str, db_path: str) -> Tuple[List[dict], int]:
# Look up the corresponding URL from the sqlite database.
url = get_uri_from_db(db_path, pdf_hash)
if url is not None:
url = url.strip()
# Skip rows with missing URLs (None or empty after strip)
if not url:
print(f"Missing URL for pdf hash {pdf_hash} at {file_path}:{line_num}")
missing_count += 1
continue
# Process PDF: extract the specified page from the cached PDF.
local_pdf_path = process_pdf_page(normalized.s3_path, normalized.pagenum, combined_id, output_pdf_dir, pdf_cache)
if local_pdf_path is None:
print(f"Skipping entry because PDF processing failed for {normalized.s3_path} page {normalized.pagenum} at {file_path}:{line_num}")
missing_count += 1
continue
row = {
"id": combined_id,
"url": url,
@ -163,6 +246,27 @@ def process_file(file_path: str, db_path: str) -> Tuple[List[dict], int]:
print(f"Error processing file {file_path}: {e}")
return rows, missing_count
def scan_file_for_s3_urls(file_path: str) -> Set[str]:
"""
Scans a single file and returns a set of unique S3 URLs found in the JSON entries.
"""
urls = set()
try:
with open(file_path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
try:
data = json.loads(line)
normalized = normalize_json_entry(data)
urls.add(normalized.s3_path)
except Exception:
# Skip entries that cannot be normalized
continue
except Exception as e:
print(f"Error reading file {file_path}: {e}")
return urls
def main():
parser = argparse.ArgumentParser(
@ -182,19 +286,60 @@ def main():
files = glob.glob(args.input_dataset)
print(f"Found {len(files)} files matching pattern: {args.input_dataset}")
# Determine output directory and create 'pdfs' subfolder.
output_abs_path = os.path.abspath(args.output)
output_dir = os.path.dirname(output_abs_path)
pdfs_dir = os.path.join(output_dir, "pdfs")
os.makedirs(pdfs_dir, exist_ok=True)
# Create a temporary directory for caching PDFs.
pdf_cache_dir = tempfile.mkdtemp(prefix="pdf_cache_")
print(f"Caching PDFs to temporary directory: {pdf_cache_dir}")
# ---------------------------------------------------------------------
# Step 1: Scan input files to collect all unique S3 URLs using a ProcessPoolExecutor.
unique_s3_urls: Set[str] = set()
print("Scanning input files to collect unique PDF URLs...")
with concurrent.futures.ProcessPoolExecutor() as executor:
results = list(tqdm(executor.map(scan_file_for_s3_urls, files), total=len(files), desc="Scanning files"))
for url_set in results:
unique_s3_urls |= url_set
print(f"Found {len(unique_s3_urls)} unique PDF URLs.")
# ---------------------------------------------------------------------
# Step 2: Download all unique PDFs to the cache directory.
pdf_cache: Dict[str, str] = {}
print("Caching PDFs from S3...")
with concurrent.futures.ThreadPoolExecutor() as executor:
future_to_url = {
executor.submit(download_pdf_to_cache, s3_url, pdf_cache_dir): s3_url
for s3_url in unique_s3_urls
}
for future in tqdm(concurrent.futures.as_completed(future_to_url),
total=len(future_to_url), desc="Downloading PDFs"):
s3_url = future_to_url[future]
try:
local_path = future.result()
if local_path:
pdf_cache[s3_url] = local_path
else:
print(f"Failed to cache PDF for {s3_url}")
except Exception as e:
print(f"Error caching PDF for {s3_url}: {e}")
# ---------------------------------------------------------------------
# Step 3: Process input files using the precached PDFs.
all_rows = []
total_missing = 0
# Process files in parallel using ProcessPoolExecutor.
print("Processing files...")
with concurrent.futures.ProcessPoolExecutor() as executor:
futures = {
executor.submit(process_file, file_path, args.db_path): file_path
executor.submit(process_file, file_path, args.db_path, pdfs_dir, pdf_cache): file_path
for file_path in files
}
for future in tqdm(
concurrent.futures.as_completed(futures),
total=len(futures),
desc="Processing files",
):
for future in tqdm(concurrent.futures.as_completed(futures),
total=len(futures), desc="Processing files"):
file_path = futures[future]
try:
rows, missing_count = future.result()
@ -212,10 +357,16 @@ def main():
valid_count = len(df)
total_processed = valid_count + total_missing
print(f"Successfully wrote {valid_count} rows to {args.output}")
print(f"Missing URL rows skipped: {total_missing} out of {total_processed} processed rows")
print(f"Rows skipped due to missing URL/PDF or filtering: {total_missing} out of {total_processed} processed rows")
else:
print("No valid rows to write. Exiting.")
# Optionally clean up the PDF cache directory.
try:
shutil.rmtree(pdf_cache_dir)
print(f"Cleaned up PDF cache directory: {pdf_cache_dir}")
except Exception as e:
print(f"Error cleaning up PDF cache directory: {e}")
if __name__ == "__main__":
main()