mirror of
https://github.com/allenai/olmocr.git
synced 2025-11-14 17:38:12 +00:00
Local path supported
This commit is contained in:
parent
5cbe331259
commit
f3ea1527ef
@ -7,6 +7,7 @@ import random
|
|||||||
import re
|
import re
|
||||||
import subprocess
|
import subprocess
|
||||||
import uuid
|
import uuid
|
||||||
|
from collections import defaultdict
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
|
|
||||||
import pypdf
|
import pypdf
|
||||||
@ -25,11 +26,28 @@ from olmocr.data.renderpdf import (
|
|||||||
from olmocr.filter.filter import PdfFilter, Language
|
from olmocr.filter.filter import PdfFilter, Language
|
||||||
|
|
||||||
|
|
||||||
def download_s3_pdf(s3_path, local_path):
|
def download_s3_pdf(path, local_path):
|
||||||
"""Download a PDF from S3 to a local path."""
|
"""Download a PDF from S3 or copy from local path."""
|
||||||
os.makedirs(os.path.dirname(local_path), exist_ok=True)
|
os.makedirs(os.path.dirname(local_path), exist_ok=True)
|
||||||
result = subprocess.run(["aws", "s3", "cp", s3_path, local_path], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
|
||||||
return result.returncode == 0
|
# Check if it's a local path
|
||||||
|
if os.path.exists(path):
|
||||||
|
# It's a local file, just copy it
|
||||||
|
import shutil
|
||||||
|
try:
|
||||||
|
shutil.copy2(path, local_path)
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed to copy local file {path}: {e}")
|
||||||
|
return False
|
||||||
|
elif path.startswith("s3://"):
|
||||||
|
# It's an S3 path, download it
|
||||||
|
result = subprocess.run(["aws", "s3", "cp", path, local_path], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
||||||
|
return result.returncode == 0
|
||||||
|
else:
|
||||||
|
# Assume it's a relative local path that doesn't exist yet
|
||||||
|
print(f"Path not found and doesn't appear to be S3: {path}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
class PreserveTablesConverter(MarkdownConverter):
|
class PreserveTablesConverter(MarkdownConverter):
|
||||||
@ -1012,7 +1030,7 @@ def generate_tests_from_html(html_content: str, pdf_id: str, page_num: int, verb
|
|||||||
|
|
||||||
async def process_pdf(pdf_info, args, client, pdf_filter=None):
|
async def process_pdf(pdf_info, args, client, pdf_filter=None):
|
||||||
"""Process a single PDF, render a random page, and create an HTML template."""
|
"""Process a single PDF, render a random page, and create an HTML template."""
|
||||||
s3_path, index = pdf_info
|
pdf_path, index = pdf_info
|
||||||
|
|
||||||
# Create a unique folder for each PDF in the temp directory
|
# Create a unique folder for each PDF in the temp directory
|
||||||
pdf_id = f"pdf_{index:05d}"
|
pdf_id = f"pdf_{index:05d}"
|
||||||
@ -1022,15 +1040,15 @@ async def process_pdf(pdf_info, args, client, pdf_filter=None):
|
|||||||
# Determine if we should log table test verification
|
# Determine if we should log table test verification
|
||||||
verbose_table_testing = args.verbose
|
verbose_table_testing = args.verbose
|
||||||
|
|
||||||
# Download PDF to local temp directory
|
# Download PDF to local temp directory (or copy if local)
|
||||||
local_pdf_path = os.path.join(temp_pdf_dir, "document.pdf")
|
local_pdf_path = os.path.join(temp_pdf_dir, "document.pdf")
|
||||||
if not download_s3_pdf(s3_path, local_pdf_path):
|
if not download_s3_pdf(pdf_path, local_pdf_path):
|
||||||
print(f"Failed to download PDF from {s3_path}")
|
print(f"Failed to download/copy PDF from {pdf_path}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Apply filter if enabled
|
# Apply filter if enabled
|
||||||
if pdf_filter and pdf_filter.filter_out_pdf(local_pdf_path):
|
if pdf_filter and pdf_filter.filter_out_pdf(local_pdf_path):
|
||||||
print(f"PDF filtered out: {s3_path}")
|
print(f"PDF filtered out: {pdf_path}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -1039,7 +1057,7 @@ async def process_pdf(pdf_info, args, client, pdf_filter=None):
|
|||||||
num_pages = len(reader.pages)
|
num_pages = len(reader.pages)
|
||||||
|
|
||||||
if num_pages == 0:
|
if num_pages == 0:
|
||||||
print(f"PDF has no pages: {s3_path}")
|
print(f"PDF has no pages: {pdf_path}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Select a random page
|
# Select a random page
|
||||||
@ -1052,7 +1070,7 @@ async def process_pdf(pdf_info, args, client, pdf_filter=None):
|
|||||||
# Generate HTML from the image
|
# Generate HTML from the image
|
||||||
html_content = await generate_html_from_image(client, image_base64)
|
html_content = await generate_html_from_image(client, image_base64)
|
||||||
if not html_content:
|
if not html_content:
|
||||||
print(f"Failed to generate HTML for {s3_path}, page {page_num}")
|
print(f"Failed to generate HTML for {pdf_path}, page {page_num}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Create output directories
|
# Create output directories
|
||||||
@ -1157,7 +1175,7 @@ async def process_pdf(pdf_info, args, client, pdf_filter=None):
|
|||||||
|
|
||||||
return {
|
return {
|
||||||
"pdf_id": pdf_id,
|
"pdf_id": pdf_id,
|
||||||
"s3_path": s3_path,
|
"pdf_path": pdf_path,
|
||||||
"page_number": page_num,
|
"page_number": page_num,
|
||||||
"html_path": html_path,
|
"html_path": html_path,
|
||||||
"markdown_path": markdown_path,
|
"markdown_path": markdown_path,
|
||||||
@ -1167,7 +1185,7 @@ async def process_pdf(pdf_info, args, client, pdf_filter=None):
|
|||||||
"num_tests": len(tests),
|
"num_tests": len(tests),
|
||||||
}
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error processing {s3_path}: {e}")
|
print(f"Error processing {pdf_path}: {e}")
|
||||||
return None
|
return None
|
||||||
finally:
|
finally:
|
||||||
# Clean up temp directory for this PDF
|
# Clean up temp directory for this PDF
|
||||||
@ -1181,7 +1199,7 @@ async def main():
|
|||||||
logging.getLogger("httpcore").setLevel(logging.WARNING)
|
logging.getLogger("httpcore").setLevel(logging.WARNING)
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="Convert PDFs to HTML templates and render with Playwright")
|
parser = argparse.ArgumentParser(description="Convert PDFs to HTML templates and render with Playwright")
|
||||||
parser.add_argument("--input_list", required=True, help="Path to a file containing S3 paths to PDFs")
|
parser.add_argument("--input_list", required=True, help="Path to a file containing S3 paths or local paths to PDFs")
|
||||||
parser.add_argument("--output_dir", required=True, help="Directory to store extracted pages and tests")
|
parser.add_argument("--output_dir", required=True, help="Directory to store extracted pages and tests")
|
||||||
parser.add_argument("--temp_dir", default="/tmp/mine_tables", help="Directory for temporary files")
|
parser.add_argument("--temp_dir", default="/tmp/mine_tables", help="Directory for temporary files")
|
||||||
parser.add_argument("--max_tests", type=int, default=100, help="Maximum number of tests to generate")
|
parser.add_argument("--max_tests", type=int, default=100, help="Maximum number of tests to generate")
|
||||||
@ -1217,7 +1235,7 @@ async def main():
|
|||||||
print("PDF filtering enabled")
|
print("PDF filtering enabled")
|
||||||
|
|
||||||
# Reservoir sampling implementation
|
# Reservoir sampling implementation
|
||||||
s3_paths = []
|
pdf_paths = []
|
||||||
with open(args.input_list, "r") as f:
|
with open(args.input_list, "r") as f:
|
||||||
for i, line in enumerate(tqdm(f)):
|
for i, line in enumerate(tqdm(f)):
|
||||||
line = line.strip()
|
line = line.strip()
|
||||||
@ -1225,18 +1243,18 @@ async def main():
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
if i < 100000:
|
if i < 100000:
|
||||||
s3_paths.append(line)
|
pdf_paths.append(line)
|
||||||
else:
|
else:
|
||||||
# Randomly replace elements with decreasing probability
|
# Randomly replace elements with decreasing probability
|
||||||
j = random.randint(0, i)
|
j = random.randint(0, i)
|
||||||
if j < 100000:
|
if j < 100000:
|
||||||
s3_paths[j] = line
|
pdf_paths[j] = line
|
||||||
|
|
||||||
print(f"Found {len(s3_paths)} PDF paths in input list")
|
print(f"Found {len(pdf_paths)} PDF paths in input list")
|
||||||
|
|
||||||
# Shuffle and limit to max_tests
|
# Shuffle and limit to max_tests
|
||||||
random.shuffle(s3_paths)
|
random.shuffle(pdf_paths)
|
||||||
s3_paths = s3_paths[: args.max_tests]
|
pdf_paths = pdf_paths[: args.max_tests]
|
||||||
|
|
||||||
# Initialize the JSONL file in bench_data folder with the specified name
|
# Initialize the JSONL file in bench_data folder with the specified name
|
||||||
bench_data_dir = os.path.join(args.output_dir, "bench_data")
|
bench_data_dir = os.path.join(args.output_dir, "bench_data")
|
||||||
@ -1246,7 +1264,7 @@ async def main():
|
|||||||
|
|
||||||
# Counter for test statistics
|
# Counter for test statistics
|
||||||
test_counter = 0
|
test_counter = 0
|
||||||
test_types = {"present": 0, "absent": 0, "table": 0, "order": 0}
|
test_types = defaultdict(int) # Automatically handles any test type
|
||||||
results = []
|
results = []
|
||||||
|
|
||||||
# Initialize an asyncio lock for file access
|
# Initialize an asyncio lock for file access
|
||||||
@ -1254,7 +1272,7 @@ async def main():
|
|||||||
|
|
||||||
# Process PDFs in parallel using asyncio
|
# Process PDFs in parallel using asyncio
|
||||||
async def process_with_progress(pdf_info):
|
async def process_with_progress(pdf_info):
|
||||||
s3_path = pdf_info[0]
|
pdf_path = pdf_info[0]
|
||||||
try:
|
try:
|
||||||
result = await process_pdf(pdf_info, args, client, pdf_filter)
|
result = await process_pdf(pdf_info, args, client, pdf_filter)
|
||||||
if result and result.get("tests"):
|
if result and result.get("tests"):
|
||||||
@ -1269,21 +1287,20 @@ async def main():
|
|||||||
nonlocal test_counter
|
nonlocal test_counter
|
||||||
test_counter += len(result["tests"])
|
test_counter += len(result["tests"])
|
||||||
for test in result["tests"]:
|
for test in result["tests"]:
|
||||||
test_type = test.get("type", "")
|
test_type = test.get("type", "unknown")
|
||||||
if test_type in test_types:
|
test_types[test_type] += 1
|
||||||
test_types[test_type] += 1
|
|
||||||
|
|
||||||
print(f"Added {len(result['tests'])} tests from {result['pdf_id']}, total: {test_counter}")
|
print(f"Added {len(result['tests'])} tests from {result['pdf_id']}, total: {test_counter}")
|
||||||
|
|
||||||
return result
|
return result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error processing {s3_path}: {e}")
|
print(f"Error processing {pdf_path}: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Create tasks for all PDFs
|
# Create tasks for all PDFs
|
||||||
tasks = []
|
tasks = []
|
||||||
for i, s3_path in enumerate(s3_paths):
|
for i, pdf_path in enumerate(pdf_paths):
|
||||||
tasks.append(process_with_progress((s3_path, i)))
|
tasks.append(process_with_progress((pdf_path, i)))
|
||||||
|
|
||||||
# Run tasks with limited concurrency
|
# Run tasks with limited concurrency
|
||||||
semaphore = asyncio.Semaphore(args.parallel)
|
semaphore = asyncio.Semaphore(args.parallel)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user