2025-03-23 19:28:49 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								# Source for "Build a Large Language Model From Scratch"  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								#   - https://www.manning.com/books/build-a-large-language-model-from-scratch  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								# Code: https://github.com/rasbt/LLMs-from-scratch  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								import  torch  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								import  torch . nn  as  nn  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								class  SelfAttention_v1 ( nn . Module ) :  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    def  __init__ ( self ,  d_in ,  d_out ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        super ( ) . __init__ ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        self . W_query  =  nn . Parameter ( torch . rand ( d_in ,  d_out ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        self . W_key  =  nn . Parameter ( torch . rand ( d_in ,  d_out ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        self . W_value  =  nn . Parameter ( torch . rand ( d_in ,  d_out ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    def  forward ( self ,  x ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        keys  =  x  @  self . W_key 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        queries  =  x  @  self . W_query 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        values  =  x  @  self . W_value 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        attn_scores  =  queries  @  keys . T  # omega 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        attn_weights  =  torch . softmax ( 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            attn_scores  /  keys . shape [ - 1 ] * * 0.5 ,  dim = - 1 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        context_vec  =  attn_weights  @  values 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        return  context_vec 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								class  SelfAttention_v2 ( nn . Module ) :  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    def  __init__ ( self ,  d_in ,  d_out ,  qkv_bias = False ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        super ( ) . __init__ ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        self . W_query  =  nn . Linear ( d_in ,  d_out ,  bias = qkv_bias ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        self . W_key  =  nn . Linear ( d_in ,  d_out ,  bias = qkv_bias ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        self . W_value  =  nn . Linear ( d_in ,  d_out ,  bias = qkv_bias ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    def  forward ( self ,  x ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        keys  =  self . W_key ( x ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        queries  =  self . W_query ( x ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        values  =  self . W_value ( x ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        attn_scores  =  queries  @  keys . T 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        attn_weights  =  torch . softmax ( attn_scores  /  keys . shape [ - 1 ] * * 0.5 ,  dim = - 1 ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        context_vec  =  attn_weights  @  values 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        return  context_vec 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								class  CausalAttention ( nn . Module ) :  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    def  __init__ ( self ,  d_in ,  d_out ,  context_length , 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								                 dropout ,  qkv_bias = False ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        super ( ) . __init__ ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        self . d_out  =  d_out 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        self . W_query  =  nn . Linear ( d_in ,  d_out ,  bias = qkv_bias ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        self . W_key  =  nn . Linear ( d_in ,  d_out ,  bias = qkv_bias ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        self . W_value  =  nn . Linear ( d_in ,  d_out ,  bias = qkv_bias ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        self . dropout  =  nn . Dropout ( dropout )   # New 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        self . register_buffer ( ' mask ' ,  torch . triu ( torch . ones ( context_length ,  context_length ) ,  diagonal = 1 ) )  # New 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    def  forward ( self ,  x ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        b ,  num_tokens ,  d_in  =  x . shape   # New batch dimension b 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        # For inputs where `num_tokens` exceeds `context_length`, this will result in errors 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        # in the mask creation further below. 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        # In practice, this is not a problem since the LLM (chapters 4-7) ensures that inputs 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        # do not exceed `context_length` before reaching this forward method. 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        keys  =  self . W_key ( x ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        queries  =  self . W_query ( x ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        values  =  self . W_value ( x ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        attn_scores  =  queries  @  keys . transpose ( 1 ,  2 )   # Changed transpose 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        attn_scores . masked_fill_ (   # New, _ ops are in-place 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            self . mask . bool ( ) [ : num_tokens ,  : num_tokens ] ,  - torch . inf )   # `:num_tokens` to account for cases where the number of tokens in the batch is smaller than the supported context_size 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        attn_weights  =  torch . softmax ( 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            attn_scores  /  keys . shape [ - 1 ] * * 0.5 ,  dim = - 1 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        attn_weights  =  self . dropout ( attn_weights )   # New 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        context_vec  =  attn_weights  @  values 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        return  context_vec 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								class  MultiHeadAttentionWrapper ( nn . Module ) :  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    def  __init__ ( self ,  d_in ,  d_out ,  context_length ,  dropout ,  num_heads ,  qkv_bias = False ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        super ( ) . __init__ ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        self . heads  =  nn . ModuleList ( 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            [ CausalAttention ( d_in ,  d_out ,  context_length ,  dropout ,  qkv_bias ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								             for  _  in  range ( num_heads ) ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    def  forward ( self ,  x ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        return  torch . cat ( [ head ( x )  for  head  in  self . heads ] ,  dim = - 1 ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								class  MultiHeadAttention ( nn . Module ) :  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    def  __init__ ( self ,  d_in ,  d_out ,  context_length ,  dropout ,  num_heads ,  qkv_bias = False ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        super ( ) . __init__ ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        assert  d_out  %  num_heads  ==  0 ,  " d_out must be divisible by n_heads " 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        self . d_out  =  d_out 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        self . num_heads  =  num_heads 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        self . head_dim  =  d_out  / /  num_heads   # Reduce the projection dim to match desired output dim 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        self . W_query  =  nn . Linear ( d_in ,  d_out ,  bias = qkv_bias ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        self . W_key  =  nn . Linear ( d_in ,  d_out ,  bias = qkv_bias ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        self . W_value  =  nn . Linear ( d_in ,  d_out ,  bias = qkv_bias ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        self . out_proj  =  nn . Linear ( d_out ,  d_out )   # Linear layer to combine head outputs 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        self . dropout  =  nn . Dropout ( dropout ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        self . register_buffer ( ' mask ' ,  torch . triu ( torch . ones ( context_length ,  context_length ) ,  diagonal = 1 ) ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    def  forward ( self ,  x ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        b ,  num_tokens ,  d_in  =  x . shape 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        keys  =  self . W_key ( x )   # Shape: (b, num_tokens, d_out) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        queries  =  self . W_query ( x ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        values  =  self . W_value ( x ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        # We implicitly split the matrix by adding a `num_heads` dimension 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        keys  =  keys . view ( b ,  num_tokens ,  self . num_heads ,  self . head_dim ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        values  =  values . view ( b ,  num_tokens ,  self . num_heads ,  self . head_dim ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        queries  =  queries . view ( b ,  num_tokens ,  self . num_heads ,  self . head_dim ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        keys  =  keys . transpose ( 1 ,  2 ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        queries  =  queries . transpose ( 1 ,  2 ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        values  =  values . transpose ( 1 ,  2 ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        # Compute scaled dot-product attention (aka self-attention) with a causal mask 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        attn_scores  =  queries  @  keys . transpose ( 2 ,  3 )   # Dot product for each head 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        # Original mask truncated to the number of tokens and converted to boolean 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        mask_bool  =  self . mask . bool ( ) [ : num_tokens ,  : num_tokens ] 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        # Use the mask to fill attention scores 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        attn_scores . masked_fill_ ( mask_bool ,  - torch . inf ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        attn_weights  =  torch . softmax ( attn_scores  /  keys . shape [ - 1 ] * * 0.5 ,  dim = - 1 ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        attn_weights  =  self . dropout ( attn_weights ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        # Shape: (b, num_tokens, num_heads, head_dim) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        context_vec  =  ( attn_weights  @  values ) . transpose ( 1 ,  2 ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        # Combine heads, where self.d_out = self.num_heads * self.head_dim 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        context_vec  =  context_vec . reshape ( b ,  num_tokens ,  self . d_out ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        context_vec  =  self . out_proj ( context_vec )   # optional projection 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        return  context_vec 
							 
						 
					
						
							
								
									
										
										
										
											2025-03-27 14:00:25 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								######################  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								# Bonus  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								######################  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								class  PyTorchMultiHeadAttention ( nn . Module ) :  
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    def  __init__ ( self ,  d_in ,  d_out ,  num_heads ,  dropout = 0.0 ,  qkv_bias = False ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        super ( ) . __init__ ( ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        assert  d_out  %  num_heads  ==  0 ,  " embed_dim is indivisible by num_heads " 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        self . num_heads  =  num_heads 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        self . head_dim  =  d_out  / /  num_heads 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        self . d_out  =  d_out 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        self . qkv  =  nn . Linear ( d_in ,  3  *  d_out ,  bias = qkv_bias ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        self . proj  =  nn . Linear ( d_out ,  d_out ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        self . dropout  =  dropout 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    def  forward ( self ,  x ) : 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        batch_size ,  num_tokens ,  embed_dim  =  x . shape 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        # (b, num_tokens, embed_dim) --> (b, num_tokens, 3 * embed_dim) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        qkv  =  self . qkv ( x ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        # (b, num_tokens, 3 * embed_dim) --> (b, num_tokens, 3, num_heads, head_dim) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        qkv  =  qkv . view ( batch_size ,  num_tokens ,  3 ,  self . num_heads ,  self . head_dim ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        # (b, num_tokens, 3, num_heads, head_dim) --> (3, b, num_heads, num_tokens, head_dim) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        qkv  =  qkv . permute ( 2 ,  0 ,  3 ,  1 ,  4 ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        # (3, b, num_heads, num_tokens, head_dim) -> 3 times (b, num_heads, num_tokens, head_dim) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        queries ,  keys ,  values  =  qkv 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        use_dropout  =  0.  if  not  self . training  else  self . dropout 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        context_vec  =  nn . functional . scaled_dot_product_attention ( 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            queries ,  keys ,  values ,  attn_mask = None ,  dropout_p = use_dropout ,  is_causal = True ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        # Combine heads, where self.d_out = self.num_heads * self.head_dim 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        context_vec  =  context_vec . transpose ( 1 ,  2 ) . contiguous ( ) . view ( batch_size ,  num_tokens ,  self . d_out ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        context_vec  =  self . proj ( context_vec ) 
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        return  context_vec