olmocr/pdelfin/train/inference.py

87 lines
2.4 KiB
Python

import os
import json
import base64
import logging
import time
from io import BytesIO
from PIL import Image
from functools import partial
from logging import Logger
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Optional
from tqdm import tqdm
import accelerate
import torch
import torch.distributed
from datasets.utils import disable_progress_bars
from datasets.utils.logging import set_verbosity
from peft import LoraConfig, get_peft_model # pyright: ignore
from transformers import (
AutoModelForCausalLM,
Trainer,
TrainerCallback,
TrainingArguments,
Qwen2VLForConditionalGeneration,
AutoProcessor,
Qwen2VLConfig
)
from pdelfin.data.renderpdf import render_pdf_to_base64png
from pdelfin.prompts.anchor import get_anchor_text
from pdelfin.prompts.prompts import build_finetuning_prompt
from pdelfin.train.dataprep import prepare_data_for_qwen2_inference
def build_page_query(local_pdf_path: str, page: int) -> dict:
image_base64 = render_pdf_to_base64png(local_pdf_path, page, 1024)
anchor_text = get_anchor_text(local_pdf_path, page, pdf_engine="pdfreport")
return {
"input_prompt_text": build_finetuning_prompt(anchor_text),
"input_prompt_image_base64": image_base64
}
@torch.no_grad()
def run_inference(model_name: str):
config = Qwen2VLConfig.from_pretrained(model_name)
processor = AutoProcessor.from_pretrained(model_name)
# If it doesn't load, change the type:mrope key to "default"
model = Qwen2VLForConditionalGeneration.from_pretrained(model_name, device_map="auto", config=config)
model.eval()
query = build_page_query(os.path.join(os.path.dirname(__file__), "..", "..", "tests", "gnarly_pdfs", "overrun_on_pg8.pdf"), 8)
inputs = prepare_data_for_qwen2_inference(query, processor)
print(inputs)
inputs = {
x: torch.from_numpy(y).unsqueeze(0).to("cuda")
for (x,y) in inputs.items()
}
output_ids = model.generate(**inputs, temperature=0.8, do_sample=True, max_new_tokens=1500)
generated_ids = [
output_ids[len(input_ids) :]
for input_ids, output_ids in zip(inputs["input_ids"], output_ids)
]
output_text = processor.batch_decode(
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
)
print(output_text)
def main():
run_inference(model_name="/root/model")
if __name__ == "__main__":
main()