mirror of
https://github.com/allenai/olmocr.git
synced 2025-10-11 08:12:22 +00:00
Adding some params
This commit is contained in:
parent
4b3660debd
commit
3be381b375
@ -5,6 +5,7 @@ set -e
|
||||
# Parse beaker-specific arguments
|
||||
SKIP_DOCKER_BUILD=false
|
||||
PREEMPTIBLE=false
|
||||
EXP_NAME=""
|
||||
|
||||
# Store all arguments to pass to python command
|
||||
PYTHON_ARGS=()
|
||||
@ -19,12 +20,17 @@ while [[ $# -gt 0 ]]; do
|
||||
PREEMPTIBLE=true
|
||||
shift
|
||||
;;
|
||||
--name)
|
||||
EXP_NAME="$2"
|
||||
shift 2
|
||||
;;
|
||||
--help|-h)
|
||||
echo "Usage: $0 [beaker-options] [grpo-training-options]"
|
||||
echo ""
|
||||
echo "Beaker-specific options:"
|
||||
echo " --skip-docker-build Skip Docker build"
|
||||
echo " --preemptible Use preemptible instances"
|
||||
echo " --name NAME Experiment name (used in output directory)"
|
||||
echo ""
|
||||
echo "All other arguments are forwarded to python -m olmocr.train.grpo_train"
|
||||
echo "Run 'python -m olmocr.train.grpo_train --help' to see available training options"
|
||||
@ -89,6 +95,7 @@ echo "Beaker user: $BEAKER_USER"
|
||||
cat << 'EOF' > /tmp/run_grpo_experiment.py
|
||||
import sys
|
||||
import shlex
|
||||
import os
|
||||
from beaker import Beaker, ExperimentSpec, TaskSpec, TaskContext, ResultSpec, TaskResources, ImageSource, Priority, Constraints, EnvVar, DataMount
|
||||
|
||||
# Get parameters from command line
|
||||
@ -97,8 +104,9 @@ beaker_user = sys.argv[2]
|
||||
git_branch = sys.argv[3]
|
||||
git_hash = sys.argv[4]
|
||||
preemptible = sys.argv[5] == "true"
|
||||
exp_name = sys.argv[6] # Empty string if not provided
|
||||
# All remaining arguments are the python command arguments
|
||||
python_args = sys.argv[6:]
|
||||
python_args = sys.argv[7:]
|
||||
|
||||
# Initialize Beaker client
|
||||
b = Beaker.from_env(default_workspace="ai2/olmocr")
|
||||
@ -159,7 +167,16 @@ if "--train_bench_data_folder" not in arg_str:
|
||||
if "--eval_bench_data_folder" not in arg_str:
|
||||
grpo_cmd.append("--eval_bench_data_folder /data/olmOCR-bench/bench_data")
|
||||
if "--output_dir" not in arg_str:
|
||||
grpo_cmd.append("--output_dir /weka/oe-training-default/jakep/olmocr-grpo-checkpoints")
|
||||
output_dir = "/weka/oe-training-default/jakep/olmocr-grpo-checkpoints"
|
||||
# Build subdirectory based on exp_name and BEAKER_WORKLOAD_ID
|
||||
beaker_workload_id = os.environ.get("BEAKER_WORKLOAD_ID")
|
||||
if exp_name and beaker_workload_id:
|
||||
output_dir = f"{output_dir}/{exp_name}-{beaker_workload_id}"
|
||||
elif beaker_workload_id:
|
||||
output_dir = f"{output_dir}/{beaker_workload_id}"
|
||||
elif exp_name:
|
||||
output_dir = f"{output_dir}/{exp_name}"
|
||||
grpo_cmd.append(f"--output_dir {output_dir}")
|
||||
|
||||
# Add all the (possibly modified) arguments
|
||||
grpo_cmd.extend(modified_args)
|
||||
@ -228,6 +245,7 @@ $PYTHON /tmp/run_grpo_experiment.py \
|
||||
"$GIT_BRANCH" \
|
||||
"$GIT_HASH" \
|
||||
"$PREEMPTIBLE" \
|
||||
"$EXP_NAME" \
|
||||
"${PYTHON_ARGS[@]}"
|
||||
|
||||
# Clean up temporary file
|
||||
|
Loading…
x
Reference in New Issue
Block a user