mirror of
https://github.com/allenai/olmocr.git
synced 2025-10-11 16:22:29 +00:00
Working on calibration set for compressor, seems like qwen2.5 is not working
This commit is contained in:
parent
3f9fc8bd1b
commit
b5f480d19d
@ -2,33 +2,121 @@
|
|||||||
"""
|
"""
|
||||||
Compresses OlmOCR checkpoints using FP8 quantization:
|
Compresses OlmOCR checkpoints using FP8 quantization:
|
||||||
1. Loads model from source (local or S3)
|
1. Loads model from source (local or S3)
|
||||||
2. Applies FP8 dynamic quantization
|
2. Applies FP8 dynamic quantization with optional calibration dataset
|
||||||
3. Saves compressed model to destination (local or S3)
|
3. Saves compressed model to destination (local or S3)
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
python compress_checkpoint.py <source_path> <destination_path> [--recipe <recipe_path>]
|
python compress_checkpoint.py <source_path> <destination_path> --recipe <recipe_path> [--num-calibration-samples N]
|
||||||
|
|
||||||
source_path: Path to checkpoint (local or S3)
|
source_path: Path to checkpoint (local or S3)
|
||||||
destination_path: Where to save compressed checkpoint (local or S3)
|
destination_path: Where to save compressed checkpoint (local or S3)
|
||||||
recipe_path: Optional path to quantization config YAML file
|
recipe_path: Path to quantization config YAML file
|
||||||
|
num_calibration_samples: Number of calibration samples to use (default: 100)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import base64
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
import random
|
||||||
import shutil
|
import shutil
|
||||||
import tempfile
|
import tempfile
|
||||||
from typing import Optional, Tuple, Union
|
from io import BytesIO
|
||||||
|
from typing import Optional, Tuple, Union, List
|
||||||
|
|
||||||
import boto3
|
import boto3
|
||||||
import torch
|
import torch
|
||||||
from llmcompressor import oneshot
|
from llmcompressor import oneshot
|
||||||
from transformers import AutoTokenizer, Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration
|
from PIL import Image
|
||||||
|
from transformers import AutoTokenizer, Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration, AutoProcessor
|
||||||
|
|
||||||
from olmocr.s3_utils import parse_s3_path
|
from olmocr.s3_utils import parse_s3_path
|
||||||
|
from olmocr.pipeline import build_page_query
|
||||||
|
|
||||||
|
|
||||||
s3_client = boto3.client("s3")
|
s3_client = boto3.client("s3")
|
||||||
|
CALIBRATION_S3_PATH = "s3://ai2-oe-data/jakep/olmocr/olmOCR-mix-0225/benchmark_set"
|
||||||
|
|
||||||
|
|
||||||
|
def download_calibration_pdfs(num_samples: int) -> List[str]:
|
||||||
|
"""Download calibration PDFs from S3 and return local paths."""
|
||||||
|
bucket, prefix = parse_s3_path(CALIBRATION_S3_PATH)
|
||||||
|
|
||||||
|
# Create temporary directory for PDFs
|
||||||
|
temp_dir = tempfile.mkdtemp()
|
||||||
|
print(f"Downloading calibration PDFs to {temp_dir}...")
|
||||||
|
|
||||||
|
# List all PDFs in the calibration dataset
|
||||||
|
paginator = s3_client.get_paginator("list_objects_v2")
|
||||||
|
pages = paginator.paginate(Bucket=bucket, Prefix=prefix)
|
||||||
|
|
||||||
|
pdf_keys = []
|
||||||
|
for page in pages:
|
||||||
|
for obj in page.get("Contents", []):
|
||||||
|
key = obj["Key"]
|
||||||
|
if key.endswith(".pdf"):
|
||||||
|
pdf_keys.append(key)
|
||||||
|
|
||||||
|
# Randomly sample PDFs
|
||||||
|
if len(pdf_keys) > num_samples:
|
||||||
|
pdf_keys = random.sample(pdf_keys, num_samples)
|
||||||
|
|
||||||
|
# Download the PDFs
|
||||||
|
local_paths = []
|
||||||
|
for key in pdf_keys:
|
||||||
|
filename = os.path.basename(key)
|
||||||
|
local_path = os.path.join(temp_dir, filename)
|
||||||
|
s3_client.download_file(bucket, key, local_path)
|
||||||
|
local_paths.append(local_path)
|
||||||
|
print(f" Downloaded {filename}")
|
||||||
|
|
||||||
|
print(f"Downloaded {len(local_paths)} calibration PDFs")
|
||||||
|
return local_paths, temp_dir
|
||||||
|
|
||||||
|
|
||||||
|
async def prepare_calibration_dataset(pdf_paths: List[str], processor) -> List[dict]:
|
||||||
|
"""Prepare calibration dataset from PDFs using build_page_query."""
|
||||||
|
dataset = []
|
||||||
|
|
||||||
|
for pdf_path in pdf_paths:
|
||||||
|
# Get first page of each PDF (page 0)
|
||||||
|
query = await build_page_query(pdf_path, page=0, target_longest_image_dim=1024)
|
||||||
|
|
||||||
|
# Extract the image and text from the query
|
||||||
|
messages = query["messages"]
|
||||||
|
if messages and len(messages) > 0:
|
||||||
|
content = messages[0]["content"]
|
||||||
|
|
||||||
|
# Extract image data and text
|
||||||
|
image_data = None
|
||||||
|
text = None
|
||||||
|
|
||||||
|
for item in content:
|
||||||
|
if item["type"] == "image_url":
|
||||||
|
image_data = item["image_url"]["url"]
|
||||||
|
elif item["type"] == "text":
|
||||||
|
text = item["text"]
|
||||||
|
|
||||||
|
if image_data and text:
|
||||||
|
# Convert base64 image to PIL Image
|
||||||
|
# Remove data URL prefix
|
||||||
|
base64_str = image_data.split(",")[1] if "," in image_data else image_data
|
||||||
|
image_bytes = base64.b64decode(base64_str)
|
||||||
|
image = Image.open(BytesIO(image_bytes))
|
||||||
|
|
||||||
|
# Process with the model's processor
|
||||||
|
inputs = processor(
|
||||||
|
text=[text],
|
||||||
|
images=[image],
|
||||||
|
padding=False,
|
||||||
|
truncation=True,
|
||||||
|
max_length=4096
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset.append(inputs)
|
||||||
|
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
def is_s3_path(path: str) -> bool:
|
def is_s3_path(path: str) -> bool:
|
||||||
@ -150,7 +238,12 @@ def copy_additional_files(source_path: str, dest_path: str, temp_source_dir: Opt
|
|||||||
shutil.copy2(source_file, dest_file)
|
shutil.copy2(source_file, dest_file)
|
||||||
|
|
||||||
|
|
||||||
def compress_checkpoint(source_path: str, dest_path: str, recipe_path: str) -> None:
|
def data_collator(batch):
|
||||||
|
"""Simple data collator for calibration dataset."""
|
||||||
|
return {key: torch.tensor(value) for key, value in batch[0].items()}
|
||||||
|
|
||||||
|
|
||||||
|
def compress_checkpoint(source_path: str, dest_path: str, recipe_path: str, num_calibration_samples: int = 100) -> None:
|
||||||
"""Compress OlmOCR checkpoint using FP8 quantization."""
|
"""Compress OlmOCR checkpoint using FP8 quantization."""
|
||||||
# Load model and tokenizer
|
# Load model and tokenizer
|
||||||
model, tokenizer, temp_source_dir = load_model_and_tokenizer(source_path)
|
model, tokenizer, temp_source_dir = load_model_and_tokenizer(source_path)
|
||||||
@ -162,9 +255,38 @@ def compress_checkpoint(source_path: str, dest_path: str, recipe_path: str) -> N
|
|||||||
print(f"{name}: shape={list(param.shape)}, dtype={param.dtype}")
|
print(f"{name}: shape={list(param.shape)}, dtype={param.dtype}")
|
||||||
print("=========================\n")
|
print("=========================\n")
|
||||||
|
|
||||||
|
# Prepare calibration dataset if requested
|
||||||
|
dataset = None
|
||||||
|
temp_pdf_dir = None
|
||||||
|
|
||||||
|
if num_calibration_samples > 0:
|
||||||
|
print(f"\nPreparing calibration dataset with {num_calibration_samples} samples...")
|
||||||
|
|
||||||
|
# Load processor for the model
|
||||||
|
processor = AutoProcessor.from_pretrained(source_path if not temp_source_dir else temp_source_dir)
|
||||||
|
|
||||||
|
# Download PDFs
|
||||||
|
pdf_paths, temp_pdf_dir = download_calibration_pdfs(num_calibration_samples)
|
||||||
|
|
||||||
|
# Prepare dataset
|
||||||
|
dataset = asyncio.run(prepare_calibration_dataset(pdf_paths, processor))
|
||||||
|
|
||||||
|
print(f"✓ Prepared {len(dataset)} calibration samples")
|
||||||
|
|
||||||
# Apply quantization using provided recipe
|
# Apply quantization using provided recipe
|
||||||
print(f"\nApplying quantization using recipe: {recipe_path}")
|
print(f"\nApplying quantization using recipe: {recipe_path}")
|
||||||
oneshot(model=model, recipe=recipe_path)
|
|
||||||
|
if dataset:
|
||||||
|
oneshot(
|
||||||
|
model=model,
|
||||||
|
recipe=recipe_path,
|
||||||
|
dataset=dataset,
|
||||||
|
num_calibration_samples=len(dataset),
|
||||||
|
data_collator=data_collator
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
oneshot(model=model, recipe=recipe_path)
|
||||||
|
|
||||||
print("✓ Quantization completed successfully")
|
print("✓ Quantization completed successfully")
|
||||||
|
|
||||||
# Save the compressed model
|
# Save the compressed model
|
||||||
@ -199,6 +321,11 @@ def compress_checkpoint(source_path: str, dest_path: str, recipe_path: str) -> N
|
|||||||
print(f"Cleaning up temporary directory {temp_source_dir}...")
|
print(f"Cleaning up temporary directory {temp_source_dir}...")
|
||||||
shutil.rmtree(temp_source_dir)
|
shutil.rmtree(temp_source_dir)
|
||||||
|
|
||||||
|
# Clean up temporary PDF directory if needed
|
||||||
|
if temp_pdf_dir:
|
||||||
|
print(f"Cleaning up temporary PDF directory {temp_pdf_dir}...")
|
||||||
|
shutil.rmtree(temp_pdf_dir)
|
||||||
|
|
||||||
# Free up GPU memory
|
# Free up GPU memory
|
||||||
del model
|
del model
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
@ -226,11 +353,13 @@ Examples:
|
|||||||
parser.add_argument("source", help="Source checkpoint path (local or S3)")
|
parser.add_argument("source", help="Source checkpoint path (local or S3)")
|
||||||
parser.add_argument("destination", help="Destination path for compressed checkpoint (local or S3)")
|
parser.add_argument("destination", help="Destination path for compressed checkpoint (local or S3)")
|
||||||
parser.add_argument("--recipe", required=True, help="Path to quantization recipe YAML file")
|
parser.add_argument("--recipe", required=True, help="Path to quantization recipe YAML file")
|
||||||
|
parser.add_argument("--num-calibration-samples", type=int, default=100,
|
||||||
|
help="Number of calibration samples to use from benchmark set (default: 100, set to 0 to disable)")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
compress_checkpoint(args.source, args.destination, args.recipe)
|
compress_checkpoint(args.source, args.destination, args.recipe, args.num_calibration_samples)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"\n❌ Error: {e}")
|
print(f"\n❌ Error: {e}")
|
||||||
return 1
|
return 1
|
||||||
|
Loading…
x
Reference in New Issue
Block a user