Local path supported

This commit is contained in:
Jake Poznanski 2025-08-22 20:11:35 +00:00
parent 5cbe331259
commit f3ea1527ef

View File

@ -7,6 +7,7 @@ import random
import re
import subprocess
import uuid
from collections import defaultdict
from typing import Dict, List
import pypdf
@ -25,11 +26,28 @@ from olmocr.data.renderpdf import (
from olmocr.filter.filter import PdfFilter, Language
def download_s3_pdf(s3_path, local_path):
"""Download a PDF from S3 to a local path."""
def download_s3_pdf(path, local_path):
"""Download a PDF from S3 or copy from local path."""
os.makedirs(os.path.dirname(local_path), exist_ok=True)
result = subprocess.run(["aws", "s3", "cp", s3_path, local_path], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
return result.returncode == 0
# Check if it's a local path
if os.path.exists(path):
# It's a local file, just copy it
import shutil
try:
shutil.copy2(path, local_path)
return True
except Exception as e:
print(f"Failed to copy local file {path}: {e}")
return False
elif path.startswith("s3://"):
# It's an S3 path, download it
result = subprocess.run(["aws", "s3", "cp", path, local_path], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
return result.returncode == 0
else:
# Assume it's a relative local path that doesn't exist yet
print(f"Path not found and doesn't appear to be S3: {path}")
return False
class PreserveTablesConverter(MarkdownConverter):
@ -1012,7 +1030,7 @@ def generate_tests_from_html(html_content: str, pdf_id: str, page_num: int, verb
async def process_pdf(pdf_info, args, client, pdf_filter=None):
"""Process a single PDF, render a random page, and create an HTML template."""
s3_path, index = pdf_info
pdf_path, index = pdf_info
# Create a unique folder for each PDF in the temp directory
pdf_id = f"pdf_{index:05d}"
@ -1022,15 +1040,15 @@ async def process_pdf(pdf_info, args, client, pdf_filter=None):
# Determine if we should log table test verification
verbose_table_testing = args.verbose
# Download PDF to local temp directory
# Download PDF to local temp directory (or copy if local)
local_pdf_path = os.path.join(temp_pdf_dir, "document.pdf")
if not download_s3_pdf(s3_path, local_pdf_path):
print(f"Failed to download PDF from {s3_path}")
if not download_s3_pdf(pdf_path, local_pdf_path):
print(f"Failed to download/copy PDF from {pdf_path}")
return None
# Apply filter if enabled
if pdf_filter and pdf_filter.filter_out_pdf(local_pdf_path):
print(f"PDF filtered out: {s3_path}")
print(f"PDF filtered out: {pdf_path}")
return None
try:
@ -1039,7 +1057,7 @@ async def process_pdf(pdf_info, args, client, pdf_filter=None):
num_pages = len(reader.pages)
if num_pages == 0:
print(f"PDF has no pages: {s3_path}")
print(f"PDF has no pages: {pdf_path}")
return None
# Select a random page
@ -1052,7 +1070,7 @@ async def process_pdf(pdf_info, args, client, pdf_filter=None):
# Generate HTML from the image
html_content = await generate_html_from_image(client, image_base64)
if not html_content:
print(f"Failed to generate HTML for {s3_path}, page {page_num}")
print(f"Failed to generate HTML for {pdf_path}, page {page_num}")
return None
# Create output directories
@ -1157,7 +1175,7 @@ async def process_pdf(pdf_info, args, client, pdf_filter=None):
return {
"pdf_id": pdf_id,
"s3_path": s3_path,
"pdf_path": pdf_path,
"page_number": page_num,
"html_path": html_path,
"markdown_path": markdown_path,
@ -1167,7 +1185,7 @@ async def process_pdf(pdf_info, args, client, pdf_filter=None):
"num_tests": len(tests),
}
except Exception as e:
print(f"Error processing {s3_path}: {e}")
print(f"Error processing {pdf_path}: {e}")
return None
finally:
# Clean up temp directory for this PDF
@ -1181,7 +1199,7 @@ async def main():
logging.getLogger("httpcore").setLevel(logging.WARNING)
parser = argparse.ArgumentParser(description="Convert PDFs to HTML templates and render with Playwright")
parser.add_argument("--input_list", required=True, help="Path to a file containing S3 paths to PDFs")
parser.add_argument("--input_list", required=True, help="Path to a file containing S3 paths or local paths to PDFs")
parser.add_argument("--output_dir", required=True, help="Directory to store extracted pages and tests")
parser.add_argument("--temp_dir", default="/tmp/mine_tables", help="Directory for temporary files")
parser.add_argument("--max_tests", type=int, default=100, help="Maximum number of tests to generate")
@ -1217,7 +1235,7 @@ async def main():
print("PDF filtering enabled")
# Reservoir sampling implementation
s3_paths = []
pdf_paths = []
with open(args.input_list, "r") as f:
for i, line in enumerate(tqdm(f)):
line = line.strip()
@ -1225,18 +1243,18 @@ async def main():
continue
if i < 100000:
s3_paths.append(line)
pdf_paths.append(line)
else:
# Randomly replace elements with decreasing probability
j = random.randint(0, i)
if j < 100000:
s3_paths[j] = line
pdf_paths[j] = line
print(f"Found {len(s3_paths)} PDF paths in input list")
print(f"Found {len(pdf_paths)} PDF paths in input list")
# Shuffle and limit to max_tests
random.shuffle(s3_paths)
s3_paths = s3_paths[: args.max_tests]
random.shuffle(pdf_paths)
pdf_paths = pdf_paths[: args.max_tests]
# Initialize the JSONL file in bench_data folder with the specified name
bench_data_dir = os.path.join(args.output_dir, "bench_data")
@ -1246,7 +1264,7 @@ async def main():
# Counter for test statistics
test_counter = 0
test_types = {"present": 0, "absent": 0, "table": 0, "order": 0}
test_types = defaultdict(int) # Automatically handles any test type
results = []
# Initialize an asyncio lock for file access
@ -1254,7 +1272,7 @@ async def main():
# Process PDFs in parallel using asyncio
async def process_with_progress(pdf_info):
s3_path = pdf_info[0]
pdf_path = pdf_info[0]
try:
result = await process_pdf(pdf_info, args, client, pdf_filter)
if result and result.get("tests"):
@ -1269,21 +1287,20 @@ async def main():
nonlocal test_counter
test_counter += len(result["tests"])
for test in result["tests"]:
test_type = test.get("type", "")
if test_type in test_types:
test_types[test_type] += 1
test_type = test.get("type", "unknown")
test_types[test_type] += 1
print(f"Added {len(result['tests'])} tests from {result['pdf_id']}, total: {test_counter}")
return result
except Exception as e:
print(f"Error processing {s3_path}: {e}")
print(f"Error processing {pdf_path}: {e}")
return None
# Create tasks for all PDFs
tasks = []
for i, s3_path in enumerate(s3_paths):
tasks.append(process_with_progress((s3_path, i)))
for i, pdf_path in enumerate(pdf_paths):
tasks.append(process_with_progress((pdf_path, i)))
# Run tasks with limited concurrency
semaphore = asyncio.Semaphore(args.parallel)