mirror of
https://github.com/allenai/olmocr.git
synced 2025-06-27 04:00:02 +00:00
Olmocr runner implemented
This commit is contained in:
parent
aac0c1503d
commit
d4b902cea2
@ -2,6 +2,8 @@ import argparse
|
||||
import glob
|
||||
import importlib
|
||||
import os
|
||||
import asyncio
|
||||
import inspect
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
@ -32,6 +34,34 @@ def parse_method_arg(method_arg):
|
||||
return name, kwargs
|
||||
|
||||
|
||||
async def process_pdfs(config, pdf_directory, data_directory, repeats):
|
||||
"""Process PDFs with both sync and async functions"""
|
||||
for candidate in config.keys():
|
||||
print(f"Starting conversion using {candidate} with kwargs: {config[candidate]['kwargs']}")
|
||||
candidate_output_dir = os.path.join(data_directory, candidate)
|
||||
os.makedirs(candidate_output_dir, exist_ok=True)
|
||||
|
||||
method = config[candidate]["method"]
|
||||
kwargs = config[candidate]["kwargs"]
|
||||
is_async = asyncio.iscoroutinefunction(method)
|
||||
|
||||
for pdf_path in tqdm(glob.glob(os.path.join(pdf_directory, "*.pdf")), desc=candidate):
|
||||
base_name = os.path.basename(pdf_path).replace(".pdf", "")
|
||||
|
||||
for i in range(1, repeats + 1):
|
||||
if is_async:
|
||||
# Run async function
|
||||
markdown = await method(pdf_path, page_num=1, **kwargs)
|
||||
else:
|
||||
# Run synchronous function
|
||||
markdown = method(pdf_path, page_num=1, **kwargs)
|
||||
|
||||
output_filename = f"{base_name}_{i}.md"
|
||||
output_path = os.path.join(candidate_output_dir, output_filename)
|
||||
with open(output_path, "w") as out_f:
|
||||
out_f.write(markdown)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Run PDF conversion using specified OCR methods and extra parameters.")
|
||||
parser.add_argument("methods", nargs="+", help="Methods to run in the format method[:key=value ...]. " "Example: gotocr mineru:temperature=2 marker:runs=3")
|
||||
@ -40,6 +70,7 @@ if __name__ == "__main__":
|
||||
|
||||
# Mapping of method names to a tuple: (module path, function name)
|
||||
available_methods = {
|
||||
"olmocr": ("olmocr.bench.runners.run_olmocr", "run_olmocr"),
|
||||
"gotocr": ("olmocr.bench.runners.run_gotocr", "run_gotocr"),
|
||||
"marker": ("olmocr.bench.runners.run_marker", "run_marker"),
|
||||
"mineru": ("olmocr.bench.runners.run_mineru", "run_mineru"),
|
||||
@ -61,18 +92,5 @@ if __name__ == "__main__":
|
||||
data_directory = os.path.join(os.path.dirname(__file__), "sample_data")
|
||||
pdf_directory = os.path.join(data_directory, "pdfs")
|
||||
|
||||
# Process each PDF using each specified method and repeat the conversion as needed.
|
||||
for candidate in config.keys():
|
||||
print(f"Starting conversion using {candidate} with kwargs: {config[candidate]['kwargs']}")
|
||||
candidate_output_dir = os.path.join(data_directory, candidate)
|
||||
os.makedirs(candidate_output_dir, exist_ok=True)
|
||||
|
||||
for pdf_path in tqdm(glob.glob(os.path.join(pdf_directory, "*.pdf")), desc=candidate):
|
||||
base_name = os.path.basename(pdf_path).replace(".pdf", "")
|
||||
# Repeat the conversion as many times as specified.
|
||||
for i in range(1, args.repeats + 1):
|
||||
markdown = config[candidate]["method"](pdf_path, page_num=1, **config[candidate]["kwargs"])
|
||||
output_filename = f"{base_name}_{i}.md"
|
||||
output_path = os.path.join(candidate_output_dir, output_filename)
|
||||
with open(output_path, "w") as out_f:
|
||||
out_f.write(markdown)
|
||||
# Run the async process function
|
||||
asyncio.run(process_pdfs(config, pdf_directory, data_directory, args.repeats))
|
@ -1,35 +1,106 @@
|
||||
import asyncio
|
||||
import glob
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
import argparse
|
||||
from typing import Optional
|
||||
import json
|
||||
|
||||
import olmocr.pipeline
|
||||
# Import necessary components from olmocr
|
||||
from olmocr.pipeline import (
|
||||
sglang_server_host,
|
||||
sglang_server_ready,
|
||||
build_page_query,
|
||||
apost,
|
||||
SGLANG_SERVER_PORT,
|
||||
MetricsKeeper,
|
||||
WorkerTracker
|
||||
)
|
||||
from olmocr.prompts import PageResponse
|
||||
|
||||
# Set sys.argv as if you were running the script from the command line.
|
||||
|
||||
workspace_dir = "olmocr/bench/sample_data/olmocr/workspace"
|
||||
# Setup basic logging
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||
logger = logging.getLogger("olmocr_runner")
|
||||
|
||||
sys.argv = [
|
||||
"pipeline.py", # The script name (can be arbitrary)
|
||||
"olmocr/bench/sample_data/olmocr/workspace", # Positional argument: workspace
|
||||
"--pdfs",
|
||||
*list(glob.glob("olmocr/bench/sample_data/pdfs/*.pdf")), # PDF paths
|
||||
]
|
||||
# Basic configuration
|
||||
@dataclass
|
||||
class Args:
|
||||
model: str = "allenai/olmOCR-7B-0225-preview"
|
||||
model_chat_template: str = "qwen2-vl"
|
||||
model_max_context: int = 8192
|
||||
target_longest_image_dim: int = 1024
|
||||
target_anchor_text_len: int = 6000
|
||||
|
||||
# Call the async main() function.
|
||||
asyncio.run(olmocr.pipeline.main())
|
||||
|
||||
# Now, take a produced jsonl files and unpack them into mds
|
||||
for jsonl_path in glob.glob(workspace_dir + "/results/*.jsonl"):
|
||||
with open(jsonl_path, "r") as jsonl_f:
|
||||
for line in jsonl_f:
|
||||
data = json.loads(line)
|
||||
|
||||
name = os.path.basename(data["metadata"]["Source-File"])
|
||||
|
||||
with open(f"olmocr/bench/sample_data/olmocr/{name.replace('.pdf', '.md')}", "w") as out_f:
|
||||
out_f.write(data["text"])
|
||||
|
||||
shutil.rmtree(workspace_dir)
|
||||
async def run_olmocr(pdf_path: str, page_num: int = 1, temperature: float = 0.8) -> str:
|
||||
"""
|
||||
Process a single page of a PDF using the olmocr pipeline.
|
||||
|
||||
Args:
|
||||
pdf_path: Path to the PDF file
|
||||
page_num: Page number to process (1-indexed)
|
||||
temperature: Temperature parameter for the model
|
||||
|
||||
Returns:
|
||||
The extracted text from the page
|
||||
"""
|
||||
# Ensure global variables are initialized
|
||||
global metrics, tracker
|
||||
if 'metrics' not in globals() or metrics is None:
|
||||
metrics = MetricsKeeper(window=60*5)
|
||||
if 'tracker' not in globals() or tracker is None:
|
||||
tracker = WorkerTracker()
|
||||
|
||||
args = Args()
|
||||
semaphore = asyncio.Semaphore(1)
|
||||
|
||||
# Ensure server is running
|
||||
server_task = None
|
||||
try:
|
||||
await asyncio.wait_for(sglang_server_ready(), timeout=5)
|
||||
print("Using existing sglang server")
|
||||
except Exception:
|
||||
print("Starting new sglang server")
|
||||
server_task = asyncio.create_task(sglang_server_host(args, semaphore))
|
||||
await sglang_server_ready()
|
||||
|
||||
try:
|
||||
# Process the page
|
||||
query = await build_page_query(
|
||||
pdf_path,
|
||||
page_num,
|
||||
args.target_longest_image_dim,
|
||||
args.target_anchor_text_len
|
||||
)
|
||||
query["temperature"] = temperature
|
||||
|
||||
# Make request and get response
|
||||
url = f"http://localhost:{SGLANG_SERVER_PORT}/v1/chat/completions"
|
||||
status_code, response_body = await apost(url, json_data=query)
|
||||
|
||||
if status_code != 200:
|
||||
return f"Error: HTTP status {status_code}"
|
||||
|
||||
# Parse response
|
||||
response_data = json.loads(response_body)
|
||||
content = response_data["choices"][0]["message"]["content"]
|
||||
model_json = json.loads(content)
|
||||
page_response = PageResponse(**model_json)
|
||||
|
||||
# Update metrics
|
||||
metrics.add_metrics(
|
||||
sglang_input_tokens=response_data["usage"].get("prompt_tokens", 0),
|
||||
sglang_output_tokens=response_data["usage"].get("completion_tokens", 0)
|
||||
)
|
||||
|
||||
return page_response.natural_text
|
||||
|
||||
except Exception as e:
|
||||
return f"Error: {type(e).__name__} - {str(e)}"
|
||||
|
||||
finally:
|
||||
# We leave the server running for potential reuse
|
||||
# This is more efficient if multiple pages will be processed
|
||||
pass
|
Loading…
x
Reference in New Issue
Block a user