Cleaning up compress and prepare checkpoint scripts

This commit is contained in:
Jake Poznanski 2025-07-14 20:36:20 +00:00
parent a5a0cd7478
commit 1ede76d0b2
3 changed files with 190 additions and 28 deletions

View File

@ -0,0 +1,190 @@
#!/usr/bin/env python3
"""
Compresses OlmOCR checkpoints using FP8 quantization:
1. Loads model from source (local or S3)
2. Applies FP8 dynamic quantization
3. Saves compressed model to destination (local or S3)
Usage:
python compress_checkpoint.py <source_path> <destination_path>
source_path: Path to checkpoint (local or S3)
destination_path: Where to save compressed checkpoint (local or S3)
"""
import argparse
import json
import os
import shutil
import tempfile
from typing import Optional, Tuple
import boto3
import torch
from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import QuantizationModifier
from smart_open import smart_open
from transformers import AutoTokenizer, Qwen2VLForConditionalGeneration
from olmocr.s3_utils import parse_s3_path
s3_client = boto3.client("s3")
def is_s3_path(path: str) -> bool:
"""Check if a path is an S3 path."""
return path.startswith("s3://")
def download_s3_to_local(bucket: str, prefix: str, local_dir: str) -> None:
"""Download all files from S3 prefix to local directory."""
os.makedirs(local_dir, exist_ok=True)
paginator = s3_client.get_paginator("list_objects_v2")
pages = paginator.paginate(Bucket=bucket, Prefix=prefix)
print(f"Downloading checkpoint from s3://{bucket}/{prefix} to {local_dir}...")
for page in pages:
for obj in page.get("Contents", []):
key = obj["Key"]
if key.endswith("/"):
continue
rel_path = os.path.relpath(key, prefix)
local_path = os.path.join(local_dir, rel_path)
os.makedirs(os.path.dirname(local_path), exist_ok=True)
s3_client.download_file(bucket, key, local_path)
print(f" Downloaded {rel_path}")
def upload_local_to_s3(local_dir: str, bucket: str, prefix: str) -> None:
"""Upload all files from local directory to S3."""
print(f"Uploading compressed checkpoint from {local_dir} to s3://{bucket}/{prefix}...")
for root, _, files in os.walk(local_dir):
for file in files:
local_path = os.path.join(root, file)
rel_path = os.path.relpath(local_path, local_dir)
s3_key = os.path.join(prefix, rel_path)
s3_client.upload_file(local_path, bucket, s3_key)
print(f" Uploaded {rel_path}")
def load_model_and_tokenizer(source_path: str) -> Tuple[Qwen2VLForConditionalGeneration, AutoTokenizer, Optional[str]]:
"""Load model and tokenizer from source path (local or S3)."""
if is_s3_path(source_path):
# Download from S3 to temporary directory
temp_dir = tempfile.mkdtemp()
bucket, prefix = parse_s3_path(source_path)
download_s3_to_local(bucket, prefix, temp_dir)
model_path = temp_dir
else:
model_path = source_path
temp_dir = None
print(f"Loading model from {model_path}...")
model = Qwen2VLForConditionalGeneration.from_pretrained(
model_path,
device_map="auto",
torch_dtype="auto"
)
print(f"Loading tokenizer from {model_path}...")
tokenizer = AutoTokenizer.from_pretrained(model_path)
return model, tokenizer, temp_dir
def compress_checkpoint(source_path: str, dest_path: str) -> None:
"""Compress OlmOCR checkpoint using FP8 quantization."""
# First, validate the source checkpoint
config_path = os.path.join(source_path, "config.json")
if is_s3_path(source_path):
config_path = f"{source_path}/config.json"
# Load model and tokenizer
model, tokenizer, temp_source_dir = load_model_and_tokenizer(source_path)
try:
# Configure FP8 dynamic quantization
print("\nApplying FP8 dynamic quantization...")
recipe = QuantizationModifier(
targets="Linear",
scheme="FP8_DYNAMIC",
ignore=["re:.*lm_head", "re:visual.*"],
)
# Apply the quantization
oneshot(model=model, recipe=recipe)
print("✓ Quantization completed successfully")
# Save the compressed model
if is_s3_path(dest_path):
# Save to temporary directory first, then upload to S3
with tempfile.TemporaryDirectory() as temp_dest_dir:
print(f"\nSaving compressed model to temporary directory...")
model.save_pretrained(temp_dest_dir)
tokenizer.save_pretrained(temp_dest_dir)
# Upload to S3
bucket, prefix = parse_s3_path(dest_path)
upload_local_to_s3(temp_dest_dir, bucket, prefix)
else:
# Save directly to local destination
print(f"\nSaving compressed model to {dest_path}...")
os.makedirs(dest_path, exist_ok=True)
model.save_pretrained(dest_path)
tokenizer.save_pretrained(dest_path)
print(f"\n✓ Successfully compressed checkpoint and saved to {dest_path}")
finally:
# Clean up temporary source directory if needed
if temp_source_dir:
print(f"Cleaning up temporary directory {temp_source_dir}...")
shutil.rmtree(temp_source_dir)
# Free up GPU memory
del model
torch.cuda.empty_cache()
def main():
parser = argparse.ArgumentParser(
description="Compress OlmOCR checkpoint using FP8 quantization",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Local to local
python compress_checkpoint.py /path/to/checkpoint /path/to/compressed
# S3 to S3
python compress_checkpoint.py s3://bucket/checkpoint s3://bucket/compressed
# S3 to local
python compress_checkpoint.py s3://bucket/checkpoint /path/to/compressed
# Local to S3
python compress_checkpoint.py /path/to/checkpoint s3://bucket/compressed
"""
)
parser.add_argument("source", help="Source checkpoint path (local or S3)")
parser.add_argument("destination", help="Destination path for compressed checkpoint (local or S3)")
args = parser.parse_args()
try:
compress_checkpoint(args.source, args.destination)
except Exception as e:
print(f"\n❌ Error: {e}")
return 1
return 0
if __name__ == "__main__":
exit(main())

View File

@ -1,28 +0,0 @@
# pip install llmcompressor
from llmcompressor import oneshot
from llmcompressor.modifiers.quantization import QuantizationModifier
from transformers import AutoTokenizer, Qwen2VLForConditionalGeneration
MODEL_ID = "/home/ubuntu/olmocr/olmOCR-7B-0225-preview"
model = Qwen2VLForConditionalGeneration.from_pretrained(MODEL_ID, device_map="auto", torch_dtype="auto")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
# Configure the simple PTQ quantization
# recipe = QuantizationModifier(
# targets="Linear", scheme="FP8_DYNAMIC", ignore=["lm_head"])
# Configure pre-defined qwen2vl recipe
recipe = QuantizationModifier(
targets="Linear",
scheme="FP8_DYNAMIC",
ignore=["re:.*lm_head", "re:visual.*"],
)
# Apply the quantization algorithm.
oneshot(model=model, recipe=recipe)
# Save the model.
SAVE_DIR = MODEL_ID.split("/")[1] + "-FP8-Dynamic-Recipe"
model.save_pretrained(SAVE_DIR)
tokenizer.save_pretrained(SAVE_DIR)