mirror of
https://github.com/allenai/olmocr.git
synced 2025-10-11 08:12:22 +00:00
Better prepare checkpoint script
This commit is contained in:
parent
8dcfdd0418
commit
c6c1fbd0eb
@ -43,6 +43,15 @@ TOKENIZER_FILES = [
|
||||
# Expected model architecture
|
||||
EXPECTED_ARCHITECTURE = "Qwen2_5_VLForConditionalGeneration"
|
||||
|
||||
# Files to exclude from copying (training-related files)
|
||||
EXCLUDED_FILES = {
|
||||
"optimizer.pt",
|
||||
"scheduler.pt",
|
||||
"rng_state.pth",
|
||||
"trainer_state.json",
|
||||
"training_args.bin"
|
||||
}
|
||||
|
||||
s3_client = boto3.client("s3")
|
||||
|
||||
|
||||
@ -92,6 +101,9 @@ def copy_local_to_local(source_dir: str, dest_dir: str) -> None:
|
||||
files_to_copy = []
|
||||
for root, _, files in os.walk(source_dir):
|
||||
for file in files:
|
||||
if file in EXCLUDED_FILES:
|
||||
print(f"Skipping excluded file: {file}")
|
||||
continue
|
||||
src_path = os.path.join(root, file)
|
||||
rel_path = os.path.relpath(src_path, source_dir)
|
||||
files_to_copy.append((src_path, os.path.join(dest_dir, rel_path)))
|
||||
@ -129,6 +141,11 @@ def copy_s3_to_local(source_bucket: str, source_prefix: str, dest_dir: str) -> N
|
||||
if key.endswith("/"):
|
||||
continue
|
||||
|
||||
filename = os.path.basename(key)
|
||||
if filename in EXCLUDED_FILES:
|
||||
print(f"Skipping excluded file: {filename}")
|
||||
continue
|
||||
|
||||
rel_path = os.path.relpath(key, source_prefix)
|
||||
local_path = os.path.join(dest_dir, rel_path)
|
||||
download_tasks.append((source_bucket, key, local_path))
|
||||
@ -151,6 +168,9 @@ def copy_local_to_s3(source_dir: str, dest_bucket: str, dest_prefix: str) -> Non
|
||||
upload_tasks = []
|
||||
for root, _, files in os.walk(source_dir):
|
||||
for file in files:
|
||||
if file in EXCLUDED_FILES:
|
||||
print(f"Skipping excluded file: {file}")
|
||||
continue
|
||||
local_path = os.path.join(root, file)
|
||||
rel_path = os.path.relpath(local_path, source_dir)
|
||||
s3_key = os.path.join(dest_prefix, rel_path)
|
||||
@ -181,6 +201,11 @@ def copy_s3_to_s3(source_bucket: str, source_prefix: str, dest_bucket: str, dest
|
||||
if key.endswith("/"):
|
||||
continue
|
||||
|
||||
filename = os.path.basename(key)
|
||||
if filename in EXCLUDED_FILES:
|
||||
print(f"Skipping excluded file: {filename}")
|
||||
continue
|
||||
|
||||
rel_path = os.path.relpath(key, source_prefix)
|
||||
dest_key = os.path.join(dest_prefix, rel_path)
|
||||
copy_source = {"Bucket": source_bucket, "Key": key}
|
||||
|
Loading…
x
Reference in New Issue
Block a user