prepare checkpoint script fixes

This commit is contained in:
Jake Poznanski 2025-09-04 22:15:55 +00:00
parent c720c02d83
commit 8f88a98e5d

View File

@ -31,6 +31,7 @@ Examples:
import argparse
import concurrent.futures
import fnmatch
import json
import os
import shutil
@ -59,11 +60,20 @@ TOKENIZER_FILES = ["chat_template.json", "merges.txt", "preprocessor_config.json
SUPPORTED_ARCHITECTURES = ["Qwen2VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration"]
# Files to exclude from copying (training-related files)
EXCLUDED_FILES = {"optimizer.pt", "scheduler.pt", "rng_state.pth", "trainer_state.json", "training_args.bin"}
# Supports exact matches and glob patterns
EXCLUDED_FILES = {"optimizer.pt", "scheduler.pt", "rng_state.pth", "trainer_state.json", "training_args.bin", "*.pt", "*.pth"}
s3_client = boto3.client("s3")
def should_exclude_file(filename: str) -> bool:
"""Check if a file should be excluded based on EXCLUDED_FILES patterns."""
for pattern in EXCLUDED_FILES:
if fnmatch.fnmatch(filename, pattern):
return True
return False
def is_s3_path(path: str) -> bool:
"""Check if a path is an S3 path."""
return path.startswith("s3://")
@ -123,7 +133,7 @@ def copy_local_to_local(source_dir: str, dest_dir: str) -> None:
files_to_copy = []
for root, _, files in os.walk(source_dir):
for file in files:
if file in EXCLUDED_FILES:
if should_exclude_file(file):
print(f"Skipping excluded file: {file}")
continue
src_path = os.path.join(root, file)
@ -164,7 +174,7 @@ def copy_s3_to_local(source_bucket: str, source_prefix: str, dest_dir: str) -> N
continue
filename = os.path.basename(key)
if filename in EXCLUDED_FILES:
if should_exclude_file(filename):
print(f"Skipping excluded file: {filename}")
continue
@ -187,7 +197,7 @@ def copy_local_to_s3(source_dir: str, dest_bucket: str, dest_prefix: str) -> Non
upload_tasks = []
for root, _, files in os.walk(source_dir):
for file in files:
if file in EXCLUDED_FILES:
if should_exclude_file(file):
print(f"Skipping excluded file: {file}")
continue
local_path = os.path.join(root, file)
@ -218,7 +228,7 @@ def copy_s3_to_s3(source_bucket: str, source_prefix: str, dest_bucket: str, dest
continue
filename = os.path.basename(key)
if filename in EXCLUDED_FILES:
if should_exclude_file(filename):
print(f"Skipping excluded file: {filename}")
continue