From c6c1fbd0eb2b18427baa49c48c311b6a819eee9c Mon Sep 17 00:00:00 2001 From: Jake Poznanski Date: Tue, 1 Jul 2025 16:44:19 +0000 Subject: [PATCH] Better prepare checkpoint script --- olmocr/train/prepare_olmocr_checkpoint.py | 25 +++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/olmocr/train/prepare_olmocr_checkpoint.py b/olmocr/train/prepare_olmocr_checkpoint.py index c8f30b0..8392922 100755 --- a/olmocr/train/prepare_olmocr_checkpoint.py +++ b/olmocr/train/prepare_olmocr_checkpoint.py @@ -43,6 +43,15 @@ TOKENIZER_FILES = [ # Expected model architecture EXPECTED_ARCHITECTURE = "Qwen2_5_VLForConditionalGeneration" +# Files to exclude from copying (training-related files) +EXCLUDED_FILES = { + "optimizer.pt", + "scheduler.pt", + "rng_state.pth", + "trainer_state.json", + "training_args.bin" +} + s3_client = boto3.client("s3") @@ -92,6 +101,9 @@ def copy_local_to_local(source_dir: str, dest_dir: str) -> None: files_to_copy = [] for root, _, files in os.walk(source_dir): for file in files: + if file in EXCLUDED_FILES: + print(f"Skipping excluded file: {file}") + continue src_path = os.path.join(root, file) rel_path = os.path.relpath(src_path, source_dir) files_to_copy.append((src_path, os.path.join(dest_dir, rel_path))) @@ -129,6 +141,11 @@ def copy_s3_to_local(source_bucket: str, source_prefix: str, dest_dir: str) -> N if key.endswith("/"): continue + filename = os.path.basename(key) + if filename in EXCLUDED_FILES: + print(f"Skipping excluded file: {filename}") + continue + rel_path = os.path.relpath(key, source_prefix) local_path = os.path.join(dest_dir, rel_path) download_tasks.append((source_bucket, key, local_path)) @@ -151,6 +168,9 @@ def copy_local_to_s3(source_dir: str, dest_bucket: str, dest_prefix: str) -> Non upload_tasks = [] for root, _, files in os.walk(source_dir): for file in files: + if file in EXCLUDED_FILES: + print(f"Skipping excluded file: {file}") + continue local_path = os.path.join(root, file) rel_path = os.path.relpath(local_path, source_dir) s3_key = os.path.join(dest_prefix, rel_path) @@ -181,6 +201,11 @@ def copy_s3_to_s3(source_bucket: str, source_prefix: str, dest_bucket: str, dest if key.endswith("/"): continue + filename = os.path.basename(key) + if filename in EXCLUDED_FILES: + print(f"Skipping excluded file: {filename}") + continue + rel_path = os.path.relpath(key, source_prefix) dest_key = os.path.join(dest_prefix, rel_path) copy_source = {"Bucket": source_bucket, "Key": key}