Olmocr runner implemented

This commit is contained in:
Jake Poznanski 2025-02-25 14:25:02 -08:00
parent aac0c1503d
commit d4b902cea2
2 changed files with 132 additions and 43 deletions

View File

@ -2,6 +2,8 @@ import argparse
import glob
import importlib
import os
import asyncio
import inspect
from tqdm import tqdm
@ -32,6 +34,34 @@ def parse_method_arg(method_arg):
return name, kwargs
async def process_pdfs(config, pdf_directory, data_directory, repeats):
"""Process PDFs with both sync and async functions"""
for candidate in config.keys():
print(f"Starting conversion using {candidate} with kwargs: {config[candidate]['kwargs']}")
candidate_output_dir = os.path.join(data_directory, candidate)
os.makedirs(candidate_output_dir, exist_ok=True)
method = config[candidate]["method"]
kwargs = config[candidate]["kwargs"]
is_async = asyncio.iscoroutinefunction(method)
for pdf_path in tqdm(glob.glob(os.path.join(pdf_directory, "*.pdf")), desc=candidate):
base_name = os.path.basename(pdf_path).replace(".pdf", "")
for i in range(1, repeats + 1):
if is_async:
# Run async function
markdown = await method(pdf_path, page_num=1, **kwargs)
else:
# Run synchronous function
markdown = method(pdf_path, page_num=1, **kwargs)
output_filename = f"{base_name}_{i}.md"
output_path = os.path.join(candidate_output_dir, output_filename)
with open(output_path, "w") as out_f:
out_f.write(markdown)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run PDF conversion using specified OCR methods and extra parameters.")
parser.add_argument("methods", nargs="+", help="Methods to run in the format method[:key=value ...]. " "Example: gotocr mineru:temperature=2 marker:runs=3")
@ -40,6 +70,7 @@ if __name__ == "__main__":
# Mapping of method names to a tuple: (module path, function name)
available_methods = {
"olmocr": ("olmocr.bench.runners.run_olmocr", "run_olmocr"),
"gotocr": ("olmocr.bench.runners.run_gotocr", "run_gotocr"),
"marker": ("olmocr.bench.runners.run_marker", "run_marker"),
"mineru": ("olmocr.bench.runners.run_mineru", "run_mineru"),
@ -61,18 +92,5 @@ if __name__ == "__main__":
data_directory = os.path.join(os.path.dirname(__file__), "sample_data")
pdf_directory = os.path.join(data_directory, "pdfs")
# Process each PDF using each specified method and repeat the conversion as needed.
for candidate in config.keys():
print(f"Starting conversion using {candidate} with kwargs: {config[candidate]['kwargs']}")
candidate_output_dir = os.path.join(data_directory, candidate)
os.makedirs(candidate_output_dir, exist_ok=True)
for pdf_path in tqdm(glob.glob(os.path.join(pdf_directory, "*.pdf")), desc=candidate):
base_name = os.path.basename(pdf_path).replace(".pdf", "")
# Repeat the conversion as many times as specified.
for i in range(1, args.repeats + 1):
markdown = config[candidate]["method"](pdf_path, page_num=1, **config[candidate]["kwargs"])
output_filename = f"{base_name}_{i}.md"
output_path = os.path.join(candidate_output_dir, output_filename)
with open(output_path, "w") as out_f:
out_f.write(markdown)
# Run the async process function
asyncio.run(process_pdfs(config, pdf_directory, data_directory, args.repeats))

View File

@ -1,35 +1,106 @@
import asyncio
import glob
import json
import logging
import os
import shutil
import sys
import tempfile
from dataclasses import dataclass
from functools import partial
import argparse
from typing import Optional
import json
import olmocr.pipeline
# Import necessary components from olmocr
from olmocr.pipeline import (
sglang_server_host,
sglang_server_ready,
build_page_query,
apost,
SGLANG_SERVER_PORT,
MetricsKeeper,
WorkerTracker
)
from olmocr.prompts import PageResponse
# Set sys.argv as if you were running the script from the command line.
workspace_dir = "olmocr/bench/sample_data/olmocr/workspace"
# Setup basic logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
logger = logging.getLogger("olmocr_runner")
sys.argv = [
"pipeline.py", # The script name (can be arbitrary)
"olmocr/bench/sample_data/olmocr/workspace", # Positional argument: workspace
"--pdfs",
*list(glob.glob("olmocr/bench/sample_data/pdfs/*.pdf")), # PDF paths
]
# Basic configuration
@dataclass
class Args:
model: str = "allenai/olmOCR-7B-0225-preview"
model_chat_template: str = "qwen2-vl"
model_max_context: int = 8192
target_longest_image_dim: int = 1024
target_anchor_text_len: int = 6000
# Call the async main() function.
asyncio.run(olmocr.pipeline.main())
# Now, take a produced jsonl files and unpack them into mds
for jsonl_path in glob.glob(workspace_dir + "/results/*.jsonl"):
with open(jsonl_path, "r") as jsonl_f:
for line in jsonl_f:
data = json.loads(line)
name = os.path.basename(data["metadata"]["Source-File"])
with open(f"olmocr/bench/sample_data/olmocr/{name.replace('.pdf', '.md')}", "w") as out_f:
out_f.write(data["text"])
shutil.rmtree(workspace_dir)
async def run_olmocr(pdf_path: str, page_num: int = 1, temperature: float = 0.8) -> str:
"""
Process a single page of a PDF using the olmocr pipeline.
Args:
pdf_path: Path to the PDF file
page_num: Page number to process (1-indexed)
temperature: Temperature parameter for the model
Returns:
The extracted text from the page
"""
# Ensure global variables are initialized
global metrics, tracker
if 'metrics' not in globals() or metrics is None:
metrics = MetricsKeeper(window=60*5)
if 'tracker' not in globals() or tracker is None:
tracker = WorkerTracker()
args = Args()
semaphore = asyncio.Semaphore(1)
# Ensure server is running
server_task = None
try:
await asyncio.wait_for(sglang_server_ready(), timeout=5)
print("Using existing sglang server")
except Exception:
print("Starting new sglang server")
server_task = asyncio.create_task(sglang_server_host(args, semaphore))
await sglang_server_ready()
try:
# Process the page
query = await build_page_query(
pdf_path,
page_num,
args.target_longest_image_dim,
args.target_anchor_text_len
)
query["temperature"] = temperature
# Make request and get response
url = f"http://localhost:{SGLANG_SERVER_PORT}/v1/chat/completions"
status_code, response_body = await apost(url, json_data=query)
if status_code != 200:
return f"Error: HTTP status {status_code}"
# Parse response
response_data = json.loads(response_body)
content = response_data["choices"][0]["message"]["content"]
model_json = json.loads(content)
page_response = PageResponse(**model_json)
# Update metrics
metrics.add_metrics(
sglang_input_tokens=response_data["usage"].get("prompt_tokens", 0),
sglang_output_tokens=response_data["usage"].get("completion_tokens", 0)
)
return page_response.natural_text
except Exception as e:
return f"Error: {type(e).__name__} - {str(e)}"
finally:
# We leave the server running for potential reuse
# This is more efficient if multiple pages will be processed
pass