mirror of
				https://github.com/allenai/olmocr.git
				synced 2025-11-04 03:56:16 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			87 lines
		
	
	
		
			2.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			87 lines
		
	
	
		
			2.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import os
 | 
						|
import json
 | 
						|
import base64
 | 
						|
import logging
 | 
						|
import time
 | 
						|
from io import BytesIO
 | 
						|
from PIL import Image
 | 
						|
from functools import partial
 | 
						|
from logging import Logger
 | 
						|
from pathlib import Path
 | 
						|
from tempfile import TemporaryDirectory
 | 
						|
from typing import Optional
 | 
						|
from tqdm import tqdm
 | 
						|
 | 
						|
import accelerate
 | 
						|
import torch
 | 
						|
import torch.distributed
 | 
						|
from datasets.utils import disable_progress_bars
 | 
						|
from datasets.utils.logging import set_verbosity
 | 
						|
from peft import LoraConfig, get_peft_model  # pyright: ignore
 | 
						|
from transformers import (
 | 
						|
    AutoModelForCausalLM,
 | 
						|
    Trainer,
 | 
						|
    TrainerCallback,
 | 
						|
    TrainingArguments,
 | 
						|
    Qwen2VLForConditionalGeneration,
 | 
						|
    AutoProcessor,
 | 
						|
    Qwen2VLConfig
 | 
						|
)
 | 
						|
 | 
						|
 | 
						|
from pdelfin.data.renderpdf import render_pdf_to_base64png
 | 
						|
from pdelfin.prompts.anchor import get_anchor_text
 | 
						|
from pdelfin.prompts.prompts import build_finetuning_prompt
 | 
						|
 | 
						|
from pdelfin.train.dataprep import prepare_data_for_qwen2_inference
 | 
						|
 | 
						|
def build_page_query(local_pdf_path: str, page: int) -> dict:
 | 
						|
    image_base64 = render_pdf_to_base64png(local_pdf_path, page, 1024)
 | 
						|
    anchor_text = get_anchor_text(local_pdf_path, page, pdf_engine="pdfreport")
 | 
						|
 | 
						|
    return {
 | 
						|
        "input_prompt_text": build_finetuning_prompt(anchor_text),
 | 
						|
        "input_prompt_image_base64": image_base64
 | 
						|
    }
 | 
						|
 | 
						|
 | 
						|
@torch.no_grad()
 | 
						|
def run_inference(model_name: str):    
 | 
						|
    config = Qwen2VLConfig.from_pretrained(model_name)
 | 
						|
    processor = AutoProcessor.from_pretrained(model_name)
 | 
						|
 | 
						|
    # If it doesn't load, change the type:mrope key to "default"
 | 
						|
 | 
						|
    model = Qwen2VLForConditionalGeneration.from_pretrained(model_name, device_map="auto", config=config)
 | 
						|
    model.eval()
 | 
						|
  
 | 
						|
 | 
						|
    query = build_page_query(os.path.join(os.path.dirname(__file__), "..", "..", "tests", "gnarly_pdfs", "overrun_on_pg8.pdf"), 8)
 | 
						|
 | 
						|
    inputs = prepare_data_for_qwen2_inference(query, processor)
 | 
						|
 | 
						|
    print(inputs)
 | 
						|
 | 
						|
    inputs = {
 | 
						|
        x: torch.from_numpy(y).unsqueeze(0).to("cuda")
 | 
						|
            for (x,y) in inputs.items()
 | 
						|
    }
 | 
						|
 | 
						|
    output_ids = model.generate(**inputs, temperature=0.8, do_sample=True, max_new_tokens=1500)
 | 
						|
    generated_ids = [
 | 
						|
        output_ids[len(input_ids) :]
 | 
						|
        for input_ids, output_ids in zip(inputs["input_ids"], output_ids)
 | 
						|
    ]
 | 
						|
    output_text = processor.batch_decode(
 | 
						|
        generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
 | 
						|
    )
 | 
						|
    print(output_text)
 | 
						|
 | 
						|
 | 
						|
 | 
						|
def main():
 | 
						|
    run_inference(model_name="/root/model")
 | 
						|
 | 
						|
 | 
						|
if __name__ == "__main__":
 | 
						|
    main() |