automatically run on gpu or cpu

This commit is contained in:
rasbt 2024-03-07 20:14:03 -06:00
parent c5b17c3d67
commit 404f48aa74

View File

@ -3,7 +3,9 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "6f678e62-7bcb-4405-86ae-dce94f494303", "id": "6f678e62-7bcb-4405-86ae-dce94f494303",
"metadata": {}, "metadata": {
"id": "6f678e62-7bcb-4405-86ae-dce94f494303"
},
"source": [ "source": [
"# Efficient Multi-Head Attention Implementations" "# Efficient Multi-Head Attention Implementations"
] ]
@ -11,7 +13,9 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "2f9bb1b6-a1e5-4e0a-884d-0f31b374a8d6", "id": "2f9bb1b6-a1e5-4e0a-884d-0f31b374a8d6",
"metadata": {}, "metadata": {
"id": "2f9bb1b6-a1e5-4e0a-884d-0f31b374a8d6"
},
"source": [ "source": [
"## Multi-head attention implementations from chapter 3" "## Multi-head attention implementations from chapter 3"
] ]
@ -20,46 +24,68 @@
"cell_type": "code", "cell_type": "code",
"execution_count": 1, "execution_count": 1,
"id": "7898551e-f582-48ac-9f66-3632abe2a93f", "id": "7898551e-f582-48ac-9f66-3632abe2a93f",
"metadata": {}, "metadata": {
"outputs": [], "id": "7898551e-f582-48ac-9f66-3632abe2a93f",
"colab": {
"base_uri": "https://localhost:8080/"
},
"outputId": "840126fe-fffa-46d4-9717-41aef89d5052"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Running on cuda\n"
]
}
],
"source": [ "source": [
"import torch\n", "import torch\n",
"\n", "\n",
"torch.manual_seed(123)\n", "torch.manual_seed(123)\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"print(f\"Running on {device}\")\n",
"\n", "\n",
"batch_size = 8\n", "batch_size = 8\n",
"context_len = 1024\n", "context_len = 1024\n",
"embed_dim = 768\n", "embed_dim = 768\n",
"embeddings = torch.randn((batch_size, context_len, embed_dim))" "embeddings = torch.randn((batch_size, context_len, embed_dim), device=device)"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 2, "execution_count": 2,
"id": "297c93ed-aec0-4896-bb89-42c4b294d3d1", "id": "297c93ed-aec0-4896-bb89-42c4b294d3d1",
"metadata": {}, "metadata": {
"id": "297c93ed-aec0-4896-bb89-42c4b294d3d1",
"outputId": "5af9d36b-37c9-4f6e-c370-58a46db02632",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"outputs": [ "outputs": [
{ {
"name": "stdout",
"output_type": "stream", "output_type": "stream",
"name": "stdout",
"text": [ "text": [
"torch.Size([8, 1024, 9216])\n" "torch.Size([8, 1024, 9216])\n"
] ]
} }
], ],
"source": [ "source": [
"from ch03 import MultiHeadAttentionWrapper as Ch03_MHA_1\n", "from ch03 import MultiHeadAttentionWrapper as Ch03_MHA_Wrapper\n",
"\n", "\n",
"mha_ch03_1 = Ch03_MHA_1(\n", "mha_ch03_wrapper = Ch03_MHA_Wrapper(\n",
" d_in=embed_dim,\n", " d_in=embed_dim,\n",
" d_out=embed_dim,\n", " d_out=embed_dim,\n",
" block_size=context_len,\n", " block_size=context_len,\n",
" dropout=0.0,\n", " dropout=0.0,\n",
" num_heads=12,\n", " num_heads=12,\n",
" qkv_bias=False\n", " qkv_bias=False\n",
")\n", ").to(device)\n",
"\n", "\n",
"out = mha_ch03_1(embeddings)\n", "out = mha_ch03_wrapper(embeddings)\n",
"print(out.shape)" "print(out.shape)"
] ]
}, },
@ -67,36 +93,44 @@
"cell_type": "code", "cell_type": "code",
"execution_count": 3, "execution_count": 3,
"id": "4ee6a61b-d25c-4a0c-8a59-f285544e3710", "id": "4ee6a61b-d25c-4a0c-8a59-f285544e3710",
"metadata": {}, "metadata": {
"id": "4ee6a61b-d25c-4a0c-8a59-f285544e3710",
"outputId": "1c7ffc71-3b51-4ee8-beab-261625b1473e",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"outputs": [ "outputs": [
{ {
"name": "stdout",
"output_type": "stream", "output_type": "stream",
"name": "stdout",
"text": [ "text": [
"torch.Size([8, 1024, 768])\n" "torch.Size([8, 1024, 768])\n"
] ]
} }
], ],
"source": [ "source": [
"from ch03 import MultiHeadAttention as Ch03_MHA_2\n", "from ch03 import MultiHeadAttention as Ch03_MHA\n",
"\n", "\n",
"mha_ch03_2 = Ch03_MHA_2(\n", "mha_ch03 = Ch03_MHA(\n",
" d_in=embed_dim,\n", " d_in=embed_dim,\n",
" d_out=embed_dim,\n", " d_out=embed_dim,\n",
" block_size=context_len,\n", " block_size=context_len,\n",
" dropout=0.0,\n", " dropout=0.0,\n",
" num_heads=12,\n", " num_heads=12,\n",
" qkv_bias=False\n", " qkv_bias=False\n",
")\n", ").to(device)\n",
"\n", "\n",
"out = mha_ch03_2(embeddings)\n", "out = mha_ch03(embeddings)\n",
"print(out.shape)" "print(out.shape)"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "73cd11da-ea3b-4081-b483-c4965dfefbc4", "id": "73cd11da-ea3b-4081-b483-c4965dfefbc4",
"metadata": {}, "metadata": {
"id": "73cd11da-ea3b-4081-b483-c4965dfefbc4"
},
"source": [ "source": [
"## An alternative multi-head attention with combined weights" "## An alternative multi-head attention with combined weights"
] ]
@ -104,7 +138,9 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "1fa1a5ea-eaff-4d2d-aaf0-b34cdb6fd4dd", "id": "1fa1a5ea-eaff-4d2d-aaf0-b34cdb6fd4dd",
"metadata": {}, "metadata": {
"id": "1fa1a5ea-eaff-4d2d-aaf0-b34cdb6fd4dd"
},
"source": [ "source": [
"- The code for the `MultiHeadAttentionAlt` class below is based on code that was kindly shared by [Rayed Bin Wahed](https://github.com/rasbt/LLMs-from-scratch/discussions/51)\n", "- The code for the `MultiHeadAttentionAlt` class below is based on code that was kindly shared by [Rayed Bin Wahed](https://github.com/rasbt/LLMs-from-scratch/discussions/51)\n",
"- The main difference between the `MultiHeadAttentionAlt` class and the `MultiHeadAttention` class used in chapter 3 is that `MultiHeadAttentionAlt` uses a single weight matrix, `self.qkv = nn.Linear(d_in, 3 * d_out, bias=qkv_bias)` instead of separate weight matrices:\n", "- The main difference between the `MultiHeadAttentionAlt` class and the `MultiHeadAttention` class used in chapter 3 is that `MultiHeadAttentionAlt` uses a single weight matrix, `self.qkv = nn.Linear(d_in, 3 * d_out, bias=qkv_bias)` instead of separate weight matrices:\n",
@ -112,7 +148,7 @@
" - `self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)`\n", " - `self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)`\n",
" - `self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)`\n", " - `self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)`\n",
" - `self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)`\n", " - `self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)`\n",
" \n", "\n",
"- Here, `self.qkv` combines all three weight matrices `self.W_query`, `self.W_key`, and `self.W_value` to carry out the query, key, and value computation in a single step\n", "- Here, `self.qkv` combines all three weight matrices `self.W_query`, `self.W_key`, and `self.W_value` to carry out the query, key, and value computation in a single step\n",
"- Using `q, k, v = qkv.unbind(0)`, we obtain the individual query, key, and value tensors, which are then used similarly to the query, key, and value tensors in the `MultiHeadAttention` class in chapter 3" "- Using `q, k, v = qkv.unbind(0)`, we obtain the individual query, key, and value tensors, which are then used similarly to the query, key, and value tensors in the `MultiHeadAttention` class in chapter 3"
] ]
@ -121,11 +157,17 @@
"cell_type": "code", "cell_type": "code",
"execution_count": 4, "execution_count": 4,
"id": "9a6bd0a2-f27c-4602-afa0-c96cd295c1a6", "id": "9a6bd0a2-f27c-4602-afa0-c96cd295c1a6",
"metadata": {}, "metadata": {
"id": "9a6bd0a2-f27c-4602-afa0-c96cd295c1a6",
"outputId": "3c225fe5-73a9-4df0-c513-6296f4bb5261",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"outputs": [ "outputs": [
{ {
"name": "stdout",
"output_type": "stream", "output_type": "stream",
"name": "stdout",
"text": [ "text": [
"torch.Size([8, 1024, 768])\n" "torch.Size([8, 1024, 768])\n"
] ]
@ -135,7 +177,7 @@
"import torch.nn as nn\n", "import torch.nn as nn\n",
"\n", "\n",
"\n", "\n",
"class MultiHeadAttentionAlt(nn.Module):\n", "class MultiHeadAttentionCombinedQKV(nn.Module):\n",
" def __init__(self, d_in, d_out, num_heads, block_size, dropout=0.0, qkv_bias=False):\n", " def __init__(self, d_in, d_out, num_heads, block_size, dropout=0.0, qkv_bias=False):\n",
" super().__init__()\n", " super().__init__()\n",
"\n", "\n",
@ -173,7 +215,7 @@
" attn_scores = attn_scores.masked_fill(\n", " attn_scores = attn_scores.masked_fill(\n",
" self.mask.bool()[:num_tokens, :num_tokens], -torch.inf\n", " self.mask.bool()[:num_tokens, :num_tokens], -torch.inf\n",
" )\n", " )\n",
" \n", "\n",
" attn_weights = torch.softmax(attn_scores / keys.shape[-1]**-0.5, dim=-1)\n", " attn_weights = torch.softmax(attn_scores / keys.shape[-1]**-0.5, dim=-1)\n",
" attn_weights = self.dropout(attn_weights)\n", " attn_weights = self.dropout(attn_weights)\n",
"\n", "\n",
@ -191,23 +233,25 @@
" return context_vec\n", " return context_vec\n",
"\n", "\n",
"\n", "\n",
"mha_alt = MultiHeadAttentionAlt(\n", "mha_combined_qkv = MultiHeadAttentionCombinedQKV(\n",
" d_in=embed_dim,\n", " d_in=embed_dim,\n",
" d_out=embed_dim,\n", " d_out=embed_dim,\n",
" block_size=context_len,\n", " block_size=context_len,\n",
" dropout=0.0,\n", " dropout=0.0,\n",
" num_heads=12,\n", " num_heads=12,\n",
" qkv_bias=False\n", " qkv_bias=False\n",
")\n", ").to(device)\n",
"\n", "\n",
"out = mha_alt(embeddings)\n", "out = mha_combined_qkv(embeddings)\n",
"print(out.shape)" "print(out.shape)"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "48a042d3-ee78-4c29-bf63-d92fe6706632", "id": "48a042d3-ee78-4c29-bf63-d92fe6706632",
"metadata": {}, "metadata": {
"id": "48a042d3-ee78-4c29-bf63-d92fe6706632"
},
"source": [ "source": [
"## Multihead attention with PyTorch's scaled dot product attention" "## Multihead attention with PyTorch's scaled dot product attention"
] ]
@ -215,7 +259,9 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "f78e346f-3b85-44e6-9feb-f01131381148", "id": "f78e346f-3b85-44e6-9feb-f01131381148",
"metadata": {}, "metadata": {
"id": "f78e346f-3b85-44e6-9feb-f01131381148"
},
"source": [ "source": [
"- The implementation below uses PyTorch's [`scaled_dot_product_attention`](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) function, which implements a memory-optimized version of self-attention calld [flash attention](https://arxiv.org/abs/2205.14135)" "- The implementation below uses PyTorch's [`scaled_dot_product_attention`](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) function, which implements a memory-optimized version of self-attention calld [flash attention](https://arxiv.org/abs/2205.14135)"
] ]
@ -224,7 +270,9 @@
"cell_type": "code", "cell_type": "code",
"execution_count": 5, "execution_count": 5,
"id": "1b8e5a0d-1f65-4a03-bf6e-723f0cc428f5", "id": "1b8e5a0d-1f65-4a03-bf6e-723f0cc428f5",
"metadata": {}, "metadata": {
"id": "1b8e5a0d-1f65-4a03-bf6e-723f0cc428f5"
},
"outputs": [], "outputs": [],
"source": [ "source": [
"class MultiHeadAttentionPyTorch(nn.Module):\n", "class MultiHeadAttentionPyTorch(nn.Module):\n",
@ -275,11 +323,17 @@
"cell_type": "code", "cell_type": "code",
"execution_count": 6, "execution_count": 6,
"id": "fbc8ba92-3471-41cb-b1b2-4c0ef5be392b", "id": "fbc8ba92-3471-41cb-b1b2-4c0ef5be392b",
"metadata": {}, "metadata": {
"id": "fbc8ba92-3471-41cb-b1b2-4c0ef5be392b",
"outputId": "f3e7933d-16d3-45e5-f03d-610319004579",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"outputs": [ "outputs": [
{ {
"name": "stdout",
"output_type": "stream", "output_type": "stream",
"name": "stdout",
"text": [ "text": [
"torch.Size([8, 1024, 768])\n" "torch.Size([8, 1024, 768])\n"
] ]
@ -293,7 +347,7 @@
" dropout=0.0,\n", " dropout=0.0,\n",
" num_heads=12,\n", " num_heads=12,\n",
" qkv_bias=False\n", " qkv_bias=False\n",
")\n", ").to(device)\n",
"\n", "\n",
"out = mha_pytorch(embeddings)\n", "out = mha_pytorch(embeddings)\n",
"print(out.shape)" "print(out.shape)"
@ -302,7 +356,9 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "8877de71-f84f-4f6d-bc87-7552013b6301", "id": "8877de71-f84f-4f6d-bc87-7552013b6301",
"metadata": {}, "metadata": {
"id": "8877de71-f84f-4f6d-bc87-7552013b6301"
},
"source": [ "source": [
"## Speed comparison" "## Speed comparison"
] ]
@ -311,67 +367,91 @@
"cell_type": "code", "cell_type": "code",
"execution_count": 7, "execution_count": 7,
"id": "a97c0b2e-6593-49d8-98bc-2267b3aa610f", "id": "a97c0b2e-6593-49d8-98bc-2267b3aa610f",
"metadata": {}, "metadata": {
"id": "a97c0b2e-6593-49d8-98bc-2267b3aa610f",
"outputId": "bb928da8-6ac0-4d15-cf12-4903d73708fc",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"outputs": [ "outputs": [
{ {
"name": "stdout",
"output_type": "stream", "output_type": "stream",
"name": "stdout",
"text": [ "text": [
"879 ms ± 4.01 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" "41.1 ms ± 9.08 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
] ]
} }
], ],
"source": [ "source": [
"%timeit mha_ch03_1(embeddings)" "%timeit mha_ch03_wrapper(embeddings)"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 8, "execution_count": 8,
"id": "19db9c2c-8e75-431a-8eef-0b4d8284e6e6", "id": "19db9c2c-8e75-431a-8eef-0b4d8284e6e6",
"metadata": {}, "metadata": {
"id": "19db9c2c-8e75-431a-8eef-0b4d8284e6e6",
"outputId": "54f8e05e-0cb2-4e4a-cacd-27a309a3be8b",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"outputs": [ "outputs": [
{ {
"name": "stdout",
"output_type": "stream", "output_type": "stream",
"name": "stdout",
"text": [ "text": [
"259 ms ± 7.91 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" "6.58 ms ± 582 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
] ]
} }
], ],
"source": [ "source": [
"%timeit mha_ch03_2(embeddings)" "%timeit mha_ch03(embeddings)"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 9, "execution_count": 9,
"id": "aa526ee0-7a88-4f34-a49a-f8f97da83779", "id": "aa526ee0-7a88-4f34-a49a-f8f97da83779",
"metadata": {}, "metadata": {
"id": "aa526ee0-7a88-4f34-a49a-f8f97da83779",
"outputId": "415e959e-b648-4f1e-f05e-8b8444e74bee",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"outputs": [ "outputs": [
{ {
"name": "stdout",
"output_type": "stream", "output_type": "stream",
"name": "stdout",
"text": [ "text": [
"290 ms ± 2.58 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" "7.2 ms ± 327 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
] ]
} }
], ],
"source": [ "source": [
"%timeit mha_alt(embeddings)" "%timeit mha_combined_qkv(embeddings)"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 10, "execution_count": 10,
"id": "cc2b4256-16d8-4c34-9fd0-d4b4af0e60fa", "id": "cc2b4256-16d8-4c34-9fd0-d4b4af0e60fa",
"metadata": {}, "metadata": {
"id": "cc2b4256-16d8-4c34-9fd0-d4b4af0e60fa",
"outputId": "05b7c696-1b97-4f18-8430-481bb8940b6b",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"outputs": [ "outputs": [
{ {
"name": "stdout",
"output_type": "stream", "output_type": "stream",
"name": "stdout",
"text": [ "text": [
"91.5 ms ± 1.04 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" "2.38 ms ± 386 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n"
] ]
} }
], ],
@ -382,8 +462,7 @@
], ],
"metadata": { "metadata": {
"kernelspec": { "kernelspec": {
"display_name": "Python 3 (ipykernel)", "display_name": "Python 3",
"language": "python",
"name": "python3" "name": "python3"
}, },
"language_info": { "language_info": {
@ -397,7 +476,13 @@
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.10.12" "version": "3.10.12"
} },
"colab": {
"provenance": [],
"machine_shape": "hm",
"gpuType": "A100"
},
"accelerator": "GPU"
}, },
"nbformat": 4, "nbformat": 4,
"nbformat_minor": 5 "nbformat_minor": 5