diff --git a/olmocr/train/compress_checkpoint.py b/olmocr/train/compress_checkpoint.py index 97e1852..cd1f85c 100755 --- a/olmocr/train/compress_checkpoint.py +++ b/olmocr/train/compress_checkpoint.py @@ -31,6 +31,7 @@ import torch from llmcompressor import oneshot from PIL import Image from transformers import AutoTokenizer, Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration, AutoProcessor +from qwen_vl_utils import process_vision_info from olmocr.s3_utils import parse_s3_path from olmocr.pipeline import build_page_query @@ -81,36 +82,22 @@ async def prepare_calibration_dataset(pdf_paths: List[str], processor) -> List[d # Extract the image and text from the query messages = query["messages"] - if messages and len(messages) > 0: - content = messages[0]["content"] - - # Extract image data and text - image_data = None - text = None - - for item in content: - if item["type"] == "image_url": - image_data = item["image_url"]["url"] - elif item["type"] == "text": - text = item["text"] - - if image_data and text: - # Convert base64 image to PIL Image - # Remove data URL prefix - base64_str = image_data.split(",")[1] if "," in image_data else image_data - image_bytes = base64.b64decode(base64_str) - image = Image.open(BytesIO(image_bytes)) - - # Process with the model's processor - inputs = processor( - text=[text], - images=[image], - padding=False, - truncation=True, - max_length=4096 - ) - - dataset.append(inputs) + + text = processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + + image_inputs, video_inputs = process_vision_info(messages) + + # tokenize + return processor( + text=[text], + images=image_inputs, + videos=video_inputs, + padding=False, + max_length=8192, + truncation=True, + ) return dataset