mirror of
https://github.com/allenai/olmocr.git
synced 2025-10-26 15:44:17 +00:00
fixed dotsocr runner
This commit is contained in:
parent
796c021ab8
commit
4f7623c429
@ -1,4 +1,5 @@
|
||||
import base64
|
||||
import os
|
||||
from io import BytesIO
|
||||
|
||||
import torch
|
||||
@ -8,6 +9,10 @@ from qwen_vl_utils import process_vision_info
|
||||
|
||||
from olmocr.data.renderpdf import render_pdf_to_base64png
|
||||
|
||||
# Set LOCAL_RANK as required by DotsOCR
|
||||
if "LOCAL_RANK" not in os.environ:
|
||||
os.environ["LOCAL_RANK"] = "0"
|
||||
|
||||
# Global cache for the model and processor.
|
||||
_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
_model = None
|
||||
@ -27,6 +32,7 @@ def load_model(model_name: str = "rednote-hilab/dots.ocr"):
|
||||
"""
|
||||
global _model, _processor
|
||||
if _model is None or _processor is None:
|
||||
# Load model following the official repo pattern
|
||||
_model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
attn_implementation="flash_attention_2",
|
||||
@ -117,7 +123,7 @@ def run_dotsocr(
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
inputs = inputs.to(_device)
|
||||
inputs = inputs.to("cuda")
|
||||
|
||||
# Inference: Generation of the output
|
||||
with torch.no_grad():
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user