From 01360ba21dfcd82067dbe05fa2e4606d3a576540 Mon Sep 17 00:00:00 2001 From: Jake Poznanski Date: Mon, 14 Jul 2025 20:56:51 +0000 Subject: [PATCH] Compressor script --- olmocr/train/compress_checkpoint.py | 58 ++++++++++++++++++++++------- 1 file changed, 44 insertions(+), 14 deletions(-) diff --git a/olmocr/train/compress_checkpoint.py b/olmocr/train/compress_checkpoint.py index b7defc4..e76ba30 100755 --- a/olmocr/train/compress_checkpoint.py +++ b/olmocr/train/compress_checkpoint.py @@ -17,14 +17,13 @@ import json import os import shutil import tempfile -from typing import Optional, Tuple +from typing import Optional, Tuple, Union import boto3 import torch from llmcompressor import oneshot from llmcompressor.modifiers.quantization import QuantizationModifier -from smart_open import smart_open -from transformers import AutoTokenizer, Qwen2VLForConditionalGeneration +from transformers import AutoTokenizer, Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration from olmocr.s3_utils import parse_s3_path @@ -74,7 +73,7 @@ def upload_local_to_s3(local_dir: str, bucket: str, prefix: str) -> None: print(f" Uploaded {rel_path}") -def load_model_and_tokenizer(source_path: str) -> Tuple[Qwen2VLForConditionalGeneration, AutoTokenizer, Optional[str]]: +def load_model_and_tokenizer(source_path: str) -> Tuple[Union[Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration], AutoTokenizer, Optional[str]]: """Load model and tokenizer from source path (local or S3).""" if is_s3_path(source_path): # Download from S3 to temporary directory @@ -86,12 +85,48 @@ def load_model_and_tokenizer(source_path: str) -> Tuple[Qwen2VLForConditionalGen model_path = source_path temp_dir = None + # Read config to determine model architecture + config_path = os.path.join(model_path, "config.json") + with open(config_path, "r") as f: + config = json.load(f) + + # Get model name from config + model_name = config.get("name_or_path", "") + print(f"Loading model from {model_path}...") - model = Qwen2VLForConditionalGeneration.from_pretrained( - model_path, - device_map="auto", - torch_dtype="auto" - ) + + # Load appropriate model class based on name + if "Qwen2.5-VL" in model_name: + print("Detected Qwen2.5-VL model") + model = Qwen2_5_VLForConditionalGeneration.from_pretrained( + model_path, + device_map="auto", + torch_dtype="auto" + ) + elif "Qwen2-VL" in model_name: + print("Detected Qwen2-VL model") + model = Qwen2VLForConditionalGeneration.from_pretrained( + model_path, + device_map="auto", + torch_dtype="auto" + ) + else: + # Default to checking architectures list + architectures = config.get("architectures", []) + if "Qwen2_5_VLForConditionalGeneration" in architectures: + print("Detected Qwen2.5-VL model from architectures") + model = Qwen2_5_VLForConditionalGeneration.from_pretrained( + model_path, + device_map="auto", + torch_dtype="auto" + ) + else: + print("Detected Qwen2-VL model from architectures") + model = Qwen2VLForConditionalGeneration.from_pretrained( + model_path, + device_map="auto", + torch_dtype="auto" + ) print(f"Loading tokenizer from {model_path}...") tokenizer = AutoTokenizer.from_pretrained(model_path) @@ -101,11 +136,6 @@ def load_model_and_tokenizer(source_path: str) -> Tuple[Qwen2VLForConditionalGen def compress_checkpoint(source_path: str, dest_path: str) -> None: """Compress OlmOCR checkpoint using FP8 quantization.""" - # First, validate the source checkpoint - config_path = os.path.join(source_path, "config.json") - if is_s3_path(source_path): - config_path = f"{source_path}/config.json" - # Load model and tokenizer model, tokenizer, temp_source_dir = load_model_and_tokenizer(source_path)