1994 lines
338 KiB
Plaintext
Raw Normal View History

2024-03-06 08:30:32 -06:00
{
2024-03-13 08:37:54 -05:00
"cells": [
2024-03-19 09:26:26 -05:00
{
"cell_type": "markdown",
"id": "e2e65c03-36d4-413f-9b23-5cdd816729ab",
2024-08-10 09:44:11 -05:00
"metadata": {
"id": "e2e65c03-36d4-413f-9b23-5cdd816729ab"
},
2024-03-19 09:26:26 -05:00
"source": [
2024-05-24 07:20:37 -05:00
"<table style=\"width:100%\">\n",
"<tr>\n",
"<td style=\"vertical-align:middle; text-align:left;\">\n",
"<font size=\"2\">\n",
"Supplementary code for the <a href=\"http://mng.bz/orYv\">Build a Large Language Model From Scratch</a> book by <a href=\"https://sebastianraschka.com\">Sebastian Raschka</a><br>\n",
"<br>Code repository: <a href=\"https://github.com/rasbt/LLMs-from-scratch\">https://github.com/rasbt/LLMs-from-scratch</a>\n",
"</font>\n",
"</td>\n",
"<td style=\"vertical-align:middle; text-align:left;\">\n",
"<a href=\"http://mng.bz/orYv\"><img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/cover-small.webp\" width=\"100px\"></a>\n",
"</td>\n",
"</tr>\n",
"</table>"
2024-03-19 09:26:26 -05:00
]
},
2024-03-13 08:37:54 -05:00
{
"cell_type": "markdown",
"id": "6f678e62-7bcb-4405-86ae-dce94f494303",
"metadata": {
"id": "6f678e62-7bcb-4405-86ae-dce94f494303"
},
"source": [
2024-03-23 07:27:43 -05:00
"# Comparing Efficient Multi-Head Attention Implementations"
2024-03-13 08:37:54 -05:00
]
},
{
"cell_type": "markdown",
"id": "b742938a-4bfc-4527-a1f1-d5963508967d",
"metadata": {
"id": "b742938a-4bfc-4527-a1f1-d5963508967d"
},
"source": [
"This code notebook compares different ways to implement causal multi-head attention used in decoder-style LLMs like GPT, Llama, etc."
]
},
{
"cell_type": "code",
2025-03-06 20:29:04 -06:00
"execution_count": 1,
2024-03-13 08:37:54 -05:00
"id": "7898551e-f582-48ac-9f66-3632abe2a93f",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "7898551e-f582-48ac-9f66-3632abe2a93f",
2025-03-06 20:29:04 -06:00
"outputId": "1dcdc621-7d0b-41e3-eac8-0f5a768e1bed"
2024-03-13 08:37:54 -05:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2025-03-06 20:29:04 -06:00
"PyTorch version: 2.6.0+cu124\n"
2024-03-13 08:37:54 -05:00
]
2024-03-08 09:30:55 -06:00
}
2024-03-13 08:37:54 -05:00
],
"source": [
"import torch\n",
"\n",
"torch.manual_seed(123)\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"print(f\"PyTorch version: {torch.__version__}\")\n",
"\n",
"batch_size = 8\n",
"context_len = 1024\n",
"embed_dim = 768\n",
"embeddings = torch.randn((batch_size, context_len, embed_dim), device=device)"
]
},
{
"cell_type": "markdown",
"id": "LYLcq3403Yq6",
"metadata": {
"id": "LYLcq3403Yq6"
},
"source": [
"- To run all the code in this notebook, please ensure you update to at least PyTorch 2.5 (FlexAttention is not included in earlier PyTorch releases)\n",
"- If the code cell above shows a PyTorch version lower than 2.5, you can upgrade your PyTorch installation by uncommenting and running the following code cell (Please note that PyTorch 2.5 requires Python 3.9 or later)\n",
"- For more specific instructions and CUDA versions, please refer to the official installation guide at https://pytorch.org"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1db27f43-86f4-478f-89df-fbc2182a129b",
2025-03-06 20:29:04 -06:00
"metadata": {
"id": "1db27f43-86f4-478f-89df-fbc2182a129b"
},
"outputs": [],
"source": [
"# pip install --upgrade torch torchvision torchaudio"
]
},
2024-03-13 08:37:54 -05:00
{
"cell_type": "markdown",
"id": "2f9bb1b6-a1e5-4e0a-884d-0f31b374a8d6",
"metadata": {
"id": "2f9bb1b6-a1e5-4e0a-884d-0f31b374a8d6"
},
"source": [
2024-03-23 07:27:43 -05:00
"<br>\n",
"&nbsp;\n",
"\n",
2024-03-13 08:37:54 -05:00
"## 1) CausalAttention MHA wrapper class from chapter 3"
]
},
{
"cell_type": "code",
2025-03-06 20:29:04 -06:00
"execution_count": 2,
2024-03-13 08:37:54 -05:00
"id": "297c93ed-aec0-4896-bb89-42c4b294d3d1",
"metadata": {
2024-03-09 10:20:08 -06:00
"colab": {
2024-03-13 08:37:54 -05:00
"base_uri": "https://localhost:8080/"
},
"id": "297c93ed-aec0-4896-bb89-42c4b294d3d1",
2025-03-06 20:29:04 -06:00
"outputId": "9d02508e-106d-4a13-9bd6-0941cc7c5d36"
2024-03-13 08:37:54 -05:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([8, 1024, 768])\n"
]
}
],
"source": [
"import torch.nn as nn\n",
"\n",
"class CausalAttention(nn.Module):\n",
"\n",
" def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):\n",
" super().__init__()\n",
" self.d_out = d_out\n",
" self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
" self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
" self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
" self.dropout = nn.Dropout(dropout) # New\n",
" self.register_buffer(\"mask\", torch.triu(torch.ones(context_length, context_length), diagonal=1)) # New\n",
"\n",
" def forward(self, x):\n",
" b, num_tokens, d_in = x.shape # New batch dimension b\n",
" keys = self.W_key(x)\n",
" queries = self.W_query(x)\n",
" values = self.W_value(x)\n",
"\n",
" attn_scores = queries @ keys.transpose(1, 2) # Changed transpose\n",
" attn_scores.masked_fill_( # New, _ ops are in-place\n",
" self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)\n",
" attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)\n",
" attn_weights = self.dropout(attn_weights) # New\n",
"\n",
" context_vec = attn_weights @ values\n",
" return context_vec\n",
"\n",
"\n",
"class Ch03_MHA_Wrapper(nn.Module):\n",
"\n",
" def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):\n",
" super().__init__()\n",
" self.heads = nn.ModuleList(\n",
" [CausalAttention(d_in, d_out, context_length, dropout, qkv_bias)\n",
" for _ in range(num_heads)]\n",
" )\n",
" self.out_proj = nn.Linear(d_out*num_heads, d_out*num_heads)\n",
"\n",
" def forward(self, x):\n",
" context_vec = torch.cat([head(x) for head in self.heads], dim=-1)\n",
" return self.out_proj(context_vec)\n",
"\n",
2024-03-13 08:37:54 -05:00
"\n",
"mha_ch03_wrapper = Ch03_MHA_Wrapper(\n",
" d_in=embed_dim,\n",
" d_out=embed_dim//12,\n",
" context_length=context_len,\n",
2024-03-13 08:37:54 -05:00
" dropout=0.0,\n",
" num_heads=12,\n",
" qkv_bias=False\n",
").to(device)\n",
"\n",
"out = mha_ch03_wrapper(embeddings)\n",
"print(out.shape)"
]
},
{
"cell_type": "markdown",
"id": "21930804-b327-40b1-8e63-94dcad39ce7b",
"metadata": {
"id": "21930804-b327-40b1-8e63-94dcad39ce7b"
},
"source": [
2024-03-23 07:27:43 -05:00
"<br>\n",
"&nbsp;\n",
"\n",
2024-03-13 08:37:54 -05:00
"## 2) The multi-head attention class from chapter 3"
]
},
{
"cell_type": "code",
2025-03-06 20:29:04 -06:00
"execution_count": 3,
2024-03-13 08:37:54 -05:00
"id": "4ee6a61b-d25c-4a0c-8a59-f285544e3710",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "4ee6a61b-d25c-4a0c-8a59-f285544e3710",
2025-03-06 20:29:04 -06:00
"outputId": "7469c10e-58e4-4b98-f5fd-ffdab4a2ef6b"
2024-03-13 08:37:54 -05:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([8, 1024, 768])\n"
]
}
],
"source": [
"class Ch03_MHA(nn.Module):\n",
" def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):\n",
" super().__init__()\n",
" assert d_out % num_heads == 0, \"d_out must be divisible by num_heads\"\n",
"\n",
" self.d_out = d_out\n",
" self.num_heads = num_heads\n",
" self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim\n",
"\n",
" self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
" self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
" self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
" self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs\n",
" self.dropout = nn.Dropout(dropout)\n",
" self.register_buffer(\"mask\", torch.triu(torch.ones(context_length, context_length), diagonal=1))\n",
"\n",
" def forward(self, x):\n",
" b, num_tokens, d_in = x.shape\n",
"\n",
" keys = self.W_key(x) # Shape: (b, num_tokens, d_out)\n",
" queries = self.W_query(x)\n",
" values = self.W_value(x)\n",
"\n",
" # We implicitly split the matrix by adding a `num_heads` dimension\n",
" # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)\n",
" keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)\n",
" values = values.view(b, num_tokens, self.num_heads, self.head_dim)\n",
" queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)\n",
"\n",
" # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)\n",
" keys = keys.transpose(1, 2)\n",
" queries = queries.transpose(1, 2)\n",
" values = values.transpose(1, 2)\n",
"\n",
" # Compute scaled dot-product attention (aka self-attention) with a causal mask\n",
" attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head\n",
"\n",
" # Original mask truncated to the number of tokens and converted to boolean\n",
" mask_bool = self.mask.bool()[:num_tokens, :num_tokens]\n",
"\n",
" # Use the mask to fill attention scores\n",
" attn_scores.masked_fill_(mask_bool, -torch.inf)\n",
"\n",
" attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)\n",
" attn_weights = self.dropout(attn_weights)\n",
"\n",
" # Shape: (b, num_tokens, num_heads, head_dim)\n",
" context_vec = (attn_weights @ values).transpose(1, 2)\n",
"\n",
" # Combine heads, where self.d_out = self.num_heads * self.head_dim\n",
" context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)\n",
" context_vec = self.out_proj(context_vec) # optional projection\n",
"\n",
" return context_vec\n",
"\n",
2024-03-13 08:37:54 -05:00
"\n",
"mha_ch03 = Ch03_MHA(\n",
" d_in=embed_dim,\n",
" d_out=embed_dim,\n",
" context_length=context_len,\n",
2024-03-13 08:37:54 -05:00
" dropout=0.0,\n",
" num_heads=12,\n",
" qkv_bias=False\n",
").to(device)\n",
"\n",
"out = mha_ch03(embeddings)\n",
"print(out.shape)"
]
},
{
"cell_type": "markdown",
"id": "73cd11da-ea3b-4081-b483-c4965dfefbc4",
"metadata": {
"id": "73cd11da-ea3b-4081-b483-c4965dfefbc4"
},
"source": [
2024-03-23 07:27:43 -05:00
"<br>\n",
"&nbsp;\n",
"\n",
2024-03-13 08:37:54 -05:00
"## 3) An alternative multi-head attention with combined weights"
]
},
{
"cell_type": "markdown",
"id": "1fa1a5ea-eaff-4d2d-aaf0-b34cdb6fd4dd",
"metadata": {
"id": "1fa1a5ea-eaff-4d2d-aaf0-b34cdb6fd4dd"
},
"source": [
"- The code for the `MultiHeadAttentionCombinedQKV` class below is based on code that was kindly shared by [Rayed Bin Wahed](https://github.com/rasbt/LLMs-from-scratch/discussions/51)\n",
"- The main difference between the `MultiHeadAttentionCombinedQKV` class and the `MultiHeadAttention` class used in chapter 3 is that `MultiHeadAttentionCombinedQKV` uses a single weight matrix, `self.qkv = nn.Linear(d_in, 3 * d_out, bias=qkv_bias)` instead of separate weight matrices:\n",
2024-03-13 08:37:54 -05:00
"\n",
" - `self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)`\n",
" - `self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)`\n",
" - `self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)`\n",
"\n",
"- Here, `self.qkv` combines all three weight matrices `self.W_query`, `self.W_key`, and `self.W_value` to carry out the query, key, and value computation in a single step\n",
"- Using `q, k, v = qkv.unbind(0)`, we obtain the individual query, key, and value tensors, which are then used similarly to the query, key, and value tensors in the `MultiHeadAttention` class in chapter 3"
]
},
{
"cell_type": "code",
2025-03-06 20:29:04 -06:00
"execution_count": 4,
2024-03-13 08:37:54 -05:00
"id": "9a6bd0a2-f27c-4602-afa0-c96cd295c1a6",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "9a6bd0a2-f27c-4602-afa0-c96cd295c1a6",
2025-03-06 20:29:04 -06:00
"outputId": "6ced0e41-958e-43af-ae3e-17b62148c1cd"
2024-03-13 08:37:54 -05:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([8, 1024, 768])\n"
]
}
],
"source": [
"import torch.nn as nn\n",
"\n",
"\n",
"class MultiHeadAttentionCombinedQKV(nn.Module):\n",
" def __init__(self, d_in, d_out, num_heads, context_length, dropout=0.0, qkv_bias=False):\n",
2024-03-13 08:37:54 -05:00
" super().__init__()\n",
"\n",
" assert d_out % num_heads == 0, \"d_out is indivisible by num_heads\"\n",
2024-03-13 08:37:54 -05:00
"\n",
" self.num_heads = num_heads\n",
" self.context_length = context_length\n",
2024-03-13 08:37:54 -05:00
" self.head_dim = d_out // num_heads\n",
"\n",
" self.qkv = nn.Linear(d_in, 3 * d_out, bias=qkv_bias)\n",
2024-05-27 07:46:29 -05:00
" self.proj = nn.Linear(d_out, d_out)\n",
2024-03-13 08:37:54 -05:00
" self.dropout = nn.Dropout(dropout)\n",
"\n",
" self.register_buffer(\n",
" \"mask\", torch.triu(torch.ones(context_length, context_length), diagonal=1)\n",
2024-03-13 08:37:54 -05:00
" )\n",
"\n",
" def forward(self, x):\n",
" batch_size, num_tokens, embed_dim = x.shape\n",
"\n",
" # (b, num_tokens, embed_dim) --> (b, num_tokens, 3 * embed_dim)\n",
" qkv = self.qkv(x)\n",
"\n",
" # (b, num_tokens, 3 * embed_dim) --> (b, num_tokens, 3, num_heads, head_dim)\n",
" qkv = qkv.view(batch_size, num_tokens, 3, self.num_heads, self.head_dim)\n",
2024-03-13 08:37:54 -05:00
"\n",
" # (b, num_tokens, 3, num_heads, head_dim) --> (3, b, num_heads, num_tokens, head_dim)\n",
" qkv = qkv.permute(2, 0, 3, 1, 4)\n",
"\n",
" # (3, b, num_heads, num_tokens, head_dim) -> 3 times (b, num_head, num_tokens, head_dim)\n",
" queries, keys, values = qkv.unbind(0)\n",
"\n",
" # (b, num_heads, num_tokens, head_dim) --> (b, num_heads, num_tokens, num_tokens)\n",
" attn_scores = queries @ keys.transpose(-2, -1)\n",
" attn_scores = attn_scores.masked_fill(\n",
" self.mask.bool()[:num_tokens, :num_tokens], -torch.inf\n",
" )\n",
"\n",
" attn_weights = torch.softmax(attn_scores / keys.shape[-1]**-0.5, dim=-1)\n",
" attn_weights = self.dropout(attn_weights)\n",
"\n",
" # (b, num_heads, num_tokens, num_tokens) --> (b, num_heads, num_tokens, head_dim)\n",
" context_vec = attn_weights @ values\n",
"\n",
" # (b, num_heads, num_tokens, head_dim) --> (b, num_tokens, num_heads, head_dim)\n",
" context_vec = context_vec.transpose(1, 2)\n",
"\n",
" # (b, num_tokens, num_heads, head_dim) --> (b, num_tokens, embed_dim)\n",
" context_vec = context_vec.contiguous().view(batch_size, num_tokens, embed_dim)\n",
2024-03-13 08:37:54 -05:00
"\n",
" context_vec = self.proj(context_vec)\n",
"\n",
" return context_vec\n",
"\n",
"\n",
"mha_combined_qkv = MultiHeadAttentionCombinedQKV(\n",
" d_in=embed_dim,\n",
" d_out=embed_dim,\n",
" context_length=context_len,\n",
2024-03-13 08:37:54 -05:00
" dropout=0.0,\n",
" num_heads=12,\n",
" qkv_bias=False\n",
").to(device)\n",
"\n",
"out = mha_combined_qkv(embeddings)\n",
"print(out.shape)"
]
},
{
"cell_type": "markdown",
"id": "9b14390d-3e21-43fd-87be-43e7029163ee",
"metadata": {
"id": "9b14390d-3e21-43fd-87be-43e7029163ee"
},
"source": [
"<br>\n",
"&nbsp;\n",
"\n",
"## 4) Multi-head attention with Einsum\n",
"\n",
"- Implementing multi-head attention using Einstein summation via [`torch.einsum`](https://pytorch.org/docs/stable/generated/torch.einsum.html)"
]
},
{
"cell_type": "code",
2025-03-06 20:29:04 -06:00
"execution_count": 5,
"id": "92481814-068d-439b-a65c-b1310ebbe0aa",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "92481814-068d-439b-a65c-b1310ebbe0aa",
2025-03-06 20:29:04 -06:00
"outputId": "f46b111d-3563-4e5c-da2a-7be156974f5e"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([8, 1024, 768])\n"
]
}
],
"source": [
"import math\n",
"\n",
"\n",
"class MHAEinsum(nn.Module):\n",
"\n",
" def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):\n",
" super().__init__()\n",
" assert d_out % num_heads == 0, \"d_out must be divisible by num_heads\"\n",
"\n",
" self.d_out = d_out\n",
" self.num_heads = num_heads\n",
" self.head_dim = d_out // num_heads\n",
"\n",
" # Initialize parameters for Q, K, V\n",
" self.W_query = nn.Parameter(torch.randn(d_out, d_in))\n",
" self.W_key = nn.Parameter(torch.randn(d_out, d_in))\n",
" self.W_value = nn.Parameter(torch.randn(d_out, d_in))\n",
"\n",
" if qkv_bias:\n",
" self.bias_q = nn.Parameter(torch.zeros(d_out))\n",
" self.bias_k = nn.Parameter(torch.zeros(d_out))\n",
" self.bias_v = nn.Parameter(torch.zeros(d_out))\n",
" else:\n",
" self.register_parameter(\"bias_q\", None)\n",
" self.register_parameter(\"bias_k\", None)\n",
" self.register_parameter(\"bias_v\", None)\n",
"\n",
" self.out_proj = nn.Linear(d_out, d_out)\n",
" self.dropout = nn.Dropout(dropout)\n",
" self.register_buffer(\"mask\", torch.triu(torch.ones(context_length, context_length), diagonal=1))\n",
"\n",
" # Initialize parameters\n",
" self.reset_parameters()\n",
"\n",
"\n",
" def reset_parameters(self):\n",
" nn.init.kaiming_uniform_(self.W_query, a=math.sqrt(5))\n",
" nn.init.kaiming_uniform_(self.W_key, a=math.sqrt(5))\n",
" nn.init.kaiming_uniform_(self.W_value, a=math.sqrt(5))\n",
" if self.bias_q is not None:\n",
" fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.W_query)\n",
" bound = 1 / math.sqrt(fan_in)\n",
" nn.init.uniform_(self.bias_q, -bound, bound)\n",
" nn.init.uniform_(self.bias_k, -bound, bound)\n",
" nn.init.uniform_(self.bias_v, -bound, bound)\n",
"\n",
" def forward(self, x):\n",
" b, n, _ = x.shape\n",
"\n",
" # Calculate Q, K, V using einsum, first perform linear transformations\n",
" Q = torch.einsum(\"bnd,di->bni\", x, self.W_query)\n",
" K = torch.einsum(\"bnd,di->bni\", x, self.W_key)\n",
" V = torch.einsum(\"bnd,di->bni\", x, self.W_value)\n",
"\n",
" # Add biases if they are used\n",
" if self.bias_q is not None:\n",
" Q += self.bias_q\n",
" K += self.bias_k\n",
" V += self.bias_v\n",
"\n",
" # Reshape for multi-head attention\n",
" Q = Q.view(b, n, self.num_heads, self.head_dim).transpose(1, 2)\n",
" K = K.view(b, n, self.num_heads, self.head_dim).transpose(1, 2)\n",
" V = V.view(b, n, self.num_heads, self.head_dim).transpose(1, 2)\n",
"\n",
" # Scaled dot-product attention\n",
" scores = torch.einsum(\"bhnd,bhmd->bhnm\", Q, K) / (self.head_dim ** 0.5)\n",
"\n",
" # Apply mask\n",
" mask = self.mask[:n, :n].unsqueeze(0).unsqueeze(1).expand(b, self.num_heads, n, n)\n",
" scores = scores.masked_fill(mask.bool(), -torch.inf)\n",
"\n",
" # Softmax and dropout\n",
" attn_weights = torch.softmax(scores, dim=-1)\n",
" attn_weights = self.dropout(attn_weights)\n",
"\n",
" # Aggregate the attended context vectors\n",
" context_vec = torch.einsum(\"bhnm,bhmd->bhnd\", attn_weights, V)\n",
"\n",
" # Combine heads and project the output\n",
" context_vec = context_vec.transpose(1, 2).reshape(b, n, self.d_out)\n",
" context_vec = self.out_proj(context_vec)\n",
"\n",
" return context_vec\n",
"\n",
"\n",
"mha_einsum = MHAEinsum(\n",
" d_in=embed_dim,\n",
" d_out=embed_dim,\n",
" context_length=context_len,\n",
" dropout=0.0,\n",
" num_heads=12,\n",
" qkv_bias=False\n",
").to(device)\n",
"\n",
"out = mha_einsum(embeddings)\n",
"print(out.shape)"
]
},
2024-03-13 08:37:54 -05:00
{
"cell_type": "markdown",
"id": "48a042d3-ee78-4c29-bf63-d92fe6706632",
"metadata": {
"id": "48a042d3-ee78-4c29-bf63-d92fe6706632"
},
"source": [
2024-03-23 07:27:43 -05:00
"<br>\n",
"&nbsp;\n",
"\n",
"## 5) Multi-head attention with PyTorch's scaled dot product attention and FlashAttention"
2024-03-13 08:37:54 -05:00
]
},
{
"cell_type": "markdown",
"id": "f78e346f-3b85-44e6-9feb-f01131381148",
"metadata": {
"id": "f78e346f-3b85-44e6-9feb-f01131381148"
},
"source": [
2024-08-10 09:44:11 -05:00
"- The implementation below uses PyTorch's [`scaled_dot_product_attention`](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) function, which implements a memory-optimized version of self-attention called [FlashAttention](https://arxiv.org/abs/2205.14135)"
2024-03-13 08:37:54 -05:00
]
},
{
"cell_type": "code",
2025-03-06 20:29:04 -06:00
"execution_count": 6,
2024-03-13 08:37:54 -05:00
"id": "1b8e5a0d-1f65-4a03-bf6e-723f0cc428f5",
"metadata": {
"id": "1b8e5a0d-1f65-4a03-bf6e-723f0cc428f5"
},
"outputs": [],
"source": [
"class MHAPyTorchScaledDotProduct(nn.Module):\n",
" def __init__(self, d_in, d_out, num_heads, context_length, dropout=0.0, qkv_bias=False):\n",
2024-03-13 08:37:54 -05:00
" super().__init__()\n",
"\n",
" assert d_out % num_heads == 0, \"d_out is indivisible by num_heads\"\n",
2024-03-13 08:37:54 -05:00
"\n",
" self.num_heads = num_heads\n",
" self.context_length = context_length\n",
2024-03-13 08:37:54 -05:00
" self.head_dim = d_out // num_heads\n",
" self.d_out = d_out\n",
"\n",
" self.qkv = nn.Linear(d_in, 3 * d_out, bias=qkv_bias)\n",
2024-05-26 15:38:35 -05:00
" self.proj = nn.Linear(d_out, d_out)\n",
2024-03-13 08:37:54 -05:00
" self.dropout = dropout\n",
"\n",
" def forward(self, x):\n",
" batch_size, num_tokens, embed_dim = x.shape\n",
"\n",
" # (b, num_tokens, embed_dim) --> (b, num_tokens, 3 * embed_dim)\n",
" qkv = self.qkv(x)\n",
"\n",
" # (b, num_tokens, 3 * embed_dim) --> (b, num_tokens, 3, num_heads, head_dim)\n",
" qkv = qkv.view(batch_size, num_tokens, 3, self.num_heads, self.head_dim)\n",
2024-03-13 08:37:54 -05:00
"\n",
" # (b, num_tokens, 3, num_heads, head_dim) --> (3, b, num_heads, num_tokens, head_dim)\n",
" qkv = qkv.permute(2, 0, 3, 1, 4)\n",
"\n",
" # (3, b, num_heads, num_tokens, head_dim) -> 3 times (b, num_heads, num_tokens, head_dim)\n",
" queries, keys, values = qkv\n",
2024-03-13 08:37:54 -05:00
"\n",
" use_dropout = 0. if not self.training else self.dropout\n",
2024-08-10 09:44:11 -05:00
"\n",
2024-03-13 08:37:54 -05:00
" context_vec = nn.functional.scaled_dot_product_attention(\n",
" queries, keys, values, attn_mask=None, dropout_p=use_dropout, is_causal=True)\n",
"\n",
" # Combine heads, where self.d_out = self.num_heads * self.head_dim\n",
" context_vec = context_vec.transpose(1, 2).contiguous().view(batch_size, num_tokens, self.d_out)\n",
"\n",
" context_vec = self.proj(context_vec)\n",
"\n",
2024-03-13 08:37:54 -05:00
" return context_vec"
]
},
{
"cell_type": "code",
2025-03-06 20:29:04 -06:00
"execution_count": 7,
2024-03-13 08:37:54 -05:00
"id": "fbc8ba92-3471-41cb-b1b2-4c0ef5be392b",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "fbc8ba92-3471-41cb-b1b2-4c0ef5be392b",
2025-03-06 20:29:04 -06:00
"outputId": "c69e79a4-e741-4371-8ecc-a775b8b246bf"
2024-03-13 08:37:54 -05:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([8, 1024, 768])\n"
]
}
],
"source": [
"mha_pytorch_scaled = MHAPyTorchScaledDotProduct(\n",
" d_in=embed_dim,\n",
" d_out=embed_dim,\n",
" context_length=context_len,\n",
2024-03-13 08:37:54 -05:00
" dropout=0.0,\n",
" num_heads=12,\n",
" qkv_bias=False\n",
").to(device)\n",
"\n",
"out = mha_pytorch_scaled(embeddings)\n",
"print(out.shape)"
]
},
2024-08-10 09:44:11 -05:00
{
"cell_type": "markdown",
"id": "51492724-6018-49f6-8bf6-ae9e585229c3",
"metadata": {
"id": "51492724-6018-49f6-8bf6-ae9e585229c3"
},
"source": [
"<br>\n",
"&nbsp;\n",
"\n",
"## 6) PyTorch's scaled dot product attention without FlashAttention\n",
2024-08-10 09:44:11 -05:00
"\n",
"- This is similar to above, except that we disable FlashAttention by passing an explicit causal mask"
]
},
{
"cell_type": "code",
2025-03-06 20:29:04 -06:00
"execution_count": 8,
2024-08-10 09:44:11 -05:00
"id": "bad53538-e905-4065-ba0c-caacdfec5a0b",
"metadata": {
"id": "bad53538-e905-4065-ba0c-caacdfec5a0b"
},
"outputs": [],
"source": [
"class MHAPyTorchSDPAWithoutFlash(nn.Module):\n",
" def __init__(self, d_in, d_out, num_heads, context_length, dropout=0.0, qkv_bias=False):\n",
" super().__init__()\n",
"\n",
" assert d_out % num_heads == 0, \"embed_dim is indivisible by num_heads\"\n",
"\n",
" self.num_heads = num_heads\n",
" self.context_length = context_length\n",
" self.head_dim = d_out // num_heads\n",
" self.d_out = d_out\n",
"\n",
" self.qkv = nn.Linear(d_in, 3 * d_out, bias=qkv_bias)\n",
" self.proj = nn.Linear(d_out, d_out)\n",
" self.dropout = dropout\n",
" self.register_buffer(\"mask\", torch.triu(torch.ones(context_length, context_length), diagonal=1).bool())\n",
"\n",
" def forward(self, x):\n",
" batch_size, num_tokens, embed_dim = x.shape\n",
"\n",
" # (b, num_tokens, embed_dim) --> (b, num_tokens, 3 * embed_dim)\n",
" qkv = self.qkv(x)\n",
"\n",
" # (b, num_tokens, 3 * embed_dim) --> (b, num_tokens, 3, num_heads, head_dim)\n",
" qkv = qkv.view(batch_size, num_tokens, 3, self.num_heads, self.head_dim)\n",
"\n",
" # (b, num_tokens, 3, num_heads, head_dim) --> (3, b, num_heads, num_tokens, head_dim)\n",
" qkv = qkv.permute(2, 0, 3, 1, 4)\n",
"\n",
" # (3, b, num_heads, num_tokens, head_dim) -> 3 times (b, num_heads, num_tokens, head_dim)\n",
" queries, keys, values = qkv\n",
"\n",
" use_dropout = 0. if not self.training else self.dropout\n",
"\n",
" # Ensure attn_mask is compatible with expected shape and `batch_first=True`\n",
" # No need to manually adjust for num_heads; ensure it's right for the sequence\n",
" if self.context_length >= num_tokens:\n",
" attn_mask = self.mask[:num_tokens, :num_tokens]\n",
" else:\n",
" attn_mask = self.mask[:self.context_length, :self.context_length]\n",
"\n",
" context_vec = nn.functional.scaled_dot_product_attention(\n",
" queries, keys, values, attn_mask=attn_mask, dropout_p=use_dropout, is_causal=False)\n",
"\n",
" # Combine heads, where self.d_out = self.num_heads * self.head_dim\n",
" context_vec = context_vec.transpose(1, 2).contiguous().view(batch_size, num_tokens, self.d_out)\n",
"\n",
" context_vec = self.proj(context_vec)\n",
"\n",
" return context_vec"
]
},
{
"cell_type": "code",
2025-03-06 20:29:04 -06:00
"execution_count": 9,
2024-08-10 09:44:11 -05:00
"id": "f3da7850-e772-47d3-bd51-22d077b01412",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "f3da7850-e772-47d3-bd51-22d077b01412",
2025-03-06 20:29:04 -06:00
"outputId": "1d208d5c-0d33-40c5-c473-0bef0bf123a0"
2024-08-10 09:44:11 -05:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([8, 1024, 768])\n"
]
}
],
"source": [
"mha_pytorch_sdpa_no_flash = MHAPyTorchSDPAWithoutFlash(\n",
" d_in=embed_dim,\n",
" d_out=embed_dim,\n",
" context_length=context_len,\n",
" dropout=0.0,\n",
" num_heads=12,\n",
" qkv_bias=False\n",
").to(device)\n",
"\n",
"out = mha_pytorch_sdpa_no_flash(embeddings)\n",
"print(out.shape)"
]
},
2024-03-13 08:37:54 -05:00
{
"cell_type": "markdown",
"id": "351c318f-4835-4d74-8d58-a070222447c4",
"metadata": {
"id": "351c318f-4835-4d74-8d58-a070222447c4"
},
"source": [
2024-03-23 07:27:43 -05:00
"<br>\n",
"&nbsp;\n",
"\n",
"## 7) Using PyTorch's torch.nn.MultiheadAttention"
2024-03-13 08:37:54 -05:00
]
},
{
"cell_type": "markdown",
"id": "74a6d060-6324-48fa-a35c-cb09f2a48965",
"metadata": {
"id": "74a6d060-6324-48fa-a35c-cb09f2a48965"
},
"source": [
"- Below, we use PyTorch's [torch.nn.MultiheadAttention](https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html) implementation"
]
},
{
"cell_type": "code",
2025-03-06 20:29:04 -06:00
"execution_count": 10,
2024-03-13 08:37:54 -05:00
"id": "3799c7ef-3155-42c6-a829-f95656453ae0",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "3799c7ef-3155-42c6-a829-f95656453ae0",
2025-03-06 20:29:04 -06:00
"outputId": "dbee238a-a189-4f30-ac45-3b5e51237615"
2024-03-13 08:37:54 -05:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([8, 1024, 768])\n"
]
}
],
"source": [
"import torch.nn as nn\n",
"\n",
"\n",
"class MHAPyTorchClass(nn.Module):\n",
" def __init__(self, d_in, d_out, num_heads, context_length, dropout=0.0, qkv_bias=False, need_weights=True):\n",
2024-03-13 08:37:54 -05:00
" super().__init__()\n",
"\n",
" self.context_length = context_length\n",
2024-03-13 08:37:54 -05:00
" self.multihead_attn = nn.MultiheadAttention(\n",
" embed_dim=d_out,\n",
" num_heads=num_heads,\n",
" dropout=dropout,\n",
" bias=qkv_bias,\n",
" add_bias_kv=qkv_bias,\n",
" batch_first=True,\n",
" )\n",
"\n",
" self.need_weights = need_weights\n",
" self.proj = nn.Linear(d_out, d_out)\n",
" self.register_buffer(\"mask\", torch.triu(torch.ones(context_length, context_length), diagonal=1).bool())\n",
2024-03-13 08:37:54 -05:00
"\n",
" def forward(self, x):\n",
" batch_size, num_tokens, _ = x.shape\n",
"\n",
" # Ensure attn_mask is compatible with expected shape and `batch_first=True`\n",
" # No need to manually adjust for num_heads; ensure it's right for the sequence\n",
" if self.context_length >= num_tokens:\n",
2024-03-13 08:37:54 -05:00
" attn_mask = self.mask[:num_tokens, :num_tokens]\n",
" else:\n",
" attn_mask = self.mask[:self.context_length, :self.context_length]\n",
2024-03-13 08:37:54 -05:00
"\n",
" # attn_mask broadcasting will handle batch_size dimension implicitly\n",
" attn_output, _ = self.multihead_attn(\n",
" x, x, x, attn_mask=attn_mask, need_weights=self.need_weights\n",
" )\n",
"\n",
" output = self.proj(attn_output)\n",
"\n",
" return output\n",
"\n",
"\n",
"mha_pytorch_class_default = MHAPyTorchClass(\n",
" d_in=embed_dim,\n",
" d_out=embed_dim,\n",
" context_length=context_len,\n",
2024-03-13 08:37:54 -05:00
" dropout=0.0,\n",
" num_heads=12,\n",
" qkv_bias=False\n",
").to(device)\n",
"\n",
"out = mha_pytorch_class_default(embeddings)\n",
"print(out.shape)"
]
},
{
"cell_type": "markdown",
"id": "a3953bff-1056-4de2-bfd1-dfccf659eee4",
"metadata": {
"id": "a3953bff-1056-4de2-bfd1-dfccf659eee4"
},
"source": [
2024-03-23 07:27:43 -05:00
"<br>\n",
"&nbsp;\n",
"\n",
"## 8) Using PyTorch's torch.nn.MultiheadAttention with `scaled_dot_product_attention`"
2024-03-13 08:37:54 -05:00
]
},
{
"cell_type": "markdown",
"id": "d2164859-31a0-4537-b4fb-27d57675ba77",
"metadata": {
"id": "d2164859-31a0-4537-b4fb-27d57675ba77"
},
"source": [
"- Set `need_weights` (default `True`) to `False` so that `MultiheadAttention` uses `scaled_dot_product_attention` [according to the documentation](https://github.com/pytorch/pytorch/blob/71d020262793542974cf13b30f2a9099773f015c/torch/nn/modules/activation.py#L1096)\n",
2024-03-13 08:37:54 -05:00
"\n",
"```markdown\n",
"need_weights: If specified, returns `attn_output_weights` in addition to `attn_outputs`.\n",
" Set `need_weights=False` to use the optimized `scaled_dot_product_attention`\n",
" and achieve the best performance for MHA.\n",
" Default: `True`\n",
"```"
2024-03-13 08:37:54 -05:00
]
},
{
"cell_type": "code",
2025-03-06 20:29:04 -06:00
"execution_count": 11,
2024-03-13 08:37:54 -05:00
"id": "4a4c2afe-5e1f-4bd7-a118-67031176f147",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "4a4c2afe-5e1f-4bd7-a118-67031176f147",
2025-03-06 20:29:04 -06:00
"outputId": "54ae35e3-6d9e-485f-c59c-6955430382f8"
2024-03-13 08:37:54 -05:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([8, 1024, 768])\n"
]
}
],
"source": [
"mha_pytorch_class_noweights = MHAPyTorchClass(\n",
" d_in=embed_dim,\n",
" d_out=embed_dim,\n",
" context_length=context_len,\n",
2024-03-13 08:37:54 -05:00
" dropout=0.0,\n",
" num_heads=12,\n",
" qkv_bias=False,\n",
" need_weights=False # NEW!\n",
").to(device)\n",
"\n",
"out = mha_pytorch_class_noweights(embeddings)\n",
"print(out.shape)"
]
},
2024-08-10 09:44:11 -05:00
{
"cell_type": "markdown",
"id": "21f4ff35-651c-4e47-bfa1-016f3de01ecc",
"metadata": {
"id": "21f4ff35-651c-4e47-bfa1-016f3de01ecc"
},
"source": [
"<br>\n",
"&nbsp;\n",
"\n",
"## 9) Using PyTorch's FlexAttention\n",
2024-08-10 09:44:11 -05:00
"\n",
"- See [FlexAttention: The Flexibility of PyTorch with the Performance of FlashAttention](https://pytorch.org/blog/flexattention/) to learn more about FlexAttention\n",
"- FlexAttention caveat: It currently doesn't support dropout\n",
"- This is supported starting from PyTorch 2.5, which you can install on a CPU machine via\n",
2024-08-10 09:44:11 -05:00
"\n",
" ```bash\n",
" pip install torch torchvision torchaudio\n",
" ```\n",
2024-08-10 09:44:11 -05:00
"\n",
"- To install PyTorch on a GPU machine, use the following (for more information, also see the installation menu on [pytorch.org](https://pytorch.org/))\n",
2024-08-10 09:44:11 -05:00
"\n",
" ```bash\n",
" pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124\n",
" ```"
2024-08-10 09:44:11 -05:00
]
},
{
"cell_type": "code",
2025-03-06 20:29:04 -06:00
"execution_count": 12,
2024-08-10 09:44:11 -05:00
"id": "834318c8-4748-4902-99f0-70ee02bef63e",
"metadata": {
"id": "834318c8-4748-4902-99f0-70ee02bef63e"
},
"outputs": [],
"source": [
"from packaging.version import parse as parse_version\n",
"\n",
"def normalize_version(version):\n",
" parsed_version = parse_version(version)\n",
" return parse_version(f\"{parsed_version.major}.{parsed_version.minor}.{parsed_version.micro}\")\n",
"\n",
"current_version = normalize_version(torch.__version__)\n",
"MIN_TORCH_VERSION = \"2.5.0\"\n",
"required_version = parse_version(MIN_TORCH_VERSION)"
]
},
{
"cell_type": "code",
2025-03-06 20:29:04 -06:00
"execution_count": 13,
2024-08-10 09:44:11 -05:00
"id": "WYyFRCXndVH9",
"metadata": {
"id": "WYyFRCXndVH9"
},
"outputs": [],
"source": [
2025-03-06 20:29:04 -06:00
"if current_version >= required_version and torch.cuda.is_available():\n",
" from torch.nn.attention.flex_attention import flex_attention, create_block_mask\n",
2024-08-10 09:44:11 -05:00
"\n",
"\n",
"def causal(b, h, q_idx, kv_idx):\n",
" return q_idx >= kv_idx\n",
"\n",
"\n",
"class MHAPyTorchFlexAttention(nn.Module):\n",
"\n",
" def __init__(self, d_in, d_out, num_heads, context_length, dropout=0.0, qkv_bias=False):\n",
" super().__init__()\n",
"\n",
" assert d_out % num_heads == 0, \"embed_dim is indivisible by num_heads\"\n",
"\n",
" self.num_heads = num_heads\n",
" self.context_length = context_length\n",
" self.head_dim = d_out // num_heads\n",
" self.d_out = d_out\n",
"\n",
" self.qkv = nn.Linear(d_in, 3 * d_out, bias=qkv_bias)\n",
" self.proj = nn.Linear(d_out, d_out)\n",
" self.dropout = dropout\n",
" # self.register_buffer(\"block_mask\", create_block_mask(causal, B=None, H=None, Q_LEN=context_length, KV_LEN=context_length))\n",
" # `create_block_mask` function does not support buffers, yet\n",
2024-08-10 09:44:11 -05:00
" self.block_mask = create_block_mask(causal, B=None, H=None, Q_LEN=context_length, KV_LEN=context_length)\n",
"\n",
"\n",
" def forward(self, x):\n",
" batch_size, num_tokens, embed_dim = x.shape\n",
"\n",
" # (b, num_tokens, embed_dim) --> (b, num_tokens, 3 * embed_dim)\n",
" qkv = self.qkv(x)\n",
"\n",
" # (b, num_tokens, 3 * embed_dim) --> (b, num_tokens, 3, num_heads, head_dim)\n",
" qkv = qkv.view(batch_size, num_tokens, 3, self.num_heads, self.head_dim)\n",
"\n",
" # (b, num_tokens, 3, num_heads, head_dim) --> (3, b, num_heads, num_tokens, head_dim)\n",
" qkv = qkv.permute(2, 0, 3, 1, 4)\n",
"\n",
" # (3, b, num_heads, num_tokens, head_dim) -> 3 times (b, num_heads, num_tokens, head_dim)\n",
" queries, keys, values = qkv\n",
"\n",
" # use_dropout = 0. if not self.training else self.dropout\n",
2024-08-10 09:44:11 -05:00
"\n",
" # Ensure attn_mask is compatible with expected shape and `batch_first=True`\n",
" # No need to manually adjust for num_heads; ensure it's right for the sequence\n",
" if self.context_length >= num_tokens:\n",
" attn_mask = self.block_mask[:num_tokens, :num_tokens]\n",
" else:\n",
" attn_mask = self.block_mask[:self.context_length, :self.context_length]\n",
"\n",
" context_vec = flex_attention(queries, keys, values, block_mask=attn_mask)\n",
"\n",
" # Combine heads, where self.d_out = self.num_heads * self.head_dim\n",
" context_vec = context_vec.transpose(1, 2).contiguous().view(batch_size, num_tokens, self.d_out)\n",
"\n",
" context_vec = self.proj(context_vec)\n",
"\n",
" return context_vec"
]
},
{
"cell_type": "code",
2025-03-06 20:29:04 -06:00
"execution_count": 14,
2024-08-10 09:44:11 -05:00
"id": "9cdaaf8a-f956-44bc-932f-4d33448e8aaf",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "9cdaaf8a-f956-44bc-932f-4d33448e8aaf",
2025-03-06 20:29:04 -06:00
"outputId": "c239092a-696e-4573-e933-c337f090d294"
2024-08-10 09:44:11 -05:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([8, 1024, 768])\n"
]
}
],
2024-08-10 09:44:11 -05:00
"source": [
"if current_version >= required_version and torch.cuda.is_available():\n",
2024-08-10 09:44:11 -05:00
"\n",
" mha_pytorch_flex = MHAPyTorchFlexAttention(\n",
" d_in=embed_dim,\n",
" d_out=embed_dim,\n",
" context_length=context_len,\n",
" dropout=0.0,\n",
" num_heads=12,\n",
" qkv_bias=False\n",
" ).to(device)\n",
"\n",
" out = mha_pytorch_flex(embeddings)\n",
" print(out.shape)"
]
},
2024-03-13 08:37:54 -05:00
{
"cell_type": "markdown",
"id": "8877de71-f84f-4f6d-bc87-7552013b6301",
"metadata": {
"id": "8877de71-f84f-4f6d-bc87-7552013b6301"
},
"source": [
2024-03-23 07:27:43 -05:00
"<br>\n",
"&nbsp;\n",
"\n",
2024-03-13 08:37:54 -05:00
"## Quick speed comparison (M3 Macbook Air CPU)"
]
},
{
"cell_type": "code",
"execution_count": null,
2024-08-10 09:44:11 -05:00
"id": "219cf93a-078f-434d-888c-2458d0731285",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "219cf93a-078f-434d-888c-2458d0731285",
"outputId": "a10b52d4-b4e6-43c2-9677-113c41edd3b7"
2024-08-10 09:44:11 -05:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"PyTorch version: 2.4.0\n",
2024-08-10 09:44:11 -05:00
"Running on cpu\n"
]
}
],
"source": [
"torch.manual_seed(123)\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"print(f\"PyTorch version: {torch.__version__}\")\n",
"print(f\"Running on {device}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
2024-03-13 08:37:54 -05:00
"id": "a97c0b2e-6593-49d8-98bc-2267b3aa610f",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "a97c0b2e-6593-49d8-98bc-2267b3aa610f",
"outputId": "7bcd7da4-d115-4ba6-efba-377a0bd7d3a8"
2024-03-13 08:37:54 -05:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"179 ms ± 7.39 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
2024-03-13 08:37:54 -05:00
]
}
],
"source": [
"## 1) CausalAttention MHA wrapper class from chapter 3\n",
"%timeit mha_ch03_wrapper(embeddings)"
]
},
{
"cell_type": "code",
"execution_count": null,
2024-03-13 08:37:54 -05:00
"id": "19db9c2c-8e75-431a-8eef-0b4d8284e6e6",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "19db9c2c-8e75-431a-8eef-0b4d8284e6e6",
"outputId": "b04b4d0d-71aa-4944-f02b-131bf5a50202"
2024-03-13 08:37:54 -05:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"166 ms ± 2.62 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
2024-03-13 08:37:54 -05:00
]
}
],
"source": [
"## 2) The multi-head attention class from chapter 3\n",
"%timeit mha_ch03(embeddings)"
]
},
{
"cell_type": "code",
"execution_count": null,
2024-03-13 08:37:54 -05:00
"id": "aa526ee0-7a88-4f34-a49a-f8f97da83779",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "aa526ee0-7a88-4f34-a49a-f8f97da83779",
"outputId": "5436928a-7b98-4c40-bf51-97973f13327e"
2024-03-13 08:37:54 -05:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"190 ms ± 2.03 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
2024-03-13 08:37:54 -05:00
]
}
],
"source": [
"## 3) An alternative multi-head attention with combined weights\n",
"%timeit mha_combined_qkv(embeddings)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "131ca826-35bf-47e5-b497-540aba439ef9",
"metadata": {
"id": "131ca826-35bf-47e5-b497-540aba439ef9",
"outputId": "f5848852-f81b-4e5f-a7ff-e37a8445ad91"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"196 ms ± 1.08 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
]
}
],
"source": [
"## 4) Multi-head attention using Einstein summation\n",
"%timeit mha_einsum(embeddings)"
]
},
{
"cell_type": "code",
"execution_count": null,
2024-03-13 08:37:54 -05:00
"id": "cc2b4256-16d8-4c34-9fd0-d4b4af0e60fa",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "cc2b4256-16d8-4c34-9fd0-d4b4af0e60fa",
"outputId": "9e07ce73-a2de-4e2c-8276-64626df9450e"
2024-03-13 08:37:54 -05:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"110 ms ± 423 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
2024-03-13 08:37:54 -05:00
]
}
],
"source": [
"## 5) Multi-head attention with PyTorch's scaled dot product attention\n",
2024-03-13 08:37:54 -05:00
"%timeit mha_pytorch_scaled(embeddings)"
]
},
{
"cell_type": "code",
"execution_count": null,
2024-08-10 09:44:11 -05:00
"id": "c44305ce-9f61-451a-b9ef-30caba222357",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "c44305ce-9f61-451a-b9ef-30caba222357",
"outputId": "6bab4a24-5bb4-4ad6-b260-3b442f598950"
2024-08-10 09:44:11 -05:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"99.5 ms ± 790 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
2024-08-10 09:44:11 -05:00
]
}
],
"source": [
"## 6) PyTorch's scaled dot product attention without FlashAttention\n",
2024-08-10 09:44:11 -05:00
"%timeit mha_pytorch_sdpa_no_flash(embeddings)"
]
},
{
"cell_type": "code",
"execution_count": null,
2024-03-13 08:37:54 -05:00
"id": "0f209e70-ebb6-4a1a-b608-1ff42e41c01d",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "0f209e70-ebb6-4a1a-b608-1ff42e41c01d",
"outputId": "630c49d1-8a06-4148-cd96-a7b2467310a0"
2024-03-13 08:37:54 -05:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"198 ms ± 3.52 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
2024-03-13 08:37:54 -05:00
]
}
],
"source": [
"## 7) Using PyTorch's torch.nn.MultiheadAttention\n",
2024-03-13 08:37:54 -05:00
"%timeit mha_pytorch_class_default(embeddings)"
]
},
{
"cell_type": "code",
"execution_count": null,
2024-03-13 08:37:54 -05:00
"id": "3f4968c2-8d40-4ab9-8dba-052b4f77d756",
"metadata": {
2024-08-10 09:44:11 -05:00
"colab": {
"base_uri": "https://localhost:8080/"
},
2024-03-13 08:37:54 -05:00
"id": "3f4968c2-8d40-4ab9-8dba-052b4f77d756",
"outputId": "10f6a268-f9cf-446c-aa83-e87b6a0b4f5c"
2024-03-13 08:37:54 -05:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"168 ms ± 2.63 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
2024-03-13 08:37:54 -05:00
]
}
],
"source": [
"## 8) Using PyTorch's torch.nn.MultiheadAttention disabling `need_weights`\n",
2024-03-13 08:37:54 -05:00
"%timeit mha_pytorch_class_noweights(embeddings)"
]
},
2024-08-10 09:44:11 -05:00
{
"cell_type": "code",
"execution_count": null,
"id": "bdd8e0fc-ef24-424c-bccf-c381e73da228",
"metadata": {
"id": "bdd8e0fc-ef24-424c-bccf-c381e73da228"
},
2024-08-10 09:44:11 -05:00
"outputs": [],
"source": [
"## 9) Using PyTorch's FlexAttention\n",
2024-08-10 09:44:11 -05:00
"\n",
"# Requires PyTorch 2.5.0 or newer and currently only supports CUDA PyTorch\n",
2024-08-10 09:44:11 -05:00
"%timeit mha_pytorch_flex(embeddings)"
]
},
2024-03-13 08:37:54 -05:00
{
"cell_type": "markdown",
"id": "a78ff594-6cc2-496d-a302-789fa104c3c9",
"metadata": {
"id": "a78ff594-6cc2-496d-a302-789fa104c3c9"
},
"source": [
2024-03-23 07:27:43 -05:00
"<br>\n",
"&nbsp;\n",
"\n",
2024-03-13 08:37:54 -05:00
"## Quick speed comparison (Nvidia A100 GPU)"
]
},
{
"cell_type": "code",
2025-03-06 20:29:04 -06:00
"execution_count": 15,
2024-08-10 09:44:11 -05:00
"id": "RStnI1pEi6Eo",
"metadata": {
"id": "RStnI1pEi6Eo"
},
"outputs": [],
"source": [
"# Enable tensor cores\n",
2024-08-10 09:44:11 -05:00
"torch.set_float32_matmul_precision(\"high\")"
]
},
{
"cell_type": "code",
2025-03-06 20:29:04 -06:00
"execution_count": 16,
2024-08-10 09:44:11 -05:00
"id": "e8431d75-e1c9-4d9a-b7da-9a1ff391f2bf",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "e8431d75-e1c9-4d9a-b7da-9a1ff391f2bf",
2025-03-06 20:29:04 -06:00
"outputId": "787933f2-1911-4830-cc3e-c3fc47afd688"
2024-08-10 09:44:11 -05:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2025-03-06 20:29:04 -06:00
"PyTorch version: 2.6.0+cu124\n",
2024-08-10 09:44:11 -05:00
"Running on cuda\n"
]
}
],
"source": [
"torch.manual_seed(123)\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"print(f\"PyTorch version: {torch.__version__}\")\n",
"print(f\"Running on {device}\")"
]
},
{
"cell_type": "code",
2025-03-06 20:29:04 -06:00
"execution_count": 17,
2024-03-13 08:37:54 -05:00
"id": "707a2a14-a089-48a8-88aa-d328e1e0a9d0",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "707a2a14-a089-48a8-88aa-d328e1e0a9d0",
2025-03-06 20:29:04 -06:00
"outputId": "f79aa3cf-f860-4d31-85be-63caa513c9a4"
2024-03-13 08:37:54 -05:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2025-03-06 20:29:04 -06:00
"4.68 ms ± 121 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
2024-03-13 08:37:54 -05:00
]
}
],
"source": [
"## 1) CausalAttention MHA wrapper class from chapter 3\n",
"%timeit mha_ch03_wrapper(embeddings)"
]
},
{
"cell_type": "code",
2025-03-06 20:29:04 -06:00
"execution_count": 18,
2024-03-13 08:37:54 -05:00
"id": "8686dd69-3655-40e4-a57b-a2c55532a010",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "8686dd69-3655-40e4-a57b-a2c55532a010",
2025-03-06 20:29:04 -06:00
"outputId": "9e38912d-8ba4-4906-a9a4-47206297465c"
2024-03-13 08:37:54 -05:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2025-03-06 20:29:04 -06:00
"3.08 ms ± 195 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n"
2024-03-13 08:37:54 -05:00
]
}
],
"source": [
"## 2) The multi-head attention class from chapter 3\n",
"%timeit mha_ch03(embeddings)"
]
},
{
"cell_type": "code",
2025-03-06 20:29:04 -06:00
"execution_count": 19,
2024-03-13 08:37:54 -05:00
"id": "2209d7df-e54b-4910-ae2b-c78cf684d9bf",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "2209d7df-e54b-4910-ae2b-c78cf684d9bf",
2025-03-06 20:29:04 -06:00
"outputId": "cb9cda4b-4a35-4718-864c-f8de3cc04853"
2024-03-13 08:37:54 -05:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2025-03-06 20:29:04 -06:00
"3.81 ms ± 532 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
2024-03-13 08:37:54 -05:00
]
}
],
"source": [
"## 3) An alternative multi-head attention with combined weights\n",
"%timeit mha_combined_qkv(embeddings)"
]
},
{
"cell_type": "code",
2025-03-06 20:29:04 -06:00
"execution_count": 20,
"id": "abee5edf-2585-4f0e-846c-b1c7ca88f545",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "abee5edf-2585-4f0e-846c-b1c7ca88f545",
2025-03-06 20:29:04 -06:00
"outputId": "aadc2f49-02ff-4b10-bc75-8302c7929597"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2025-03-06 20:29:04 -06:00
"4.11 ms ± 170 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
]
}
],
"source": [
"## 4) Multi-head attention using Einstein summation\n",
"%timeit mha_einsum(embeddings)"
]
},
{
"cell_type": "code",
2025-03-06 20:29:04 -06:00
"execution_count": 21,
2024-03-13 08:37:54 -05:00
"id": "1075abe2-4839-4fd6-af3e-c09bb3651e26",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "1075abe2-4839-4fd6-af3e-c09bb3651e26",
2025-03-06 20:29:04 -06:00
"outputId": "56968cdf-158b-41bb-9505-ffde33c4f09c"
2024-03-13 08:37:54 -05:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2025-03-06 20:29:04 -06:00
"1.1 ms ± 800 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n"
2024-03-13 08:37:54 -05:00
]
2024-03-09 10:09:17 -06:00
}
2024-03-13 08:37:54 -05:00
],
"source": [
"## 5) Multi-head attention with PyTorch's scaled dot product attention\n",
2024-03-13 08:37:54 -05:00
"%timeit mha_pytorch_scaled(embeddings)"
]
},
{
"cell_type": "code",
2025-03-06 20:29:04 -06:00
"execution_count": 22,
2024-08-10 09:44:11 -05:00
"id": "218adbaf-f17f-47d9-81d5-41c758218df7",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "218adbaf-f17f-47d9-81d5-41c758218df7",
2025-03-06 20:29:04 -06:00
"outputId": "63e103a7-fade-4a30-d32a-f2cfac09ea6c"
2024-08-10 09:44:11 -05:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2025-03-06 20:29:04 -06:00
"1.8 ms ± 93.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
2024-08-10 09:44:11 -05:00
]
}
],
"source": [
"## 6) PyTorch's scaled dot product attention without FlashAttention\n",
2024-08-10 09:44:11 -05:00
"%timeit mha_pytorch_sdpa_no_flash(embeddings)"
]
},
{
"cell_type": "code",
2025-03-06 20:29:04 -06:00
"execution_count": 23,
2024-03-13 08:37:54 -05:00
"id": "868e3670-8edc-47bc-9e06-eb505e44dc9d",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "868e3670-8edc-47bc-9e06-eb505e44dc9d",
2025-03-06 20:29:04 -06:00
"outputId": "93f1c5e7-6040-44e9-c26c-0995caa86b50"
2024-03-13 08:37:54 -05:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2025-03-06 20:29:04 -06:00
"3.04 ms ± 394 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
2024-03-13 08:37:54 -05:00
]
}
],
"source": [
"## 7) Using PyTorch's torch.nn.MultiheadAttention\n",
2024-03-13 08:37:54 -05:00
"%timeit mha_pytorch_class_default(embeddings)"
]
},
{
"cell_type": "code",
2025-03-06 20:29:04 -06:00
"execution_count": 24,
2024-03-13 08:37:54 -05:00
"id": "944870e6-de54-4e3b-a455-b8f21f6f92c8",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "944870e6-de54-4e3b-a455-b8f21f6f92c8",
2025-03-06 20:29:04 -06:00
"outputId": "83e36077-80f9-41e1-abbb-2cd1f7645dd0"
2024-03-13 08:37:54 -05:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2025-03-06 20:29:04 -06:00
"2.13 ms ± 4.48 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
2024-03-13 08:37:54 -05:00
]
}
],
"source": [
"## 8) Using PyTorch's torch.nn.MultiheadAttention disabling `need_weights`\n",
2024-03-13 08:37:54 -05:00
"%timeit mha_pytorch_class_noweights(embeddings)"
]
},
2024-08-10 09:44:11 -05:00
{
"cell_type": "code",
2025-03-06 20:29:04 -06:00
"execution_count": 25,
2024-08-10 09:44:11 -05:00
"id": "evKtpb5QN_2A",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "evKtpb5QN_2A",
2025-03-06 20:29:04 -06:00
"outputId": "af64f756-1aad-4032-a431-76842bf7dafe"
2024-08-10 09:44:11 -05:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2025-03-06 20:29:04 -06:00
"13.9 ms ± 557 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
2024-08-10 09:44:11 -05:00
]
}
],
"source": [
"## 9) Using PyTorch's FlexAttention\n",
2024-08-10 09:44:11 -05:00
"\n",
"# Requires PyTorch 2.5.0 or newer\n",
"%timeit mha_pytorch_flex(embeddings)"
]
},
2024-03-13 08:37:54 -05:00
{
"cell_type": "markdown",
"id": "dabc6575-0316-4640-a729-e616d5c17b73",
"metadata": {
"id": "dabc6575-0316-4640-a729-e616d5c17b73"
},
"source": [
2024-03-23 07:27:43 -05:00
"<br>\n",
"&nbsp;\n",
"\n",
"\n",
"# Visualizations"
2024-03-13 08:37:54 -05:00
]
},
{
"cell_type": "code",
2025-03-06 20:29:04 -06:00
"execution_count": 26,
2024-08-10 09:44:11 -05:00
"id": "bbb2f729-d3d8-46d0-b249-9249197ea574",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "bbb2f729-d3d8-46d0-b249-9249197ea574",
2025-03-06 20:29:04 -06:00
"outputId": "6e6167c4-93e5-4491-a49f-c1d3966fb10a"
2024-08-10 09:44:11 -05:00
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2025-03-06 20:29:04 -06:00
"PyTorch version: 2.6.0+cu124\n",
2024-08-10 09:44:11 -05:00
"Running on cuda\n"
]
}
],
"source": [
"torch.manual_seed(123)\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"print(f\"PyTorch version: {torch.__version__}\")\n",
"print(f\"Running on {device}\")"
]
},
{
"cell_type": "code",
2025-03-06 20:29:04 -06:00
"execution_count": 27,
"id": "b0620bf5",
"metadata": {
"id": "b0620bf5"
},
"outputs": [],
"source": [
"functions = {\n",
" \"1) MHA wrapper class\": mha_ch03_wrapper,\n",
" \"2) MHA Ch03\": mha_ch03,\n",
" \"3) MHA with combined QKV weights\": mha_combined_qkv,\n",
" \"4) MHA with Einsum\": mha_einsum,\n",
" \"5) MHA with PyTorch scaled_dot_product_attention\": mha_pytorch_scaled,\n",
" \"6) PyTorch's SDPA, no FlashAttention\": mha_pytorch_sdpa_no_flash,\n",
" \"7) PyTorch MHA class defaults\": mha_pytorch_class_default,\n",
" \"8) PyTorch MHA with need_weights=False\": mha_pytorch_class_noweights\n",
" }\n",
"\n",
2025-03-06 20:29:04 -06:00
"if current_version >= required_version and torch.cuda.is_available():\n",
" functions[\"9) PyTorch's FlexAttention\"] = mha_pytorch_flex"
]
},
{
"cell_type": "code",
2025-03-06 20:29:04 -06:00
"execution_count": 28,
"id": "CDJAPZaszaqx",
"metadata": {
"id": "CDJAPZaszaqx"
},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"# Customize further for dark mode aesthetics\n",
"plt.rcParams[\"figure.facecolor\"] = \"#121212\"\n",
"plt.rcParams[\"axes.facecolor\"] = \"#121212\"\n",
"plt.rcParams[\"axes.edgecolor\"] = \"white\"\n",
"plt.rcParams[\"axes.labelcolor\"] = \"white\"\n",
"plt.rcParams[\"text.color\"] = \"white\"\n",
"plt.rcParams[\"xtick.color\"] = \"white\"\n",
"plt.rcParams[\"ytick.color\"] = \"white\"\n",
"plt.rcParams[\"grid.color\"] = \"#444444\"\n",
"plt.rcParams[\"lines.linewidth\"] = 2\n",
"plt.rcParams[\"lines.markersize\"] = 8\n",
"\n",
"def plot_execution_times(functions, execution_means, execution_stds, filename):\n",
"\n",
" # Create plot\n",
" fig, ax = plt.subplots()\n",
" bars = ax.bar(functions.keys(), execution_means, yerr=execution_stds, capsize=5, error_kw={'ecolor': 'grey'})\n",
"\n",
" plt.ylabel(\"Execution time (ms)\")\n",
" plt.xticks(rotation=45, ha=\"right\")\n",
"\n",
" # Calculate new ylim with a margin\n",
" max_execution_time = max(execution_means)\n",
" upper_ylim = max_execution_time + 0.4 * max_execution_time # Adding a 40% margin\n",
" plt.ylim(0, upper_ylim)\n",
"\n",
" # Annotate bars with execution times\n",
" for bar in bars:\n",
" yval = bar.get_height()\n",
" plt.text(bar.get_x() + bar.get_width()/2, yval + (0.05 * upper_ylim), round(yval, 2), ha=\"center\", va=\"bottom\")\n",
"\n",
" plt.tight_layout()\n",
" plt.savefig(filename)\n",
" plt.show()"
]
},
{
"cell_type": "markdown",
"id": "4df834dc",
"metadata": {
"id": "4df834dc"
},
"source": [
"## Speed comparison (Nvidia A100 GPU) with warmup (forward pass only)"
]
},
{
"cell_type": "code",
2025-03-06 20:29:04 -06:00
"execution_count": 29,
2024-03-13 08:37:54 -05:00
"id": "29b63d3d-6d0b-43bb-9c68-d5514dc81000",
"metadata": {
"id": "29b63d3d-6d0b-43bb-9c68-d5514dc81000"
},
"outputs": [],
"source": [
"# CUDA benchmark code shared by Andrei Aksionov\n",
"# and based on code from\n",
"# https://github.com/cuda-mode/lectures/blob/main/lecture1/pytorch_square.py\n",
"\n",
"import numpy as np\n",
"\n",
"def time_pytorch_function(func, *input, num_repeats=1_000):\n",
2024-03-13 08:37:54 -05:00
" start = torch.cuda.Event(enable_timing=True)\n",
" end = torch.cuda.Event(enable_timing=True)\n",
"\n",
" # Warmup\n",
" for _ in range(5):\n",
" func(*input)\n",
" torch.cuda.synchronize()\n",
"\n",
" times = []\n",
2024-03-13 08:37:54 -05:00
" for _ in range(num_repeats):\n",
" start.record()\n",
2024-03-13 08:37:54 -05:00
" func(*input)\n",
" end.record()\n",
2024-03-13 08:37:54 -05:00
" torch.cuda.synchronize()\n",
" times.append(start.elapsed_time(end))\n",
"\n",
" return np.mean(times), np.std(times)"
2024-03-13 08:37:54 -05:00
]
},
{
"cell_type": "code",
2025-03-06 20:29:04 -06:00
"execution_count": 30,
"id": "9dd07a09",
2024-03-13 08:37:54 -05:00
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
2025-03-06 20:29:04 -06:00
"height": 486
2024-03-13 08:37:54 -05:00
},
"id": "9dd07a09",
2025-03-06 20:29:04 -06:00
"outputId": "84d3ed5c-d4e6-47d0-b277-75ecf36a55e8"
2024-03-13 08:37:54 -05:00
},
"outputs": [
{
"data": {
2025-03-06 20:29:04 -06:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAnIAAAHWCAYAAADzS2TwAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAPYQAAD2EBqD+naQABAABJREFUeJzs3XVYVOnbwPHv0CBIiAE26tq1dncLoojdriiKGGAHiomYiIq5dnfn2q7d2C22WCgWMO8fvJwfI+rqrjpn9P5cl5fMieF+OGfOuec5T2gcHBy0CCGEEEIIg2Ok7wCEEEIIIcS/I4mcEEIIIYSBkkROCCGEEMJASSInhBBCCGGgJJETQgghhDBQksgJIYQQQhgoSeSEEEIIIQyUJHJCCCGEEAbKRN8BGDonJydevnyp7zCEEEII8ZOxtrbm3r17n91GErn/wMnJibNnz+o7DCGEEEL8pPLkyfPZZE4Suf8goSYuT548UisnhBBCiG/G2tqas2fP/mN+IYncN/Dy5UuioqL0HYYQQgghfjHS2UEIIYQQwkBJIieEEEIIYaAkkRNCCCGEMFCSyAkhhBBCGChJ5IQQQgghDJQkckIIIYQQBkoSOSGEEEIIAyWJnBBCCCGEgZJETgghhBDCQEkiJ4QQQghhoCSRE0IIIYQwUJLICSGEEEIYKEnkhBBCCCEMlCRyQgghhBAGShI5IYQQQggDJYmcEEIIIYSBkkROCCGEEMJASSInhBBCCGGgJJETQgghhDBQksgJIYQQQhgoE30H8KEMGTJQokQJ0qVLh5WVFY8fP+bMmTMcOXKEt2/f6js8IYQQQgjVUE0iV79+fdq3b0+BAgV4+PAh9+/f582bN9jb25MpUybevn3L8uXLmTBhAhEREfoOVwghhBBC71SRyO3cuZP379+zaNEiWrZsyd27d3XWm5mZUaRIEerWrcuOHTvo0aMHa9eu1VO0QgghhBDqoHFwcNDqO4gKFSqwc+fOL9rW3t6eDBkycOrUqe8c1T+zsbHhxo0bZMqUiaioKH2HI4QQQoifxJfmGKqpkftST58+5enTp98xGiGEEEIIw6C6Xqv58uUjZ86cyusaNWowb948+vfvj6mpqR4jE0IIIYRQF9UlcmPHjiVr1qwAZMyYkenTpxMdHY2bmxuDBg3Sb3BCCCGEECqiukQuS5YsnDlzBoA6derw999/0759e3x8fHB1ddVzdEIIIYT4EUqUKMGCBQsIDw8nMjKSmjVrfnLb0aNHExkZSfv27f/xfdu2bcuJEye4c+cOW7du5ffff1fW2dnZMXLkSA4dOkRERASnTp1ixIgR2NjYfJMyfQ+qS+Q0Gg1GRvFhlStXjm3btgFw584dHBwc9BmaEEIIIX4QKysrwsPD6dmz52e3q1WrFoULF+bevXv/+J7u7u4MGTKE4OBgKlasyNmzZ1m2bBmOjo4ApEmThjRp0jBw4EBKly6Nj48PFStWJCQk5JuU6XtQRWeHxE6ePImfnx+7d++mZMmS+Pv7A/GPWR89eqTn6IQQQgjxI+zYsYMdO3Z8dhsnJydGjhxJ/fr1Wbx48T++Z8eOHZk3bx4LFy4EwM/Pj6pVq9K0aVMmTJjAhQsXaNWqlbL9jRs3GDZsGGFhYRgbGxMbG/ufyvQ9qK5Grm/fvuTLl4+goCDGjh3L9evXAXBzc+Pw4cN6jk4IIYQQaqDRaJgyZQoTJ07k4sWL/7i9qakp+fPnZ/fu3coyrVbL7t27KVKkyCf3S548OVFRUapM4kCFNXLnzp2jTJkySZYHBASo9o8ohBBCiB+rS5cuxMTEMG3atC/aPkWKFJiYmPDw4UOd5Q8fPiRbtmwf3cfBwQF/f3/mzp37n+P9XlSXyCWWLFkypb1cAhl4VwghhPi15c+fHy8vLypWrPjdfoeNjQ2LFy/m4sWLBAUFfbff81+pLpHLkCEDQUFBlCpVCgsLC2W5RqNBq9WSKlUqPUYnhBBCCH0rXrw4KVOm1JnlycTEhCFDhtChQwcKFiyYZJ/IyEhiYmKS5BGpUqVKUktnbW3N0qVLefnyJS1atCAmJub7FOQbUF0iFxYWhkajwdfXl0ePHqHV6n0GMSGEEEKoyNKlS3XaugEsX76cpUuXKh0ZPvT+/XtOnTpF2bJl2bhxIxBfSVS2bFlmzJihbGdjY8OyZct4+/YtTZs25e3bt9+vIN+A6hK53LlzU6lSJa5cufKf36tEiRL4+PhQoEAB0qRJQ/PmzZWDBxAaGkrjxo119tmxYwcNGjT4z79bCCGEEP9esmTJyJw5s/I6Q4YM5MmTh6dPn3Lnzp0k03W+f/+eBw8e6OQPq1atYsOGDUqiNnnyZCZNmsTJkyc5fvw47du3x8rKSkn+bGxsWL58OZaWlnTo0AEbGxtlDLnHjx8TFxf3vYv91VSXyJ04cYK0adN+k0QuYQyahQsXfrKh4vbt2+ncubPyWu2ZtxBCCPErKFCgAGvXrlVeDxs2DIBFixbh4+PzRe+RKVMmnTFoV69ejaOjI7179yZVqlScPXuWBg0aKMOb5cuXj8KFCwNw7NixJPHcvn37P5Xpe9A4ODio6tllpkyZGDNmDMuWLeP8+fO8f/9eZ/25c+f+1ftGRkZ+tEbO1taW5s2b/6v3tLGx4caNG2TKlEk6YQghhBDim/nSHEN1NXKOjo5kypSJiRMnKsu0Wu136+xQqlQpLly4wPPnz9m7dy/Dhg1LUl0rhBBCCKFGqkvkQkJCOHPmDF5eXjx8+PC7dnbYsWMH69ev5+bNm2TOnJn+/fuzdOlSqlWr9tHn4GZmZpibmyuvra2tv1tsQgghhBD/RHWJXLp06WjatKkyo8P3tGrVKuXn8+fPEx4ezvHjxyldujR79uxJsn3Xrl3p1avXd49LCCGEEOJLqG6Krr1795InTx69/O6bN2/y+PFjnV4yiY0fP55MmTIp//QVpxBCCCEEqLBGbsuWLQwdOpScOXN+tLPD5s2bv9vvdnZ2xsHBgQcPHnx0/bt373j37t13+/1CCCGEEF9DdYncmDFjAOjRo0eSdV/b2eFzY9A8e/aMHj16sH79eh48eEDmzJkJCAjg2rVr/PXXX/+9IEIIIYQQ35nqErmUKVN+s/f63Bg0/v7+5M6dm0aNGmFra8v9+/fZuXMnI0aMkFo3IYQQwoBYWVmRLFmyr97v1atXREdHf4eIfhzVJXLf0v79+0mRIsUn13t6ev7AaIQQQgjxPeTNm5dixYp99X6HDh3i0KFD3yGiH0cViVzdunV1epB+jrOzM+nSpePw4cPfOSohhBBCGIIzZ85w7dq1JMvr1KmDlZUV0dHRrFmzJsn6V69e/YjwvitVJHKtW7emZ8+eLFy4kC1btnDp0iWd9TY2NhQrVgxPT0/Kly9Ply5d9BSpEEIIIdQmOjr6o49IE8aEjYuLU6bh+tmoIpFzc3OjevXqtGvXjgEDBhAdHc3Dhw95+/YtdnZ2pEqVisjISBYvXkzp0qV/2oMhhBBCCPE1VJHIQfywIps3b8bBwYHixYuTLl06LC0tiYyM5MyZM5w+ffq7zvIghBBCCP2xaTf3m7+nxuIU8B5NMvtv/v5R01t80/f7t1STyCV48uSJzsT2QgghhBDi41Q3s4MQQgghhPgyqquRE0IIIYT4Gpa8w0rzPslyI7TK/yk0SXuoRmtNeY3Zd4/ve5JETgghhBAGLbvJIwqa3vvkektNDG4W55MsP/HeiZMxab9naN+dJHJCCCGEMGgXY1JyO9buq/eL1pp++2B+MNUmcqampmTMmJHr168TGxur73CEEEIIoVKvMeO11rAfkf5bquvsYGlpyYQJE4iIiGD//v2kS5cOgJEjR8pAwEIIIYQQiagukRswYAB58uTBzc2NN2/eKMt3796Nu7u7/gITQgghhFAZ1T1arVmzJn/88QdHjx7VWX7hwgUyZ86sp6iEEEIIIdRHdTVyKVKk+OgUXFZWVjKzgxBCCCFEIqpL5E6ePEnVqlWV1wnJW/PmzTly5Ii+whJCCCGEUB3VPVodOnQoS5cuJXv27BgbG9O+fXuyZ89OkSJFcHNz03d4QgghhBC
2024-03-13 08:37:54 -05:00
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"execution_stats = [time_pytorch_function(fn, embeddings) for fn in functions.values()]\n",
"execution_means = [stat[0] for stat in execution_stats]\n",
"execution_stds = [stat[1] for stat in execution_stats]\n",
2024-03-13 08:37:54 -05:00
"\n",
2024-08-10 09:44:11 -05:00
"\n",
"plot_execution_times(functions, execution_means, execution_stds, filename=\"1_forward-only.pdf\")"
2024-08-10 09:44:11 -05:00
]
},
{
"cell_type": "markdown",
"id": "VQaSerWCOnYB",
"metadata": {
"id": "VQaSerWCOnYB"
},
"source": [
"<br>\n",
"&nbsp;\n",
"\n",
"\n",
"## Speed comparison (Nvidia A100 GPU) with warmup (forward and backward pass)"
]
},
{
"cell_type": "code",
2025-03-06 20:29:04 -06:00
"execution_count": 31,
"id": "69e6377b",
2024-08-10 09:44:11 -05:00
"metadata": {
"id": "69e6377b"
2024-08-10 09:44:11 -05:00
},
"outputs": [],
2024-08-10 09:44:11 -05:00
"source": [
"def forward_backward(func, embeddings):\n",
" if embeddings.grad is not None:\n",
" embeddings.grad.zero_()\n",
"\n",
" output = func(embeddings)\n",
" loss = output.sum()\n",
" loss.backward()\n",
"\n",
"\n",
"def time_pytorch_function_forward_backward(func, *input, num_repeats = 1_000):\n",
" # CUDA IS ASYNC so can't use python time module\n",
" start = torch.cuda.Event(enable_timing=True)\n",
" end = torch.cuda.Event(enable_timing=True)\n",
"\n",
" # Warmup\n",
" for _ in range(5):\n",
" forward_backward(func, *input)\n",
" torch.cuda.synchronize()\n",
"\n",
" times = []\n",
2024-08-10 09:44:11 -05:00
" for _ in range(num_repeats):\n",
" start.record()\n",
2024-08-10 09:44:11 -05:00
" forward_backward(func, *input)\n",
" end.record()\n",
2024-08-10 09:44:11 -05:00
" torch.cuda.synchronize()\n",
" times.append(start.elapsed_time(end))\n",
2024-08-10 09:44:11 -05:00
"\n",
" return np.mean(times), np.std(times)"
]
},
{
"cell_type": "code",
2025-03-06 20:29:04 -06:00
"execution_count": 32,
"id": "ReCmeRhCOpm8",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
2025-03-06 20:29:04 -06:00
"height": 486
},
"id": "ReCmeRhCOpm8",
2025-03-06 20:29:04 -06:00
"outputId": "6d0f526e-d044-49b0-d2e7-fb1dac063920"
},
"outputs": [
{
"data": {
2025-03-06 20:29:04 -06:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAnIAAAHWCAYAAADzS2TwAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAPYQAAD2EBqD+naQABAABJREFUeJzs3XVYFen7+PH3oUFaDFBsP3Z3NyaIYqBiryiKCXag2GIril3YveqKtbZrx9qdYIuKYgHn94c/5stZ1FVXPXP0fl2X1+6Z8n6cOTP3eeYJjaOjoxYhhBBCCGFwjPQdgBBCCCGE+DqSyAkhhBBCGChJ5IQQQgghDJQkckIIIYQQBkoSOSGEEEIIAyWJnBBCCCGEgZJETgghhBDCQEkiJ4QQQghhoEz0HYChc3Z25sWLF/oOQwghhBA/GWtra+7evfvJbSSR+w+cnZ05c+aMvsMQQgghxE8qb968n0zmJJH7DxJr4vLmzSu1ckIIIYT4ZqytrTlz5sy/5heSyH0DL168ICYmRt9hCCGEEOIXI50dhBBCCCEMlCRyQgghhBAGShI5IYQQQggDJYmcEEIIIYSBkkROCCGEEMJASSInhBBCCGGgJJETQgghhDBQksgJIYQQQhgoSeSEEEIIIQyUJHJCCCGEEAbqp03kWrduzZ49e7hx4wY3btwgIiKCKlWqKOvNzc0ZM2YMly9f5ubNm8yfP59UqVLpMWIhhBBCiC/z0yZyUVFRBAcHU7lyZapUqcLevXsJDw8nR44cAAwfPpzq1avTpk0bPDw8SJs2LQsWLNBz1EIIIYQQn0/j6Oio1XcQP8qVK1cICgri999/59KlS/j6+rJhwwYAsmfPzsGDB6levTpHjx79rOPZ2Nhw48YNMmXKRExMzPcMXQghhBC/kM/NMX7aGrmkjIyMqFevHlZWVhw9epSCBQtiZmbG7t27lW0uX77M7du3KVq0qB4jFUIIIYT4fCb6DuB7ypUrFxEREVhYWPDy5UtatGjBxYsXyZs3L2/evOH58+c62z98+JA0adJ89HhmZmaYm5srn62trb9b7EIIIYQQ/+anTuSuXLlCxYoVsbW1xcPDg9DQUDw8PL76eN26daN3797fMEIhhBBCiK/3Uydy79694/r16wCcOnWKQoUK4evry7p16zA3N8fW1lanVi5VqlTcv3//o8ebOHEi06dPVz5bW1tz5syZ71cAIYQQQohP+CXayCUyMjLC3NyckydP8vbtWypUqKCsy5YtG66urp/s6PD27VtiYmKUPy9evPgRYQshhBBCfNBPWyM3cOBAtm/fzp07d7C2tqZBgwaUKVOGhg0bEhMTw+LFixk6dCjR0dHExMQwatQoDh8+/Nk9VoUQQggh9O2nTeScnJyYNm0aadKk4fnz55w7d46GDRuya9cuAPr3709CQgLz58/HzMyMnTt30rNnT/0GLYQQQgjxBVQ3jlyGDBkoVaoU6dOnx8rKikePHnH69GmOHDnCmzdv9B2eDhlHTgghhBDfw+fmGKqpkWvQoAHt27enYMGCPHjwgHv37vH69WscHBzIlCkTb968YdWqVUyaNIk7d+7oO1whhBBCCL1TRSK3c+dO3r17x9KlS2nZsiVRUVE6683MzChWrBj16tVjx44d9OzZk99//11P0QohhBBCqIMqXq1WqlSJnTt3fta2Dg4OZMiQgVOnTn3nqP6dvFoVQgghxPdgUK9WPzeJA4iOjiY6Ovo7RiOEEEIIYRhUN45c/vz5yZUrl/K5Zs2aLFq0iAEDBmBqaqrHyIQQQggh1EV1idz48ePJli0bABkzZmTWrFnExsbi4eHB4MGD9RucEEIIIYSKqC6Ry5o1K6dPnwagbt26/PXXX7Rv3x5/f3/c3d31HJ0QQgghvrdu3bqxfft2bt68yYULF1i0aJFSyZNo3LhxHD16lDt37nDx4kXCw8PJnj37J4/7+PHjD/7x9/dXtgkPD+fUqVNERkZy9uxZpk+fTtq0ab9LOb8F1SVyGo0GI6P3YVWoUIFt27YBEBkZiaOjoz5DE0IIIcQPULp0aebMmYObmxteXl6YmJiwatUqrKyslG1OnTpF586dKVWqFA0bNgRg1apVSg7xIbly5dL507lzZxISEtiwYYOyzb59+2jbti0lSpSgVatWZMqUiXnz5n2/wv5Hqui1mtS6deuIjIxk9+7dTJo0idKlS3P9+nVKly5NaGgohQoV0neICum1KoQQQnx/KVOm5NKlS9SpU4e//vrrg9vkzp2bvXv3UqRIEW7cuPFZx120aBHW1tbUq1fvo9vUqFGDRYsW4ezsTFxc3NeE/1U+N8dQXY1cv379yJ8/P6NHj2b8+PFcv34dAA8PDw4fPqzn6IQQQgjxo9na2gJ8dNQKKysrmjZtyo0bN4iMjPysY6ZKlYpq1aoRHh7+0W3s7e1p0KABhw8f/qFJ3JdQxfAjSZ07d45y5colWx4UFER8fLweIhJCCCGEvmg0GoYPH87Bgwe5cOGCzro2bdoQFBSEtbU1ly9fxsvLi3fv3n3Wcb29vXnx4gUbN25Mti4oKIi2bduSIkUKjhw5QpMmTb5JWb4H1dXIJZUiRQpsbGywsbHBzMwMS0tLfYckhBBCiB8oJCSEXLly0a5du2TrVq5cSaVKlahTpw5Xrlxhzpw5mJubf9ZxmzVrxqpVqz44j/uUKVOoVKkSXl5exMfHM23atP9cju9FdTVyGTJkYPTo0ZQpUwYLCwtluUajQavVkjp1aj1GJ4QQQogfZfTo0bi5uVGnTp1k03cCxMTEEBMTw7Vr1zh69ChXr16ldu3arFmz5pPHLVmyJNmzZ6dt27YfXP/kyROePHnC1atXuXTpEqdPn6Zo0aIcPXr0m5TrW1JdIhcWFoZGo6FLly48fPgQrVZVfTGEEEII8QOMHj2a2rVr4+Hhwa1bt/51e41Gg0aj+awaOR8fH06ePMnZs2c/67jAZ9f0/WiqS+Ty5MlDlSpVuHLlir5DEUIIIYQehISE4OXlhY+PDy9evFDexj1//pzXr1+TMWNG6tWrx86dO3n06BEuLi507dqV169fK8OWARw8eJChQ4eyadMmZZmNjQ0eHh4MGjQo2d9bpEgRChUqxMGDB3n69CmZM2emb9++XLt2jSNHjnz/gn8F1SVyJ06cIF26dJLICSGEEL+oNm3aAOiM7wbg7+/P0qVLefPmDSVLlqR9+/bY29vz8OFDDhw4QM2aNXn06JGyffbs2ZUer4nq1auHRqNh9erVyf7e2NhY6tSpQ+/evbGysuL+/fv8+eefjBs3jrdv336Hkv53qhtHLlOmTIwbN46VK1dy/vz5ZL1Pzp07p6fIkpNx5IQQQgjxPXxujqG6GjknJycyZcrElClTlGVarVY6OwghhBBC/IPqErnJkydz+vRpfH19efDggXR2EEIIIYT4CNUlcunTp6dZs2bKjA5CCCGEEOLDVDcg8N69e8mbN6++wxBCCCGEUD3V1cht2bKFYcOGkStXrg92doiIiNBTZEIIIYQQ6qK6XqsPHz786Dq1dXaQXqtCCCGE+B4MttdqqlSp9B2CEEIIIYRBUF0iJ4QQQgjxJaysrEiRIsUX7/fy5UtiY2O/Q0Q/jioSuXr16rF27drP2tbFxYX06dNz+PDh7xyVEEIIIQxBvnz5KFGixBfvd+jQIQ4dOvQdIvpxVJHItW7dml69erFkyRK2bNnCpUuXdNbb2NhQokQJGjZsSMWKFenataueIhVCCCGE2pw+fZpr164lW163bl2srKyIjY1l/fr1yda/fPnyR4T3XakikfPw8KBGjRq0a9eOgQMHEhsby4MHD3jz5g329vakTp2ax48fs2zZMsqWLfvJDhFCCCGE+LXExsZ+8BVpQkKC8t+fNXdQRSIH74cViYiIwNHRkZIlS5I+fXosLS15/Pgxp0+f5u+//5ZZHoQQQoiflE27hd/8mBqLU8A7NCkcvvnxY2a1+KbH+1qqSeQSPXnyhD/++EPfYQghhBDCQFjyFivNu2TLjdAq/02pSf4aNVZryivMvnt835PqEjkhhBBCiC+Rw+QhhUzvfnS9pSY
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"execution_stats = [time_pytorch_function_forward_backward(fn, embeddings) for fn in functions.values()]\n",
"execution_means = [stat[0] for stat in execution_stats]\n",
"execution_stds = [stat[1] for stat in execution_stats]\n",
2024-08-10 09:44:11 -05:00
"\n",
"\n",
"plot_execution_times(functions, execution_means, execution_stds, filename=\"2_forward-and-backward.pdf\")"
2024-08-10 09:44:11 -05:00
]
},
{
"cell_type": "markdown",
"id": "1gWX-Ayqia1k",
"metadata": {
"id": "1gWX-Ayqia1k"
},
"source": [
"<br>\n",
"&nbsp;\n",
"\n",
"\n",
"## Speed comparison (Nvidia A100 GPU) with warmup and compilation (forward and backward pass)"
]
},
{
"cell_type": "code",
2025-03-06 20:29:04 -06:00
"execution_count": 33,
2024-08-10 09:44:11 -05:00
"id": "LQDiAPooiYAz",
"metadata": {
"id": "LQDiAPooiYAz"
},
"outputs": [],
"source": [
"import torch._dynamo\n",
"torch._dynamo.config.suppress_errors = True\n",
"\n",
"def prepare_function(fn):\n",
" fn = torch.compile(fn)\n",
" return fn"
]
},
{
"cell_type": "code",
2025-03-06 20:29:04 -06:00
"execution_count": 34,
"id": "aac06ffe",
2024-08-10 09:44:11 -05:00
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
2025-03-06 20:29:04 -06:00
"height": 486
2024-08-10 09:44:11 -05:00
},
"id": "aac06ffe",
2025-03-06 20:29:04 -06:00
"outputId": "d66cf0e8-18ab-40ab-e22f-86e8e82edfdd"
2024-08-10 09:44:11 -05:00
},
"outputs": [
{
"data": {
2025-03-06 20:29:04 -06:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAnIAAAHWCAYAAADzS2TwAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAPYQAAD2EBqD+naQABAABJREFUeJzs3XVYFO3XwPHv0imhoqDYPnY3dmGCgYGKrSiKDbaimIid2N3d2N2Fj93dgYFF7fsHL/NjRR/b3dXzuS4v2dmZ4dzM7MzZe+5Q2dvbqxFCCCGEEHrHQNsBCCGEEEKI7yOJnBBCCCGEnpJETgghhBBCT0kiJ4QQQgihpySRE0IIIYTQU5LICSGEEELoKUnkhBBCCCH0lCRyQgghhBB6ykjbAeg7R0dHIiIitB2GEEIIIf4wVlZWPHjw4D/XkUTuBzg6OnL27FlthyGEEEKIP1TOnDn/M5mTRO4HxNfE5cyZU2rlhBBCCPHTWFlZcfbs2S/mF5LI/QQRERG8fv1a22EIIYQQ4i8jnR2EEEIIIfSUJHJCCCGEEHpKEjkhhBBCCD0liZwQQgghhJ6SRE4IIYQQQk9JIieEEEIIoackkRNCCCGE0FOSyAkhhBBC6ClJ5IQQQggh9JQkckIIIYQQekoSOSGEEEIIPSWJnBBCCCGEnpJETgghhBBCT0kiJ4QQQgihpySRE0IIIYTQU5LICSGEEELoKUnkhBBCCCH0lCRyQgghhBB6ShI5IYQQQgg9JYmcEEIIIYSeMtJ2AN8jTZo0FCtWjNSpU2NhYcHTp085c+YMx44d48OHD9oOTwghhBDit9CrRK5OnTq0adOGvHnz8vjxYx4+fMj79++xs7MjXbp0fPjwgRUrVjBu3Dju3r2r7XCFEEIIIX4pvUnkdu3aRVRUFIsXL6Zp06bcv39f430TExMKFSpErVq12LFjB/7+/qxbt05L0QohhBBC/Hoqe3t7tbaD+Bply5Zl165dX7WunZ0dadKk4fTp0780Jmtra27evEm6dOl4/fr1L/1dQgghhPh7fG2OoVc1cl8rPDyc8PDwXxiNEEIIIYT26WWv1dy5c5MtWzbldZUqVZg/fz59+/bF2NhYi5EJIYQQQvw+epnIjR49mkyZMgGQNm1apk+fztu3b3F3d2fAgAHaDU4IIYQQ4jfRy0QuY8aMnDlzBoAaNWpw6NAh2rRpg6+vL25ublqO7vc7deoUz549S/RvxIgRn1w/S5YszJkzR9muTZs2idbp3Lkz27dv59atW1y8eJH58+crybMQQgghdINeJnIqlQoDg7jQS5cuzbZt2wC4d+8e9vb2X70fAwMDevXqxcmTJ7l79y7Hjx+nW7duvyTmX6lChQpky5ZN+Ve7dm0A1q5d+8n1LSwsuHnzJoGBgTx8+PCT67i4uDBz5kxcXV3x8PDAyMiIFStWYGFh8cvKIYQQQohvozedHRIKCwujW7du7NmzBxcXF/z8/IC4x6xPnjz56v106tSJ5s2b0759ey5evEjevHmZOHEir1+/Ztq0ab8q/J/u2bNnGq87derE9evXOXDgwCfXP3XqFKdOnQKgf//+n1ynXr16Gq99fX25fPkyefLk4dChQz8haiGEEEL8KL1M5Hr37s3UqVOpWrUqo0eP5saNGwC4u7tz9OjRr95PoUKF2Lx5s1Kjd+fOHTw8PMifP/8vift3MDY2pm7dukyZMuWn7jdJkiQA0htYCCGE0CF6mcidP3+ekiVLJloeEBBATEzMV+/n2LFjNGnShIwZM3Lt2jVy5MhBkSJF6Nev388M97eqWrUqNjY2LF68+KftU6VSMWTIEA4fPszFixd/2n6FEEII8WP0MpFLyNLSUmkvF+9rB+cdO3Ys1tbWHD58mJiYGAwNDRkyZAgrVqz45PomJiaYmpoqr62srL4/8F/Ey8uL7du3f7bt2/cIDg4mW7ZsVKtW7aftUwghhBA/Ti8TuTRp0hAUFETx4sUxMzNTlqtUKtRqNQ4ODl+1n5o1a1KnTh28vb25ePEiuXLlYsiQITx8+JAlS5YkWr9z58706NHjp5XjZ0udOjWlS5emadOmP22fQUFBuLq6Ur169UTTogkhhBBCu/QykQsJCUGlUtGxY0eePHmCWv19s4wNHDiQcePGsXr1agAuXLiAs7MznTt3/mQiN3bsWI22Z1ZWVpw9e/b7CvELNGzYkCdPnrB169afsr+goCCqVauGu7s7t2/f/in7FEIIIcTPo5eJXI4cOShfvjxXr179of2Ym5sTGxursSwmJgaVSvXJ9SMjI4mMjPyh3/mrqFQqGjZsyNKlSxO1E5w8eTIPHjxg0KBBQFyHiCxZsgBxj4sdHR3JmTMnb968UTqOBAcH4+HhgZeXFxEREUot56tXr3j//v1vLJkQQoiv5ejoSEBAAOXLl8fc3JwbN27QoUMHwsLCPruNiYkJ/v7+1K1bFwcHBx49ekRwcDCLFi0C4oayKlGiRKLttm7dSoMGDX5VUcRX0stE7tSpU6RKleqHE7ktW7bQtWtX7t69y8WLF8mdOzc+Pj7KyatPSpcujbOzMwsXLkz0XqpUqTQS1pQpU7Jnzx7ldYcOHejQoQP79++nRo0aALRo0QKA9evXa+zL19f3p3akEEII8XPY2NiwadMm9u/fT/369Xn69CkZMmTgxYsX/7ndrFmzSJ48uTJ0VYoUKTTanjdt2hQTExPltZ2dHXv37mXdunW/qijiG6js7e2/77mkFqVLl45Ro0axfPlyLly4QFRUlMb758+f/6r9WFlZ0atXL6pVq0ayZMl4+PAhq1atIjg4ONE+P8Xa2pqbN2+SLl26r+5gIYQQQvwK/fv3p3DhwlSvXv2rtylXrhwzZswgf/78X0z44rVp04ZevXqRPXt23r59+53Rii/52hxDLxO5ggULMnXqVNKkSaMsU6vV39zZ4UdJIieEEEJXHDx4kJ07d+Lk5ISLiwsPHjxg1qxZzJ8//7PbBAcHkzFjRsLCwqhXrx5v3rwhNDSUYcOGfbYZzb59+zh27Bhdu3b9VUURfH2OoZePVsePH8+ZM2fw9vbm8ePH393ZQQghhPhTpE2blubNmzNlyhTGjBlDvnz5GDZsGFFRUZ/swBe/TZEiRXj//j1NmjTB3t6e4OBg7O3t6dChQ6L18+fPT/bs2enUqdOvLo74SnqZyKVOnZpGjRopDfOFEEKIv52BgQFhYWEMHjwYgDNnzpAtWzaaNWv22UTOwMAAtVpNmzZtlFqffv36MXv2bPz9/RPVyjVq1Ihz585x8uTJX1sY8dUMvryK7tm3bx85c+bUdhhCCCGEznj06BGXLl3SWHb58mVSp079n9s8ePBA49Hd5cuXMTAwwMnJSWNdCwsLateuzYIFC35u4OKH6GWN3JYtWxg8eDDZsmX7ZGeH0NBQLUUmhBBCaMeRI0fIlCmTxrKMGTNy586d/9zG3d0dS0tL3rx5o2wTExOTaBD4GjVqYGJiwvLly39+8OK76WVnhydPnnz2PensIIQQ4m+UL18+Nm/eTFBQEGvWrCF//vyMGTOGrl27KlNP9uvXD0dHR9q1awfETXN56NAhjh8/TlBQEPb29owdO5aDBw/SpUsXjf1v2LCBBw8e0Lp1699etr/R1+YYevloNXny5J/997uSOCGEEPrJ0dGRkJAQrly5wt27d9m3bx958+b97PopUqRg6tSpHDlyhCdPnjBkyJBPrpckSRJGjBjBuXPnuH//PkeOHKFChQq/qBSJnTp1iiZNmlC7dm3279+Pn58fffr00Zg/PEWKFKRKlUp5/ebNGzw8PLCxsWH79u1MnTqVLVu20KtXL419Z8qUiWLFin1yrFKhXXr5aFUIIYT4Ht8zaK6JiQnPnj1j9OjRtG3b9pPrGBsbs2rVKp48eULz5s158OABzs7OvHz58heV5NO2bt36n9M0+vr6Jlp25coVPDw8/nO/V69eJWnSpD8cn/j59CaRq1WrljIn6pc4OTmROnVqjh49+ouj0n0WFhZYWlp+83Zv3ryRgR6FEH+cTp06ce/ePY2hNb40l/SdO3fo3bs3EDen9ac0atQIW1t
2024-08-10 09:44:11 -05:00
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"execution_stats = [time_pytorch_function_forward_backward(prepare_function(fn), embeddings) for fn in functions.values()]\n",
"execution_means = [stat[0] for stat in execution_stats]\n",
"execution_stds = [stat[1] for stat in execution_stats]\n",
2024-03-13 08:37:54 -05:00
"\n",
"\n",
"plot_execution_times(functions, execution_means, execution_stds, filename=\"3_forward-and-backward-compiled.pdf\")"
2024-03-13 08:37:54 -05:00
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"gpuType": "A100",
2025-03-06 20:29:04 -06:00
"machine_shape": "hm",
2024-03-13 08:37:54 -05:00
"provenance": []
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
2024-03-13 08:37:54 -05:00
"language": "python",
"name": "python3"
2024-03-06 08:30:32 -06:00
},
2024-03-13 08:37:54 -05:00
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.16"
2024-03-13 08:37:54 -05:00
}
},
"nbformat": 4,
"nbformat_minor": 5
}