2023-05-21 00:26:16 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								from  __future__  import  annotations  
						 
					
						
							
								
									
										
										
										
											2022-10-02 15:03:39 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								import  math  
						 
					
						
							
								
									
										
										
										
											2022-12-27 08:50:55 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								import  psutil  
						 
					
						
							
								
									
										
										
										
											2022-10-08 17:02:18 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-10-02 15:03:39 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								import  torch  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								from  torch  import  einsum  
						 
					
						
							
								
									
										
										
										
											2022-10-08 16:33:39 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-10-02 15:03:39 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								from  ldm . util  import  default  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								from  einops  import  rearrange  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-05-19 00:03:27 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								from  modules  import  shared ,  errors ,  devices ,  sub_quadratic_attention  
						 
					
						
							
								
									
										
										
										
											2022-10-11 15:51:22 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								from  modules . hypernetworks  import  hypernetwork  
						 
					
						
							
								
									
										
										
										
											2022-10-11 11:09:51 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-05-18 22:48:28 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								import  ldm . modules . attention  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								import  ldm . modules . diffusionmodules . model  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-07-12 23:52:43 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								import  sgm . modules . attention  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								import  sgm . modules . diffusionmodules . model  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-05-18 22:48:28 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								diffusionmodules_model_AttnBlock_forward  =  ldm . modules . diffusionmodules . model . AttnBlock . forward  
						 
					
						
							
								
									
										
										
										
											2023-07-12 23:52:43 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								sgm_diffusionmodules_model_AttnBlock_forward  =  sgm . modules . diffusionmodules . model . AttnBlock . forward  
						 
					
						
							
								
									
										
										
										
											2023-05-18 22:48:28 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								class  SdOptimization :  
						 
					
						
							
								
									
										
										
										
											2023-05-19 09:17:36 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    name :  str  =  None 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    label :  str  |  None  =  None 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    cmd_opt :  str  |  None  =  None 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    priority :  int  =  0 
							 
						 
					
						
							
								
									
										
										
										
											2023-05-18 22:48:28 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    def  title ( self ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        if  self . label  is  None : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            return  self . name 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        return  f " { self . name }  -  { self . label } " 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    def  is_available ( self ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        return  True 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    def  apply ( self ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        pass 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    def  undo ( self ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        ldm . modules . attention . CrossAttention . forward  =  hypernetwork . attention_CrossAttention_forward 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        ldm . modules . diffusionmodules . model . AttnBlock . forward  =  diffusionmodules_model_AttnBlock_forward 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-07-12 23:52:43 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        sgm . modules . attention . CrossAttention . forward  =  hypernetwork . attention_CrossAttention_forward 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        sgm . modules . diffusionmodules . model . AttnBlock . forward  =  sgm_diffusionmodules_model_AttnBlock_forward 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-05-18 22:48:28 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								class  SdOptimizationXformers ( SdOptimization ) :  
						 
					
						
							
								
									
										
										
										
											2023-05-19 09:17:36 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    name  =  " xformers " 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    cmd_opt  =  " xformers " 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    priority  =  100 
							 
						 
					
						
							
								
									
										
										
										
											2023-05-18 22:48:28 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    def  is_available ( self ) : 
							 
						 
					
						
							
								
									
										
										
										
											2023-06-04 11:33:51 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        return  shared . cmd_opts . force_enable_xformers  or  ( shared . xformers_available  and  torch . cuda . is_available ( )  and  ( 6 ,  0 )  < =  torch . cuda . get_device_capability ( shared . device )  < =  ( 9 ,  0 ) ) 
							 
						 
					
						
							
								
									
										
										
										
											2023-05-18 22:48:28 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    def  apply ( self ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        ldm . modules . attention . CrossAttention . forward  =  xformers_attention_forward 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        ldm . modules . diffusionmodules . model . AttnBlock . forward  =  xformers_attnblock_forward 
							 
						 
					
						
							
								
									
										
										
										
											2023-07-12 23:52:43 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        sgm . modules . attention . CrossAttention . forward  =  xformers_attention_forward 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        sgm . modules . diffusionmodules . model . AttnBlock . forward  =  xformers_attnblock_forward 
							 
						 
					
						
							
								
									
										
										
										
											2023-05-18 22:48:28 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								class  SdOptimizationSdpNoMem ( SdOptimization ) :  
						 
					
						
							
								
									
										
										
										
											2023-05-19 09:17:36 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    name  =  " sdp-no-mem " 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    label  =  " scaled dot product without memory efficient attention " 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    cmd_opt  =  " opt_sdp_no_mem_attention " 
							 
						 
					
						
							
								
									
										
										
										
											2023-06-01 08:12:06 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    priority  =  80 
							 
						 
					
						
							
								
									
										
										
										
											2023-05-18 22:48:28 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    def  is_available ( self ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        return  hasattr ( torch . nn . functional ,  " scaled_dot_product_attention " )  and  callable ( torch . nn . functional . scaled_dot_product_attention ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    def  apply ( self ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        ldm . modules . attention . CrossAttention . forward  =  scaled_dot_product_no_mem_attention_forward 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        ldm . modules . diffusionmodules . model . AttnBlock . forward  =  sdp_no_mem_attnblock_forward 
							 
						 
					
						
							
								
									
										
										
										
											2023-07-12 23:52:43 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        sgm . modules . attention . CrossAttention . forward  =  scaled_dot_product_no_mem_attention_forward 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        sgm . modules . diffusionmodules . model . AttnBlock . forward  =  sdp_no_mem_attnblock_forward 
							 
						 
					
						
							
								
									
										
										
										
											2023-05-18 22:48:28 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								class  SdOptimizationSdp ( SdOptimizationSdpNoMem ) :  
						 
					
						
							
								
									
										
										
										
											2023-05-19 09:17:36 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    name  =  " sdp " 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    label  =  " scaled dot product " 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    cmd_opt  =  " opt_sdp_attention " 
							 
						 
					
						
							
								
									
										
										
										
											2023-06-01 08:12:06 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    priority  =  70 
							 
						 
					
						
							
								
									
										
										
										
											2023-05-18 22:48:28 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    def  apply ( self ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        ldm . modules . attention . CrossAttention . forward  =  scaled_dot_product_attention_forward 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        ldm . modules . diffusionmodules . model . AttnBlock . forward  =  sdp_attnblock_forward 
							 
						 
					
						
							
								
									
										
										
										
											2023-07-12 23:52:43 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        sgm . modules . attention . CrossAttention . forward  =  scaled_dot_product_attention_forward 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        sgm . modules . diffusionmodules . model . AttnBlock . forward  =  sdp_attnblock_forward 
							 
						 
					
						
							
								
									
										
										
										
											2023-05-18 22:48:28 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								class  SdOptimizationSubQuad ( SdOptimization ) :  
						 
					
						
							
								
									
										
										
										
											2023-05-19 09:17:36 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    name  =  " sub-quadratic " 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    cmd_opt  =  " opt_sub_quad_attention " 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    priority  =  10 
							 
						 
					
						
							
								
									
										
										
										
											2023-05-18 22:48:28 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    def  apply ( self ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        ldm . modules . attention . CrossAttention . forward  =  sub_quad_attention_forward 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        ldm . modules . diffusionmodules . model . AttnBlock . forward  =  sub_quad_attnblock_forward 
							 
						 
					
						
							
								
									
										
										
										
											2023-07-12 23:52:43 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        sgm . modules . attention . CrossAttention . forward  =  sub_quad_attention_forward 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        sgm . modules . diffusionmodules . model . AttnBlock . forward  =  sub_quad_attnblock_forward 
							 
						 
					
						
							
								
									
										
										
										
											2023-05-18 22:48:28 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								class  SdOptimizationV1 ( SdOptimization ) :  
						 
					
						
							
								
									
										
										
										
											2023-05-19 09:17:36 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    name  =  " V1 " 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    label  =  " original v1 " 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    cmd_opt  =  " opt_split_attention_v1 " 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    priority  =  10 
							 
						 
					
						
							
								
									
										
										
										
											2023-05-18 22:48:28 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    def  apply ( self ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        ldm . modules . attention . CrossAttention . forward  =  split_cross_attention_forward_v1 
							 
						 
					
						
							
								
									
										
										
										
											2023-07-12 23:52:43 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        sgm . modules . attention . CrossAttention . forward  =  split_cross_attention_forward_v1 
							 
						 
					
						
							
								
									
										
										
										
											2023-05-18 22:48:28 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								class  SdOptimizationInvokeAI ( SdOptimization ) :  
						 
					
						
							
								
									
										
										
										
											2023-05-19 09:17:36 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    name  =  " InvokeAI " 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    cmd_opt  =  " opt_split_attention_invokeai " 
							 
						 
					
						
							
								
									
										
										
										
											2023-05-18 22:48:28 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-05-19 09:17:36 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    @property 
							 
						 
					
						
							
								
									
										
										
										
											2023-05-18 22:48:28 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    def  priority ( self ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        return  1000  if  not  torch . cuda . is_available ( )  else  10 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    def  apply ( self ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        ldm . modules . attention . CrossAttention . forward  =  split_cross_attention_forward_invokeAI 
							 
						 
					
						
							
								
									
										
										
										
											2023-07-12 23:52:43 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        sgm . modules . attention . CrossAttention . forward  =  split_cross_attention_forward_invokeAI 
							 
						 
					
						
							
								
									
										
										
										
											2023-05-18 22:48:28 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								class  SdOptimizationDoggettx ( SdOptimization ) :  
						 
					
						
							
								
									
										
										
										
											2023-05-19 09:17:36 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    name  =  " Doggettx " 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    cmd_opt  =  " opt_split_attention " 
							 
						 
					
						
							
								
									
										
										
										
											2023-06-01 08:12:06 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    priority  =  90 
							 
						 
					
						
							
								
									
										
										
										
											2023-05-18 22:48:28 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    def  apply ( self ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        ldm . modules . attention . CrossAttention . forward  =  split_cross_attention_forward 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        ldm . modules . diffusionmodules . model . AttnBlock . forward  =  cross_attention_attnblock_forward 
							 
						 
					
						
							
								
									
										
										
										
											2023-07-12 23:52:43 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        sgm . modules . attention . CrossAttention . forward  =  split_cross_attention_forward 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        sgm . modules . diffusionmodules . model . AttnBlock . forward  =  cross_attention_attnblock_forward 
							 
						 
					
						
							
								
									
										
										
										
											2023-05-18 22:48:28 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								def  list_optimizers ( res ) :  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    res . extend ( [ 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        SdOptimizationXformers ( ) , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        SdOptimizationSdpNoMem ( ) , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        SdOptimizationSdp ( ) , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        SdOptimizationSubQuad ( ) , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        SdOptimizationV1 ( ) , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        SdOptimizationInvokeAI ( ) , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        SdOptimizationDoggettx ( ) , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    ] ) 
							 
						 
					
						
							
								
									
										
										
										
											2022-12-27 08:50:55 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-10-07 10:17:52 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-10-08 19:25:10 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								if  shared . cmd_opts . xformers  or  shared . cmd_opts . force_enable_xformers :  
						 
					
						
							
								
									
										
										
										
											2022-10-08 17:02:18 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    try : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        import  xformers . ops 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        shared . xformers_available  =  True 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    except  Exception : 
							 
						 
					
						
							
								
									
										
										
										
											2023-05-31 19:56:37 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        errors . report ( " Cannot import xformers " ,  exc_info = True ) 
							 
						 
					
						
							
								
									
										
										
										
											2022-10-08 17:02:18 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-10-02 15:03:39 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-12-27 08:50:55 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								def  get_available_vram ( ) :  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    if  shared . device . type  ==  ' cuda ' : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        stats  =  torch . cuda . memory_stats ( shared . device ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        mem_active  =  stats [ ' active_bytes.all.current ' ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        mem_reserved  =  stats [ ' reserved_bytes.all.current ' ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        mem_free_cuda ,  _  =  torch . cuda . mem_get_info ( torch . cuda . current_device ( ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        mem_free_torch  =  mem_reserved  -  mem_active 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        mem_free_total  =  mem_free_cuda  +  mem_free_torch 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        return  mem_free_total 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    else : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        return  psutil . virtual_memory ( ) . available 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-10-02 15:03:39 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion  
						 
					
						
							
								
									
										
										
										
											2023-07-13 09:30:33 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								def  split_cross_attention_forward_v1 ( self ,  x ,  context = None ,  mask = None ,  * * kwargs ) :  
						 
					
						
							
								
									
										
										
										
											2022-10-02 15:03:39 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    h  =  self . heads 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-10-08 01:47:02 -04:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    q_in  =  self . to_q ( x ) 
							 
						 
					
						
							
								
									
										
										
										
											2022-10-02 15:03:39 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    context  =  default ( context ,  x ) 
							 
						 
					
						
							
								
									
										
										
										
											2022-10-08 01:47:02 -04:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-21 08:36:07 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    context_k ,  context_v  =  hypernetwork . apply_hypernetworks ( shared . loaded_hypernetworks ,  context ) 
							 
						 
					
						
							
								
									
										
										
										
											2022-10-11 11:09:51 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    k_in  =  self . to_k ( context_k ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    v_in  =  self . to_v ( context_v ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    del  context ,  context_k ,  context_v ,  x 
							 
						 
					
						
							
								
									
										
										
										
											2022-10-02 15:03:39 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-05-10 11:05:02 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    q ,  k ,  v  =  ( rearrange ( t ,  ' b n (h d) -> (b h) n d ' ,  h = h )  for  t  in  ( q_in ,  k_in ,  v_in ) ) 
							 
						 
					
						
							
								
									
										
										
										
											2022-10-08 01:47:02 -04:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    del  q_in ,  k_in ,  v_in 
							 
						 
					
						
							
								
									
										
										
										
											2022-10-02 15:03:39 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-25 00:23:10 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    dtype  =  q . dtype 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    if  shared . opts . upcast_attn : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        q ,  k ,  v  =  q . float ( ) ,  k . float ( ) ,  v . float ( ) 
							 
						 
					
						
							
								
									
										
										
										
											2022-10-02 15:03:39 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-25 00:23:10 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    with  devices . without_autocast ( disable = not  shared . opts . upcast_attn ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        r1  =  torch . zeros ( q . shape [ 0 ] ,  q . shape [ 1 ] ,  v . shape [ 2 ] ,  device = q . device ,  dtype = q . dtype ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        for  i  in  range ( 0 ,  q . shape [ 0 ] ,  2 ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            end  =  i  +  2 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            s1  =  einsum ( ' b i d, b j d -> b i j ' ,  q [ i : end ] ,  k [ i : end ] ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            s1  * =  self . scale 
							 
						 
					
						
							
								
									
										
										
										
											2023-05-11 18:28:15 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-25 00:23:10 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								            s2  =  s1 . softmax ( dim = - 1 ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            del  s1 
							 
						 
					
						
							
								
									
										
										
										
											2023-05-11 18:28:15 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-25 00:23:10 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								            r1 [ i : end ]  =  einsum ( ' b i j, b j d -> b i d ' ,  s2 ,  v [ i : end ] ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            del  s2 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        del  q ,  k ,  v 
							 
						 
					
						
							
								
									
										
										
										
											2022-10-02 15:03:39 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-25 00:23:10 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    r1  =  r1 . to ( dtype ) 
							 
						 
					
						
							
								
									
										
										
										
											2022-10-02 15:03:39 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    r2  =  rearrange ( r1 ,  ' (b h) n d -> b n (h d) ' ,  h = h ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    del  r1 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    return  self . to_out ( r2 ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-10-11 11:09:51 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								# taken from https://github.com/Doggettx/stable-diffusion and modified  
						 
					
						
							
								
									
										
										
										
											2023-07-13 09:30:33 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								def  split_cross_attention_forward ( self ,  x ,  context = None ,  mask = None ,  * * kwargs ) :  
						 
					
						
							
								
									
										
										
										
											2022-10-02 15:03:39 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    h  =  self . heads 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    q_in  =  self . to_q ( x ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    context  =  default ( context ,  x ) 
							 
						 
					
						
							
								
									
										
										
										
											2022-10-07 10:17:52 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-21 08:36:07 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    context_k ,  context_v  =  hypernetwork . apply_hypernetworks ( shared . loaded_hypernetworks ,  context ) 
							 
						 
					
						
							
								
									
										
										
										
											2022-10-11 11:09:51 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    k_in  =  self . to_k ( context_k ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    v_in  =  self . to_v ( context_v ) 
							 
						 
					
						
							
								
									
										
										
										
											2022-10-07 10:17:52 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-25 00:23:10 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    dtype  =  q_in . dtype 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    if  shared . opts . upcast_attn : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        q_in ,  k_in ,  v_in  =  q_in . float ( ) ,  k_in . float ( ) ,  v_in  if  v_in . device . type  ==  ' mps '  else  v_in . float ( ) 
							 
						 
					
						
							
								
									
										
										
										
											2022-10-02 15:03:39 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-25 00:23:10 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    with  devices . without_autocast ( disable = not  shared . opts . upcast_attn ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        k_in  =  k_in  *  self . scale 
							 
						 
					
						
							
								
									
										
										
										
											2023-05-11 18:28:15 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-25 00:23:10 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        del  context ,  x 
							 
						 
					
						
							
								
									
										
										
										
											2023-05-11 18:28:15 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-05-10 11:05:02 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        q ,  k ,  v  =  ( rearrange ( t ,  ' b n (h d) -> (b h) n d ' ,  h = h )  for  t  in  ( q_in ,  k_in ,  v_in ) ) 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-25 00:23:10 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        del  q_in ,  k_in ,  v_in 
							 
						 
					
						
							
								
									
										
										
										
											2023-05-11 18:28:15 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-25 00:23:10 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        r1  =  torch . zeros ( q . shape [ 0 ] ,  q . shape [ 1 ] ,  v . shape [ 2 ] ,  device = q . device ,  dtype = q . dtype ) 
							 
						 
					
						
							
								
									
										
										
										
											2023-05-11 18:28:15 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-25 00:23:10 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        mem_free_total  =  get_available_vram ( ) 
							 
						 
					
						
							
								
									
										
										
										
											2023-05-11 18:28:15 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-25 00:23:10 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        gb  =  1024  * *  3 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        tensor_size  =  q . shape [ 0 ]  *  q . shape [ 1 ]  *  k . shape [ 1 ]  *  q . element_size ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        modifier  =  3  if  q . element_size ( )  ==  2  else  2.5 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        mem_required  =  tensor_size  *  modifier 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        steps  =  1 
							 
						 
					
						
							
								
									
										
										
										
											2023-05-11 18:28:15 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-25 00:23:10 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        if  mem_required  >  mem_free_total : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            steps  =  2  * *  ( math . ceil ( math . log ( mem_required  /  mem_free_total ,  2 ) ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB " 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            #       f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}") 
							 
						 
					
						
							
								
									
										
										
										
											2023-05-11 18:28:15 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-25 00:23:10 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        if  steps  >  64 : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            max_res  =  math . floor ( math . sqrt ( math . sqrt ( mem_free_total  /  2.5 ) )  /  8 )  *  64 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            raise  RuntimeError ( f ' Not enough memory, use lower resolution (max approx.  { max_res } x { max_res } ).  ' 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                               f ' Need:  { mem_required  /  64  /  gb : 0.1f } GB free, Have: { mem_free_total  /  gb : 0.1f } GB free ' ) 
							 
						 
					
						
							
								
									
										
										
										
											2023-05-11 18:28:15 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-08-02 18:37:16 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        slice_size  =  q . shape [ 1 ]  / /  steps 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-25 00:23:10 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        for  i  in  range ( 0 ,  q . shape [ 1 ] ,  slice_size ) : 
							 
						 
					
						
							
								
									
										
										
										
											2023-08-02 18:37:16 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								            end  =  min ( i  +  slice_size ,  q . shape [ 1 ] ) 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-25 00:23:10 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								            s1  =  einsum ( ' b i d, b j d -> b i j ' ,  q [ : ,  i : end ] ,  k ) 
							 
						 
					
						
							
								
									
										
										
										
											2023-05-11 18:28:15 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-25 00:23:10 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								            s2  =  s1 . softmax ( dim = - 1 ,  dtype = q . dtype ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            del  s1 
							 
						 
					
						
							
								
									
										
										
										
											2023-05-11 18:28:15 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-25 00:23:10 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								            r1 [ : ,  i : end ]  =  einsum ( ' b i j, b j d -> b i d ' ,  s2 ,  v ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            del  s2 
							 
						 
					
						
							
								
									
										
										
										
											2023-05-11 18:28:15 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-25 00:23:10 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        del  q ,  k ,  v 
							 
						 
					
						
							
								
									
										
										
										
											2022-10-02 15:03:39 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-25 00:23:10 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    r1  =  r1 . to ( dtype ) 
							 
						 
					
						
							
								
									
										
										
										
											2022-10-02 15:03:39 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    r2  =  rearrange ( r1 ,  ' (b h) n d -> b n (h d) ' ,  h = h ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    del  r1 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    return  self . to_out ( r2 ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-10-10 22:48:54 -04:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-12-19 17:25:14 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								# -- Taken from https://github.com/invoke-ai/InvokeAI and modified --  
						 
					
						
							
								
									
										
										
										
											2022-12-27 08:50:55 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								mem_total_gb  =  psutil . virtual_memory ( ) . total  / /  ( 1  <<  30 )  
						 
					
						
							
								
									
										
										
										
											2022-10-10 22:48:54 -04:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-07-12 23:52:43 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-10-10 22:48:54 -04:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								def  einsum_op_compvis ( q ,  k ,  v ) :  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    s  =  einsum ( ' b i d, b j d -> b i j ' ,  q ,  k ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    s  =  s . softmax ( dim = - 1 ,  dtype = s . dtype ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    return  einsum ( ' b i j, b j d -> b i d ' ,  s ,  v ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-07-12 23:52:43 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-10-10 22:48:54 -04:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								def  einsum_op_slice_0 ( q ,  k ,  v ,  slice_size ) :  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    r  =  torch . zeros ( q . shape [ 0 ] ,  q . shape [ 1 ] ,  v . shape [ 2 ] ,  device = q . device ,  dtype = q . dtype ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    for  i  in  range ( 0 ,  q . shape [ 0 ] ,  slice_size ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        end  =  i  +  slice_size 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        r [ i : end ]  =  einsum_op_compvis ( q [ i : end ] ,  k [ i : end ] ,  v [ i : end ] ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    return  r 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-07-12 23:52:43 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-10-10 22:48:54 -04:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								def  einsum_op_slice_1 ( q ,  k ,  v ,  slice_size ) :  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    r  =  torch . zeros ( q . shape [ 0 ] ,  q . shape [ 1 ] ,  v . shape [ 2 ] ,  device = q . device ,  dtype = q . dtype ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    for  i  in  range ( 0 ,  q . shape [ 1 ] ,  slice_size ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        end  =  i  +  slice_size 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        r [ : ,  i : end ]  =  einsum_op_compvis ( q [ : ,  i : end ] ,  k ,  v ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    return  r 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-07-12 23:52:43 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-10-10 22:48:54 -04:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								def  einsum_op_mps_v1 ( q ,  k ,  v ) :  
						 
					
						
							
								
									
										
										
										
											2022-12-19 17:25:14 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    if  q . shape [ 0 ]  *  q . shape [ 1 ]  < =  2 * * 16 :  # (512x512) max q.shape[1]: 4096 
							 
						 
					
						
							
								
									
										
										
										
											2022-10-10 22:48:54 -04:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        return  einsum_op_compvis ( q ,  k ,  v ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    else : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        slice_size  =  math . floor ( 2 * * 30  /  ( q . shape [ 0 ]  *  q . shape [ 1 ] ) ) 
							 
						 
					
						
							
								
									
										
										
										
											2022-12-19 17:25:14 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        if  slice_size  %  4096  ==  0 : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            slice_size  - =  1 
							 
						 
					
						
							
								
									
										
										
										
											2022-10-10 22:48:54 -04:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        return  einsum_op_slice_1 ( q ,  k ,  v ,  slice_size ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-07-12 23:52:43 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-10-10 22:48:54 -04:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								def  einsum_op_mps_v2 ( q ,  k ,  v ) :  
						 
					
						
							
								
									
										
										
										
											2022-12-19 17:25:14 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    if  mem_total_gb  >  8  and  q . shape [ 0 ]  *  q . shape [ 1 ]  < =  2 * * 16 : 
							 
						 
					
						
							
								
									
										
										
										
											2022-10-10 22:48:54 -04:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        return  einsum_op_compvis ( q ,  k ,  v ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    else : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        return  einsum_op_slice_0 ( q ,  k ,  v ,  1 ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-07-12 23:52:43 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-10-10 22:48:54 -04:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								def  einsum_op_tensor_mem ( q ,  k ,  v ,  max_tensor_mb ) :  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    size_mb  =  q . shape [ 0 ]  *  q . shape [ 1 ]  *  k . shape [ 1 ]  *  q . element_size ( )  / /  ( 1  <<  20 ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    if  size_mb  < =  max_tensor_mb : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        return  einsum_op_compvis ( q ,  k ,  v ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    div  =  1  <<  int ( ( size_mb  -  1 )  /  max_tensor_mb ) . bit_length ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    if  div  < =  q . shape [ 0 ] : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        return  einsum_op_slice_0 ( q ,  k ,  v ,  q . shape [ 0 ]  / /  div ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    return  einsum_op_slice_1 ( q ,  k ,  v ,  max ( q . shape [ 1 ]  / /  div ,  1 ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-07-12 23:52:43 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-10-11 03:32:11 -04:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								def  einsum_op_cuda ( q ,  k ,  v ) :  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    stats  =  torch . cuda . memory_stats ( q . device ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    mem_active  =  stats [ ' active_bytes.all.current ' ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    mem_reserved  =  stats [ ' reserved_bytes.all.current ' ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    mem_free_cuda ,  _  =  torch . cuda . mem_get_info ( q . device ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    mem_free_torch  =  mem_reserved  -  mem_active 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    mem_free_total  =  mem_free_cuda  +  mem_free_torch 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    # Divide factor of safety as there's copying and fragmentation 
							 
						 
					
						
							
								
									
										
										
										
											2022-10-18 20:28:28 -03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    return  einsum_op_tensor_mem ( q ,  k ,  v ,  mem_free_total  /  3.3  /  ( 1  <<  20 ) ) 
							 
						 
					
						
							
								
									
										
										
										
											2022-10-11 03:32:11 -04:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-07-12 23:52:43 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-10-10 22:48:54 -04:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								def  einsum_op ( q ,  k ,  v ) :  
						 
					
						
							
								
									
										
										
										
											2022-10-11 03:32:11 -04:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    if  q . device . type  ==  ' cuda ' : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        return  einsum_op_cuda ( q ,  k ,  v ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-10-10 22:48:54 -04:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    if  q . device . type  ==  ' mps ' : 
							 
						 
					
						
							
								
									
										
										
										
											2022-12-19 17:25:14 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        if  mem_total_gb  > =  32  and  q . shape [ 0 ]  %  32  !=  0  and  q . shape [ 0 ]  *  q . shape [ 1 ]  <  2 * * 18 : 
							 
						 
					
						
							
								
									
										
										
										
											2022-10-10 22:48:54 -04:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								            return  einsum_op_mps_v1 ( q ,  k ,  v ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        return  einsum_op_mps_v2 ( q ,  k ,  v ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    # Smaller slices are faster due to L2/L3/SLC caches. 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    # Tested on i7 with 8MB L3 cache. 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    return  einsum_op_tensor_mem ( q ,  k ,  v ,  32 ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-07-12 23:52:43 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-07-13 09:30:33 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								def  split_cross_attention_forward_invokeAI ( self ,  x ,  context = None ,  mask = None ,  * * kwargs ) :  
						 
					
						
							
								
									
										
										
										
											2022-10-10 22:48:54 -04:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    h  =  self . heads 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    q  =  self . to_q ( x ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    context  =  default ( context ,  x ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-21 08:36:07 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    context_k ,  context_v  =  hypernetwork . apply_hypernetworks ( shared . loaded_hypernetworks ,  context ) 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-25 00:23:10 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    k  =  self . to_k ( context_k ) 
							 
						 
					
						
							
								
									
										
										
										
											2022-10-11 05:13:17 -04:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    v  =  self . to_v ( context_v ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    del  context ,  context_k ,  context_v ,  x 
							 
						 
					
						
							
								
									
										
										
										
											2022-10-10 22:48:54 -04:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-25 00:23:10 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    dtype  =  q . dtype 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    if  shared . opts . upcast_attn : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        q ,  k ,  v  =  q . float ( ) ,  k . float ( ) ,  v  if  v . device . type  ==  ' mps '  else  v . float ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    with  devices . without_autocast ( disable = not  shared . opts . upcast_attn ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        k  =  k  *  self . scale 
							 
						 
					
						
							
								
									
										
										
										
											2023-05-11 18:28:15 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-05-10 11:05:02 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        q ,  k ,  v  =  ( rearrange ( t ,  ' b n (h d) -> (b h) n d ' ,  h = h )  for  t  in  ( q ,  k ,  v ) ) 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-25 00:23:10 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        r  =  einsum_op ( q ,  k ,  v ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    r  =  r . to ( dtype ) 
							 
						 
					
						
							
								
									
										
										
										
											2022-10-10 22:48:54 -04:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    return  self . to_out ( rearrange ( r ,  ' (b h) n d -> b n (h d) ' ,  h = h ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-10-10 23:55:48 -04:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								# -- End of code from https://github.com/invoke-ai/InvokeAI --  
						 
					
						
							
								
									
										
										
										
											2022-10-10 22:48:54 -04:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-12-27 08:50:55 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								# Based on Birch-san's modified implementation of sub-quadratic attention from https://github.com/Birch-san/diffusers/pull/1  
						 
					
						
							
								
									
										
										
										
											2023-01-06 16:42:47 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								# The sub_quad_attention_forward function is under the MIT License listed under Memory Efficient Attention in the Licenses section of the web UI interface  
						 
					
						
							
								
									
										
										
										
											2023-07-13 09:30:33 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								def  sub_quad_attention_forward ( self ,  x ,  context = None ,  mask = None ,  * * kwargs ) :  
						 
					
						
							
								
									
										
										
										
											2022-12-27 08:50:55 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    assert  mask  is  None ,  " attention-mask not currently implemented for SubQuadraticCrossAttnProcessor. " 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    h  =  self . heads 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    q  =  self . to_q ( x ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    context  =  default ( context ,  x ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-21 08:36:07 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    context_k ,  context_v  =  hypernetwork . apply_hypernetworks ( shared . loaded_hypernetworks ,  context ) 
							 
						 
					
						
							
								
									
										
										
										
											2022-12-27 08:50:55 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    k  =  self . to_k ( context_k ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    v  =  self . to_v ( context_v ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    del  context ,  context_k ,  context_v ,  x 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    q  =  q . unflatten ( - 1 ,  ( h ,  - 1 ) ) . transpose ( 1 , 2 ) . flatten ( end_dim = 1 ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    k  =  k . unflatten ( - 1 ,  ( h ,  - 1 ) ) . transpose ( 1 , 2 ) . flatten ( end_dim = 1 ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    v  =  v . unflatten ( - 1 ,  ( h ,  - 1 ) ) . transpose ( 1 , 2 ) . flatten ( end_dim = 1 ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-04-14 02:22:48 -04:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    if  q . device . type  ==  ' mps ' : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        q ,  k ,  v  =  q . contiguous ( ) ,  k . contiguous ( ) ,  v . contiguous ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-25 00:23:10 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    dtype  =  q . dtype 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    if  shared . opts . upcast_attn : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        q ,  k  =  q . float ( ) ,  k . float ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-06 01:01:51 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    x  =  sub_quad_attention ( q ,  k ,  v ,  q_chunk_size = shared . cmd_opts . sub_quad_q_chunk_size ,  kv_chunk_size = shared . cmd_opts . sub_quad_kv_chunk_size ,  chunk_threshold = shared . cmd_opts . sub_quad_chunk_threshold ,  use_checkpoint = self . training ) 
							 
						 
					
						
							
								
									
										
										
										
											2022-12-27 08:50:55 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-25 00:23:10 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    x  =  x . to ( dtype ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-12-27 08:50:55 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    x  =  x . unflatten ( 0 ,  ( - 1 ,  h ) ) . transpose ( 1 , 2 ) . flatten ( start_dim = 2 ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    out_proj ,  dropout  =  self . to_out 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    x  =  out_proj ( x ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    x  =  dropout ( x ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    return  x 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-07-12 23:52:43 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-06 01:01:51 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								def  sub_quad_attention ( q ,  k ,  v ,  q_chunk_size = 1024 ,  kv_chunk_size = None ,  kv_chunk_size_min = None ,  chunk_threshold = None ,  use_checkpoint = True ) :  
						 
					
						
							
								
									
										
										
										
											2022-12-27 08:50:55 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    bytes_per_token  =  torch . finfo ( q . dtype ) . bits / / 8 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    batch_x_heads ,  q_tokens ,  _  =  q . shape 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    _ ,  k_tokens ,  _  =  k . shape 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    qk_matmul_size_bytes  =  batch_x_heads  *  bytes_per_token  *  q_tokens  *  k_tokens 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-06 01:01:51 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    if  chunk_threshold  is  None : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        chunk_threshold_bytes  =  int ( get_available_vram ( )  *  0.9 )  if  q . device . type  ==  ' mps '  else  int ( get_available_vram ( )  *  0.7 ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    elif  chunk_threshold  ==  0 : 
							 
						 
					
						
							
								
									
										
										
										
											2022-12-27 08:50:55 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        chunk_threshold_bytes  =  None 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-06 01:01:51 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    else : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        chunk_threshold_bytes  =  int ( 0.01  *  chunk_threshold  *  get_available_vram ( ) ) 
							 
						 
					
						
							
								
									
										
										
										
											2022-12-27 08:50:55 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-06 01:01:51 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    if  kv_chunk_size_min  is  None  and  chunk_threshold_bytes  is  not  None : 
							 
						 
					
						
							
								
									
										
										
										
											2022-12-27 08:50:55 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        kv_chunk_size_min  =  chunk_threshold_bytes  / /  ( batch_x_heads  *  bytes_per_token  *  ( k . shape [ 2 ]  +  v . shape [ 2 ] ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    elif  kv_chunk_size_min  ==  0 : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        kv_chunk_size_min  =  None 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    if  chunk_threshold_bytes  is  not  None  and  qk_matmul_size_bytes  < =  chunk_threshold_bytes : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        # the big matmul fits into our memory limit; do everything in 1 chunk, 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        # i.e. send it down the unchunked fast-path 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        kv_chunk_size  =  k_tokens 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-25 00:23:10 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    with  devices . without_autocast ( disable = q . dtype  ==  v . dtype ) : 
							 
						 
					
						
							
								
									
										
										
										
											2023-05-18 22:48:28 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        return  sub_quadratic_attention . efficient_dot_product_attention ( 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-25 00:23:10 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								            q , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            k , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            v , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            query_chunk_size = q_chunk_size , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            kv_chunk_size = kv_chunk_size , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            kv_chunk_size_min  =  kv_chunk_size_min , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            use_checkpoint = use_checkpoint , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        ) 
							 
						 
					
						
							
								
									
										
										
										
											2022-12-27 08:50:55 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-23 16:40:20 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								def  get_xformers_flash_attention_op ( q ,  k ,  v ) :  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    if  not  shared . cmd_opts . xformers_flash_attention : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        return  None 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    try : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        flash_attention_op  =  xformers . ops . MemoryEfficientAttentionFlashAttentionOp 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        fw ,  bw  =  flash_attention_op 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        if  fw . supports ( xformers . ops . fmha . Inputs ( query = q ,  key = k ,  value = v ,  attn_bias = None ) ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            return  flash_attention_op 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    except  Exception  as  e : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        errors . display_once ( e ,  " enabling flash attention " ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    return  None 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-07-13 09:30:33 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								def  xformers_attention_forward ( self ,  x ,  context = None ,  mask = None ,  * * kwargs ) :  
						 
					
						
							
								
									
										
										
										
											2022-10-07 05:21:49 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    h  =  self . heads 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    q_in  =  self . to_q ( x ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    context  =  default ( context ,  x ) 
							 
						 
					
						
							
								
									
										
										
										
											2022-10-11 11:09:51 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-21 08:36:07 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    context_k ,  context_v  =  hypernetwork . apply_hypernetworks ( shared . loaded_hypernetworks ,  context ) 
							 
						 
					
						
							
								
									
										
										
										
											2022-10-11 11:09:51 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    k_in  =  self . to_k ( context_k ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    v_in  =  self . to_v ( context_v ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-05-10 11:05:02 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    q ,  k ,  v  =  ( rearrange ( t ,  ' b n (h d) -> b n h d ' ,  h = h )  for  t  in  ( q_in ,  k_in ,  v_in ) ) 
							 
						 
					
						
							
								
									
										
										
										
											2022-10-07 05:21:49 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    del  q_in ,  k_in ,  v_in 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-21 17:42:04 +09:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-25 00:23:10 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    dtype  =  q . dtype 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    if  shared . opts . upcast_attn : 
							 
						 
					
						
							
								
									
										
										
										
											2023-03-21 14:50:22 +04:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        q ,  k ,  v  =  q . float ( ) ,  k . float ( ) ,  v . float ( ) 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-25 00:23:10 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-23 16:40:20 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    out  =  xformers . ops . memory_efficient_attention ( q ,  k ,  v ,  attn_bias = None ,  op = get_xformers_flash_attention_op ( q ,  k ,  v ) ) 
							 
						 
					
						
							
								
									
										
										
										
											2022-10-07 05:21:49 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-25 00:23:10 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    out  =  out . to ( dtype ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-10-08 04:09:18 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    out  =  rearrange ( out ,  ' b n h d -> b n (h d) ' ,  h = h ) 
							 
						 
					
						
							
								
									
										
										
										
											2022-10-07 05:21:49 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    return  self . to_out ( out ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-07-12 23:52:43 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-03-07 00:33:13 +05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								# Based on Diffusers usage of scaled dot product attention from https://github.com/huggingface/diffusers/blob/c7da8fd23359a22d0df2741688b5b4f33c26df21/src/diffusers/models/cross_attention.py  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								# The scaled_dot_product_attention_forward function contains parts of code under Apache-2.0 license listed under Scaled Dot Product Attention in the Licenses section of the web UI interface  
						 
					
						
							
								
									
										
										
										
											2023-07-13 09:30:33 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								def  scaled_dot_product_attention_forward ( self ,  x ,  context = None ,  mask = None ,  * * kwargs ) :  
						 
					
						
							
								
									
										
										
										
											2023-03-07 00:33:13 +05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    batch_size ,  sequence_length ,  inner_dim  =  x . shape 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    if  mask  is  not  None : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        mask  =  self . prepare_attention_mask ( mask ,  sequence_length ,  batch_size ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        mask  =  mask . view ( batch_size ,  self . heads ,  - 1 ,  mask . shape [ - 1 ] ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    h  =  self . heads 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    q_in  =  self . to_q ( x ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    context  =  default ( context ,  x ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    context_k ,  context_v  =  hypernetwork . apply_hypernetworks ( shared . loaded_hypernetworks ,  context ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    k_in  =  self . to_k ( context_k ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    v_in  =  self . to_v ( context_v ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    head_dim  =  inner_dim  / /  h 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    q  =  q_in . view ( batch_size ,  - 1 ,  h ,  head_dim ) . transpose ( 1 ,  2 ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    k  =  k_in . view ( batch_size ,  - 1 ,  h ,  head_dim ) . transpose ( 1 ,  2 ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    v  =  v_in . view ( batch_size ,  - 1 ,  h ,  head_dim ) . transpose ( 1 ,  2 ) 
							 
						 
					
						
							
								
									
										
										
										
											2023-05-11 18:28:15 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-03-07 00:33:13 +05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    del  q_in ,  k_in ,  v_in 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    dtype  =  q . dtype 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    if  shared . opts . upcast_attn : 
							 
						 
					
						
							
								
									
										
										
										
											2023-03-24 16:29:16 +04:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        q ,  k ,  v  =  q . float ( ) ,  k . float ( ) ,  v . float ( ) 
							 
						 
					
						
							
								
									
										
										
										
											2023-03-07 00:33:13 +05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    # the output of sdp = (batch, num_heads, seq_len, head_dim) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    hidden_states  =  torch . nn . functional . scaled_dot_product_attention ( 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        q ,  k ,  v ,  attn_mask = mask ,  dropout_p = 0.0 ,  is_causal = False 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    hidden_states  =  hidden_states . transpose ( 1 ,  2 ) . reshape ( batch_size ,  - 1 ,  h  *  head_dim ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    hidden_states  =  hidden_states . to ( dtype ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    # linear proj 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    hidden_states  =  self . to_out [ 0 ] ( hidden_states ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    # dropout 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    hidden_states  =  self . to_out [ 1 ] ( hidden_states ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    return  hidden_states 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-07-12 23:52:43 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-07-13 09:30:33 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								def  scaled_dot_product_no_mem_attention_forward ( self ,  x ,  context = None ,  mask = None ,  * * kwargs ) :  
						 
					
						
							
								
									
										
										
										
											2023-03-10 12:19:36 +05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    with  torch . backends . cuda . sdp_kernel ( enable_flash = True ,  enable_math = True ,  enable_mem_efficient = False ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        return  scaled_dot_product_attention_forward ( self ,  x ,  context ,  mask ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-07-12 23:52:43 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-10-02 15:03:39 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								def  cross_attention_attnblock_forward ( self ,  x ) :  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        h_  =  x 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        h_  =  self . norm ( h_ ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        q1  =  self . q ( h_ ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        k1  =  self . k ( h_ ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        v  =  self . v ( h_ ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        # compute attention 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        b ,  c ,  h ,  w  =  q1 . shape 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        q2  =  q1 . reshape ( b ,  c ,  h * w ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        del  q1 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        q  =  q2 . permute ( 0 ,  2 ,  1 )    # b,hw,c 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        del  q2 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        k  =  k1 . reshape ( b ,  c ,  h * w )  # b,c,hw 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        del  k1 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        h_  =  torch . zeros_like ( k ,  device = q . device ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-12-27 08:50:55 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        mem_free_total  =  get_available_vram ( ) 
							 
						 
					
						
							
								
									
										
										
										
											2022-10-02 15:03:39 +03:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        tensor_size  =  q . shape [ 0 ]  *  q . shape [ 1 ]  *  k . shape [ 2 ]  *  q . element_size ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        mem_required  =  tensor_size  *  2.5 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        steps  =  1 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        if  mem_required  >  mem_free_total : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            steps  =  2 * * ( math . ceil ( math . log ( mem_required  /  mem_free_total ,  2 ) ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        slice_size  =  q . shape [ 1 ]  / /  steps  if  ( q . shape [ 1 ]  %  steps )  ==  0  else  q . shape [ 1 ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        for  i  in  range ( 0 ,  q . shape [ 1 ] ,  slice_size ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            end  =  i  +  slice_size 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            w1  =  torch . bmm ( q [ : ,  i : end ] ,  k )      # b,hw,hw    w[b,i,j]=sum_c q[b,i,c]k[b,c,j] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            w2  =  w1  *  ( int ( c ) * * ( - 0.5 ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            del  w1 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            w3  =  torch . nn . functional . softmax ( w2 ,  dim = 2 ,  dtype = q . dtype ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            del  w2 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            # attend to values 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            v1  =  v . reshape ( b ,  c ,  h * w ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            w4  =  w3 . permute ( 0 ,  2 ,  1 )    # b,hw,hw (first hw of k, second of q) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            del  w3 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            h_ [ : ,  : ,  i : end ]  =  torch . bmm ( v1 ,  w4 )      # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            del  v1 ,  w4 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        h2  =  h_ . reshape ( b ,  c ,  h ,  w ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        del  h_ 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        h3  =  self . proj_out ( h2 ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        del  h2 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        h3  + =  x 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        return  h3 
							 
						 
					
						
							
								
									
										
										
										
											2023-05-11 18:28:15 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-07-12 23:52:43 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-10-17 22:18:59 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								def  xformers_attnblock_forward ( self ,  x ) :  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    try : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        h_  =  x 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        h_  =  self . norm ( h_ ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        q  =  self . q ( h_ ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        k  =  self . k ( h_ ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        v  =  self . v ( h_ ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        b ,  c ,  h ,  w  =  q . shape 
							 
						 
					
						
							
								
									
										
										
										
											2023-05-10 11:05:02 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        q ,  k ,  v  =  ( rearrange ( t ,  ' b c h w -> b (h w) c ' )  for  t  in  ( q ,  k ,  v ) ) 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-25 00:23:10 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        dtype  =  q . dtype 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        if  shared . opts . upcast_attn : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            q ,  k  =  q . float ( ) ,  k . float ( ) 
							 
						 
					
						
							
								
									
										
										
										
											2022-10-18 00:02:50 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        q  =  q . contiguous ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        k  =  k . contiguous ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        v  =  v . contiguous ( ) 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-23 16:40:20 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        out  =  xformers . ops . memory_efficient_attention ( q ,  k ,  v ,  op = get_xformers_flash_attention_op ( q ,  k ,  v ) ) 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-25 00:23:10 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        out  =  out . to ( dtype ) 
							 
						 
					
						
							
								
									
										
										
										
											2022-10-17 22:18:59 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        out  =  rearrange ( out ,  ' b (h w) c -> b c h w ' ,  h = h ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        out  =  self . proj_out ( out ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        return  x  +  out 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    except  NotImplementedError : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        return  cross_attention_attnblock_forward ( self ,  x ) 
							 
						 
					
						
							
								
									
										
										
										
											2022-12-27 08:50:55 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-07-12 23:52:43 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-03-10 22:48:41 +05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								def  sdp_attnblock_forward ( self ,  x ) :  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    h_  =  x 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    h_  =  self . norm ( h_ ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    q  =  self . q ( h_ ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    k  =  self . k ( h_ ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    v  =  self . v ( h_ ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    b ,  c ,  h ,  w  =  q . shape 
							 
						 
					
						
							
								
									
										
										
										
											2023-05-10 11:05:02 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    q ,  k ,  v  =  ( rearrange ( t ,  ' b c h w -> b (h w) c ' )  for  t  in  ( q ,  k ,  v ) ) 
							 
						 
					
						
							
								
									
										
										
										
											2023-03-10 22:48:41 +05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    dtype  =  q . dtype 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    if  shared . opts . upcast_attn : 
							 
						 
					
						
							
								
									
										
										
											
												Fix upcast attention dtype error.
Without this fix, enabling the "Upcast cross attention layer to float32" option while also using `--opt-sdp-attention` breaks generation with an error:
```
  File "/ext3/automatic1111/stable-diffusion-webui/modules/sd_hijack_optimizations.py", line 612, in sdp_attnblock_forward
    out = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=False)
RuntimeError: Expected query, key, and value to have the same dtype, but got query.dtype: float key.dtype: float and value.dtype: c10::Half instead.
```
The fix is to make sure to upcast the value tensor too.
											 
										 
										
											2023-06-06 21:45:30 +01:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        q ,  k ,  v  =  q . float ( ) ,  k . float ( ) ,  v . float ( ) 
							 
						 
					
						
							
								
									
										
										
										
											2023-03-10 22:48:41 +05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    q  =  q . contiguous ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    k  =  k . contiguous ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    v  =  v . contiguous ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    out  =  torch . nn . functional . scaled_dot_product_attention ( q ,  k ,  v ,  dropout_p = 0.0 ,  is_causal = False ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    out  =  out . to ( dtype ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    out  =  rearrange ( out ,  ' b (h w) c -> b c h w ' ,  h = h ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    out  =  self . proj_out ( out ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    return  x  +  out 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-07-12 23:52:43 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-03-10 22:48:41 +05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								def  sdp_no_mem_attnblock_forward ( self ,  x ) :  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    with  torch . backends . cuda . sdp_kernel ( enable_flash = True ,  enable_math = True ,  enable_mem_efficient = False ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        return  sdp_attnblock_forward ( self ,  x ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-07-12 23:52:43 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-12-27 08:50:55 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								def  sub_quad_attnblock_forward ( self ,  x ) :  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    h_  =  x 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    h_  =  self . norm ( h_ ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    q  =  self . q ( h_ ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    k  =  self . k ( h_ ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    v  =  self . v ( h_ ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    b ,  c ,  h ,  w  =  q . shape 
							 
						 
					
						
							
								
									
										
										
										
											2023-05-10 11:05:02 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    q ,  k ,  v  =  ( rearrange ( t ,  ' b c h w -> b (h w) c ' )  for  t  in  ( q ,  k ,  v ) ) 
							 
						 
					
						
							
								
									
										
										
										
											2022-12-27 08:50:55 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    q  =  q . contiguous ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    k  =  k . contiguous ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    v  =  v . contiguous ( ) 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-06 01:01:51 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    out  =  sub_quad_attention ( q ,  k ,  v ,  q_chunk_size = shared . cmd_opts . sub_quad_q_chunk_size ,  kv_chunk_size = shared . cmd_opts . sub_quad_kv_chunk_size ,  chunk_threshold = shared . cmd_opts . sub_quad_chunk_threshold ,  use_checkpoint = self . training ) 
							 
						 
					
						
							
								
									
										
										
										
											2022-12-27 08:50:55 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    out  =  rearrange ( out ,  ' b (h w) c -> b c h w ' ,  h = h ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    out  =  self . proj_out ( out ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    return  x  +  out