Adding some params

This commit is contained in:
Jake Poznanski 2025-08-26 20:46:06 +00:00
parent 4b3660debd
commit 3be381b375

View File

@ -5,6 +5,7 @@ set -e
# Parse beaker-specific arguments
SKIP_DOCKER_BUILD=false
PREEMPTIBLE=false
EXP_NAME=""
# Store all arguments to pass to python command
PYTHON_ARGS=()
@ -19,12 +20,17 @@ while [[ $# -gt 0 ]]; do
PREEMPTIBLE=true
shift
;;
--name)
EXP_NAME="$2"
shift 2
;;
--help|-h)
echo "Usage: $0 [beaker-options] [grpo-training-options]"
echo ""
echo "Beaker-specific options:"
echo " --skip-docker-build Skip Docker build"
echo " --preemptible Use preemptible instances"
echo " --name NAME Experiment name (used in output directory)"
echo ""
echo "All other arguments are forwarded to python -m olmocr.train.grpo_train"
echo "Run 'python -m olmocr.train.grpo_train --help' to see available training options"
@ -89,6 +95,7 @@ echo "Beaker user: $BEAKER_USER"
cat << 'EOF' > /tmp/run_grpo_experiment.py
import sys
import shlex
import os
from beaker import Beaker, ExperimentSpec, TaskSpec, TaskContext, ResultSpec, TaskResources, ImageSource, Priority, Constraints, EnvVar, DataMount
# Get parameters from command line
@ -97,8 +104,9 @@ beaker_user = sys.argv[2]
git_branch = sys.argv[3]
git_hash = sys.argv[4]
preemptible = sys.argv[5] == "true"
exp_name = sys.argv[6] # Empty string if not provided
# All remaining arguments are the python command arguments
python_args = sys.argv[6:]
python_args = sys.argv[7:]
# Initialize Beaker client
b = Beaker.from_env(default_workspace="ai2/olmocr")
@ -159,7 +167,16 @@ if "--train_bench_data_folder" not in arg_str:
if "--eval_bench_data_folder" not in arg_str:
grpo_cmd.append("--eval_bench_data_folder /data/olmOCR-bench/bench_data")
if "--output_dir" not in arg_str:
grpo_cmd.append("--output_dir /weka/oe-training-default/jakep/olmocr-grpo-checkpoints")
output_dir = "/weka/oe-training-default/jakep/olmocr-grpo-checkpoints"
# Build subdirectory based on exp_name and BEAKER_WORKLOAD_ID
beaker_workload_id = os.environ.get("BEAKER_WORKLOAD_ID")
if exp_name and beaker_workload_id:
output_dir = f"{output_dir}/{exp_name}-{beaker_workload_id}"
elif beaker_workload_id:
output_dir = f"{output_dir}/{beaker_workload_id}"
elif exp_name:
output_dir = f"{output_dir}/{exp_name}"
grpo_cmd.append(f"--output_dir {output_dir}")
# Add all the (possibly modified) arguments
grpo_cmd.extend(modified_args)
@ -228,6 +245,7 @@ $PYTHON /tmp/run_grpo_experiment.py \
"$GIT_BRANCH" \
"$GIT_HASH" \
"$PREEMPTIBLE" \
"$EXP_NAME" \
"${PYTHON_ARGS[@]}"
# Clean up temporary file