mirror of
https://github.com/allenai/olmocr.git
synced 2025-09-14 11:09:26 +00:00
87 lines
2.4 KiB
Python
87 lines
2.4 KiB
Python
import os
|
|
import json
|
|
import base64
|
|
import logging
|
|
import time
|
|
from io import BytesIO
|
|
from PIL import Image
|
|
from functools import partial
|
|
from logging import Logger
|
|
from pathlib import Path
|
|
from tempfile import TemporaryDirectory
|
|
from typing import Optional
|
|
from tqdm import tqdm
|
|
|
|
import accelerate
|
|
import torch
|
|
import torch.distributed
|
|
from datasets.utils import disable_progress_bars
|
|
from datasets.utils.logging import set_verbosity
|
|
from peft import LoraConfig, get_peft_model # pyright: ignore
|
|
from transformers import (
|
|
AutoModelForCausalLM,
|
|
Trainer,
|
|
TrainerCallback,
|
|
TrainingArguments,
|
|
Qwen2VLForConditionalGeneration,
|
|
AutoProcessor,
|
|
Qwen2VLConfig
|
|
)
|
|
|
|
|
|
from pdelfin.data.renderpdf import render_pdf_to_base64png
|
|
from pdelfin.prompts.anchor import get_anchor_text
|
|
from pdelfin.prompts.prompts import build_finetuning_prompt
|
|
|
|
from pdelfin.train.dataprep import prepare_data_for_qwen2_inference
|
|
|
|
def build_page_query(local_pdf_path: str, page: int) -> dict:
|
|
image_base64 = render_pdf_to_base64png(local_pdf_path, page, 1024)
|
|
anchor_text = get_anchor_text(local_pdf_path, page, pdf_engine="pdfreport")
|
|
|
|
return {
|
|
"input_prompt_text": build_finetuning_prompt(anchor_text),
|
|
"input_prompt_image_base64": image_base64
|
|
}
|
|
|
|
|
|
@torch.no_grad()
|
|
def run_inference(model_name: str):
|
|
config = Qwen2VLConfig.from_pretrained(model_name)
|
|
processor = AutoProcessor.from_pretrained(model_name)
|
|
|
|
# If it doesn't load, change the type:mrope key to "default"
|
|
|
|
model = Qwen2VLForConditionalGeneration.from_pretrained(model_name, device_map="auto", config=config)
|
|
model.eval()
|
|
|
|
|
|
query = build_page_query(os.path.join(os.path.dirname(__file__), "..", "..", "tests", "gnarly_pdfs", "overrun_on_pg8.pdf"), 8)
|
|
|
|
inputs = prepare_data_for_qwen2_inference(query, processor)
|
|
|
|
print(inputs)
|
|
|
|
inputs = {
|
|
x: torch.from_numpy(y).unsqueeze(0).to("cuda")
|
|
for (x,y) in inputs.items()
|
|
}
|
|
|
|
output_ids = model.generate(**inputs, temperature=0.8, do_sample=True, max_new_tokens=1500)
|
|
generated_ids = [
|
|
output_ids[len(input_ids) :]
|
|
for input_ids, output_ids in zip(inputs["input_ids"], output_ids)
|
|
]
|
|
output_text = processor.batch_decode(
|
|
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
|
|
)
|
|
print(output_text)
|
|
|
|
|
|
|
|
def main():
|
|
run_inference(model_name="/root/model")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main() |