mirror of
https://github.com/allenai/olmocr.git
synced 2025-10-03 12:28:35 +00:00
Better converter
This commit is contained in:
parent
e369569f99
commit
51cfdbd64f
@ -12,13 +12,17 @@ import glob
|
||||
import json
|
||||
import re
|
||||
import sqlite3
|
||||
import tempfile
|
||||
import os
|
||||
import shutil
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, List, Tuple
|
||||
from typing import Optional, List, Tuple, Dict, Set
|
||||
import concurrent.futures
|
||||
|
||||
import boto3
|
||||
from tqdm import tqdm
|
||||
import pandas as pd
|
||||
|
||||
from pypdf import PdfReader, PdfWriter
|
||||
|
||||
def parse_pdf_hash(pretty_pdf_path: str) -> Optional[str]:
|
||||
"""
|
||||
@ -27,13 +31,13 @@ def parse_pdf_hash(pretty_pdf_path: str) -> Optional[str]:
|
||||
s3://ai2-s2-pdfs/de80/a57e6c57b45796d2e020173227f7eae44232.pdf-1
|
||||
it will return "de80a57e6c57b45796d2e020173227f7eae44232".
|
||||
"""
|
||||
pattern = r"s3://ai2-s2-pdfs/([a-f0-9]{4})/([a-f0-9]+)\.pdf"
|
||||
# Allow an optional "-<number>" at the end.
|
||||
pattern = r"s3://ai2-s2-pdfs/([a-f0-9]{4})/([a-f0-9]+)\.pdf(?:-\d+)?$"
|
||||
match = re.match(pattern, pretty_pdf_path)
|
||||
if match:
|
||||
return match.group(1) + match.group(2)
|
||||
return None
|
||||
|
||||
|
||||
def get_uri_from_db(db_path: str, pdf_hash: str) -> Optional[str]:
|
||||
"""
|
||||
Looks up the URL for the given pdf_hash in the sqlite database.
|
||||
@ -44,8 +48,7 @@ def get_uri_from_db(db_path: str, pdf_hash: str) -> Optional[str]:
|
||||
cursor.execute("SELECT uri FROM pdf_mapping WHERE pdf_hash = ?", (pdf_hash,))
|
||||
result = cursor.fetchone()
|
||||
conn.close()
|
||||
return result[0] if result else None
|
||||
|
||||
return result[0].strip() if result and result[0] else None
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class NormalizedEntry:
|
||||
@ -70,7 +73,6 @@ class NormalizedEntry:
|
||||
def goldkey(self):
|
||||
return f"{self.s3_path}-{self.pagenum}"
|
||||
|
||||
|
||||
def normalize_json_entry(data: dict) -> NormalizedEntry:
|
||||
"""
|
||||
Normalizes a JSON entry from any of the supported formats.
|
||||
@ -107,14 +109,84 @@ def normalize_json_entry(data: dict) -> NormalizedEntry:
|
||||
else:
|
||||
raise ValueError("Unsupported JSON format")
|
||||
|
||||
def parse_s3_url(s3_url: str) -> Tuple[str, str]:
|
||||
"""
|
||||
Parses an S3 URL of the form s3://bucket/key and returns (bucket, key).
|
||||
"""
|
||||
if not s3_url.startswith("s3://"):
|
||||
raise ValueError(f"Invalid S3 URL: {s3_url}")
|
||||
s3_path = s3_url[5:]
|
||||
bucket, key = s3_path.split("/", 1)
|
||||
return bucket, key
|
||||
|
||||
def process_file(file_path: str, db_path: str) -> Tuple[List[dict], int]:
|
||||
def download_pdf_to_cache(s3_url: str, cache_dir: str) -> Optional[str]:
|
||||
"""
|
||||
Downloads the PDF from the given S3 URL into the specified cache directory.
|
||||
The destination filename is based on the parsed PDF hash.
|
||||
Returns the path to the downloaded PDF.
|
||||
"""
|
||||
try:
|
||||
bucket, key = parse_s3_url(s3_url)
|
||||
s3_client = boto3.client('s3')
|
||||
pdf_hash = parse_pdf_hash(s3_url)
|
||||
if not pdf_hash:
|
||||
# Fallback: use a sanitized version of the s3_url
|
||||
pdf_hash = re.sub(r'\W+', '_', s3_url)
|
||||
dest_path = os.path.join(cache_dir, f"{pdf_hash}.pdf")
|
||||
# Avoid re-downloading if already exists
|
||||
if not os.path.exists(dest_path):
|
||||
s3_client.download_file(bucket, key, dest_path)
|
||||
return dest_path
|
||||
except Exception as e:
|
||||
print(f"Error downloading {s3_url}: {e}")
|
||||
return None
|
||||
|
||||
def process_pdf_page(s3_url: str, page_number: int, combined_id: str, output_pdf_dir: str, pdf_cache: Dict[str, str]) -> Optional[str]:
|
||||
"""
|
||||
Extracts the specified page (1-indexed) from the cached PDF corresponding to s3_url.
|
||||
Writes a new single-page PDF to the output_pdf_dir using the combined_id as the filename.
|
||||
Returns the relative path to the new PDF (e.g., "pdfs/<combined_id>.pdf").
|
||||
"""
|
||||
try:
|
||||
local_cached_pdf = pdf_cache.get(s3_url)
|
||||
if not local_cached_pdf or not os.path.exists(local_cached_pdf):
|
||||
print(f"Cached PDF not found for {s3_url}")
|
||||
return None
|
||||
reader = PdfReader(local_cached_pdf)
|
||||
# pypdf uses 0-indexed page numbers
|
||||
page_index = page_number - 1
|
||||
if page_index < 0 or page_index >= len(reader.pages):
|
||||
print(f"Page number {page_number} out of range for PDF {s3_url}")
|
||||
return None
|
||||
writer = PdfWriter()
|
||||
writer.add_page(reader.pages[page_index])
|
||||
output_filename = f"{combined_id}.pdf"
|
||||
output_path = os.path.join(output_pdf_dir, output_filename)
|
||||
with open(output_path, "wb") as f_out:
|
||||
writer.write(f_out)
|
||||
# Return the relative path (assuming pdfs/ folder is relative to the parquet file location)
|
||||
return os.path.join("pdfs", output_filename)
|
||||
except Exception as e:
|
||||
print(f"Error processing PDF page for {s3_url} page {page_number}: {e}")
|
||||
return None
|
||||
|
||||
def process_file(file_path: str, db_path: str, output_pdf_dir: str, pdf_cache: Dict[str, str]) -> Tuple[List[dict], int]:
|
||||
"""
|
||||
Process a single file and return a tuple:
|
||||
(list of valid rows, number of rows skipped due to missing URL).
|
||||
(list of valid rows, number of rows skipped due to missing URL or PDF extraction/filtering).
|
||||
For each JSON entry, the function:
|
||||
- Normalizes the JSON.
|
||||
- Skips entries whose response contains the word "resume" (any case) along with either an email address or a phone number.
|
||||
- Extracts the PDF hash and builds the combined id.
|
||||
- Looks up the corresponding URL from the sqlite database.
|
||||
- Extracts the specified page from the cached PDF and writes it to output_pdf_dir.
|
||||
- Outputs a row with "id", "url", "page_number", "response".
|
||||
"""
|
||||
rows = []
|
||||
missing_count = 0
|
||||
email_regex = r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}\b"
|
||||
phone_regex = r"\b(?:\+?\d{1,3}[-.\s]?)?(?:\(?\d{3}\)?[-.\s]?)?\d{3}[-.\s]?\d{4}\b"
|
||||
|
||||
try:
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
for line_num, line in enumerate(f, start=1):
|
||||
@ -133,7 +205,14 @@ def process_file(file_path: str, db_path: str) -> Tuple[List[dict], int]:
|
||||
print(f"Error normalizing entry at {file_path}:{line_num} - {e}")
|
||||
continue
|
||||
|
||||
# Extract the pdf hash from the s3_path.
|
||||
# Apply filter: skip if response contains "resume" (any case) and an email or phone number.
|
||||
response_text = normalized.text if normalized.text else ""
|
||||
if (re.search(r"resume", response_text, re.IGNORECASE) and
|
||||
(re.search(email_regex, response_text) or re.search(phone_regex, response_text))):
|
||||
print(f"Skipping entry due to resume and contact info in response at {file_path}:{line_num}")
|
||||
continue
|
||||
|
||||
# Extract the PDF hash from the s3_path.
|
||||
pdf_hash = parse_pdf_hash(normalized.s3_path)
|
||||
if pdf_hash is None:
|
||||
print(f"Could not parse pdf hash from {normalized.s3_path} at {file_path}:{line_num}")
|
||||
@ -144,14 +223,18 @@ def process_file(file_path: str, db_path: str) -> Tuple[List[dict], int]:
|
||||
|
||||
# Look up the corresponding URL from the sqlite database.
|
||||
url = get_uri_from_db(db_path, pdf_hash)
|
||||
if url is not None:
|
||||
url = url.strip()
|
||||
# Skip rows with missing URLs (None or empty after strip)
|
||||
if not url:
|
||||
print(f"Missing URL for pdf hash {pdf_hash} at {file_path}:{line_num}")
|
||||
missing_count += 1
|
||||
continue
|
||||
|
||||
# Process PDF: extract the specified page from the cached PDF.
|
||||
local_pdf_path = process_pdf_page(normalized.s3_path, normalized.pagenum, combined_id, output_pdf_dir, pdf_cache)
|
||||
if local_pdf_path is None:
|
||||
print(f"Skipping entry because PDF processing failed for {normalized.s3_path} page {normalized.pagenum} at {file_path}:{line_num}")
|
||||
missing_count += 1
|
||||
continue
|
||||
|
||||
row = {
|
||||
"id": combined_id,
|
||||
"url": url,
|
||||
@ -163,6 +246,27 @@ def process_file(file_path: str, db_path: str) -> Tuple[List[dict], int]:
|
||||
print(f"Error processing file {file_path}: {e}")
|
||||
return rows, missing_count
|
||||
|
||||
def scan_file_for_s3_urls(file_path: str) -> Set[str]:
|
||||
"""
|
||||
Scans a single file and returns a set of unique S3 URLs found in the JSON entries.
|
||||
"""
|
||||
urls = set()
|
||||
try:
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
data = json.loads(line)
|
||||
normalized = normalize_json_entry(data)
|
||||
urls.add(normalized.s3_path)
|
||||
except Exception:
|
||||
# Skip entries that cannot be normalized
|
||||
continue
|
||||
except Exception as e:
|
||||
print(f"Error reading file {file_path}: {e}")
|
||||
return urls
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
@ -182,19 +286,60 @@ def main():
|
||||
files = glob.glob(args.input_dataset)
|
||||
print(f"Found {len(files)} files matching pattern: {args.input_dataset}")
|
||||
|
||||
# Determine output directory and create 'pdfs' subfolder.
|
||||
output_abs_path = os.path.abspath(args.output)
|
||||
output_dir = os.path.dirname(output_abs_path)
|
||||
pdfs_dir = os.path.join(output_dir, "pdfs")
|
||||
os.makedirs(pdfs_dir, exist_ok=True)
|
||||
|
||||
# Create a temporary directory for caching PDFs.
|
||||
pdf_cache_dir = tempfile.mkdtemp(prefix="pdf_cache_")
|
||||
print(f"Caching PDFs to temporary directory: {pdf_cache_dir}")
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
# Step 1: Scan input files to collect all unique S3 URLs using a ProcessPoolExecutor.
|
||||
unique_s3_urls: Set[str] = set()
|
||||
print("Scanning input files to collect unique PDF URLs...")
|
||||
with concurrent.futures.ProcessPoolExecutor() as executor:
|
||||
results = list(tqdm(executor.map(scan_file_for_s3_urls, files), total=len(files), desc="Scanning files"))
|
||||
for url_set in results:
|
||||
unique_s3_urls |= url_set
|
||||
|
||||
print(f"Found {len(unique_s3_urls)} unique PDF URLs.")
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
# Step 2: Download all unique PDFs to the cache directory.
|
||||
pdf_cache: Dict[str, str] = {}
|
||||
print("Caching PDFs from S3...")
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future_to_url = {
|
||||
executor.submit(download_pdf_to_cache, s3_url, pdf_cache_dir): s3_url
|
||||
for s3_url in unique_s3_urls
|
||||
}
|
||||
for future in tqdm(concurrent.futures.as_completed(future_to_url),
|
||||
total=len(future_to_url), desc="Downloading PDFs"):
|
||||
s3_url = future_to_url[future]
|
||||
try:
|
||||
local_path = future.result()
|
||||
if local_path:
|
||||
pdf_cache[s3_url] = local_path
|
||||
else:
|
||||
print(f"Failed to cache PDF for {s3_url}")
|
||||
except Exception as e:
|
||||
print(f"Error caching PDF for {s3_url}: {e}")
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
# Step 3: Process input files using the precached PDFs.
|
||||
all_rows = []
|
||||
total_missing = 0
|
||||
# Process files in parallel using ProcessPoolExecutor.
|
||||
print("Processing files...")
|
||||
with concurrent.futures.ProcessPoolExecutor() as executor:
|
||||
futures = {
|
||||
executor.submit(process_file, file_path, args.db_path): file_path
|
||||
executor.submit(process_file, file_path, args.db_path, pdfs_dir, pdf_cache): file_path
|
||||
for file_path in files
|
||||
}
|
||||
for future in tqdm(
|
||||
concurrent.futures.as_completed(futures),
|
||||
total=len(futures),
|
||||
desc="Processing files",
|
||||
):
|
||||
for future in tqdm(concurrent.futures.as_completed(futures),
|
||||
total=len(futures), desc="Processing files"):
|
||||
file_path = futures[future]
|
||||
try:
|
||||
rows, missing_count = future.result()
|
||||
@ -212,10 +357,16 @@ def main():
|
||||
valid_count = len(df)
|
||||
total_processed = valid_count + total_missing
|
||||
print(f"Successfully wrote {valid_count} rows to {args.output}")
|
||||
print(f"Missing URL rows skipped: {total_missing} out of {total_processed} processed rows")
|
||||
print(f"Rows skipped due to missing URL/PDF or filtering: {total_missing} out of {total_processed} processed rows")
|
||||
else:
|
||||
print("No valid rows to write. Exiting.")
|
||||
|
||||
# Optionally clean up the PDF cache directory.
|
||||
try:
|
||||
shutil.rmtree(pdf_cache_dir)
|
||||
print(f"Cleaned up PDF cache directory: {pdf_cache_dir}")
|
||||
except Exception as e:
|
||||
print(f"Error cleaning up PDF cache directory: {e}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user