mirror of
https://github.com/allenai/olmocr.git
synced 2025-11-02 19:13:53 +00:00
Async mode
This commit is contained in:
parent
0df56e958e
commit
5cbe331259
@ -1,6 +1,5 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
@ -8,11 +7,10 @@ import random
|
||||
import re
|
||||
import subprocess
|
||||
import uuid
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Dict, List
|
||||
|
||||
import pypdf
|
||||
from anthropic import Anthropic
|
||||
from anthropic import AsyncAnthropic
|
||||
from bs4 import BeautifulSoup
|
||||
from markdownify import MarkdownConverter, SPACES
|
||||
from playwright.async_api import async_playwright
|
||||
@ -226,13 +224,13 @@ def extract_code_block(initial_response):
|
||||
return None
|
||||
|
||||
|
||||
def generate_html_from_image(client, image_base64):
|
||||
async def generate_html_from_image(client, image_base64):
|
||||
"""Call Claude API to generate HTML from an image using a multi-step prompting strategy."""
|
||||
png_width, png_height = get_png_dimensions_from_base64(image_base64)
|
||||
|
||||
try:
|
||||
# Step 1: Initial analysis and column detection
|
||||
analysis_response = client.messages.create(
|
||||
analysis_response = await client.messages.create(
|
||||
model="claude-sonnet-4-20250514",
|
||||
max_tokens=2000,
|
||||
temperature=0.1,
|
||||
@ -261,7 +259,7 @@ def generate_html_from_image(client, image_base64):
|
||||
analysis_text += content.text
|
||||
|
||||
# Step 2: Initial HTML generation with detailed layout instructions
|
||||
initial_response = client.messages.create(
|
||||
initial_response = await client.messages.create(
|
||||
model="claude-sonnet-4-20250514",
|
||||
max_tokens=6000,
|
||||
temperature=0.2,
|
||||
@ -1012,7 +1010,7 @@ def generate_tests_from_html(html_content: str, pdf_id: str, page_num: int, verb
|
||||
return validated_tests
|
||||
|
||||
|
||||
def process_pdf(pdf_info, args, client, pdf_filter=None):
|
||||
async def process_pdf(pdf_info, args, client, pdf_filter=None):
|
||||
"""Process a single PDF, render a random page, and create an HTML template."""
|
||||
s3_path, index = pdf_info
|
||||
|
||||
@ -1047,11 +1045,12 @@ def process_pdf(pdf_info, args, client, pdf_filter=None):
|
||||
# Select a random page
|
||||
page_num = random.randint(1, num_pages)
|
||||
|
||||
# Render the page as a base64 PNG
|
||||
image_base64 = render_pdf_to_base64png(local_pdf_path, page_num, target_longest_image_dim=2048)
|
||||
# Render the page as a base64 PNG (run in thread pool since it's blocking I/O)
|
||||
loop = asyncio.get_event_loop()
|
||||
image_base64 = await loop.run_in_executor(None, render_pdf_to_base64png, local_pdf_path, page_num, 2048)
|
||||
|
||||
# Generate HTML from the image
|
||||
html_content = generate_html_from_image(client, image_base64)
|
||||
html_content = await generate_html_from_image(client, image_base64)
|
||||
if not html_content:
|
||||
print(f"Failed to generate HTML for {s3_path}, page {page_num}")
|
||||
return None
|
||||
@ -1117,14 +1116,8 @@ def process_pdf(pdf_info, args, client, pdf_filter=None):
|
||||
# Get PNG dimensions
|
||||
png_width, png_height = get_png_dimensions_from_base64(image_base64)
|
||||
|
||||
# Run the async function in the synchronous context
|
||||
# Create a new event loop to avoid conflicts
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
render_success = loop.run_until_complete(render_pdf_with_playwright(html_content, playwright_pdf_path, png_width, png_height))
|
||||
finally:
|
||||
loop.close()
|
||||
# Run the async function directly since we're already in an async context
|
||||
render_success = await render_pdf_with_playwright(html_content, playwright_pdf_path, png_width, png_height)
|
||||
|
||||
if render_success:
|
||||
print(f"Successfully rendered with Playwright: {playwright_pdf_path}")
|
||||
@ -1182,7 +1175,7 @@ def process_pdf(pdf_info, args, client, pdf_filter=None):
|
||||
subprocess.run(["rm", "-rf", temp_pdf_dir])
|
||||
|
||||
|
||||
def main():
|
||||
async def main():
|
||||
# Configure logging to suppress httpx messages
|
||||
logging.getLogger("httpx").setLevel(logging.WARNING)
|
||||
logging.getLogger("httpcore").setLevel(logging.WARNING)
|
||||
@ -1192,7 +1185,7 @@ def main():
|
||||
parser.add_argument("--output_dir", required=True, help="Directory to store extracted pages and tests")
|
||||
parser.add_argument("--temp_dir", default="/tmp/mine_tables", help="Directory for temporary files")
|
||||
parser.add_argument("--max_tests", type=int, default=100, help="Maximum number of tests to generate")
|
||||
parser.add_argument("--parallel", type=int, default=1, help="Number of parallel threads to use")
|
||||
parser.add_argument("--parallel", type=int, default=1, help="Number of parallel tasks to use")
|
||||
parser.add_argument("--api_key", help="Claude API key (or set ANTHROPIC_API_KEY environment variable)")
|
||||
parser.add_argument("--skip_playwright", action="store_true", help="Skip Playwright PDF rendering")
|
||||
parser.add_argument("--verbose", action="store_true", help="Enable verbose output including table test verification")
|
||||
@ -1210,8 +1203,8 @@ def main():
|
||||
print("Error: API key not provided. Use --api_key or set ANTHROPIC_API_KEY environment variable.")
|
||||
return
|
||||
|
||||
# Initialize Claude client
|
||||
client = Anthropic(api_key=api_key)
|
||||
# Initialize async Claude client
|
||||
client = AsyncAnthropic(api_key=api_key)
|
||||
|
||||
# Initialize PDF filter if enabled
|
||||
pdf_filter = None
|
||||
@ -1256,41 +1249,56 @@ def main():
|
||||
test_types = {"present": 0, "absent": 0, "table": 0, "order": 0}
|
||||
results = []
|
||||
|
||||
# Initialize a threading lock for file access
|
||||
import threading
|
||||
# Initialize an asyncio lock for file access
|
||||
file_lock = asyncio.Lock()
|
||||
|
||||
file_lock = threading.Lock()
|
||||
|
||||
# Process PDFs in parallel
|
||||
with ThreadPoolExecutor(max_workers=args.parallel) as executor:
|
||||
# Submit all tasks
|
||||
futures = {executor.submit(process_pdf, (s3_path, i), args, client, pdf_filter): s3_path for i, s3_path in enumerate(s3_paths)}
|
||||
|
||||
# Process results as they complete
|
||||
for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures), desc="Processing PDFs"):
|
||||
s3_path = futures[future]
|
||||
try:
|
||||
result = future.result()
|
||||
if result and result.get("tests"):
|
||||
results.append(result)
|
||||
|
||||
# Append tests to synthetic.json as they're created (JSONL format)
|
||||
with file_lock:
|
||||
# Append each test as a separate JSON line
|
||||
with open(synthetic_json_path, "a") as f:
|
||||
for test in result["tests"]:
|
||||
f.write(json.dumps(test) + "\n")
|
||||
|
||||
# Update counters
|
||||
test_counter += len(result["tests"])
|
||||
# Process PDFs in parallel using asyncio
|
||||
async def process_with_progress(pdf_info):
|
||||
s3_path = pdf_info[0]
|
||||
try:
|
||||
result = await process_pdf(pdf_info, args, client, pdf_filter)
|
||||
if result and result.get("tests"):
|
||||
# Append tests to synthetic.json as they're created (JSONL format)
|
||||
async with file_lock:
|
||||
# Append each test as a separate JSON line
|
||||
with open(synthetic_json_path, "a") as f:
|
||||
for test in result["tests"]:
|
||||
test_type = test.get("type", "")
|
||||
if test_type in test_types:
|
||||
test_types[test_type] += 1
|
||||
f.write(json.dumps(test) + "\n")
|
||||
|
||||
# Update counters
|
||||
nonlocal test_counter
|
||||
test_counter += len(result["tests"])
|
||||
for test in result["tests"]:
|
||||
test_type = test.get("type", "")
|
||||
if test_type in test_types:
|
||||
test_types[test_type] += 1
|
||||
|
||||
print(f"Added {len(result['tests'])} tests from {result['pdf_id']}, total: {test_counter}")
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
print(f"Error processing {s3_path}: {e}")
|
||||
return None
|
||||
|
||||
print(f"Added {len(result['tests'])} tests from {result['pdf_id']}, total: {test_counter}")
|
||||
except Exception as e:
|
||||
print(f"Error processing {s3_path}: {e}")
|
||||
# Create tasks for all PDFs
|
||||
tasks = []
|
||||
for i, s3_path in enumerate(s3_paths):
|
||||
tasks.append(process_with_progress((s3_path, i)))
|
||||
|
||||
# Run tasks with limited concurrency
|
||||
semaphore = asyncio.Semaphore(args.parallel)
|
||||
|
||||
async def bounded_task(task_coro):
|
||||
async with semaphore:
|
||||
return await task_coro
|
||||
|
||||
bounded_tasks = [bounded_task(task) for task in tasks]
|
||||
|
||||
# Process all tasks with progress bar
|
||||
for coro in tqdm(asyncio.as_completed(bounded_tasks), total=len(bounded_tasks), desc="Processing PDFs"):
|
||||
result = await coro
|
||||
if result:
|
||||
results.append(result)
|
||||
|
||||
print(f"Generated {len(results)} HTML templates")
|
||||
|
||||
@ -1312,4 +1320,4 @@ def main():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
asyncio.run(main())
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user