mirror of
https://github.com/allenai/olmocr.git
synced 2025-10-11 08:12:22 +00:00
prepare checkpoint script fixes
This commit is contained in:
parent
c720c02d83
commit
8f88a98e5d
@ -31,6 +31,7 @@ Examples:
|
||||
|
||||
import argparse
|
||||
import concurrent.futures
|
||||
import fnmatch
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
@ -59,11 +60,20 @@ TOKENIZER_FILES = ["chat_template.json", "merges.txt", "preprocessor_config.json
|
||||
SUPPORTED_ARCHITECTURES = ["Qwen2VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration"]
|
||||
|
||||
# Files to exclude from copying (training-related files)
|
||||
EXCLUDED_FILES = {"optimizer.pt", "scheduler.pt", "rng_state.pth", "trainer_state.json", "training_args.bin"}
|
||||
# Supports exact matches and glob patterns
|
||||
EXCLUDED_FILES = {"optimizer.pt", "scheduler.pt", "rng_state.pth", "trainer_state.json", "training_args.bin", "*.pt", "*.pth"}
|
||||
|
||||
s3_client = boto3.client("s3")
|
||||
|
||||
|
||||
def should_exclude_file(filename: str) -> bool:
|
||||
"""Check if a file should be excluded based on EXCLUDED_FILES patterns."""
|
||||
for pattern in EXCLUDED_FILES:
|
||||
if fnmatch.fnmatch(filename, pattern):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def is_s3_path(path: str) -> bool:
|
||||
"""Check if a path is an S3 path."""
|
||||
return path.startswith("s3://")
|
||||
@ -123,7 +133,7 @@ def copy_local_to_local(source_dir: str, dest_dir: str) -> None:
|
||||
files_to_copy = []
|
||||
for root, _, files in os.walk(source_dir):
|
||||
for file in files:
|
||||
if file in EXCLUDED_FILES:
|
||||
if should_exclude_file(file):
|
||||
print(f"Skipping excluded file: {file}")
|
||||
continue
|
||||
src_path = os.path.join(root, file)
|
||||
@ -164,7 +174,7 @@ def copy_s3_to_local(source_bucket: str, source_prefix: str, dest_dir: str) -> N
|
||||
continue
|
||||
|
||||
filename = os.path.basename(key)
|
||||
if filename in EXCLUDED_FILES:
|
||||
if should_exclude_file(filename):
|
||||
print(f"Skipping excluded file: {filename}")
|
||||
continue
|
||||
|
||||
@ -187,7 +197,7 @@ def copy_local_to_s3(source_dir: str, dest_bucket: str, dest_prefix: str) -> Non
|
||||
upload_tasks = []
|
||||
for root, _, files in os.walk(source_dir):
|
||||
for file in files:
|
||||
if file in EXCLUDED_FILES:
|
||||
if should_exclude_file(file):
|
||||
print(f"Skipping excluded file: {file}")
|
||||
continue
|
||||
local_path = os.path.join(root, file)
|
||||
@ -218,7 +228,7 @@ def copy_s3_to_s3(source_bucket: str, source_prefix: str, dest_bucket: str, dest
|
||||
continue
|
||||
|
||||
filename = os.path.basename(key)
|
||||
if filename in EXCLUDED_FILES:
|
||||
if should_exclude_file(filename):
|
||||
print(f"Skipping excluded file: {filename}")
|
||||
continue
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user