mirror of
https://github.com/allenai/olmocr.git
synced 2025-10-12 16:52:20 +00:00
prepare checkpoint script fixes
This commit is contained in:
parent
c720c02d83
commit
8f88a98e5d
@ -31,6 +31,7 @@ Examples:
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import concurrent.futures
|
import concurrent.futures
|
||||||
|
import fnmatch
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
@ -59,11 +60,20 @@ TOKENIZER_FILES = ["chat_template.json", "merges.txt", "preprocessor_config.json
|
|||||||
SUPPORTED_ARCHITECTURES = ["Qwen2VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration"]
|
SUPPORTED_ARCHITECTURES = ["Qwen2VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration"]
|
||||||
|
|
||||||
# Files to exclude from copying (training-related files)
|
# Files to exclude from copying (training-related files)
|
||||||
EXCLUDED_FILES = {"optimizer.pt", "scheduler.pt", "rng_state.pth", "trainer_state.json", "training_args.bin"}
|
# Supports exact matches and glob patterns
|
||||||
|
EXCLUDED_FILES = {"optimizer.pt", "scheduler.pt", "rng_state.pth", "trainer_state.json", "training_args.bin", "*.pt", "*.pth"}
|
||||||
|
|
||||||
s3_client = boto3.client("s3")
|
s3_client = boto3.client("s3")
|
||||||
|
|
||||||
|
|
||||||
|
def should_exclude_file(filename: str) -> bool:
|
||||||
|
"""Check if a file should be excluded based on EXCLUDED_FILES patterns."""
|
||||||
|
for pattern in EXCLUDED_FILES:
|
||||||
|
if fnmatch.fnmatch(filename, pattern):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def is_s3_path(path: str) -> bool:
|
def is_s3_path(path: str) -> bool:
|
||||||
"""Check if a path is an S3 path."""
|
"""Check if a path is an S3 path."""
|
||||||
return path.startswith("s3://")
|
return path.startswith("s3://")
|
||||||
@ -123,7 +133,7 @@ def copy_local_to_local(source_dir: str, dest_dir: str) -> None:
|
|||||||
files_to_copy = []
|
files_to_copy = []
|
||||||
for root, _, files in os.walk(source_dir):
|
for root, _, files in os.walk(source_dir):
|
||||||
for file in files:
|
for file in files:
|
||||||
if file in EXCLUDED_FILES:
|
if should_exclude_file(file):
|
||||||
print(f"Skipping excluded file: {file}")
|
print(f"Skipping excluded file: {file}")
|
||||||
continue
|
continue
|
||||||
src_path = os.path.join(root, file)
|
src_path = os.path.join(root, file)
|
||||||
@ -164,7 +174,7 @@ def copy_s3_to_local(source_bucket: str, source_prefix: str, dest_dir: str) -> N
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
filename = os.path.basename(key)
|
filename = os.path.basename(key)
|
||||||
if filename in EXCLUDED_FILES:
|
if should_exclude_file(filename):
|
||||||
print(f"Skipping excluded file: {filename}")
|
print(f"Skipping excluded file: {filename}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@ -187,7 +197,7 @@ def copy_local_to_s3(source_dir: str, dest_bucket: str, dest_prefix: str) -> Non
|
|||||||
upload_tasks = []
|
upload_tasks = []
|
||||||
for root, _, files in os.walk(source_dir):
|
for root, _, files in os.walk(source_dir):
|
||||||
for file in files:
|
for file in files:
|
||||||
if file in EXCLUDED_FILES:
|
if should_exclude_file(file):
|
||||||
print(f"Skipping excluded file: {file}")
|
print(f"Skipping excluded file: {file}")
|
||||||
continue
|
continue
|
||||||
local_path = os.path.join(root, file)
|
local_path = os.path.join(root, file)
|
||||||
@ -218,7 +228,7 @@ def copy_s3_to_s3(source_bucket: str, source_prefix: str, dest_bucket: str, dest
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
filename = os.path.basename(key)
|
filename = os.path.basename(key)
|
||||||
if filename in EXCLUDED_FILES:
|
if should_exclude_file(filename):
|
||||||
print(f"Skipping excluded file: {filename}")
|
print(f"Skipping excluded file: {filename}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user