From 3433c8f5f2de7c27f41b55edd9655e666452dc5e Mon Sep 17 00:00:00 2001 From: Jake Poznanski Date: Tue, 26 Aug 2025 18:28:41 +0000 Subject: [PATCH] Adding more cmd line args --- olmocr/train/grpo_train.py | 38 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/olmocr/train/grpo_train.py b/olmocr/train/grpo_train.py index f7e904a..4e30830 100644 --- a/olmocr/train/grpo_train.py +++ b/olmocr/train/grpo_train.py @@ -405,6 +405,38 @@ def main(): default=None, help="Weights & Biases run name (default: auto-generated)" ) + parser.add_argument( + "--loss_type", + type=str, + default="bnpo", + choices=["bnpo", "grpo", "exo"], + help="Loss formulation to use (default: bnpo)" + ) + parser.add_argument( + "--scale_rewards", + action="store_true", + default=True, + help="Whether to scale rewards by their standard deviation (default: True)" + ) + parser.add_argument( + "--no_scale_rewards", + action="store_false", + dest="scale_rewards", + help="Disable reward scaling" + ) + parser.add_argument( + "--beta", + type=float, + default=0.0, + help="KL coefficient for reference model (default: 0.0, no reference model)" + ) + parser.add_argument( + "--importance_sampling_level", + type=str, + default="token", + choices=["token", "sequence"], + help="Level for importance sampling ratios (default: token)" + ) args = parser.parse_args() @@ -502,6 +534,12 @@ def main(): remove_unused_columns=False, bf16=True, dataloader_num_workers=0, + + # GRPO-specific parameters + loss_type=args.loss_type, + scale_rewards=args.scale_rewards, + beta=args.beta, + importance_sampling_level=args.importance_sampling_level, # Vllm setup to speed up generation use_vllm=True,