Adding more cmd line args

This commit is contained in:
Jake Poznanski 2025-08-26 18:28:41 +00:00
parent 9671f6847c
commit 3433c8f5f2

View File

@ -405,6 +405,38 @@ def main():
default=None,
help="Weights & Biases run name (default: auto-generated)"
)
parser.add_argument(
"--loss_type",
type=str,
default="bnpo",
choices=["bnpo", "grpo", "exo"],
help="Loss formulation to use (default: bnpo)"
)
parser.add_argument(
"--scale_rewards",
action="store_true",
default=True,
help="Whether to scale rewards by their standard deviation (default: True)"
)
parser.add_argument(
"--no_scale_rewards",
action="store_false",
dest="scale_rewards",
help="Disable reward scaling"
)
parser.add_argument(
"--beta",
type=float,
default=0.0,
help="KL coefficient for reference model (default: 0.0, no reference model)"
)
parser.add_argument(
"--importance_sampling_level",
type=str,
default="token",
choices=["token", "sequence"],
help="Level for importance sampling ratios (default: token)"
)
args = parser.parse_args()
@ -503,6 +535,12 @@ def main():
bf16=True,
dataloader_num_workers=0,
# GRPO-specific parameters
loss_type=args.loss_type,
scale_rewards=args.scale_rewards,
beta=args.beta,
importance_sampling_level=args.importance_sampling_level,
# Vllm setup to speed up generation
use_vllm=True,
vllm_mode="colocate",