2025-03-23 19:28:49 -05:00
# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
# Source for "Build a Large Language Model From Scratch"
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
# Code: https://github.com/rasbt/LLMs-from-scratch
import torch
import torch . nn as nn
class SelfAttention_v1 ( nn . Module ) :
def __init__ ( self , d_in , d_out ) :
super ( ) . __init__ ( )
self . W_query = nn . Parameter ( torch . rand ( d_in , d_out ) )
self . W_key = nn . Parameter ( torch . rand ( d_in , d_out ) )
self . W_value = nn . Parameter ( torch . rand ( d_in , d_out ) )
def forward ( self , x ) :
keys = x @ self . W_key
queries = x @ self . W_query
values = x @ self . W_value
attn_scores = queries @ keys . T # omega
attn_weights = torch . softmax (
attn_scores / keys . shape [ - 1 ] * * 0.5 , dim = - 1
)
context_vec = attn_weights @ values
return context_vec
class SelfAttention_v2 ( nn . Module ) :
def __init__ ( self , d_in , d_out , qkv_bias = False ) :
super ( ) . __init__ ( )
self . W_query = nn . Linear ( d_in , d_out , bias = qkv_bias )
self . W_key = nn . Linear ( d_in , d_out , bias = qkv_bias )
self . W_value = nn . Linear ( d_in , d_out , bias = qkv_bias )
def forward ( self , x ) :
keys = self . W_key ( x )
queries = self . W_query ( x )
values = self . W_value ( x )
attn_scores = queries @ keys . T
attn_weights = torch . softmax ( attn_scores / keys . shape [ - 1 ] * * 0.5 , dim = - 1 )
context_vec = attn_weights @ values
return context_vec
class CausalAttention ( nn . Module ) :
def __init__ ( self , d_in , d_out , context_length ,
dropout , qkv_bias = False ) :
super ( ) . __init__ ( )
self . d_out = d_out
self . W_query = nn . Linear ( d_in , d_out , bias = qkv_bias )
self . W_key = nn . Linear ( d_in , d_out , bias = qkv_bias )
self . W_value = nn . Linear ( d_in , d_out , bias = qkv_bias )
self . dropout = nn . Dropout ( dropout ) # New
self . register_buffer ( ' mask ' , torch . triu ( torch . ones ( context_length , context_length ) , diagonal = 1 ) ) # New
def forward ( self , x ) :
b , num_tokens , d_in = x . shape # New batch dimension b
# For inputs where `num_tokens` exceeds `context_length`, this will result in errors
# in the mask creation further below.
# In practice, this is not a problem since the LLM (chapters 4-7) ensures that inputs
# do not exceed `context_length` before reaching this forward method.
keys = self . W_key ( x )
queries = self . W_query ( x )
values = self . W_value ( x )
attn_scores = queries @ keys . transpose ( 1 , 2 ) # Changed transpose
attn_scores . masked_fill_ ( # New, _ ops are in-place
self . mask . bool ( ) [ : num_tokens , : num_tokens ] , - torch . inf ) # `:num_tokens` to account for cases where the number of tokens in the batch is smaller than the supported context_size
attn_weights = torch . softmax (
attn_scores / keys . shape [ - 1 ] * * 0.5 , dim = - 1
)
attn_weights = self . dropout ( attn_weights ) # New
context_vec = attn_weights @ values
return context_vec
class MultiHeadAttentionWrapper ( nn . Module ) :
def __init__ ( self , d_in , d_out , context_length , dropout , num_heads , qkv_bias = False ) :
super ( ) . __init__ ( )
self . heads = nn . ModuleList (
[ CausalAttention ( d_in , d_out , context_length , dropout , qkv_bias )
for _ in range ( num_heads ) ]
)
def forward ( self , x ) :
return torch . cat ( [ head ( x ) for head in self . heads ] , dim = - 1 )
class MultiHeadAttention ( nn . Module ) :
def __init__ ( self , d_in , d_out , context_length , dropout , num_heads , qkv_bias = False ) :
super ( ) . __init__ ( )
assert d_out % num_heads == 0 , " d_out must be divisible by n_heads "
self . d_out = d_out
self . num_heads = num_heads
self . head_dim = d_out / / num_heads # Reduce the projection dim to match desired output dim
self . W_query = nn . Linear ( d_in , d_out , bias = qkv_bias )
self . W_key = nn . Linear ( d_in , d_out , bias = qkv_bias )
self . W_value = nn . Linear ( d_in , d_out , bias = qkv_bias )
self . out_proj = nn . Linear ( d_out , d_out ) # Linear layer to combine head outputs
self . dropout = nn . Dropout ( dropout )
self . register_buffer ( ' mask ' , torch . triu ( torch . ones ( context_length , context_length ) , diagonal = 1 ) )
def forward ( self , x ) :
b , num_tokens , d_in = x . shape
keys = self . W_key ( x ) # Shape: (b, num_tokens, d_out)
queries = self . W_query ( x )
values = self . W_value ( x )
# We implicitly split the matrix by adding a `num_heads` dimension
# Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
keys = keys . view ( b , num_tokens , self . num_heads , self . head_dim )
values = values . view ( b , num_tokens , self . num_heads , self . head_dim )
queries = queries . view ( b , num_tokens , self . num_heads , self . head_dim )
# Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
keys = keys . transpose ( 1 , 2 )
queries = queries . transpose ( 1 , 2 )
values = values . transpose ( 1 , 2 )
# Compute scaled dot-product attention (aka self-attention) with a causal mask
attn_scores = queries @ keys . transpose ( 2 , 3 ) # Dot product for each head
# Original mask truncated to the number of tokens and converted to boolean
mask_bool = self . mask . bool ( ) [ : num_tokens , : num_tokens ]
# Use the mask to fill attention scores
attn_scores . masked_fill_ ( mask_bool , - torch . inf )
attn_weights = torch . softmax ( attn_scores / keys . shape [ - 1 ] * * 0.5 , dim = - 1 )
attn_weights = self . dropout ( attn_weights )
# Shape: (b, num_tokens, num_heads, head_dim)
context_vec = ( attn_weights @ values ) . transpose ( 1 , 2 )
# Combine heads, where self.d_out = self.num_heads * self.head_dim
context_vec = context_vec . reshape ( b , num_tokens , self . d_out )
context_vec = self . out_proj ( context_vec ) # optional projection
return context_vec
2025-03-27 14:00:25 -05:00
######################
# Bonus
######################
class PyTorchMultiHeadAttention ( nn . Module ) :
def __init__ ( self , d_in , d_out , num_heads , dropout = 0.0 , qkv_bias = False ) :
super ( ) . __init__ ( )
assert d_out % num_heads == 0 , " embed_dim is indivisible by num_heads "
self . num_heads = num_heads
self . head_dim = d_out / / num_heads
self . d_out = d_out
self . qkv = nn . Linear ( d_in , 3 * d_out , bias = qkv_bias )
self . proj = nn . Linear ( d_out , d_out )
self . dropout = dropout
def forward ( self , x ) :
batch_size , num_tokens , embed_dim = x . shape
# (b, num_tokens, embed_dim) --> (b, num_tokens, 3 * embed_dim)
qkv = self . qkv ( x )
# (b, num_tokens, 3 * embed_dim) --> (b, num_tokens, 3, num_heads, head_dim)
qkv = qkv . view ( batch_size , num_tokens , 3 , self . num_heads , self . head_dim )
# (b, num_tokens, 3, num_heads, head_dim) --> (3, b, num_heads, num_tokens, head_dim)
qkv = qkv . permute ( 2 , 0 , 3 , 1 , 4 )
# (3, b, num_heads, num_tokens, head_dim) -> 3 times (b, num_heads, num_tokens, head_dim)
queries , keys , values = qkv
use_dropout = 0. if not self . training else self . dropout
context_vec = nn . functional . scaled_dot_product_attention (
queries , keys , values , attn_mask = None , dropout_p = use_dropout , is_causal = True )
# Combine heads, where self.d_out = self.num_heads * self.head_dim
context_vec = context_vec . transpose ( 1 , 2 ) . contiguous ( ) . view ( batch_size , num_tokens , self . d_out )
context_vec = self . proj ( context_vec )
return context_vec