mirror of
https://github.com/allenai/olmocr.git
synced 2025-10-11 16:22:29 +00:00
Better prepare checkpoint script
This commit is contained in:
parent
8dcfdd0418
commit
c6c1fbd0eb
@ -43,6 +43,15 @@ TOKENIZER_FILES = [
|
|||||||
# Expected model architecture
|
# Expected model architecture
|
||||||
EXPECTED_ARCHITECTURE = "Qwen2_5_VLForConditionalGeneration"
|
EXPECTED_ARCHITECTURE = "Qwen2_5_VLForConditionalGeneration"
|
||||||
|
|
||||||
|
# Files to exclude from copying (training-related files)
|
||||||
|
EXCLUDED_FILES = {
|
||||||
|
"optimizer.pt",
|
||||||
|
"scheduler.pt",
|
||||||
|
"rng_state.pth",
|
||||||
|
"trainer_state.json",
|
||||||
|
"training_args.bin"
|
||||||
|
}
|
||||||
|
|
||||||
s3_client = boto3.client("s3")
|
s3_client = boto3.client("s3")
|
||||||
|
|
||||||
|
|
||||||
@ -92,6 +101,9 @@ def copy_local_to_local(source_dir: str, dest_dir: str) -> None:
|
|||||||
files_to_copy = []
|
files_to_copy = []
|
||||||
for root, _, files in os.walk(source_dir):
|
for root, _, files in os.walk(source_dir):
|
||||||
for file in files:
|
for file in files:
|
||||||
|
if file in EXCLUDED_FILES:
|
||||||
|
print(f"Skipping excluded file: {file}")
|
||||||
|
continue
|
||||||
src_path = os.path.join(root, file)
|
src_path = os.path.join(root, file)
|
||||||
rel_path = os.path.relpath(src_path, source_dir)
|
rel_path = os.path.relpath(src_path, source_dir)
|
||||||
files_to_copy.append((src_path, os.path.join(dest_dir, rel_path)))
|
files_to_copy.append((src_path, os.path.join(dest_dir, rel_path)))
|
||||||
@ -129,6 +141,11 @@ def copy_s3_to_local(source_bucket: str, source_prefix: str, dest_dir: str) -> N
|
|||||||
if key.endswith("/"):
|
if key.endswith("/"):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
filename = os.path.basename(key)
|
||||||
|
if filename in EXCLUDED_FILES:
|
||||||
|
print(f"Skipping excluded file: {filename}")
|
||||||
|
continue
|
||||||
|
|
||||||
rel_path = os.path.relpath(key, source_prefix)
|
rel_path = os.path.relpath(key, source_prefix)
|
||||||
local_path = os.path.join(dest_dir, rel_path)
|
local_path = os.path.join(dest_dir, rel_path)
|
||||||
download_tasks.append((source_bucket, key, local_path))
|
download_tasks.append((source_bucket, key, local_path))
|
||||||
@ -151,6 +168,9 @@ def copy_local_to_s3(source_dir: str, dest_bucket: str, dest_prefix: str) -> Non
|
|||||||
upload_tasks = []
|
upload_tasks = []
|
||||||
for root, _, files in os.walk(source_dir):
|
for root, _, files in os.walk(source_dir):
|
||||||
for file in files:
|
for file in files:
|
||||||
|
if file in EXCLUDED_FILES:
|
||||||
|
print(f"Skipping excluded file: {file}")
|
||||||
|
continue
|
||||||
local_path = os.path.join(root, file)
|
local_path = os.path.join(root, file)
|
||||||
rel_path = os.path.relpath(local_path, source_dir)
|
rel_path = os.path.relpath(local_path, source_dir)
|
||||||
s3_key = os.path.join(dest_prefix, rel_path)
|
s3_key = os.path.join(dest_prefix, rel_path)
|
||||||
@ -181,6 +201,11 @@ def copy_s3_to_s3(source_bucket: str, source_prefix: str, dest_bucket: str, dest
|
|||||||
if key.endswith("/"):
|
if key.endswith("/"):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
filename = os.path.basename(key)
|
||||||
|
if filename in EXCLUDED_FILES:
|
||||||
|
print(f"Skipping excluded file: {filename}")
|
||||||
|
continue
|
||||||
|
|
||||||
rel_path = os.path.relpath(key, source_prefix)
|
rel_path = os.path.relpath(key, source_prefix)
|
||||||
dest_key = os.path.join(dest_prefix, rel_path)
|
dest_key = os.path.join(dest_prefix, rel_path)
|
||||||
copy_source = {"Bucket": source_bucket, "Key": key}
|
copy_source = {"Bucket": source_bucket, "Key": key}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user