mirror of
https://github.com/allenai/olmocr.git
synced 2025-09-15 03:28:08 +00:00
87 lines
2.4 KiB
Python
87 lines
2.4 KiB
Python
![]() |
import os
|
||
|
import json
|
||
|
import base64
|
||
|
import logging
|
||
|
import time
|
||
|
from io import BytesIO
|
||
|
from PIL import Image
|
||
|
from functools import partial
|
||
|
from logging import Logger
|
||
|
from pathlib import Path
|
||
|
from tempfile import TemporaryDirectory
|
||
|
from typing import Optional
|
||
|
from tqdm import tqdm
|
||
|
|
||
|
import accelerate
|
||
|
import torch
|
||
|
import torch.distributed
|
||
|
from datasets.utils import disable_progress_bars
|
||
|
from datasets.utils.logging import set_verbosity
|
||
|
from peft import LoraConfig, get_peft_model # pyright: ignore
|
||
|
from transformers import (
|
||
|
AutoModelForCausalLM,
|
||
|
Trainer,
|
||
|
TrainerCallback,
|
||
|
TrainingArguments,
|
||
|
Qwen2VLForConditionalGeneration,
|
||
|
AutoProcessor,
|
||
|
Qwen2VLConfig
|
||
|
)
|
||
|
|
||
|
|
||
|
from pdelfin.data.renderpdf import render_pdf_to_base64png
|
||
|
from pdelfin.prompts.anchor import get_anchor_text
|
||
|
from pdelfin.prompts.prompts import build_finetuning_prompt
|
||
|
|
||
|
from pdelfin.train.dataprep import prepare_data_for_qwen2_inference
|
||
|
|
||
|
def build_page_query(local_pdf_path: str, page: int) -> dict:
|
||
|
image_base64 = render_pdf_to_base64png(local_pdf_path, page, 1024)
|
||
|
anchor_text = get_anchor_text(local_pdf_path, page, pdf_engine="pdfreport")
|
||
|
|
||
|
return {
|
||
|
"input_prompt_text": build_finetuning_prompt(anchor_text),
|
||
|
"input_prompt_image_base64": image_base64
|
||
|
}
|
||
|
|
||
|
|
||
|
@torch.no_grad()
|
||
|
def run_inference(model_name: str):
|
||
|
config = Qwen2VLConfig.from_pretrained(model_name)
|
||
|
processor = AutoProcessor.from_pretrained(model_name)
|
||
|
|
||
|
# If it doesn't load, change the type:mrope key to "default"
|
||
|
|
||
|
model = Qwen2VLForConditionalGeneration.from_pretrained(model_name, device_map="auto", config=config)
|
||
|
model.eval()
|
||
|
|
||
|
|
||
|
query = build_page_query(os.path.join(os.path.dirname(__file__), "..", "..", "tests", "gnarly_pdfs", "overrun_on_pg8.pdf"), 8)
|
||
|
|
||
|
inputs = prepare_data_for_qwen2_inference(query, processor)
|
||
|
|
||
|
print(inputs)
|
||
|
|
||
|
inputs = {
|
||
|
x: torch.from_numpy(y).unsqueeze(0).to("cuda")
|
||
|
for (x,y) in inputs.items()
|
||
|
}
|
||
|
|
||
|
output_ids = model.generate(**inputs, temperature=0.8, do_sample=True, max_new_tokens=1500)
|
||
|
generated_ids = [
|
||
|
output_ids[len(input_ids) :]
|
||
|
for input_ids, output_ids in zip(inputs["input_ids"], output_ids)
|
||
|
]
|
||
|
output_text = processor.batch_decode(
|
||
|
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
|
||
|
)
|
||
|
print(output_text)
|
||
|
|
||
|
|
||
|
|
||
|
def main():
|
||
|
run_inference(model_name="/root/model")
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
main()
|