851 lines
94 KiB
Plaintext
Raw Normal View History

2024-03-13 08:34:39 -05:00
{
"cells": [
{
"cell_type": "markdown",
"id": "6f678e62-7bcb-4405-86ae-dce94f494303",
"metadata": {
"id": "6f678e62-7bcb-4405-86ae-dce94f494303"
},
"source": [
"# Efficient Multi-Head Attention Implementations"
]
},
{
"cell_type": "markdown",
"id": "b742938a-4bfc-4527-a1f1-d5963508967d",
"metadata": {
"id": "b742938a-4bfc-4527-a1f1-d5963508967d"
},
"source": [
"This code notebook compares different ways to implement causal multi-head attention used in decoder-style LLMs like GPT, Llama, etc."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "7898551e-f582-48ac-9f66-3632abe2a93f",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "7898551e-f582-48ac-9f66-3632abe2a93f",
"outputId": "7d088260-3fa1-44f2-bd65-2a46e289f9d4"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"PyTorch version: 2.1.0\n",
"Running on cpu\n"
]
}
],
"source": [
"import torch\n",
"\n",
"torch.manual_seed(123)\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"print(f\"PyTorch version: {torch.__version__}\")\n",
"print(f\"Running on {device}\")\n",
"\n",
"batch_size = 8\n",
"context_len = 1024\n",
"embed_dim = 768\n",
"embeddings = torch.randn((batch_size, context_len, embed_dim), device=device)"
]
},
{
"cell_type": "markdown",
"id": "2f9bb1b6-a1e5-4e0a-884d-0f31b374a8d6",
"metadata": {
"id": "2f9bb1b6-a1e5-4e0a-884d-0f31b374a8d6"
},
"source": [
"## 1) CausalAttention MHA wrapper class from chapter 3"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "297c93ed-aec0-4896-bb89-42c4b294d3d1",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "297c93ed-aec0-4896-bb89-42c4b294d3d1",
"outputId": "f8a33752-2cd6-4101-8feb-9d1699984719"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([8, 1024, 768])\n"
]
}
],
"source": [
"from ch03 import MultiHeadAttentionWrapper as Ch03_MHA_Wrapper\n",
"\n",
"mha_ch03_wrapper = Ch03_MHA_Wrapper(\n",
" d_in=embed_dim,\n",
" d_out=embed_dim//12,\n",
" block_size=context_len,\n",
" dropout=0.0,\n",
" num_heads=12,\n",
" qkv_bias=False\n",
").to(device)\n",
"\n",
"out = mha_ch03_wrapper(embeddings)\n",
"print(out.shape)"
]
},
{
"cell_type": "markdown",
"id": "21930804-b327-40b1-8e63-94dcad39ce7b",
"metadata": {
"id": "21930804-b327-40b1-8e63-94dcad39ce7b"
},
"source": [
"## 2) The multi-head attention class from chapter 3"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "4ee6a61b-d25c-4a0c-8a59-f285544e3710",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "4ee6a61b-d25c-4a0c-8a59-f285544e3710",
"outputId": "b704a040-3547-422c-ecda-df9982a2da35"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([8, 1024, 768])\n"
]
}
],
"source": [
"from ch03 import MultiHeadAttention as Ch03_MHA\n",
"\n",
"mha_ch03 = Ch03_MHA(\n",
" d_in=embed_dim,\n",
" d_out=embed_dim,\n",
" block_size=context_len,\n",
" dropout=0.0,\n",
" num_heads=12,\n",
" qkv_bias=False\n",
").to(device)\n",
"\n",
"out = mha_ch03(embeddings)\n",
"print(out.shape)"
]
},
{
"cell_type": "markdown",
"id": "73cd11da-ea3b-4081-b483-c4965dfefbc4",
"metadata": {
"id": "73cd11da-ea3b-4081-b483-c4965dfefbc4"
},
"source": [
"## 3) An alternative multi-head attention with combined weights"
]
},
{
"cell_type": "markdown",
"id": "1fa1a5ea-eaff-4d2d-aaf0-b34cdb6fd4dd",
"metadata": {
"id": "1fa1a5ea-eaff-4d2d-aaf0-b34cdb6fd4dd"
},
"source": [
"- The code for the `MultiHeadAttentionAlt` class below is based on code that was kindly shared by [Rayed Bin Wahed](https://github.com/rasbt/LLMs-from-scratch/discussions/51)\n",
"- The main difference between the `MultiHeadAttentionAlt` class and the `MultiHeadAttention` class used in chapter 3 is that `MultiHeadAttentionAlt` uses a single weight matrix, `self.qkv = nn.Linear(d_in, 3 * d_out, bias=qkv_bias)` instead of separate weight matrices:\n",
"\n",
" - `self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)`\n",
" - `self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)`\n",
" - `self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)`\n",
"\n",
"- Here, `self.qkv` combines all three weight matrices `self.W_query`, `self.W_key`, and `self.W_value` to carry out the query, key, and value computation in a single step\n",
"- Using `q, k, v = qkv.unbind(0)`, we obtain the individual query, key, and value tensors, which are then used similarly to the query, key, and value tensors in the `MultiHeadAttention` class in chapter 3"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "9a6bd0a2-f27c-4602-afa0-c96cd295c1a6",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "9a6bd0a2-f27c-4602-afa0-c96cd295c1a6",
"outputId": "5d948671-176f-4633-bede-97767e36becc"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([8, 1024, 768])\n"
]
}
],
"source": [
"import torch.nn as nn\n",
"\n",
"\n",
"class MultiHeadAttentionCombinedQKV(nn.Module):\n",
" def __init__(self, d_in, d_out, num_heads, block_size, dropout=0.0, qkv_bias=False):\n",
" super().__init__()\n",
"\n",
" assert d_out % num_heads == 0, \"embed_dim is indivisible by num_heads\"\n",
"\n",
" self.num_heads = num_heads\n",
" self.block_size = block_size\n",
" self.head_dim = d_out // num_heads\n",
"\n",
" self.qkv = nn.Linear(d_in, 3 * d_out, bias=qkv_bias)\n",
" self.proj = nn.Linear(d_in, d_out)\n",
" self.dropout = nn.Dropout(dropout)\n",
"\n",
" self.register_buffer(\n",
" \"mask\", torch.triu(torch.ones(block_size, block_size), diagonal=1)\n",
" )\n",
"\n",
" def forward(self, x):\n",
" batch_size, num_tokens, embed_dim = x.shape\n",
"\n",
" # (b, num_tokens, embed_dim) --> (b, num_tokens, 3 * embed_dim)\n",
" qkv = self.qkv(x)\n",
"\n",
" # (b, num_tokens, 3 * embed_dim) --> (b, num_tokens, 3, num_heads, head_dim)\n",
" qkv = qkv.reshape(batch_size, num_tokens, 3, self.num_heads, self.head_dim)\n",
"\n",
" # (b, num_tokens, 3, num_heads, head_dim) --> (3, b, num_heads, num_tokens, head_dim)\n",
" qkv = qkv.permute(2, 0, 3, 1, 4)\n",
"\n",
" # (3, b, num_heads, num_tokens, head_dim) -> 3 times (b, num_head, num_tokens, head_dim)\n",
" queries, keys, values = qkv.unbind(0)\n",
"\n",
" # (b, num_heads, num_tokens, head_dim) --> (b, num_heads, num_tokens, num_tokens)\n",
" attn_scores = queries @ keys.transpose(-2, -1)\n",
" attn_scores = attn_scores.masked_fill(\n",
" self.mask.bool()[:num_tokens, :num_tokens], -torch.inf\n",
" )\n",
"\n",
" attn_weights = torch.softmax(attn_scores / keys.shape[-1]**-0.5, dim=-1)\n",
" attn_weights = self.dropout(attn_weights)\n",
"\n",
" # (b, num_heads, num_tokens, num_tokens) --> (b, num_heads, num_tokens, head_dim)\n",
" context_vec = attn_weights @ values\n",
"\n",
" # (b, num_heads, num_tokens, head_dim) --> (b, num_tokens, num_heads, head_dim)\n",
" context_vec = context_vec.transpose(1, 2)\n",
"\n",
" # (b, num_tokens, num_heads, head_dim) --> (b, num_tokens, embed_dim)\n",
" context_vec = context_vec.reshape(batch_size, num_tokens, embed_dim)\n",
"\n",
" context_vec = self.proj(context_vec)\n",
"\n",
" return context_vec\n",
"\n",
"\n",
"mha_combined_qkv = MultiHeadAttentionCombinedQKV(\n",
" d_in=embed_dim,\n",
" d_out=embed_dim,\n",
" block_size=context_len,\n",
" dropout=0.0,\n",
" num_heads=12,\n",
" qkv_bias=False\n",
").to(device)\n",
"\n",
"out = mha_combined_qkv(embeddings)\n",
"print(out.shape)"
]
},
{
"cell_type": "markdown",
"id": "48a042d3-ee78-4c29-bf63-d92fe6706632",
"metadata": {
"id": "48a042d3-ee78-4c29-bf63-d92fe6706632"
},
"source": [
"## 4) Multihead attention with PyTorch's scaled dot product attention"
]
},
{
"cell_type": "markdown",
"id": "f78e346f-3b85-44e6-9feb-f01131381148",
"metadata": {
"id": "f78e346f-3b85-44e6-9feb-f01131381148"
},
"source": [
"- The implementation below uses PyTorch's [`scaled_dot_product_attention`](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) function, which implements a memory-optimized version of self-attention calld [flash attention](https://arxiv.org/abs/2205.14135)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "1b8e5a0d-1f65-4a03-bf6e-723f0cc428f5",
"metadata": {
"id": "1b8e5a0d-1f65-4a03-bf6e-723f0cc428f5"
},
"outputs": [],
"source": [
"class MHAPyTorchScaledDotProduct(nn.Module):\n",
" def __init__(self, d_in, d_out, num_heads, block_size, dropout=0.0, qkv_bias=False):\n",
" super().__init__()\n",
"\n",
" assert d_out % num_heads == 0, \"embed_dim is indivisible by num_heads\"\n",
"\n",
" self.num_heads = num_heads\n",
" self.block_size = block_size\n",
" self.head_dim = d_out // num_heads\n",
" self.d_out = d_out\n",
"\n",
" self.qkv = nn.Linear(d_in, 3 * d_out, bias=qkv_bias)\n",
" self.proj = nn.Linear(d_in, d_out)\n",
" self.dropout = dropout\n",
"\n",
" self.register_buffer(\n",
" \"mask\", torch.triu(torch.ones(block_size, block_size), diagonal=1)\n",
" )\n",
"\n",
" def forward(self, x):\n",
" batch_size, num_tokens, embed_dim = x.shape\n",
"\n",
" # (b, num_tokens, embed_dim) --> (b, num_tokens, 3 * embed_dim)\n",
" qkv = self.qkv(x)\n",
"\n",
" # (b, num_tokens, 3 * embed_dim) --> (b, num_tokens, 3, num_heads, head_dim)\n",
" qkv = qkv.reshape(batch_size, num_tokens, 3, self.num_heads, self.head_dim)\n",
"\n",
" # (b, num_tokens, 3, num_heads, head_dim) --> (3, b, num_heads, num_tokens, head_dim)\n",
" qkv = qkv.permute(2, 0, 3, 1, 4)\n",
"\n",
" # (3, b, num_heads, num_tokens, head_dim) -> 3 times (b, num_heads, num_tokens, head_dim)\n",
" queries, keys, values = qkv.unbind(0)\n",
"\n",
" use_dropout = 0. if not self.training else self.dropout\n",
" context_vec = nn.functional.scaled_dot_product_attention(\n",
" queries, keys, values, attn_mask=None, dropout_p=use_dropout, is_causal=True)\n",
"\n",
" # Combine heads, where self.d_out = self.num_heads * self.head_dim\n",
" context_vec = context_vec.transpose(1, 2).contiguous().view(batch_size, num_tokens, self.d_out)\n",
"\n",
" return context_vec"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "fbc8ba92-3471-41cb-b1b2-4c0ef5be392b",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "fbc8ba92-3471-41cb-b1b2-4c0ef5be392b",
"outputId": "af9e4855-7f20-4d61-8532-4827df8dfb30"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([8, 1024, 768])\n"
]
}
],
"source": [
"mha_pytorch_scaled = MHAPyTorchScaledDotProduct(\n",
" d_in=embed_dim,\n",
" d_out=embed_dim,\n",
" block_size=context_len,\n",
" dropout=0.0,\n",
" num_heads=12,\n",
" qkv_bias=False\n",
").to(device)\n",
"\n",
"out = mha_pytorch_scaled(embeddings)\n",
"print(out.shape)"
]
},
{
"cell_type": "markdown",
"id": "351c318f-4835-4d74-8d58-a070222447c4",
"metadata": {
"id": "351c318f-4835-4d74-8d58-a070222447c4"
},
"source": [
"## 5) Using PyTorch's torch.nn.MultiheadAttention"
]
},
{
"cell_type": "markdown",
"id": "74a6d060-6324-48fa-a35c-cb09f2a48965",
"metadata": {
"id": "74a6d060-6324-48fa-a35c-cb09f2a48965"
},
"source": [
"- Below, we use PyTorch's [torch.nn.MultiheadAttention](https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html) implementation"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "3799c7ef-3155-42c6-a829-f95656453ae0",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "3799c7ef-3155-42c6-a829-f95656453ae0",
"outputId": "2a085df8-0445-4818-9978-6dc74469f568"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([8, 1024, 768])\n"
]
}
],
"source": [
"import torch.nn as nn\n",
"\n",
"\n",
"class MHAPyTorchClass(nn.Module):\n",
" def __init__(self, d_in, d_out, num_heads, block_size, dropout=0.0, qkv_bias=False, need_weights=True):\n",
" super().__init__()\n",
"\n",
" self.block_size = block_size\n",
" self.multihead_attn = nn.MultiheadAttention(\n",
" embed_dim=d_out,\n",
" num_heads=num_heads,\n",
" dropout=dropout,\n",
" bias=qkv_bias,\n",
" add_bias_kv=qkv_bias,\n",
" batch_first=True,\n",
" )\n",
"\n",
" self.need_weights = need_weights\n",
" self.proj = nn.Linear(d_out, d_out)\n",
" self.register_buffer(\"mask\", torch.triu(torch.ones(block_size, block_size), diagonal=1).bool())\n",
"\n",
" def forward(self, x):\n",
" batch_size, num_tokens, _ = x.shape\n",
"\n",
" # Ensure attn_mask is compatible with expected shape and `batch_first=True`\n",
" # No need to manually adjust for num_heads; ensure it's right for the sequence\n",
" if self.block_size >= num_tokens:\n",
" attn_mask = self.mask[:num_tokens, :num_tokens]\n",
" else:\n",
" attn_mask = self.mask[:self.block_size, :self.block_size]\n",
"\n",
" # attn_mask broadcasting will handle batch_size dimension implicitly\n",
" attn_output, _ = self.multihead_attn(\n",
" x, x, x, attn_mask=attn_mask, need_weights=self.need_weights\n",
" )\n",
"\n",
" output = self.proj(attn_output)\n",
"\n",
" return output\n",
"\n",
"\n",
"mha_pytorch_class_default = MHAPyTorchClass(\n",
" d_in=embed_dim,\n",
" d_out=embed_dim,\n",
" block_size=context_len,\n",
" dropout=0.0,\n",
" num_heads=12,\n",
" qkv_bias=False\n",
").to(device)\n",
"\n",
"out = mha_pytorch_class_default(embeddings)\n",
"print(out.shape)"
]
},
{
"cell_type": "markdown",
"id": "a3953bff-1056-4de2-bfd1-dfccf659eee4",
"metadata": {
"id": "a3953bff-1056-4de2-bfd1-dfccf659eee4"
},
"source": [
"## 6) Using PyTorch's torch.nn.MultiheadAttention with `scaled_dot_product_attention`"
]
},
{
"cell_type": "markdown",
"id": "d2164859-31a0-4537-b4fb-27d57675ba77",
"metadata": {
"id": "d2164859-31a0-4537-b4fb-27d57675ba77"
},
"source": [
"- Set `need_weights` (default `True`) to need_weights=False so that MultiheadAttention uses `scaled_dot_product_attention` [according to the documentation](https://github.com/pytorch/pytorch/blob/71d020262793542974cf13b30f2a9099773f015c/torch/nn/modules/activation.py#L1096)\n",
"\n",
"> need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``.\n",
" Set ``need_weights=False`` to use the optimized ``scaled_dot_product_attention``\n",
" and achieve the best performance for MHA.\n",
" Default: ``True``."
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "4a4c2afe-5e1f-4bd7-a118-67031176f147",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "4a4c2afe-5e1f-4bd7-a118-67031176f147",
"outputId": "234771f4-8a53-4478-8a9b-cf19f79a5e07"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([8, 1024, 768])\n"
]
}
],
"source": [
"mha_pytorch_class_noweights = MHAPyTorchClass(\n",
" d_in=embed_dim,\n",
" d_out=embed_dim,\n",
" block_size=context_len,\n",
" dropout=0.0,\n",
" num_heads=12,\n",
" qkv_bias=False,\n",
" need_weights=False # NEW!\n",
").to(device)\n",
"\n",
"out = mha_pytorch_class_noweights(embeddings)\n",
"print(out.shape)"
]
},
{
"cell_type": "markdown",
"id": "8877de71-f84f-4f6d-bc87-7552013b6301",
"metadata": {
"id": "8877de71-f84f-4f6d-bc87-7552013b6301"
},
"source": [
"## Quick speed comparison (M3 Macbook Air CPU)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "a97c0b2e-6593-49d8-98bc-2267b3aa610f",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "a97c0b2e-6593-49d8-98bc-2267b3aa610f",
"outputId": "ebe635b2-5c03-4e9b-da3a-951d308acf7b"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"194 ms ± 2.75 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
]
}
],
"source": [
"## 1) CausalAttention MHA wrapper class from chapter 3\n",
"%timeit mha_ch03_wrapper(embeddings)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "19db9c2c-8e75-431a-8eef-0b4d8284e6e6",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "19db9c2c-8e75-431a-8eef-0b4d8284e6e6",
"outputId": "c6e7bcff-661c-45a6-da82-b1e3f89cf761"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"198 ms ± 4.12 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
]
}
],
"source": [
"## 2) The multi-head attention class from chapter 3\n",
"%timeit mha_ch03(embeddings)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "aa526ee0-7a88-4f34-a49a-f8f97da83779",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "aa526ee0-7a88-4f34-a49a-f8f97da83779",
"outputId": "92b634f8-43f8-468f-87a1-bb774b64c212"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"234 ms ± 4.26 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
]
}
],
"source": [
"## 3) An alternative multi-head attention with combined weights\n",
"%timeit mha_combined_qkv(embeddings)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "cc2b4256-16d8-4c34-9fd0-d4b4af0e60fa",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "cc2b4256-16d8-4c34-9fd0-d4b4af0e60fa",
"outputId": "80c6e314-0771-470e-b090-628984ce2d85"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"71.7 ms ± 3.65 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
]
}
],
"source": [
"## 4) Multihead attention with PyTorch's scaled dot product attention\n",
"%timeit mha_pytorch_scaled(embeddings)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "0f209e70-ebb6-4a1a-b608-1ff42e41c01d",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "0f209e70-ebb6-4a1a-b608-1ff42e41c01d",
"outputId": "3cd37b53-04d4-4dd0-9450-6fc8ebaac083"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"211 ms ± 5.31 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
]
}
],
"source": [
"## 5) Using PyTorch's torch.nn.MultiheadAttention\n",
"%timeit mha_pytorch_class_default(embeddings)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "3f4968c2-8d40-4ab9-8dba-052b4f77d756",
"metadata": {
"id": "3f4968c2-8d40-4ab9-8dba-052b4f77d756",
"outputId": "2e86bdb4-7fa0-4051-b000-4a2b591060a2",
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"207 ms ± 18.3 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
]
}
],
"source": [
"## 6) Using PyTorch's torch.nn.MultiheadAttention disabling `need_weights`\n",
"%timeit mha_pytorch_class_noweights(embeddings)"
]
},
{
"cell_type": "markdown",
"id": "dabc6575-0316-4640-a729-e616d5c17b73",
"metadata": {
"id": "dabc6575-0316-4640-a729-e616d5c17b73"
},
"source": [
"## Speed comparison (Nvidia A100 GPU) with warmup"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "29b63d3d-6d0b-43bb-9c68-d5514dc81000",
"metadata": {
"id": "29b63d3d-6d0b-43bb-9c68-d5514dc81000"
},
"outputs": [],
"source": [
"# CUDA benchmark code shared by Andrei Aksionov\n",
"# and based on code from\n",
"# https://github.com/cuda-mode/lectures/blob/main/lecture1/pytorch_square.py\n",
"\n",
"import time\n",
"\n",
"def time_pytorch_function(func, *input, num_repeats = 100):\n",
" # CUDA IS ASYNC so can't use python time module\n",
" #start = torch.cuda.Event(enable_timing=True)\n",
" #end = torch.cuda.Event(enable_timing=True)\n",
" start = time.time()\n",
" # Warmup\n",
" #for _ in range(5):\n",
" # func(*input)\n",
" #torch.cuda.synchronize()\n",
"\n",
" #start.record()\n",
" for _ in range(num_repeats):\n",
" func(*input)\n",
" #torch.cuda.synchronize()\n",
" #end.record()\n",
" #torch.cuda.synchronize()\n",
" return (time.time()-start) / num_repeats"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "CDJAPZaszaqx",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 489
},
"id": "CDJAPZaszaqx",
"outputId": "f23e9b83-7fd6-4011-9434-0e6934cf762a"
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAnAAAAHWCAYAAAD3vrTNAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy88F64QAAAACXBIWXMAAA9hAAAPYQGoP6dpAADLCUlEQVR4nOzddVhVWdvA4d+hRVAERMEce2xndHQMxMQiDLAFRcEEC8XAwlExUBRjFNuxu1vsGsXA7sDAFgPJ7w8/9ssRMGYOAuNzX5fXO2efffZeZ72bdZ699lrPUpmamiYghBBCCCEyDa30LoAQQgghhPg2EsAJIYQQQmQyEsAJIYQQQmQyEsAJIYQQQmQyEsAJIYQQQmQyEsAJIYQQQmQyEsAJIYQQQmQyEsAJIYQQQmQyOuldgIzG0tKSN2/epHcxhBBCCPGDMjIy4uHDh5/dRwK4JCwtLQkLC0vvYgghhBDiB1e6dOnPBnESwCWR2PNWunRp6YUTQgghxHdnZGREWFjYF+MQCeBS8ObNGyIjI9O7GEIIIYQQKZJJDEIIIYQQmUyGDuDc3NwIDQ0lPDycnTt38ssvv6S6b5MmTdizZw83b97k7t27hISE4Ozs/B1LK4QQQgjxfWTYAM7R0RE/Pz8mTJhA7dq1CQsLY9WqVZibm6e4/4sXLwgICKBBgwZYW1uzdOlSpk2bRq1atb5zyYUQQggh0laGDeC6d+/O4sWLWbp0KVeuXKFfv368f/+etm3bprj/4cOH2bJlC1evXuX27dvMnj2bCxcuUKVKle9cciGEEEKItJUhAzhdXV3KlSvH/v37lW0JCQns37+fSpUqfdUxrK2tKVKkCEePHk2rYgohhBBCpIsMGcCZmZmho6NDRESE2vaIiAgsLCxS/ZyxsTF37tzh0aNHLFu2DB8fH0JCQlLdX09PD2NjY+WfkZGRpr6CEOIzvmV8a/v27dm8eTM3btzgxo0brF27NsX9ixUrxpIlS7h16xZ3795l9+7d5MmTJy2/hhBCpJsMGcD9U2/evMHGxoa6devyxx9/MHr0aKpVq5bq/r179+b27dvKP0niK0Ta+9bxrdWqVWPt2rU4ODjQoEEDwsPDWb16NZaWlso+BQsWZMuWLVy7dg17e3usra2ZOHEiHz58+F5fSwghviuVqalpQnoX4lO6urrcv3+fjh07snXrVmX79OnTyZ49O+3atfuq40yZMoU8efLg5OSU4vt6enro6+srrxOT5xUsWFDywAmRRnbu3EloaCgDBw4EQKVScf78eebMmUNgYOAXP6+lpcXNmzcZOHAgK1asAGDOnDnExsbSrVu3NC27EEKkNWNjY27fvv3FWCRD9sDFxMRw9uxZrK2tlW0qlQpra2tOnjz51cfR0tJCT08v1fejo6OJjIxU/snqC0KkLU2MbzU0NERHR4cXL14AH9uG+vXrc/36dVatWsXly5fZuXMnjRo1SpPvIIQQGUGGDOAAZsyYQfv27WnVqhXFihVj4sSJGBoasnTpUuV9X19fZf/evXtjY2NDgQIFKFasGN27d8fZ2ZlVq1al11cQQnzin45vTWr48OE8evRICQJz5syJkZERXl5e7NmzhxYtWrBlyxYWLlxI1apVNf4dhBAiI8iwS2mtX78ec3NzfHx8sLCwICwsDGdnZ548eQJAnjx5iI+PV/Y3NDRk/PjxWFlZERUVxbVr1+jatSvr169Pp28ghNA0Ly8vmjZtir29vTK+TUvr433otm3bmDVrFgBhYWFUqlQJV1dXjhw5km7lFUKItJJhAziA4OBggoODU3zPwcFB7fWYMWMYM2bM9yiWEOIfevbsGbGxscl62ywsLJL1yn2qR48eeHl50axZMy5evKh2zJiYGK5evaq2/7Vr16hcubLmCi+EEBlIhn2EKoT47/mn41t79epF//79cXZ25syZM8mOGRoaSpEiRdS2Fy5cmHv37mm0/EIIkVFk6B44IcR/z4wZM5g+fTpnzpzh9OnTeHh4JBvf+vDhQ/z8/ADw9PTEx8cHDw8P7t69q/TevX37lrdv3wIQFBREcHAwR44c4dChQ9SpUwdbW1vs7e3T50sKIUQakwBOCPFdfev41o4dO6Kvr8+CBQvUjuPv78/48eMB2LJlC/369aN3796MHTuW69ev4+rqyvHjx7/b9xJCiO8pQ+aBSy9fm3tFCCGEECItZOo8cEIIIYQQInUSwAkhhBCpSIt1exNNnDiRZ8+e4eHhkRZFF/9xEsAJIYQQKUiLdXsTNW7cmIoVK/Lw4cO0/hriP0oCOCGEECIF3bt3Z/HixSxdupQrV67Qr18/3r9/T9u2bVPcv2vXrsybN4+wsDCuXbuGl5cXWlpaamlzACwtLRk3bhweHh7ExMR8j68i/oM0Ogs1f/78/P777+TNmxdDQ0OePn3K+fPnOXnypJI1XQghhMjoEtftnTJlirLt367bCx/zHs6cOZNp06Zx5coVTRdb/EA0EsC1aNECDw8PypcvT0REBI8ePSIqKoocOXJQsGBBPnz4wOrVqwkMDOT+/fuaOKUQQgiRZj63bm/RokW/6hifrtsLH5eDi42NZfbs2Rotr/jx/OsAbt++fcTExLBs2TJcXFx48OCB2vt6enpUqlSJpk2bsmfPHry9vdm4ceO/Pa0QQgiRYaW0bm+5cuVwd3endu3a6Vw68V/wrwO4UaNGsW/fvlTfj46O5vDhwxw+fJg//viD/Pnz/9tTCiGEEGkqLdbtrVKlCjlz5uTs2bPKNh0dHfz8/OjatSsVKlTQ7JcQ/2n/ehLD54K3T7148ULtwhVCCCEyorRYt3flypXUqFGDmjVrKv8ePnxIUFAQTk5OafVVMiRNp2cZMGAAx44d4+7du8o+v/76a1p/jXSl0VmoZcuW5eeff1ZeN2zYkMWLFzN06FB0dXU1eSohhBAiTc2YMYP27dvTqlUrihUrxsSJE5Ot2+vr66vs7+npyaBBg/D09FTW7bWwsCBr1qzAx06My5cvq/2LiYnh8ePHXL9+PV2+Y3pIi/QsN27cYODAgdSoUYNGjRpx9+5dVq9ejZmZ2ff6Wt+dRgO4gIAAihQpAkCBAgWYM2cO7969w97enhEjRmjyVEIIIUSaWr9+PcOHD8fHx4eQkBDKlCmTbN3eXLlyKfsnXbf30qVLyr8ePXqk11fIkNIiPcuaNWvYv38/d+7c4cqVK/j6+pItWzZKlSr1vb7Wd6fRNCKFCxfm/PnzADg4OHD06FE8PDz47bffCA4OZsiQIZo8nRBpzs3NjZ49e2JhYcGFCxfw8fHh9OnTKe7bvn17WrZsqfRCnz17ltGjRyv76+joMGTIEOrWrUuBAgWIjIxk//79jBo1ikePHn237/QtjLssSu8iZDiRczqkdxHEdxQcHExwcHCK7zk4OKi9/idj2H60cW9plZ7l03N06NCBV69eERYWpoliZ0ga7YFTqVRoaX08ZM2aNdm1axcA4eHhmJqaavJUQqQ5TXfzZ8mShbJlyzJx4kRq166Ni4sLRYoU4a+//vqeX0sIIdLN59KzfDphJDUppWcBqF+/Pnfu3OHBgwd069aN5s2b8/z5c42VPaPRaAB35swZ+vXrh7OzM1WrVlUCuAIFCihdzkJkFpru5o+MjKR58+Zs2LCB69ev8/fffzNw4EDKly9Pnjx5vudXE0KITCkxPUuHDh2SLRBw6NAhbGxsaNiwIXv27GHu3Lmp3nD/F2g0gBs8eDBly5bF39+fgIAAbt26BYC9vT0nTpzQ5KmESFOJ3fxJ7/A03c0PkC1bNuLj43n9+vW/LrMQQmR0mkjP0qJFC7X0LInevXvHrVu3+Pvvv5WEye3atdNo+TMSjY6Bu3jxIjVq1Ei2ffjw4cTFxWnyVEKkqbTKwp6Uvr4+w4YNY82aNURGRv7rMgshREaXND3L1q1bgf+lZ0ltrCF8TM/St29fnJyckqVnSY2WlhZ6enqaKHaGpNEALqmsWbMq4+ESyY+U+FGklIU9KR0dHebOnYtKpcLb2zsdSiiEEOljxowZTJ8+nTNnznD69Gk8PDySpWd5+PAhfn5+wMf0LD4+Pnh
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"\n",
"import matplotlib.pyplot as plt\n",
"\n",
"\n",
"#embeddings_cuda = embeddings.to(torch.device(\"cuda\"))\n",
"\n",
"functions = {\n",
" \"1) MHA wrapper class\": mha_ch03_wrapper,\n",
" \"2) MHA Ch03\": mha_ch03,\n",
" \"3) MHA with combined QKV weights\": mha_combined_qkv,\n",
" \"4) MHA with PyTorch scaled_dot_product_attention\": mha_pytorch_scaled,\n",
" \"5) PyTorch MHA class defaults\": mha_pytorch_class_default,\n",
" \"6) PyTorch MHA with need_weights=False\": mha_pytorch_class_noweights\n",
"}\n",
"execution_times = [time_pytorch_function(fn, embeddings) for name,fn in functions.items()]\n",
"\n",
"\n",
"# Plotting\n",
"\n",
"# Customize further for dark mode aesthetics\n",
"plt.rcParams['figure.facecolor'] = '#121212' # Dark figure background\n",
"plt.rcParams['axes.facecolor'] = '#121212' # Dark axes background\n",
"plt.rcParams['axes.edgecolor'] = 'white' # White axes border\n",
"plt.rcParams['axes.labelcolor'] = 'white' # White labels\n",
"plt.rcParams['text.color'] = 'white' # White text\n",
"plt.rcParams['xtick.color'] = 'white' # White x ticks\n",
"plt.rcParams['ytick.color'] = 'white' # White y ticks\n",
"plt.rcParams['grid.color'] = '#444444' # Lighter grid lines for contrast\n",
"plt.rcParams['lines.linewidth'] = 2 # Thicker plot lines for visibility\n",
"plt.rcParams['lines.markersize'] = 8 # Larger markers for visibility\n",
"\n",
"fig, ax = plt.subplots()\n",
"bars = plt.bar(functions.keys(), execution_times)\n",
"\n",
"plt.ylabel('Execution time (ms)')\n",
"plt.xticks(rotation=45, ha=\"right\")\n",
"\n",
"# Calculate new ylim with a margin\n",
"max_execution_time = max(execution_times)\n",
"upper_ylim = max_execution_time + 0.2 * max_execution_time # Adding a 20% margin\n",
"\n",
"plt.ylim(0, upper_ylim) # Setting new ylim\n",
"\n",
"# Annotate bars with execution times\n",
"for bar in bars:\n",
" yval = bar.get_height()\n",
" plt.text(bar.get_x() + bar.get_width()/2, yval + (0.05 * upper_ylim), round(yval, 2), ha='center', va='bottom')\n",
"\n",
"\n",
"plt.tight_layout()\n",
"plt.savefig(\"2.pdf\")\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d3e1137b-9acc-4cc5-bcbf-0e8533839f06",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"gpuType": "A100",
"machine_shape": "hm",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}