diff --git a/olmocr/bench/runners/run_dotsocr.py b/olmocr/bench/runners/run_dotsocr.py index 123665c..3f381a2 100644 --- a/olmocr/bench/runners/run_dotsocr.py +++ b/olmocr/bench/runners/run_dotsocr.py @@ -1,5 +1,4 @@ import base64 -import os from io import BytesIO import torch @@ -9,17 +8,12 @@ from qwen_vl_utils import process_vision_info from olmocr.data.renderpdf import render_pdf_to_base64png -# Set LOCAL_RANK as required by DotsOCR -if "LOCAL_RANK" not in os.environ: - os.environ["LOCAL_RANK"] = "0" -# Global cache for the model and processor. -_device = "cuda" if torch.cuda.is_available() else "cpu" _model = None _processor = None -def load_model(model_name: str = "rednote-hilab/dots.ocr"): +def load_model(model_name: str = "./weights/DotsOCR"): """ Load the DotsOCR model and processor if they haven't been loaded already. @@ -32,12 +26,12 @@ def load_model(model_name: str = "rednote-hilab/dots.ocr"): """ global _model, _processor if _model is None or _processor is None: - # Load model following the official repo pattern _model = AutoModelForCausalLM.from_pretrained( model_name, - attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16, device_map="auto", + attn_implementation="flash_attention_2", + low_cpu_mem_usage=True, trust_remote_code=True ) _processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True) @@ -47,7 +41,7 @@ def load_model(model_name: str = "rednote-hilab/dots.ocr"): def run_dotsocr( pdf_path: str, page_num: int = 1, - model_name: str = "rednote-hilab/dots.ocr", + model_name: str = "./weights/DotsOCR", target_longest_image_dim: int = 1024 ) -> str: """ @@ -59,7 +53,7 @@ def run_dotsocr( Args: pdf_path (str): The local path to the PDF file. page_num (int): The page number to process (default: 1). - model_name (str): Hugging Face model name (default: "rednote-hilab/dots.ocr"). + model_name (str): Hugging Face model name (default: "./weights/DotsOCR"). target_longest_image_dim (int): Target dimension for the longest side of the image (default: 1024). Returns: @@ -75,24 +69,7 @@ def run_dotsocr( image = Image.open(BytesIO(base64.b64decode(image_base64))) # Define the prompt for layout extraction - prompt = """Please output the layout information from the PDF image, including each layout element's bbox, its category, and the corresponding text content within the bbox. - -1. Bbox format: [x1, y1, x2, y2] - -2. Layout Categories: The possible categories are ['Caption', 'Footnote', 'Formula', 'List-item', 'Page-footer', 'Page-header', 'Picture', 'Section-header', 'Table', 'Text', 'Title']. - -3. Text Extraction & Formatting Rules: - - Picture: For the 'Picture' category, the text field should be omitted. - - Formula: Format its text as LaTeX. - - Table: Format its text as HTML. - - All Others (Text, Title, etc.): Format their text as Markdown. - -4. Constraints: - - The output text must be the original text from the image, with no translation. - - All layout elements must be sorted according to human reading order. - -5. Final Output: The entire output must be a single JSON object. -""" + prompt = """Extract the text content from this image.""" messages = [ { @@ -126,8 +103,8 @@ def run_dotsocr( inputs = inputs.to("cuda") # Inference: Generation of the output - with torch.no_grad(): - generated_ids = model.generate(**inputs, max_new_tokens=24000) + # with torch.no_grad(): + generated_ids = model.generate(**inputs, max_new_tokens=4096) generated_ids_trimmed = [ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)