readability improvements

This commit is contained in:
rasbt 2024-01-15 07:36:19 -06:00
parent a7b4880179
commit 9e85f13ba9
2 changed files with 53 additions and 32 deletions

View File

@ -1658,11 +1658,18 @@
"\n",
" # Compute scaled dot-product attention (aka self-attention) with a causal mask\n",
" attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head\n",
" attn_scores.masked_fill_(self.mask.bool()[:num_tokens, :num_tokens].unsqueeze(0).unsqueeze(0), -torch.inf)\n",
" # Original mask truncated to the number of tokens and converted to boolean\n",
" mask_bool = self.mask.bool()[:num_tokens, :num_tokens]\n",
" # Unsqueeze the mask twice to match dimensions\n",
" mask_unsqueezed = mask_bool.unsqueeze(0).unsqueeze(0)\n",
" # Use the unsqueezed mask to fill attention scores\n",
" attn_scores.masked_fill_(mask_unsqueezed, -torch.inf)\n",
" \n",
" attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)\n",
" attn_weights = self.dropout(attn_weights)\n",
"\n",
" context_vec = (attn_weights @ values).transpose(1, 2) # Shape: (b, num_tokens, n_heads, head_dim)\n",
" # Shape: (b, num_tokens, num_heads, head_dim)\n",
" context_vec = (attn_weights @ values).transpose(1, 2) \n",
" \n",
" # Combine heads, where self.d_out = self.num_heads * self.head_dim\n",
" context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)\n",

View File

