mirror of
https://github.com/allenai/olmocr.git
synced 2025-10-17 11:12:33 +00:00
Cleaning up compress and prepare checkpoint scripts
This commit is contained in:
parent
a5a0cd7478
commit
1ede76d0b2
190
olmocr/train/compress_checkpoint.py
Executable file
190
olmocr/train/compress_checkpoint.py
Executable file
@ -0,0 +1,190 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Compresses OlmOCR checkpoints using FP8 quantization:
|
||||
1. Loads model from source (local or S3)
|
||||
2. Applies FP8 dynamic quantization
|
||||
3. Saves compressed model to destination (local or S3)
|
||||
|
||||
Usage:
|
||||
python compress_checkpoint.py <source_path> <destination_path>
|
||||
|
||||
source_path: Path to checkpoint (local or S3)
|
||||
destination_path: Where to save compressed checkpoint (local or S3)
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import boto3
|
||||
import torch
|
||||
from llmcompressor import oneshot
|
||||
from llmcompressor.modifiers.quantization import QuantizationModifier
|
||||
from smart_open import smart_open
|
||||
from transformers import AutoTokenizer, Qwen2VLForConditionalGeneration
|
||||
|
||||
from olmocr.s3_utils import parse_s3_path
|
||||
|
||||
|
||||
s3_client = boto3.client("s3")
|
||||
|
||||
|
||||
def is_s3_path(path: str) -> bool:
|
||||
"""Check if a path is an S3 path."""
|
||||
return path.startswith("s3://")
|
||||
|
||||
|
||||
def download_s3_to_local(bucket: str, prefix: str, local_dir: str) -> None:
|
||||
"""Download all files from S3 prefix to local directory."""
|
||||
os.makedirs(local_dir, exist_ok=True)
|
||||
|
||||
paginator = s3_client.get_paginator("list_objects_v2")
|
||||
pages = paginator.paginate(Bucket=bucket, Prefix=prefix)
|
||||
|
||||
print(f"Downloading checkpoint from s3://{bucket}/{prefix} to {local_dir}...")
|
||||
|
||||
for page in pages:
|
||||
for obj in page.get("Contents", []):
|
||||
key = obj["Key"]
|
||||
if key.endswith("/"):
|
||||
continue
|
||||
|
||||
rel_path = os.path.relpath(key, prefix)
|
||||
local_path = os.path.join(local_dir, rel_path)
|
||||
|
||||
os.makedirs(os.path.dirname(local_path), exist_ok=True)
|
||||
s3_client.download_file(bucket, key, local_path)
|
||||
print(f" Downloaded {rel_path}")
|
||||
|
||||
|
||||
def upload_local_to_s3(local_dir: str, bucket: str, prefix: str) -> None:
|
||||
"""Upload all files from local directory to S3."""
|
||||
print(f"Uploading compressed checkpoint from {local_dir} to s3://{bucket}/{prefix}...")
|
||||
|
||||
for root, _, files in os.walk(local_dir):
|
||||
for file in files:
|
||||
local_path = os.path.join(root, file)
|
||||
rel_path = os.path.relpath(local_path, local_dir)
|
||||
s3_key = os.path.join(prefix, rel_path)
|
||||
|
||||
s3_client.upload_file(local_path, bucket, s3_key)
|
||||
print(f" Uploaded {rel_path}")
|
||||
|
||||
|
||||
def load_model_and_tokenizer(source_path: str) -> Tuple[Qwen2VLForConditionalGeneration, AutoTokenizer, Optional[str]]:
|
||||
"""Load model and tokenizer from source path (local or S3)."""
|
||||
if is_s3_path(source_path):
|
||||
# Download from S3 to temporary directory
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
bucket, prefix = parse_s3_path(source_path)
|
||||
download_s3_to_local(bucket, prefix, temp_dir)
|
||||
model_path = temp_dir
|
||||
else:
|
||||
model_path = source_path
|
||||
temp_dir = None
|
||||
|
||||
print(f"Loading model from {model_path}...")
|
||||
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
||||
model_path,
|
||||
device_map="auto",
|
||||
torch_dtype="auto"
|
||||
)
|
||||
|
||||
print(f"Loading tokenizer from {model_path}...")
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||
|
||||
return model, tokenizer, temp_dir
|
||||
|
||||
|
||||
def compress_checkpoint(source_path: str, dest_path: str) -> None:
|
||||
"""Compress OlmOCR checkpoint using FP8 quantization."""
|
||||
# First, validate the source checkpoint
|
||||
config_path = os.path.join(source_path, "config.json")
|
||||
if is_s3_path(source_path):
|
||||
config_path = f"{source_path}/config.json"
|
||||
|
||||
# Load model and tokenizer
|
||||
model, tokenizer, temp_source_dir = load_model_and_tokenizer(source_path)
|
||||
|
||||
try:
|
||||
# Configure FP8 dynamic quantization
|
||||
print("\nApplying FP8 dynamic quantization...")
|
||||
recipe = QuantizationModifier(
|
||||
targets="Linear",
|
||||
scheme="FP8_DYNAMIC",
|
||||
ignore=["re:.*lm_head", "re:visual.*"],
|
||||
)
|
||||
|
||||
# Apply the quantization
|
||||
oneshot(model=model, recipe=recipe)
|
||||
print("✓ Quantization completed successfully")
|
||||
|
||||
# Save the compressed model
|
||||
if is_s3_path(dest_path):
|
||||
# Save to temporary directory first, then upload to S3
|
||||
with tempfile.TemporaryDirectory() as temp_dest_dir:
|
||||
print(f"\nSaving compressed model to temporary directory...")
|
||||
model.save_pretrained(temp_dest_dir)
|
||||
tokenizer.save_pretrained(temp_dest_dir)
|
||||
|
||||
# Upload to S3
|
||||
bucket, prefix = parse_s3_path(dest_path)
|
||||
upload_local_to_s3(temp_dest_dir, bucket, prefix)
|
||||
else:
|
||||
# Save directly to local destination
|
||||
print(f"\nSaving compressed model to {dest_path}...")
|
||||
os.makedirs(dest_path, exist_ok=True)
|
||||
model.save_pretrained(dest_path)
|
||||
tokenizer.save_pretrained(dest_path)
|
||||
|
||||
print(f"\n✓ Successfully compressed checkpoint and saved to {dest_path}")
|
||||
|
||||
finally:
|
||||
# Clean up temporary source directory if needed
|
||||
if temp_source_dir:
|
||||
print(f"Cleaning up temporary directory {temp_source_dir}...")
|
||||
shutil.rmtree(temp_source_dir)
|
||||
|
||||
# Free up GPU memory
|
||||
del model
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Compress OlmOCR checkpoint using FP8 quantization",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
# Local to local
|
||||
python compress_checkpoint.py /path/to/checkpoint /path/to/compressed
|
||||
|
||||
# S3 to S3
|
||||
python compress_checkpoint.py s3://bucket/checkpoint s3://bucket/compressed
|
||||
|
||||
# S3 to local
|
||||
python compress_checkpoint.py s3://bucket/checkpoint /path/to/compressed
|
||||
|
||||
# Local to S3
|
||||
python compress_checkpoint.py /path/to/checkpoint s3://bucket/compressed
|
||||
"""
|
||||
)
|
||||
parser.add_argument("source", help="Source checkpoint path (local or S3)")
|
||||
parser.add_argument("destination", help="Destination path for compressed checkpoint (local or S3)")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
compress_checkpoint(args.source, args.destination)
|
||||
except Exception as e:
|
||||
print(f"\n❌ Error: {e}")
|
||||
return 1
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit(main())
|
@ -1,28 +0,0 @@
|
||||
# pip install llmcompressor
|
||||
from llmcompressor import oneshot
|
||||
from llmcompressor.modifiers.quantization import QuantizationModifier
|
||||
from transformers import AutoTokenizer, Qwen2VLForConditionalGeneration
|
||||
|
||||
MODEL_ID = "/home/ubuntu/olmocr/olmOCR-7B-0225-preview"
|
||||
|
||||
model = Qwen2VLForConditionalGeneration.from_pretrained(MODEL_ID, device_map="auto", torch_dtype="auto")
|
||||
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
||||
|
||||
# Configure the simple PTQ quantization
|
||||
# recipe = QuantizationModifier(
|
||||
# targets="Linear", scheme="FP8_DYNAMIC", ignore=["lm_head"])
|
||||
|
||||
# Configure pre-defined qwen2vl recipe
|
||||
recipe = QuantizationModifier(
|
||||
targets="Linear",
|
||||
scheme="FP8_DYNAMIC",
|
||||
ignore=["re:.*lm_head", "re:visual.*"],
|
||||
)
|
||||
|
||||
# Apply the quantization algorithm.
|
||||
oneshot(model=model, recipe=recipe)
|
||||
|
||||
# Save the model.
|
||||
SAVE_DIR = MODEL_ID.split("/")[1] + "-FP8-Dynamic-Recipe"
|
||||
model.save_pretrained(SAVE_DIR)
|
||||
tokenizer.save_pretrained(SAVE_DIR)
|
Loading…
x
Reference in New Issue
Block a user