diff --git a/ch03/02_bonus_efficient-multihead-attention/mha-implementations.ipynb b/ch03/02_bonus_efficient-multihead-attention/mha-implementations.ipynb index f935060..d1b969f 100644 --- a/ch03/02_bonus_efficient-multihead-attention/mha-implementations.ipynb +++ b/ch03/02_bonus_efficient-multihead-attention/mha-implementations.ipynb @@ -1,618 +1,859 @@ { - "cells": [ - { - "cell_type": "markdown", - "id": "6f678e62-7bcb-4405-86ae-dce94f494303", - "metadata": { - "id": "6f678e62-7bcb-4405-86ae-dce94f494303" - }, - "source": [ - "# Efficient Multi-Head Attention Implementations" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "id": "7898551e-f582-48ac-9f66-3632abe2a93f", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "7898551e-f582-48ac-9f66-3632abe2a93f", - "outputId": "02205088-47f1-4fc1-83a4-dd0be4cd64dd" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Running on cuda\n" - ] - } - ], - "source": [ - "import torch\n", - "\n", - "torch.manual_seed(123)\n", - "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", - "print(f\"Running on {device}\")\n", - "\n", - "batch_size = 8\n", - "context_len = 1024\n", - "embed_dim = 768\n", - "embeddings = torch.randn((batch_size, context_len, embed_dim), device=device)" - ] - }, - { - "cell_type": "markdown", - "id": "2f9bb1b6-a1e5-4e0a-884d-0f31b374a8d6", - "metadata": { - "id": "2f9bb1b6-a1e5-4e0a-884d-0f31b374a8d6" - }, - "source": [ - "## 1) CausalAttention MHA wrapper class from chapter 3" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "297c93ed-aec0-4896-bb89-42c4b294d3d1", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "297c93ed-aec0-4896-bb89-42c4b294d3d1", - "outputId": "a1eefc3c-21ea-463e-e75e-06af9f6262dd" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "torch.Size([8, 1024, 9216])\n" - ] - } - ], - "source": [ - "from ch03 import MultiHeadAttentionWrapper as Ch03_MHA_Wrapper\n", - "\n", - "mha_ch03_wrapper = Ch03_MHA_Wrapper(\n", - " d_in=embed_dim,\n", - " d_out=embed_dim,\n", - " block_size=context_len,\n", - " dropout=0.0,\n", - " num_heads=12,\n", - " qkv_bias=False\n", - ").to(device)\n", - "\n", - "out = mha_ch03_wrapper(embeddings)\n", - "print(out.shape)" - ] - }, - { - "cell_type": "markdown", - "id": "21930804-b327-40b1-8e63-94dcad39ce7b", - "metadata": { - "id": "21930804-b327-40b1-8e63-94dcad39ce7b" - }, - "source": [ - "## 2) The multi-head attention class from chapter 3" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "4ee6a61b-d25c-4a0c-8a59-f285544e3710", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "4ee6a61b-d25c-4a0c-8a59-f285544e3710", - "outputId": "c66ee5fd-b0cd-4ab4-e097-4d64902ea0d0" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "torch.Size([8, 1024, 768])\n" - ] - } - ], - "source": [ - "from ch03 import MultiHeadAttention as Ch03_MHA\n", - "\n", - "mha_ch03 = Ch03_MHA(\n", - " d_in=embed_dim,\n", - " d_out=embed_dim,\n", - " block_size=context_len,\n", - " dropout=0.0,\n", - " num_heads=12,\n", - " qkv_bias=False\n", - ").to(device)\n", - "\n", - "out = mha_ch03(embeddings)\n", - "print(out.shape)" - ] - }, - { - "cell_type": "markdown", - "id": "73cd11da-ea3b-4081-b483-c4965dfefbc4", - "metadata": { - "id": "73cd11da-ea3b-4081-b483-c4965dfefbc4" - }, - "source": [ - "## 3) An alternative multi-head attention with combined weights" - ] - }, - { - "cell_type": "markdown", - "id": "1fa1a5ea-eaff-4d2d-aaf0-b34cdb6fd4dd", - "metadata": { - "id": "1fa1a5ea-eaff-4d2d-aaf0-b34cdb6fd4dd" - }, - "source": [ - "- The code for the `MultiHeadAttentionAlt` class below is based on code that was kindly shared by [Rayed Bin Wahed](https://github.com/rasbt/LLMs-from-scratch/discussions/51)\n", - "- The main difference between the `MultiHeadAttentionAlt` class and the `MultiHeadAttention` class used in chapter 3 is that `MultiHeadAttentionAlt` uses a single weight matrix, `self.qkv = nn.Linear(d_in, 3 * d_out, bias=qkv_bias)` instead of separate weight matrices:\n", - "\n", - " - `self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)`\n", - " - `self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)`\n", - " - `self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)`\n", - "\n", - "- Here, `self.qkv` combines all three weight matrices `self.W_query`, `self.W_key`, and `self.W_value` to carry out the query, key, and value computation in a single step\n", - "- Using `q, k, v = qkv.unbind(0)`, we obtain the individual query, key, and value tensors, which are then used similarly to the query, key, and value tensors in the `MultiHeadAttention` class in chapter 3" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "9a6bd0a2-f27c-4602-afa0-c96cd295c1a6", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "9a6bd0a2-f27c-4602-afa0-c96cd295c1a6", - "outputId": "9c4ffbe8-6684-429c-b86a-b68121341a4c" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "torch.Size([8, 1024, 768])\n" - ] - } - ], - "source": [ - "import torch.nn as nn\n", - "\n", - "\n", - "class MultiHeadAttentionCombinedQKV(nn.Module):\n", - " def __init__(self, d_in, d_out, num_heads, block_size, dropout=0.0, qkv_bias=False):\n", - " super().__init__()\n", - "\n", - " assert d_out % num_heads == 0, \"embed_dim is indivisible by num_heads\"\n", - "\n", - " self.num_heads = num_heads\n", - " self.block_size = block_size\n", - " self.head_dim = d_out // num_heads\n", - "\n", - " self.qkv = nn.Linear(d_in, 3 * d_out, bias=qkv_bias)\n", - " self.proj = nn.Linear(d_in, d_out)\n", - " self.dropout = nn.Dropout(dropout)\n", - "\n", - " self.register_buffer(\n", - " \"mask\", torch.triu(torch.ones(block_size, block_size), diagonal=1)\n", - " )\n", - "\n", - " def forward(self, x):\n", - " batch_size, num_tokens, embed_dim = x.shape\n", - "\n", - " # (b, num_tokens, embed_dim) --> (b, num_tokens, 3 * embed_dim)\n", - " qkv = self.qkv(x)\n", - "\n", - " # (b, num_tokens, 3 * embed_dim) --> (b, num_tokens, 3, num_heads, head_dim)\n", - " qkv = qkv.reshape(batch_size, num_tokens, 3, self.num_heads, self.head_dim)\n", - "\n", - " # (b, num_tokens, 3, num_heads, head_dim) --> (3, b, num_heads, num_tokens, head_dim)\n", - " qkv = qkv.permute(2, 0, 3, 1, 4)\n", - "\n", - " # (3, b, num_heads, num_tokens, head_dim) -> 3 times (b, num_head, num_tokens, head_dim)\n", - " queries, keys, values = qkv.unbind(0)\n", - "\n", - " # (b, num_heads, num_tokens, head_dim) --> (b, num_heads, num_tokens, num_tokens)\n", - " attn_scores = queries @ keys.transpose(-2, -1)\n", - " attn_scores = attn_scores.masked_fill(\n", - " self.mask.bool()[:num_tokens, :num_tokens], -torch.inf\n", - " )\n", - "\n", - " attn_weights = torch.softmax(attn_scores / keys.shape[-1]**-0.5, dim=-1)\n", - " attn_weights = self.dropout(attn_weights)\n", - "\n", - " # (b, num_heads, num_tokens, num_tokens) --> (b, num_heads, num_tokens, head_dim)\n", - " context_vec = attn_weights @ values\n", - "\n", - " # (b, num_heads, num_tokens, head_dim) --> (b, num_tokens, num_heads, head_dim)\n", - " context_vec = context_vec.transpose(1, 2)\n", - "\n", - " # (b, num_tokens, num_heads, head_dim) --> (b, num_tokens, embed_dim)\n", - " context_vec = context_vec.reshape(batch_size, num_tokens, embed_dim)\n", - "\n", - " context_vec = self.proj(context_vec)\n", - "\n", - " return context_vec\n", - "\n", - "\n", - "mha_combined_qkv = MultiHeadAttentionCombinedQKV(\n", - " d_in=embed_dim,\n", - " d_out=embed_dim,\n", - " block_size=context_len,\n", - " dropout=0.0,\n", - " num_heads=12,\n", - " qkv_bias=False\n", - ").to(device)\n", - "\n", - "out = mha_combined_qkv(embeddings)\n", - "print(out.shape)" - ] - }, - { - "cell_type": "markdown", - "id": "48a042d3-ee78-4c29-bf63-d92fe6706632", - "metadata": { - "id": "48a042d3-ee78-4c29-bf63-d92fe6706632" - }, - "source": [ - "## 4) Multihead attention with PyTorch's scaled dot product attention" - ] - }, - { - "cell_type": "markdown", - "id": "f78e346f-3b85-44e6-9feb-f01131381148", - "metadata": { - "id": "f78e346f-3b85-44e6-9feb-f01131381148" - }, - "source": [ - "- The implementation below uses PyTorch's [`scaled_dot_product_attention`](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) function, which implements a memory-optimized version of self-attention calld [flash attention](https://arxiv.org/abs/2205.14135)" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "1b8e5a0d-1f65-4a03-bf6e-723f0cc428f5", - "metadata": { - "id": "1b8e5a0d-1f65-4a03-bf6e-723f0cc428f5" - }, - "outputs": [], - "source": [ - "class MHAPyTorchScaledDotProduct(nn.Module):\n", - " def __init__(self, d_in, d_out, num_heads, block_size, dropout=0.0, qkv_bias=False):\n", - " super().__init__()\n", - "\n", - " assert d_out % num_heads == 0, \"embed_dim is indivisible by num_heads\"\n", - "\n", - " self.num_heads = num_heads\n", - " self.block_size = block_size\n", - " self.head_dim = d_out // num_heads\n", - " self.d_out = d_out\n", - "\n", - " self.qkv = nn.Linear(d_in, 3 * d_out, bias=qkv_bias)\n", - " self.proj = nn.Linear(d_in, d_out)\n", - " self.dropout = dropout\n", - "\n", - " self.register_buffer(\n", - " \"mask\", torch.triu(torch.ones(block_size, block_size), diagonal=1)\n", - " )\n", - "\n", - " def forward(self, x):\n", - " batch_size, num_tokens, embed_dim = x.shape\n", - "\n", - " # (b, num_tokens, embed_dim) --> (b, num_tokens, 3 * embed_dim)\n", - " qkv = self.qkv(x)\n", - "\n", - " # (b, num_tokens, 3 * embed_dim) --> (b, num_tokens, 3, num_heads, head_dim)\n", - " qkv = qkv.reshape(batch_size, num_tokens, 3, self.num_heads, self.head_dim)\n", - "\n", - " # (b, num_tokens, 3, num_heads, head_dim) --> (3, b, num_heads, num_tokens, head_dim)\n", - " qkv = qkv.permute(2, 0, 3, 1, 4)\n", - "\n", - " # (3, b, num_heads, num_tokens, head_dim) -> 3 times (b, num_heads, num_tokens, head_dim)\n", - " queries, keys, values = qkv.unbind(0)\n", - "\n", - " use_dropout = 0. if not self.training else self.dropout\n", - " context_vec = nn.functional.scaled_dot_product_attention(\n", - " queries, keys, values, attn_mask=None, dropout_p=use_dropout, is_causal=True)\n", - "\n", - " # Combine heads, where self.d_out = self.num_heads * self.head_dim\n", - " context_vec = context_vec.transpose(1, 2).contiguous().view(batch_size, num_tokens, self.d_out)\n", - "\n", - " return context_vec" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "fbc8ba92-3471-41cb-b1b2-4c0ef5be392b", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "fbc8ba92-3471-41cb-b1b2-4c0ef5be392b", - "outputId": "027b5a66-4e17-49e8-9e80-9c70eaf201ab" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "torch.Size([8, 1024, 768])\n" - ] - } - ], - "source": [ - "mha_pytorch_scaled = MHAPyTorchScaledDotProduct(\n", - " d_in=embed_dim,\n", - " d_out=embed_dim,\n", - " block_size=context_len,\n", - " dropout=0.0,\n", - " num_heads=12,\n", - " qkv_bias=False\n", - ").to(device)\n", - "\n", - "out = mha_pytorch_scaled(embeddings)\n", - "print(out.shape)" - ] - }, - { - "cell_type": "markdown", - "id": "351c318f-4835-4d74-8d58-a070222447c4", - "metadata": { - "id": "351c318f-4835-4d74-8d58-a070222447c4" - }, - "source": [ - "## 5) Using PyTorch's torch.nn.MultiheadAttention" - ] - }, - { - "cell_type": "markdown", - "id": "74a6d060-6324-48fa-a35c-cb09f2a48965", - "metadata": { - "id": "74a6d060-6324-48fa-a35c-cb09f2a48965" - }, - "source": [ - "- Below, we use PyTorch's [torch.nn.MultiheadAttention](https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html) implementation" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "3799c7ef-3155-42c6-a829-f95656453ae0", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "3799c7ef-3155-42c6-a829-f95656453ae0", - "outputId": "9d9afbbd-2e85-44cb-afc9-8cb3c91e8368" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "torch.Size([8, 1024, 768])\n" - ] - } - ], - "source": [ - "class MHAPyTorchClass(nn.Module):\n", - " def __init__(self, d_in, d_out, num_heads, block_size, dropout=0.0, qkv_bias=False):\n", - " super().__init__()\n", - "\n", - " self.block_size = block_size\n", - " self.multihead_attn = nn.MultiheadAttention(\n", - " embed_dim=d_out,\n", - " num_heads=num_heads,\n", - " dropout=dropout,\n", - " bias=qkv_bias,\n", - " add_bias_kv=qkv_bias,\n", - " batch_first=True\n", - " )\n", - "\n", - " self.proj = nn.Linear(d_out, d_out)\n", - " self.register_buffer(\"mask\", torch.triu(torch.ones(block_size, block_size), diagonal=1).bool())\n", - "\n", - " def forward(self, x):\n", - " batch_size, num_tokens, _ = x.shape\n", - "\n", - " # Ensure attn_mask is compatible with expected shape and `batch_first=True`\n", - " # No need to manually adjust for num_heads; ensure it's right for the sequence\n", - " if self.block_size >= num_tokens:\n", - " attn_mask = self.mask[:num_tokens, :num_tokens]\n", - " else:\n", - " attn_mask = self.mask[:self.block_size, :self.block_size]\n", - "\n", - " # attn_mask broadcasting will handle batch_size dimension implicitly\n", - " attn_output, _ = self.multihead_attn(x, x, x, attn_mask=attn_mask)\n", - "\n", - " output = self.proj(attn_output)\n", - "\n", - " return output\n", - "\n", - "\n", - "mha_pytorch_class = MHAPyTorchClass(\n", - " d_in=embed_dim,\n", - " d_out=embed_dim,\n", - " block_size=context_len,\n", - " dropout=0.0,\n", - " num_heads=12,\n", - " qkv_bias=False\n", - ").to(device)\n", - "\n", - "out = mha_pytorch_class(embeddings)\n", - "print(out.shape)" - ] - }, - { - "cell_type": "markdown", - "id": "8877de71-f84f-4f6d-bc87-7552013b6301", - "metadata": { - "id": "8877de71-f84f-4f6d-bc87-7552013b6301" - }, - "source": [ - "## Speed comparison" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "a97c0b2e-6593-49d8-98bc-2267b3aa610f", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "a97c0b2e-6593-49d8-98bc-2267b3aa610f", - "outputId": "ebe635b2-5c03-4e9b-da3a-951d308acf7b" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "41.1 ms ± 12.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" - ] - } - ], - "source": [ - "## 1) CausalAttention MHA wrapper class from chapter 3\n", - "%timeit mha_ch03_wrapper(embeddings)" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "19db9c2c-8e75-431a-8eef-0b4d8284e6e6", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "19db9c2c-8e75-431a-8eef-0b4d8284e6e6", - "outputId": "c6e7bcff-661c-45a6-da82-b1e3f89cf761" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "6.58 ms ± 143 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" - ] - } - ], - "source": [ - "## 2) The multi-head attention class from chapter 3\n", - "%timeit mha_ch03(embeddings)" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "aa526ee0-7a88-4f34-a49a-f8f97da83779", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "aa526ee0-7a88-4f34-a49a-f8f97da83779", - "outputId": "92b634f8-43f8-468f-87a1-bb774b64c212" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "7.19 ms ± 294 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" - ] - } - ], - "source": [ - "## 3) An alternative multi-head attention with combined weights\n", - "%timeit mha_combined_qkv(embeddings)" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "cc2b4256-16d8-4c34-9fd0-d4b4af0e60fa", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "cc2b4256-16d8-4c34-9fd0-d4b4af0e60fa", - "outputId": "80c6e314-0771-470e-b090-628984ce2d85" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "2.37 ms ± 432 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n" - ] - } - ], - "source": [ - "## 4) Multihead attention with PyTorch's scaled dot product attention\n", - "%timeit mha_pytorch_scaled(embeddings)" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "id": "0f209e70-ebb6-4a1a-b608-1ff42e41c01d", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "0f209e70-ebb6-4a1a-b608-1ff42e41c01d", - "outputId": "3cd37b53-04d4-4dd0-9450-6fc8ebaac083" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "6.66 ms ± 397 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" - ] - } - ], - "source": [ - "## 5) Using PyTorch's torch.nn.MultiheadAttention\n", - "%timeit mha_pytorch_class(embeddings)" - ] - } - ], - "metadata": { - "accelerator": "GPU", - "colab": { - "gpuType": "A100", - "machine_shape": "hm", - "provenance": [] - }, - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.6" - } + "cells": [ + { + "cell_type": "markdown", + "id": "6f678e62-7bcb-4405-86ae-dce94f494303", + "metadata": { + "id": "6f678e62-7bcb-4405-86ae-dce94f494303" + }, + "source": [ + "# Efficient Multi-Head Attention Implementations" + ] }, - "nbformat": 4, - "nbformat_minor": 5 -} \ No newline at end of file + { + "cell_type": "code", + "execution_count": 1, + "id": "7898551e-f582-48ac-9f66-3632abe2a93f", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "7898551e-f582-48ac-9f66-3632abe2a93f", + "outputId": "999a54ca-36b5-4e26-e25a-94953b4d1590" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Running on cpu\n" + ] + } + ], + "source": [ + "import torch\n", + "\n", + "torch.manual_seed(123)\n", + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "print(f\"Running on {device}\")\n", + "\n", + "batch_size = 8\n", + "context_len = 1024\n", + "embed_dim = 768\n", + "embeddings = torch.randn((batch_size, context_len, embed_dim), device=device)" + ] + }, + { + "cell_type": "markdown", + "id": "2f9bb1b6-a1e5-4e0a-884d-0f31b374a8d6", + "metadata": { + "id": "2f9bb1b6-a1e5-4e0a-884d-0f31b374a8d6" + }, + "source": [ + "## 1) CausalAttention MHA wrapper class from chapter 3" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "297c93ed-aec0-4896-bb89-42c4b294d3d1", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "297c93ed-aec0-4896-bb89-42c4b294d3d1", + "outputId": "4d694519-8283-49a6-b27f-c2ae06a2fa4e" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([8, 1024, 9216])\n" + ] + } + ], + "source": [ + "from ch03 import MultiHeadAttentionWrapper as Ch03_MHA_Wrapper\n", + "\n", + "mha_ch03_wrapper = Ch03_MHA_Wrapper(\n", + " d_in=embed_dim,\n", + " d_out=embed_dim,\n", + " block_size=context_len,\n", + " dropout=0.0,\n", + " num_heads=12,\n", + " qkv_bias=False\n", + ").to(device)\n", + "\n", + "out = mha_ch03_wrapper(embeddings)\n", + "print(out.shape)" + ] + }, + { + "cell_type": "markdown", + "id": "21930804-b327-40b1-8e63-94dcad39ce7b", + "metadata": { + "id": "21930804-b327-40b1-8e63-94dcad39ce7b" + }, + "source": [ + "## 2) The multi-head attention class from chapter 3" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "4ee6a61b-d25c-4a0c-8a59-f285544e3710", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "4ee6a61b-d25c-4a0c-8a59-f285544e3710", + "outputId": "f7cd4e03-5622-4dd6-a9e0-31b80e15838b" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([8, 1024, 768])\n" + ] + } + ], + "source": [ + "from ch03 import MultiHeadAttention as Ch03_MHA\n", + "\n", + "mha_ch03 = Ch03_MHA(\n", + " d_in=embed_dim,\n", + " d_out=embed_dim,\n", + " block_size=context_len,\n", + " dropout=0.0,\n", + " num_heads=12,\n", + " qkv_bias=False\n", + ").to(device)\n", + "\n", + "out = mha_ch03(embeddings)\n", + "print(out.shape)" + ] + }, + { + "cell_type": "markdown", + "id": "73cd11da-ea3b-4081-b483-c4965dfefbc4", + "metadata": { + "id": "73cd11da-ea3b-4081-b483-c4965dfefbc4" + }, + "source": [ + "## 3) An alternative multi-head attention with combined weights" + ] + }, + { + "cell_type": "markdown", + "id": "1fa1a5ea-eaff-4d2d-aaf0-b34cdb6fd4dd", + "metadata": { + "id": "1fa1a5ea-eaff-4d2d-aaf0-b34cdb6fd4dd" + }, + "source": [ + "- The code for the `MultiHeadAttentionAlt` class below is based on code that was kindly shared by [Rayed Bin Wahed](https://github.com/rasbt/LLMs-from-scratch/discussions/51)\n", + "- The main difference between the `MultiHeadAttentionAlt` class and the `MultiHeadAttention` class used in chapter 3 is that `MultiHeadAttentionAlt` uses a single weight matrix, `self.qkv = nn.Linear(d_in, 3 * d_out, bias=qkv_bias)` instead of separate weight matrices:\n", + "\n", + " - `self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)`\n", + " - `self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)`\n", + " - `self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)`\n", + "\n", + "- Here, `self.qkv` combines all three weight matrices `self.W_query`, `self.W_key`, and `self.W_value` to carry out the query, key, and value computation in a single step\n", + "- Using `q, k, v = qkv.unbind(0)`, we obtain the individual query, key, and value tensors, which are then used similarly to the query, key, and value tensors in the `MultiHeadAttention` class in chapter 3" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "9a6bd0a2-f27c-4602-afa0-c96cd295c1a6", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "9a6bd0a2-f27c-4602-afa0-c96cd295c1a6", + "outputId": "86c643c3-c1dd-4021-d7e6-b79e3cea3d99" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([8, 1024, 768])\n" + ] + } + ], + "source": [ + "import torch.nn as nn\n", + "\n", + "\n", + "class MultiHeadAttentionCombinedQKV(nn.Module):\n", + " def __init__(self, d_in, d_out, num_heads, block_size, dropout=0.0, qkv_bias=False):\n", + " super().__init__()\n", + "\n", + " assert d_out % num_heads == 0, \"embed_dim is indivisible by num_heads\"\n", + "\n", + " self.num_heads = num_heads\n", + " self.block_size = block_size\n", + " self.head_dim = d_out // num_heads\n", + "\n", + " self.qkv = nn.Linear(d_in, 3 * d_out, bias=qkv_bias)\n", + " self.proj = nn.Linear(d_in, d_out)\n", + " self.dropout = nn.Dropout(dropout)\n", + "\n", + " self.register_buffer(\n", + " \"mask\", torch.triu(torch.ones(block_size, block_size), diagonal=1)\n", + " )\n", + "\n", + " def forward(self, x):\n", + " batch_size, num_tokens, embed_dim = x.shape\n", + "\n", + " # (b, num_tokens, embed_dim) --> (b, num_tokens, 3 * embed_dim)\n", + " qkv = self.qkv(x)\n", + "\n", + " # (b, num_tokens, 3 * embed_dim) --> (b, num_tokens, 3, num_heads, head_dim)\n", + " qkv = qkv.reshape(batch_size, num_tokens, 3, self.num_heads, self.head_dim)\n", + "\n", + " # (b, num_tokens, 3, num_heads, head_dim) --> (3, b, num_heads, num_tokens, head_dim)\n", + " qkv = qkv.permute(2, 0, 3, 1, 4)\n", + "\n", + " # (3, b, num_heads, num_tokens, head_dim) -> 3 times (b, num_head, num_tokens, head_dim)\n", + " queries, keys, values = qkv.unbind(0)\n", + "\n", + " # (b, num_heads, num_tokens, head_dim) --> (b, num_heads, num_tokens, num_tokens)\n", + " attn_scores = queries @ keys.transpose(-2, -1)\n", + " attn_scores = attn_scores.masked_fill(\n", + " self.mask.bool()[:num_tokens, :num_tokens], -torch.inf\n", + " )\n", + "\n", + " attn_weights = torch.softmax(attn_scores / keys.shape[-1]**-0.5, dim=-1)\n", + " attn_weights = self.dropout(attn_weights)\n", + "\n", + " # (b, num_heads, num_tokens, num_tokens) --> (b, num_heads, num_tokens, head_dim)\n", + " context_vec = attn_weights @ values\n", + "\n", + " # (b, num_heads, num_tokens, head_dim) --> (b, num_tokens, num_heads, head_dim)\n", + " context_vec = context_vec.transpose(1, 2)\n", + "\n", + " # (b, num_tokens, num_heads, head_dim) --> (b, num_tokens, embed_dim)\n", + " context_vec = context_vec.reshape(batch_size, num_tokens, embed_dim)\n", + "\n", + " context_vec = self.proj(context_vec)\n", + "\n", + " return context_vec\n", + "\n", + "\n", + "mha_combined_qkv = MultiHeadAttentionCombinedQKV(\n", + " d_in=embed_dim,\n", + " d_out=embed_dim,\n", + " block_size=context_len,\n", + " dropout=0.0,\n", + " num_heads=12,\n", + " qkv_bias=False\n", + ").to(device)\n", + "\n", + "out = mha_combined_qkv(embeddings)\n", + "print(out.shape)" + ] + }, + { + "cell_type": "markdown", + "id": "48a042d3-ee78-4c29-bf63-d92fe6706632", + "metadata": { + "id": "48a042d3-ee78-4c29-bf63-d92fe6706632" + }, + "source": [ + "## 4) Multihead attention with PyTorch's scaled dot product attention" + ] + }, + { + "cell_type": "markdown", + "id": "f78e346f-3b85-44e6-9feb-f01131381148", + "metadata": { + "id": "f78e346f-3b85-44e6-9feb-f01131381148" + }, + "source": [ + "- The implementation below uses PyTorch's [`scaled_dot_product_attention`](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) function, which implements a memory-optimized version of self-attention calld [flash attention](https://arxiv.org/abs/2205.14135)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "1b8e5a0d-1f65-4a03-bf6e-723f0cc428f5", + "metadata": { + "id": "1b8e5a0d-1f65-4a03-bf6e-723f0cc428f5" + }, + "outputs": [], + "source": [ + "class MHAPyTorchScaledDotProduct(nn.Module):\n", + " def __init__(self, d_in, d_out, num_heads, block_size, dropout=0.0, qkv_bias=False):\n", + " super().__init__()\n", + "\n", + " assert d_out % num_heads == 0, \"embed_dim is indivisible by num_heads\"\n", + "\n", + " self.num_heads = num_heads\n", + " self.block_size = block_size\n", + " self.head_dim = d_out // num_heads\n", + " self.d_out = d_out\n", + "\n", + " self.qkv = nn.Linear(d_in, 3 * d_out, bias=qkv_bias)\n", + " self.proj = nn.Linear(d_in, d_out)\n", + " self.dropout = dropout\n", + "\n", + " self.register_buffer(\n", + " \"mask\", torch.triu(torch.ones(block_size, block_size), diagonal=1)\n", + " )\n", + "\n", + " def forward(self, x):\n", + " batch_size, num_tokens, embed_dim = x.shape\n", + "\n", + " # (b, num_tokens, embed_dim) --> (b, num_tokens, 3 * embed_dim)\n", + " qkv = self.qkv(x)\n", + "\n", + " # (b, num_tokens, 3 * embed_dim) --> (b, num_tokens, 3, num_heads, head_dim)\n", + " qkv = qkv.reshape(batch_size, num_tokens, 3, self.num_heads, self.head_dim)\n", + "\n", + " # (b, num_tokens, 3, num_heads, head_dim) --> (3, b, num_heads, num_tokens, head_dim)\n", + " qkv = qkv.permute(2, 0, 3, 1, 4)\n", + "\n", + " # (3, b, num_heads, num_tokens, head_dim) -> 3 times (b, num_heads, num_tokens, head_dim)\n", + " queries, keys, values = qkv.unbind(0)\n", + "\n", + " use_dropout = 0. if not self.training else self.dropout\n", + " context_vec = nn.functional.scaled_dot_product_attention(\n", + " queries, keys, values, attn_mask=None, dropout_p=use_dropout, is_causal=True)\n", + "\n", + " # Combine heads, where self.d_out = self.num_heads * self.head_dim\n", + " context_vec = context_vec.transpose(1, 2).contiguous().view(batch_size, num_tokens, self.d_out)\n", + "\n", + " return context_vec" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "fbc8ba92-3471-41cb-b1b2-4c0ef5be392b", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "fbc8ba92-3471-41cb-b1b2-4c0ef5be392b", + "outputId": "ff45b9c9-19a7-4769-efaf-b9f417a31631" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([8, 1024, 768])\n" + ] + } + ], + "source": [ + "mha_pytorch_scaled = MHAPyTorchScaledDotProduct(\n", + " d_in=embed_dim,\n", + " d_out=embed_dim,\n", + " block_size=context_len,\n", + " dropout=0.0,\n", + " num_heads=12,\n", + " qkv_bias=False\n", + ").to(device)\n", + "\n", + "out = mha_pytorch_scaled(embeddings)\n", + "print(out.shape)" + ] + }, + { + "cell_type": "markdown", + "id": "351c318f-4835-4d74-8d58-a070222447c4", + "metadata": { + "id": "351c318f-4835-4d74-8d58-a070222447c4" + }, + "source": [ + "## 5) Using PyTorch's torch.nn.MultiheadAttention" + ] + }, + { + "cell_type": "markdown", + "id": "74a6d060-6324-48fa-a35c-cb09f2a48965", + "metadata": { + "id": "74a6d060-6324-48fa-a35c-cb09f2a48965" + }, + "source": [ + "- Below, we use PyTorch's [torch.nn.MultiheadAttention](https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html) implementation" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "3799c7ef-3155-42c6-a829-f95656453ae0", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "3799c7ef-3155-42c6-a829-f95656453ae0", + "outputId": "1a6ce118-be18-4d7b-99bc-a0c9bb78f8b1" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([8, 1024, 768])\n" + ] + } + ], + "source": [ + "import torch.nn as nn\n", + "\n", + "\n", + "class MHAPyTorchClass(nn.Module):\n", + " def __init__(self, d_in, d_out, num_heads, block_size, dropout=0.0, qkv_bias=False, need_weights=True):\n", + " super().__init__()\n", + "\n", + " self.block_size = block_size\n", + " self.multihead_attn = nn.MultiheadAttention(\n", + " embed_dim=d_out,\n", + " num_heads=num_heads,\n", + " dropout=dropout,\n", + " bias=qkv_bias,\n", + " add_bias_kv=qkv_bias,\n", + " batch_first=True,\n", + " )\n", + "\n", + " self.need_weights = need_weights\n", + " self.proj = nn.Linear(d_out, d_out)\n", + " self.register_buffer(\"mask\", torch.triu(torch.ones(block_size, block_size), diagonal=1).bool())\n", + "\n", + " def forward(self, x):\n", + " batch_size, num_tokens, _ = x.shape\n", + "\n", + " # Ensure attn_mask is compatible with expected shape and `batch_first=True`\n", + " # No need to manually adjust for num_heads; ensure it's right for the sequence\n", + " if self.block_size >= num_tokens:\n", + " attn_mask = self.mask[:num_tokens, :num_tokens]\n", + " else:\n", + " attn_mask = self.mask[:self.block_size, :self.block_size]\n", + "\n", + " # attn_mask broadcasting will handle batch_size dimension implicitly\n", + " attn_output, _ = self.multihead_attn(\n", + " x, x, x, attn_mask=attn_mask, need_weights=self.need_weights\n", + " )\n", + "\n", + " output = self.proj(attn_output)\n", + "\n", + " return output\n", + "\n", + "\n", + "mha_pytorch_class_default = MHAPyTorchClass(\n", + " d_in=embed_dim,\n", + " d_out=embed_dim,\n", + " block_size=context_len,\n", + " dropout=0.0,\n", + " num_heads=12,\n", + " qkv_bias=False\n", + ").to(device)\n", + "\n", + "out = mha_pytorch_class_default(embeddings)\n", + "print(out.shape)" + ] + }, + { + "cell_type": "markdown", + "id": "a3953bff-1056-4de2-bfd1-dfccf659eee4", + "metadata": {}, + "source": [ + "## 6) Using PyTorch's torch.nn.MultiheadAttention with `scaled_dot_product_attention`" + ] + }, + { + "cell_type": "markdown", + "id": "d2164859-31a0-4537-b4fb-27d57675ba77", + "metadata": { + "id": "d2164859-31a0-4537-b4fb-27d57675ba77" + }, + "source": [ + "- Set `need_weights` (default `True`) to need_weights=False so that MultiheadAttention uses `scaled_dot_product_attention` [according to the documentation](https://github.com/pytorch/pytorch/blob/71d020262793542974cf13b30f2a9099773f015c/torch/nn/modules/activation.py#L1096)\n", + "\n", + "> need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``.\n", + " Set ``need_weights=False`` to use the optimized ``scaled_dot_product_attention``\n", + " and achieve the best performance for MHA.\n", + " Default: ``True``." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "4a4c2afe-5e1f-4bd7-a118-67031176f147", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "4a4c2afe-5e1f-4bd7-a118-67031176f147", + "outputId": "ef1a0698-5b18-426d-df0f-00bdbaeeaccc" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([8, 1024, 768])\n" + ] + } + ], + "source": [ + "mha_pytorch_class_noweights = MHAPyTorchClass(\n", + " d_in=embed_dim,\n", + " d_out=embed_dim,\n", + " block_size=context_len,\n", + " dropout=0.0,\n", + " num_heads=12,\n", + " qkv_bias=False,\n", + " need_weights=False # NEW!\n", + ").to(device)\n", + "\n", + "out = mha_pytorch_class_noweights(embeddings)\n", + "print(out.shape)" + ] + }, + { + "cell_type": "markdown", + "id": "8877de71-f84f-4f6d-bc87-7552013b6301", + "metadata": { + "id": "8877de71-f84f-4f6d-bc87-7552013b6301" + }, + "source": [ + "## Speed comparison (M1 Macbook Air CPU)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "a97c0b2e-6593-49d8-98bc-2267b3aa610f", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "a97c0b2e-6593-49d8-98bc-2267b3aa610f", + "outputId": "ebe635b2-5c03-4e9b-da3a-951d308acf7b" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "914 ms ± 50.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" + ] + } + ], + "source": [ + "## 1) CausalAttention MHA wrapper class from chapter 3\n", + "%timeit mha_ch03_wrapper(embeddings)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "19db9c2c-8e75-431a-8eef-0b4d8284e6e6", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "19db9c2c-8e75-431a-8eef-0b4d8284e6e6", + "outputId": "c6e7bcff-661c-45a6-da82-b1e3f89cf761" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "252 ms ± 9.04 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" + ] + } + ], + "source": [ + "## 2) The multi-head attention class from chapter 3\n", + "%timeit mha_ch03(embeddings)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "aa526ee0-7a88-4f34-a49a-f8f97da83779", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "aa526ee0-7a88-4f34-a49a-f8f97da83779", + "outputId": "92b634f8-43f8-468f-87a1-bb774b64c212" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "300 ms ± 8.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" + ] + } + ], + "source": [ + "## 3) An alternative multi-head attention with combined weights\n", + "%timeit mha_combined_qkv(embeddings)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "cc2b4256-16d8-4c34-9fd0-d4b4af0e60fa", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "cc2b4256-16d8-4c34-9fd0-d4b4af0e60fa", + "outputId": "80c6e314-0771-470e-b090-628984ce2d85" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "94.2 ms ± 1.6 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" + ] + } + ], + "source": [ + "## 4) Multihead attention with PyTorch's scaled dot product attention\n", + "%timeit mha_pytorch_scaled(embeddings)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "0f209e70-ebb6-4a1a-b608-1ff42e41c01d", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "0f209e70-ebb6-4a1a-b608-1ff42e41c01d", + "outputId": "3cd37b53-04d4-4dd0-9450-6fc8ebaac083" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "297 ms ± 2.37 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" + ] + } + ], + "source": [ + "## 5) Using PyTorch's torch.nn.MultiheadAttention\n", + "%timeit mha_pytorch_class_default(embeddings)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "3f4968c2-8d40-4ab9-8dba-052b4f77d756", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "274 ms ± 2.19 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" + ] + } + ], + "source": [ + "## 6) Using PyTorch's torch.nn.MultiheadAttention disabling `need_weights`\n", + "%timeit mha_pytorch_class_noweights(embeddings)" + ] + }, + { + "cell_type": "markdown", + "id": "a78ff594-6cc2-496d-a302-789fa104c3c9", + "metadata": { + "id": "8877de71-f84f-4f6d-bc87-7552013b6301" + }, + "source": [ + "## Speed comparison (Nvidia A100 GPU)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "4d21edc6", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "a97c0b2e-6593-49d8-98bc-2267b3aa610f", + "outputId": "ebe635b2-5c03-4e9b-da3a-951d308acf7b" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "914 ms ± 50.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" + ] + } + ], + "source": [ + "## 1) CausalAttention MHA wrapper class from chapter 3\n", + "%timeit mha_ch03_wrapper(embeddings)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "98fda51b", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "19db9c2c-8e75-431a-8eef-0b4d8284e6e6", + "outputId": "c6e7bcff-661c-45a6-da82-b1e3f89cf761" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "252 ms ± 9.04 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" + ] + } + ], + "source": [ + "## 2) The multi-head attention class from chapter 3\n", + "%timeit mha_ch03(embeddings)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "af8d11a7", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "aa526ee0-7a88-4f34-a49a-f8f97da83779", + "outputId": "92b634f8-43f8-468f-87a1-bb774b64c212" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "300 ms ± 8.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" + ] + } + ], + "source": [ + "## 3) An alternative multi-head attention with combined weights\n", + "%timeit mha_combined_qkv(embeddings)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "de1e9a77", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "cc2b4256-16d8-4c34-9fd0-d4b4af0e60fa", + "outputId": "80c6e314-0771-470e-b090-628984ce2d85" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "94.2 ms ± 1.6 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" + ] + } + ], + "source": [ + "## 4) Multihead attention with PyTorch's scaled dot product attention\n", + "%timeit mha_pytorch_scaled(embeddings)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "481e3fea", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "0f209e70-ebb6-4a1a-b608-1ff42e41c01d", + "outputId": "3cd37b53-04d4-4dd0-9450-6fc8ebaac083" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "297 ms ± 2.37 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" + ] + } + ], + "source": [ + "## 5) Using PyTorch's torch.nn.MultiheadAttention\n", + "%timeit mha_pytorch_class_default(embeddings)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "9d52a9eb", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "274 ms ± 2.19 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" + ] + } + ], + "source": [ + "## 6) Using PyTorch's torch.nn.MultiheadAttention disabling `need_weights`\n", + "%timeit mha_pytorch_class_noweights(embeddings)" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "A100", + "machine_shape": "hm", + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}