mirror of
https://github.com/allenai/olmocr.git
synced 2025-10-26 23:53:31 +00:00
fixed dotsocr runner
This commit is contained in:
parent
796c021ab8
commit
4f7623c429
@ -1,4 +1,5 @@
|
|||||||
import base64
|
import base64
|
||||||
|
import os
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -8,6 +9,10 @@ from qwen_vl_utils import process_vision_info
|
|||||||
|
|
||||||
from olmocr.data.renderpdf import render_pdf_to_base64png
|
from olmocr.data.renderpdf import render_pdf_to_base64png
|
||||||
|
|
||||||
|
# Set LOCAL_RANK as required by DotsOCR
|
||||||
|
if "LOCAL_RANK" not in os.environ:
|
||||||
|
os.environ["LOCAL_RANK"] = "0"
|
||||||
|
|
||||||
# Global cache for the model and processor.
|
# Global cache for the model and processor.
|
||||||
_device = "cuda" if torch.cuda.is_available() else "cpu"
|
_device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
_model = None
|
_model = None
|
||||||
@ -27,6 +32,7 @@ def load_model(model_name: str = "rednote-hilab/dots.ocr"):
|
|||||||
"""
|
"""
|
||||||
global _model, _processor
|
global _model, _processor
|
||||||
if _model is None or _processor is None:
|
if _model is None or _processor is None:
|
||||||
|
# Load model following the official repo pattern
|
||||||
_model = AutoModelForCausalLM.from_pretrained(
|
_model = AutoModelForCausalLM.from_pretrained(
|
||||||
model_name,
|
model_name,
|
||||||
attn_implementation="flash_attention_2",
|
attn_implementation="flash_attention_2",
|
||||||
@ -117,7 +123,7 @@ def run_dotsocr(
|
|||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
)
|
)
|
||||||
|
|
||||||
inputs = inputs.to(_device)
|
inputs = inputs.to("cuda")
|
||||||
|
|
||||||
# Inference: Generation of the output
|
# Inference: Generation of the output
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user