diff --git a/olmocr/train/configs/example_config.yaml b/olmocr/train/configs/example_config.yaml index b9d7c17..10c6d42 100644 --- a/olmocr/train/configs/example_config.yaml +++ b/olmocr/train/configs/example_config.yaml @@ -10,7 +10,7 @@ model: trust_remote_code: true torch_dtype: auto use_flash_attention: true - attn_implementation: flash_attention_2 + attn_implementation: sdpa # LoRA settings (disabled by default) use_lora: false diff --git a/olmocr/train/train.py b/olmocr/train/train.py index 5ec9630..f69c066 100644 --- a/olmocr/train/train.py +++ b/olmocr/train/train.py @@ -8,6 +8,7 @@ import logging from transformers import ( AutoProcessor, Qwen2VLForConditionalGeneration, + Qwen2_5_VLForConditionalGeneration, Trainer, TrainingArguments, EarlyStoppingCallback @@ -92,13 +93,24 @@ def main(): # Load model logger.info(f"Loading model: {config.model.name}") - model = Qwen2VLForConditionalGeneration.from_pretrained( - config.model.name, - torch_dtype=getattr(torch, config.model.torch_dtype) if config.model.torch_dtype != "auto" else "auto", - device_map=config.model.device_map, - trust_remote_code=config.model.trust_remote_code, - attn_implementation=config.model.attn_implementation if config.model.use_flash_attention else None, - ) + if "Qwen2.5-VL" in config.model.name: + model = Qwen2_5_VLForConditionalGeneration.from_pretrained( + config.model.name, + torch_dtype=getattr(torch, config.model.torch_dtype) if config.model.torch_dtype != "auto" else "auto", + device_map=config.model.device_map, + trust_remote_code=config.model.trust_remote_code, + attn_implementation=config.model.attn_implementation if config.model.use_flash_attention else None, + ) + elif "Qwen2-VL" in config.model.name: + model = Qwen2VLForConditionalGeneration.from_pretrained( + config.model.name, + torch_dtype=getattr(torch, config.model.torch_dtype) if config.model.torch_dtype != "auto" else "auto", + device_map=config.model.device_map, + trust_remote_code=config.model.trust_remote_code, + attn_implementation=config.model.attn_implementation if config.model.use_flash_attention else None, + ) + else: + raise NotImplementedError() # Enable gradient checkpointing if configured if config.training.gradient_checkpointing: