{ "cells": [ { "cell_type": "markdown", "id": "6f678e62-7bcb-4405-86ae-dce94f494303", "metadata": { "id": "6f678e62-7bcb-4405-86ae-dce94f494303" }, "source": [ "# Efficient Multi-Head Attention Implementations" ] }, { "cell_type": "code", "execution_count": 1, "id": "7898551e-f582-48ac-9f66-3632abe2a93f", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "7898551e-f582-48ac-9f66-3632abe2a93f", "outputId": "999a54ca-36b5-4e26-e25a-94953b4d1590" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Running on cpu\n" ] } ], "source": [ "import torch\n", "\n", "torch.manual_seed(123)\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "print(f\"Running on {device}\")\n", "\n", "batch_size = 8\n", "context_len = 1024\n", "embed_dim = 768\n", "embeddings = torch.randn((batch_size, context_len, embed_dim), device=device)" ] }, { "cell_type": "markdown", "id": "2f9bb1b6-a1e5-4e0a-884d-0f31b374a8d6", "metadata": { "id": "2f9bb1b6-a1e5-4e0a-884d-0f31b374a8d6" }, "source": [ "## 1) CausalAttention MHA wrapper class from chapter 3" ] }, { "cell_type": "code", "execution_count": 2, "id": "297c93ed-aec0-4896-bb89-42c4b294d3d1", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "297c93ed-aec0-4896-bb89-42c4b294d3d1", "outputId": "4d694519-8283-49a6-b27f-c2ae06a2fa4e" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([8, 1024, 9216])\n" ] } ], "source": [ "from ch03 import MultiHeadAttentionWrapper as Ch03_MHA_Wrapper\n", "\n", "mha_ch03_wrapper = Ch03_MHA_Wrapper(\n", " d_in=embed_dim,\n", " d_out=embed_dim,\n", " block_size=context_len,\n", " dropout=0.0,\n", " num_heads=12,\n", " qkv_bias=False\n", ").to(device)\n", "\n", "out = mha_ch03_wrapper(embeddings)\n", "print(out.shape)" ] }, { "cell_type": "markdown", "id": "21930804-b327-40b1-8e63-94dcad39ce7b", "metadata": { "id": "21930804-b327-40b1-8e63-94dcad39ce7b" }, "source": [ "## 2) The multi-head attention class from chapter 3" ] }, { "cell_type": "code", "execution_count": 3, "id": "4ee6a61b-d25c-4a0c-8a59-f285544e3710", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "4ee6a61b-d25c-4a0c-8a59-f285544e3710", "outputId": "f7cd4e03-5622-4dd6-a9e0-31b80e15838b" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([8, 1024, 768])\n" ] } ], "source": [ "from ch03 import MultiHeadAttention as Ch03_MHA\n", "\n", "mha_ch03 = Ch03_MHA(\n", " d_in=embed_dim,\n", " d_out=embed_dim,\n", " block_size=context_len,\n", " dropout=0.0,\n", " num_heads=12,\n", " qkv_bias=False\n", ").to(device)\n", "\n", "out = mha_ch03(embeddings)\n", "print(out.shape)" ] }, { "cell_type": "markdown", "id": "73cd11da-ea3b-4081-b483-c4965dfefbc4", "metadata": { "id": "73cd11da-ea3b-4081-b483-c4965dfefbc4" }, "source": [ "## 3) An alternative multi-head attention with combined weights" ] }, { "cell_type": "markdown", "id": "1fa1a5ea-eaff-4d2d-aaf0-b34cdb6fd4dd", "metadata": { "id": "1fa1a5ea-eaff-4d2d-aaf0-b34cdb6fd4dd" }, "source": [ "- The code for the `MultiHeadAttentionAlt` class below is based on code that was kindly shared by [Rayed Bin Wahed](https://github.com/rasbt/LLMs-from-scratch/discussions/51)\n", "- The main difference between the `MultiHeadAttentionAlt` class and the `MultiHeadAttention` class used in chapter 3 is that `MultiHeadAttentionAlt` uses a single weight matrix, `self.qkv = nn.Linear(d_in, 3 * d_out, bias=qkv_bias)` instead of separate weight matrices:\n", "\n", " - `self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)`\n", " - `self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)`\n", " - `self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)`\n", "\n", "- Here, `self.qkv` combines all three weight matrices `self.W_query`, `self.W_key`, and `self.W_value` to carry out the query, key, and value computation in a single step\n", "- Using `q, k, v = qkv.unbind(0)`, we obtain the individual query, key, and value tensors, which are then used similarly to the query, key, and value tensors in the `MultiHeadAttention` class in chapter 3" ] }, { "cell_type": "code", "execution_count": 4, "id": "9a6bd0a2-f27c-4602-afa0-c96cd295c1a6", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "9a6bd0a2-f27c-4602-afa0-c96cd295c1a6", "outputId": "86c643c3-c1dd-4021-d7e6-b79e3cea3d99" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([8, 1024, 768])\n" ] } ], "source": [ "import torch.nn as nn\n", "\n", "\n", "class MultiHeadAttentionCombinedQKV(nn.Module):\n", " def __init__(self, d_in, d_out, num_heads, block_size, dropout=0.0, qkv_bias=False):\n", " super().__init__()\n", "\n", " assert d_out % num_heads == 0, \"embed_dim is indivisible by num_heads\"\n", "\n", " self.num_heads = num_heads\n", " self.block_size = block_size\n", " self.head_dim = d_out // num_heads\n", "\n", " self.qkv = nn.Linear(d_in, 3 * d_out, bias=qkv_bias)\n", " self.proj = nn.Linear(d_in, d_out)\n", " self.dropout = nn.Dropout(dropout)\n", "\n", " self.register_buffer(\n", " \"mask\", torch.triu(torch.ones(block_size, block_size), diagonal=1)\n", " )\n", "\n", " def forward(self, x):\n", " batch_size, num_tokens, embed_dim = x.shape\n", "\n", " # (b, num_tokens, embed_dim) --> (b, num_tokens, 3 * embed_dim)\n", " qkv = self.qkv(x)\n", "\n", " # (b, num_tokens, 3 * embed_dim) --> (b, num_tokens, 3, num_heads, head_dim)\n", " qkv = qkv.reshape(batch_size, num_tokens, 3, self.num_heads, self.head_dim)\n", "\n", " # (b, num_tokens, 3, num_heads, head_dim) --> (3, b, num_heads, num_tokens, head_dim)\n", " qkv = qkv.permute(2, 0, 3, 1, 4)\n", "\n", " # (3, b, num_heads, num_tokens, head_dim) -> 3 times (b, num_head, num_tokens, head_dim)\n", " queries, keys, values = qkv.unbind(0)\n", "\n", " # (b, num_heads, num_tokens, head_dim) --> (b, num_heads, num_tokens, num_tokens)\n", " attn_scores = queries @ keys.transpose(-2, -1)\n", " attn_scores = attn_scores.masked_fill(\n", " self.mask.bool()[:num_tokens, :num_tokens], -torch.inf\n", " )\n", "\n", " attn_weights = torch.softmax(attn_scores / keys.shape[-1]**-0.5, dim=-1)\n", " attn_weights = self.dropout(attn_weights)\n", "\n", " # (b, num_heads, num_tokens, num_tokens) --> (b, num_heads, num_tokens, head_dim)\n", " context_vec = attn_weights @ values\n", "\n", " # (b, num_heads, num_tokens, head_dim) --> (b, num_tokens, num_heads, head_dim)\n", " context_vec = context_vec.transpose(1, 2)\n", "\n", " # (b, num_tokens, num_heads, head_dim) --> (b, num_tokens, embed_dim)\n", " context_vec = context_vec.reshape(batch_size, num_tokens, embed_dim)\n", "\n", " context_vec = self.proj(context_vec)\n", "\n", " return context_vec\n", "\n", "\n", "mha_combined_qkv = MultiHeadAttentionCombinedQKV(\n", " d_in=embed_dim,\n", " d_out=embed_dim,\n", " block_size=context_len,\n", " dropout=0.0,\n", " num_heads=12,\n", " qkv_bias=False\n", ").to(device)\n", "\n", "out = mha_combined_qkv(embeddings)\n", "print(out.shape)" ] }, { "cell_type": "markdown", "id": "48a042d3-ee78-4c29-bf63-d92fe6706632", "metadata": { "id": "48a042d3-ee78-4c29-bf63-d92fe6706632" }, "source": [ "## 4) Multihead attention with PyTorch's scaled dot product attention" ] }, { "cell_type": "markdown", "id": "f78e346f-3b85-44e6-9feb-f01131381148", "metadata": { "id": "f78e346f-3b85-44e6-9feb-f01131381148" }, "source": [ "- The implementation below uses PyTorch's [`scaled_dot_product_attention`](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) function, which implements a memory-optimized version of self-attention calld [flash attention](https://arxiv.org/abs/2205.14135)" ] }, { "cell_type": "code", "execution_count": 5, "id": "1b8e5a0d-1f65-4a03-bf6e-723f0cc428f5", "metadata": { "id": "1b8e5a0d-1f65-4a03-bf6e-723f0cc428f5" }, "outputs": [], "source": [ "class MHAPyTorchScaledDotProduct(nn.Module):\n", " def __init__(self, d_in, d_out, num_heads, block_size, dropout=0.0, qkv_bias=False):\n", " super().__init__()\n", "\n", " assert d_out % num_heads == 0, \"embed_dim is indivisible by num_heads\"\n", "\n", " self.num_heads = num_heads\n", " self.block_size = block_size\n", " self.head_dim = d_out // num_heads\n", " self.d_out = d_out\n", "\n", " self.qkv = nn.Linear(d_in, 3 * d_out, bias=qkv_bias)\n", " self.proj = nn.Linear(d_in, d_out)\n", " self.dropout = dropout\n", "\n", " self.register_buffer(\n", " \"mask\", torch.triu(torch.ones(block_size, block_size), diagonal=1)\n", " )\n", "\n", " def forward(self, x):\n", " batch_size, num_tokens, embed_dim = x.shape\n", "\n", " # (b, num_tokens, embed_dim) --> (b, num_tokens, 3 * embed_dim)\n", " qkv = self.qkv(x)\n", "\n", " # (b, num_tokens, 3 * embed_dim) --> (b, num_tokens, 3, num_heads, head_dim)\n", " qkv = qkv.reshape(batch_size, num_tokens, 3, self.num_heads, self.head_dim)\n", "\n", " # (b, num_tokens, 3, num_heads, head_dim) --> (3, b, num_heads, num_tokens, head_dim)\n", " qkv = qkv.permute(2, 0, 3, 1, 4)\n", "\n", " # (3, b, num_heads, num_tokens, head_dim) -> 3 times (b, num_heads, num_tokens, head_dim)\n", " queries, keys, values = qkv.unbind(0)\n", "\n", " use_dropout = 0. if not self.training else self.dropout\n", " context_vec = nn.functional.scaled_dot_product_attention(\n", " queries, keys, values, attn_mask=None, dropout_p=use_dropout, is_causal=True)\n", "\n", " # Combine heads, where self.d_out = self.num_heads * self.head_dim\n", " context_vec = context_vec.transpose(1, 2).contiguous().view(batch_size, num_tokens, self.d_out)\n", "\n", " return context_vec" ] }, { "cell_type": "code", "execution_count": 6, "id": "fbc8ba92-3471-41cb-b1b2-4c0ef5be392b", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "fbc8ba92-3471-41cb-b1b2-4c0ef5be392b", "outputId": "ff45b9c9-19a7-4769-efaf-b9f417a31631" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([8, 1024, 768])\n" ] } ], "source": [ "mha_pytorch_scaled = MHAPyTorchScaledDotProduct(\n", " d_in=embed_dim,\n", " d_out=embed_dim,\n", " block_size=context_len,\n", " dropout=0.0,\n", " num_heads=12,\n", " qkv_bias=False\n", ").to(device)\n", "\n", "out = mha_pytorch_scaled(embeddings)\n", "print(out.shape)" ] }, { "cell_type": "markdown", "id": "351c318f-4835-4d74-8d58-a070222447c4", "metadata": { "id": "351c318f-4835-4d74-8d58-a070222447c4" }, "source": [ "## 5) Using PyTorch's torch.nn.MultiheadAttention" ] }, { "cell_type": "markdown", "id": "74a6d060-6324-48fa-a35c-cb09f2a48965", "metadata": { "id": "74a6d060-6324-48fa-a35c-cb09f2a48965" }, "source": [ "- Below, we use PyTorch's [torch.nn.MultiheadAttention](https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html) implementation" ] }, { "cell_type": "code", "execution_count": 7, "id": "3799c7ef-3155-42c6-a829-f95656453ae0", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "3799c7ef-3155-42c6-a829-f95656453ae0", "outputId": "1a6ce118-be18-4d7b-99bc-a0c9bb78f8b1" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([8, 1024, 768])\n" ] } ], "source": [ "import torch.nn as nn\n", "\n", "\n", "class MHAPyTorchClass(nn.Module):\n", " def __init__(self, d_in, d_out, num_heads, block_size, dropout=0.0, qkv_bias=False, need_weights=True):\n", " super().__init__()\n", "\n", " self.block_size = block_size\n", " self.multihead_attn = nn.MultiheadAttention(\n", " embed_dim=d_out,\n", " num_heads=num_heads,\n", " dropout=dropout,\n", " bias=qkv_bias,\n", " add_bias_kv=qkv_bias,\n", " batch_first=True,\n", " )\n", "\n", " self.need_weights = need_weights\n", " self.proj = nn.Linear(d_out, d_out)\n", " self.register_buffer(\"mask\", torch.triu(torch.ones(block_size, block_size), diagonal=1).bool())\n", "\n", " def forward(self, x):\n", " batch_size, num_tokens, _ = x.shape\n", "\n", " # Ensure attn_mask is compatible with expected shape and `batch_first=True`\n", " # No need to manually adjust for num_heads; ensure it's right for the sequence\n", " if self.block_size >= num_tokens:\n", " attn_mask = self.mask[:num_tokens, :num_tokens]\n", " else:\n", " attn_mask = self.mask[:self.block_size, :self.block_size]\n", "\n", " # attn_mask broadcasting will handle batch_size dimension implicitly\n", " attn_output, _ = self.multihead_attn(\n", " x, x, x, attn_mask=attn_mask, need_weights=self.need_weights\n", " )\n", "\n", " output = self.proj(attn_output)\n", "\n", " return output\n", "\n", "\n", "mha_pytorch_class_default = MHAPyTorchClass(\n", " d_in=embed_dim,\n", " d_out=embed_dim,\n", " block_size=context_len,\n", " dropout=0.0,\n", " num_heads=12,\n", " qkv_bias=False\n", ").to(device)\n", "\n", "out = mha_pytorch_class_default(embeddings)\n", "print(out.shape)" ] }, { "cell_type": "markdown", "id": "a3953bff-1056-4de2-bfd1-dfccf659eee4", "metadata": {}, "source": [ "## 6) Using PyTorch's torch.nn.MultiheadAttention with `scaled_dot_product_attention`" ] }, { "cell_type": "markdown", "id": "d2164859-31a0-4537-b4fb-27d57675ba77", "metadata": { "id": "d2164859-31a0-4537-b4fb-27d57675ba77" }, "source": [ "- Set `need_weights` (default `True`) to need_weights=False so that MultiheadAttention uses `scaled_dot_product_attention` [according to the documentation](https://github.com/pytorch/pytorch/blob/71d020262793542974cf13b30f2a9099773f015c/torch/nn/modules/activation.py#L1096)\n", "\n", "> need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``.\n", " Set ``need_weights=False`` to use the optimized ``scaled_dot_product_attention``\n", " and achieve the best performance for MHA.\n", " Default: ``True``." ] }, { "cell_type": "code", "execution_count": 8, "id": "4a4c2afe-5e1f-4bd7-a118-67031176f147", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "4a4c2afe-5e1f-4bd7-a118-67031176f147", "outputId": "ef1a0698-5b18-426d-df0f-00bdbaeeaccc" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([8, 1024, 768])\n" ] } ], "source": [ "mha_pytorch_class_noweights = MHAPyTorchClass(\n", " d_in=embed_dim,\n", " d_out=embed_dim,\n", " block_size=context_len,\n", " dropout=0.0,\n", " num_heads=12,\n", " qkv_bias=False,\n", " need_weights=False # NEW!\n", ").to(device)\n", "\n", "out = mha_pytorch_class_noweights(embeddings)\n", "print(out.shape)" ] }, { "cell_type": "markdown", "id": "8877de71-f84f-4f6d-bc87-7552013b6301", "metadata": { "id": "8877de71-f84f-4f6d-bc87-7552013b6301" }, "source": [ "## Speed comparison (M1 Macbook Air CPU)" ] }, { "cell_type": "code", "execution_count": 9, "id": "a97c0b2e-6593-49d8-98bc-2267b3aa610f", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "a97c0b2e-6593-49d8-98bc-2267b3aa610f", "outputId": "ebe635b2-5c03-4e9b-da3a-951d308acf7b" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "914 ms ± 50.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" ] } ], "source": [ "## 1) CausalAttention MHA wrapper class from chapter 3\n", "%timeit mha_ch03_wrapper(embeddings)" ] }, { "cell_type": "code", "execution_count": 10, "id": "19db9c2c-8e75-431a-8eef-0b4d8284e6e6", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "19db9c2c-8e75-431a-8eef-0b4d8284e6e6", "outputId": "c6e7bcff-661c-45a6-da82-b1e3f89cf761" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "252 ms ± 9.04 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" ] } ], "source": [ "## 2) The multi-head attention class from chapter 3\n", "%timeit mha_ch03(embeddings)" ] }, { "cell_type": "code", "execution_count": 11, "id": "aa526ee0-7a88-4f34-a49a-f8f97da83779", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "aa526ee0-7a88-4f34-a49a-f8f97da83779", "outputId": "92b634f8-43f8-468f-87a1-bb774b64c212" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "300 ms ± 8.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" ] } ], "source": [ "## 3) An alternative multi-head attention with combined weights\n", "%timeit mha_combined_qkv(embeddings)" ] }, { "cell_type": "code", "execution_count": 12, "id": "cc2b4256-16d8-4c34-9fd0-d4b4af0e60fa", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "cc2b4256-16d8-4c34-9fd0-d4b4af0e60fa", "outputId": "80c6e314-0771-470e-b090-628984ce2d85" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "94.2 ms ± 1.6 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" ] } ], "source": [ "## 4) Multihead attention with PyTorch's scaled dot product attention\n", "%timeit mha_pytorch_scaled(embeddings)" ] }, { "cell_type": "code", "execution_count": 13, "id": "0f209e70-ebb6-4a1a-b608-1ff42e41c01d", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "0f209e70-ebb6-4a1a-b608-1ff42e41c01d", "outputId": "3cd37b53-04d4-4dd0-9450-6fc8ebaac083" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "297 ms ± 2.37 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" ] } ], "source": [ "## 5) Using PyTorch's torch.nn.MultiheadAttention\n", "%timeit mha_pytorch_class_default(embeddings)" ] }, { "cell_type": "code", "execution_count": 14, "id": "3f4968c2-8d40-4ab9-8dba-052b4f77d756", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "274 ms ± 2.19 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" ] } ], "source": [ "## 6) Using PyTorch's torch.nn.MultiheadAttention disabling `need_weights`\n", "%timeit mha_pytorch_class_noweights(embeddings)" ] }, { "cell_type": "markdown", "id": "a78ff594-6cc2-496d-a302-789fa104c3c9", "metadata": { "id": "8877de71-f84f-4f6d-bc87-7552013b6301" }, "source": [ "## Speed comparison (Nvidia A100 GPU)" ] }, { "cell_type": "code", "execution_count": 9, "id": "707a2a14-a089-48a8-88aa-d328e1e0a9d0", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "a97c0b2e-6593-49d8-98bc-2267b3aa610f", "outputId": "a0ea49e8-3907-4021-a8ee-a7541d19f650" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "41 ms ± 6.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" ] } ], "source": [ "## 1) CausalAttention MHA wrapper class from chapter 3\n", "%timeit mha_ch03_wrapper(embeddings)" ] }, { "cell_type": "code", "execution_count": 10, "id": "8686dd69-3655-40e4-a57b-a2c55532a010", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "19db9c2c-8e75-431a-8eef-0b4d8284e6e6", "outputId": "d11c08c3-d741-4427-9b45-9b1e8882cf80" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "6.57 ms ± 996 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" ] } ], "source": [ "## 2) The multi-head attention class from chapter 3\n", "%timeit mha_ch03(embeddings)" ] }, { "cell_type": "code", "execution_count": 11, "id": "2209d7df-e54b-4910-ae2b-c78cf684d9bf", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "aa526ee0-7a88-4f34-a49a-f8f97da83779", "outputId": "8e69637e-4d6a-4b25-acb0-2fa829167932" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "7.19 ms ± 433 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" ] } ], "source": [ "## 3) An alternative multi-head attention with combined weights\n", "%timeit mha_combined_qkv(embeddings)" ] }, { "cell_type": "code", "execution_count": 12, "id": "1075abe2-4839-4fd6-af3e-c09bb3651e26", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "cc2b4256-16d8-4c34-9fd0-d4b4af0e60fa", "outputId": "17aee8fc-8e6c-4a7f-d8fc-0880fd816467" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2.37 ms ± 228 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n" ] } ], "source": [ "## 4) Multihead attention with PyTorch's scaled dot product attention\n", "%timeit mha_pytorch_scaled(embeddings)" ] }, { "cell_type": "code", "execution_count": 13, "id": "868e3670-8edc-47bc-9e06-eb505e44dc9d", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "0f209e70-ebb6-4a1a-b608-1ff42e41c01d", "outputId": "47737f93-1762-4561-9bb5-a8ac356ff11b" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "6.66 ms ± 241 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" ] } ], "source": [ "## 5) Using PyTorch's torch.nn.MultiheadAttention\n", "%timeit mha_pytorch_class_default(embeddings)" ] }, { "cell_type": "code", "execution_count": 14, "id": "944870e6-de54-4e3b-a455-b8f21f6f92c8", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "3f4968c2-8d40-4ab9-8dba-052b4f77d756", "outputId": "9fba58fe-b3d1-43ae-ec15-45b57792785f" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "4.52 ms ± 243 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)\n" ] } ], "source": [ "## 6) Using PyTorch's torch.nn.MultiheadAttention disabling `need_weights`\n", "%timeit mha_pytorch_class_noweights(embeddings)" ] } ], "metadata": { "accelerator": "GPU", "colab": { "gpuType": "A100", "machine_shape": "hm", "provenance": [] }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 5 }