diff --git a/scripts/hf_local_test.py b/scripts/hf_local_test.py new file mode 100644 index 0000000..65a8cf6 --- /dev/null +++ b/scripts/hf_local_test.py @@ -0,0 +1,65 @@ +import torch +import base64 +import urllib.request + +from io import BytesIO +from PIL import Image +from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration + +from olmocr.data.renderpdf import render_pdf_to_base64png +from olmocr.prompts import build_no_anchoring_v4_yaml_prompt + +# Initialize the model +model = Qwen2_5_VLForConditionalGeneration.from_pretrained("allenai/olmOCR-7B-1025", torch_dtype=torch.bfloat16).eval() +processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct") +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +model.to(device) + +# Grab a sample PDF +urllib.request.urlretrieve("https://olmocr.allenai.org/papers/olmocr.pdf", "./paper.pdf") + +# Render page 1 to an image +image_base64 = render_pdf_to_base64png("./paper.pdf", 1, target_longest_image_dim=1288) + + +# Build the full prompt +messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": build_no_anchoring_v4_yaml_prompt()}, + {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_base64}"}}, + ], + } + ] + +# Apply the chat template and processor +text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) +main_image = Image.open(BytesIO(base64.b64decode(image_base64))) + +inputs = processor( + text=[text], + images=[main_image], + padding=True, + return_tensors="pt", +) +inputs = {key: value.to(device) for (key, value) in inputs.items()} + + +# Generate the output +output = model.generate( + **inputs, + temperature=0.1, + max_new_tokens=50, + num_return_sequences=1, + do_sample=True, + ) + +# Decode the output +prompt_length = inputs["input_ids"].shape[1] +new_tokens = output[:, prompt_length:] +text_output = processor.tokenizer.batch_decode( + new_tokens, skip_special_tokens=True +) + +print(text_output) \ No newline at end of file