mirror of
https://github.com/allenai/olmocr.git
synced 2025-11-03 03:25:22 +00:00
Async mode
This commit is contained in:
parent
0df56e958e
commit
5cbe331259
@ -1,6 +1,5 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import asyncio
|
import asyncio
|
||||||
import concurrent.futures
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
@ -8,11 +7,10 @@ import random
|
|||||||
import re
|
import re
|
||||||
import subprocess
|
import subprocess
|
||||||
import uuid
|
import uuid
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
|
|
||||||
import pypdf
|
import pypdf
|
||||||
from anthropic import Anthropic
|
from anthropic import AsyncAnthropic
|
||||||
from bs4 import BeautifulSoup
|
from bs4 import BeautifulSoup
|
||||||
from markdownify import MarkdownConverter, SPACES
|
from markdownify import MarkdownConverter, SPACES
|
||||||
from playwright.async_api import async_playwright
|
from playwright.async_api import async_playwright
|
||||||
@ -226,13 +224,13 @@ def extract_code_block(initial_response):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def generate_html_from_image(client, image_base64):
|
async def generate_html_from_image(client, image_base64):
|
||||||
"""Call Claude API to generate HTML from an image using a multi-step prompting strategy."""
|
"""Call Claude API to generate HTML from an image using a multi-step prompting strategy."""
|
||||||
png_width, png_height = get_png_dimensions_from_base64(image_base64)
|
png_width, png_height = get_png_dimensions_from_base64(image_base64)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Step 1: Initial analysis and column detection
|
# Step 1: Initial analysis and column detection
|
||||||
analysis_response = client.messages.create(
|
analysis_response = await client.messages.create(
|
||||||
model="claude-sonnet-4-20250514",
|
model="claude-sonnet-4-20250514",
|
||||||
max_tokens=2000,
|
max_tokens=2000,
|
||||||
temperature=0.1,
|
temperature=0.1,
|
||||||
@ -261,7 +259,7 @@ def generate_html_from_image(client, image_base64):
|
|||||||
analysis_text += content.text
|
analysis_text += content.text
|
||||||
|
|
||||||
# Step 2: Initial HTML generation with detailed layout instructions
|
# Step 2: Initial HTML generation with detailed layout instructions
|
||||||
initial_response = client.messages.create(
|
initial_response = await client.messages.create(
|
||||||
model="claude-sonnet-4-20250514",
|
model="claude-sonnet-4-20250514",
|
||||||
max_tokens=6000,
|
max_tokens=6000,
|
||||||
temperature=0.2,
|
temperature=0.2,
|
||||||
@ -1012,7 +1010,7 @@ def generate_tests_from_html(html_content: str, pdf_id: str, page_num: int, verb
|
|||||||
return validated_tests
|
return validated_tests
|
||||||
|
|
||||||
|
|
||||||
def process_pdf(pdf_info, args, client, pdf_filter=None):
|
async def process_pdf(pdf_info, args, client, pdf_filter=None):
|
||||||
"""Process a single PDF, render a random page, and create an HTML template."""
|
"""Process a single PDF, render a random page, and create an HTML template."""
|
||||||
s3_path, index = pdf_info
|
s3_path, index = pdf_info
|
||||||
|
|
||||||
@ -1047,11 +1045,12 @@ def process_pdf(pdf_info, args, client, pdf_filter=None):
|
|||||||
# Select a random page
|
# Select a random page
|
||||||
page_num = random.randint(1, num_pages)
|
page_num = random.randint(1, num_pages)
|
||||||
|
|
||||||
# Render the page as a base64 PNG
|
# Render the page as a base64 PNG (run in thread pool since it's blocking I/O)
|
||||||
image_base64 = render_pdf_to_base64png(local_pdf_path, page_num, target_longest_image_dim=2048)
|
loop = asyncio.get_event_loop()
|
||||||
|
image_base64 = await loop.run_in_executor(None, render_pdf_to_base64png, local_pdf_path, page_num, 2048)
|
||||||
|
|
||||||
# Generate HTML from the image
|
# Generate HTML from the image
|
||||||
html_content = generate_html_from_image(client, image_base64)
|
html_content = await generate_html_from_image(client, image_base64)
|
||||||
if not html_content:
|
if not html_content:
|
||||||
print(f"Failed to generate HTML for {s3_path}, page {page_num}")
|
print(f"Failed to generate HTML for {s3_path}, page {page_num}")
|
||||||
return None
|
return None
|
||||||
@ -1117,14 +1116,8 @@ def process_pdf(pdf_info, args, client, pdf_filter=None):
|
|||||||
# Get PNG dimensions
|
# Get PNG dimensions
|
||||||
png_width, png_height = get_png_dimensions_from_base64(image_base64)
|
png_width, png_height = get_png_dimensions_from_base64(image_base64)
|
||||||
|
|
||||||
# Run the async function in the synchronous context
|
# Run the async function directly since we're already in an async context
|
||||||
# Create a new event loop to avoid conflicts
|
render_success = await render_pdf_with_playwright(html_content, playwright_pdf_path, png_width, png_height)
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
asyncio.set_event_loop(loop)
|
|
||||||
try:
|
|
||||||
render_success = loop.run_until_complete(render_pdf_with_playwright(html_content, playwright_pdf_path, png_width, png_height))
|
|
||||||
finally:
|
|
||||||
loop.close()
|
|
||||||
|
|
||||||
if render_success:
|
if render_success:
|
||||||
print(f"Successfully rendered with Playwright: {playwright_pdf_path}")
|
print(f"Successfully rendered with Playwright: {playwright_pdf_path}")
|
||||||
@ -1182,7 +1175,7 @@ def process_pdf(pdf_info, args, client, pdf_filter=None):
|
|||||||
subprocess.run(["rm", "-rf", temp_pdf_dir])
|
subprocess.run(["rm", "-rf", temp_pdf_dir])
|
||||||
|
|
||||||
|
|
||||||
def main():
|
async def main():
|
||||||
# Configure logging to suppress httpx messages
|
# Configure logging to suppress httpx messages
|
||||||
logging.getLogger("httpx").setLevel(logging.WARNING)
|
logging.getLogger("httpx").setLevel(logging.WARNING)
|
||||||
logging.getLogger("httpcore").setLevel(logging.WARNING)
|
logging.getLogger("httpcore").setLevel(logging.WARNING)
|
||||||
@ -1192,7 +1185,7 @@ def main():
|
|||||||
parser.add_argument("--output_dir", required=True, help="Directory to store extracted pages and tests")
|
parser.add_argument("--output_dir", required=True, help="Directory to store extracted pages and tests")
|
||||||
parser.add_argument("--temp_dir", default="/tmp/mine_tables", help="Directory for temporary files")
|
parser.add_argument("--temp_dir", default="/tmp/mine_tables", help="Directory for temporary files")
|
||||||
parser.add_argument("--max_tests", type=int, default=100, help="Maximum number of tests to generate")
|
parser.add_argument("--max_tests", type=int, default=100, help="Maximum number of tests to generate")
|
||||||
parser.add_argument("--parallel", type=int, default=1, help="Number of parallel threads to use")
|
parser.add_argument("--parallel", type=int, default=1, help="Number of parallel tasks to use")
|
||||||
parser.add_argument("--api_key", help="Claude API key (or set ANTHROPIC_API_KEY environment variable)")
|
parser.add_argument("--api_key", help="Claude API key (or set ANTHROPIC_API_KEY environment variable)")
|
||||||
parser.add_argument("--skip_playwright", action="store_true", help="Skip Playwright PDF rendering")
|
parser.add_argument("--skip_playwright", action="store_true", help="Skip Playwright PDF rendering")
|
||||||
parser.add_argument("--verbose", action="store_true", help="Enable verbose output including table test verification")
|
parser.add_argument("--verbose", action="store_true", help="Enable verbose output including table test verification")
|
||||||
@ -1210,8 +1203,8 @@ def main():
|
|||||||
print("Error: API key not provided. Use --api_key or set ANTHROPIC_API_KEY environment variable.")
|
print("Error: API key not provided. Use --api_key or set ANTHROPIC_API_KEY environment variable.")
|
||||||
return
|
return
|
||||||
|
|
||||||
# Initialize Claude client
|
# Initialize async Claude client
|
||||||
client = Anthropic(api_key=api_key)
|
client = AsyncAnthropic(api_key=api_key)
|
||||||
|
|
||||||
# Initialize PDF filter if enabled
|
# Initialize PDF filter if enabled
|
||||||
pdf_filter = None
|
pdf_filter = None
|
||||||
@ -1256,41 +1249,56 @@ def main():
|
|||||||
test_types = {"present": 0, "absent": 0, "table": 0, "order": 0}
|
test_types = {"present": 0, "absent": 0, "table": 0, "order": 0}
|
||||||
results = []
|
results = []
|
||||||
|
|
||||||
# Initialize a threading lock for file access
|
# Initialize an asyncio lock for file access
|
||||||
import threading
|
file_lock = asyncio.Lock()
|
||||||
|
|
||||||
file_lock = threading.Lock()
|
# Process PDFs in parallel using asyncio
|
||||||
|
async def process_with_progress(pdf_info):
|
||||||
# Process PDFs in parallel
|
s3_path = pdf_info[0]
|
||||||
with ThreadPoolExecutor(max_workers=args.parallel) as executor:
|
try:
|
||||||
# Submit all tasks
|
result = await process_pdf(pdf_info, args, client, pdf_filter)
|
||||||
futures = {executor.submit(process_pdf, (s3_path, i), args, client, pdf_filter): s3_path for i, s3_path in enumerate(s3_paths)}
|
if result and result.get("tests"):
|
||||||
|
# Append tests to synthetic.json as they're created (JSONL format)
|
||||||
# Process results as they complete
|
async with file_lock:
|
||||||
for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures), desc="Processing PDFs"):
|
# Append each test as a separate JSON line
|
||||||
s3_path = futures[future]
|
with open(synthetic_json_path, "a") as f:
|
||||||
try:
|
|
||||||
result = future.result()
|
|
||||||
if result and result.get("tests"):
|
|
||||||
results.append(result)
|
|
||||||
|
|
||||||
# Append tests to synthetic.json as they're created (JSONL format)
|
|
||||||
with file_lock:
|
|
||||||
# Append each test as a separate JSON line
|
|
||||||
with open(synthetic_json_path, "a") as f:
|
|
||||||
for test in result["tests"]:
|
|
||||||
f.write(json.dumps(test) + "\n")
|
|
||||||
|
|
||||||
# Update counters
|
|
||||||
test_counter += len(result["tests"])
|
|
||||||
for test in result["tests"]:
|
for test in result["tests"]:
|
||||||
test_type = test.get("type", "")
|
f.write(json.dumps(test) + "\n")
|
||||||
if test_type in test_types:
|
|
||||||
test_types[test_type] += 1
|
# Update counters
|
||||||
|
nonlocal test_counter
|
||||||
|
test_counter += len(result["tests"])
|
||||||
|
for test in result["tests"]:
|
||||||
|
test_type = test.get("type", "")
|
||||||
|
if test_type in test_types:
|
||||||
|
test_types[test_type] += 1
|
||||||
|
|
||||||
|
print(f"Added {len(result['tests'])} tests from {result['pdf_id']}, total: {test_counter}")
|
||||||
|
|
||||||
|
return result
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error processing {s3_path}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
print(f"Added {len(result['tests'])} tests from {result['pdf_id']}, total: {test_counter}")
|
# Create tasks for all PDFs
|
||||||
except Exception as e:
|
tasks = []
|
||||||
print(f"Error processing {s3_path}: {e}")
|
for i, s3_path in enumerate(s3_paths):
|
||||||
|
tasks.append(process_with_progress((s3_path, i)))
|
||||||
|
|
||||||
|
# Run tasks with limited concurrency
|
||||||
|
semaphore = asyncio.Semaphore(args.parallel)
|
||||||
|
|
||||||
|
async def bounded_task(task_coro):
|
||||||
|
async with semaphore:
|
||||||
|
return await task_coro
|
||||||
|
|
||||||
|
bounded_tasks = [bounded_task(task) for task in tasks]
|
||||||
|
|
||||||
|
# Process all tasks with progress bar
|
||||||
|
for coro in tqdm(asyncio.as_completed(bounded_tasks), total=len(bounded_tasks), desc="Processing PDFs"):
|
||||||
|
result = await coro
|
||||||
|
if result:
|
||||||
|
results.append(result)
|
||||||
|
|
||||||
print(f"Generated {len(results)} HTML templates")
|
print(f"Generated {len(results)} HTML templates")
|
||||||
|
|
||||||
@ -1312,4 +1320,4 @@ def main():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
asyncio.run(main())
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user