diff --git a/olmocr/train/grpo_train.py b/olmocr/train/grpo_train.py index 61859c0..b0c4822 100644 --- a/olmocr/train/grpo_train.py +++ b/olmocr/train/grpo_train.py @@ -408,9 +408,7 @@ def main(): temperature=0.7, report_to=report_to, remove_unused_columns=False, - torch_dtype=torch.bfloat16, bf16=True, - gradient_checkpointing=True, dataloader_num_workers=0, )