mirror of
https://github.com/allenai/olmocr.git
synced 2025-10-11 08:12:22 +00:00
Working on calibration set for compressor, seems like qwen2.5 is not working
This commit is contained in:
parent
3f9fc8bd1b
commit
b5f480d19d
@ -2,33 +2,121 @@
|
||||
"""
|
||||
Compresses OlmOCR checkpoints using FP8 quantization:
|
||||
1. Loads model from source (local or S3)
|
||||
2. Applies FP8 dynamic quantization
|
||||
2. Applies FP8 dynamic quantization with optional calibration dataset
|
||||
3. Saves compressed model to destination (local or S3)
|
||||
|
||||
Usage:
|
||||
python compress_checkpoint.py <source_path> <destination_path> [--recipe <recipe_path>]
|
||||
python compress_checkpoint.py <source_path> <destination_path> --recipe <recipe_path> [--num-calibration-samples N]
|
||||
|
||||
source_path: Path to checkpoint (local or S3)
|
||||
destination_path: Where to save compressed checkpoint (local or S3)
|
||||
recipe_path: Optional path to quantization config YAML file
|
||||
recipe_path: Path to quantization config YAML file
|
||||
num_calibration_samples: Number of calibration samples to use (default: 100)
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
import tempfile
|
||||
from typing import Optional, Tuple, Union
|
||||
from io import BytesIO
|
||||
from typing import Optional, Tuple, Union, List
|
||||
|
||||
import boto3
|
||||
import torch
|
||||
from llmcompressor import oneshot
|
||||
from transformers import AutoTokenizer, Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration
|
||||
from PIL import Image
|
||||
from transformers import AutoTokenizer, Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration, AutoProcessor
|
||||
|
||||
from olmocr.s3_utils import parse_s3_path
|
||||
from olmocr.pipeline import build_page_query
|
||||
|
||||
|
||||
s3_client = boto3.client("s3")
|
||||
CALIBRATION_S3_PATH = "s3://ai2-oe-data/jakep/olmocr/olmOCR-mix-0225/benchmark_set"
|
||||
|
||||
|
||||
def download_calibration_pdfs(num_samples: int) -> List[str]:
|
||||
"""Download calibration PDFs from S3 and return local paths."""
|
||||
bucket, prefix = parse_s3_path(CALIBRATION_S3_PATH)
|
||||
|
||||
# Create temporary directory for PDFs
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
print(f"Downloading calibration PDFs to {temp_dir}...")
|
||||
|
||||
# List all PDFs in the calibration dataset
|
||||
paginator = s3_client.get_paginator("list_objects_v2")
|
||||
pages = paginator.paginate(Bucket=bucket, Prefix=prefix)
|
||||
|
||||
pdf_keys = []
|
||||
for page in pages:
|
||||
for obj in page.get("Contents", []):
|
||||
key = obj["Key"]
|
||||
if key.endswith(".pdf"):
|
||||
pdf_keys.append(key)
|
||||
|
||||
# Randomly sample PDFs
|
||||
if len(pdf_keys) > num_samples:
|
||||
pdf_keys = random.sample(pdf_keys, num_samples)
|
||||
|
||||
# Download the PDFs
|
||||
local_paths = []
|
||||
for key in pdf_keys:
|
||||
filename = os.path.basename(key)
|
||||
local_path = os.path.join(temp_dir, filename)
|
||||
s3_client.download_file(bucket, key, local_path)
|
||||
local_paths.append(local_path)
|
||||
print(f" Downloaded {filename}")
|
||||
|
||||
print(f"Downloaded {len(local_paths)} calibration PDFs")
|
||||
return local_paths, temp_dir
|
||||
|
||||
|
||||
async def prepare_calibration_dataset(pdf_paths: List[str], processor) -> List[dict]:
|
||||
"""Prepare calibration dataset from PDFs using build_page_query."""
|
||||
dataset = []
|
||||
|
||||
for pdf_path in pdf_paths:
|
||||
# Get first page of each PDF (page 0)
|
||||
query = await build_page_query(pdf_path, page=0, target_longest_image_dim=1024)
|
||||
|
||||
# Extract the image and text from the query
|
||||
messages = query["messages"]
|
||||
if messages and len(messages) > 0:
|
||||
content = messages[0]["content"]
|
||||
|
||||
# Extract image data and text
|
||||
image_data = None
|
||||
text = None
|
||||
|
||||
for item in content:
|
||||
if item["type"] == "image_url":
|
||||
image_data = item["image_url"]["url"]
|
||||
elif item["type"] == "text":
|
||||
text = item["text"]
|
||||
|
||||
if image_data and text:
|
||||
# Convert base64 image to PIL Image
|
||||
# Remove data URL prefix
|
||||
base64_str = image_data.split(",")[1] if "," in image_data else image_data
|
||||
image_bytes = base64.b64decode(base64_str)
|
||||
image = Image.open(BytesIO(image_bytes))
|
||||
|
||||
# Process with the model's processor
|
||||
inputs = processor(
|
||||
text=[text],
|
||||
images=[image],
|
||||
padding=False,
|
||||
truncation=True,
|
||||
max_length=4096
|
||||
)
|
||||
|
||||
dataset.append(inputs)
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
def is_s3_path(path: str) -> bool:
|
||||
@ -150,7 +238,12 @@ def copy_additional_files(source_path: str, dest_path: str, temp_source_dir: Opt
|
||||
shutil.copy2(source_file, dest_file)
|
||||
|
||||
|
||||
def compress_checkpoint(source_path: str, dest_path: str, recipe_path: str) -> None:
|
||||
def data_collator(batch):
|
||||
"""Simple data collator for calibration dataset."""
|
||||
return {key: torch.tensor(value) for key, value in batch[0].items()}
|
||||
|
||||
|
||||
def compress_checkpoint(source_path: str, dest_path: str, recipe_path: str, num_calibration_samples: int = 100) -> None:
|
||||
"""Compress OlmOCR checkpoint using FP8 quantization."""
|
||||
# Load model and tokenizer
|
||||
model, tokenizer, temp_source_dir = load_model_and_tokenizer(source_path)
|
||||
@ -162,9 +255,38 @@ def compress_checkpoint(source_path: str, dest_path: str, recipe_path: str) -> N
|
||||
print(f"{name}: shape={list(param.shape)}, dtype={param.dtype}")
|
||||
print("=========================\n")
|
||||
|
||||
# Prepare calibration dataset if requested
|
||||
dataset = None
|
||||
temp_pdf_dir = None
|
||||
|
||||
if num_calibration_samples > 0:
|
||||
print(f"\nPreparing calibration dataset with {num_calibration_samples} samples...")
|
||||
|
||||
# Load processor for the model
|
||||
processor = AutoProcessor.from_pretrained(source_path if not temp_source_dir else temp_source_dir)
|
||||
|
||||
# Download PDFs
|
||||
pdf_paths, temp_pdf_dir = download_calibration_pdfs(num_calibration_samples)
|
||||
|
||||
# Prepare dataset
|
||||
dataset = asyncio.run(prepare_calibration_dataset(pdf_paths, processor))
|
||||
|
||||
print(f"✓ Prepared {len(dataset)} calibration samples")
|
||||
|
||||
# Apply quantization using provided recipe
|
||||
print(f"\nApplying quantization using recipe: {recipe_path}")
|
||||
oneshot(model=model, recipe=recipe_path)
|
||||
|
||||
if dataset:
|
||||
oneshot(
|
||||
model=model,
|
||||
recipe=recipe_path,
|
||||
dataset=dataset,
|
||||
num_calibration_samples=len(dataset),
|
||||
data_collator=data_collator
|
||||
)
|
||||
else:
|
||||
oneshot(model=model, recipe=recipe_path)
|
||||
|
||||
print("✓ Quantization completed successfully")
|
||||
|
||||
# Save the compressed model
|
||||
@ -199,6 +321,11 @@ def compress_checkpoint(source_path: str, dest_path: str, recipe_path: str) -> N
|
||||
print(f"Cleaning up temporary directory {temp_source_dir}...")
|
||||
shutil.rmtree(temp_source_dir)
|
||||
|
||||
# Clean up temporary PDF directory if needed
|
||||
if temp_pdf_dir:
|
||||
print(f"Cleaning up temporary PDF directory {temp_pdf_dir}...")
|
||||
shutil.rmtree(temp_pdf_dir)
|
||||
|
||||
# Free up GPU memory
|
||||
del model
|
||||
torch.cuda.empty_cache()
|
||||
@ -226,11 +353,13 @@ Examples:
|
||||
parser.add_argument("source", help="Source checkpoint path (local or S3)")
|
||||
parser.add_argument("destination", help="Destination path for compressed checkpoint (local or S3)")
|
||||
parser.add_argument("--recipe", required=True, help="Path to quantization recipe YAML file")
|
||||
parser.add_argument("--num-calibration-samples", type=int, default=100,
|
||||
help="Number of calibration samples to use from benchmark set (default: 100, set to 0 to disable)")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
compress_checkpoint(args.source, args.destination, args.recipe)
|
||||
compress_checkpoint(args.source, args.destination, args.recipe, args.num_calibration_samples)
|
||||
except Exception as e:
|
||||
print(f"\n❌ Error: {e}")
|
||||
return 1
|
||||
|
Loading…
x
Reference in New Issue
Block a user