diff --git a/olmocr/bench/synth/mine_html_templates.py b/olmocr/bench/synth/mine_html_templates.py index f7a0f7c..ca3826a 100644 --- a/olmocr/bench/synth/mine_html_templates.py +++ b/olmocr/bench/synth/mine_html_templates.py @@ -1,6 +1,5 @@ import argparse import asyncio -import concurrent.futures import json import logging import os @@ -8,11 +7,10 @@ import random import re import subprocess import uuid -from concurrent.futures import ThreadPoolExecutor from typing import Dict, List import pypdf -from anthropic import Anthropic +from anthropic import AsyncAnthropic from bs4 import BeautifulSoup from markdownify import MarkdownConverter, SPACES from playwright.async_api import async_playwright @@ -226,13 +224,13 @@ def extract_code_block(initial_response): return None -def generate_html_from_image(client, image_base64): +async def generate_html_from_image(client, image_base64): """Call Claude API to generate HTML from an image using a multi-step prompting strategy.""" png_width, png_height = get_png_dimensions_from_base64(image_base64) try: # Step 1: Initial analysis and column detection - analysis_response = client.messages.create( + analysis_response = await client.messages.create( model="claude-sonnet-4-20250514", max_tokens=2000, temperature=0.1, @@ -261,7 +259,7 @@ def generate_html_from_image(client, image_base64): analysis_text += content.text # Step 2: Initial HTML generation with detailed layout instructions - initial_response = client.messages.create( + initial_response = await client.messages.create( model="claude-sonnet-4-20250514", max_tokens=6000, temperature=0.2, @@ -1012,7 +1010,7 @@ def generate_tests_from_html(html_content: str, pdf_id: str, page_num: int, verb return validated_tests -def process_pdf(pdf_info, args, client, pdf_filter=None): +async def process_pdf(pdf_info, args, client, pdf_filter=None): """Process a single PDF, render a random page, and create an HTML template.""" s3_path, index = pdf_info @@ -1047,11 +1045,12 @@ def process_pdf(pdf_info, args, client, pdf_filter=None): # Select a random page page_num = random.randint(1, num_pages) - # Render the page as a base64 PNG - image_base64 = render_pdf_to_base64png(local_pdf_path, page_num, target_longest_image_dim=2048) + # Render the page as a base64 PNG (run in thread pool since it's blocking I/O) + loop = asyncio.get_event_loop() + image_base64 = await loop.run_in_executor(None, render_pdf_to_base64png, local_pdf_path, page_num, 2048) # Generate HTML from the image - html_content = generate_html_from_image(client, image_base64) + html_content = await generate_html_from_image(client, image_base64) if not html_content: print(f"Failed to generate HTML for {s3_path}, page {page_num}") return None @@ -1117,14 +1116,8 @@ def process_pdf(pdf_info, args, client, pdf_filter=None): # Get PNG dimensions png_width, png_height = get_png_dimensions_from_base64(image_base64) - # Run the async function in the synchronous context - # Create a new event loop to avoid conflicts - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - render_success = loop.run_until_complete(render_pdf_with_playwright(html_content, playwright_pdf_path, png_width, png_height)) - finally: - loop.close() + # Run the async function directly since we're already in an async context + render_success = await render_pdf_with_playwright(html_content, playwright_pdf_path, png_width, png_height) if render_success: print(f"Successfully rendered with Playwright: {playwright_pdf_path}") @@ -1182,7 +1175,7 @@ def process_pdf(pdf_info, args, client, pdf_filter=None): subprocess.run(["rm", "-rf", temp_pdf_dir]) -def main(): +async def main(): # Configure logging to suppress httpx messages logging.getLogger("httpx").setLevel(logging.WARNING) logging.getLogger("httpcore").setLevel(logging.WARNING) @@ -1192,7 +1185,7 @@ def main(): parser.add_argument("--output_dir", required=True, help="Directory to store extracted pages and tests") parser.add_argument("--temp_dir", default="/tmp/mine_tables", help="Directory for temporary files") parser.add_argument("--max_tests", type=int, default=100, help="Maximum number of tests to generate") - parser.add_argument("--parallel", type=int, default=1, help="Number of parallel threads to use") + parser.add_argument("--parallel", type=int, default=1, help="Number of parallel tasks to use") parser.add_argument("--api_key", help="Claude API key (or set ANTHROPIC_API_KEY environment variable)") parser.add_argument("--skip_playwright", action="store_true", help="Skip Playwright PDF rendering") parser.add_argument("--verbose", action="store_true", help="Enable verbose output including table test verification") @@ -1210,8 +1203,8 @@ def main(): print("Error: API key not provided. Use --api_key or set ANTHROPIC_API_KEY environment variable.") return - # Initialize Claude client - client = Anthropic(api_key=api_key) + # Initialize async Claude client + client = AsyncAnthropic(api_key=api_key) # Initialize PDF filter if enabled pdf_filter = None @@ -1256,41 +1249,56 @@ def main(): test_types = {"present": 0, "absent": 0, "table": 0, "order": 0} results = [] - # Initialize a threading lock for file access - import threading + # Initialize an asyncio lock for file access + file_lock = asyncio.Lock() - file_lock = threading.Lock() - - # Process PDFs in parallel - with ThreadPoolExecutor(max_workers=args.parallel) as executor: - # Submit all tasks - futures = {executor.submit(process_pdf, (s3_path, i), args, client, pdf_filter): s3_path for i, s3_path in enumerate(s3_paths)} - - # Process results as they complete - for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures), desc="Processing PDFs"): - s3_path = futures[future] - try: - result = future.result() - if result and result.get("tests"): - results.append(result) - - # Append tests to synthetic.json as they're created (JSONL format) - with file_lock: - # Append each test as a separate JSON line - with open(synthetic_json_path, "a") as f: - for test in result["tests"]: - f.write(json.dumps(test) + "\n") - - # Update counters - test_counter += len(result["tests"]) + # Process PDFs in parallel using asyncio + async def process_with_progress(pdf_info): + s3_path = pdf_info[0] + try: + result = await process_pdf(pdf_info, args, client, pdf_filter) + if result and result.get("tests"): + # Append tests to synthetic.json as they're created (JSONL format) + async with file_lock: + # Append each test as a separate JSON line + with open(synthetic_json_path, "a") as f: for test in result["tests"]: - test_type = test.get("type", "") - if test_type in test_types: - test_types[test_type] += 1 + f.write(json.dumps(test) + "\n") + + # Update counters + nonlocal test_counter + test_counter += len(result["tests"]) + for test in result["tests"]: + test_type = test.get("type", "") + if test_type in test_types: + test_types[test_type] += 1 + + print(f"Added {len(result['tests'])} tests from {result['pdf_id']}, total: {test_counter}") + + return result + except Exception as e: + print(f"Error processing {s3_path}: {e}") + return None - print(f"Added {len(result['tests'])} tests from {result['pdf_id']}, total: {test_counter}") - except Exception as e: - print(f"Error processing {s3_path}: {e}") + # Create tasks for all PDFs + tasks = [] + for i, s3_path in enumerate(s3_paths): + tasks.append(process_with_progress((s3_path, i))) + + # Run tasks with limited concurrency + semaphore = asyncio.Semaphore(args.parallel) + + async def bounded_task(task_coro): + async with semaphore: + return await task_coro + + bounded_tasks = [bounded_task(task) for task in tasks] + + # Process all tasks with progress bar + for coro in tqdm(asyncio.as_completed(bounded_tasks), total=len(bounded_tasks), desc="Processing PDFs"): + result = await coro + if result: + results.append(result) print(f"Generated {len(results)} HTML templates") @@ -1312,4 +1320,4 @@ def main(): if __name__ == "__main__": - main() + asyncio.run(main())