Better prepare checkpoint script

This commit is contained in:
Jake Poznanski 2025-07-01 16:44:19 +00:00
parent 8dcfdd0418
commit c6c1fbd0eb

View File

@ -43,6 +43,15 @@ TOKENIZER_FILES = [
# Expected model architecture # Expected model architecture
EXPECTED_ARCHITECTURE = "Qwen2_5_VLForConditionalGeneration" EXPECTED_ARCHITECTURE = "Qwen2_5_VLForConditionalGeneration"
# Files to exclude from copying (training-related files)
EXCLUDED_FILES = {
"optimizer.pt",
"scheduler.pt",
"rng_state.pth",
"trainer_state.json",
"training_args.bin"
}
s3_client = boto3.client("s3") s3_client = boto3.client("s3")
@ -92,6 +101,9 @@ def copy_local_to_local(source_dir: str, dest_dir: str) -> None:
files_to_copy = [] files_to_copy = []
for root, _, files in os.walk(source_dir): for root, _, files in os.walk(source_dir):
for file in files: for file in files:
if file in EXCLUDED_FILES:
print(f"Skipping excluded file: {file}")
continue
src_path = os.path.join(root, file) src_path = os.path.join(root, file)
rel_path = os.path.relpath(src_path, source_dir) rel_path = os.path.relpath(src_path, source_dir)
files_to_copy.append((src_path, os.path.join(dest_dir, rel_path))) files_to_copy.append((src_path, os.path.join(dest_dir, rel_path)))
@ -129,6 +141,11 @@ def copy_s3_to_local(source_bucket: str, source_prefix: str, dest_dir: str) -> N
if key.endswith("/"): if key.endswith("/"):
continue continue
filename = os.path.basename(key)
if filename in EXCLUDED_FILES:
print(f"Skipping excluded file: {filename}")
continue
rel_path = os.path.relpath(key, source_prefix) rel_path = os.path.relpath(key, source_prefix)
local_path = os.path.join(dest_dir, rel_path) local_path = os.path.join(dest_dir, rel_path)
download_tasks.append((source_bucket, key, local_path)) download_tasks.append((source_bucket, key, local_path))
@ -151,6 +168,9 @@ def copy_local_to_s3(source_dir: str, dest_bucket: str, dest_prefix: str) -> Non
upload_tasks = [] upload_tasks = []
for root, _, files in os.walk(source_dir): for root, _, files in os.walk(source_dir):
for file in files: for file in files:
if file in EXCLUDED_FILES:
print(f"Skipping excluded file: {file}")
continue
local_path = os.path.join(root, file) local_path = os.path.join(root, file)
rel_path = os.path.relpath(local_path, source_dir) rel_path = os.path.relpath(local_path, source_dir)
s3_key = os.path.join(dest_prefix, rel_path) s3_key = os.path.join(dest_prefix, rel_path)
@ -181,6 +201,11 @@ def copy_s3_to_s3(source_bucket: str, source_prefix: str, dest_bucket: str, dest
if key.endswith("/"): if key.endswith("/"):
continue continue
filename = os.path.basename(key)
if filename in EXCLUDED_FILES:
print(f"Skipping excluded file: {filename}")
continue
rel_path = os.path.relpath(key, source_prefix) rel_path = os.path.relpath(key, source_prefix)
dest_key = os.path.join(dest_prefix, rel_path) dest_key = os.path.join(dest_prefix, rel_path)
copy_source = {"Bucket": source_bucket, "Key": key} copy_source = {"Bucket": source_bucket, "Key": key}