2022-12-27 08:50:55 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								# original source:  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								#   https://github.com/AminRezaei0x443/memory-efficient-attention/blob/1bc0d9e6ac5f82ea43a375135c4e1d3896ee1694/memory_efficient_attention/attention_torch.py  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								# license:  
						 
					
						
							
								
									
										
										
										
											2023-01-06 16:42:47 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								#   MIT License (see Memory Efficient Attention under the Licenses section in the web UI interface for the full license)  
						 
					
						
							
								
									
										
										
										
											2022-12-27 08:50:55 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								# credit:  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								#   Amin Rezaei (original author)  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								#   Alex Birch (optimized algorithm for 3D tensors, at the expense of removing bias, masking and callbacks)  
						 
					
						
							
								
									
										
										
										
											2023-01-05 04:37:17 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								#   brkirch (modified to use torch.narrow instead of dynamic_slice implementation)  
						 
					
						
							
								
									
										
										
										
											2022-12-27 08:50:55 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								# implementation of:  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								#   Self-attention Does Not Need O(n2) Memory":  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								#   https://arxiv.org/abs/2112.05682v2  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								from  functools  import  partial  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								import  torch  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								from  torch  import  Tensor  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								from  torch . utils . checkpoint  import  checkpoint  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								import  math  
						 
					
						
							
								
									
										
										
										
											2023-01-07 13:08:21 -07:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								from  typing  import  Optional ,  NamedTuple ,  List  
						 
					
						
							
								
									
										
										
										
											2022-12-27 08:50:55 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-09 20:08:48 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-05 04:37:17 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								def  narrow_trunc (  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    input :  Tensor , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    dim :  int , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    start :  int , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    length :  int 
							 
						 
					
						
							
								
									
										
										
										
											2022-12-27 08:50:55 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								)  - >  Tensor :  
						 
					
						
							
								
									
										
										
										
											2023-01-05 04:37:17 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    return  torch . narrow ( input ,  dim ,  start ,  length  if  input . shape [ dim ]  > =  start  +  length  else  input . shape [ dim ]  -  start ) 
							 
						 
					
						
							
								
									
										
										
										
											2022-12-27 08:50:55 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-09 20:08:48 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-12-27 08:50:55 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								class  AttnChunk ( NamedTuple ) :  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    exp_values :  Tensor 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    exp_weights_sum :  Tensor 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    max_score :  Tensor 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-09 20:08:48 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								class  SummarizeChunk :  
						 
					
						
							
								
									
										
										
										
											2022-12-27 08:50:55 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    @staticmethod 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    def  __call__ ( 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        query :  Tensor , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        key :  Tensor , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        value :  Tensor , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    )  - >  AttnChunk :  . . . 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-09 20:08:48 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								class  ComputeQueryChunkAttn :  
						 
					
						
							
								
									
										
										
										
											2022-12-27 08:50:55 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    @staticmethod 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    def  __call__ ( 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        query :  Tensor , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        key :  Tensor , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        value :  Tensor , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    )  - >  Tensor :  . . . 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-09 20:08:48 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-12-27 08:50:55 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								def  _summarize_chunk (  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    query :  Tensor , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    key :  Tensor , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    value :  Tensor , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    scale :  float , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								)  - >  AttnChunk :  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    attn_weights  =  torch . baddbmm ( 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        torch . empty ( 1 ,  1 ,  1 ,  device = query . device ,  dtype = query . dtype ) , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        query , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        key . transpose ( 1 , 2 ) , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        alpha = scale , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        beta = 0 , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    max_score ,  _  =  torch . max ( attn_weights ,  - 1 ,  keepdim = True ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    max_score  =  max_score . detach ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    exp_weights  =  torch . exp ( attn_weights  -  max_score ) 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-25 00:23:10 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    exp_values  =  torch . bmm ( exp_weights ,  value )  if  query . device . type  ==  ' mps '  else  torch . bmm ( exp_weights ,  value . to ( exp_weights . dtype ) ) . to ( value . dtype ) 
							 
						 
					
						
							
								
									
										
										
										
											2022-12-27 08:50:55 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    max_score  =  max_score . squeeze ( - 1 ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    return  AttnChunk ( exp_values ,  exp_weights . sum ( dim = - 1 ) ,  max_score ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-09 20:08:48 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-12-27 08:50:55 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								def  _query_chunk_attention (  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    query :  Tensor , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    key :  Tensor , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    value :  Tensor , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    summarize_chunk :  SummarizeChunk , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    kv_chunk_size :  int , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								)  - >  Tensor :  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    batch_x_heads ,  k_tokens ,  k_channels_per_head  =  key . shape 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    _ ,  _ ,  v_channels_per_head  =  value . shape 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    def  chunk_scanner ( chunk_idx :  int )  - >  AttnChunk : 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-05 04:37:17 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        key_chunk  =  narrow_trunc ( 
							 
						 
					
						
							
								
									
										
										
										
											2022-12-27 08:50:55 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								            key , 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-05 04:37:17 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								            1 , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            chunk_idx , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            kv_chunk_size 
							 
						 
					
						
							
								
									
										
										
										
											2022-12-27 08:50:55 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								        ) 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-05 04:37:17 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        value_chunk  =  narrow_trunc ( 
							 
						 
					
						
							
								
									
										
										
										
											2022-12-27 08:50:55 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								            value , 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-05 04:37:17 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								            1 , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            chunk_idx , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            kv_chunk_size 
							 
						 
					
						
							
								
									
										
										
										
											2022-12-27 08:50:55 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								        ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        return  summarize_chunk ( query ,  key_chunk ,  value_chunk ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    chunks :  List [ AttnChunk ]  =  [ 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        chunk_scanner ( chunk )  for  chunk  in  torch . arange ( 0 ,  k_tokens ,  kv_chunk_size ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    acc_chunk  =  AttnChunk ( * map ( torch . stack ,  zip ( * chunks ) ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    chunk_values ,  chunk_weights ,  chunk_max  =  acc_chunk 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    global_max ,  _  =  torch . max ( chunk_max ,  0 ,  keepdim = True ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    max_diffs  =  torch . exp ( chunk_max  -  global_max ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    chunk_values  * =  torch . unsqueeze ( max_diffs ,  - 1 ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    chunk_weights  * =  max_diffs 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    all_values  =  chunk_values . sum ( dim = 0 ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    all_weights  =  torch . unsqueeze ( chunk_weights ,  - 1 ) . sum ( dim = 0 ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    return  all_values  /  all_weights 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-09 20:08:48 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-12-27 08:50:55 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								# TODO: refactor CrossAttention#get_attention_scores to share code with this  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								def  _get_attention_scores_no_kv_chunking (  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    query :  Tensor , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    key :  Tensor , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    value :  Tensor , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    scale :  float , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								)  - >  Tensor :  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    attn_scores  =  torch . baddbmm ( 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        torch . empty ( 1 ,  1 ,  1 ,  device = query . device ,  dtype = query . dtype ) , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        query , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        key . transpose ( 1 , 2 ) , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        alpha = scale , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        beta = 0 , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    attn_probs  =  attn_scores . softmax ( dim = - 1 ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    del  attn_scores 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-25 00:23:10 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    hidden_states_slice  =  torch . bmm ( attn_probs ,  value )  if  query . device . type  ==  ' mps '  else  torch . bmm ( attn_probs ,  value . to ( attn_probs . dtype ) ) . to ( value . dtype ) 
							 
						 
					
						
							
								
									
										
										
										
											2022-12-27 08:50:55 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    return  hidden_states_slice 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-09 20:08:48 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-12-27 08:50:55 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								class  ScannedChunk ( NamedTuple ) :  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    chunk_idx :  int 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    attn_chunk :  AttnChunk 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2023-01-09 20:08:48 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-12-27 08:50:55 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								def  efficient_dot_product_attention (  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    query :  Tensor , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    key :  Tensor , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    value :  Tensor , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    query_chunk_size = 1024 , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    kv_chunk_size :  Optional [ int ]  =  None , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    kv_chunk_size_min :  Optional [ int ]  =  None , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    use_checkpoint = True , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								) :  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    """ Computes efficient dot-product attention given query, key, and value. 
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      This  is  efficient  version  of  attention  presented  in 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      https : / / arxiv . org / abs / 2112.05682 v2  which  comes  with  O ( sqrt ( n ) )  memory  requirements . 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      Args : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        query :  queries  for  calculating  attention  with  shape  of 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								          ` [ batch  *  num_heads ,  tokens ,  channels_per_head ] ` . 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        key :  keys  for  calculating  attention  with  shape  of 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								          ` [ batch  *  num_heads ,  tokens ,  channels_per_head ] ` . 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        value :  values  to  be  used  in  attention  with  shape  of 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								          ` [ batch  *  num_heads ,  tokens ,  channels_per_head ] ` . 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        query_chunk_size :  int :  query  chunks  size 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        kv_chunk_size :  Optional [ int ] :  key / value  chunks  size .  if  None :  defaults  to  sqrt ( key_tokens ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        kv_chunk_size_min :  Optional [ int ] :  key / value  minimum  chunk  size .  only  considered  when  kv_chunk_size  is  None .  changes  ` sqrt ( key_tokens ) `  into  ` max ( sqrt ( key_tokens ) ,  kv_chunk_size_min ) ` ,  to  ensure  our  chunk  sizes  don ' t get too small (smaller chunks = more chunks = less concurrent work done). 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        use_checkpoint :  bool :  whether  to  use  checkpointing  ( recommended  True  for  training ,  False  for  inference ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      Returns : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        Output  of  shape  ` [ batch  *  num_heads ,  query_tokens ,  channels_per_head ] ` . 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      """ 
 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    batch_x_heads ,  q_tokens ,  q_channels_per_head  =  query . shape 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    _ ,  k_tokens ,  _  =  key . shape 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    scale  =  q_channels_per_head  * *  - 0.5 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    kv_chunk_size  =  min ( kv_chunk_size  or  int ( math . sqrt ( k_tokens ) ) ,  k_tokens ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    if  kv_chunk_size_min  is  not  None : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        kv_chunk_size  =  max ( kv_chunk_size ,  kv_chunk_size_min ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    def  get_query_chunk ( chunk_idx :  int )  - >  Tensor : 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-05 04:37:17 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        return  narrow_trunc ( 
							 
						 
					
						
							
								
									
										
										
										
											2022-12-27 08:50:55 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								            query , 
							 
						 
					
						
							
								
									
										
										
										
											2023-01-05 04:37:17 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								            1 , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            chunk_idx , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            min ( query_chunk_size ,  q_tokens ) 
							 
						 
					
						
							
								
									
										
										
										
											2022-12-27 08:50:55 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								        ) 
							 
						 
					
						
							
								
									
										
										
										
											2023-05-11 18:28:15 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-12-27 08:50:55 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    summarize_chunk :  SummarizeChunk  =  partial ( _summarize_chunk ,  scale = scale ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    summarize_chunk :  SummarizeChunk  =  partial ( checkpoint ,  summarize_chunk )  if  use_checkpoint  else  summarize_chunk 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    compute_query_chunk_attn :  ComputeQueryChunkAttn  =  partial ( 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        _get_attention_scores_no_kv_chunking , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        scale = scale 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    )  if  k_tokens  < =  kv_chunk_size  else  ( 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        # fast-path for when there's just 1 key-value chunk per query chunk (this is just sliced attention btw) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        partial ( 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            _query_chunk_attention , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            kv_chunk_size = kv_chunk_size , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            summarize_chunk = summarize_chunk , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    if  q_tokens  < =  query_chunk_size : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        # fast-path for when there's just 1 query chunk 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        return  compute_query_chunk_attn ( 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            query = query , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            key = key , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            value = value , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        ) 
							 
						 
					
						
							
								
									
										
										
										
											2023-05-11 07:45:05 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    res  =  torch . zeros_like ( query ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    for  i  in  range ( math . ceil ( q_tokens  /  query_chunk_size ) ) : 
							 
						 
					
						
							
								
									
										
										
										
											2023-05-10 22:05:18 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        attn_scores  =  compute_query_chunk_attn ( 
							 
						 
					
						
							
								
									
										
										
										
											2022-12-27 08:50:55 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								            query = get_query_chunk ( i  *  query_chunk_size ) , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            key = key , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            value = value , 
							 
						 
					
						
							
								
									
										
										
										
											2023-05-10 22:05:18 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								        ) 
							 
						 
					
						
							
								
									
										
										
										
											2023-05-11 07:45:05 +03:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        res [ : ,  i  *  query_chunk_size : i  *  query_chunk_size  +  attn_scores . shape [ 1 ] ,  : ]  =  attn_scores 
							 
						 
					
						
							
								
									
										
										
										
											2023-05-10 22:05:18 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
									
										
										
										
											2022-12-27 08:50:55 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    return  res