From 8f88a98e5d10d9b2fc8bdba18a88f2fb62ce78fc Mon Sep 17 00:00:00 2001 From: Jake Poznanski Date: Thu, 4 Sep 2025 22:15:55 +0000 Subject: [PATCH] prepare checkpoint script fixes --- olmocr/train/prepare_checkpoint.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/olmocr/train/prepare_checkpoint.py b/olmocr/train/prepare_checkpoint.py index 1ebe4c6..e8b6443 100755 --- a/olmocr/train/prepare_checkpoint.py +++ b/olmocr/train/prepare_checkpoint.py @@ -31,6 +31,7 @@ Examples: import argparse import concurrent.futures +import fnmatch import json import os import shutil @@ -59,11 +60,20 @@ TOKENIZER_FILES = ["chat_template.json", "merges.txt", "preprocessor_config.json SUPPORTED_ARCHITECTURES = ["Qwen2VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration"] # Files to exclude from copying (training-related files) -EXCLUDED_FILES = {"optimizer.pt", "scheduler.pt", "rng_state.pth", "trainer_state.json", "training_args.bin"} +# Supports exact matches and glob patterns +EXCLUDED_FILES = {"optimizer.pt", "scheduler.pt", "rng_state.pth", "trainer_state.json", "training_args.bin", "*.pt", "*.pth"} s3_client = boto3.client("s3") +def should_exclude_file(filename: str) -> bool: + """Check if a file should be excluded based on EXCLUDED_FILES patterns.""" + for pattern in EXCLUDED_FILES: + if fnmatch.fnmatch(filename, pattern): + return True + return False + + def is_s3_path(path: str) -> bool: """Check if a path is an S3 path.""" return path.startswith("s3://") @@ -123,7 +133,7 @@ def copy_local_to_local(source_dir: str, dest_dir: str) -> None: files_to_copy = [] for root, _, files in os.walk(source_dir): for file in files: - if file in EXCLUDED_FILES: + if should_exclude_file(file): print(f"Skipping excluded file: {file}") continue src_path = os.path.join(root, file) @@ -164,7 +174,7 @@ def copy_s3_to_local(source_bucket: str, source_prefix: str, dest_dir: str) -> N continue filename = os.path.basename(key) - if filename in EXCLUDED_FILES: + if should_exclude_file(filename): print(f"Skipping excluded file: {filename}") continue @@ -187,7 +197,7 @@ def copy_local_to_s3(source_dir: str, dest_bucket: str, dest_prefix: str) -> Non upload_tasks = [] for root, _, files in os.walk(source_dir): for file in files: - if file in EXCLUDED_FILES: + if should_exclude_file(file): print(f"Skipping excluded file: {file}") continue local_path = os.path.join(root, file) @@ -218,7 +228,7 @@ def copy_s3_to_s3(source_bucket: str, source_prefix: str, dest_bucket: str, dest continue filename = os.path.basename(key) - if filename in EXCLUDED_FILES: + if should_exclude_file(filename): print(f"Skipping excluded file: {filename}") continue