mirror of
https://github.com/allenai/olmocr.git
synced 2025-12-16 17:56:25 +00:00
Compressor script
This commit is contained in:
parent
1ede76d0b2
commit
01360ba21d
@ -17,14 +17,13 @@ import json
|
|||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import tempfile
|
import tempfile
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
import boto3
|
import boto3
|
||||||
import torch
|
import torch
|
||||||
from llmcompressor import oneshot
|
from llmcompressor import oneshot
|
||||||
from llmcompressor.modifiers.quantization import QuantizationModifier
|
from llmcompressor.modifiers.quantization import QuantizationModifier
|
||||||
from smart_open import smart_open
|
from transformers import AutoTokenizer, Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration
|
||||||
from transformers import AutoTokenizer, Qwen2VLForConditionalGeneration
|
|
||||||
|
|
||||||
from olmocr.s3_utils import parse_s3_path
|
from olmocr.s3_utils import parse_s3_path
|
||||||
|
|
||||||
@ -74,7 +73,7 @@ def upload_local_to_s3(local_dir: str, bucket: str, prefix: str) -> None:
|
|||||||
print(f" Uploaded {rel_path}")
|
print(f" Uploaded {rel_path}")
|
||||||
|
|
||||||
|
|
||||||
def load_model_and_tokenizer(source_path: str) -> Tuple[Qwen2VLForConditionalGeneration, AutoTokenizer, Optional[str]]:
|
def load_model_and_tokenizer(source_path: str) -> Tuple[Union[Qwen2VLForConditionalGeneration, Qwen2_5_VLForConditionalGeneration], AutoTokenizer, Optional[str]]:
|
||||||
"""Load model and tokenizer from source path (local or S3)."""
|
"""Load model and tokenizer from source path (local or S3)."""
|
||||||
if is_s3_path(source_path):
|
if is_s3_path(source_path):
|
||||||
# Download from S3 to temporary directory
|
# Download from S3 to temporary directory
|
||||||
@ -86,12 +85,48 @@ def load_model_and_tokenizer(source_path: str) -> Tuple[Qwen2VLForConditionalGen
|
|||||||
model_path = source_path
|
model_path = source_path
|
||||||
temp_dir = None
|
temp_dir = None
|
||||||
|
|
||||||
|
# Read config to determine model architecture
|
||||||
|
config_path = os.path.join(model_path, "config.json")
|
||||||
|
with open(config_path, "r") as f:
|
||||||
|
config = json.load(f)
|
||||||
|
|
||||||
|
# Get model name from config
|
||||||
|
model_name = config.get("name_or_path", "")
|
||||||
|
|
||||||
print(f"Loading model from {model_path}...")
|
print(f"Loading model from {model_path}...")
|
||||||
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
|
||||||
model_path,
|
# Load appropriate model class based on name
|
||||||
device_map="auto",
|
if "Qwen2.5-VL" in model_name:
|
||||||
torch_dtype="auto"
|
print("Detected Qwen2.5-VL model")
|
||||||
)
|
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||||
|
model_path,
|
||||||
|
device_map="auto",
|
||||||
|
torch_dtype="auto"
|
||||||
|
)
|
||||||
|
elif "Qwen2-VL" in model_name:
|
||||||
|
print("Detected Qwen2-VL model")
|
||||||
|
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
||||||
|
model_path,
|
||||||
|
device_map="auto",
|
||||||
|
torch_dtype="auto"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Default to checking architectures list
|
||||||
|
architectures = config.get("architectures", [])
|
||||||
|
if "Qwen2_5_VLForConditionalGeneration" in architectures:
|
||||||
|
print("Detected Qwen2.5-VL model from architectures")
|
||||||
|
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||||
|
model_path,
|
||||||
|
device_map="auto",
|
||||||
|
torch_dtype="auto"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print("Detected Qwen2-VL model from architectures")
|
||||||
|
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
||||||
|
model_path,
|
||||||
|
device_map="auto",
|
||||||
|
torch_dtype="auto"
|
||||||
|
)
|
||||||
|
|
||||||
print(f"Loading tokenizer from {model_path}...")
|
print(f"Loading tokenizer from {model_path}...")
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||||
@ -101,11 +136,6 @@ def load_model_and_tokenizer(source_path: str) -> Tuple[Qwen2VLForConditionalGen
|
|||||||
|
|
||||||
def compress_checkpoint(source_path: str, dest_path: str) -> None:
|
def compress_checkpoint(source_path: str, dest_path: str) -> None:
|
||||||
"""Compress OlmOCR checkpoint using FP8 quantization."""
|
"""Compress OlmOCR checkpoint using FP8 quantization."""
|
||||||
# First, validate the source checkpoint
|
|
||||||
config_path = os.path.join(source_path, "config.json")
|
|
||||||
if is_s3_path(source_path):
|
|
||||||
config_path = f"{source_path}/config.json"
|
|
||||||
|
|
||||||
# Load model and tokenizer
|
# Load model and tokenizer
|
||||||
model, tokenizer, temp_source_dir = load_model_and_tokenizer(source_path)
|
model, tokenizer, temp_source_dir = load_model_and_tokenizer(source_path)
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user