Working on calibration set for compressor, seems like qwen2.5 is not working

This commit is contained in:
Jake Poznanski 2025-07-15 18:59:48 +00:00
parent 3f9fc8bd1b
commit b5f480d19d

View File

@ -2,33 +2,121 @@
"""
Compresses OlmOCR checkpoints using FP8 quantization:
1. Loads model from source (local or S3)
2. Applies FP8 dynamic quantization
2. Applies FP8 dynamic quantization with optional calibration dataset
3. Saves compressed model to destination (local or S3)
Usage:
python compress_checkpoint.py <source_path> <destination_path> [--recipe <recipe_path>]
python compress_checkpoint.py <source_path> <destination_path> --recipe <recipe_path> [--num-calibration-samples N]
source_path: Path to checkpoint (local or S3)
destination_path: Where to save compressed checkpoint (local or S3)
recipe_path: Optional path to quantization config YAML file
recipe_path: Path to quantization config YAML file
num_calibration_samples: Number of calibration samples to use (default: 100)
"""
import argparse
import asyncio
import base64
import json
import os
import random
import shutil
import tempfile
from typing import Optional, Tuple, Union
from io import BytesIO
from typing import Optional, Tuple, Union, List
import boto3
import torch
from llmcompressor import oneshot
from transformers import AutoTokenizer, Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration
from PIL import Image
from transformers import AutoTokenizer, Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration, AutoProcessor
from olmocr.s3_utils import parse_s3_path
from olmocr.pipeline import build_page_query
s3_client = boto3.client("s3")
CALIBRATION_S3_PATH = "s3://ai2-oe-data/jakep/olmocr/olmOCR-mix-0225/benchmark_set"
def download_calibration_pdfs(num_samples: int) -> List[str]:
"""Download calibration PDFs from S3 and return local paths."""
bucket, prefix = parse_s3_path(CALIBRATION_S3_PATH)
# Create temporary directory for PDFs
temp_dir = tempfile.mkdtemp()
print(f"Downloading calibration PDFs to {temp_dir}...")
# List all PDFs in the calibration dataset
paginator = s3_client.get_paginator("list_objects_v2")
pages = paginator.paginate(Bucket=bucket, Prefix=prefix)
pdf_keys = []
for page in pages:
for obj in page.get("Contents", []):
key = obj["Key"]
if key.endswith(".pdf"):
pdf_keys.append(key)
# Randomly sample PDFs
if len(pdf_keys) > num_samples:
pdf_keys = random.sample(pdf_keys, num_samples)
# Download the PDFs
local_paths = []
for key in pdf_keys:
filename = os.path.basename(key)
local_path = os.path.join(temp_dir, filename)
s3_client.download_file(bucket, key, local_path)
local_paths.append(local_path)
print(f" Downloaded {filename}")
print(f"Downloaded {len(local_paths)} calibration PDFs")
return local_paths, temp_dir
async def prepare_calibration_dataset(pdf_paths: List[str], processor) -> List[dict]:
"""Prepare calibration dataset from PDFs using build_page_query."""
dataset = []
for pdf_path in pdf_paths:
# Get first page of each PDF (page 0)
query = await build_page_query(pdf_path, page=0, target_longest_image_dim=1024)
# Extract the image and text from the query
messages = query["messages"]
if messages and len(messages) > 0:
content = messages[0]["content"]
# Extract image data and text
image_data = None
text = None
for item in content:
if item["type"] == "image_url":
image_data = item["image_url"]["url"]
elif item["type"] == "text":
text = item["text"]
if image_data and text:
# Convert base64 image to PIL Image
# Remove data URL prefix
base64_str = image_data.split(",")[1] if "," in image_data else image_data
image_bytes = base64.b64decode(base64_str)
image = Image.open(BytesIO(image_bytes))
# Process with the model's processor
inputs = processor(
text=[text],
images=[image],
padding=False,
truncation=True,
max_length=4096
)
dataset.append(inputs)
return dataset
def is_s3_path(path: str) -> bool:
@ -150,7 +238,12 @@ def copy_additional_files(source_path: str, dest_path: str, temp_source_dir: Opt
shutil.copy2(source_file, dest_file)
def compress_checkpoint(source_path: str, dest_path: str, recipe_path: str) -> None:
def data_collator(batch):
"""Simple data collator for calibration dataset."""
return {key: torch.tensor(value) for key, value in batch[0].items()}
def compress_checkpoint(source_path: str, dest_path: str, recipe_path: str, num_calibration_samples: int = 100) -> None:
"""Compress OlmOCR checkpoint using FP8 quantization."""
# Load model and tokenizer
model, tokenizer, temp_source_dir = load_model_and_tokenizer(source_path)
@ -162,9 +255,38 @@ def compress_checkpoint(source_path: str, dest_path: str, recipe_path: str) -> N
print(f"{name}: shape={list(param.shape)}, dtype={param.dtype}")
print("=========================\n")
# Prepare calibration dataset if requested
dataset = None
temp_pdf_dir = None
if num_calibration_samples > 0:
print(f"\nPreparing calibration dataset with {num_calibration_samples} samples...")
# Load processor for the model
processor = AutoProcessor.from_pretrained(source_path if not temp_source_dir else temp_source_dir)
# Download PDFs
pdf_paths, temp_pdf_dir = download_calibration_pdfs(num_calibration_samples)
# Prepare dataset
dataset = asyncio.run(prepare_calibration_dataset(pdf_paths, processor))
print(f"✓ Prepared {len(dataset)} calibration samples")
# Apply quantization using provided recipe
print(f"\nApplying quantization using recipe: {recipe_path}")
oneshot(model=model, recipe=recipe_path)
if dataset:
oneshot(
model=model,
recipe=recipe_path,
dataset=dataset,
num_calibration_samples=len(dataset),
data_collator=data_collator
)
else:
oneshot(model=model, recipe=recipe_path)
print("✓ Quantization completed successfully")
# Save the compressed model
@ -199,6 +321,11 @@ def compress_checkpoint(source_path: str, dest_path: str, recipe_path: str) -> N
print(f"Cleaning up temporary directory {temp_source_dir}...")
shutil.rmtree(temp_source_dir)
# Clean up temporary PDF directory if needed
if temp_pdf_dir:
print(f"Cleaning up temporary PDF directory {temp_pdf_dir}...")
shutil.rmtree(temp_pdf_dir)
# Free up GPU memory
del model
torch.cuda.empty_cache()
@ -226,11 +353,13 @@ Examples:
parser.add_argument("source", help="Source checkpoint path (local or S3)")
parser.add_argument("destination", help="Destination path for compressed checkpoint (local or S3)")
parser.add_argument("--recipe", required=True, help="Path to quantization recipe YAML file")
parser.add_argument("--num-calibration-samples", type=int, default=100,
help="Number of calibration samples to use from benchmark set (default: 100, set to 0 to disable)")
args = parser.parse_args()
try:
compress_checkpoint(args.source, args.destination, args.recipe)
compress_checkpoint(args.source, args.destination, args.recipe, args.num_calibration_samples)
except Exception as e:
print(f"\n❌ Error: {e}")
return 1