mirror of
https://github.com/allenai/olmocr.git
synced 2025-10-11 16:22:29 +00:00
fixed lint and style
This commit is contained in:
parent
7dc6a4b2a5
commit
e1bc7b8861
@ -3,12 +3,11 @@ from io import BytesIO
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import AutoModelForCausalLM, AutoProcessor
|
||||
from qwen_vl_utils import process_vision_info
|
||||
from transformers import AutoModelForCausalLM, AutoProcessor
|
||||
|
||||
from olmocr.data.renderpdf import render_pdf_to_base64png
|
||||
|
||||
|
||||
_model = None
|
||||
_processor = None
|
||||
|
||||
@ -27,23 +26,13 @@ def load_model(model_name: str = "./weights/DotsOCR"):
|
||||
global _model, _processor
|
||||
if _model is None or _processor is None:
|
||||
_model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="auto",
|
||||
attn_implementation="flash_attention_2",
|
||||
low_cpu_mem_usage=True,
|
||||
trust_remote_code=True
|
||||
model_name, torch_dtype=torch.bfloat16, device_map="auto", attn_implementation="flash_attention_2", low_cpu_mem_usage=True, trust_remote_code=True
|
||||
)
|
||||
_processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
|
||||
return _model, _processor
|
||||
|
||||
|
||||
def run_dotsocr(
|
||||
pdf_path: str,
|
||||
page_num: int = 1,
|
||||
model_name: str = "./weights/DotsOCR",
|
||||
target_longest_image_dim: int = 1024
|
||||
) -> str:
|
||||
def run_dotsocr(pdf_path: str, page_num: int = 1, model_name: str = "./weights/DotsOCR", target_longest_image_dim: int = 1024) -> str:
|
||||
"""
|
||||
Convert page of a PDF file to structured layout information using DotsOCR.
|
||||
|
||||
@ -71,25 +60,10 @@ def run_dotsocr(
|
||||
# Define the prompt for layout extraction
|
||||
prompt = """Extract the text content from this image."""
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"image": image
|
||||
},
|
||||
{"type": "text", "text": prompt}
|
||||
]
|
||||
}
|
||||
]
|
||||
messages = [{"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": prompt}]}]
|
||||
|
||||
# Preparation for inference
|
||||
text = processor.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True
|
||||
)
|
||||
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||
|
||||
image_inputs, video_inputs = process_vision_info(messages)
|
||||
inputs = processor(
|
||||
@ -105,13 +79,9 @@ def run_dotsocr(
|
||||
with torch.no_grad():
|
||||
generated_ids = model.generate(**inputs, max_new_tokens=4096)
|
||||
|
||||
generated_ids_trimmed = [
|
||||
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
||||
]
|
||||
generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
|
||||
|
||||
output_text = processor.batch_decode(
|
||||
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
||||
)
|
||||
output_text = processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
||||
|
||||
del inputs
|
||||
del generated_ids
|
||||
|
Loading…
x
Reference in New Issue
Block a user