mirror of
https://github.com/allenai/olmocr.git
synced 2025-10-03 04:15:07 +00:00
Better converter
This commit is contained in:
parent
e369569f99
commit
51cfdbd64f
@ -12,13 +12,17 @@ import glob
|
|||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
import sqlite3
|
import sqlite3
|
||||||
|
import tempfile
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional, List, Tuple
|
from typing import Optional, List, Tuple, Dict, Set
|
||||||
import concurrent.futures
|
import concurrent.futures
|
||||||
|
|
||||||
|
import boto3
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
from pypdf import PdfReader, PdfWriter
|
||||||
|
|
||||||
def parse_pdf_hash(pretty_pdf_path: str) -> Optional[str]:
|
def parse_pdf_hash(pretty_pdf_path: str) -> Optional[str]:
|
||||||
"""
|
"""
|
||||||
@ -27,13 +31,13 @@ def parse_pdf_hash(pretty_pdf_path: str) -> Optional[str]:
|
|||||||
s3://ai2-s2-pdfs/de80/a57e6c57b45796d2e020173227f7eae44232.pdf-1
|
s3://ai2-s2-pdfs/de80/a57e6c57b45796d2e020173227f7eae44232.pdf-1
|
||||||
it will return "de80a57e6c57b45796d2e020173227f7eae44232".
|
it will return "de80a57e6c57b45796d2e020173227f7eae44232".
|
||||||
"""
|
"""
|
||||||
pattern = r"s3://ai2-s2-pdfs/([a-f0-9]{4})/([a-f0-9]+)\.pdf"
|
# Allow an optional "-<number>" at the end.
|
||||||
|
pattern = r"s3://ai2-s2-pdfs/([a-f0-9]{4})/([a-f0-9]+)\.pdf(?:-\d+)?$"
|
||||||
match = re.match(pattern, pretty_pdf_path)
|
match = re.match(pattern, pretty_pdf_path)
|
||||||
if match:
|
if match:
|
||||||
return match.group(1) + match.group(2)
|
return match.group(1) + match.group(2)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def get_uri_from_db(db_path: str, pdf_hash: str) -> Optional[str]:
|
def get_uri_from_db(db_path: str, pdf_hash: str) -> Optional[str]:
|
||||||
"""
|
"""
|
||||||
Looks up the URL for the given pdf_hash in the sqlite database.
|
Looks up the URL for the given pdf_hash in the sqlite database.
|
||||||
@ -44,8 +48,7 @@ def get_uri_from_db(db_path: str, pdf_hash: str) -> Optional[str]:
|
|||||||
cursor.execute("SELECT uri FROM pdf_mapping WHERE pdf_hash = ?", (pdf_hash,))
|
cursor.execute("SELECT uri FROM pdf_mapping WHERE pdf_hash = ?", (pdf_hash,))
|
||||||
result = cursor.fetchone()
|
result = cursor.fetchone()
|
||||||
conn.close()
|
conn.close()
|
||||||
return result[0] if result else None
|
return result[0].strip() if result and result[0] else None
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class NormalizedEntry:
|
class NormalizedEntry:
|
||||||
@ -70,7 +73,6 @@ class NormalizedEntry:
|
|||||||
def goldkey(self):
|
def goldkey(self):
|
||||||
return f"{self.s3_path}-{self.pagenum}"
|
return f"{self.s3_path}-{self.pagenum}"
|
||||||
|
|
||||||
|
|
||||||
def normalize_json_entry(data: dict) -> NormalizedEntry:
|
def normalize_json_entry(data: dict) -> NormalizedEntry:
|
||||||
"""
|
"""
|
||||||
Normalizes a JSON entry from any of the supported formats.
|
Normalizes a JSON entry from any of the supported formats.
|
||||||
@ -107,14 +109,84 @@ def normalize_json_entry(data: dict) -> NormalizedEntry:
|
|||||||
else:
|
else:
|
||||||
raise ValueError("Unsupported JSON format")
|
raise ValueError("Unsupported JSON format")
|
||||||
|
|
||||||
|
def parse_s3_url(s3_url: str) -> Tuple[str, str]:
|
||||||
|
"""
|
||||||
|
Parses an S3 URL of the form s3://bucket/key and returns (bucket, key).
|
||||||
|
"""
|
||||||
|
if not s3_url.startswith("s3://"):
|
||||||
|
raise ValueError(f"Invalid S3 URL: {s3_url}")
|
||||||
|
s3_path = s3_url[5:]
|
||||||
|
bucket, key = s3_path.split("/", 1)
|
||||||
|
return bucket, key
|
||||||
|
|
||||||
def process_file(file_path: str, db_path: str) -> Tuple[List[dict], int]:
|
def download_pdf_to_cache(s3_url: str, cache_dir: str) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Downloads the PDF from the given S3 URL into the specified cache directory.
|
||||||
|
The destination filename is based on the parsed PDF hash.
|
||||||
|
Returns the path to the downloaded PDF.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
bucket, key = parse_s3_url(s3_url)
|
||||||
|
s3_client = boto3.client('s3')
|
||||||
|
pdf_hash = parse_pdf_hash(s3_url)
|
||||||
|
if not pdf_hash:
|
||||||
|
# Fallback: use a sanitized version of the s3_url
|
||||||
|
pdf_hash = re.sub(r'\W+', '_', s3_url)
|
||||||
|
dest_path = os.path.join(cache_dir, f"{pdf_hash}.pdf")
|
||||||
|
# Avoid re-downloading if already exists
|
||||||
|
if not os.path.exists(dest_path):
|
||||||
|
s3_client.download_file(bucket, key, dest_path)
|
||||||
|
return dest_path
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error downloading {s3_url}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def process_pdf_page(s3_url: str, page_number: int, combined_id: str, output_pdf_dir: str, pdf_cache: Dict[str, str]) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Extracts the specified page (1-indexed) from the cached PDF corresponding to s3_url.
|
||||||
|
Writes a new single-page PDF to the output_pdf_dir using the combined_id as the filename.
|
||||||
|
Returns the relative path to the new PDF (e.g., "pdfs/<combined_id>.pdf").
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
local_cached_pdf = pdf_cache.get(s3_url)
|
||||||
|
if not local_cached_pdf or not os.path.exists(local_cached_pdf):
|
||||||
|
print(f"Cached PDF not found for {s3_url}")
|
||||||
|
return None
|
||||||
|
reader = PdfReader(local_cached_pdf)
|
||||||
|
# pypdf uses 0-indexed page numbers
|
||||||
|
page_index = page_number - 1
|
||||||
|
if page_index < 0 or page_index >= len(reader.pages):
|
||||||
|
print(f"Page number {page_number} out of range for PDF {s3_url}")
|
||||||
|
return None
|
||||||
|
writer = PdfWriter()
|
||||||
|
writer.add_page(reader.pages[page_index])
|
||||||
|
output_filename = f"{combined_id}.pdf"
|
||||||
|
output_path = os.path.join(output_pdf_dir, output_filename)
|
||||||
|
with open(output_path, "wb") as f_out:
|
||||||
|
writer.write(f_out)
|
||||||
|
# Return the relative path (assuming pdfs/ folder is relative to the parquet file location)
|
||||||
|
return os.path.join("pdfs", output_filename)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing PDF page for {s3_url} page {page_number}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def process_file(file_path: str, db_path: str, output_pdf_dir: str, pdf_cache: Dict[str, str]) -> Tuple[List[dict], int]:
|
||||||
"""
|
"""
|
||||||
Process a single file and return a tuple:
|
Process a single file and return a tuple:
|
||||||
(list of valid rows, number of rows skipped due to missing URL).
|
(list of valid rows, number of rows skipped due to missing URL or PDF extraction/filtering).
|
||||||
|
For each JSON entry, the function:
|
||||||
|
- Normalizes the JSON.
|
||||||
|
- Skips entries whose response contains the word "resume" (any case) along with either an email address or a phone number.
|
||||||
|
- Extracts the PDF hash and builds the combined id.
|
||||||
|
- Looks up the corresponding URL from the sqlite database.
|
||||||
|
- Extracts the specified page from the cached PDF and writes it to output_pdf_dir.
|
||||||
|
- Outputs a row with "id", "url", "page_number", "response".
|
||||||
"""
|
"""
|
||||||
rows = []
|
rows = []
|
||||||
missing_count = 0
|
missing_count = 0
|
||||||
|
email_regex = r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}\b"
|
||||||
|
phone_regex = r"\b(?:\+?\d{1,3}[-.\s]?)?(?:\(?\d{3}\)?[-.\s]?)?\d{3}[-.\s]?\d{4}\b"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with open(file_path, "r", encoding="utf-8") as f:
|
with open(file_path, "r", encoding="utf-8") as f:
|
||||||
for line_num, line in enumerate(f, start=1):
|
for line_num, line in enumerate(f, start=1):
|
||||||
@ -133,7 +205,14 @@ def process_file(file_path: str, db_path: str) -> Tuple[List[dict], int]:
|
|||||||
print(f"Error normalizing entry at {file_path}:{line_num} - {e}")
|
print(f"Error normalizing entry at {file_path}:{line_num} - {e}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Extract the pdf hash from the s3_path.
|
# Apply filter: skip if response contains "resume" (any case) and an email or phone number.
|
||||||
|
response_text = normalized.text if normalized.text else ""
|
||||||
|
if (re.search(r"resume", response_text, re.IGNORECASE) and
|
||||||
|
(re.search(email_regex, response_text) or re.search(phone_regex, response_text))):
|
||||||
|
print(f"Skipping entry due to resume and contact info in response at {file_path}:{line_num}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Extract the PDF hash from the s3_path.
|
||||||
pdf_hash = parse_pdf_hash(normalized.s3_path)
|
pdf_hash = parse_pdf_hash(normalized.s3_path)
|
||||||
if pdf_hash is None:
|
if pdf_hash is None:
|
||||||
print(f"Could not parse pdf hash from {normalized.s3_path} at {file_path}:{line_num}")
|
print(f"Could not parse pdf hash from {normalized.s3_path} at {file_path}:{line_num}")
|
||||||
@ -144,14 +223,18 @@ def process_file(file_path: str, db_path: str) -> Tuple[List[dict], int]:
|
|||||||
|
|
||||||
# Look up the corresponding URL from the sqlite database.
|
# Look up the corresponding URL from the sqlite database.
|
||||||
url = get_uri_from_db(db_path, pdf_hash)
|
url = get_uri_from_db(db_path, pdf_hash)
|
||||||
if url is not None:
|
|
||||||
url = url.strip()
|
|
||||||
# Skip rows with missing URLs (None or empty after strip)
|
|
||||||
if not url:
|
if not url:
|
||||||
print(f"Missing URL for pdf hash {pdf_hash} at {file_path}:{line_num}")
|
print(f"Missing URL for pdf hash {pdf_hash} at {file_path}:{line_num}")
|
||||||
missing_count += 1
|
missing_count += 1
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# Process PDF: extract the specified page from the cached PDF.
|
||||||
|
local_pdf_path = process_pdf_page(normalized.s3_path, normalized.pagenum, combined_id, output_pdf_dir, pdf_cache)
|
||||||
|
if local_pdf_path is None:
|
||||||
|
print(f"Skipping entry because PDF processing failed for {normalized.s3_path} page {normalized.pagenum} at {file_path}:{line_num}")
|
||||||
|
missing_count += 1
|
||||||
|
continue
|
||||||
|
|
||||||
row = {
|
row = {
|
||||||
"id": combined_id,
|
"id": combined_id,
|
||||||
"url": url,
|
"url": url,
|
||||||
@ -163,6 +246,27 @@ def process_file(file_path: str, db_path: str) -> Tuple[List[dict], int]:
|
|||||||
print(f"Error processing file {file_path}: {e}")
|
print(f"Error processing file {file_path}: {e}")
|
||||||
return rows, missing_count
|
return rows, missing_count
|
||||||
|
|
||||||
|
def scan_file_for_s3_urls(file_path: str) -> Set[str]:
|
||||||
|
"""
|
||||||
|
Scans a single file and returns a set of unique S3 URLs found in the JSON entries.
|
||||||
|
"""
|
||||||
|
urls = set()
|
||||||
|
try:
|
||||||
|
with open(file_path, "r", encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
line = line.strip()
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
data = json.loads(line)
|
||||||
|
normalized = normalize_json_entry(data)
|
||||||
|
urls.add(normalized.s3_path)
|
||||||
|
except Exception:
|
||||||
|
# Skip entries that cannot be normalized
|
||||||
|
continue
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error reading file {file_path}: {e}")
|
||||||
|
return urls
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
@ -182,19 +286,60 @@ def main():
|
|||||||
files = glob.glob(args.input_dataset)
|
files = glob.glob(args.input_dataset)
|
||||||
print(f"Found {len(files)} files matching pattern: {args.input_dataset}")
|
print(f"Found {len(files)} files matching pattern: {args.input_dataset}")
|
||||||
|
|
||||||
|
# Determine output directory and create 'pdfs' subfolder.
|
||||||
|
output_abs_path = os.path.abspath(args.output)
|
||||||
|
output_dir = os.path.dirname(output_abs_path)
|
||||||
|
pdfs_dir = os.path.join(output_dir, "pdfs")
|
||||||
|
os.makedirs(pdfs_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# Create a temporary directory for caching PDFs.
|
||||||
|
pdf_cache_dir = tempfile.mkdtemp(prefix="pdf_cache_")
|
||||||
|
print(f"Caching PDFs to temporary directory: {pdf_cache_dir}")
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------
|
||||||
|
# Step 1: Scan input files to collect all unique S3 URLs using a ProcessPoolExecutor.
|
||||||
|
unique_s3_urls: Set[str] = set()
|
||||||
|
print("Scanning input files to collect unique PDF URLs...")
|
||||||
|
with concurrent.futures.ProcessPoolExecutor() as executor:
|
||||||
|
results = list(tqdm(executor.map(scan_file_for_s3_urls, files), total=len(files), desc="Scanning files"))
|
||||||
|
for url_set in results:
|
||||||
|
unique_s3_urls |= url_set
|
||||||
|
|
||||||
|
print(f"Found {len(unique_s3_urls)} unique PDF URLs.")
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------
|
||||||
|
# Step 2: Download all unique PDFs to the cache directory.
|
||||||
|
pdf_cache: Dict[str, str] = {}
|
||||||
|
print("Caching PDFs from S3...")
|
||||||
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||||
|
future_to_url = {
|
||||||
|
executor.submit(download_pdf_to_cache, s3_url, pdf_cache_dir): s3_url
|
||||||
|
for s3_url in unique_s3_urls
|
||||||
|
}
|
||||||
|
for future in tqdm(concurrent.futures.as_completed(future_to_url),
|
||||||
|
total=len(future_to_url), desc="Downloading PDFs"):
|
||||||
|
s3_url = future_to_url[future]
|
||||||
|
try:
|
||||||
|
local_path = future.result()
|
||||||
|
if local_path:
|
||||||
|
pdf_cache[s3_url] = local_path
|
||||||
|
else:
|
||||||
|
print(f"Failed to cache PDF for {s3_url}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error caching PDF for {s3_url}: {e}")
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------
|
||||||
|
# Step 3: Process input files using the precached PDFs.
|
||||||
all_rows = []
|
all_rows = []
|
||||||
total_missing = 0
|
total_missing = 0
|
||||||
# Process files in parallel using ProcessPoolExecutor.
|
print("Processing files...")
|
||||||
with concurrent.futures.ProcessPoolExecutor() as executor:
|
with concurrent.futures.ProcessPoolExecutor() as executor:
|
||||||
futures = {
|
futures = {
|
||||||
executor.submit(process_file, file_path, args.db_path): file_path
|
executor.submit(process_file, file_path, args.db_path, pdfs_dir, pdf_cache): file_path
|
||||||
for file_path in files
|
for file_path in files
|
||||||
}
|
}
|
||||||
for future in tqdm(
|
for future in tqdm(concurrent.futures.as_completed(futures),
|
||||||
concurrent.futures.as_completed(futures),
|
total=len(futures), desc="Processing files"):
|
||||||
total=len(futures),
|
|
||||||
desc="Processing files",
|
|
||||||
):
|
|
||||||
file_path = futures[future]
|
file_path = futures[future]
|
||||||
try:
|
try:
|
||||||
rows, missing_count = future.result()
|
rows, missing_count = future.result()
|
||||||
@ -212,10 +357,16 @@ def main():
|
|||||||
valid_count = len(df)
|
valid_count = len(df)
|
||||||
total_processed = valid_count + total_missing
|
total_processed = valid_count + total_missing
|
||||||
print(f"Successfully wrote {valid_count} rows to {args.output}")
|
print(f"Successfully wrote {valid_count} rows to {args.output}")
|
||||||
print(f"Missing URL rows skipped: {total_missing} out of {total_processed} processed rows")
|
print(f"Rows skipped due to missing URL/PDF or filtering: {total_missing} out of {total_processed} processed rows")
|
||||||
else:
|
else:
|
||||||
print("No valid rows to write. Exiting.")
|
print("No valid rows to write. Exiting.")
|
||||||
|
|
||||||
|
# Optionally clean up the PDF cache directory.
|
||||||
|
try:
|
||||||
|
shutil.rmtree(pdf_cache_dir)
|
||||||
|
print(f"Cleaned up PDF cache directory: {pdf_cache_dir}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error cleaning up PDF cache directory: {e}")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user