mirror of
https://github.com/allenai/olmocr.git
synced 2025-11-14 17:38:12 +00:00
Sdpa
This commit is contained in:
parent
b96454b786
commit
850b598db1
@ -10,7 +10,7 @@ model:
|
|||||||
trust_remote_code: true
|
trust_remote_code: true
|
||||||
torch_dtype: auto
|
torch_dtype: auto
|
||||||
use_flash_attention: true
|
use_flash_attention: true
|
||||||
attn_implementation: flash_attention_2
|
attn_implementation: sdpa
|
||||||
|
|
||||||
# LoRA settings (disabled by default)
|
# LoRA settings (disabled by default)
|
||||||
use_lora: false
|
use_lora: false
|
||||||
|
|||||||
@ -8,6 +8,7 @@ import logging
|
|||||||
from transformers import (
|
from transformers import (
|
||||||
AutoProcessor,
|
AutoProcessor,
|
||||||
Qwen2VLForConditionalGeneration,
|
Qwen2VLForConditionalGeneration,
|
||||||
|
Qwen2_5_VLForConditionalGeneration,
|
||||||
Trainer,
|
Trainer,
|
||||||
TrainingArguments,
|
TrainingArguments,
|
||||||
EarlyStoppingCallback
|
EarlyStoppingCallback
|
||||||
@ -92,6 +93,15 @@ def main():
|
|||||||
|
|
||||||
# Load model
|
# Load model
|
||||||
logger.info(f"Loading model: {config.model.name}")
|
logger.info(f"Loading model: {config.model.name}")
|
||||||
|
if "Qwen2.5-VL" in config.model.name:
|
||||||
|
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||||
|
config.model.name,
|
||||||
|
torch_dtype=getattr(torch, config.model.torch_dtype) if config.model.torch_dtype != "auto" else "auto",
|
||||||
|
device_map=config.model.device_map,
|
||||||
|
trust_remote_code=config.model.trust_remote_code,
|
||||||
|
attn_implementation=config.model.attn_implementation if config.model.use_flash_attention else None,
|
||||||
|
)
|
||||||
|
elif "Qwen2-VL" in config.model.name:
|
||||||
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
model = Qwen2VLForConditionalGeneration.from_pretrained(
|
||||||
config.model.name,
|
config.model.name,
|
||||||
torch_dtype=getattr(torch, config.model.torch_dtype) if config.model.torch_dtype != "auto" else "auto",
|
torch_dtype=getattr(torch, config.model.torch_dtype) if config.model.torch_dtype != "auto" else "auto",
|
||||||
@ -99,6 +109,8 @@ def main():
|
|||||||
trust_remote_code=config.model.trust_remote_code,
|
trust_remote_code=config.model.trust_remote_code,
|
||||||
attn_implementation=config.model.attn_implementation if config.model.use_flash_attention else None,
|
attn_implementation=config.model.attn_implementation if config.model.use_flash_attention else None,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
# Enable gradient checkpointing if configured
|
# Enable gradient checkpointing if configured
|
||||||
if config.training.gradient_checkpointing:
|
if config.training.gradient_checkpointing:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user