mirror of
https://github.com/allenai/olmocr.git
synced 2025-10-11 16:22:29 +00:00
Compressor script
This commit is contained in:
parent
1ede76d0b2
commit
01360ba21d
@ -17,14 +17,13 @@ import json
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import boto3
|
||||
import torch
|
||||
from llmcompressor import oneshot
|
||||
from llmcompressor.modifiers.quantization import QuantizationModifier
|
||||
from smart_open import smart_open
|
||||
from transformers import AutoTokenizer, Qwen2VLForConditionalGeneration
|
||||
from transformers import AutoTokenizer, Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration
|
||||
|
||||
from olmocr.s3_utils import parse_s3_path
|
||||
|
||||
@ -74,7 +73,7 @@ def upload_local_to_s3(local_dir: str, bucket: str, prefix: str) -> None:
|
||||
print(f" Uploaded {rel_path}")
|
||||
|
||||
|
||||
def load_model_and_tokenizer(source_path: str) -> Tuple[Qwen2VLForConditionalGeneration, AutoTokenizer, Optional[str]]:
|
||||
def load_model_and_tokenizer(source_path: str) -> Tuple[Union[Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration], AutoTokenizer, Optional[str]]:
|
||||
"""Load model and tokenizer from source path (local or S3)."""
|
||||
if is_s3_path(source_path):
|
||||
# Download from S3 to temporary directory
|
||||
@ -86,12 +85,48 @@ def load_model_and_tokenizer(source_path: str) -> Tuple[Qwen2VLForConditionalGen
|
||||
model_path = source_path
|
||||
temp_dir = None
|
||||
|
||||
# Read config to determine model architecture
|
||||
config_path = os.path.join(model_path, "config.json")
|
||||
with open(config_path, "r") as f:
|
||||
config = json.load(f)
|
||||
|
||||
# Get model name from config
|
||||
model_name = config.get("name_or_path", "")
|
||||
|
||||
print(f"Loading model from {model_path}...")
|
||||
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
||||
model_path,
|
||||
device_map="auto",
|
||||
torch_dtype="auto"
|
||||
)
|
||||
|
||||
# Load appropriate model class based on name
|
||||
if "Qwen2.5-VL" in model_name:
|
||||
print("Detected Qwen2.5-VL model")
|
||||
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||
model_path,
|
||||
device_map="auto",
|
||||
torch_dtype="auto"
|
||||
)
|
||||
elif "Qwen2-VL" in model_name:
|
||||
print("Detected Qwen2-VL model")
|
||||
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
||||
model_path,
|
||||
device_map="auto",
|
||||
torch_dtype="auto"
|
||||
)
|
||||
else:
|
||||
# Default to checking architectures list
|
||||
architectures = config.get("architectures", [])
|
||||
if "Qwen2_5_VLForConditionalGeneration" in architectures:
|
||||
print("Detected Qwen2.5-VL model from architectures")
|
||||
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||
model_path,
|
||||
device_map="auto",
|
||||
torch_dtype="auto"
|
||||
)
|
||||
else:
|
||||
print("Detected Qwen2-VL model from architectures")
|
||||
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
||||
model_path,
|
||||
device_map="auto",
|
||||
torch_dtype="auto"
|
||||
)
|
||||
|
||||
print(f"Loading tokenizer from {model_path}...")
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||
@ -101,11 +136,6 @@ def load_model_and_tokenizer(source_path: str) -> Tuple[Qwen2VLForConditionalGen
|
||||
|
||||
def compress_checkpoint(source_path: str, dest_path: str) -> None:
|
||||
"""Compress OlmOCR checkpoint using FP8 quantization."""
|
||||
# First, validate the source checkpoint
|
||||
config_path = os.path.join(source_path, "config.json")
|
||||
if is_s3_path(source_path):
|
||||
config_path = f"{source_path}/config.json"
|
||||
|
||||
# Load model and tokenizer
|
||||
model, tokenizer, temp_source_dir = load_model_and_tokenizer(source_path)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user