mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-11-01 10:20:00 +00:00
Merge pull request #14 from rasbt/update
Small cosmetic updates and exercise solutions for chapter 3
This commit is contained in:
commit
32267e3253
@ -323,7 +323,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.4"
|
||||
"version": "3.10.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@ -26,7 +26,7 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"torch version: 2.0.1\n"
|
||||
"torch version: 2.1.0\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -935,8 +935,8 @@
|
||||
" return context_vec\n",
|
||||
"\n",
|
||||
"torch.manual_seed(123)\n",
|
||||
"sa = SelfAttention_v1(d_in, d_out)\n",
|
||||
"print(sa(inputs))"
|
||||
"sa_v1 = SelfAttention_v1(d_in, d_out)\n",
|
||||
"print(sa_v1(inputs))"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -989,8 +989,8 @@
|
||||
" return context_vec\n",
|
||||
"\n",
|
||||
"torch.manual_seed(789)\n",
|
||||
"sa = SelfAttention_v2(d_in, d_out)\n",
|
||||
"print(sa(inputs))"
|
||||
"sa_v2 = SelfAttention_v2(d_in, d_out)\n",
|
||||
"print(sa_v2(inputs))"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -1006,7 +1006,7 @@
|
||||
"id": "c5025b37-0f2c-4a67-a7cb-1286af7026ab",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## 3.5 Hiding future words with causal self-attention"
|
||||
"## 3.5 Hiding future words with causal attention"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -1078,7 +1078,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 28,
|
||||
"execution_count": 26,
|
||||
"id": "43f3d2e3-185b-4184-9f98-edde5e6df746",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -1111,7 +1111,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 29,
|
||||
"execution_count": 27,
|
||||
"id": "9f531e2e-f4d2-4fea-a87f-4c132e48b9e7",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -1151,7 +1151,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 30,
|
||||
"execution_count": 28,
|
||||
"id": "6d392083-fd81-4f70-9bdf-8db985e673d6",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -1185,7 +1185,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 27,
|
||||
"execution_count": 29,
|
||||
"id": "a2be2f43-9cf0-44f6-8d8b-68ef2fb3cc39",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -1218,7 +1218,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 28,
|
||||
"execution_count": 30,
|
||||
"id": "b1cd6d7f-16f2-43c1-915e-0824f1a4bc52",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -1280,7 +1280,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 29,
|
||||
"execution_count": 31,
|
||||
"id": "0de578db-8289-41d6-b377-ef645751e33f",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -1307,7 +1307,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 30,
|
||||
"execution_count": 32,
|
||||
"id": "b16c5edb-942b-458c-8e95-25e4e355381e",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -1329,19 +1329,27 @@
|
||||
"print(dropout(attn_weights))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "cdc14639-5f0f-4840-aa9d-8eb36ea90fb7",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### 3.5.3 Implementing a compact causal self-attention class"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "09c41d29-1933-43dc-ada6-2dbb56287204",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"- Now, we are ready to implement a working implementation of self-attention, including the causal and dropout masks. \n",
|
||||
"- One more thing is to implement the code to handle batches consisting of more than one input so that our `CausalSelfAttention` class supports the batch outputs produced by the data loader we implemented in chapter 2.\n",
|
||||
"- One more thing is to implement the code to handle batches consisting of more than one input so that our `CausalAttention` class supports the batch outputs produced by the data loader we implemented in chapter 2.\n",
|
||||
"- For simplicity, to simulate such batch input, we duplicate the input text example:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 31,
|
||||
"execution_count": 33,
|
||||
"id": "977a5fa7-a9d5-4e2e-8a32-8e0331ccfe28",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -1358,17 +1366,9 @@
|
||||
"print(batch.shape) # 2 inputs with 6 tokens each, and each token has embedding dimension 3"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "cdc14639-5f0f-4840-aa9d-8eb36ea90fb7",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### 3.5.3 Implementing a compact causal self-attention class"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 32,
|
||||
"execution_count": 34,
|
||||
"id": "60d8c2eb-2d8e-4d2c-99bc-9eef8cc53ca0",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -1394,7 +1394,7 @@
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"class CausalSelfAttention(nn.Module):\n",
|
||||
"class CausalAttention(nn.Module):\n",
|
||||
"\n",
|
||||
" def __init__(self, d_in, d_out, block_size, dropout):\n",
|
||||
" super().__init__()\n",
|
||||
@ -1423,9 +1423,9 @@
|
||||
"torch.manual_seed(123)\n",
|
||||
"\n",
|
||||
"block_size = batch.shape[1]\n",
|
||||
"csa = CausalSelfAttention(d_in, d_out, block_size, 0.0)\n",
|
||||
"ca = CausalAttention(d_in, d_out, block_size, 0.0)\n",
|
||||
"\n",
|
||||
"context_vecs = csa(batch)\n",
|
||||
"context_vecs = ca(batch)\n",
|
||||
"\n",
|
||||
"print(context_vecs)\n",
|
||||
"print(\"context_vecs.shape:\", context_vecs.shape)"
|
||||
@ -1475,7 +1475,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 33,
|
||||
"execution_count": 42,
|
||||
"id": "b9a66e11-7105-4bb4-be84-041f1a1f3bd2",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -1506,7 +1506,7 @@
|
||||
" def __init__(self, d_in, d_out, block_size, dropout, num_heads):\n",
|
||||
" super().__init__()\n",
|
||||
" self.heads = nn.ModuleList(\n",
|
||||
" [CausalSelfAttention(d_in, d_out, block_size, dropout) \n",
|
||||
" [CausalAttention(d_in, d_out, block_size, dropout) \n",
|
||||
" for _ in range(num_heads)]\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
@ -1516,7 +1516,8 @@
|
||||
"\n",
|
||||
"torch.manual_seed(123)\n",
|
||||
"\n",
|
||||
"block_size = batch.shape[1]\n",
|
||||
"block_size = batch.shape[1] # This is the number of tokens\n",
|
||||
"d_in, d_out = 3, 2\n",
|
||||
"mha = MultiHeadAttentionWrapper(d_in, d_out, block_size, 0.0, num_heads=2)\n",
|
||||
"\n",
|
||||
"context_vecs = mha(batch)\n",
|
||||
@ -1537,7 +1538,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 34,
|
||||
"execution_count": 36,
|
||||
"id": "dc9a4375-068b-4b2a-aabb-a29347ca5ecd",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -1587,14 +1588,14 @@
|
||||
"id": "f4b48d0d-71ba-4fa0-b714-ca80cabcb6f7",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"- While the above is an intuitive and fully functional implementation of multi-head attention (wrapping the single-head attention `CausalSelfAttention` implementation from earlier), we can write a stand-alone class called `MultiHeadAttention` to achieve the same.\n",
|
||||
"- While the above is an intuitive and fully functional implementation of multi-head attention (wrapping the single-head attention `CausalAttention` implementation from earlier), we can write a stand-alone class called `MultiHeadAttention` to achieve the same.\n",
|
||||
"\n",
|
||||
"- We don't concatenate single attention heads for this stand-alone `MultiHeadAttention` class. Instead, we create single W_query, W_key, and W_value weight matrices and then split those into individual matrices for each attention head:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 35,
|
||||
"execution_count": 37,
|
||||
"id": "110b0188-6e9e-4e56-a988-10523c6c8538",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -1637,34 +1638,33 @@
|
||||
" self.register_buffer('mask', torch.triu(torch.ones(block_size, block_size), diagonal=1))\n",
|
||||
"\n",
|
||||
" def forward(self, x):\n",
|
||||
" b, n_tokens, d_in = x.shape\n",
|
||||
" # (b, n_heads, T) -> (b, T, n_heads, head_dim)\n",
|
||||
" b, num_tokens, d_in = x.shape\n",
|
||||
"\n",
|
||||
" keys = self.W_key(x) # Shape: (b, T, d_out)\n",
|
||||
" keys = self.W_key(x) # Shape: (b, num_tokens, d_out)\n",
|
||||
" queries = self.W_query(x)\n",
|
||||
" values = self.W_value(x)\n",
|
||||
"\n",
|
||||
" # We implicitely split the matrix by adding a `num_heads` dimension\n",
|
||||
" # Unroll last dim: (b, T, d_out) -> (b, T, num_heads, head_dim)\n",
|
||||
" keys = keys.view(b, n_tokens, self.num_heads, self.head_dim) \n",
|
||||
" values = values.view(b, n_tokens, self.num_heads, self.head_dim)\n",
|
||||
" queries = queries.view(b, n_tokens, self.num_heads, self.head_dim)\n",
|
||||
" # We implicitly split the matrix by adding a `num_heads` dimension\n",
|
||||
" # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)\n",
|
||||
" keys = keys.view(b, num_tokens, self.num_heads, self.head_dim) \n",
|
||||
" values = values.view(b, num_tokens, self.num_heads, self.head_dim)\n",
|
||||
" queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)\n",
|
||||
"\n",
|
||||
" # Transpose: (b, T, num_heads, head_dim) -> (b, num_heads, T, head_dim)\n",
|
||||
" # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)\n",
|
||||
" keys = keys.transpose(1, 2)\n",
|
||||
" queries = queries.transpose(1, 2)\n",
|
||||
" values = values.transpose(1, 2)\n",
|
||||
"\n",
|
||||
" # Compute scaled dot-product attention\n",
|
||||
" # Compute scaled dot-product attention (aka self-attention) with a causal mask\n",
|
||||
" attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head\n",
|
||||
" attn_scores.masked_fill_(self.mask.bool()[:n_tokens, :n_tokens].unsqueeze(0).unsqueeze(0), -torch.inf)\n",
|
||||
" attn_scores.masked_fill_(self.mask.bool()[:num_tokens, :num_tokens].unsqueeze(0).unsqueeze(0), -torch.inf)\n",
|
||||
" attn_weights = torch.softmax(attn_scores / self.head_dim**0.5, dim=-1)\n",
|
||||
" attn_weights = self.dropout(attn_weights)\n",
|
||||
"\n",
|
||||
" context_vec = (attn_weights @ values).transpose(1, 2) # Shape: (b, T, n_heads, head_dim)\n",
|
||||
" context_vec = (attn_weights @ values).transpose(1, 2) # Shape: (b, num_tokens, n_heads, head_dim)\n",
|
||||
" \n",
|
||||
" # Combine heads, where self.d_out = self.num_heads * self.head_dim\n",
|
||||
" context_vec = context_vec.contiguous().view(b, n_tokens, self.d_out)\n",
|
||||
" context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)\n",
|
||||
" context_vec = self.out_proj(context_vec) # optional projection\n",
|
||||
"\n",
|
||||
" return context_vec\n",
|
||||
@ -1709,7 +1709,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 36,
|
||||
"execution_count": 38,
|
||||
"id": "e8cfc1ae-78ab-4faa-bc73-98bd054806c9",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -1752,7 +1752,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 37,
|
||||
"execution_count": 39,
|
||||
"id": "053760f1-1a02-42f0-b3bf-3d939e407039",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -1782,6 +1782,36 @@
|
||||
"print(\"\\nSecond head:\\n\", second_res)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 45,
|
||||
"id": "08c2a3fd-e674-4d69-9ef4-ea94b788e937",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"2360064"
|
||||
]
|
||||
},
|
||||
"execution_count": 45,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"block_size = 1024\n",
|
||||
"d_in, d_out = 768, 768\n",
|
||||
"num_heads = 12\n",
|
||||
"\n",
|
||||
"mha = MultiHeadAttention(d_in, d_out, block_size, 0.0, num_heads)\n",
|
||||
"\n",
|
||||
"def count_parameters(model):\n",
|
||||
" return sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
|
||||
"\n",
|
||||
"count_parameters(mha)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "dec671bf-7938-4304-ad1e-75d9920e7f43",
|
||||
|
||||
308
ch03/01_main-chapter-code/exercise-solutions.ipynb
Normal file
308
ch03/01_main-chapter-code/exercise-solutions.ipynb
Normal file
@ -0,0 +1,308 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "51c9672d-8d0c-470d-ac2d-1271f8ec3f14",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Chapter 3 Exercise solutions"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "33dfa199-9aee-41d4-a64b-7e3811b9a616",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Exercise 3.1"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "5fee2cf5-61c3-4167-81b5-44ea155bbaf2",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch\n",
|
||||
"\n",
|
||||
"inputs = torch.tensor(\n",
|
||||
" [[0.43, 0.15, 0.89], # Your (x^1)\n",
|
||||
" [0.55, 0.87, 0.66], # journey (x^2)\n",
|
||||
" [0.57, 0.85, 0.64], # starts (x^3)\n",
|
||||
" [0.22, 0.58, 0.33], # with (x^4)\n",
|
||||
" [0.77, 0.25, 0.10], # one (x^5)\n",
|
||||
" [0.05, 0.80, 0.55]] # step (x^6)\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"d_in, d_out = 3, 2"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 58,
|
||||
"id": "62ea289c-41cd-4416-89dd-dde6383a6f70",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import torch.nn as nn\n",
|
||||
"\n",
|
||||
"class SelfAttention_v1(nn.Module):\n",
|
||||
"\n",
|
||||
" def __init__(self, d_in, d_out):\n",
|
||||
" super().__init__()\n",
|
||||
" self.d_out = d_out\n",
|
||||
" self.W_query = nn.Parameter(torch.rand(d_in, d_out))\n",
|
||||
" self.W_key = nn.Parameter(torch.rand(d_in, d_out))\n",
|
||||
" self.W_value = nn.Parameter(torch.rand(d_in, d_out))\n",
|
||||
"\n",
|
||||
" def forward(self, x):\n",
|
||||
" keys = x @ self.W_key\n",
|
||||
" queries = x @ self.W_query\n",
|
||||
" values = x @ self.W_value\n",
|
||||
" \n",
|
||||
" attn_scores = queries @ keys.T # omega\n",
|
||||
" attn_weights = torch.softmax(attn_scores / self.d_out**0.5, dim=-1)\n",
|
||||
"\n",
|
||||
" context_vec = attn_weights @ values\n",
|
||||
" return context_vec\n",
|
||||
"\n",
|
||||
"torch.manual_seed(123)\n",
|
||||
"sa_v1 = SelfAttention_v1(d_in, d_out)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 59,
|
||||
"id": "7b035143-f4e8-45fb-b398-dec1bd5153d4",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class SelfAttention_v2(nn.Module):\n",
|
||||
"\n",
|
||||
" def __init__(self, d_in, d_out):\n",
|
||||
" super().__init__()\n",
|
||||
" self.d_out = d_out\n",
|
||||
" self.W_query = nn.Linear(d_in, d_out, bias=False)\n",
|
||||
" self.W_key = nn.Linear(d_in, d_out, bias=False)\n",
|
||||
" self.W_value = nn.Linear(d_in, d_out, bias=False)\n",
|
||||
"\n",
|
||||
" def forward(self, x):\n",
|
||||
" keys = self.W_key(x)\n",
|
||||
" queries = self.W_query(x)\n",
|
||||
" values = self.W_value(x)\n",
|
||||
" \n",
|
||||
" attn_scores = queries @ keys.T\n",
|
||||
" attn_weights = torch.softmax(attn_scores / self.d_out**0.5, dim=1)\n",
|
||||
"\n",
|
||||
" context_vec = attn_weights @ values\n",
|
||||
" return context_vec\n",
|
||||
"\n",
|
||||
"torch.manual_seed(123)\n",
|
||||
"sa_v2 = SelfAttention_v2(d_in, d_out)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 60,
|
||||
"id": "7591d79c-c30e-406d-adfd-20c12eb448f6",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sa_v1.W_query = torch.nn.Parameter(sa_v2.W_query.weight.T)\n",
|
||||
"sa_v1.W_key = torch.nn.Parameter(sa_v2.W_key.weight.T)\n",
|
||||
"sa_v1.W_value = torch.nn.Parameter(sa_v2.W_value.weight.T)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 61,
|
||||
"id": "ddd0f54f-6bce-46cc-a428-17c2a56557d0",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"tensor([[-0.5337, -0.1051],\n",
|
||||
" [-0.5323, -0.1080],\n",
|
||||
" [-0.5323, -0.1079],\n",
|
||||
" [-0.5297, -0.1076],\n",
|
||||
" [-0.5311, -0.1066],\n",
|
||||
" [-0.5299, -0.1081]], grad_fn=<MmBackward0>)"
|
||||
]
|
||||
},
|
||||
"execution_count": 61,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"sa_v1(inputs)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 62,
|
||||
"id": "340908f8-1144-4ddd-a9e1-a1c5c3d592f5",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"tensor([[-0.5337, -0.1051],\n",
|
||||
" [-0.5323, -0.1080],\n",
|
||||
" [-0.5323, -0.1079],\n",
|
||||
" [-0.5297, -0.1076],\n",
|
||||
" [-0.5311, -0.1066],\n",
|
||||
" [-0.5299, -0.1081]], grad_fn=<MmBackward0>)"
|
||||
]
|
||||
},
|
||||
"execution_count": 62,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"sa_v2(inputs)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "33543edb-46b5-4b01-8704-f7f101230544",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Exercise 3.2"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "0588e209-1644-496a-8dae-7630b4ef9083",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"If we want to have an output dimension of 2, as earlier in single-head attention, we can have to change the projection dimension `d_out` to 1:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "18e748ef-3106-4e11-a781-b230b74a0cef",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"```python\n",
|
||||
"torch.manual_seed(123)\n",
|
||||
"\n",
|
||||
"d_out = 1\n",
|
||||
"mha = MultiHeadAttentionWrapper(d_in, d_out, block_size, 0.0, num_heads=2)\n",
|
||||
"\n",
|
||||
"context_vecs = mha(batch)\n",
|
||||
"\n",
|
||||
"print(context_vecs)\n",
|
||||
"print(\"context_vecs.shape:\", context_vecs.shape)\n",
|
||||
"```"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "78234544-d989-4f71-ac28-85a7ec1e6b7b",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"```\n",
|
||||
"tensor([[[-9.1476e-02, 3.4164e-02],\n",
|
||||
" [-2.6796e-01, -1.3427e-03],\n",
|
||||
" [-4.8421e-01, -4.8909e-02],\n",
|
||||
" [-6.4808e-01, -1.0625e-01],\n",
|
||||
" [-8.8380e-01, -1.7140e-01],\n",
|
||||
" [-1.4744e+00, -3.4327e-01]],\n",
|
||||
"\n",
|
||||
" [[-9.1476e-02, 3.4164e-02],\n",
|
||||
" [-2.6796e-01, -1.3427e-03],\n",
|
||||
" [-4.8421e-01, -4.8909e-02],\n",
|
||||
" [-6.4808e-01, -1.0625e-01],\n",
|
||||
" [-8.8380e-01, -1.7140e-01],\n",
|
||||
" [-1.4744e+00, -3.4327e-01]]], grad_fn=<CatBackward0>)\n",
|
||||
"context_vecs.shape: torch.Size([2, 6, 2])\n",
|
||||
"```"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "92bdabcb-06cf-4576-b810-d883bbd313ba",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Exercise 3.3"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "84c9b963-d01f-46e6-96bf-8eb2a54c5e42",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"```python\n",
|
||||
"block_size = 1024\n",
|
||||
"d_in, d_out = 768, 768\n",
|
||||
"num_heads = 12\n",
|
||||
"\n",
|
||||
"mha = MultiHeadAttention(d_in, d_out, block_size, 0.0, num_heads)\n",
|
||||
"```"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "375d5290-8e8b-4149-958e-1efb58a69191",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Optionally, the number of parameters is as follows:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "6d7e603c-1658-4da9-9c0b-ef4bc72832b4",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"```python\n",
|
||||
"def count_parameters(model):\n",
|
||||
" return sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
|
||||
"\n",
|
||||
"count_parameters(mha)\n",
|
||||
"```"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "51ba00bd-feb0-4424-84cb-7c2b1f908779",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"```\n",
|
||||
"2360064 # (2.36 M)\n",
|
||||
"```"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "a56c1d47-9b95-4bd1-a517-580a6f779c52",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"The GPT-2 model has 117M parameters in total, but as we can see, most of its parameters are not in the multi-head attention module itself."
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user