mirror of
				https://github.com/AUTOMATIC1111/stable-diffusion-webui.git
				synced 2025-10-31 01:54:44 +00:00 
			
		
		
		
	Merge pull request #14791 from AUTOMATIC1111/fix-mha-manual-cast
Fix dtype error in MHA layer/change dtype checking mechanism for manual cast
This commit is contained in:
		
						commit
						ce168ab5db
					
				| @ -4,7 +4,6 @@ from functools import lru_cache | ||||
| 
 | ||||
| import torch | ||||
| from modules import errors, shared | ||||
| from modules import torch_utils | ||||
| 
 | ||||
| if sys.platform == "darwin": | ||||
|     from modules import mac_specific | ||||
| @ -141,7 +140,12 @@ def manual_cast_forward(target_dtype): | ||||
|             args = [arg.to(target_dtype) if isinstance(arg, torch.Tensor) else arg for arg in args] | ||||
|             kwargs = {k: v.to(target_dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()} | ||||
| 
 | ||||
|         org_dtype = torch_utils.get_param(self).dtype | ||||
|         org_dtype = target_dtype | ||||
|         for param in self.parameters(): | ||||
|             if param.dtype != target_dtype: | ||||
|                 org_dtype = param.dtype | ||||
|                 break | ||||
| 
 | ||||
|         if org_dtype != target_dtype: | ||||
|             self.to(target_dtype) | ||||
|         result = self.org_forward(*args, **kwargs) | ||||
| @ -170,7 +174,7 @@ def manual_cast(target_dtype): | ||||
|             continue | ||||
|         applied = True | ||||
|         org_forward = module_type.forward | ||||
|         if module_type == torch.nn.MultiheadAttention and has_xpu(): | ||||
|         if module_type == torch.nn.MultiheadAttention: | ||||
|             module_type.forward = manual_cast_forward(torch.float32) | ||||
|         else: | ||||
|             module_type.forward = manual_cast_forward(target_dtype) | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 AUTOMATIC1111
						AUTOMATIC1111