mirror of
https://github.com/allenai/olmocr.git
synced 2025-11-03 11:35:29 +00:00
Local path supported
This commit is contained in:
parent
5cbe331259
commit
f3ea1527ef
@ -7,6 +7,7 @@ import random
|
||||
import re
|
||||
import subprocess
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List
|
||||
|
||||
import pypdf
|
||||
@ -25,11 +26,28 @@ from olmocr.data.renderpdf import (
|
||||
from olmocr.filter.filter import PdfFilter, Language
|
||||
|
||||
|
||||
def download_s3_pdf(s3_path, local_path):
|
||||
"""Download a PDF from S3 to a local path."""
|
||||
def download_s3_pdf(path, local_path):
|
||||
"""Download a PDF from S3 or copy from local path."""
|
||||
os.makedirs(os.path.dirname(local_path), exist_ok=True)
|
||||
result = subprocess.run(["aws", "s3", "cp", s3_path, local_path], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
||||
return result.returncode == 0
|
||||
|
||||
# Check if it's a local path
|
||||
if os.path.exists(path):
|
||||
# It's a local file, just copy it
|
||||
import shutil
|
||||
try:
|
||||
shutil.copy2(path, local_path)
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"Failed to copy local file {path}: {e}")
|
||||
return False
|
||||
elif path.startswith("s3://"):
|
||||
# It's an S3 path, download it
|
||||
result = subprocess.run(["aws", "s3", "cp", path, local_path], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
||||
return result.returncode == 0
|
||||
else:
|
||||
# Assume it's a relative local path that doesn't exist yet
|
||||
print(f"Path not found and doesn't appear to be S3: {path}")
|
||||
return False
|
||||
|
||||
|
||||
class PreserveTablesConverter(MarkdownConverter):
|
||||
@ -1012,7 +1030,7 @@ def generate_tests_from_html(html_content: str, pdf_id: str, page_num: int, verb
|
||||
|
||||
async def process_pdf(pdf_info, args, client, pdf_filter=None):
|
||||
"""Process a single PDF, render a random page, and create an HTML template."""
|
||||
s3_path, index = pdf_info
|
||||
pdf_path, index = pdf_info
|
||||
|
||||
# Create a unique folder for each PDF in the temp directory
|
||||
pdf_id = f"pdf_{index:05d}"
|
||||
@ -1022,15 +1040,15 @@ async def process_pdf(pdf_info, args, client, pdf_filter=None):
|
||||
# Determine if we should log table test verification
|
||||
verbose_table_testing = args.verbose
|
||||
|
||||
# Download PDF to local temp directory
|
||||
# Download PDF to local temp directory (or copy if local)
|
||||
local_pdf_path = os.path.join(temp_pdf_dir, "document.pdf")
|
||||
if not download_s3_pdf(s3_path, local_pdf_path):
|
||||
print(f"Failed to download PDF from {s3_path}")
|
||||
if not download_s3_pdf(pdf_path, local_pdf_path):
|
||||
print(f"Failed to download/copy PDF from {pdf_path}")
|
||||
return None
|
||||
|
||||
# Apply filter if enabled
|
||||
if pdf_filter and pdf_filter.filter_out_pdf(local_pdf_path):
|
||||
print(f"PDF filtered out: {s3_path}")
|
||||
print(f"PDF filtered out: {pdf_path}")
|
||||
return None
|
||||
|
||||
try:
|
||||
@ -1039,7 +1057,7 @@ async def process_pdf(pdf_info, args, client, pdf_filter=None):
|
||||
num_pages = len(reader.pages)
|
||||
|
||||
if num_pages == 0:
|
||||
print(f"PDF has no pages: {s3_path}")
|
||||
print(f"PDF has no pages: {pdf_path}")
|
||||
return None
|
||||
|
||||
# Select a random page
|
||||
@ -1052,7 +1070,7 @@ async def process_pdf(pdf_info, args, client, pdf_filter=None):
|
||||
# Generate HTML from the image
|
||||
html_content = await generate_html_from_image(client, image_base64)
|
||||
if not html_content:
|
||||
print(f"Failed to generate HTML for {s3_path}, page {page_num}")
|
||||
print(f"Failed to generate HTML for {pdf_path}, page {page_num}")
|
||||
return None
|
||||
|
||||
# Create output directories
|
||||
@ -1157,7 +1175,7 @@ async def process_pdf(pdf_info, args, client, pdf_filter=None):
|
||||
|
||||
return {
|
||||
"pdf_id": pdf_id,
|
||||
"s3_path": s3_path,
|
||||
"pdf_path": pdf_path,
|
||||
"page_number": page_num,
|
||||
"html_path": html_path,
|
||||
"markdown_path": markdown_path,
|
||||
@ -1167,7 +1185,7 @@ async def process_pdf(pdf_info, args, client, pdf_filter=None):
|
||||
"num_tests": len(tests),
|
||||
}
|
||||
except Exception as e:
|
||||
print(f"Error processing {s3_path}: {e}")
|
||||
print(f"Error processing {pdf_path}: {e}")
|
||||
return None
|
||||
finally:
|
||||
# Clean up temp directory for this PDF
|
||||
@ -1181,7 +1199,7 @@ async def main():
|
||||
logging.getLogger("httpcore").setLevel(logging.WARNING)
|
||||
|
||||
parser = argparse.ArgumentParser(description="Convert PDFs to HTML templates and render with Playwright")
|
||||
parser.add_argument("--input_list", required=True, help="Path to a file containing S3 paths to PDFs")
|
||||
parser.add_argument("--input_list", required=True, help="Path to a file containing S3 paths or local paths to PDFs")
|
||||
parser.add_argument("--output_dir", required=True, help="Directory to store extracted pages and tests")
|
||||
parser.add_argument("--temp_dir", default="/tmp/mine_tables", help="Directory for temporary files")
|
||||
parser.add_argument("--max_tests", type=int, default=100, help="Maximum number of tests to generate")
|
||||
@ -1217,7 +1235,7 @@ async def main():
|
||||
print("PDF filtering enabled")
|
||||
|
||||
# Reservoir sampling implementation
|
||||
s3_paths = []
|
||||
pdf_paths = []
|
||||
with open(args.input_list, "r") as f:
|
||||
for i, line in enumerate(tqdm(f)):
|
||||
line = line.strip()
|
||||
@ -1225,18 +1243,18 @@ async def main():
|
||||
continue
|
||||
|
||||
if i < 100000:
|
||||
s3_paths.append(line)
|
||||
pdf_paths.append(line)
|
||||
else:
|
||||
# Randomly replace elements with decreasing probability
|
||||
j = random.randint(0, i)
|
||||
if j < 100000:
|
||||
s3_paths[j] = line
|
||||
pdf_paths[j] = line
|
||||
|
||||
print(f"Found {len(s3_paths)} PDF paths in input list")
|
||||
print(f"Found {len(pdf_paths)} PDF paths in input list")
|
||||
|
||||
# Shuffle and limit to max_tests
|
||||
random.shuffle(s3_paths)
|
||||
s3_paths = s3_paths[: args.max_tests]
|
||||
random.shuffle(pdf_paths)
|
||||
pdf_paths = pdf_paths[: args.max_tests]
|
||||
|
||||
# Initialize the JSONL file in bench_data folder with the specified name
|
||||
bench_data_dir = os.path.join(args.output_dir, "bench_data")
|
||||
@ -1246,7 +1264,7 @@ async def main():
|
||||
|
||||
# Counter for test statistics
|
||||
test_counter = 0
|
||||
test_types = {"present": 0, "absent": 0, "table": 0, "order": 0}
|
||||
test_types = defaultdict(int) # Automatically handles any test type
|
||||
results = []
|
||||
|
||||
# Initialize an asyncio lock for file access
|
||||
@ -1254,7 +1272,7 @@ async def main():
|
||||
|
||||
# Process PDFs in parallel using asyncio
|
||||
async def process_with_progress(pdf_info):
|
||||
s3_path = pdf_info[0]
|
||||
pdf_path = pdf_info[0]
|
||||
try:
|
||||
result = await process_pdf(pdf_info, args, client, pdf_filter)
|
||||
if result and result.get("tests"):
|
||||
@ -1269,21 +1287,20 @@ async def main():
|
||||
nonlocal test_counter
|
||||
test_counter += len(result["tests"])
|
||||
for test in result["tests"]:
|
||||
test_type = test.get("type", "")
|
||||
if test_type in test_types:
|
||||
test_types[test_type] += 1
|
||||
test_type = test.get("type", "unknown")
|
||||
test_types[test_type] += 1
|
||||
|
||||
print(f"Added {len(result['tests'])} tests from {result['pdf_id']}, total: {test_counter}")
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
print(f"Error processing {s3_path}: {e}")
|
||||
print(f"Error processing {pdf_path}: {e}")
|
||||
return None
|
||||
|
||||
# Create tasks for all PDFs
|
||||
tasks = []
|
||||
for i, s3_path in enumerate(s3_paths):
|
||||
tasks.append(process_with_progress((s3_path, i)))
|
||||
for i, pdf_path in enumerate(pdf_paths):
|
||||
tasks.append(process_with_progress((pdf_path, i)))
|
||||
|
||||
# Run tasks with limited concurrency
|
||||
semaphore = asyncio.Semaphore(args.parallel)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user