fixed lint and style

This commit is contained in:
aman-17 2025-09-26 19:44:14 +00:00
parent 7dc6a4b2a5
commit e1bc7b8861

View File

@ -3,12 +3,11 @@ from io import BytesIO
import torch
from PIL import Image
from transformers import AutoModelForCausalLM, AutoProcessor
from qwen_vl_utils import process_vision_info
from transformers import AutoModelForCausalLM, AutoProcessor
from olmocr.data.renderpdf import render_pdf_to_base64png
_model = None
_processor = None
@ -27,23 +26,13 @@ def load_model(model_name: str = "./weights/DotsOCR"):
global _model, _processor
if _model is None or _processor is None:
_model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map="auto",
attn_implementation="flash_attention_2",
low_cpu_mem_usage=True,
trust_remote_code=True
model_name, torch_dtype=torch.bfloat16, device_map="auto", attn_implementation="flash_attention_2", low_cpu_mem_usage=True, trust_remote_code=True
)
_processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
return _model, _processor
def run_dotsocr(
pdf_path: str,
page_num: int = 1,
model_name: str = "./weights/DotsOCR",
target_longest_image_dim: int = 1024
) -> str:
def run_dotsocr(pdf_path: str, page_num: int = 1, model_name: str = "./weights/DotsOCR", target_longest_image_dim: int = 1024) -> str:
"""
Convert page of a PDF file to structured layout information using DotsOCR.
@ -71,25 +60,10 @@ def run_dotsocr(
# Define the prompt for layout extraction
prompt = """Extract the text content from this image."""
messages = [
{
"role": "user",
"content": [
{
"type": "image",
"image": image
},
{"type": "text", "text": prompt}
]
}
]
messages = [{"role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": prompt}]}]
# Preparation for inference
text = processor.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
@ -105,17 +79,13 @@ def run_dotsocr(
with torch.no_grad():
generated_ids = model.generate(**inputs, max_new_tokens=4096)
generated_ids_trimmed = [
out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
output_text = processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
output_text = processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)
del inputs
del generated_ids
del generated_ids_trimmed
torch.cuda.empty_cache()
return output_text[0] if output_text else ""
return output_text[0] if output_text else ""