2024-03-06 08:30:32 -06:00
{
2024-03-13 08:37:54 -05:00
"cells": [
2024-03-19 09:26:26 -05:00
{
"cell_type": "markdown",
"id": "e2e65c03-36d4-413f-9b23-5cdd816729ab",
2024-08-10 09:44:11 -05:00
"metadata": {
"id": "e2e65c03-36d4-413f-9b23-5cdd816729ab"
},
2024-03-19 09:26:26 -05:00
"source": [
2024-05-24 07:20:37 -05:00
"<table style=\"width:100%\">\n",
"<tr>\n",
"<td style=\"vertical-align:middle; text-align:left;\">\n",
"<font size=\"2\">\n",
"Supplementary code for the <a href=\"http://mng.bz/orYv\">Build a Large Language Model From Scratch</a> book by <a href=\"https://sebastianraschka.com\">Sebastian Raschka</a><br>\n",
"<br>Code repository: <a href=\"https://github.com/rasbt/LLMs-from-scratch\">https://github.com/rasbt/LLMs-from-scratch</a>\n",
"</font>\n",
"</td>\n",
"<td style=\"vertical-align:middle; text-align:left;\">\n",
"<a href=\"http://mng.bz/orYv\"><img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/cover-small.webp\" width=\"100px\"></a>\n",
"</td>\n",
"</tr>\n",
"</table>"
2024-03-19 09:26:26 -05:00
]
},
2024-03-13 08:37:54 -05:00
{
"cell_type": "markdown",
"id": "6f678e62-7bcb-4405-86ae-dce94f494303",
"metadata": {
"id": "6f678e62-7bcb-4405-86ae-dce94f494303"
},
"source": [
2024-03-23 07:27:43 -05:00
"# Comparing Efficient Multi-Head Attention Implementations"
2024-03-13 08:37:54 -05:00
]
},
{
"cell_type": "markdown",
"id": "b742938a-4bfc-4527-a1f1-d5963508967d",
"metadata": {
"id": "b742938a-4bfc-4527-a1f1-d5963508967d"
},
"source": [
"This code notebook compares different ways to implement causal multi-head attention used in decoder-style LLMs like GPT, Llama, etc."
]
},
{
"cell_type": "code",
2024-08-14 03:57:41 +02:00
"execution_count": 2,
2024-03-13 08:37:54 -05:00
"id": "7898551e-f582-48ac-9f66-3632abe2a93f",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "7898551e-f582-48ac-9f66-3632abe2a93f",
2024-08-14 03:57:41 +02:00
"outputId": "3aa27e4f-402c-4adc-f2d1-271bc6e0385d"
2024-03-13 08:37:54 -05:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-08-14 03:57:41 +02:00
"PyTorch version: 2.5.0.dev20240813\n"
2024-03-13 08:37:54 -05:00
]
2024-03-08 09:30:55 -06:00
}
2024-03-13 08:37:54 -05:00
],
"source": [
"import torch\n",
"\n",
"torch.manual_seed(123)\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"print(f\"PyTorch version: {torch.__version__}\")\n",
"\n",
"batch_size = 8\n",
"context_len = 1024\n",
"embed_dim = 768\n",
"embeddings = torch.randn((batch_size, context_len, embed_dim), device=device)"
]
},
{
"cell_type": "markdown",
"id": "2f9bb1b6-a1e5-4e0a-884d-0f31b374a8d6",
"metadata": {
"id": "2f9bb1b6-a1e5-4e0a-884d-0f31b374a8d6"
},
"source": [
2024-03-23 07:27:43 -05:00
"<br>\n",
" \n",
"\n",
2024-03-13 08:37:54 -05:00
"## 1) CausalAttention MHA wrapper class from chapter 3"
]
},
{
"cell_type": "code",
2024-08-14 03:57:41 +02:00
"execution_count": 3,
2024-03-13 08:37:54 -05:00
"id": "297c93ed-aec0-4896-bb89-42c4b294d3d1",
"metadata": {
2024-03-09 10:20:08 -06:00
"colab": {
2024-03-13 08:37:54 -05:00
"base_uri": "https://localhost:8080/"
},
"id": "297c93ed-aec0-4896-bb89-42c4b294d3d1",
2024-08-14 03:57:41 +02:00
"outputId": "e76a6a62-7a52-4c6b-aa36-e90cea0cd415"
2024-03-13 08:37:54 -05:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([8, 1024, 768])\n"
]
}
],
"source": [
"from ch03 import MultiHeadAttentionWrapper as Ch03_MHA_Wrapper\n",
"\n",
"mha_ch03_wrapper = Ch03_MHA_Wrapper(\n",
" d_in=embed_dim,\n",
" d_out=embed_dim//12,\n",
2024-04-04 07:27:41 -05:00
" context_length=context_len,\n",
2024-03-13 08:37:54 -05:00
" dropout=0.0,\n",
" num_heads=12,\n",
" qkv_bias=False\n",
").to(device)\n",
"\n",
"out = mha_ch03_wrapper(embeddings)\n",
"print(out.shape)"
]
},
{
"cell_type": "markdown",
"id": "21930804-b327-40b1-8e63-94dcad39ce7b",
"metadata": {
"id": "21930804-b327-40b1-8e63-94dcad39ce7b"
},
"source": [
2024-03-23 07:27:43 -05:00
"<br>\n",
" \n",
"\n",
2024-03-13 08:37:54 -05:00
"## 2) The multi-head attention class from chapter 3"
]
},
{
"cell_type": "code",
2024-08-14 03:57:41 +02:00
"execution_count": 4,
2024-03-13 08:37:54 -05:00
"id": "4ee6a61b-d25c-4a0c-8a59-f285544e3710",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "4ee6a61b-d25c-4a0c-8a59-f285544e3710",
2024-08-14 03:57:41 +02:00
"outputId": "650c8992-a6c6-4f28-938a-ee9297131d38"
2024-03-13 08:37:54 -05:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([8, 1024, 768])\n"
]
}
],
"source": [
"from ch03 import MultiHeadAttention as Ch03_MHA\n",
"\n",
"mha_ch03 = Ch03_MHA(\n",
" d_in=embed_dim,\n",
" d_out=embed_dim,\n",
2024-04-04 07:27:41 -05:00
" context_length=context_len,\n",
2024-03-13 08:37:54 -05:00
" dropout=0.0,\n",
" num_heads=12,\n",
" qkv_bias=False\n",
").to(device)\n",
"\n",
"out = mha_ch03(embeddings)\n",
"print(out.shape)"
]
},
{
"cell_type": "markdown",
"id": "73cd11da-ea3b-4081-b483-c4965dfefbc4",
"metadata": {
"id": "73cd11da-ea3b-4081-b483-c4965dfefbc4"
},
"source": [
2024-03-23 07:27:43 -05:00
"<br>\n",
" \n",
"\n",
2024-03-13 08:37:54 -05:00
"## 3) An alternative multi-head attention with combined weights"
]
},
{
"cell_type": "markdown",
"id": "1fa1a5ea-eaff-4d2d-aaf0-b34cdb6fd4dd",
"metadata": {
"id": "1fa1a5ea-eaff-4d2d-aaf0-b34cdb6fd4dd"
},
"source": [
2024-08-14 03:57:41 +02:00
"- The code for the `MultiHeadAttentionCombinedQKV` class below is based on code that was kindly shared by [Rayed Bin Wahed](https://github.com/rasbt/LLMs-from-scratch/discussions/51)\n",
"- The main difference between the `MultiHeadAttentionCombinedQKV` class and the `MultiHeadAttention` class used in chapter 3 is that `MultiHeadAttentionCombinedQKV` uses a single weight matrix, `self.qkv = nn.Linear(d_in, 3 * d_out, bias=qkv_bias)` instead of separate weight matrices:\n",
2024-03-13 08:37:54 -05:00
"\n",
" - `self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)`\n",
" - `self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)`\n",
" - `self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)`\n",
"\n",
"- Here, `self.qkv` combines all three weight matrices `self.W_query`, `self.W_key`, and `self.W_value` to carry out the query, key, and value computation in a single step\n",
"- Using `q, k, v = qkv.unbind(0)`, we obtain the individual query, key, and value tensors, which are then used similarly to the query, key, and value tensors in the `MultiHeadAttention` class in chapter 3"
]
},
{
"cell_type": "code",
2024-08-14 03:57:41 +02:00
"execution_count": 5,
2024-03-13 08:37:54 -05:00
"id": "9a6bd0a2-f27c-4602-afa0-c96cd295c1a6",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "9a6bd0a2-f27c-4602-afa0-c96cd295c1a6",
2024-08-14 03:57:41 +02:00
"outputId": "f0bae195-7caf-4aee-efd6-d55a56c1d4d3"
2024-03-13 08:37:54 -05:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([8, 1024, 768])\n"
]
}
],
"source": [
"import torch.nn as nn\n",
"\n",
"\n",
"class MultiHeadAttentionCombinedQKV(nn.Module):\n",
2024-04-04 07:27:41 -05:00
" def __init__(self, d_in, d_out, num_heads, context_length, dropout=0.0, qkv_bias=False):\n",
2024-03-13 08:37:54 -05:00
" super().__init__()\n",
"\n",
" assert d_out % num_heads == 0, \"embed_dim is indivisible by num_heads\"\n",
"\n",
" self.num_heads = num_heads\n",
2024-04-04 07:27:41 -05:00
" self.context_length = context_length\n",
2024-03-13 08:37:54 -05:00
" self.head_dim = d_out // num_heads\n",
"\n",
" self.qkv = nn.Linear(d_in, 3 * d_out, bias=qkv_bias)\n",
2024-05-27 07:46:29 -05:00
" self.proj = nn.Linear(d_out, d_out)\n",
2024-03-13 08:37:54 -05:00
" self.dropout = nn.Dropout(dropout)\n",
"\n",
" self.register_buffer(\n",
2024-04-04 07:27:41 -05:00
" \"mask\", torch.triu(torch.ones(context_length, context_length), diagonal=1)\n",
2024-03-13 08:37:54 -05:00
" )\n",
"\n",
" def forward(self, x):\n",
" batch_size, num_tokens, embed_dim = x.shape\n",
"\n",
" # (b, num_tokens, embed_dim) --> (b, num_tokens, 3 * embed_dim)\n",
" qkv = self.qkv(x)\n",
"\n",
" # (b, num_tokens, 3 * embed_dim) --> (b, num_tokens, 3, num_heads, head_dim)\n",
2024-04-26 17:13:08 -05:00
" qkv = qkv.view(batch_size, num_tokens, 3, self.num_heads, self.head_dim)\n",
2024-03-13 08:37:54 -05:00
"\n",
" # (b, num_tokens, 3, num_heads, head_dim) --> (3, b, num_heads, num_tokens, head_dim)\n",
" qkv = qkv.permute(2, 0, 3, 1, 4)\n",
"\n",
" # (3, b, num_heads, num_tokens, head_dim) -> 3 times (b, num_head, num_tokens, head_dim)\n",
" queries, keys, values = qkv.unbind(0)\n",
"\n",
" # (b, num_heads, num_tokens, head_dim) --> (b, num_heads, num_tokens, num_tokens)\n",
" attn_scores = queries @ keys.transpose(-2, -1)\n",
" attn_scores = attn_scores.masked_fill(\n",
" self.mask.bool()[:num_tokens, :num_tokens], -torch.inf\n",
" )\n",
"\n",
" attn_weights = torch.softmax(attn_scores / keys.shape[-1]**-0.5, dim=-1)\n",
" attn_weights = self.dropout(attn_weights)\n",
"\n",
" # (b, num_heads, num_tokens, num_tokens) --> (b, num_heads, num_tokens, head_dim)\n",
" context_vec = attn_weights @ values\n",
"\n",
" # (b, num_heads, num_tokens, head_dim) --> (b, num_tokens, num_heads, head_dim)\n",
" context_vec = context_vec.transpose(1, 2)\n",
"\n",
" # (b, num_tokens, num_heads, head_dim) --> (b, num_tokens, embed_dim)\n",
2024-04-26 17:13:08 -05:00
" context_vec = context_vec.contiguous().view(batch_size, num_tokens, embed_dim)\n",
2024-03-13 08:37:54 -05:00
"\n",
" context_vec = self.proj(context_vec)\n",
"\n",
" return context_vec\n",
"\n",
"\n",
"mha_combined_qkv = MultiHeadAttentionCombinedQKV(\n",
" d_in=embed_dim,\n",
" d_out=embed_dim,\n",
2024-04-04 07:27:41 -05:00
" context_length=context_len,\n",
2024-03-13 08:37:54 -05:00
" dropout=0.0,\n",
" num_heads=12,\n",
" qkv_bias=False\n",
").to(device)\n",
"\n",
"out = mha_combined_qkv(embeddings)\n",
"print(out.shape)"
]
},
{
"cell_type": "markdown",
"id": "48a042d3-ee78-4c29-bf63-d92fe6706632",
"metadata": {
"id": "48a042d3-ee78-4c29-bf63-d92fe6706632"
},
"source": [
2024-03-23 07:27:43 -05:00
"<br>\n",
" \n",
"\n",
2024-08-10 09:44:11 -05:00
"## 4) Multihead attention with PyTorch's scaled dot product attention and FlashAttention"
2024-03-13 08:37:54 -05:00
]
},
{
"cell_type": "markdown",
"id": "f78e346f-3b85-44e6-9feb-f01131381148",
"metadata": {
"id": "f78e346f-3b85-44e6-9feb-f01131381148"
},
"source": [
2024-08-10 09:44:11 -05:00
"- The implementation below uses PyTorch's [`scaled_dot_product_attention`](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) function, which implements a memory-optimized version of self-attention called [FlashAttention](https://arxiv.org/abs/2205.14135)"
2024-03-13 08:37:54 -05:00
]
},
{
"cell_type": "code",
2024-08-14 03:57:41 +02:00
"execution_count": 6,
2024-03-13 08:37:54 -05:00
"id": "1b8e5a0d-1f65-4a03-bf6e-723f0cc428f5",
"metadata": {
"id": "1b8e5a0d-1f65-4a03-bf6e-723f0cc428f5"
},
"outputs": [],
"source": [
"class MHAPyTorchScaledDotProduct(nn.Module):\n",
2024-04-04 07:27:41 -05:00
" def __init__(self, d_in, d_out, num_heads, context_length, dropout=0.0, qkv_bias=False):\n",
2024-03-13 08:37:54 -05:00
" super().__init__()\n",
"\n",
" assert d_out % num_heads == 0, \"embed_dim is indivisible by num_heads\"\n",
"\n",
" self.num_heads = num_heads\n",
2024-04-04 07:27:41 -05:00
" self.context_length = context_length\n",
2024-03-13 08:37:54 -05:00
" self.head_dim = d_out // num_heads\n",
" self.d_out = d_out\n",
"\n",
" self.qkv = nn.Linear(d_in, 3 * d_out, bias=qkv_bias)\n",
2024-05-26 15:38:35 -05:00
" self.proj = nn.Linear(d_out, d_out)\n",
2024-03-13 08:37:54 -05:00
" self.dropout = dropout\n",
"\n",
" def forward(self, x):\n",
" batch_size, num_tokens, embed_dim = x.shape\n",
"\n",
" # (b, num_tokens, embed_dim) --> (b, num_tokens, 3 * embed_dim)\n",
" qkv = self.qkv(x)\n",
"\n",
" # (b, num_tokens, 3 * embed_dim) --> (b, num_tokens, 3, num_heads, head_dim)\n",
2024-04-26 17:13:08 -05:00
" qkv = qkv.view(batch_size, num_tokens, 3, self.num_heads, self.head_dim)\n",
2024-03-13 08:37:54 -05:00
"\n",
" # (b, num_tokens, 3, num_heads, head_dim) --> (3, b, num_heads, num_tokens, head_dim)\n",
" qkv = qkv.permute(2, 0, 3, 1, 4)\n",
"\n",
" # (3, b, num_heads, num_tokens, head_dim) -> 3 times (b, num_heads, num_tokens, head_dim)\n",
2024-04-26 17:13:08 -05:00
" queries, keys, values = qkv\n",
2024-03-13 08:37:54 -05:00
"\n",
" use_dropout = 0. if not self.training else self.dropout\n",
2024-08-10 09:44:11 -05:00
"\n",
2024-03-13 08:37:54 -05:00
" context_vec = nn.functional.scaled_dot_product_attention(\n",
" queries, keys, values, attn_mask=None, dropout_p=use_dropout, is_causal=True)\n",
"\n",
" # Combine heads, where self.d_out = self.num_heads * self.head_dim\n",
" context_vec = context_vec.transpose(1, 2).contiguous().view(batch_size, num_tokens, self.d_out)\n",
"\n",
2024-04-26 17:13:08 -05:00
" context_vec = self.proj(context_vec)\n",
"\n",
2024-03-13 08:37:54 -05:00
" return context_vec"
]
},
{
"cell_type": "code",
2024-08-14 03:57:41 +02:00
"execution_count": 7,
2024-03-13 08:37:54 -05:00
"id": "fbc8ba92-3471-41cb-b1b2-4c0ef5be392b",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "fbc8ba92-3471-41cb-b1b2-4c0ef5be392b",
2024-08-14 03:57:41 +02:00
"outputId": "da8b6836-35c6-43f5-a9e2-99217d517101"
2024-03-13 08:37:54 -05:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([8, 1024, 768])\n"
]
}
],
"source": [
"mha_pytorch_scaled = MHAPyTorchScaledDotProduct(\n",
" d_in=embed_dim,\n",
" d_out=embed_dim,\n",
2024-04-04 07:27:41 -05:00
" context_length=context_len,\n",
2024-03-13 08:37:54 -05:00
" dropout=0.0,\n",
" num_heads=12,\n",
" qkv_bias=False\n",
").to(device)\n",
"\n",
"out = mha_pytorch_scaled(embeddings)\n",
"print(out.shape)"
]
},
2024-08-10 09:44:11 -05:00
{
"cell_type": "markdown",
"id": "51492724-6018-49f6-8bf6-ae9e585229c3",
"metadata": {
"id": "51492724-6018-49f6-8bf6-ae9e585229c3"
},
"source": [
"<br>\n",
" \n",
"\n",
"## 5) PyTorch's scaled dot product attention without FlashAttention\n",
"\n",
"- This is similar to above, except that we disable FlashAttention by passing an explicit causal mask"
]
},
{
"cell_type": "code",
2024-08-14 03:57:41 +02:00
"execution_count": 8,
2024-08-10 09:44:11 -05:00
"id": "bad53538-e905-4065-ba0c-caacdfec5a0b",
"metadata": {
"id": "bad53538-e905-4065-ba0c-caacdfec5a0b"
},
"outputs": [],
"source": [
"class MHAPyTorchSDPAWithoutFlash(nn.Module):\n",
" def __init__(self, d_in, d_out, num_heads, context_length, dropout=0.0, qkv_bias=False):\n",
" super().__init__()\n",
"\n",
" assert d_out % num_heads == 0, \"embed_dim is indivisible by num_heads\"\n",
"\n",
" self.num_heads = num_heads\n",
" self.context_length = context_length\n",
" self.head_dim = d_out // num_heads\n",
" self.d_out = d_out\n",
"\n",
" self.qkv = nn.Linear(d_in, 3 * d_out, bias=qkv_bias)\n",
" self.proj = nn.Linear(d_out, d_out)\n",
" self.dropout = dropout\n",
" self.register_buffer(\"mask\", torch.triu(torch.ones(context_length, context_length), diagonal=1).bool())\n",
"\n",
" def forward(self, x):\n",
" batch_size, num_tokens, embed_dim = x.shape\n",
"\n",
" # (b, num_tokens, embed_dim) --> (b, num_tokens, 3 * embed_dim)\n",
" qkv = self.qkv(x)\n",
"\n",
" # (b, num_tokens, 3 * embed_dim) --> (b, num_tokens, 3, num_heads, head_dim)\n",
" qkv = qkv.view(batch_size, num_tokens, 3, self.num_heads, self.head_dim)\n",
"\n",
" # (b, num_tokens, 3, num_heads, head_dim) --> (3, b, num_heads, num_tokens, head_dim)\n",
" qkv = qkv.permute(2, 0, 3, 1, 4)\n",
"\n",
" # (3, b, num_heads, num_tokens, head_dim) -> 3 times (b, num_heads, num_tokens, head_dim)\n",
" queries, keys, values = qkv\n",
"\n",
" use_dropout = 0. if not self.training else self.dropout\n",
"\n",
" # Ensure attn_mask is compatible with expected shape and `batch_first=True`\n",
" # No need to manually adjust for num_heads; ensure it's right for the sequence\n",
" if self.context_length >= num_tokens:\n",
" attn_mask = self.mask[:num_tokens, :num_tokens]\n",
" else:\n",
" attn_mask = self.mask[:self.context_length, :self.context_length]\n",
"\n",
" context_vec = nn.functional.scaled_dot_product_attention(\n",
" queries, keys, values, attn_mask=attn_mask, dropout_p=use_dropout, is_causal=False)\n",
"\n",
" # Combine heads, where self.d_out = self.num_heads * self.head_dim\n",
" context_vec = context_vec.transpose(1, 2).contiguous().view(batch_size, num_tokens, self.d_out)\n",
"\n",
" context_vec = self.proj(context_vec)\n",
"\n",
" return context_vec"
]
},
{
"cell_type": "code",
2024-08-14 03:57:41 +02:00
"execution_count": 9,
2024-08-10 09:44:11 -05:00
"id": "f3da7850-e772-47d3-bd51-22d077b01412",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "f3da7850-e772-47d3-bd51-22d077b01412",
2024-08-14 03:57:41 +02:00
"outputId": "3c726fe1-6601-4a30-c2d2-e9fd48547c50"
2024-08-10 09:44:11 -05:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([8, 1024, 768])\n"
]
}
],
"source": [
"mha_pytorch_sdpa_no_flash = MHAPyTorchSDPAWithoutFlash(\n",
" d_in=embed_dim,\n",
" d_out=embed_dim,\n",
" context_length=context_len,\n",
" dropout=0.0,\n",
" num_heads=12,\n",
" qkv_bias=False\n",
").to(device)\n",
"\n",
"out = mha_pytorch_sdpa_no_flash(embeddings)\n",
"print(out.shape)"
]
},
2024-03-13 08:37:54 -05:00
{
"cell_type": "markdown",
"id": "351c318f-4835-4d74-8d58-a070222447c4",
"metadata": {
"id": "351c318f-4835-4d74-8d58-a070222447c4"
},
"source": [
2024-03-23 07:27:43 -05:00
"<br>\n",
" \n",
"\n",
2024-08-10 09:44:11 -05:00
"## 6) Using PyTorch's torch.nn.MultiheadAttention"
2024-03-13 08:37:54 -05:00
]
},
{
"cell_type": "markdown",
"id": "74a6d060-6324-48fa-a35c-cb09f2a48965",
"metadata": {
"id": "74a6d060-6324-48fa-a35c-cb09f2a48965"
},
"source": [
"- Below, we use PyTorch's [torch.nn.MultiheadAttention](https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html) implementation"
]
},
{
"cell_type": "code",
2024-08-14 03:57:41 +02:00
"execution_count": 10,
2024-03-13 08:37:54 -05:00
"id": "3799c7ef-3155-42c6-a829-f95656453ae0",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "3799c7ef-3155-42c6-a829-f95656453ae0",
2024-08-14 03:57:41 +02:00
"outputId": "dcb11757-c8c2-4909-c0e5-5ba2cf7b1096"
2024-03-13 08:37:54 -05:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([8, 1024, 768])\n"
]
}
],
"source": [
"import torch.nn as nn\n",
"\n",
"\n",
"class MHAPyTorchClass(nn.Module):\n",
2024-04-04 07:27:41 -05:00
" def __init__(self, d_in, d_out, num_heads, context_length, dropout=0.0, qkv_bias=False, need_weights=True):\n",
2024-03-13 08:37:54 -05:00
" super().__init__()\n",
"\n",
2024-04-04 07:27:41 -05:00
" self.context_length = context_length\n",
2024-03-13 08:37:54 -05:00
" self.multihead_attn = nn.MultiheadAttention(\n",
" embed_dim=d_out,\n",
" num_heads=num_heads,\n",
" dropout=dropout,\n",
" bias=qkv_bias,\n",
" add_bias_kv=qkv_bias,\n",
" batch_first=True,\n",
" )\n",
"\n",
" self.need_weights = need_weights\n",
" self.proj = nn.Linear(d_out, d_out)\n",
2024-04-04 07:27:41 -05:00
" self.register_buffer(\"mask\", torch.triu(torch.ones(context_length, context_length), diagonal=1).bool())\n",
2024-03-13 08:37:54 -05:00
"\n",
" def forward(self, x):\n",
" batch_size, num_tokens, _ = x.shape\n",
"\n",
" # Ensure attn_mask is compatible with expected shape and `batch_first=True`\n",
" # No need to manually adjust for num_heads; ensure it's right for the sequence\n",
2024-04-04 07:27:41 -05:00
" if self.context_length >= num_tokens:\n",
2024-03-13 08:37:54 -05:00
" attn_mask = self.mask[:num_tokens, :num_tokens]\n",
" else:\n",
2024-04-04 07:27:41 -05:00
" attn_mask = self.mask[:self.context_length, :self.context_length]\n",
2024-03-13 08:37:54 -05:00
"\n",
" # attn_mask broadcasting will handle batch_size dimension implicitly\n",
" attn_output, _ = self.multihead_attn(\n",
" x, x, x, attn_mask=attn_mask, need_weights=self.need_weights\n",
" )\n",
"\n",
" output = self.proj(attn_output)\n",
"\n",
" return output\n",
"\n",
"\n",
"mha_pytorch_class_default = MHAPyTorchClass(\n",
" d_in=embed_dim,\n",
" d_out=embed_dim,\n",
2024-04-04 07:27:41 -05:00
" context_length=context_len,\n",
2024-03-13 08:37:54 -05:00
" dropout=0.0,\n",
" num_heads=12,\n",
" qkv_bias=False\n",
").to(device)\n",
"\n",
"out = mha_pytorch_class_default(embeddings)\n",
"print(out.shape)"
]
},
{
"cell_type": "markdown",
"id": "a3953bff-1056-4de2-bfd1-dfccf659eee4",
"metadata": {
"id": "a3953bff-1056-4de2-bfd1-dfccf659eee4"
},
"source": [
2024-03-23 07:27:43 -05:00
"<br>\n",
" \n",
"\n",
2024-08-10 09:44:11 -05:00
"## 7) Using PyTorch's torch.nn.MultiheadAttention with `scaled_dot_product_attention`"
2024-03-13 08:37:54 -05:00
]
},
{
"cell_type": "markdown",
"id": "d2164859-31a0-4537-b4fb-27d57675ba77",
"metadata": {
"id": "d2164859-31a0-4537-b4fb-27d57675ba77"
},
"source": [
"- Set `need_weights` (default `True`) to need_weights=False so that MultiheadAttention uses `scaled_dot_product_attention` [according to the documentation](https://github.com/pytorch/pytorch/blob/71d020262793542974cf13b30f2a9099773f015c/torch/nn/modules/activation.py#L1096)\n",
"\n",
"> need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``.\n",
" Set ``need_weights=False`` to use the optimized ``scaled_dot_product_attention``\n",
" and achieve the best performance for MHA.\n",
" Default: ``True``."
]
},
{
"cell_type": "code",
2024-08-14 03:57:41 +02:00
"execution_count": 11,
2024-03-13 08:37:54 -05:00
"id": "4a4c2afe-5e1f-4bd7-a118-67031176f147",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "4a4c2afe-5e1f-4bd7-a118-67031176f147",
2024-08-14 03:57:41 +02:00
"outputId": "60a22fd6-fda4-478a-98d9-331623814018"
2024-03-13 08:37:54 -05:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([8, 1024, 768])\n"
]
}
],
"source": [
"mha_pytorch_class_noweights = MHAPyTorchClass(\n",
" d_in=embed_dim,\n",
" d_out=embed_dim,\n",
2024-04-04 07:27:41 -05:00
" context_length=context_len,\n",
2024-03-13 08:37:54 -05:00
" dropout=0.0,\n",
" num_heads=12,\n",
" qkv_bias=False,\n",
" need_weights=False # NEW!\n",
").to(device)\n",
"\n",
"out = mha_pytorch_class_noweights(embeddings)\n",
"print(out.shape)"
]
},
2024-08-10 09:44:11 -05:00
{
"cell_type": "markdown",
"id": "21f4ff35-651c-4e47-bfa1-016f3de01ecc",
"metadata": {
"id": "21f4ff35-651c-4e47-bfa1-016f3de01ecc"
},
"source": [
"<br>\n",
" \n",
"\n",
"## 8) Using PyTorch's FlexAttention\n",
"\n",
"- See [FlexAttention: The Flexibility of PyTorch with the Performance of FlashAttention](https://pytorch.org/blog/flexattention/) to learn more about FlexAttention\n",
"- This is currently only supported in PyTorch 2.5 (nightly), which you can install on a CPU machine via\n",
"\n",
2024-08-14 03:57:41 +02:00
" ```bash\n",
" pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cpu -U\n",
" ```\n",
2024-08-10 09:44:11 -05:00
"\n",
"- To install PyTorch nighly on a GPU machine, use the following (for more information, also see the installation menu on [pytorch.org](https://pytorch.org/))\n",
"\n",
2024-08-14 03:57:41 +02:00
" ```bash\n",
" pip install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121 -U\n",
" ```"
2024-08-10 09:44:11 -05:00
]
},
{
"cell_type": "code",
2024-08-14 03:57:41 +02:00
"execution_count": 12,
2024-08-10 09:44:11 -05:00
"id": "834318c8-4748-4902-99f0-70ee02bef63e",
"metadata": {
"id": "834318c8-4748-4902-99f0-70ee02bef63e"
},
"outputs": [],
"source": [
"from packaging.version import parse as parse_version\n",
"\n",
"def normalize_version(version):\n",
" parsed_version = parse_version(version)\n",
" return parse_version(f\"{parsed_version.major}.{parsed_version.minor}.{parsed_version.micro}\")\n",
"\n",
"current_version = normalize_version(torch.__version__)\n",
"MIN_TORCH_VERSION = \"2.5.0\"\n",
"required_version = parse_version(MIN_TORCH_VERSION)"
]
},
{
"cell_type": "code",
2024-08-14 03:57:41 +02:00
"execution_count": 14,
2024-08-10 09:44:11 -05:00
"id": "WYyFRCXndVH9",
"metadata": {
"id": "WYyFRCXndVH9"
},
"outputs": [],
"source": [
"if current_version >= required_version:\n",
2024-08-14 03:57:41 +02:00
" from torch.nn.attention.flex_attention import flex_attention, create_block_mask\n",
2024-08-10 09:44:11 -05:00
"\n",
"\n",
"def causal(b, h, q_idx, kv_idx):\n",
" return q_idx >= kv_idx\n",
"\n",
"\n",
"class MHAPyTorchFlexAttention(nn.Module):\n",
"\n",
" def __init__(self, d_in, d_out, num_heads, context_length, dropout=0.0, qkv_bias=False):\n",
" super().__init__()\n",
"\n",
" assert d_out % num_heads == 0, \"embed_dim is indivisible by num_heads\"\n",
"\n",
" self.num_heads = num_heads\n",
" self.context_length = context_length\n",
" self.head_dim = d_out // num_heads\n",
" self.d_out = d_out\n",
"\n",
" self.qkv = nn.Linear(d_in, 3 * d_out, bias=qkv_bias)\n",
" self.proj = nn.Linear(d_out, d_out)\n",
" self.dropout = dropout\n",
" # self.register_buffer(\"block_mask\", create_block_mask(causal, B=None, H=None, Q_LEN=context_length, KV_LEN=context_length))\n",
2024-08-14 03:57:41 +02:00
" # `create_block_mask` function does not support buffers, yet\n",
2024-08-10 09:44:11 -05:00
" self.block_mask = create_block_mask(causal, B=None, H=None, Q_LEN=context_length, KV_LEN=context_length)\n",
"\n",
"\n",
" def forward(self, x):\n",
" batch_size, num_tokens, embed_dim = x.shape\n",
"\n",
" # (b, num_tokens, embed_dim) --> (b, num_tokens, 3 * embed_dim)\n",
" qkv = self.qkv(x)\n",
"\n",
" # (b, num_tokens, 3 * embed_dim) --> (b, num_tokens, 3, num_heads, head_dim)\n",
" qkv = qkv.view(batch_size, num_tokens, 3, self.num_heads, self.head_dim)\n",
"\n",
" # (b, num_tokens, 3, num_heads, head_dim) --> (3, b, num_heads, num_tokens, head_dim)\n",
" qkv = qkv.permute(2, 0, 3, 1, 4)\n",
"\n",
" # (3, b, num_heads, num_tokens, head_dim) -> 3 times (b, num_heads, num_tokens, head_dim)\n",
" queries, keys, values = qkv\n",
"\n",
" use_dropout = 0. if not self.training else self.dropout\n",
"\n",
" # Ensure attn_mask is compatible with expected shape and `batch_first=True`\n",
" # No need to manually adjust for num_heads; ensure it's right for the sequence\n",
" if self.context_length >= num_tokens:\n",
" attn_mask = self.block_mask[:num_tokens, :num_tokens]\n",
" else:\n",
" attn_mask = self.block_mask[:self.context_length, :self.context_length]\n",
"\n",
" context_vec = flex_attention(queries, keys, values, block_mask=attn_mask)\n",
"\n",
" # Combine heads, where self.d_out = self.num_heads * self.head_dim\n",
" context_vec = context_vec.transpose(1, 2).contiguous().view(batch_size, num_tokens, self.d_out)\n",
"\n",
" context_vec = self.proj(context_vec)\n",
"\n",
" return context_vec"
]
},
{
"cell_type": "code",
2024-08-14 03:57:41 +02:00
"execution_count": 15,
2024-08-10 09:44:11 -05:00
"id": "9cdaaf8a-f956-44bc-932f-4d33448e8aaf",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "9cdaaf8a-f956-44bc-932f-4d33448e8aaf",
2024-08-14 03:57:41 +02:00
"outputId": "0d13771e-46df-4200-e20d-422fcb8144b3"
2024-08-10 09:44:11 -05:00
},
"outputs": [],
"source": [
2024-08-14 03:57:41 +02:00
"if current_version >= required_version and torch.cuda.is_available():\n",
2024-08-10 09:44:11 -05:00
"\n",
" mha_pytorch_flex = MHAPyTorchFlexAttention(\n",
" d_in=embed_dim,\n",
" d_out=embed_dim,\n",
" context_length=context_len,\n",
" dropout=0.0,\n",
" num_heads=12,\n",
" qkv_bias=False\n",
" ).to(device)\n",
"\n",
" out = mha_pytorch_flex(embeddings)\n",
" print(out.shape)"
]
},
2024-03-13 08:37:54 -05:00
{
"cell_type": "markdown",
"id": "8877de71-f84f-4f6d-bc87-7552013b6301",
"metadata": {
"id": "8877de71-f84f-4f6d-bc87-7552013b6301"
},
"source": [
2024-03-23 07:27:43 -05:00
"<br>\n",
" \n",
"\n",
2024-03-13 08:37:54 -05:00
"## Quick speed comparison (M3 Macbook Air CPU)"
]
},
{
"cell_type": "code",
2024-08-14 03:57:41 +02:00
"execution_count": 16,
2024-08-10 09:44:11 -05:00
"id": "219cf93a-078f-434d-888c-2458d0731285",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "219cf93a-078f-434d-888c-2458d0731285",
2024-08-14 03:57:41 +02:00
"outputId": "a10b52d4-b4e6-43c2-9677-113c41edd3b7"
2024-08-10 09:44:11 -05:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-08-14 03:57:41 +02:00
"PyTorch version: 2.5.0.dev20240813\n",
2024-08-10 09:44:11 -05:00
"Running on cpu\n"
]
}
],
"source": [
"torch.manual_seed(123)\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"print(f\"PyTorch version: {torch.__version__}\")\n",
"print(f\"Running on {device}\")"
]
},
{
"cell_type": "code",
2024-08-14 03:57:41 +02:00
"execution_count": 17,
2024-03-13 08:37:54 -05:00
"id": "a97c0b2e-6593-49d8-98bc-2267b3aa610f",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "a97c0b2e-6593-49d8-98bc-2267b3aa610f",
2024-08-14 03:57:41 +02:00
"outputId": "7bcd7da4-d115-4ba6-efba-377a0bd7d3a8"
2024-03-13 08:37:54 -05:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-08-14 03:57:41 +02:00
"183 ms ± 6.95 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
2024-03-13 08:37:54 -05:00
]
}
],
"source": [
"## 1) CausalAttention MHA wrapper class from chapter 3\n",
"%timeit mha_ch03_wrapper(embeddings)"
]
},
{
"cell_type": "code",
2024-08-14 03:57:41 +02:00
"execution_count": 18,
2024-03-13 08:37:54 -05:00
"id": "19db9c2c-8e75-431a-8eef-0b4d8284e6e6",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "19db9c2c-8e75-431a-8eef-0b4d8284e6e6",
2024-08-14 03:57:41 +02:00
"outputId": "b04b4d0d-71aa-4944-f02b-131bf5a50202"
2024-03-13 08:37:54 -05:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-08-14 03:57:41 +02:00
"179 ms ± 1.22 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
2024-03-13 08:37:54 -05:00
]
}
],
"source": [
"## 2) The multi-head attention class from chapter 3\n",
"%timeit mha_ch03(embeddings)"
]
},
{
"cell_type": "code",
2024-08-14 03:57:41 +02:00
"execution_count": 19,
2024-03-13 08:37:54 -05:00
"id": "aa526ee0-7a88-4f34-a49a-f8f97da83779",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "aa526ee0-7a88-4f34-a49a-f8f97da83779",
2024-08-14 03:57:41 +02:00
"outputId": "5436928a-7b98-4c40-bf51-97973f13327e"
2024-03-13 08:37:54 -05:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-08-14 03:57:41 +02:00
"197 ms ± 1.78 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
2024-03-13 08:37:54 -05:00
]
}
],
"source": [
"## 3) An alternative multi-head attention with combined weights\n",
"%timeit mha_combined_qkv(embeddings)"
]
},
{
"cell_type": "code",
2024-08-14 03:57:41 +02:00
"execution_count": 20,
2024-03-13 08:37:54 -05:00
"id": "cc2b4256-16d8-4c34-9fd0-d4b4af0e60fa",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "cc2b4256-16d8-4c34-9fd0-d4b4af0e60fa",
2024-08-14 03:57:41 +02:00
"outputId": "9e07ce73-a2de-4e2c-8276-64626df9450e"
2024-03-13 08:37:54 -05:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-08-14 03:57:41 +02:00
"65.4 ms ± 315 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
2024-03-13 08:37:54 -05:00
]
}
],
"source": [
"## 4) Multihead attention with PyTorch's scaled dot product attention\n",
"%timeit mha_pytorch_scaled(embeddings)"
]
},
{
"cell_type": "code",
2024-08-14 03:57:41 +02:00
"execution_count": 21,
2024-08-10 09:44:11 -05:00
"id": "c44305ce-9f61-451a-b9ef-30caba222357",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "c44305ce-9f61-451a-b9ef-30caba222357",
2024-08-14 03:57:41 +02:00
"outputId": "6bab4a24-5bb4-4ad6-b260-3b442f598950"
2024-08-10 09:44:11 -05:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-08-14 03:57:41 +02:00
"112 ms ± 20.2 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
2024-08-10 09:44:11 -05:00
]
}
],
"source": [
"## 5) PyTorch's scaled dot product attention without FlashAttention\n",
"%timeit mha_pytorch_sdpa_no_flash(embeddings)"
]
},
{
"cell_type": "code",
2024-08-14 03:57:41 +02:00
"execution_count": 22,
2024-03-13 08:37:54 -05:00
"id": "0f209e70-ebb6-4a1a-b608-1ff42e41c01d",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "0f209e70-ebb6-4a1a-b608-1ff42e41c01d",
2024-08-14 03:57:41 +02:00
"outputId": "630c49d1-8a06-4148-cd96-a7b2467310a0"
2024-03-13 08:37:54 -05:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-08-14 03:57:41 +02:00
"199 ms ± 3.87 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
2024-03-13 08:37:54 -05:00
]
}
],
"source": [
2024-08-10 09:44:11 -05:00
"## 6) Using PyTorch's torch.nn.MultiheadAttention\n",
2024-03-13 08:37:54 -05:00
"%timeit mha_pytorch_class_default(embeddings)"
]
},
{
"cell_type": "code",
2024-08-14 03:57:41 +02:00
"execution_count": 23,
2024-03-13 08:37:54 -05:00
"id": "3f4968c2-8d40-4ab9-8dba-052b4f77d756",
"metadata": {
2024-08-10 09:44:11 -05:00
"colab": {
"base_uri": "https://localhost:8080/"
},
2024-03-13 08:37:54 -05:00
"id": "3f4968c2-8d40-4ab9-8dba-052b4f77d756",
2024-08-14 03:57:41 +02:00
"outputId": "10f6a268-f9cf-446c-aa83-e87b6a0b4f5c"
2024-03-13 08:37:54 -05:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-08-14 03:57:41 +02:00
"143 ms ± 19.7 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
2024-03-13 08:37:54 -05:00
]
}
],
"source": [
2024-08-10 09:44:11 -05:00
"## 7) Using PyTorch's torch.nn.MultiheadAttention disabling `need_weights`\n",
2024-03-13 08:37:54 -05:00
"%timeit mha_pytorch_class_noweights(embeddings)"
]
},
2024-08-10 09:44:11 -05:00
{
"cell_type": "code",
"execution_count": null,
2024-08-14 03:57:41 +02:00
"id": "c143200a-6e26-4185-af52-6cf2379d72ae",
2024-08-10 09:44:11 -05:00
"metadata": {},
"outputs": [],
"source": [
"## 8) Using PyTorch's FlexAttention\n",
"\n",
2024-08-14 03:57:41 +02:00
"# Requires PyTorch 2.5.0 or newer and currently only supports CUDA PyTorch\n",
2024-08-10 09:44:11 -05:00
"%timeit mha_pytorch_flex(embeddings)"
]
},
2024-03-13 08:37:54 -05:00
{
"cell_type": "markdown",
"id": "a78ff594-6cc2-496d-a302-789fa104c3c9",
"metadata": {
"id": "a78ff594-6cc2-496d-a302-789fa104c3c9"
},
"source": [
2024-03-23 07:27:43 -05:00
"<br>\n",
" \n",
"\n",
2024-03-13 08:37:54 -05:00
"## Quick speed comparison (Nvidia A100 GPU)"
]
},
{
"cell_type": "code",
2024-08-14 03:57:41 +02:00
"execution_count": 24,
2024-08-10 09:44:11 -05:00
"id": "RStnI1pEi6Eo",
"metadata": {
"id": "RStnI1pEi6Eo"
},
"outputs": [],
"source": [
2024-08-12 14:54:12 +02:00
"# Enable tensor cores\n",
2024-08-10 09:44:11 -05:00
"torch.set_float32_matmul_precision(\"high\")"
]
},
{
"cell_type": "code",
2024-08-14 03:57:41 +02:00
"execution_count": 25,
2024-08-10 09:44:11 -05:00
"id": "e8431d75-e1c9-4d9a-b7da-9a1ff391f2bf",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "e8431d75-e1c9-4d9a-b7da-9a1ff391f2bf",
2024-08-14 03:57:41 +02:00
"outputId": "308f8f74-1757-4534-9050-78676bb948a5"
2024-08-10 09:44:11 -05:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-08-14 03:57:41 +02:00
"PyTorch version: 2.5.0.dev20240813+cu121\n",
2024-08-10 09:44:11 -05:00
"Running on cuda\n"
]
}
],
"source": [
"torch.manual_seed(123)\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"print(f\"PyTorch version: {torch.__version__}\")\n",
"print(f\"Running on {device}\")"
]
},
{
"cell_type": "code",
2024-08-14 03:57:41 +02:00
"execution_count": 26,
2024-03-13 08:37:54 -05:00
"id": "707a2a14-a089-48a8-88aa-d328e1e0a9d0",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "707a2a14-a089-48a8-88aa-d328e1e0a9d0",
2024-08-14 03:57:41 +02:00
"outputId": "38942aa6-8bd7-4dd9-e528-b833231999d0"
2024-03-13 08:37:54 -05:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-08-14 03:57:41 +02:00
"4.28 ms ± 26.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
2024-03-13 08:37:54 -05:00
]
}
],
"source": [
"## 1) CausalAttention MHA wrapper class from chapter 3\n",
"%timeit mha_ch03_wrapper(embeddings)"
]
},
{
"cell_type": "code",
2024-08-14 03:57:41 +02:00
"execution_count": 27,
2024-03-13 08:37:54 -05:00
"id": "8686dd69-3655-40e4-a57b-a2c55532a010",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "8686dd69-3655-40e4-a57b-a2c55532a010",
2024-08-14 03:57:41 +02:00
"outputId": "0f8adeba-2119-4361-a439-e48b1e870857"
2024-03-13 08:37:54 -05:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-08-14 03:57:41 +02:00
"3.09 ms ± 108 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n"
2024-03-13 08:37:54 -05:00
]
}
],
"source": [
"## 2) The multi-head attention class from chapter 3\n",
"%timeit mha_ch03(embeddings)"
]
},
{
"cell_type": "code",
2024-08-14 03:57:41 +02:00
"execution_count": 28,
2024-03-13 08:37:54 -05:00
"id": "2209d7df-e54b-4910-ae2b-c78cf684d9bf",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "2209d7df-e54b-4910-ae2b-c78cf684d9bf",
2024-08-14 03:57:41 +02:00
"outputId": "3ce8ddef-df53-4ead-8654-d46f4f343844"
2024-03-13 08:37:54 -05:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-08-14 03:57:41 +02:00
"3.81 ms ± 4.91 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
2024-03-13 08:37:54 -05:00
]
}
],
"source": [
"## 3) An alternative multi-head attention with combined weights\n",
"%timeit mha_combined_qkv(embeddings)"
]
},
{
"cell_type": "code",
2024-08-14 03:57:41 +02:00
"execution_count": 29,
2024-03-13 08:37:54 -05:00
"id": "1075abe2-4839-4fd6-af3e-c09bb3651e26",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "1075abe2-4839-4fd6-af3e-c09bb3651e26",
2024-08-14 03:57:41 +02:00
"outputId": "e4c72e71-1275-4ce5-9973-2dfde98fc626"
2024-03-13 08:37:54 -05:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-08-14 03:57:41 +02:00
"1.24 ms ± 1.14 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n"
2024-03-13 08:37:54 -05:00
]
2024-03-09 10:09:17 -06:00
}
2024-03-13 08:37:54 -05:00
],
"source": [
"## 4) Multihead attention with PyTorch's scaled dot product attention\n",
"%timeit mha_pytorch_scaled(embeddings)"
]
},
{
"cell_type": "code",
2024-08-14 03:57:41 +02:00
"execution_count": 30,
2024-08-10 09:44:11 -05:00
"id": "218adbaf-f17f-47d9-81d5-41c758218df7",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "218adbaf-f17f-47d9-81d5-41c758218df7",
2024-08-14 03:57:41 +02:00
"outputId": "9251f2a9-aaa9-4403-87bd-b24b0a6e8be4"
2024-08-10 09:44:11 -05:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-08-14 03:57:41 +02:00
"2.01 ms ± 120 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
2024-08-10 09:44:11 -05:00
]
}
],
"source": [
"## 5) PyTorch's scaled dot product attention without FlashAttention\n",
"%timeit mha_pytorch_sdpa_no_flash(embeddings)"
]
},
{
"cell_type": "code",
2024-08-14 03:57:41 +02:00
"execution_count": 31,
2024-03-13 08:37:54 -05:00
"id": "868e3670-8edc-47bc-9e06-eb505e44dc9d",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "868e3670-8edc-47bc-9e06-eb505e44dc9d",
2024-08-14 03:57:41 +02:00
"outputId": "d951ed94-1fef-4d63-935e-19ee8e007f71"
2024-03-13 08:37:54 -05:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-08-14 03:57:41 +02:00
"3.06 ms ± 228 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
2024-03-13 08:37:54 -05:00
]
}
],
"source": [
2024-08-10 09:44:11 -05:00
"## 6) Using PyTorch's torch.nn.MultiheadAttention\n",
2024-03-13 08:37:54 -05:00
"%timeit mha_pytorch_class_default(embeddings)"
]
},
{
"cell_type": "code",
2024-08-14 03:57:41 +02:00
"execution_count": 32,
2024-03-13 08:37:54 -05:00
"id": "944870e6-de54-4e3b-a455-b8f21f6f92c8",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "944870e6-de54-4e3b-a455-b8f21f6f92c8",
2024-08-14 03:57:41 +02:00
"outputId": "8c69a388-ec32-49d4-a59e-cac4af1aa37a"
2024-03-13 08:37:54 -05:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-08-14 03:57:41 +02:00
"2.34 ms ± 6.28 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
2024-03-13 08:37:54 -05:00
]
}
],
"source": [
2024-08-10 09:44:11 -05:00
"## 7) Using PyTorch's torch.nn.MultiheadAttention disabling `need_weights`\n",
2024-03-13 08:37:54 -05:00
"%timeit mha_pytorch_class_noweights(embeddings)"
]
},
2024-08-10 09:44:11 -05:00
{
"cell_type": "code",
2024-08-14 03:57:41 +02:00
"execution_count": 33,
2024-08-10 09:44:11 -05:00
"id": "evKtpb5QN_2A",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "evKtpb5QN_2A",
2024-08-14 03:57:41 +02:00
"outputId": "38702248-46c0-4dc6-e584-96617f2a9650"
2024-08-10 09:44:11 -05:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-08-14 03:57:41 +02:00
"8.38 ms ± 476 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
2024-08-10 09:44:11 -05:00
]
}
],
"source": [
"## 8) Using PyTorch's FlexAttention\n",
"\n",
"# Requires PyTorch 2.5.0 or newer\n",
"%timeit mha_pytorch_flex(embeddings)"
]
},
2024-03-13 08:37:54 -05:00
{
"cell_type": "markdown",
"id": "dabc6575-0316-4640-a729-e616d5c17b73",
"metadata": {
"id": "dabc6575-0316-4640-a729-e616d5c17b73"
},
"source": [
2024-03-23 07:27:43 -05:00
"<br>\n",
" \n",
"\n",
"\n",
2024-08-14 03:57:41 +02:00
"# Visualizations"
2024-03-13 08:37:54 -05:00
]
},
{
"cell_type": "code",
2024-08-14 03:57:41 +02:00
"execution_count": 34,
2024-08-10 09:44:11 -05:00
"id": "bbb2f729-d3d8-46d0-b249-9249197ea574",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "bbb2f729-d3d8-46d0-b249-9249197ea574",
2024-08-14 03:57:41 +02:00
"outputId": "fcf22728-5570-4edf-9c9d-4bbf0e1e1413"
2024-08-10 09:44:11 -05:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-08-14 03:57:41 +02:00
"PyTorch version: 2.5.0.dev20240813+cu121\n",
2024-08-10 09:44:11 -05:00
"Running on cuda\n"
]
}
],
"source": [
"torch.manual_seed(123)\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"print(f\"PyTorch version: {torch.__version__}\")\n",
"print(f\"Running on {device}\")"
]
},
{
"cell_type": "code",
2024-08-14 03:57:41 +02:00
"execution_count": 35,
"id": "b0620bf5",
"metadata": {
"id": "b0620bf5"
},
"outputs": [],
"source": [
"functions = {\n",
" \"1) MHA wrapper class\": mha_ch03_wrapper,\n",
" \"2) MHA Ch03\": mha_ch03,\n",
" \"3) MHA with combined QKV weights\": mha_combined_qkv,\n",
" \"4) MHA with PyTorch scaled_dot_product_attention\": mha_pytorch_scaled,\n",
" \"5) PyTorch's SDPA, no FlashAttention\": mha_pytorch_sdpa_no_flash,\n",
" \"6) PyTorch MHA class defaults\": mha_pytorch_class_default,\n",
" \"7) PyTorch MHA with need_weights=False\": mha_pytorch_class_noweights\n",
" }\n",
"\n",
"if current_version >= required_version:\n",
" functions[\"8) PyTorch's FlexAttention\"] = mha_pytorch_flex"
]
},
{
"cell_type": "code",
"execution_count": 36,
"id": "CDJAPZaszaqx",
"metadata": {
"id": "CDJAPZaszaqx"
},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"# Customize further for dark mode aesthetics\n",
"plt.rcParams[\"figure.facecolor\"] = \"#121212\"\n",
"plt.rcParams[\"axes.facecolor\"] = \"#121212\"\n",
"plt.rcParams[\"axes.edgecolor\"] = \"white\"\n",
"plt.rcParams[\"axes.labelcolor\"] = \"white\"\n",
"plt.rcParams[\"text.color\"] = \"white\"\n",
"plt.rcParams[\"xtick.color\"] = \"white\"\n",
"plt.rcParams[\"ytick.color\"] = \"white\"\n",
"plt.rcParams[\"grid.color\"] = \"#444444\"\n",
"plt.rcParams[\"lines.linewidth\"] = 2\n",
"plt.rcParams[\"lines.markersize\"] = 8\n",
"\n",
"def plot_execution_times(functions, execution_means, execution_stds, filename):\n",
"\n",
" # Create plot\n",
" fig, ax = plt.subplots()\n",
" bars = ax.bar(functions.keys(), execution_means, yerr=execution_stds, capsize=5, error_kw={'ecolor': 'grey'})\n",
"\n",
" plt.ylabel(\"Execution time (ms)\")\n",
" plt.xticks(rotation=45, ha=\"right\")\n",
"\n",
" # Calculate new ylim with a margin\n",
" max_execution_time = max(execution_means)\n",
" upper_ylim = max_execution_time + 0.4 * max_execution_time # Adding a 40% margin\n",
" plt.ylim(0, upper_ylim)\n",
"\n",
" # Annotate bars with execution times\n",
" for bar in bars:\n",
" yval = bar.get_height()\n",
" plt.text(bar.get_x() + bar.get_width()/2, yval + (0.05 * upper_ylim), round(yval, 2), ha=\"center\", va=\"bottom\")\n",
"\n",
" plt.tight_layout()\n",
" plt.savefig(filename)\n",
" plt.show()"
]
},
{
"cell_type": "markdown",
"id": "4df834dc",
"metadata": {
"id": "4df834dc"
},
"source": [
"## Speed comparison (Nvidia A100 GPU) with warmup (forward pass only)"
]
},
{
"cell_type": "code",
"execution_count": 37,
2024-03-13 08:37:54 -05:00
"id": "29b63d3d-6d0b-43bb-9c68-d5514dc81000",
"metadata": {
"id": "29b63d3d-6d0b-43bb-9c68-d5514dc81000"
},
"outputs": [],
"source": [
"# CUDA benchmark code shared by Andrei Aksionov\n",
"# and based on code from\n",
"# https://github.com/cuda-mode/lectures/blob/main/lecture1/pytorch_square.py\n",
"\n",
2024-08-14 03:57:41 +02:00
"import numpy as np\n",
"\n",
"def time_pytorch_function(func, *input, num_repeats=1_000):\n",
2024-03-13 08:37:54 -05:00
" start = torch.cuda.Event(enable_timing=True)\n",
" end = torch.cuda.Event(enable_timing=True)\n",
"\n",
" # Warmup\n",
" for _ in range(5):\n",
" func(*input)\n",
" torch.cuda.synchronize()\n",
"\n",
2024-08-14 03:57:41 +02:00
" times = []\n",
2024-03-13 08:37:54 -05:00
" for _ in range(num_repeats):\n",
2024-08-14 03:57:41 +02:00
" start.record()\n",
2024-03-13 08:37:54 -05:00
" func(*input)\n",
2024-08-14 03:57:41 +02:00
" end.record()\n",
2024-03-13 08:37:54 -05:00
" torch.cuda.synchronize()\n",
2024-08-14 03:57:41 +02:00
" times.append(start.elapsed_time(end))\n",
"\n",
" return np.mean(times), np.std(times)"
2024-03-13 08:37:54 -05:00
]
},
{
"cell_type": "code",
2024-08-14 03:57:41 +02:00
"execution_count": 38,
"id": "9dd07a09",
2024-03-13 08:37:54 -05:00
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
2024-08-14 03:57:41 +02:00
"height": 487
2024-03-13 08:37:54 -05:00
},
2024-08-14 03:57:41 +02:00
"id": "9dd07a09",
"outputId": "b8db6c34-b593-45a6-9713-81107860d4c7"
2024-03-13 08:37:54 -05:00
},
"outputs": [
{
"data": {
2024-08-14 03:57:41 +02:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAnIAAAHWCAYAAADzS2TwAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzddVhU6fv48ffQEiqoKKjYufbaibq22N2JhSIqNnZhi90du3YndrerKGKAgQqCiaDU/P7gx/ky4u7qR3Q4eL+ua691zjkzcz+cM2fueVJjY2OjRQghhBBCqI6BvgMQQgghhBD/G0nkhBBCCCFUShI5IYQQQgiVkkROCCGEEEKlJJETQgghhFApSeSEEEIIIVRKEjkhhBBCCJWSRE4IIYQQQqWM9B2AGtnZ2REWFqbvMIQQQgiRQllaWvL8+fP/PE4SuW9kZ2fHrVu39B2GEEIIIVK4QoUK/WcyJ4ncN4qviStUqJDUygkhhBAiyVlaWnLr1q2vyjMkkfsfhYWF8f79e32HIYQQQohfWIob7FCuXDnWr1+Pj48PoaGh1K1bV2f/vHnzCA0N1fnvr7/+0lO0QgghhBD/uxRXI2dubo6Pjw8bNmxgzZo1XzzmyJEj9O3bV3n86dOnnxWeEEIIIUSSSXGJnLe3N97e3v96TGRkJMHBwT8pIiGEEEKIHyPFJXJfo0KFCvj6+vL27VtOnTrFxIkTef369RePNTExwdTUVHlsaWn5s8IUQgghhPhXv1wi5+3tzZ49e3j06BE5cuRg5MiR/PXXX9SqVYvY2NhEx/fv358hQ4boIVIhhBBCiH/3yyVy27dvV/59584dfHx8uHr1KhUrVuTkyZOJjp89ezYLFy5UHscPCRZCCCGE0LcUN2r1Wz169IiQkBBy5Mjxxf2RkZG8f/9e+U/mjhNCCCFEcvHLJ3L29vbY2NgQFBSk71CEEEIIIb5JimtatbCw0Kldc3BwoFChQrx+/Zo3b97g7u7Onj17CAoKIkeOHIwePZqHDx9y9OhRPUYthBBCCPHtUlwiV6xYMXbt2qU8njhxIgAbN25k0KBB/Pbbb7Rq1Yo0adLw4sULjh07xuTJk4mMjNRXyEIIIYQQ/5MUl8idOXOGdOnS/eP+5s2b/8RohBBCCCF+nF++j5wQQgghhFpJIieEEEIIoVKSyAkhhBBCqJQkckIIIYQQKiWJnBBCCCGESkkiJ4QQQgihUpLICSGEEEKolCRyQgghhBAqJYmcEEIIIYRKJYuVHRwcHChXrhxZsmTB3NyckJAQbt68yaVLl/j06ZO+wxNCCCGESJb0msg1a9aMHj16UKxYMYKDg3nx4gUfP37E2tqa7Nmz8+nTJ7Zs2cKcOXN4+vSpPkMVQgghhEh29JbIHTt2jKioKDZu3EjHjh159uyZzn4TExNKlSpF48aN8fb2xt3dnV27dukpWiGEEEKI5EdjY2Oj1ccbV61alWPHjn3VsdbW1jg4OHDjxo0fHNV/s7KyIiAggOzZs/P+/Xt9hyOEEEKIFOZbcg291sh9rdevX/P69esfGI0QQgghhPoki1GrRYoUoUCBAsrjOnXqsHbtWkaOHImxsbEeIxNCCCGESL6SRSI3c+ZMcufODUC2bNlYunQp4eHhNGjQgDFjxug3OCGEEEKIZCpZJHK5cuXi5s2bADRs2JBz587Ro0cPXFxccHJy0nN0QgghhPjRypUrx/r16/Hx8SE0NJS6devq7K9fvz5btmzh3r17hIaGUqhQof98zXz58rFq1SquXbtGaGgoPXr0+OJxdnZ2LFq0iHv37vH06VNOnTpFsWLFkqJYP1yySOQ0Gg0GBnGhVKlShcOHDwMQGBiIjY2NPkMTQgghxE9gbm6Oj48PgwcP/sf958+fZ+zYsd/0mgEBAYwbN44XL1588Zg0adKwb98+oqKiaNmyJeXLl8fDw4M3b978L8X46ZLFhMDXr19n4MCBnDhxgvLlyzNo0CAgrpn15cuXeo5OCCGEED+at7c33t7e/7j/r7/+AiBr1qxf/ZrXrl3j2rVrAIwaNeqLx7i6uhIYGEjfvn2VbY8fP/7q99C3ZFEjN3z4cIoUKYKnpyczZ87E398fgAYNGnDx4kU9RyeEEEKIlKp27dpcv36dFStW4Ovry7Fjx2jfvr2+w/pqyaJG7vbt21SqVCnR9tGjRxMTE6OHiIQQQgjxK8iWLRudO3dm4cKFzJo1i+LFizN58mSioqLYtGmTvsP7T8kikUvIwsJC6S8XTybeFUIIIcSPYGBgwPXr15kwYQIAN2/epECBAnTq1EkSua/l4OCAp6cnFSpUwMzMTNmu0WjQarXY2trqMTohhBBCpFRBQUHcvXtXZ5ufn59qZs1IFoncokWL0Gg09OvXj5cvX6LV6mXVMCGEEEL8Yi5cuKDMZRsvV65cPHnyRE8RfZtkkcj99ttvVK9enfv37+s7FCGEEELogYWFBTly5FAeOzg4UKhQIV6/fk1gYCBp06YlS5YsZMqUCUBJvoKDgwkODgZgwYIFPH/+nPHjxwNgbGxMvnz5ADAxMcHOzo5ChQrx4cMHZWDlokWL2L9/P25ubuzYsYMSJUrQoUMHBgwY8NPK/j00NjY2eq/+2rFjB7NmzeLEiRP6DuU/fctCtkIIIYT4OhUqVGDXrl2Jtm/cuBEXFxdat27NvHnzEu339PRk6tSpAOzcuZMnT57g4uICxE1Vcv369UTPOX36NA0bNlQe16xZEw8PD3LmzMnjx49ZsGABa9euTaKSfbtvyTWSRSKXPXt2ZsyYwebNm7lz5w5RUVE6+2/fvq2nyBKTRE4IIYQQP9K35BrJomk1ffr0ZM+enblz5yrbtFqtDHYQQgghhPgXySKR8/Ly4ubNmzg7OxMcHCyDHYQQQgghvkKySOSyZMlC27ZtlY6HQgghhBDivyWLJbpOnTpFoUKF9B2GEEIIIYSqJIsauYMHDzJhwgQKFCjwxcEOBw4c0FNkQgghhEhuzM3NsbCw+ObnffjwgfDw8B8Qkf4ki1GrL1++/Md9yW2wg4xaFUIIIfSrTJkylClT5pufd+HCBS5cuPADIkpaqhu1miFDBn2HIIQQQgiVuHnzJg8fPky0vWHDhpibmxMeHs7OnTsT7f/w4cPPCO+nShaJnBBCCCHE1woPD/9iE2lsbKzy/39r7UtJ9DbYoXHjxl99rL29PaVLl/6B0QghhBBCqI/eErnOnTtz7tw5+vbtS968eRPtt7Ky4o8//mDx4sUcO3YMGxsbPUQphBBCCJF86a1ptUGDBtSuXZvu3bvj4eFBeHg4wcHBfPr0ibRp02Jra0toaCibNm2iYsWKv0wVqRBCCCHE19JrH7kDBw5w4MABbGxsKFu2LFmyZCFVqlSEhoZy8+ZN/v77b1nlQQghhEghrLqv+aGvrzG7AUShsbD+4e/1fmmHH/r6XytZDHZ49eoV+/bt03cYQgghhBCqkixWdhBCCCGEEN9OEjkhhBBCCJVKFk2rQgghhBBfKxWRmGuiEm03QKv8P50m8eS/4VpjIjD54fH9TJLICSGEEEJV8hm9pLjx83/cn0oTTQOzO4m2X4uy43p05h8Z2k+XrBI5Y2NjsmXLhr+/PzExMfoORwghhBDJ0N3oDDyJSfvNzwvXGid9MHqWLBK5VKlSMWXKFFq1agVA6dKlefToEVOmTOH58+fMmTNHzxEKIYQQIrmIwIQIbcpqIv1fJYvBDh4eHhQqVIgGDRrw8eNHZfuJEydo1KiR/gITQgghhEjGkkUiV7duXYYMGcKFCxd0tvv6+pIjR45veq1y5cqxfv16fHx8CA0NpW7duomOGTp0KD4+Pjx9+pRt27aRM2fO74pfCCGEEEIfkkUily5dui8uwWVubv7NKzuYm5vj4+PD4MGDv7i/X79+ODs7M2jQIGrWrEl4eDibN2/G1NT0f4pdCCGEEEJfkkUid/36dWrWrKk8jk/e2rdvz6VLl77ptby9vZk0aRJ79+794v4ePXowY8YM9u/fz+3bt+nVqxeZMmX6Ys2dEEIIIURyliwGO0yYMIG//vqLfPnyYWhoSI8ePciXLx+lSpW
2024-03-13 08:37:54 -05:00
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
2024-08-14 03:57:41 +02:00
"execution_stats = [time_pytorch_function(fn, embeddings) for fn in functions.values()]\n",
"execution_means = [stat[0] for stat in execution_stats]\n",
"execution_stds = [stat[1] for stat in execution_stats]\n",
2024-03-13 08:37:54 -05:00
"\n",
2024-08-10 09:44:11 -05:00
"\n",
2024-08-14 03:57:41 +02:00
"plot_execution_times(functions, execution_means, execution_stds, filename=\"1_forward-only.pdf\")"
2024-08-10 09:44:11 -05:00
]
},
{
"cell_type": "markdown",
"id": "VQaSerWCOnYB",
"metadata": {
"id": "VQaSerWCOnYB"
},
"source": [
"<br>\n",
" \n",
"\n",
"\n",
"## Speed comparison (Nvidia A100 GPU) with warmup (forward and backward pass)"
]
},
{
"cell_type": "code",
2024-08-14 03:57:41 +02:00
"execution_count": 39,
"id": "69e6377b",
2024-08-10 09:44:11 -05:00
"metadata": {
2024-08-14 03:57:41 +02:00
"id": "69e6377b"
2024-08-10 09:44:11 -05:00
},
2024-08-14 03:57:41 +02:00
"outputs": [],
2024-08-10 09:44:11 -05:00
"source": [
"def forward_backward(func, embeddings):\n",
" if embeddings.grad is not None:\n",
" embeddings.grad.zero_()\n",
"\n",
" output = func(embeddings)\n",
" loss = output.sum()\n",
" loss.backward()\n",
"\n",
"\n",
"def time_pytorch_function_forward_backward(func, *input, num_repeats = 1_000):\n",
" # CUDA IS ASYNC so can't use python time module\n",
" start = torch.cuda.Event(enable_timing=True)\n",
" end = torch.cuda.Event(enable_timing=True)\n",
"\n",
" # Warmup\n",
" for _ in range(5):\n",
" forward_backward(func, *input)\n",
" torch.cuda.synchronize()\n",
"\n",
2024-08-14 03:57:41 +02:00
" times = []\n",
2024-08-10 09:44:11 -05:00
" for _ in range(num_repeats):\n",
2024-08-14 03:57:41 +02:00
" start.record()\n",
2024-08-10 09:44:11 -05:00
" forward_backward(func, *input)\n",
2024-08-14 03:57:41 +02:00
" end.record()\n",
2024-08-10 09:44:11 -05:00
" torch.cuda.synchronize()\n",
2024-08-14 03:57:41 +02:00
" times.append(start.elapsed_time(end))\n",
2024-08-10 09:44:11 -05:00
"\n",
2024-08-14 03:57:41 +02:00
" return np.mean(times), np.std(times)"
]
},
{
"cell_type": "code",
"execution_count": 40,
"id": "ReCmeRhCOpm8",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 487
},
"id": "ReCmeRhCOpm8",
"outputId": "01159c54-afac-4a0d-a06f-b41cad1b30e6"
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAnIAAAHWCAYAAADzS2TwAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzddVhU6dvA8e/QEiqoKKjYuXaLhbq22GJhKxjYYmMXotjd6NqNri12167diQq2ICAC8/7hy/kxoq7ugsPB+3NdXjqn5n6cM+fc85wnNDY2NlqEEEIIIYTqGOg7ACGEEEII8e9IIieEEEIIoVKSyAkhhBBCqJQkckIIIYQQKiWJnBBCCCGESkkiJ4QQQgihUpLICSGEEEKolCRyQgghhBAqZaTvANTIzs6O0NBQfYchhBBCiGTK0tKSp0+f/uN2ksj9IDs7Oy5fvqzvMIQQQgiRzBUoUOAfkzlJ5H5QbE1cgQIFpFZOCCGEEAnO0tKSy5cvf1eeIYncvxQaGkpISIi+wxBCCCHEL0w6OwghhBBCqJQkckIIIYQQKiWJnBBCCCGESkkiJ4QQQgihUpLICSGEEEKolCRyQgghhBAqJYmcEEIIIYRKSSInhBBCCKFSksgJIYQQQqiUJHJCCCGEEColiZwQQgghhEpJIieEEEIIoVKSyAkhhBBCqJQkckIIIYQQKiWJnBBCCCGESkkiJ4QQQgihUskqkevduzf79u3jwYMHXL9+nRUrVpAzZ06dbbZu3crLly91/kyePFlPEQshhBBC/HtG+g4gITk6OrJ48WLOnz+PkZERw4YNY8OGDTg6OhIWFqZst3z5ciZOnKi8Dg8P10e4QgghhBD/SbJK5FxcXHRee3h4cPPmTQoXLsyJEyeU5eHh4QQHB//s8IQQQgghElSyerT6uZQpUwLw+vVrneVNmjTh5s2bHD16FC8vL1KkSKGP8IQQQggh/pMkUSPn4OBA2bJlyZQpE+bm5rx48YJLly5x5swZPnz48K+OqdFoGDduHCdPnuT69evK8o0bN/Lo0SOePXvGb7/9xogRI8iZMydt27b94nFMTEwwNTVVXltaWv6reIQQQgiRMMzNzbGwsPjh/d6/f6/T1Co50Gsi16RJE9zd3SlSpAjBwcE8e/aMiIgIrK2tyZo1Kx8+fGDDhg1Mnz6dx48f/9CxfXx8yJcvH3Xq1NFZ7ufnp/z72rVrBAUFsWXLFrJmzcr9+/fjHad3794MHDjwX5VPCCGEEAmvYMGClC5d+of3O3XqFKdOnUqEiPRHY2Njo9XHGwcEBPDx40fWrFnDrl27ePLkic56ExMTSpYsScOGDXF2dsbT05Nt27Z917G9vb2pVasWdevW5eHDh9/c1tzcnEePHtGkSRMCAgLirf9Sjdzly5fJmjUrISEh3xWPEEIIIRLO12rk6tevj7m5OWFhYWzdujXeerXUyFlZWXH//v3vyjX0ViM3evToLyZOsSIjIzl27BjHjh1j3LhxODg4fNdxvb29qVOnDvXq1fvHJA6gQIECAAQFBX01jsjIyO96byGEEEIkvrCwsC8mZDExMcrfz58//9lh6YXeErlvJXGfe/36dbwOC1/i4+ND48aNcXV1JTQ0FFtbWwDevXtHREQEWbNmpXHjxuzbt49Xr17x22+/MXbsWI4dO8bVq1f/dVmEEEIIIfQhSXR2KFSoEB8/fuTatWsA1KpVi5YtW3Ljxg28vb35+PHjdx2nQ4cOAPj7++ss9/DwYPXq1URGRlKpUiW6dOmCubk5gYGB+Pv74+vrm7AFEkIIIYT4CZJEIufr68v06dO5du0aWbJkYeHChezYsYN69eqRIkUKhg4d+l3HSZMmzTfXP3nyhHr16iVEyEIIIYQQepckxpHLkSMHly5dAj41VDxx4gTu7u54eHjg7Oys5+iEEEIIkdi+Z5pNU1NTJk2axK1bt3jw4AHLli0jXbp03zzurFmz4k3NuW7dOp1tLly4EG+bXr16JXgZE0OSqJHTaDQYGHzKKStVqsTu3bsBCAwMxMbGRp+hCSGEEOIn+J5pNseNG0e1atXo0KED7969w9vbm+XLl1O7du1vHnvfvn306NFDef2lMWrHjx/PihUrlNehoaEJVLLElSQSuYsXL9KvXz8OHTqEo6Mj/fv3ByBLliy/TK8TIYQQ4lf2T9NsWllZ0apVK9zc3Dhy5AgAPXr04OTJk5QoUYKzZ89+9diRkZH/ODVnaGioKqfvTBKPVocMGUKhQoXw9vbG19eXe/fuAVCvXj1Onz6t5+iEEEII8bN9Ps1mkSJFMDEx4dChQ8o2t27d4tGjR5QoUeKbxypXrhzXr1/n1KlTTJ48GWtr63jb9OrVi1u3bhEQEICHhweGhoYJWJrEkyRq5K5evUqFChXiLR8xYgTR0dF6iEgIIYQQ+vKlaTZtbW358OED796909n2+fPnpE+f/qvH2r9/P9u3b+fBgwdky5aNYcOGsW7dOmrUqKGMO7dgwQL+/vtvXr9+TalSpfDy8iJ9+vR4eXklXiETSJJI5OKysLBQ2svFkhkUhBBCiF/H16bZ/Dc2b96s/PvatWtcuXKF8+fPU758eQ4fPgzA3LlzlW2uXr1KZGQkvr6+jBkzJslPCpAkHq06ODiwevVqHj58yL1797hz5w537tzh7t273LlzR9/hCSGEEOIn8fb2pnr16tSvX19n+s7g4GBMTU2VR66x0qVL99XZmb7kwYMHvHjxgmzZsn11m3PnzmFsbPzds0rpU5KokZs3bx4ajYaePXvy/PlztFq9TP8qhBBCCD361jSbFy9eVAb2jx34P2fOnGTOnPmbHR0+Z29vj42NzTeTv4IFCxIdHa2KDpdJIpH77bffqFq1Krdv39Z3KEIIIYTQg3+aZjMkJIQ//viDMWPG8Pr1a0JCQpg4cSKnT5/WSeS6d+/Ojh07gE/NtTw9Pdm+fTtBQUFky5aNESNGcPfuXQ4cOABAiRIlKF68OEePHiU0NJSSJUsyduxY1q9fz9u3b3/+f8QPShKJ3IULF8iYMaMkckIIIcQv6p+m2QQYOnQoMTExLFu2DBMTEwICAvD09NTZPm3atJiZmQEQHR3Nb7/9RvPmzUmVKhXPnj0jICCACRMmKG3fIiMjadSoEQMHDsTExISHDx8yb9485syZk9hFThAaGxsbvT/HzJo1K1OmTGH9+vVcu3Yt3tyqSWlCeysrK+7fv0/WrFmlE4YQQgiRhHTo0AFLS0tCQ0NZsmSJvsP5134k10gSNXJp06Yla9aszJw5U1mm1WrRaDRotVqlelUIIYQQQvxPkkjkZsyYwaVLl3BzcyM4OFg6OwghhBBCfIckkchlypSJVq1aKTM6CCGEEEKIf5YkxpE7cuQIBQoU0HcYQgghhBCqkiRq5Hbv3s3YsWPJly/fFzs77Nq1S0+RCSGEEEIkXUkikZsyZQpAvC7EgHR2EEIIIYT4iiSRyKVLl07fIQghhBBCqE6SaCMnhBBCCCF+nN4SuYYNG373tvb29pQqVSoRoxFCCCGEUB+9JXLt27fnxIkT9OjRg9y5c8dbb2Vlxe+//878+fMJCAjAxsZGD1EKIYQQQiRdemsjV69ePWrWrEnnzp3x8vIiLCyM4OBgPnz4QOrUqbG1teXly5esWbOG8uXL8/z58388Zu/evalbty65cuUiPDycM2fOMGrUKJ05XE1NTRkzZgwNGzbUmafte44vhBBCCJGUJIm5Vm1sbChTpgyZMmUiRYoUvHz5kkuXLvH333//0CwP69atY/PmzZw/fx4jIyOGDRtGvnz5cHR0JCwsDIDJkydTrVo1PDw8ePfuHd7e3sTExFC7du3veg+Za1UIIYT4d6w6+yXq8V3M/sJC85H3WmPWRRRO1PcKWdgm0Y6turlWX716xZ9//vmfj+Pi4qLz2sPDg5s3b1K4cGFOnDiBlZUVrVq1ws3NjSNHjgDQo0cPTp48SYkSJTh79ux/jkEIIYQQ4mdJ1r1WU6ZMCcDr168BKFKkCCYmJhw6dEjZ5tatWzx69IgSJUp88RgmJiZYWVkpfywtLRM/cCGEEEKI75AkauQ
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"execution_stats = [time_pytorch_function_forward_backward(fn, embeddings) for fn in functions.values()]\n",
"execution_means = [stat[0] for stat in execution_stats]\n",
"execution_stds = [stat[1] for stat in execution_stats]\n",
2024-08-10 09:44:11 -05:00
"\n",
"\n",
2024-08-14 03:57:41 +02:00
"plot_execution_times(functions, execution_means, execution_stds, filename=\"2_forward-and-backward.pdf\")"
2024-08-10 09:44:11 -05:00
]
},
{
"cell_type": "markdown",
"id": "1gWX-Ayqia1k",
"metadata": {
"id": "1gWX-Ayqia1k"
},
"source": [
"<br>\n",
" \n",
"\n",
"\n",
"## Speed comparison (Nvidia A100 GPU) with warmup and compilation (forward and backward pass)"
]
},
{
"cell_type": "code",
2024-08-14 03:57:41 +02:00
"execution_count": 41,
2024-08-10 09:44:11 -05:00
"id": "LQDiAPooiYAz",
2024-08-14 03:57:41 +02:00
"metadata": {
"id": "LQDiAPooiYAz"
},
"outputs": [],
"source": [
"import torch._dynamo\n",
"torch._dynamo.config.suppress_errors = True\n",
"\n",
"def prepare_function(fn):\n",
" fn = torch.compile(fn)\n",
" return fn"
]
},
{
"cell_type": "code",
"execution_count": 42,
"id": "aac06ffe",
2024-08-10 09:44:11 -05:00
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
2024-08-14 03:57:41 +02:00
"height": 488
2024-08-10 09:44:11 -05:00
},
2024-08-14 03:57:41 +02:00
"id": "aac06ffe",
"outputId": "f64e3437-487b-45d2-d080-43c98464c47e"
2024-08-10 09:44:11 -05:00
},
"outputs": [
{
"data": {
2024-08-14 03:57:41 +02:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAnIAAAHXCAYAAAA4F7dVAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzddVgV6dvA8e8hpVRQUexWXHXVNbFQ1xYwwcBW7MBWxEBdAwsTu2vF3l27u3HtWLEwEMRACYHz/uHFvBzRXf0JHgbvz3Vx7Z45M2fux5kzc59nntDY2NhoEUIIIYQQqmOg7wCEEEIIIcT/RhI5IYQQQgiVkkROCCGEEEKlJJETQgghhFApSeSEEEIIIVRKEjkhhBBCCJWSRE4IIYQQQqUkkRNCCCGEUClJ5IQQQgghVEoSOSGEEEIIlUpziVylSpVYs2YNV69eJSwsjAYNGiRZZ9iwYVy9epVHjx6xefNm8ufPr4dIhRBCCCG+TZpL5MzNzbl69SpDhgz55Pt9+/bFw8ODQYMGUadOHd69e8fGjRsxNTX9zpEKIYQQQnwbjY2NjVbfQaSUsLAw2rZty19//aUsu3r1KvPmzWPu3LkAWFlZcePGDXr37s2WLVu+6HPt7OyIiIhIkZiFEEIIISwtLXny5Ml/rmf0HWJJNfLkyUO2bNk4fPiwsuzNmzecP3+ecuXKfTKRMzEx0amty5o1K6dPn/4u8QohhBDix1W8ePH/TOZ+qETO1tYWgOfPn+ssf/78ufLex/r378/QoUOTLC9evLjUygkhhBAi2VlaWnLlypUvyjN+qETufzFz5kzmz5+vvE78j/vmzRs9RiaEEEKIH12a6+zwb0JCQgDIkiWLzvIsWbIo730sJiaGN2/eKH9SCyeEEEKI1OKHSuTu37/P06dPqVatmrLMysqKX375hbNnz+oxMiGEEEKIr5fmHq1aWFiQL18+5XXu3LkpXrw44eHhBAcHs2DBAgYOHMjdu3e5f/8+I0aM4OnTpzo9W4UQQggh1CDNJXKlSpVi+/btyusJEyYAsG7dOnr37s2sWbMwNzdn+vTpZMiQgdOnT+Pq6kp0dLS+QhZCCCGE+J+k6XHkUoKVlRX37t0jb9680tlBCCGEEMnua3KNH6qNnBBCCCFEWiKJnBBCCCGESkkiJ4QQQgihUpLICSGEEEKolCRyQgghhBAqJYmcEEIIIYRKSSInhBBCCKFSksgJIYQQQqiUJHJCCCGEEColiZwQQgghhEpJIieEEEIIoVKSyAkhhBBCqJQkckIIIYQQKmWk7wAAcufOTaVKlciZMyfm5uaEhoZy+fJlzp49S3R0tL7DE0IIIYRIlfSayDVv3pxu3bpRqlQpQkJCePr0KVFRUVhbW5M3b16io6MJCAjAz8+PR48e6TNUIYQQQohUR2+J3MGDB3n//j3r1q2jffv2PH78WOd9ExMTypUrR5MmTdi/fz+DBw9m+/bteopWCCGEECL10djY2Gj1seMaNWpw8ODBL1rX2tqa3Llzc+nSpRSO6r9ZWVlx79498ubNy5s3b/QdjhBCCCHSmK/JNfRaI/elwsPDCQ8PT8FohBBCCCHUJ1X0Wi1ZsiT29vbK6/r167Nq1SpGjhyJsbGxHiMTQgghhEi9UkUiN336dAoWLAhAnjx5WLRoEe/evcPZ2ZkxY8boNzghhBBCiFQqVSRyBQoU4PLlywC4uLhw8uRJunXrRu/evXFyctJzdEIIIYQQqVOqSOQ0Gg0GBh9CqV69Onv37gUgODgYGxsbfYamChcvXiQsLCzJ35QpUz65fpEiRVi+fLmyXbdu3ZKsY2lpyYQJEwgMDOTRo0fs3LmT0qVLp3RRhBBCCPEVUkUiFxgYyMCBA3F1dcXBwUFJ5PLkycPz58/1HF3q9+uvv2Jvb6/8NW3aFIBt27Z9cn1zc3Pu3buHj48PT58+/eQ6M2fOxNHRkR49elC1alUOHjzI5s2bsbOzS7FyCCGEEOLrpIpEbsSIEZQsWZLJkyczffp0goKCAHB2dubMmTN6ji71CwsLIyQkRPmrU6cOd+/e5fjx459c/+LFi4wZM4YtW7YQExOT5P106dLh5OTEmDFjOHnyJEFBQUyZMoW7d+/SsWPHlC6OEEIIIb5Qqpii69q1a1StWjXJ8tGjRxMXF5es+zIwMGDo0KG0aNECW1tbnj59yrp165g2bVqy7kdfjI2NadGiBfPnz/+fP8PIyAgjI6Mk06NFRUVRoUKFbw1RCCGEEMkkVdTIJWZhYYGVlRVWVlaYmJhgZmaWrJ/fr18/OnbsyNChQ6lUqRJjx46lb9++eHh4JOt+9KVBgwZkyJCBdevW/c+fERERwZkzZxg4cCDZsmXDwMCAFi1aUK5cObJly5aM0QohhPhaX9suGj484Tp16hTBwcEcPXqUX3/99TtGLFJSqqiRy507N5MnT6Zy5cqkS5dOWa7RaNBqtdja2ibbvsqVK8fOnTuVdngPHz6kWbNmlClTJtn2oU/u7u7s27fvs23fvlSPHj2YNWsWV69eJTY2lr///pvNmzfz888/J1OkQggh/he//vorhoaGymt7e3s2b9782XbR5cqVY9GiRYwbN449e/bQrFkzVq1aRY0aNbhx48b3ClukkFSRyPn7+6PRaOjbty/Pnz9Hq025WcPOnj1Lu3btKFCgAP/88w8//fQTFSpUwNvbO8X2+b3kzJmT6tWr0759+2/+rHv37uHs7Iy5uTlWVlY8e/aMxYsXc+/evW8PVAghxP8sLCxM53W/fv3+tV10t27d2L9/P3PmzAFg4sSJODo60qVLFwYNGpTi8YqUlSoSuZ9++olatWpx586dFN/XzJkzsbKy4tSpU8TFxWFoaMiECRMICAj45PomJiaYmpoqry0tLVM8xv9V69atef78OXv27Em2z3z37h3v3r0jQ4YM1KxZUwZoFkKIVORL2kWXK1eOefPm6Sw7cOAADRo0SOnwxHeQKhK5ixcvkiNHju+SyDVu3JjmzZvj4eHBjRs3KFGiBBMmTODp06esX78+yfr9+/dn6NChKR7Xt9JoNLRu3ZoNGzYk6SAyb948njx5wrhx44APX/wiRYoAHxJVOzs7ihcvztu3b5UewzVq1ECj0XDnzh3y58/PmDFjuH37NmvXrv2+BRNCCPFZX9Iu2tbWNslQXs+fP0/WZktCf1JFIte/f3+mTZuGnZ0d169f5/379zrvX7t2Ldn2NXbsWPz8/NiyZQsA169fJ1euXPTv3/+TidzMmTN1fulYWlpy5cqVZIsnuVSvXp1cuXKxZs2aJO/lyJGD+Ph45XW2bNk4fPiw8rpPnz706dOHY8eO4eLiAkD69Onx9vYme/bshIeH88cffzB+/HhiY2NTvjBCCCG+SHK1ixbqlSoSucyZM5M3b15mz56tLNNqtSnS2cHMzEwnqQGIi4tDo9F8cv2YmJhPjrWW2hw6dIhMmTJ98r2E5CzBw4cPP7tugm3btn224awQQgj9+9J20SEhIWTJkkVnWZYsWQgJCUnJ8MR3kioSuVmzZnH58mU8PDwICQlJ0c4Ou3fvZsCAATx69IgbN25QsmRJevToIY8MhRBCqMqXtos+e/Ys1apVY8GCBcoyR0dHzp49m9Ihiu8gVSRyOXPmpE2bNkr7rJQ0bNgwhg8fjq+vL5kzZ+bp06esWLECX1/fFN+3EEIIkRy+pl30ggUL2LFjBz179mTv3r00adKEUqVK4enpqY/QRTJLFYnc0aNHKV68+HdJ5CIiIvDy8sLLyyvF9yWEEEKkhK9pF3327Fk8PDzw8vJi5MiR3L17l7Zt28oYcmmExsbGJuWeY36hdu3aMXDgQNasWfPJzg67du3SU2RJWVlZce/ePfLmzcubN2/0HY4QQggh0pivyTVSRSL3cbfoxJK7s8O3kkROCCGEECnpa3KNVPFo9ePeNEIIIYQQ4r8Z6DsAIYQQQgjxv9FbItekSZMvXjd79uyUL18+BaNJu8zNzcmSJctX/5mbm+s7dCHEJ9jZ2eHv78/t27d59OgRR48epVSpUp9dv1GjRmzatImbN29y7949du3aRY0aNXT
2024-08-10 09:44:11 -05:00
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
2024-08-14 03:57:41 +02:00
"execution_stats = [time_pytorch_function_forward_backward(prepare_function(fn), embeddings) for fn in functions.values()]\n",
"execution_means = [stat[0] for stat in execution_stats]\n",
"execution_stds = [stat[1] for stat in execution_stats]\n",
2024-03-13 08:37:54 -05:00
"\n",
"\n",
2024-08-14 03:57:41 +02:00
"plot_execution_times(functions, execution_means, execution_stds, filename=\"3_forward-and-backward-compiled.pdf\")"
2024-03-13 08:37:54 -05:00
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"gpuType": "A100",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
2024-03-06 08:30:32 -06:00
},
2024-03-13 08:37:54 -05:00
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
2024-08-14 03:57:41 +02:00
"version": "3.10.14"
2024-03-13 08:37:54 -05:00
}
},
"nbformat": 4,
"nbformat_minor": 5
}