diff --git a/olmocr/bench/runners/run_dotsocr.py b/olmocr/bench/runners/run_dotsocr.py index 911f0e5..123665c 100644 --- a/olmocr/bench/runners/run_dotsocr.py +++ b/olmocr/bench/runners/run_dotsocr.py @@ -1,4 +1,5 @@ import base64 +import os from io import BytesIO import torch @@ -8,6 +9,10 @@ from qwen_vl_utils import process_vision_info from olmocr.data.renderpdf import render_pdf_to_base64png +# Set LOCAL_RANK as required by DotsOCR +if "LOCAL_RANK" not in os.environ: + os.environ["LOCAL_RANK"] = "0" + # Global cache for the model and processor. _device = "cuda" if torch.cuda.is_available() else "cpu" _model = None @@ -27,6 +32,7 @@ def load_model(model_name: str = "rednote-hilab/dots.ocr"): """ global _model, _processor if _model is None or _processor is None: + # Load model following the official repo pattern _model = AutoModelForCausalLM.from_pretrained( model_name, attn_implementation="flash_attention_2", @@ -117,7 +123,7 @@ def run_dotsocr( return_tensors="pt", ) - inputs = inputs.to(_device) + inputs = inputs.to("cuda") # Inference: Generation of the output with torch.no_grad():