mirror of
https://github.com/allenai/olmocr.git
synced 2025-10-11 16:22:29 +00:00
fixed lint and style
This commit is contained in:
parent
7dc6a4b2a5
commit
e1bc7b8861
@ -3,12 +3,11 @@ from io import BytesIO
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from transformers import AutoModelForCausalLM, AutoProcessor
|
|
||||||
from qwen_vl_utils import process_vision_info
|
from qwen_vl_utils import process_vision_info
|
||||||
|
from transformers import AutoModelForCausalLM, AutoProcessor
|
||||||
|
|
||||||
from olmocr.data.renderpdf import render_pdf_to_base64png
|
from olmocr.data.renderpdf import render_pdf_to_base64png
|
||||||
|
|
||||||
|
|
||||||
_model = None
|
_model = None
|
||||||
_processor = None
|
_processor = None
|
||||||
|
|
||||||
@ -27,23 +26,13 @@ def load_model(model_name: str = "./weights/DotsOCR"):
|
|||||||
global _model, _processor
|
global _model, _processor
|
||||||
if _model is None or _processor is None:
|
if _model is None or _processor is None:
|
||||||
_model = AutoModelForCausalLM.from_pretrained(
|
_model = AutoModelForCausalLM.from_pretrained(
|
||||||
model_name,
|
model_name, torch_dtype=torch.bfloat16, device_map="auto", attn_implementation="flash_attention_2", low_cpu_mem_usage=True, trust_remote_code=True
|
||||||
torch_dtype=torch.bfloat16,
|
|
||||||
device_map="auto",
|
|
||||||
attn_implementation="flash_attention_2",
|
|
||||||
low_cpu_mem_usage=True,
|
|
||||||
trust_remote_code=True
|
|
||||||
)
|
)
|
||||||
_processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
|
_processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
|
||||||
return _model, _processor
|
return _model, _processor
|
||||||
|
|
||||||
|
|
||||||
def run_dotsocr(
|
def run_dotsocr(pdf_path: str, page_num: int = 1, model_name: str = "./weights/DotsOCR", target_longest_image_dim: int = 1024) -> str:
|
||||||
pdf_path: str,
|
|
||||||
page_num: int = 1,
|
|
||||||
model_name: str = "./weights/DotsOCR",
|
|
||||||
target_longest_image_dim: int = 1024
|
|
||||||
) -> str:
|
|
||||||
"""
|
"""
|
||||||
Convert page of a PDF file to structured layout information using DotsOCR.
|
Convert page of a PDF file to structured layout information using DotsOCR.
|
||||||
|
|
||||||
@ -71,25 +60,10 @@ def run_dotsocr(
|
|||||||
# Define the prompt for layout extraction
|
# Define the prompt for layout extraction
|
||||||
prompt = """Extract the text content from this image."""
|
prompt = """Extract the text content from this image."""
|
||||||
|
|
||||||
messages = [
|
messages = [{"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": prompt}]}]
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": [
|
|
||||||
{
|
|
||||||
"type": "image",
|
|
||||||
"image": image
|
|
||||||
},
|
|
||||||
{"type": "text", "text": prompt}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
]
|
|
||||||
|
|
||||||
# Preparation for inference
|
# Preparation for inference
|
||||||
text = processor.apply_chat_template(
|
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||||
messages,
|
|
||||||
tokenize=False,
|
|
||||||
add_generation_prompt=True
|
|
||||||
)
|
|
||||||
|
|
||||||
image_inputs, video_inputs = process_vision_info(messages)
|
image_inputs, video_inputs = process_vision_info(messages)
|
||||||
inputs = processor(
|
inputs = processor(
|
||||||
@ -105,13 +79,9 @@ def run_dotsocr(
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
generated_ids = model.generate(**inputs, max_new_tokens=4096)
|
generated_ids = model.generate(**inputs, max_new_tokens=4096)
|
||||||
|
|
||||||
generated_ids_trimmed = [
|
generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
|
||||||
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
|
|
||||||
]
|
|
||||||
|
|
||||||
output_text = processor.batch_decode(
|
output_text = processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)
|
||||||
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
|
|
||||||
)
|
|
||||||
|
|
||||||
del inputs
|
del inputs
|
||||||
del generated_ids
|
del generated_ids
|
||||||
|
Loading…
x
Reference in New Issue
Block a user