diff --git a/olmocr/bench/runners/run_dotsocr.py b/olmocr/bench/runners/run_dotsocr.py index ff8b410..24764a5 100644 --- a/olmocr/bench/runners/run_dotsocr.py +++ b/olmocr/bench/runners/run_dotsocr.py @@ -3,12 +3,11 @@ from io import BytesIO import torch from PIL import Image -from transformers import AutoModelForCausalLM, AutoProcessor from qwen_vl_utils import process_vision_info +from transformers import AutoModelForCausalLM, AutoProcessor from olmocr.data.renderpdf import render_pdf_to_base64png - _model = None _processor = None @@ -27,23 +26,13 @@ def load_model(model_name: str = "./weights/DotsOCR"): global _model, _processor if _model is None or _processor is None: _model = AutoModelForCausalLM.from_pretrained( - model_name, - torch_dtype=torch.bfloat16, - device_map="auto", - attn_implementation="flash_attention_2", - low_cpu_mem_usage=True, - trust_remote_code=True + model_name, torch_dtype=torch.bfloat16, device_map="auto", attn_implementation="flash_attention_2", low_cpu_mem_usage=True, trust_remote_code=True ) _processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True) return _model, _processor -def run_dotsocr( - pdf_path: str, - page_num: int = 1, - model_name: str = "./weights/DotsOCR", - target_longest_image_dim: int = 1024 -) -> str: +def run_dotsocr(pdf_path: str, page_num: int = 1, model_name: str = "./weights/DotsOCR", target_longest_image_dim: int = 1024) -> str: """ Convert page of a PDF file to structured layout information using DotsOCR. @@ -71,25 +60,10 @@ def run_dotsocr( # Define the prompt for layout extraction prompt = """Extract the text content from this image.""" - messages = [ - { - "role": "user", - "content": [ - { - "type": "image", - "image": image - }, - {"type": "text", "text": prompt} - ] - } - ] + messages = [{"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": prompt}]}] # Preparation for inference - text = processor.apply_chat_template( - messages, - tokenize=False, - add_generation_prompt=True - ) + text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) image_inputs, video_inputs = process_vision_info(messages) inputs = processor( @@ -105,17 +79,13 @@ def run_dotsocr( with torch.no_grad(): generated_ids = model.generate(**inputs, max_new_tokens=4096) - generated_ids_trimmed = [ - out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) - ] + generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)] - output_text = processor.batch_decode( - generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False - ) + output_text = processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False) del inputs del generated_ids del generated_ids_trimmed torch.cuda.empty_cache() - return output_text[0] if output_text else "" \ No newline at end of file + return output_text[0] if output_text else ""