mirror of
https://github.com/allenai/olmocr.git
synced 2025-11-07 22:18:51 +00:00
Qwen checkpoint fixer script
This commit is contained in:
parent
2c7323d1c4
commit
b6543a4f65
@ -1,25 +1,45 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import json
|
||||||
import boto3
|
|
||||||
from tqdm import tqdm
|
|
||||||
from transformers import AutoModel, Qwen2VLForConditionalGeneration
|
|
||||||
from smart_open import smart_open
|
from smart_open import smart_open
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(description='Fix up a Qwen2VL checkpoint saved on s3 or otherwise, so that it will load properly in vllm/birr')
|
parser = argparse.ArgumentParser(description='Fix up a Qwen2VL checkpoint saved on s3 or otherwise, so that it will load properly in vllm/birr')
|
||||||
parser.add_argument('s3_path', type=str, help='S3 path to the Hugging Face checkpoint.')
|
parser.add_argument('s3_path', type=str, help='S3 path to the Hugging Face checkpoint.')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# Create a temporary directory to store the model files
|
qwen_replacement_files = [
|
||||||
|
# Config is special to fix rope config
|
||||||
|
"s3://ai2-oe-data/artifacts/Qwen2-VL-7B-Instruct/config.json",
|
||||||
|
|
||||||
# Rewrite the config.json from the official repo, this fixes a weird bug with the rope scaling configuration
|
# Tokenizer and preprocessor are just not saved in the usual flow
|
||||||
with smart_open("https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct/resolve/main/config.json", "r") as newf:
|
"https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct/resolve/main/tokenizer.json",
|
||||||
new_config = newf.read()
|
"https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct/resolve/main/tokenizer_config.json",
|
||||||
|
"https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct/resolve/main/vocab.json",
|
||||||
|
"https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct/resolve/main/merges.txt",
|
||||||
|
"https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct/resolve/main/generation_config.json",
|
||||||
|
"https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct/resolve/main/chat_template.json",
|
||||||
|
"https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct/resolve/main/preprocessor_config.json",
|
||||||
|
]
|
||||||
|
|
||||||
with smart_open(os.path.join(args.s3_path, "config.json"), "w") as oldf:
|
# Now, download the config.json from the original path and verify the architectures
|
||||||
oldf.write(new_config)
|
config_path = os.path.join(args.s3_path, "config.json")
|
||||||
|
|
||||||
|
with smart_open(config_path, 'r') as f:
|
||||||
|
config_data = json.load(f)
|
||||||
|
|
||||||
|
assert config_data["architectures"] == ["Qwen2VLForConditionalGeneration"]
|
||||||
|
|
||||||
|
# Iterate over each file in the replacement list
|
||||||
|
for replacement_file in qwen_replacement_files:
|
||||||
|
filename = os.path.basename(replacement_file)
|
||||||
|
dest_path = os.path.join(args.s3_path, filename)
|
||||||
|
|
||||||
|
with smart_open(replacement_file, 'rb') as src_file:
|
||||||
|
data = src_file.read()
|
||||||
|
|
||||||
|
with smart_open(dest_path, 'wb') as dest_file:
|
||||||
|
dest_file.write(data)
|
||||||
|
|
||||||
print("Model updated successfully.")
|
print("Model updated successfully.")
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user