Compressor script

This commit is contained in:
Jake Poznanski 2025-07-14 20:56:51 +00:00
parent 1ede76d0b2
commit 01360ba21d

View File

@ -17,14 +17,13 @@ import json
import os import os
import shutil import shutil
import tempfile import tempfile
from typing import Optional, Tuple from typing import Optional, Tuple, Union
import boto3 import boto3
import torch import torch
from llmcompressor import oneshot from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import QuantizationModifier from llmcompressor.modifiers.quantization import QuantizationModifier
from smart_open import smart_open from transformers import AutoTokenizer, Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration
from transformers import AutoTokenizer, Qwen2VLForConditionalGeneration
from olmocr.s3_utils import parse_s3_path from olmocr.s3_utils import parse_s3_path
@ -74,7 +73,7 @@ def upload_local_to_s3(local_dir: str, bucket: str, prefix: str) -> None:
print(f" Uploaded {rel_path}") print(f" Uploaded {rel_path}")
def load_model_and_tokenizer(source_path: str) -> Tuple[Qwen2VLForConditionalGeneration, AutoTokenizer, Optional[str]]: def load_model_and_tokenizer(source_path: str) -> Tuple[Union[Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration], AutoTokenizer, Optional[str]]:
"""Load model and tokenizer from source path (local or S3).""" """Load model and tokenizer from source path (local or S3)."""
if is_s3_path(source_path): if is_s3_path(source_path):
# Download from S3 to temporary directory # Download from S3 to temporary directory
@ -86,12 +85,48 @@ def load_model_and_tokenizer(source_path: str) -> Tuple[Qwen2VLForConditionalGen
model_path = source_path model_path = source_path
temp_dir = None temp_dir = None
# Read config to determine model architecture
config_path = os.path.join(model_path, "config.json")
with open(config_path, "r") as f:
config = json.load(f)
# Get model name from config
model_name = config.get("name_or_path", "")
print(f"Loading model from {model_path}...") print(f"Loading model from {model_path}...")
model = Qwen2VLForConditionalGeneration.from_pretrained(
model_path, # Load appropriate model class based on name
device_map="auto", if "Qwen2.5-VL" in model_name:
torch_dtype="auto" print("Detected Qwen2.5-VL model")
) model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
model_path,
device_map="auto",
torch_dtype="auto"
)
elif "Qwen2-VL" in model_name:
print("Detected Qwen2-VL model")
model = Qwen2VLForConditionalGeneration.from_pretrained(
model_path,
device_map="auto",
torch_dtype="auto"
)
else:
# Default to checking architectures list
architectures = config.get("architectures", [])
if "Qwen2_5_VLForConditionalGeneration" in architectures:
print("Detected Qwen2.5-VL model from architectures")
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
model_path,
device_map="auto",
torch_dtype="auto"
)
else:
print("Detected Qwen2-VL model from architectures")
model = Qwen2VLForConditionalGeneration.from_pretrained(
model_path,
device_map="auto",
torch_dtype="auto"
)
print(f"Loading tokenizer from {model_path}...") print(f"Loading tokenizer from {model_path}...")
tokenizer = AutoTokenizer.from_pretrained(model_path) tokenizer = AutoTokenizer.from_pretrained(model_path)
@ -101,11 +136,6 @@ def load_model_and_tokenizer(source_path: str) -> Tuple[Qwen2VLForConditionalGen
def compress_checkpoint(source_path: str, dest_path: str) -> None: def compress_checkpoint(source_path: str, dest_path: str) -> None:
"""Compress OlmOCR checkpoint using FP8 quantization.""" """Compress OlmOCR checkpoint using FP8 quantization."""
# First, validate the source checkpoint
config_path = os.path.join(source_path, "config.json")
if is_s3_path(source_path):
config_path = f"{source_path}/config.json"
# Load model and tokenizer # Load model and tokenizer
model, tokenizer, temp_source_dir = load_model_and_tokenizer(source_path) model, tokenizer, temp_source_dir = load_model_and_tokenizer(source_path)