Local path supported

This commit is contained in:
Jake Poznanski 2025-08-22 20:11:35 +00:00
parent 5cbe331259
commit f3ea1527ef

View File

@ -7,6 +7,7 @@ import random
import re import re
import subprocess import subprocess
import uuid import uuid
from collections import defaultdict
from typing import Dict, List from typing import Dict, List
import pypdf import pypdf
@ -25,11 +26,28 @@ from olmocr.data.renderpdf import (
from olmocr.filter.filter import PdfFilter, Language from olmocr.filter.filter import PdfFilter, Language
def download_s3_pdf(s3_path, local_path): def download_s3_pdf(path, local_path):
"""Download a PDF from S3 to a local path.""" """Download a PDF from S3 or copy from local path."""
os.makedirs(os.path.dirname(local_path), exist_ok=True) os.makedirs(os.path.dirname(local_path), exist_ok=True)
result = subprocess.run(["aws", "s3", "cp", s3_path, local_path], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
return result.returncode == 0 # Check if it's a local path
if os.path.exists(path):
# It's a local file, just copy it
import shutil
try:
shutil.copy2(path, local_path)
return True
except Exception as e:
print(f"Failed to copy local file {path}: {e}")
return False
elif path.startswith("s3://"):
# It's an S3 path, download it
result = subprocess.run(["aws", "s3", "cp", path, local_path], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
return result.returncode == 0
else:
# Assume it's a relative local path that doesn't exist yet
print(f"Path not found and doesn't appear to be S3: {path}")
return False
class PreserveTablesConverter(MarkdownConverter): class PreserveTablesConverter(MarkdownConverter):
@ -1012,7 +1030,7 @@ def generate_tests_from_html(html_content: str, pdf_id: str, page_num: int, verb
async def process_pdf(pdf_info, args, client, pdf_filter=None): async def process_pdf(pdf_info, args, client, pdf_filter=None):
"""Process a single PDF, render a random page, and create an HTML template.""" """Process a single PDF, render a random page, and create an HTML template."""
s3_path, index = pdf_info pdf_path, index = pdf_info
# Create a unique folder for each PDF in the temp directory # Create a unique folder for each PDF in the temp directory
pdf_id = f"pdf_{index:05d}" pdf_id = f"pdf_{index:05d}"
@ -1022,15 +1040,15 @@ async def process_pdf(pdf_info, args, client, pdf_filter=None):
# Determine if we should log table test verification # Determine if we should log table test verification
verbose_table_testing = args.verbose verbose_table_testing = args.verbose
# Download PDF to local temp directory # Download PDF to local temp directory (or copy if local)
local_pdf_path = os.path.join(temp_pdf_dir, "document.pdf") local_pdf_path = os.path.join(temp_pdf_dir, "document.pdf")
if not download_s3_pdf(s3_path, local_pdf_path): if not download_s3_pdf(pdf_path, local_pdf_path):
print(f"Failed to download PDF from {s3_path}") print(f"Failed to download/copy PDF from {pdf_path}")
return None return None
# Apply filter if enabled # Apply filter if enabled
if pdf_filter and pdf_filter.filter_out_pdf(local_pdf_path): if pdf_filter and pdf_filter.filter_out_pdf(local_pdf_path):
print(f"PDF filtered out: {s3_path}") print(f"PDF filtered out: {pdf_path}")
return None return None
try: try:
@ -1039,7 +1057,7 @@ async def process_pdf(pdf_info, args, client, pdf_filter=None):
num_pages = len(reader.pages) num_pages = len(reader.pages)
if num_pages == 0: if num_pages == 0:
print(f"PDF has no pages: {s3_path}") print(f"PDF has no pages: {pdf_path}")
return None return None
# Select a random page # Select a random page
@ -1052,7 +1070,7 @@ async def process_pdf(pdf_info, args, client, pdf_filter=None):
# Generate HTML from the image # Generate HTML from the image
html_content = await generate_html_from_image(client, image_base64) html_content = await generate_html_from_image(client, image_base64)
if not html_content: if not html_content:
print(f"Failed to generate HTML for {s3_path}, page {page_num}") print(f"Failed to generate HTML for {pdf_path}, page {page_num}")
return None return None
# Create output directories # Create output directories
@ -1157,7 +1175,7 @@ async def process_pdf(pdf_info, args, client, pdf_filter=None):
return { return {
"pdf_id": pdf_id, "pdf_id": pdf_id,
"s3_path": s3_path, "pdf_path": pdf_path,
"page_number": page_num, "page_number": page_num,
"html_path": html_path, "html_path": html_path,
"markdown_path": markdown_path, "markdown_path": markdown_path,
@ -1167,7 +1185,7 @@ async def process_pdf(pdf_info, args, client, pdf_filter=None):
"num_tests": len(tests), "num_tests": len(tests),
} }
except Exception as e: except Exception as e:
print(f"Error processing {s3_path}: {e}") print(f"Error processing {pdf_path}: {e}")
return None return None
finally: finally:
# Clean up temp directory for this PDF # Clean up temp directory for this PDF
@ -1181,7 +1199,7 @@ async def main():
logging.getLogger("httpcore").setLevel(logging.WARNING) logging.getLogger("httpcore").setLevel(logging.WARNING)
parser = argparse.ArgumentParser(description="Convert PDFs to HTML templates and render with Playwright") parser = argparse.ArgumentParser(description="Convert PDFs to HTML templates and render with Playwright")
parser.add_argument("--input_list", required=True, help="Path to a file containing S3 paths to PDFs") parser.add_argument("--input_list", required=True, help="Path to a file containing S3 paths or local paths to PDFs")
parser.add_argument("--output_dir", required=True, help="Directory to store extracted pages and tests") parser.add_argument("--output_dir", required=True, help="Directory to store extracted pages and tests")
parser.add_argument("--temp_dir", default="/tmp/mine_tables", help="Directory for temporary files") parser.add_argument("--temp_dir", default="/tmp/mine_tables", help="Directory for temporary files")
parser.add_argument("--max_tests", type=int, default=100, help="Maximum number of tests to generate") parser.add_argument("--max_tests", type=int, default=100, help="Maximum number of tests to generate")
@ -1217,7 +1235,7 @@ async def main():
print("PDF filtering enabled") print("PDF filtering enabled")
# Reservoir sampling implementation # Reservoir sampling implementation
s3_paths = [] pdf_paths = []
with open(args.input_list, "r") as f: with open(args.input_list, "r") as f:
for i, line in enumerate(tqdm(f)): for i, line in enumerate(tqdm(f)):
line = line.strip() line = line.strip()
@ -1225,18 +1243,18 @@ async def main():
continue continue
if i < 100000: if i < 100000:
s3_paths.append(line) pdf_paths.append(line)
else: else:
# Randomly replace elements with decreasing probability # Randomly replace elements with decreasing probability
j = random.randint(0, i) j = random.randint(0, i)
if j < 100000: if j < 100000:
s3_paths[j] = line pdf_paths[j] = line
print(f"Found {len(s3_paths)} PDF paths in input list") print(f"Found {len(pdf_paths)} PDF paths in input list")
# Shuffle and limit to max_tests # Shuffle and limit to max_tests
random.shuffle(s3_paths) random.shuffle(pdf_paths)
s3_paths = s3_paths[: args.max_tests] pdf_paths = pdf_paths[: args.max_tests]
# Initialize the JSONL file in bench_data folder with the specified name # Initialize the JSONL file in bench_data folder with the specified name
bench_data_dir = os.path.join(args.output_dir, "bench_data") bench_data_dir = os.path.join(args.output_dir, "bench_data")
@ -1246,7 +1264,7 @@ async def main():
# Counter for test statistics # Counter for test statistics
test_counter = 0 test_counter = 0
test_types = {"present": 0, "absent": 0, "table": 0, "order": 0} test_types = defaultdict(int) # Automatically handles any test type
results = [] results = []
# Initialize an asyncio lock for file access # Initialize an asyncio lock for file access
@ -1254,7 +1272,7 @@ async def main():
# Process PDFs in parallel using asyncio # Process PDFs in parallel using asyncio
async def process_with_progress(pdf_info): async def process_with_progress(pdf_info):
s3_path = pdf_info[0] pdf_path = pdf_info[0]
try: try:
result = await process_pdf(pdf_info, args, client, pdf_filter) result = await process_pdf(pdf_info, args, client, pdf_filter)
if result and result.get("tests"): if result and result.get("tests"):
@ -1269,21 +1287,20 @@ async def main():
nonlocal test_counter nonlocal test_counter
test_counter += len(result["tests"]) test_counter += len(result["tests"])
for test in result["tests"]: for test in result["tests"]:
test_type = test.get("type", "") test_type = test.get("type", "unknown")
if test_type in test_types: test_types[test_type] += 1
test_types[test_type] += 1
print(f"Added {len(result['tests'])} tests from {result['pdf_id']}, total: {test_counter}") print(f"Added {len(result['tests'])} tests from {result['pdf_id']}, total: {test_counter}")
return result return result
except Exception as e: except Exception as e:
print(f"Error processing {s3_path}: {e}") print(f"Error processing {pdf_path}: {e}")
return None return None
# Create tasks for all PDFs # Create tasks for all PDFs
tasks = [] tasks = []
for i, s3_path in enumerate(s3_paths): for i, pdf_path in enumerate(pdf_paths):
tasks.append(process_with_progress((s3_path, i))) tasks.append(process_with_progress((pdf_path, i)))
# Run tasks with limited concurrency # Run tasks with limited concurrency
semaphore = asyncio.Semaphore(args.parallel) semaphore = asyncio.Semaphore(args.parallel)