diff --git a/olmocr/train/grpo_train.py b/olmocr/train/grpo_train.py index 5d05d9a..5677959 100644 --- a/olmocr/train/grpo_train.py +++ b/olmocr/train/grpo_train.py @@ -473,6 +473,13 @@ def main(): trust_remote_code=True, ) + # Load model + logger.info(f"Loading model: {args.model_name}") + if "Qwen2-VL" in args.model_name: + model_class = Qwen2VLForConditionalGeneration + else: + model_class = Qwen2_5_VLForConditionalGeneration + model = model_class.from_pretrained( args.model_name, torch_dtype=torch.bfloat16,