fixed dotsocr runner

This commit is contained in:
aman-17 2025-09-19 16:14:05 -07:00
parent 796c021ab8
commit 4f7623c429

View File

@ -1,4 +1,5 @@
import base64
import os
from io import BytesIO
import torch
@ -8,6 +9,10 @@ from qwen_vl_utils import process_vision_info
from olmocr.data.renderpdf import render_pdf_to_base64png
# Set LOCAL_RANK as required by DotsOCR
if "LOCAL_RANK" not in os.environ:
os.environ["LOCAL_RANK"] = "0"
# Global cache for the model and processor.
_device = "cuda" if torch.cuda.is_available() else "cpu"
_model = None
@ -27,6 +32,7 @@ def load_model(model_name: str = "rednote-hilab/dots.ocr"):
"""
global _model, _processor
if _model is None or _processor is None:
# Load model following the official repo pattern
_model = AutoModelForCausalLM.from_pretrained(
model_name,
attn_implementation="flash_attention_2",
@ -117,7 +123,7 @@ def run_dotsocr(
return_tensors="pt",
)
inputs = inputs.to(_device)
inputs = inputs.to("cuda")
# Inference: Generation of the output
with torch.no_grad():