LLMs-from-scratch/ch03/01_main-chapter-code/multihead-attention.ipynb

333 lines
10 KiB
Plaintext
Raw Normal View History

2023-12-09 17:13:56 -06:00
{
"cells": [
{
"cell_type": "markdown",
"id": "6f678e62-7bcb-4405-86ae-dce94f494303",
"metadata": {},
"source": [
"# Multi-head Attention Plus Data Loading"
]
},
{
"cell_type": "markdown",
"id": "070000fc-a7b7-4c56-a2c0-a938d413a790",
"metadata": {},
"source": [
"The complete chapter code is located in [ch03.ipynb](./ch03.ipynb).\n",
"\n",
"This notebook contains the main takeaway, multihead-attention implementation (plus the data loading pipeline from chapter 2)"
]
},
{
"cell_type": "markdown",
"id": "3f60dc93-281d-447e-941f-aede0c7ff7fc",
"metadata": {},
"source": [
"## Data Loader from Chapter 2"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "0ed4b7db-3b47-4fd3-a4a6-5f4ed5dd166e",
"metadata": {},
"outputs": [],
"source": [
"import tiktoken\n",
"import torch\n",
"from torch.utils.data import Dataset, DataLoader\n",
"\n",
"\n",
"class GPTDatasetV1(Dataset):\n",
" def __init__(self, txt, tokenizer, max_length, stride):\n",
" self.tokenizer = tokenizer\n",
" self.input_ids = []\n",
" self.target_ids = []\n",
"\n",
" # Tokenize the entire text\n",
" token_ids = tokenizer.encode(txt)\n",
"\n",
" # Use a sliding window to chunk the book into overlapping sequences of max_length\n",
" for i in range(0, len(token_ids) - max_length, stride):\n",
" input_chunk = token_ids[i:i + max_length]\n",
" target_chunk = token_ids[i + 1: i + max_length + 1]\n",
" self.input_ids.append(torch.tensor(input_chunk))\n",
" self.target_ids.append(torch.tensor(target_chunk))\n",
"\n",
" def __len__(self):\n",
" return len(self.input_ids)\n",
"\n",
" def __getitem__(self, idx):\n",
" return self.input_ids[idx], self.target_ids[idx]\n",
"\n",
"\n",
"def create_dataloader(txt, batch_size=4, max_length=256, stride=128):\n",
" # Initialize the tokenizer\n",
" tokenizer = tiktoken.get_encoding(\"gpt2\")\n",
"\n",
" # Create dataset\n",
" dataset = GPTDatasetV1(txt, tokenizer, max_length, stride)\n",
"\n",
" # Create dataloader\n",
" dataloader = DataLoader(dataset, batch_size=batch_size)\n",
"\n",
" return dataloader\n",
"\n",
"\n",
"with open(\"small-text-sample.txt\", \"r\", encoding=\"utf-8\") as f:\n",
" raw_text = f.read()\n",
"\n",
"tokenizer = tiktoken.get_encoding(\"gpt2\")\n",
"encoded_text = tokenizer.encode(raw_text)\n",
"\n",
"vocab_size = 50257\n",
"output_dim = 256\n",
"token_embedding_layer = torch.nn.Embedding(vocab_size, output_dim)\n",
"pos_embedding_layer = torch.nn.Embedding(vocab_size, output_dim)\n",
"\n",
"max_length = 4\n",
"dataloader = create_dataloader(raw_text, batch_size=8, max_length=max_length, stride=5)"
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "664397bc-6daa-4b88-90aa-e8fc1fbd5846",
"metadata": {},
"outputs": [],
"source": [
"for batch in dataloader:\n",
" x, y = batch\n",
"\n",
" token_embeddings = token_embedding_layer(x)\n",
" pos_embeddings = pos_embedding_layer(torch.arange(max_length))\n",
"\n",
" input_embeddings = token_embeddings + pos_embeddings\n",
"\n",
" break"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "d3664332-e6bb-447e-8b96-203aafde8b24",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([8, 4, 256])\n"
]
}
],
"source": [
"print(input_embeddings.shape)"
]
},
{
"cell_type": "markdown",
"id": "bd298bf4-e320-40c1-9084-6526d07e6d5d",
"metadata": {},
"source": [
"# Multi-head Attention from Chapter 3"
]
},
{
"cell_type": "markdown",
"id": "58b2297b-a001-49fd-994c-b1700866cd01",
"metadata": {},
"source": [
"## Variant A: Simple implementation"
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "a44e682d-1c3c-445d-85fa-b142f89f8503",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"\n",
"class CausalSelfAttention(torch.nn.Module):\n",
"\n",
" def __init__(self, d_in, d_out, block_size, dropout):\n",
" super().__init__()\n",
" self.d_out = d_out\n",
" self.W_query = torch.nn.Linear(d_in, d_out, bias=False)\n",
" self.W_key = torch.nn.Linear(d_in, d_out, bias=False)\n",
" self.W_value = torch.nn.Linear(d_in, d_out, bias=False)\n",
" self.dropout = torch.nn.Dropout(dropout) # New\n",
" self.register_buffer('mask', torch.triu(torch.ones(block_size, block_size), diagonal=1)) # New\n",
"\n",
" def forward(self, x):\n",
" b, n_tokens, d_in = x.shape # New batch dimension b\n",
" keys = self.W_key(x)\n",
" queries = self.W_query(x)\n",
" values = self.W_value(x)\n",
"\n",
" attn_scores = queries @ keys.transpose(1, 2) # Changed transpose\n",
" attn_scores.masked_fill_( # New, _ ops are in-place\n",
" self.mask.bool()[:n_tokens, :n_tokens], -torch.inf) \n",
" attn_weights = torch.softmax(attn_scores / self.d_out**0.5, dim=1)\n",
" attn_weights = self.dropout(attn_weights) # New\n",
"\n",
" context_vec = attn_weights @ values\n",
" return context_vec\n",
"\n",
"\n",
"class MultiHeadAttentionWrapper(torch.nn.Module):\n",
" def __init__(self, d_in, d_out, block_size, dropout, num_heads):\n",
" super().__init__()\n",
" self.heads = torch.nn.ModuleList(\n",
" [CausalSelfAttention(d_in, d_out, block_size, dropout) \n",
" for _ in range(num_heads)]\n",
" )\n",
" self.out_proj = torch.nn.Linear(d_out*num_heads, d_out*num_heads)\n",
"\n",
" def forward(self, x):\n",
" context_vec = torch.cat([head(x) for head in self.heads], dim=-1)\n",
" return self.out_proj(context_vec)"
]
},
{
"cell_type": "code",
"execution_count": 28,
"id": "7898551e-f582-48ac-9f66-3632abe2a93f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"context_vecs.shape: torch.Size([8, 4, 256])\n"
]
}
],
"source": [
"torch.manual_seed(123)\n",
"\n",
"block_size = max_length\n",
"d_in = output_dim\n",
"\n",
"num_heads=2\n",
"d_out = d_in // num_heads\n",
"\n",
"mha = MultiHeadAttentionWrapper(d_in, d_out, block_size, 0.0, num_heads)\n",
"\n",
"batch = input_embeddings\n",
"context_vecs = mha(batch)\n",
"\n",
"print(\"context_vecs.shape:\", context_vecs.shape)"
]
},
{
"cell_type": "markdown",
"id": "1e288239-5146-424d-97fe-74024ae711b9",
"metadata": {},
"source": [
"## Variant B: Alternative implementation"
]
},
{
"cell_type": "code",
"execution_count": 34,
"id": "2773c09d-c136-4372-a2be-04b58d292842",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"\n",
"\n",
"class MultiHeadAttention(torch.nn.Module):\n",
" def __init__(self, d_in, d_out, block_size, dropout, num_heads):\n",
" super().__init__()\n",
" assert d_out % num_heads == 0, \"d_out must be divisible by n_heads\"\n",
"\n",
" self.d_out = d_out\n",
" self.num_heads = num_heads\n",
" self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim\n",
"\n",
" self.W_query = torch.nn.Linear(d_in, d_out, bias=False)\n",
" self.W_key = torch.nn.Linear(d_in, d_out, bias=False)\n",
" self.W_value = torch.nn.Linear(d_in, d_out, bias=False)\n",
" self.out_proj = torch.nn.Linear(d_out, d_out) # Linear layer to combine head outputs\n",
" self.dropout = torch.nn.Dropout(dropout)\n",
" self.register_buffer('mask', torch.triu(torch.ones(block_size, block_size), diagonal=1))\n",
"\n",
" def forward(self, x):\n",
" b, n_tokens, d_in = x.shape\n",
"\n",
" # Split into multiple heads\n",
" keys = self.W_key(x).view(b, n_tokens, self.num_heads, self.head_dim).transpose(1, 2)\n",
" queries = self.W_query(x).view(b, n_tokens, self.num_heads, self.head_dim).transpose(1, 2)\n",
" values = self.W_value(x).view(b, n_tokens, self.num_heads, self.head_dim).transpose(1, 2)\n",
"\n",
" # Compute scaled dot-product attention for each head\n",
" attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head\n",
" attn_scores.masked_fill_(self.mask.bool()[:n_tokens, :n_tokens].unsqueeze(0).unsqueeze(0), -torch.inf)\n",
" attn_weights = torch.softmax(attn_scores / self.head_dim**0.5, dim=-1)\n",
" attn_weights = self.dropout(attn_weights)\n",
" context_vec = (attn_weights @ values).transpose(1, 2) # Shape: (b, T, n_heads, head_dim)\n",
" \n",
" # Combine heads, where self.d_out = self.num_heads * self.head_dim\n",
" context_vec = context_vec.contiguous().view(b, n_tokens, self.d_out)\n",
" context_vec = self.out_proj(context_vec) # optional projection\n",
"\n",
" return context_vec"
]
},
{
"cell_type": "code",
"execution_count": 35,
"id": "779fdd04-0152-4308-af08-840800a7f395",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"context_vecs.shape: torch.Size([8, 4, 256])\n"
]
}
],
"source": [
"torch.manual_seed(123)\n",
"\n",
"block_size = max_length\n",
"d_in = output_dim\n",
"d_out = d_in\n",
"\n",
"mha = MultiHeadAttention(d_in, d_out, block_size, 0.0, num_heads=2)\n",
"\n",
"batch = input_embeddings\n",
"context_vecs = mha(batch)\n",
"\n",
"print(\"context_vecs.shape:\", context_vecs.shape)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}