mirror of
https://github.com/allenai/olmocr.git
synced 2025-06-27 04:00:02 +00:00
Olmocr runner implemented
This commit is contained in:
parent
aac0c1503d
commit
d4b902cea2
@ -2,6 +2,8 @@ import argparse
|
|||||||
import glob
|
import glob
|
||||||
import importlib
|
import importlib
|
||||||
import os
|
import os
|
||||||
|
import asyncio
|
||||||
|
import inspect
|
||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
@ -32,6 +34,34 @@ def parse_method_arg(method_arg):
|
|||||||
return name, kwargs
|
return name, kwargs
|
||||||
|
|
||||||
|
|
||||||
|
async def process_pdfs(config, pdf_directory, data_directory, repeats):
|
||||||
|
"""Process PDFs with both sync and async functions"""
|
||||||
|
for candidate in config.keys():
|
||||||
|
print(f"Starting conversion using {candidate} with kwargs: {config[candidate]['kwargs']}")
|
||||||
|
candidate_output_dir = os.path.join(data_directory, candidate)
|
||||||
|
os.makedirs(candidate_output_dir, exist_ok=True)
|
||||||
|
|
||||||
|
method = config[candidate]["method"]
|
||||||
|
kwargs = config[candidate]["kwargs"]
|
||||||
|
is_async = asyncio.iscoroutinefunction(method)
|
||||||
|
|
||||||
|
for pdf_path in tqdm(glob.glob(os.path.join(pdf_directory, "*.pdf")), desc=candidate):
|
||||||
|
base_name = os.path.basename(pdf_path).replace(".pdf", "")
|
||||||
|
|
||||||
|
for i in range(1, repeats + 1):
|
||||||
|
if is_async:
|
||||||
|
# Run async function
|
||||||
|
markdown = await method(pdf_path, page_num=1, **kwargs)
|
||||||
|
else:
|
||||||
|
# Run synchronous function
|
||||||
|
markdown = method(pdf_path, page_num=1, **kwargs)
|
||||||
|
|
||||||
|
output_filename = f"{base_name}_{i}.md"
|
||||||
|
output_path = os.path.join(candidate_output_dir, output_filename)
|
||||||
|
with open(output_path, "w") as out_f:
|
||||||
|
out_f.write(markdown)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(description="Run PDF conversion using specified OCR methods and extra parameters.")
|
parser = argparse.ArgumentParser(description="Run PDF conversion using specified OCR methods and extra parameters.")
|
||||||
parser.add_argument("methods", nargs="+", help="Methods to run in the format method[:key=value ...]. " "Example: gotocr mineru:temperature=2 marker:runs=3")
|
parser.add_argument("methods", nargs="+", help="Methods to run in the format method[:key=value ...]. " "Example: gotocr mineru:temperature=2 marker:runs=3")
|
||||||
@ -40,6 +70,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# Mapping of method names to a tuple: (module path, function name)
|
# Mapping of method names to a tuple: (module path, function name)
|
||||||
available_methods = {
|
available_methods = {
|
||||||
|
"olmocr": ("olmocr.bench.runners.run_olmocr", "run_olmocr"),
|
||||||
"gotocr": ("olmocr.bench.runners.run_gotocr", "run_gotocr"),
|
"gotocr": ("olmocr.bench.runners.run_gotocr", "run_gotocr"),
|
||||||
"marker": ("olmocr.bench.runners.run_marker", "run_marker"),
|
"marker": ("olmocr.bench.runners.run_marker", "run_marker"),
|
||||||
"mineru": ("olmocr.bench.runners.run_mineru", "run_mineru"),
|
"mineru": ("olmocr.bench.runners.run_mineru", "run_mineru"),
|
||||||
@ -61,18 +92,5 @@ if __name__ == "__main__":
|
|||||||
data_directory = os.path.join(os.path.dirname(__file__), "sample_data")
|
data_directory = os.path.join(os.path.dirname(__file__), "sample_data")
|
||||||
pdf_directory = os.path.join(data_directory, "pdfs")
|
pdf_directory = os.path.join(data_directory, "pdfs")
|
||||||
|
|
||||||
# Process each PDF using each specified method and repeat the conversion as needed.
|
# Run the async process function
|
||||||
for candidate in config.keys():
|
asyncio.run(process_pdfs(config, pdf_directory, data_directory, args.repeats))
|
||||||
print(f"Starting conversion using {candidate} with kwargs: {config[candidate]['kwargs']}")
|
|
||||||
candidate_output_dir = os.path.join(data_directory, candidate)
|
|
||||||
os.makedirs(candidate_output_dir, exist_ok=True)
|
|
||||||
|
|
||||||
for pdf_path in tqdm(glob.glob(os.path.join(pdf_directory, "*.pdf")), desc=candidate):
|
|
||||||
base_name = os.path.basename(pdf_path).replace(".pdf", "")
|
|
||||||
# Repeat the conversion as many times as specified.
|
|
||||||
for i in range(1, args.repeats + 1):
|
|
||||||
markdown = config[candidate]["method"](pdf_path, page_num=1, **config[candidate]["kwargs"])
|
|
||||||
output_filename = f"{base_name}_{i}.md"
|
|
||||||
output_path = os.path.join(candidate_output_dir, output_filename)
|
|
||||||
with open(output_path, "w") as out_f:
|
|
||||||
out_f.write(markdown)
|
|
@ -1,35 +1,106 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import glob
|
import logging
|
||||||
import json
|
|
||||||
import os
|
import os
|
||||||
import shutil
|
import tempfile
|
||||||
import sys
|
from dataclasses import dataclass
|
||||||
|
from functools import partial
|
||||||
|
import argparse
|
||||||
|
from typing import Optional
|
||||||
|
import json
|
||||||
|
|
||||||
import olmocr.pipeline
|
# Import necessary components from olmocr
|
||||||
|
from olmocr.pipeline import (
|
||||||
|
sglang_server_host,
|
||||||
|
sglang_server_ready,
|
||||||
|
build_page_query,
|
||||||
|
apost,
|
||||||
|
SGLANG_SERVER_PORT,
|
||||||
|
MetricsKeeper,
|
||||||
|
WorkerTracker
|
||||||
|
)
|
||||||
|
from olmocr.prompts import PageResponse
|
||||||
|
|
||||||
# Set sys.argv as if you were running the script from the command line.
|
|
||||||
|
|
||||||
workspace_dir = "olmocr/bench/sample_data/olmocr/workspace"
|
# Setup basic logging
|
||||||
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
||||||
|
logger = logging.getLogger("olmocr_runner")
|
||||||
|
|
||||||
sys.argv = [
|
# Basic configuration
|
||||||
"pipeline.py", # The script name (can be arbitrary)
|
@dataclass
|
||||||
"olmocr/bench/sample_data/olmocr/workspace", # Positional argument: workspace
|
class Args:
|
||||||
"--pdfs",
|
model: str = "allenai/olmOCR-7B-0225-preview"
|
||||||
*list(glob.glob("olmocr/bench/sample_data/pdfs/*.pdf")), # PDF paths
|
model_chat_template: str = "qwen2-vl"
|
||||||
]
|
model_max_context: int = 8192
|
||||||
|
target_longest_image_dim: int = 1024
|
||||||
|
target_anchor_text_len: int = 6000
|
||||||
|
|
||||||
# Call the async main() function.
|
async def run_olmocr(pdf_path: str, page_num: int = 1, temperature: float = 0.8) -> str:
|
||||||
asyncio.run(olmocr.pipeline.main())
|
"""
|
||||||
|
Process a single page of a PDF using the olmocr pipeline.
|
||||||
# Now, take a produced jsonl files and unpack them into mds
|
|
||||||
for jsonl_path in glob.glob(workspace_dir + "/results/*.jsonl"):
|
Args:
|
||||||
with open(jsonl_path, "r") as jsonl_f:
|
pdf_path: Path to the PDF file
|
||||||
for line in jsonl_f:
|
page_num: Page number to process (1-indexed)
|
||||||
data = json.loads(line)
|
temperature: Temperature parameter for the model
|
||||||
|
|
||||||
name = os.path.basename(data["metadata"]["Source-File"])
|
Returns:
|
||||||
|
The extracted text from the page
|
||||||
with open(f"olmocr/bench/sample_data/olmocr/{name.replace('.pdf', '.md')}", "w") as out_f:
|
"""
|
||||||
out_f.write(data["text"])
|
# Ensure global variables are initialized
|
||||||
|
global metrics, tracker
|
||||||
shutil.rmtree(workspace_dir)
|
if 'metrics' not in globals() or metrics is None:
|
||||||
|
metrics = MetricsKeeper(window=60*5)
|
||||||
|
if 'tracker' not in globals() or tracker is None:
|
||||||
|
tracker = WorkerTracker()
|
||||||
|
|
||||||
|
args = Args()
|
||||||
|
semaphore = asyncio.Semaphore(1)
|
||||||
|
|
||||||
|
# Ensure server is running
|
||||||
|
server_task = None
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(sglang_server_ready(), timeout=5)
|
||||||
|
print("Using existing sglang server")
|
||||||
|
except Exception:
|
||||||
|
print("Starting new sglang server")
|
||||||
|
server_task = asyncio.create_task(sglang_server_host(args, semaphore))
|
||||||
|
await sglang_server_ready()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Process the page
|
||||||
|
query = await build_page_query(
|
||||||
|
pdf_path,
|
||||||
|
page_num,
|
||||||
|
args.target_longest_image_dim,
|
||||||
|
args.target_anchor_text_len
|
||||||
|
)
|
||||||
|
query["temperature"] = temperature
|
||||||
|
|
||||||
|
# Make request and get response
|
||||||
|
url = f"http://localhost:{SGLANG_SERVER_PORT}/v1/chat/completions"
|
||||||
|
status_code, response_body = await apost(url, json_data=query)
|
||||||
|
|
||||||
|
if status_code != 200:
|
||||||
|
return f"Error: HTTP status {status_code}"
|
||||||
|
|
||||||
|
# Parse response
|
||||||
|
response_data = json.loads(response_body)
|
||||||
|
content = response_data["choices"][0]["message"]["content"]
|
||||||
|
model_json = json.loads(content)
|
||||||
|
page_response = PageResponse(**model_json)
|
||||||
|
|
||||||
|
# Update metrics
|
||||||
|
metrics.add_metrics(
|
||||||
|
sglang_input_tokens=response_data["usage"].get("prompt_tokens", 0),
|
||||||
|
sglang_output_tokens=response_data["usage"].get("completion_tokens", 0)
|
||||||
|
)
|
||||||
|
|
||||||
|
return page_response.natural_text
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return f"Error: {type(e).__name__} - {str(e)}"
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# We leave the server running for potential reuse
|
||||||
|
# This is more efficient if multiple pages will be processed
|
||||||
|
pass
|
Loading…
x
Reference in New Issue
Block a user