This commit is contained in:
Jake Poznanski 2025-06-27 16:59:33 +00:00
parent b96454b786
commit 850b598db1
2 changed files with 20 additions and 8 deletions

View File

@ -10,7 +10,7 @@ model:
trust_remote_code: true
torch_dtype: auto
use_flash_attention: true
attn_implementation: flash_attention_2
attn_implementation: sdpa
# LoRA settings (disabled by default)
use_lora: false

View File

@ -8,6 +8,7 @@ import logging
from transformers import (
AutoProcessor,
Qwen2VLForConditionalGeneration,
Qwen2_5_VLForConditionalGeneration,
Trainer,
TrainingArguments,
EarlyStoppingCallback
@ -92,13 +93,24 @@ def main():
# Load model
logger.info(f"Loading model: {config.model.name}")
model = Qwen2VLForConditionalGeneration.from_pretrained(
config.model.name,
torch_dtype=getattr(torch, config.model.torch_dtype) if config.model.torch_dtype != "auto" else "auto",
device_map=config.model.device_map,
trust_remote_code=config.model.trust_remote_code,
attn_implementation=config.model.attn_implementation if config.model.use_flash_attention else None,
)
if "Qwen2.5-VL" in config.model.name:
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
config.model.name,
torch_dtype=getattr(torch, config.model.torch_dtype) if config.model.torch_dtype != "auto" else "auto",
device_map=config.model.device_map,
trust_remote_code=config.model.trust_remote_code,
attn_implementation=config.model.attn_implementation if config.model.use_flash_attention else None,
)
elif "Qwen2-VL" in config.model.name:
model = Qwen2VLForConditionalGeneration.from_pretrained(
config.model.name,
torch_dtype=getattr(torch, config.model.torch_dtype) if config.model.torch_dtype != "auto" else "auto",
device_map=config.model.device_map,
trust_remote_code=config.model.trust_remote_code,
attn_implementation=config.model.attn_implementation if config.model.use_flash_attention else None,
)
else:
raise NotImplementedError()
# Enable gradient checkpointing if configured
if config.training.gradient_checkpointing: