mirror of
				https://github.com/allenai/olmocr.git
				synced 2025-11-03 19:45:41 +00:00 
			
		
		
		
	Async mode
This commit is contained in:
		
							parent
							
								
									0df56e958e
								
							
						
					
					
						commit
						5cbe331259
					
				@ -1,6 +1,5 @@
 | 
				
			|||||||
import argparse
 | 
					import argparse
 | 
				
			||||||
import asyncio
 | 
					import asyncio
 | 
				
			||||||
import concurrent.futures
 | 
					 | 
				
			||||||
import json
 | 
					import json
 | 
				
			||||||
import logging
 | 
					import logging
 | 
				
			||||||
import os
 | 
					import os
 | 
				
			||||||
@ -8,11 +7,10 @@ import random
 | 
				
			|||||||
import re
 | 
					import re
 | 
				
			||||||
import subprocess
 | 
					import subprocess
 | 
				
			||||||
import uuid
 | 
					import uuid
 | 
				
			||||||
from concurrent.futures import ThreadPoolExecutor
 | 
					 | 
				
			||||||
from typing import Dict, List
 | 
					from typing import Dict, List
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import pypdf
 | 
					import pypdf
 | 
				
			||||||
from anthropic import Anthropic
 | 
					from anthropic import AsyncAnthropic
 | 
				
			||||||
from bs4 import BeautifulSoup
 | 
					from bs4 import BeautifulSoup
 | 
				
			||||||
from markdownify import MarkdownConverter, SPACES
 | 
					from markdownify import MarkdownConverter, SPACES
 | 
				
			||||||
from playwright.async_api import async_playwright
 | 
					from playwright.async_api import async_playwright
 | 
				
			||||||
@ -226,13 +224,13 @@ def extract_code_block(initial_response):
 | 
				
			|||||||
    return None
 | 
					    return None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def generate_html_from_image(client, image_base64):
 | 
					async def generate_html_from_image(client, image_base64):
 | 
				
			||||||
    """Call Claude API to generate HTML from an image using a multi-step prompting strategy."""
 | 
					    """Call Claude API to generate HTML from an image using a multi-step prompting strategy."""
 | 
				
			||||||
    png_width, png_height = get_png_dimensions_from_base64(image_base64)
 | 
					    png_width, png_height = get_png_dimensions_from_base64(image_base64)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    try:
 | 
					    try:
 | 
				
			||||||
        # Step 1: Initial analysis and column detection
 | 
					        # Step 1: Initial analysis and column detection
 | 
				
			||||||
        analysis_response = client.messages.create(
 | 
					        analysis_response = await client.messages.create(
 | 
				
			||||||
            model="claude-sonnet-4-20250514",
 | 
					            model="claude-sonnet-4-20250514",
 | 
				
			||||||
            max_tokens=2000,
 | 
					            max_tokens=2000,
 | 
				
			||||||
            temperature=0.1,
 | 
					            temperature=0.1,
 | 
				
			||||||
@ -261,7 +259,7 @@ def generate_html_from_image(client, image_base64):
 | 
				
			|||||||
                analysis_text += content.text
 | 
					                analysis_text += content.text
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Step 2: Initial HTML generation with detailed layout instructions
 | 
					        # Step 2: Initial HTML generation with detailed layout instructions
 | 
				
			||||||
        initial_response = client.messages.create(
 | 
					        initial_response = await client.messages.create(
 | 
				
			||||||
            model="claude-sonnet-4-20250514",
 | 
					            model="claude-sonnet-4-20250514",
 | 
				
			||||||
            max_tokens=6000,
 | 
					            max_tokens=6000,
 | 
				
			||||||
            temperature=0.2,
 | 
					            temperature=0.2,
 | 
				
			||||||
@ -1012,7 +1010,7 @@ def generate_tests_from_html(html_content: str, pdf_id: str, page_num: int, verb
 | 
				
			|||||||
    return validated_tests
 | 
					    return validated_tests
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def process_pdf(pdf_info, args, client, pdf_filter=None):
 | 
					async def process_pdf(pdf_info, args, client, pdf_filter=None):
 | 
				
			||||||
    """Process a single PDF, render a random page, and create an HTML template."""
 | 
					    """Process a single PDF, render a random page, and create an HTML template."""
 | 
				
			||||||
    s3_path, index = pdf_info
 | 
					    s3_path, index = pdf_info
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -1047,11 +1045,12 @@ def process_pdf(pdf_info, args, client, pdf_filter=None):
 | 
				
			|||||||
        # Select a random page
 | 
					        # Select a random page
 | 
				
			||||||
        page_num = random.randint(1, num_pages)
 | 
					        page_num = random.randint(1, num_pages)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Render the page as a base64 PNG
 | 
					        # Render the page as a base64 PNG (run in thread pool since it's blocking I/O)
 | 
				
			||||||
        image_base64 = render_pdf_to_base64png(local_pdf_path, page_num, target_longest_image_dim=2048)
 | 
					        loop = asyncio.get_event_loop()
 | 
				
			||||||
 | 
					        image_base64 = await loop.run_in_executor(None, render_pdf_to_base64png, local_pdf_path, page_num, 2048)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Generate HTML from the image
 | 
					        # Generate HTML from the image
 | 
				
			||||||
        html_content = generate_html_from_image(client, image_base64)
 | 
					        html_content = await generate_html_from_image(client, image_base64)
 | 
				
			||||||
        if not html_content:
 | 
					        if not html_content:
 | 
				
			||||||
            print(f"Failed to generate HTML for {s3_path}, page {page_num}")
 | 
					            print(f"Failed to generate HTML for {s3_path}, page {page_num}")
 | 
				
			||||||
            return None
 | 
					            return None
 | 
				
			||||||
@ -1117,14 +1116,8 @@ def process_pdf(pdf_info, args, client, pdf_filter=None):
 | 
				
			|||||||
                # Get PNG dimensions
 | 
					                # Get PNG dimensions
 | 
				
			||||||
                png_width, png_height = get_png_dimensions_from_base64(image_base64)
 | 
					                png_width, png_height = get_png_dimensions_from_base64(image_base64)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                # Run the async function in the synchronous context
 | 
					                # Run the async function directly since we're already in an async context
 | 
				
			||||||
                # Create a new event loop to avoid conflicts
 | 
					                render_success = await render_pdf_with_playwright(html_content, playwright_pdf_path, png_width, png_height)
 | 
				
			||||||
                loop = asyncio.new_event_loop()
 | 
					 | 
				
			||||||
                asyncio.set_event_loop(loop)
 | 
					 | 
				
			||||||
                try:
 | 
					 | 
				
			||||||
                    render_success = loop.run_until_complete(render_pdf_with_playwright(html_content, playwright_pdf_path, png_width, png_height))
 | 
					 | 
				
			||||||
                finally:
 | 
					 | 
				
			||||||
                    loop.close()
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
                if render_success:
 | 
					                if render_success:
 | 
				
			||||||
                    print(f"Successfully rendered with Playwright: {playwright_pdf_path}")
 | 
					                    print(f"Successfully rendered with Playwright: {playwright_pdf_path}")
 | 
				
			||||||
@ -1182,7 +1175,7 @@ def process_pdf(pdf_info, args, client, pdf_filter=None):
 | 
				
			|||||||
            subprocess.run(["rm", "-rf", temp_pdf_dir])
 | 
					            subprocess.run(["rm", "-rf", temp_pdf_dir])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def main():
 | 
					async def main():
 | 
				
			||||||
    # Configure logging to suppress httpx messages
 | 
					    # Configure logging to suppress httpx messages
 | 
				
			||||||
    logging.getLogger("httpx").setLevel(logging.WARNING)
 | 
					    logging.getLogger("httpx").setLevel(logging.WARNING)
 | 
				
			||||||
    logging.getLogger("httpcore").setLevel(logging.WARNING)
 | 
					    logging.getLogger("httpcore").setLevel(logging.WARNING)
 | 
				
			||||||
@ -1192,7 +1185,7 @@ def main():
 | 
				
			|||||||
    parser.add_argument("--output_dir", required=True, help="Directory to store extracted pages and tests")
 | 
					    parser.add_argument("--output_dir", required=True, help="Directory to store extracted pages and tests")
 | 
				
			||||||
    parser.add_argument("--temp_dir", default="/tmp/mine_tables", help="Directory for temporary files")
 | 
					    parser.add_argument("--temp_dir", default="/tmp/mine_tables", help="Directory for temporary files")
 | 
				
			||||||
    parser.add_argument("--max_tests", type=int, default=100, help="Maximum number of tests to generate")
 | 
					    parser.add_argument("--max_tests", type=int, default=100, help="Maximum number of tests to generate")
 | 
				
			||||||
    parser.add_argument("--parallel", type=int, default=1, help="Number of parallel threads to use")
 | 
					    parser.add_argument("--parallel", type=int, default=1, help="Number of parallel tasks to use")
 | 
				
			||||||
    parser.add_argument("--api_key", help="Claude API key (or set ANTHROPIC_API_KEY environment variable)")
 | 
					    parser.add_argument("--api_key", help="Claude API key (or set ANTHROPIC_API_KEY environment variable)")
 | 
				
			||||||
    parser.add_argument("--skip_playwright", action="store_true", help="Skip Playwright PDF rendering")
 | 
					    parser.add_argument("--skip_playwright", action="store_true", help="Skip Playwright PDF rendering")
 | 
				
			||||||
    parser.add_argument("--verbose", action="store_true", help="Enable verbose output including table test verification")
 | 
					    parser.add_argument("--verbose", action="store_true", help="Enable verbose output including table test verification")
 | 
				
			||||||
@ -1210,8 +1203,8 @@ def main():
 | 
				
			|||||||
        print("Error: API key not provided. Use --api_key or set ANTHROPIC_API_KEY environment variable.")
 | 
					        print("Error: API key not provided. Use --api_key or set ANTHROPIC_API_KEY environment variable.")
 | 
				
			||||||
        return
 | 
					        return
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Initialize Claude client
 | 
					    # Initialize async Claude client
 | 
				
			||||||
    client = Anthropic(api_key=api_key)
 | 
					    client = AsyncAnthropic(api_key=api_key)
 | 
				
			||||||
    
 | 
					    
 | 
				
			||||||
    # Initialize PDF filter if enabled
 | 
					    # Initialize PDF filter if enabled
 | 
				
			||||||
    pdf_filter = None
 | 
					    pdf_filter = None
 | 
				
			||||||
@ -1256,32 +1249,24 @@ def main():
 | 
				
			|||||||
    test_types = {"present": 0, "absent": 0, "table": 0, "order": 0}
 | 
					    test_types = {"present": 0, "absent": 0, "table": 0, "order": 0}
 | 
				
			||||||
    results = []
 | 
					    results = []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Initialize a threading lock for file access
 | 
					    # Initialize an asyncio lock for file access
 | 
				
			||||||
    import threading
 | 
					    file_lock = asyncio.Lock()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    file_lock = threading.Lock()
 | 
					    # Process PDFs in parallel using asyncio
 | 
				
			||||||
 | 
					    async def process_with_progress(pdf_info):
 | 
				
			||||||
    # Process PDFs in parallel
 | 
					        s3_path = pdf_info[0]
 | 
				
			||||||
    with ThreadPoolExecutor(max_workers=args.parallel) as executor:
 | 
					 | 
				
			||||||
        # Submit all tasks
 | 
					 | 
				
			||||||
        futures = {executor.submit(process_pdf, (s3_path, i), args, client, pdf_filter): s3_path for i, s3_path in enumerate(s3_paths)}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # Process results as they complete
 | 
					 | 
				
			||||||
        for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures), desc="Processing PDFs"):
 | 
					 | 
				
			||||||
            s3_path = futures[future]
 | 
					 | 
				
			||||||
        try:
 | 
					        try:
 | 
				
			||||||
                result = future.result()
 | 
					            result = await process_pdf(pdf_info, args, client, pdf_filter)
 | 
				
			||||||
            if result and result.get("tests"):
 | 
					            if result and result.get("tests"):
 | 
				
			||||||
                    results.append(result)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                # Append tests to synthetic.json as they're created (JSONL format)
 | 
					                # Append tests to synthetic.json as they're created (JSONL format)
 | 
				
			||||||
                    with file_lock:
 | 
					                async with file_lock:
 | 
				
			||||||
                    # Append each test as a separate JSON line
 | 
					                    # Append each test as a separate JSON line
 | 
				
			||||||
                    with open(synthetic_json_path, "a") as f:
 | 
					                    with open(synthetic_json_path, "a") as f:
 | 
				
			||||||
                        for test in result["tests"]:
 | 
					                        for test in result["tests"]:
 | 
				
			||||||
                            f.write(json.dumps(test) + "\n")
 | 
					                            f.write(json.dumps(test) + "\n")
 | 
				
			||||||
                    
 | 
					                    
 | 
				
			||||||
                    # Update counters
 | 
					                    # Update counters
 | 
				
			||||||
 | 
					                    nonlocal test_counter
 | 
				
			||||||
                    test_counter += len(result["tests"])
 | 
					                    test_counter += len(result["tests"])
 | 
				
			||||||
                    for test in result["tests"]:
 | 
					                    for test in result["tests"]:
 | 
				
			||||||
                        test_type = test.get("type", "")
 | 
					                        test_type = test.get("type", "")
 | 
				
			||||||
@ -1289,8 +1274,31 @@ def main():
 | 
				
			|||||||
                            test_types[test_type] += 1
 | 
					                            test_types[test_type] += 1
 | 
				
			||||||
                    
 | 
					                    
 | 
				
			||||||
                    print(f"Added {len(result['tests'])} tests from {result['pdf_id']}, total: {test_counter}")
 | 
					                    print(f"Added {len(result['tests'])} tests from {result['pdf_id']}, total: {test_counter}")
 | 
				
			||||||
 | 
					                
 | 
				
			||||||
 | 
					                return result
 | 
				
			||||||
        except Exception as e:
 | 
					        except Exception as e:
 | 
				
			||||||
            print(f"Error processing {s3_path}: {e}")
 | 
					            print(f"Error processing {s3_path}: {e}")
 | 
				
			||||||
 | 
					            return None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # Create tasks for all PDFs
 | 
				
			||||||
 | 
					    tasks = []
 | 
				
			||||||
 | 
					    for i, s3_path in enumerate(s3_paths):
 | 
				
			||||||
 | 
					        tasks.append(process_with_progress((s3_path, i)))
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    # Run tasks with limited concurrency
 | 
				
			||||||
 | 
					    semaphore = asyncio.Semaphore(args.parallel)
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    async def bounded_task(task_coro):
 | 
				
			||||||
 | 
					        async with semaphore:
 | 
				
			||||||
 | 
					            return await task_coro
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    bounded_tasks = [bounded_task(task) for task in tasks]
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    # Process all tasks with progress bar
 | 
				
			||||||
 | 
					    for coro in tqdm(asyncio.as_completed(bounded_tasks), total=len(bounded_tasks), desc="Processing PDFs"):
 | 
				
			||||||
 | 
					        result = await coro
 | 
				
			||||||
 | 
					        if result:
 | 
				
			||||||
 | 
					            results.append(result)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    print(f"Generated {len(results)} HTML templates")
 | 
					    print(f"Generated {len(results)} HTML templates")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -1312,4 +1320,4 @@ def main():
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if __name__ == "__main__":
 | 
					if __name__ == "__main__":
 | 
				
			||||||
    main()
 | 
					    asyncio.run(main())
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user