diff --git a/scripts/train/grpotrainer-beaker.sh b/scripts/train/grpotrainer-beaker.sh index aac3991..24881f0 100755 --- a/scripts/train/grpotrainer-beaker.sh +++ b/scripts/train/grpotrainer-beaker.sh @@ -128,17 +128,17 @@ model_sync_commands = [] modified_args = list(python_args) for i in range(len(modified_args)): if modified_args[i] == "--model_name" and i + 1 < len(modified_args): - model_path = modified_args[i + 1] + model_path = modified_args[i + 1].rstrip('/') if model_path.startswith("s3://"): # Extract checkpoint name from S3 path (last part of path) - checkpoint_name = model_path.rstrip('/').split('/')[-1] + checkpoint_name = model_path.split('/')[-1] local_model_path = f"/data/models/{checkpoint_name}" # Create sync commands model_sync_commands = [ f"echo 'Syncing model from S3: {model_path}'", "mkdir -p /data/models", - f"s5cmd sync '{model_path}' '{local_model_path}/'", + f"s5cmd sync '{model_path}/*' '{local_model_path}/'", ] # Replace S3 path with local path in arguments