mirror of
				https://github.com/allenai/olmocr.git
				synced 2025-11-03 19:45:41 +00:00 
			
		
		
		
	Sdpa
This commit is contained in:
		
							parent
							
								
									b96454b786
								
							
						
					
					
						commit
						850b598db1
					
				@ -10,7 +10,7 @@ model:
 | 
			
		||||
  trust_remote_code: true
 | 
			
		||||
  torch_dtype: auto
 | 
			
		||||
  use_flash_attention: true
 | 
			
		||||
  attn_implementation: flash_attention_2
 | 
			
		||||
  attn_implementation: sdpa
 | 
			
		||||
  
 | 
			
		||||
  # LoRA settings (disabled by default)
 | 
			
		||||
  use_lora: false
 | 
			
		||||
 | 
			
		||||
@ -8,6 +8,7 @@ import logging
 | 
			
		||||
from transformers import (
 | 
			
		||||
    AutoProcessor,
 | 
			
		||||
    Qwen2VLForConditionalGeneration,
 | 
			
		||||
    Qwen2_5_VLForConditionalGeneration,
 | 
			
		||||
    Trainer,
 | 
			
		||||
    TrainingArguments,
 | 
			
		||||
    EarlyStoppingCallback
 | 
			
		||||
@ -92,13 +93,24 @@ def main():
 | 
			
		||||
    
 | 
			
		||||
    # Load model
 | 
			
		||||
    logger.info(f"Loading model: {config.model.name}")
 | 
			
		||||
    model = Qwen2VLForConditionalGeneration.from_pretrained(
 | 
			
		||||
        config.model.name,
 | 
			
		||||
        torch_dtype=getattr(torch, config.model.torch_dtype) if config.model.torch_dtype != "auto" else "auto",
 | 
			
		||||
        device_map=config.model.device_map,
 | 
			
		||||
        trust_remote_code=config.model.trust_remote_code,
 | 
			
		||||
        attn_implementation=config.model.attn_implementation if config.model.use_flash_attention else None,
 | 
			
		||||
    )
 | 
			
		||||
    if "Qwen2.5-VL" in config.model.name:
 | 
			
		||||
        model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
 | 
			
		||||
            config.model.name,
 | 
			
		||||
            torch_dtype=getattr(torch, config.model.torch_dtype) if config.model.torch_dtype != "auto" else "auto",
 | 
			
		||||
            device_map=config.model.device_map,
 | 
			
		||||
            trust_remote_code=config.model.trust_remote_code,
 | 
			
		||||
            attn_implementation=config.model.attn_implementation if config.model.use_flash_attention else None,
 | 
			
		||||
        )
 | 
			
		||||
    elif "Qwen2-VL" in config.model.name:
 | 
			
		||||
        model = Qwen2VLForConditionalGeneration.from_pretrained(
 | 
			
		||||
            config.model.name,
 | 
			
		||||
            torch_dtype=getattr(torch, config.model.torch_dtype) if config.model.torch_dtype != "auto" else "auto",
 | 
			
		||||
            device_map=config.model.device_map,
 | 
			
		||||
            trust_remote_code=config.model.trust_remote_code,
 | 
			
		||||
            attn_implementation=config.model.attn_implementation if config.model.use_flash_attention else None,
 | 
			
		||||
        )
 | 
			
		||||
    else:
 | 
			
		||||
        raise NotImplementedError()
 | 
			
		||||
    
 | 
			
		||||
    # Enable gradient checkpointing if configured
 | 
			
		||||
    if config.training.gradient_checkpointing:
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user