diff --git a/olmocr/train/compress_checkpoint.py b/olmocr/train/compress_checkpoint.py index 8ae7514..e2fbffb 100755 --- a/olmocr/train/compress_checkpoint.py +++ b/olmocr/train/compress_checkpoint.py @@ -2,33 +2,121 @@ """ Compresses OlmOCR checkpoints using FP8 quantization: 1. Loads model from source (local or S3) -2. Applies FP8 dynamic quantization +2. Applies FP8 dynamic quantization with optional calibration dataset 3. Saves compressed model to destination (local or S3) Usage: - python compress_checkpoint.py [--recipe ] + python compress_checkpoint.py --recipe [--num-calibration-samples N] source_path: Path to checkpoint (local or S3) destination_path: Where to save compressed checkpoint (local or S3) - recipe_path: Optional path to quantization config YAML file + recipe_path: Path to quantization config YAML file + num_calibration_samples: Number of calibration samples to use (default: 100) """ import argparse +import asyncio +import base64 import json import os +import random import shutil import tempfile -from typing import Optional, Tuple, Union +from io import BytesIO +from typing import Optional, Tuple, Union, List import boto3 import torch from llmcompressor import oneshot -from transformers import AutoTokenizer, Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration +from PIL import Image +from transformers import AutoTokenizer, Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration, AutoProcessor from olmocr.s3_utils import parse_s3_path +from olmocr.pipeline import build_page_query s3_client = boto3.client("s3") +CALIBRATION_S3_PATH = "s3://ai2-oe-data/jakep/olmocr/olmOCR-mix-0225/benchmark_set" + + +def download_calibration_pdfs(num_samples: int) -> List[str]: + """Download calibration PDFs from S3 and return local paths.""" + bucket, prefix = parse_s3_path(CALIBRATION_S3_PATH) + + # Create temporary directory for PDFs + temp_dir = tempfile.mkdtemp() + print(f"Downloading calibration PDFs to {temp_dir}...") + + # List all PDFs in the calibration dataset + paginator = s3_client.get_paginator("list_objects_v2") + pages = paginator.paginate(Bucket=bucket, Prefix=prefix) + + pdf_keys = [] + for page in pages: + for obj in page.get("Contents", []): + key = obj["Key"] + if key.endswith(".pdf"): + pdf_keys.append(key) + + # Randomly sample PDFs + if len(pdf_keys) > num_samples: + pdf_keys = random.sample(pdf_keys, num_samples) + + # Download the PDFs + local_paths = [] + for key in pdf_keys: + filename = os.path.basename(key) + local_path = os.path.join(temp_dir, filename) + s3_client.download_file(bucket, key, local_path) + local_paths.append(local_path) + print(f" Downloaded {filename}") + + print(f"Downloaded {len(local_paths)} calibration PDFs") + return local_paths, temp_dir + + +async def prepare_calibration_dataset(pdf_paths: List[str], processor) -> List[dict]: + """Prepare calibration dataset from PDFs using build_page_query.""" + dataset = [] + + for pdf_path in pdf_paths: + # Get first page of each PDF (page 0) + query = await build_page_query(pdf_path, page=0, target_longest_image_dim=1024) + + # Extract the image and text from the query + messages = query["messages"] + if messages and len(messages) > 0: + content = messages[0]["content"] + + # Extract image data and text + image_data = None + text = None + + for item in content: + if item["type"] == "image_url": + image_data = item["image_url"]["url"] + elif item["type"] == "text": + text = item["text"] + + if image_data and text: + # Convert base64 image to PIL Image + # Remove data URL prefix + base64_str = image_data.split(",")[1] if "," in image_data else image_data + image_bytes = base64.b64decode(base64_str) + image = Image.open(BytesIO(image_bytes)) + + # Process with the model's processor + inputs = processor( + text=[text], + images=[image], + padding=False, + truncation=True, + max_length=4096 + ) + + dataset.append(inputs) + + return dataset def is_s3_path(path: str) -> bool: @@ -150,7 +238,12 @@ def copy_additional_files(source_path: str, dest_path: str, temp_source_dir: Opt shutil.copy2(source_file, dest_file) -def compress_checkpoint(source_path: str, dest_path: str, recipe_path: str) -> None: +def data_collator(batch): + """Simple data collator for calibration dataset.""" + return {key: torch.tensor(value) for key, value in batch[0].items()} + + +def compress_checkpoint(source_path: str, dest_path: str, recipe_path: str, num_calibration_samples: int = 100) -> None: """Compress OlmOCR checkpoint using FP8 quantization.""" # Load model and tokenizer model, tokenizer, temp_source_dir = load_model_and_tokenizer(source_path) @@ -162,9 +255,38 @@ def compress_checkpoint(source_path: str, dest_path: str, recipe_path: str) -> N print(f"{name}: shape={list(param.shape)}, dtype={param.dtype}") print("=========================\n") + # Prepare calibration dataset if requested + dataset = None + temp_pdf_dir = None + + if num_calibration_samples > 0: + print(f"\nPreparing calibration dataset with {num_calibration_samples} samples...") + + # Load processor for the model + processor = AutoProcessor.from_pretrained(source_path if not temp_source_dir else temp_source_dir) + + # Download PDFs + pdf_paths, temp_pdf_dir = download_calibration_pdfs(num_calibration_samples) + + # Prepare dataset + dataset = asyncio.run(prepare_calibration_dataset(pdf_paths, processor)) + + print(f"āœ“ Prepared {len(dataset)} calibration samples") + # Apply quantization using provided recipe print(f"\nApplying quantization using recipe: {recipe_path}") - oneshot(model=model, recipe=recipe_path) + + if dataset: + oneshot( + model=model, + recipe=recipe_path, + dataset=dataset, + num_calibration_samples=len(dataset), + data_collator=data_collator + ) + else: + oneshot(model=model, recipe=recipe_path) + print("āœ“ Quantization completed successfully") # Save the compressed model @@ -199,6 +321,11 @@ def compress_checkpoint(source_path: str, dest_path: str, recipe_path: str) -> N print(f"Cleaning up temporary directory {temp_source_dir}...") shutil.rmtree(temp_source_dir) + # Clean up temporary PDF directory if needed + if temp_pdf_dir: + print(f"Cleaning up temporary PDF directory {temp_pdf_dir}...") + shutil.rmtree(temp_pdf_dir) + # Free up GPU memory del model torch.cuda.empty_cache() @@ -226,11 +353,13 @@ Examples: parser.add_argument("source", help="Source checkpoint path (local or S3)") parser.add_argument("destination", help="Destination path for compressed checkpoint (local or S3)") parser.add_argument("--recipe", required=True, help="Path to quantization recipe YAML file") + parser.add_argument("--num-calibration-samples", type=int, default=100, + help="Number of calibration samples to use from benchmark set (default: 100, set to 0 to disable)") args = parser.parse_args() try: - compress_checkpoint(args.source, args.destination, args.recipe) + compress_checkpoint(args.source, args.destination, args.recipe, args.num_calibration_samples) except Exception as e: print(f"\nāŒ Error: {e}") return 1