Compressor script

This commit is contained in:
Jake Poznanski 2025-07-14 20:56:51 +00:00
parent 1ede76d0b2
commit 01360ba21d

View File

@ -17,14 +17,13 @@ import json
import os
import shutil
import tempfile
from typing import Optional, Tuple
from typing import Optional, Tuple, Union
import boto3
import torch
from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import QuantizationModifier
from smart_open import smart_open
from transformers import AutoTokenizer, Qwen2VLForConditionalGeneration
from transformers import AutoTokenizer, Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration
from olmocr.s3_utils import parse_s3_path
@ -74,7 +73,7 @@ def upload_local_to_s3(local_dir: str, bucket: str, prefix: str) -> None:
print(f" Uploaded {rel_path}")
def load_model_and_tokenizer(source_path: str) -> Tuple[Qwen2VLForConditionalGeneration, AutoTokenizer, Optional[str]]:
def load_model_and_tokenizer(source_path: str) -> Tuple[Union[Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration], AutoTokenizer, Optional[str]]:
"""Load model and tokenizer from source path (local or S3)."""
if is_s3_path(source_path):
# Download from S3 to temporary directory
@ -86,12 +85,48 @@ def load_model_and_tokenizer(source_path: str) -> Tuple[Qwen2VLForConditionalGen
model_path = source_path
temp_dir = None
# Read config to determine model architecture
config_path = os.path.join(model_path, "config.json")
with open(config_path, "r") as f:
config = json.load(f)
# Get model name from config
model_name = config.get("name_or_path", "")
print(f"Loading model from {model_path}...")
model = Qwen2VLForConditionalGeneration.from_pretrained(
model_path,
device_map="auto",
torch_dtype="auto"
)
# Load appropriate model class based on name
if "Qwen2.5-VL" in model_name:
print("Detected Qwen2.5-VL model")
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
model_path,
device_map="auto",
torch_dtype="auto"
)
elif "Qwen2-VL" in model_name:
print("Detected Qwen2-VL model")
model = Qwen2VLForConditionalGeneration.from_pretrained(
model_path,
device_map="auto",
torch_dtype="auto"
)
else:
# Default to checking architectures list
architectures = config.get("architectures", [])
if "Qwen2_5_VLForConditionalGeneration" in architectures:
print("Detected Qwen2.5-VL model from architectures")
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
model_path,
device_map="auto",
torch_dtype="auto"
)
else:
print("Detected Qwen2-VL model from architectures")
model = Qwen2VLForConditionalGeneration.from_pretrained(
model_path,
device_map="auto",
torch_dtype="auto"
)
print(f"Loading tokenizer from {model_path}...")
tokenizer = AutoTokenizer.from_pretrained(model_path)
@ -101,11 +136,6 @@ def load_model_and_tokenizer(source_path: str) -> Tuple[Qwen2VLForConditionalGen
def compress_checkpoint(source_path: str, dest_path: str) -> None:
"""Compress OlmOCR checkpoint using FP8 quantization."""
# First, validate the source checkpoint
config_path = os.path.join(source_path, "config.json")
if is_s3_path(source_path):
config_path = f"{source_path}/config.json"
# Load model and tokenizer
model, tokenizer, temp_source_dir = load_model_and_tokenizer(source_path)