@ -35,6 +35,7 @@
"source": [
"import tiktoken\n",
"import torch\n",
"import torch.nn as nn\n",
"from torch.utils.data import Dataset, DataLoader\n",
"\n",
"\n",
@ -86,8 +87,8 @@
"block_size = max_len\n",
"\n",
"\n",
"token_embedding_layer = torch.nn.Embedding(vocab_size, output_dim)\n",
"pos_embedding_layer = torch.nn.Embedding(vocab_size, output_dim)\n",
"token_embedding_layer = nn.Embedding(vocab_size, output_dim)\n",
"pos_embedding_layer = nn.Embedding(vocab_size, output_dim)\n",
"\n",
"max_length = 4\n",
"dataloader = create_dataloader(raw_text, batch_size=8, max_length=max_length, stride=5)"
@ -152,17 +153,15 @@
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"\n",
"class CausalSelfAttention(torch.nn.Module):\n",
"class CausalSelfAttention(nn.Module):\n",
"\n",
" def __init__(self, d_in, d_out, block_size, dropout):\n",
" super().__init__()\n",
" self.d_out = d_out\n",
" self.W_query = torch.nn.Linear(d_in, d_out, bias=False)\n",
" self.W_key = torch.nn.Linear(d_in, d_out, bias=False)\n",
" self.W_value = torch.nn.Linear(d_in, d_out, bias=False)\n",
" self.dropout = torch.nn.Dropout(dropout) # New\n",
" self.W_query = nn.Linear(d_in, d_out, bias=False)\n",
" self.W_key = nn.Linear(d_in, d_out, bias=False)\n",
" self.W_value = nn.Linear(d_in, d_out, bias=False)\n",
" self.dropout = nn.Dropout(dropout) # New\n",
" self.register_buffer('mask', torch.triu(torch.ones(block_size, block_size), diagonal=1)) # New\n",
"\n",
" def forward(self, x):\n",
@ -181,14 +180,14 @@
" return context_vec\n",
"\n",
"\n",
"class MultiHeadAttentionWrapper(torch.nn.Module):\n",
"class MultiHeadAttentionWrapper(nn.Module):\n",
" def __init__(self, d_in, d_out, block_size, dropout, num_heads):\n",
" super().__init__()\n",
" self.heads = torch.nn.ModuleList(\n",
" self.heads = nn.ModuleList(\n",
" [CausalSelfAttention(d_in, d_out, block_size, dropout) \n",
" for _ in range(num_heads)]\n",
" )\n",
" self.out_proj = torch.nn.Linear(d_out*num_heads, d_out*num_heads)\n",
" self.out_proj = nn.Linear(d_out*num_heads, d_out*num_heads)\n",
"\n",
" def forward(self, x):\n",
" context_vec = torch.cat([head(x) for head in self.heads], dim=-1)\n",
@ -241,10 +240,7 @@
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"\n",
"\n",
"class MultiHeadAttention(torch.nn.Module):\n",
"class MultiHeadAttention(nn.Module):\n",
" def __init__(self, d_in, d_out, block_size, dropout, num_heads):\n",
" super().__init__()\n",
" assert d_out % num_heads == 0, \"d_out must be divisible by n_heads\"\n",
@ -253,30 +249,48 @@
" self.num_heads = num_heads\n",
" self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim\n",
"\n",
" self.W_query = torch.nn.Linear(d_in, d_out, bias=False)\n",
" self.W_key = torch.nn.Linear(d_in, d_out, bias=False)\n",
" self.W_value = torch.nn.Linear(d_in, d_out, bias=False)\n",
" self.out_proj = torch.nn.Linear(d_out, d_out) # Linear layer to combine head outputs\n",
" self.dropout = torch.nn.Dropout(dropout)\n",
" self.W_query = nn.Linear(d_in, d_out, bias=False)\n",
" self.W_key = nn.Linear(d_in, d_out, bias=False)\n",
" self.W_value = nn.Linear(d_in, d_out, bias=False)\n",
" self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs\n",
" self.dropout = nn.Dropout(dropout)\n",
" self.register_buffer('mask', torch.triu(torch.ones(block_size, block_size), diagonal=1))\n",
"\n",
" def forward(self, x):\n",
" b, n_tokens, d_in = x.shape\n",
" b, num_tokens, d_in = x.shape\n",
"\n",
" # Split into multiple heads\n",
" keys = self.W_key(x).view(b, n_tokens, self.num_heads, self.head_dim).transpose(1, 2)\n",
" queries = self.W_query(x).view(b, n_tokens, self.num_heads, self.head_dim).transpose(1, 2)\n",
" values = self.W_value(x).view(b, n_tokens, self.num_heads, self.head_dim).transpose(1, 2)\n",
" keys = self.W_key(x) # Shape: (b, num_tokens, d_out)\n",
" queries = self.W_query(x)\n",
" values = self.W_value(x)\n",
"\n",
" # Compute scaled dot-product attention for each head\n",
" # We implicitly split the matrix by adding a `num_heads` dimension\n",
" # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)\n",
" keys = keys.view(b, num_tokens, self.num_heads, self.head_dim) \n",
" values = values.view(b, num_tokens, self.num_heads, self.head_dim)\n",
" queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)\n",
"\n",
" # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)\n",
" keys = keys.transpose(1, 2)\n",
" queries = queries.transpose(1, 2)\n",
" values = values.transpose(1, 2)\n",
"\n",
" # Compute scaled dot-product attention (aka self-attention) with a causal mask\n",
" attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head\n",
" attn_scores.masked_fill_(self.mask.bool()[:n_tokens, :n_tokens].unsqueeze(0).unsqueeze(0), -torch.inf)\n",
" # Original mask truncated to the number of tokens and converted to boolean\n",
" mask_bool = self.mask.bool()[:num_tokens, :num_tokens]\n",
" # Unsqueeze the mask twice to match dimensions\n",
" mask_unsqueezed = mask_bool.unsqueeze(0).unsqueeze(0)\n",
" # Use the unsqueezed mask to fill attention scores\n",
" attn_scores.masked_fill_(mask_unsqueezed, -torch.inf)\n",
" \n",
" attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)\n",
" attn_weights = self.dropout(attn_weights)\n",
" context_vec = (attn_weights @ values).transpose(1, 2) # Shape: (b, T, n_heads, head_dim)\n",
"\n",
" # Shape: (b, num_tokens, num_heads, head_dim)\n",
" context_vec = (attn_weights @ values).transpose(1, 2) \n",
" \n",
" # Combine heads, where self.d_out = self.num_heads * self.head_dim\n",
" context_vec = context_vec.contiguous().view(b, n_tokens, self.d_out)\n",
" context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)\n",
" context_vec = self.out_proj(context_vec) # optional projection\n",
"\n",
" return context_vec"