2024-03-06 08:30:32 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								{
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								 "cells": [
							 
						 
					
						
							
								
									
										
										
										
											2024-03-19 09:26:26 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "e2e65c03-36d4-413f-9b23-5cdd816729ab",
							 
						 
					
						
							
								
									
										
										
										
											2024-08-10 09:44:11 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "e2e65c03-36d4-413f-9b23-5cdd816729ab"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
									
										
										
										
											2024-03-19 09:26:26 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-05-24 07:20:37 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "<table style=\"width:100%\">\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "<tr>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "<td style=\"vertical-align:middle; text-align:left;\">\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "<font size=\"2\">\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "Supplementary code for the <a href=\"http://mng.bz/orYv\">Build a Large Language Model From Scratch</a> book by <a href=\"https://sebastianraschka.com\">Sebastian Raschka</a><br>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "<br>Code repository: <a href=\"https://github.com/rasbt/LLMs-from-scratch\">https://github.com/rasbt/LLMs-from-scratch</a>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "</font>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "</td>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "<td style=\"vertical-align:middle; text-align:left;\">\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "<a href=\"http://mng.bz/orYv\"><img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/cover-small.webp\" width=\"100px\"></a>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "</td>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "</tr>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "</table>"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-19 09:26:26 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "6f678e62-7bcb-4405-86ae-dce94f494303",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "6f678e62-7bcb-4405-86ae-dce94f494303"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-03-23 07:27:43 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "# Comparing Efficient Multi-Head Attention Implementations"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "b742938a-4bfc-4527-a1f1-d5963508967d",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "b742938a-4bfc-4527-a1f1-d5963508967d"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "This code notebook compares different ways to implement causal multi-head attention used in decoder-style LLMs like GPT, Llama, etc."
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 3,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "id": "7898551e-f582-48ac-9f66-3632abe2a93f",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "colab": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "base_uri": "https://localhost:8080/"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "7898551e-f582-48ac-9f66-3632abe2a93f",
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "outputId": "1a7d22c1-96d8-4a42-e3ec-ce78abaf18eb"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "PyTorch version: 2.5.0.dev20240905+cu121\n"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
									
										
										
										
											2024-03-08 09:30:55 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "import torch\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "torch.manual_seed(123)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(f\"PyTorch version: {torch.__version__}\")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "batch_size = 8\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "context_len = 1024\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "embed_dim = 768\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "embeddings = torch.randn((batch_size, context_len, embed_dim), device=device)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2024-10-23 03:23:31 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "LYLcq3403Yq6",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "LYLcq3403Yq6"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- To run all the code in this notebook, please ensure you update to at least PyTorch 2.5 (FlexAttention is not included in earlier PyTorch releases)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "If the code cell above shows a PyTorch version lower than 2.5, you can upgrade your PyTorch installation by uncommenting and running the following code cell (Please note that PyTorch 2.5 requires Python 3.9 or later)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- For more specific instructions and CUDA versions, please refer to the official installation guide at https://pytorch.org."
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "execution_count": null,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "1db27f43-86f4-478f-89df-fbc2182a129b",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "# pip install --upgrade torch torchvision torchaudio"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "2f9bb1b6-a1e5-4e0a-884d-0f31b374a8d6",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "2f9bb1b6-a1e5-4e0a-884d-0f31b374a8d6"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-03-23 07:27:43 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "<br>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    " \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "## 1) CausalAttention MHA wrapper class from chapter 3"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 4,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "id": "297c93ed-aec0-4896-bb89-42c4b294d3d1",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
									
										
										
										
											2024-03-09 10:20:08 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "colab": {
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								     "base_uri": "https://localhost:8080/"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "297c93ed-aec0-4896-bb89-42c4b294d3d1",
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "outputId": "b6f596e4-b778-496c-bea8-3fe83d873c5b"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "torch.Size([8, 1024, 768])\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "import torch.nn as nn\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "class CausalAttention(nn.Module):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        super().__init__()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        self.d_out = d_out\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        self.dropout = nn.Dropout(dropout)  # New\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        self.register_buffer(\"mask\", torch.triu(torch.ones(context_length, context_length), diagonal=1))  # New\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    def forward(self, x):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        b, num_tokens, d_in = x.shape  # New batch dimension b\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        keys = self.W_key(x)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        queries = self.W_query(x)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        values = self.W_value(x)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        attn_scores = queries @ keys.transpose(1, 2)  # Changed transpose\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        attn_scores.masked_fill_(  # New, _ ops are in-place\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        attn_weights = self.dropout(attn_weights)  # New\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        context_vec = attn_weights @ values\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        return context_vec\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "class Ch03_MHA_Wrapper(nn.Module):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        super().__init__()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        self.heads = nn.ModuleList(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            [CausalAttention(d_in, d_out, context_length, dropout, qkv_bias)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "             for _ in range(num_heads)]\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        )\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        self.out_proj = nn.Linear(d_out*num_heads, d_out*num_heads)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    def forward(self, x):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        context_vec = torch.cat([head(x) for head in self.heads], dim=-1)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        return self.out_proj(context_vec)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "mha_ch03_wrapper = Ch03_MHA_Wrapper(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    d_in=embed_dim,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    d_out=embed_dim//12,\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-04 07:27:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    context_length=context_len,\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    dropout=0.0,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    num_heads=12,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    qkv_bias=False\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    ").to(device)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "out = mha_ch03_wrapper(embeddings)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(out.shape)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "21930804-b327-40b1-8e63-94dcad39ce7b",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "21930804-b327-40b1-8e63-94dcad39ce7b"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-03-23 07:27:43 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "<br>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    " \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "## 2) The multi-head attention class from chapter 3"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 5,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "id": "4ee6a61b-d25c-4a0c-8a59-f285544e3710",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "colab": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "base_uri": "https://localhost:8080/"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "4ee6a61b-d25c-4a0c-8a59-f285544e3710",
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "outputId": "4d9ade55-4710-4ae6-9f00-aa87811bfb04"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "torch.Size([8, 1024, 768])\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "class Ch03_MHA(nn.Module):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        super().__init__()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        assert d_out % num_heads == 0, \"d_out must be divisible by num_heads\"\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        self.d_out = d_out\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        self.num_heads = num_heads\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        self.head_dim = d_out // num_heads  # Reduce the projection dim to match desired output dim\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        self.out_proj = nn.Linear(d_out, d_out)  # Linear layer to combine head outputs\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        self.dropout = nn.Dropout(dropout)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        self.register_buffer(\"mask\", torch.triu(torch.ones(context_length, context_length), diagonal=1))\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    def forward(self, x):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        b, num_tokens, d_in = x.shape\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        keys = self.W_key(x)  # Shape: (b, num_tokens, d_out)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        queries = self.W_query(x)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        values = self.W_value(x)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        # We implicitly split the matrix by adding a `num_heads` dimension\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        values = values.view(b, num_tokens, self.num_heads, self.head_dim)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        keys = keys.transpose(1, 2)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        queries = queries.transpose(1, 2)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        values = values.transpose(1, 2)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        # Compute scaled dot-product attention (aka self-attention) with a causal mask\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        attn_scores = queries @ keys.transpose(2, 3)  # Dot product for each head\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        # Original mask truncated to the number of tokens and converted to boolean\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        # Use the mask to fill attention scores\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        attn_scores.masked_fill_(mask_bool, -torch.inf)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        attn_weights = self.dropout(attn_weights)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        # Shape: (b, num_tokens, num_heads, head_dim)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        context_vec = (attn_weights @ values).transpose(1, 2)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        # Combine heads, where self.d_out = self.num_heads * self.head_dim\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        context_vec = self.out_proj(context_vec)  # optional projection\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        return context_vec\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "mha_ch03 = Ch03_MHA(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    d_in=embed_dim,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    d_out=embed_dim,\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-04 07:27:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    context_length=context_len,\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    dropout=0.0,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    num_heads=12,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    qkv_bias=False\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    ").to(device)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "out = mha_ch03(embeddings)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(out.shape)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "73cd11da-ea3b-4081-b483-c4965dfefbc4",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "73cd11da-ea3b-4081-b483-c4965dfefbc4"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-03-23 07:27:43 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "<br>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    " \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "## 3) An alternative multi-head attention with combined weights"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "1fa1a5ea-eaff-4d2d-aaf0-b34cdb6fd4dd",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "1fa1a5ea-eaff-4d2d-aaf0-b34cdb6fd4dd"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-08-14 03:57:41 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "- The code for the `MultiHeadAttentionCombinedQKV` class below is based on code that was kindly shared by [Rayed Bin Wahed](https://github.com/rasbt/LLMs-from-scratch/discussions/51)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- The main difference between the `MultiHeadAttentionCombinedQKV` class and the `MultiHeadAttention` class used in chapter 3 is that `MultiHeadAttentionCombinedQKV` uses a single weight matrix, `self.qkv = nn.Linear(d_in, 3 * d_out, bias=qkv_bias)` instead of separate weight matrices:\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "  - `self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)`\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "  - `self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)`\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "  - `self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)`\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Here, `self.qkv` combines all three weight matrices `self.W_query`, `self.W_key`, and `self.W_value` to carry out the query, key, and value computation in a single step\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Using `q, k, v = qkv.unbind(0)`, we obtain the individual query, key, and value tensors, which are then used similarly to the query, key, and value tensors in the `MultiHeadAttention` class in chapter 3"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 6,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "id": "9a6bd0a2-f27c-4602-afa0-c96cd295c1a6",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "colab": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "base_uri": "https://localhost:8080/"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "9a6bd0a2-f27c-4602-afa0-c96cd295c1a6",
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "outputId": "a0a023ee-3bc7-4a89-cdba-7c97921160ee"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "torch.Size([8, 1024, 768])\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "import torch.nn as nn\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "class MultiHeadAttentionCombinedQKV(nn.Module):\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-04 07:27:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    def __init__(self, d_in, d_out, num_heads, context_length, dropout=0.0, qkv_bias=False):\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "        super().__init__()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        assert d_out % num_heads == 0, \"embed_dim is indivisible by num_heads\"\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        self.num_heads = num_heads\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-04 07:27:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "        self.context_length = context_length\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "        self.head_dim = d_out // num_heads\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        self.qkv = nn.Linear(d_in, 3 * d_out, bias=qkv_bias)\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-05-27 07:46:29 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "        self.proj = nn.Linear(d_out, d_out)\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "        self.dropout = nn.Dropout(dropout)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        self.register_buffer(\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-04 07:27:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "            \"mask\", torch.triu(torch.ones(context_length, context_length), diagonal=1)\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "        )\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    def forward(self, x):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        batch_size, num_tokens, embed_dim = x.shape\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        # (b, num_tokens, embed_dim) --> (b, num_tokens, 3 * embed_dim)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        qkv = self.qkv(x)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        # (b, num_tokens, 3 * embed_dim) --> (b, num_tokens, 3, num_heads, head_dim)\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-26 17:13:08 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "        qkv = qkv.view(batch_size, num_tokens, 3, self.num_heads, self.head_dim)\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        # (b, num_tokens, 3, num_heads, head_dim) --> (3, b, num_heads, num_tokens, head_dim)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        qkv = qkv.permute(2, 0, 3, 1, 4)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        # (3, b, num_heads, num_tokens, head_dim) -> 3 times (b, num_head, num_tokens, head_dim)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        queries, keys, values = qkv.unbind(0)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        # (b, num_heads, num_tokens, head_dim) --> (b, num_heads, num_tokens, num_tokens)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        attn_scores = queries @ keys.transpose(-2, -1)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        attn_scores = attn_scores.masked_fill(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            self.mask.bool()[:num_tokens, :num_tokens], -torch.inf\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        )\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**-0.5, dim=-1)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        attn_weights = self.dropout(attn_weights)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        # (b, num_heads, num_tokens, num_tokens) --> (b, num_heads, num_tokens, head_dim)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        context_vec = attn_weights @ values\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        # (b, num_heads, num_tokens, head_dim) --> (b, num_tokens, num_heads, head_dim)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        context_vec = context_vec.transpose(1, 2)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        # (b, num_tokens, num_heads, head_dim) --> (b, num_tokens, embed_dim)\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-26 17:13:08 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "        context_vec = context_vec.contiguous().view(batch_size, num_tokens, embed_dim)\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        context_vec = self.proj(context_vec)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        return context_vec\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "mha_combined_qkv = MultiHeadAttentionCombinedQKV(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    d_in=embed_dim,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    d_out=embed_dim,\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-04 07:27:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    context_length=context_len,\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    dropout=0.0,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    num_heads=12,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    qkv_bias=False\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    ").to(device)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "out = mha_combined_qkv(embeddings)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(out.shape)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "9b14390d-3e21-43fd-87be-43e7029163ee",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "9b14390d-3e21-43fd-87be-43e7029163ee"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "<br>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    " \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "## 4) Multi-head attention with Einsum\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Implementing multi-head attention using Einstein summation via [`torch.einsum`](https://pytorch.org/docs/stable/generated/torch.einsum.html)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "execution_count": 7,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "92481814-068d-439b-a65c-b1310ebbe0aa",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "colab": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "base_uri": "https://localhost:8080/"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "92481814-068d-439b-a65c-b1310ebbe0aa",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "outputId": "59a75f6e-ef06-418f-8e54-d3b368fbed13"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "torch.Size([8, 1024, 768])\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "import math\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "class MHAEinsum(nn.Module):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        super().__init__()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        assert d_out % num_heads == 0, \"d_out must be divisible by num_heads\"\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        self.d_out = d_out\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        self.num_heads = num_heads\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        self.head_dim = d_out // num_heads\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        # Initialize parameters for Q, K, V\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        self.W_query = nn.Parameter(torch.randn(d_out, d_in))\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        self.W_key = nn.Parameter(torch.randn(d_out, d_in))\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        self.W_value = nn.Parameter(torch.randn(d_out, d_in))\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        if qkv_bias:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            self.bias_q = nn.Parameter(torch.zeros(d_out))\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            self.bias_k = nn.Parameter(torch.zeros(d_out))\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            self.bias_v = nn.Parameter(torch.zeros(d_out))\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        else:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            self.register_parameter(\"bias_q\", None)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            self.register_parameter(\"bias_k\", None)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            self.register_parameter(\"bias_v\", None)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        self.out_proj = nn.Linear(d_out, d_out)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        self.dropout = nn.Dropout(dropout)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        self.register_buffer(\"mask\", torch.triu(torch.ones(context_length, context_length), diagonal=1))\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        # Initialize parameters\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        self.reset_parameters()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    def reset_parameters(self):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        nn.init.kaiming_uniform_(self.W_query, a=math.sqrt(5))\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        nn.init.kaiming_uniform_(self.W_key, a=math.sqrt(5))\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        nn.init.kaiming_uniform_(self.W_value, a=math.sqrt(5))\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        if self.bias_q is not None:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.W_query)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            bound = 1 / math.sqrt(fan_in)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            nn.init.uniform_(self.bias_q, -bound, bound)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            nn.init.uniform_(self.bias_k, -bound, bound)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            nn.init.uniform_(self.bias_v, -bound, bound)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    def forward(self, x):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        b, n, _ = x.shape\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        # Calculate Q, K, V using einsum, first perform linear transformations\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        Q = torch.einsum(\"bnd,di->bni\", x, self.W_query)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        K = torch.einsum(\"bnd,di->bni\", x, self.W_key)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        V = torch.einsum(\"bnd,di->bni\", x, self.W_value)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        # Add biases if they are used\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        if self.bias_q is not None:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            Q += self.bias_q\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            K += self.bias_k\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            V += self.bias_v\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        # Reshape for multi-head attention\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        Q = Q.view(b, n, self.num_heads, self.head_dim).transpose(1, 2)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        K = K.view(b, n, self.num_heads, self.head_dim).transpose(1, 2)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        V = V.view(b, n, self.num_heads, self.head_dim).transpose(1, 2)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        # Scaled dot-product attention\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        scores = torch.einsum(\"bhnd,bhmd->bhnm\", Q, K) / (self.head_dim ** 0.5)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        # Apply mask\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        mask = self.mask[:n, :n].unsqueeze(0).unsqueeze(1).expand(b, self.num_heads, n, n)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        scores = scores.masked_fill(mask.bool(), -torch.inf)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        # Softmax and dropout\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        attn_weights = torch.softmax(scores, dim=-1)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        attn_weights = self.dropout(attn_weights)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        # Aggregate the attended context vectors\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        context_vec = torch.einsum(\"bhnm,bhmd->bhnd\", attn_weights, V)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        # Combine heads and project the output\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        context_vec = context_vec.transpose(1, 2).reshape(b, n, self.d_out)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        context_vec = self.out_proj(context_vec)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        return context_vec\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "mha_einsum = MHAEinsum(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    d_in=embed_dim,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    d_out=embed_dim,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    context_length=context_len,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    dropout=0.0,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    num_heads=12,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    qkv_bias=False\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    ").to(device)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "out = mha_einsum(embeddings)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(out.shape)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "48a042d3-ee78-4c29-bf63-d92fe6706632",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "48a042d3-ee78-4c29-bf63-d92fe6706632"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-03-23 07:27:43 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "<br>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    " \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "## 5) Multi-head attention with PyTorch's scaled dot product attention and FlashAttention"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "f78e346f-3b85-44e6-9feb-f01131381148",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "f78e346f-3b85-44e6-9feb-f01131381148"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-08-10 09:44:11 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "- The implementation below uses PyTorch's [`scaled_dot_product_attention`](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) function, which implements a memory-optimized version of self-attention called [FlashAttention](https://arxiv.org/abs/2205.14135)"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 8,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "id": "1b8e5a0d-1f65-4a03-bf6e-723f0cc428f5",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "1b8e5a0d-1f65-4a03-bf6e-723f0cc428f5"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "class MHAPyTorchScaledDotProduct(nn.Module):\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-04 07:27:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    def __init__(self, d_in, d_out, num_heads, context_length, dropout=0.0, qkv_bias=False):\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "        super().__init__()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        assert d_out % num_heads == 0, \"embed_dim is indivisible by num_heads\"\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        self.num_heads = num_heads\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-04 07:27:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "        self.context_length = context_length\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "        self.head_dim = d_out // num_heads\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        self.d_out = d_out\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        self.qkv = nn.Linear(d_in, 3 * d_out, bias=qkv_bias)\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-05-26 15:38:35 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "        self.proj = nn.Linear(d_out, d_out)\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "        self.dropout = dropout\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    def forward(self, x):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        batch_size, num_tokens, embed_dim = x.shape\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        # (b, num_tokens, embed_dim) --> (b, num_tokens, 3 * embed_dim)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        qkv = self.qkv(x)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        # (b, num_tokens, 3 * embed_dim) --> (b, num_tokens, 3, num_heads, head_dim)\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-26 17:13:08 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "        qkv = qkv.view(batch_size, num_tokens, 3, self.num_heads, self.head_dim)\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        # (b, num_tokens, 3, num_heads, head_dim) --> (3, b, num_heads, num_tokens, head_dim)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        qkv = qkv.permute(2, 0, 3, 1, 4)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        # (3, b, num_heads, num_tokens, head_dim) -> 3 times (b, num_heads, num_tokens, head_dim)\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-26 17:13:08 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "        queries, keys, values = qkv\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        use_dropout = 0. if not self.training else self.dropout\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-08-10 09:44:11 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "        context_vec = nn.functional.scaled_dot_product_attention(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            queries, keys, values, attn_mask=None, dropout_p=use_dropout, is_causal=True)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        # Combine heads, where self.d_out = self.num_heads * self.head_dim\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        context_vec = context_vec.transpose(1, 2).contiguous().view(batch_size, num_tokens, self.d_out)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-26 17:13:08 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "        context_vec = self.proj(context_vec)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "        return context_vec"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 9,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "id": "fbc8ba92-3471-41cb-b1b2-4c0ef5be392b",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "colab": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "base_uri": "https://localhost:8080/"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "fbc8ba92-3471-41cb-b1b2-4c0ef5be392b",
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "outputId": "087a53e7-86d8-48dc-bf2e-023f0f2104cb"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "torch.Size([8, 1024, 768])\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "mha_pytorch_scaled = MHAPyTorchScaledDotProduct(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    d_in=embed_dim,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    d_out=embed_dim,\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-04 07:27:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    context_length=context_len,\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    dropout=0.0,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    num_heads=12,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    qkv_bias=False\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    ").to(device)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "out = mha_pytorch_scaled(embeddings)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(out.shape)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2024-08-10 09:44:11 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "51492724-6018-49f6-8bf6-ae9e585229c3",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "51492724-6018-49f6-8bf6-ae9e585229c3"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "<br>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    " \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "## 6) PyTorch's scaled dot product attention without FlashAttention\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-08-10 09:44:11 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- This is similar to above, except that we disable FlashAttention by passing an explicit causal mask"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 10,
							 
						 
					
						
							
								
									
										
										
										
											2024-08-10 09:44:11 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "id": "bad53538-e905-4065-ba0c-caacdfec5a0b",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "bad53538-e905-4065-ba0c-caacdfec5a0b"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "class MHAPyTorchSDPAWithoutFlash(nn.Module):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    def __init__(self, d_in, d_out, num_heads, context_length, dropout=0.0, qkv_bias=False):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        super().__init__()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        assert d_out % num_heads == 0, \"embed_dim is indivisible by num_heads\"\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        self.num_heads = num_heads\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        self.context_length = context_length\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        self.head_dim = d_out // num_heads\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        self.d_out = d_out\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        self.qkv = nn.Linear(d_in, 3 * d_out, bias=qkv_bias)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        self.proj = nn.Linear(d_out, d_out)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        self.dropout = dropout\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        self.register_buffer(\"mask\", torch.triu(torch.ones(context_length, context_length), diagonal=1).bool())\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    def forward(self, x):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        batch_size, num_tokens, embed_dim = x.shape\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        # (b, num_tokens, embed_dim) --> (b, num_tokens, 3 * embed_dim)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        qkv = self.qkv(x)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        # (b, num_tokens, 3 * embed_dim) --> (b, num_tokens, 3, num_heads, head_dim)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        qkv = qkv.view(batch_size, num_tokens, 3, self.num_heads, self.head_dim)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        # (b, num_tokens, 3, num_heads, head_dim) --> (3, b, num_heads, num_tokens, head_dim)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        qkv = qkv.permute(2, 0, 3, 1, 4)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        # (3, b, num_heads, num_tokens, head_dim) -> 3 times (b, num_heads, num_tokens, head_dim)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        queries, keys, values = qkv\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        use_dropout = 0. if not self.training else self.dropout\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        # Ensure attn_mask is compatible with expected shape and `batch_first=True`\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        # No need to manually adjust for num_heads; ensure it's right for the sequence\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        if self.context_length >= num_tokens:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            attn_mask = self.mask[:num_tokens, :num_tokens]\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        else:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            attn_mask = self.mask[:self.context_length, :self.context_length]\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        context_vec = nn.functional.scaled_dot_product_attention(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            queries, keys, values, attn_mask=attn_mask, dropout_p=use_dropout, is_causal=False)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        # Combine heads, where self.d_out = self.num_heads * self.head_dim\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        context_vec = context_vec.transpose(1, 2).contiguous().view(batch_size, num_tokens, self.d_out)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        context_vec = self.proj(context_vec)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        return context_vec"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 11,
							 
						 
					
						
							
								
									
										
										
										
											2024-08-10 09:44:11 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "id": "f3da7850-e772-47d3-bd51-22d077b01412",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "colab": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "base_uri": "https://localhost:8080/"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "f3da7850-e772-47d3-bd51-22d077b01412",
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "outputId": "cc8fc837-8e06-42fc-bad5-b17816f47fcd"
							 
						 
					
						
							
								
									
										
										
										
											2024-08-10 09:44:11 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "torch.Size([8, 1024, 768])\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "mha_pytorch_sdpa_no_flash = MHAPyTorchSDPAWithoutFlash(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    d_in=embed_dim,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    d_out=embed_dim,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    context_length=context_len,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    dropout=0.0,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    num_heads=12,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    qkv_bias=False\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    ").to(device)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "out = mha_pytorch_sdpa_no_flash(embeddings)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(out.shape)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "351c318f-4835-4d74-8d58-a070222447c4",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "351c318f-4835-4d74-8d58-a070222447c4"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-03-23 07:27:43 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "<br>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    " \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "## 7) Using PyTorch's torch.nn.MultiheadAttention"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "74a6d060-6324-48fa-a35c-cb09f2a48965",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "74a6d060-6324-48fa-a35c-cb09f2a48965"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Below, we use PyTorch's [torch.nn.MultiheadAttention](https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html) implementation"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 12,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "id": "3799c7ef-3155-42c6-a829-f95656453ae0",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "colab": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "base_uri": "https://localhost:8080/"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "3799c7ef-3155-42c6-a829-f95656453ae0",
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "outputId": "78236eea-a0f4-47e4-c846-606e7f8f8768"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "torch.Size([8, 1024, 768])\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "import torch.nn as nn\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "class MHAPyTorchClass(nn.Module):\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-04 07:27:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    def __init__(self, d_in, d_out, num_heads, context_length, dropout=0.0, qkv_bias=False, need_weights=True):\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "        super().__init__()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-04 07:27:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "        self.context_length = context_length\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "        self.multihead_attn = nn.MultiheadAttention(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            embed_dim=d_out,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            num_heads=num_heads,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            dropout=dropout,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            bias=qkv_bias,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            add_bias_kv=qkv_bias,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            batch_first=True,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        )\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        self.need_weights = need_weights\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        self.proj = nn.Linear(d_out, d_out)\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-04 07:27:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "        self.register_buffer(\"mask\", torch.triu(torch.ones(context_length, context_length), diagonal=1).bool())\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    def forward(self, x):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        batch_size, num_tokens, _ = x.shape\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        # Ensure attn_mask is compatible with expected shape and `batch_first=True`\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        # No need to manually adjust for num_heads; ensure it's right for the sequence\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-04 07:27:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "        if self.context_length >= num_tokens:\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "            attn_mask = self.mask[:num_tokens, :num_tokens]\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        else:\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-04 07:27:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "            attn_mask = self.mask[:self.context_length, :self.context_length]\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        # attn_mask broadcasting will handle batch_size dimension implicitly\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        attn_output, _ = self.multihead_attn(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            x, x, x, attn_mask=attn_mask, need_weights=self.need_weights\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        )\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        output = self.proj(attn_output)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        return output\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "mha_pytorch_class_default = MHAPyTorchClass(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    d_in=embed_dim,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    d_out=embed_dim,\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-04 07:27:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    context_length=context_len,\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    dropout=0.0,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    num_heads=12,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    qkv_bias=False\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    ").to(device)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "out = mha_pytorch_class_default(embeddings)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(out.shape)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "a3953bff-1056-4de2-bfd1-dfccf659eee4",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "a3953bff-1056-4de2-bfd1-dfccf659eee4"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-03-23 07:27:43 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "<br>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    " \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "## 8) Using PyTorch's torch.nn.MultiheadAttention with `scaled_dot_product_attention`"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "d2164859-31a0-4537-b4fb-27d57675ba77",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "d2164859-31a0-4537-b4fb-27d57675ba77"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "- Set `need_weights` (default `True`) to need_weights=False so that `MultiheadAttention` uses `scaled_dot_product_attention` [according to the documentation](https://github.com/pytorch/pytorch/blob/71d020262793542974cf13b30f2a9099773f015c/torch/nn/modules/activation.py#L1096)\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    ">  need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``.\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            Set ``need_weights=False`` to use the optimized ``scaled_dot_product_attention``\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            and achieve the best performance for MHA.\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            Default: ``True``."
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 13,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "id": "4a4c2afe-5e1f-4bd7-a118-67031176f147",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "colab": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "base_uri": "https://localhost:8080/"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "4a4c2afe-5e1f-4bd7-a118-67031176f147",
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "outputId": "6359dcff-ddcf-4cf9-eada-c3f0685cced2"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "torch.Size([8, 1024, 768])\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "mha_pytorch_class_noweights = MHAPyTorchClass(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    d_in=embed_dim,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    d_out=embed_dim,\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-04 07:27:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    context_length=context_len,\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    dropout=0.0,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    num_heads=12,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    qkv_bias=False,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    need_weights=False # NEW!\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    ").to(device)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "out = mha_pytorch_class_noweights(embeddings)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(out.shape)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2024-08-10 09:44:11 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "21f4ff35-651c-4e47-bfa1-016f3de01ecc",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "21f4ff35-651c-4e47-bfa1-016f3de01ecc"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "<br>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    " \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "## 9) Using PyTorch's FlexAttention\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-08-10 09:44:11 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- See [FlexAttention: The Flexibility of PyTorch with the Performance of FlashAttention](https://pytorch.org/blog/flexattention/) to learn more about FlexAttention\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-10-23 03:23:31 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "- This is supported starting from PyTorch 2.5, which you can install on a CPU machine via\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-08-10 09:44:11 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-08-14 03:57:41 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    ```bash\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-10-23 03:23:31 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    pip install torch torchvision torchaudio\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-08-14 03:57:41 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    ```\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-08-10 09:44:11 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-10-23 03:23:31 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "- To install PyTorch on a GPU machine, use the following (for more information, also see the installation menu on [pytorch.org](https://pytorch.org/))\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-08-10 09:44:11 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-08-14 03:57:41 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    ```bash\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-10-23 03:23:31 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-08-14 03:57:41 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    ```"
							 
						 
					
						
							
								
									
										
										
										
											2024-08-10 09:44:11 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 14,
							 
						 
					
						
							
								
									
										
										
										
											2024-08-10 09:44:11 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "id": "834318c8-4748-4902-99f0-70ee02bef63e",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "834318c8-4748-4902-99f0-70ee02bef63e"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "from packaging.version import parse as parse_version\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "def normalize_version(version):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    parsed_version = parse_version(version)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    return parse_version(f\"{parsed_version.major}.{parsed_version.minor}.{parsed_version.micro}\")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "current_version = normalize_version(torch.__version__)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "MIN_TORCH_VERSION = \"2.5.0\"\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "required_version = parse_version(MIN_TORCH_VERSION)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 15,
							 
						 
					
						
							
								
									
										
										
										
											2024-08-10 09:44:11 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "id": "WYyFRCXndVH9",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "WYyFRCXndVH9"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "if current_version >= required_version:\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-08-14 03:57:41 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    from torch.nn.attention.flex_attention import flex_attention, create_block_mask\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-08-10 09:44:11 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "def causal(b, h, q_idx, kv_idx):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    return q_idx >= kv_idx\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "class MHAPyTorchFlexAttention(nn.Module):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    def __init__(self, d_in, d_out, num_heads, context_length, dropout=0.0, qkv_bias=False):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        super().__init__()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        assert d_out % num_heads == 0, \"embed_dim is indivisible by num_heads\"\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        self.num_heads = num_heads\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        self.context_length = context_length\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        self.head_dim = d_out // num_heads\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        self.d_out = d_out\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        self.qkv = nn.Linear(d_in, 3 * d_out, bias=qkv_bias)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        self.proj = nn.Linear(d_out, d_out)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        self.dropout = dropout\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        # self.register_buffer(\"block_mask\", create_block_mask(causal, B=None, H=None, Q_LEN=context_length, KV_LEN=context_length))\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-08-14 03:57:41 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "        # `create_block_mask` function does not support buffers, yet\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-08-10 09:44:11 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "        self.block_mask = create_block_mask(causal, B=None, H=None, Q_LEN=context_length, KV_LEN=context_length)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    def forward(self, x):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        batch_size, num_tokens, embed_dim = x.shape\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        # (b, num_tokens, embed_dim) --> (b, num_tokens, 3 * embed_dim)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        qkv = self.qkv(x)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        # (b, num_tokens, 3 * embed_dim) --> (b, num_tokens, 3, num_heads, head_dim)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        qkv = qkv.view(batch_size, num_tokens, 3, self.num_heads, self.head_dim)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        # (b, num_tokens, 3, num_heads, head_dim) --> (3, b, num_heads, num_tokens, head_dim)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        qkv = qkv.permute(2, 0, 3, 1, 4)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        # (3, b, num_heads, num_tokens, head_dim) -> 3 times (b, num_heads, num_tokens, head_dim)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        queries, keys, values = qkv\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        use_dropout = 0. if not self.training else self.dropout\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        # Ensure attn_mask is compatible with expected shape and `batch_first=True`\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        # No need to manually adjust for num_heads; ensure it's right for the sequence\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        if self.context_length >= num_tokens:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            attn_mask = self.block_mask[:num_tokens, :num_tokens]\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        else:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            attn_mask = self.block_mask[:self.context_length, :self.context_length]\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        context_vec = flex_attention(queries, keys, values, block_mask=attn_mask)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        # Combine heads, where self.d_out = self.num_heads * self.head_dim\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        context_vec = context_vec.transpose(1, 2).contiguous().view(batch_size, num_tokens, self.d_out)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        context_vec = self.proj(context_vec)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        return context_vec"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 16,
							 
						 
					
						
							
								
									
										
										
										
											2024-08-10 09:44:11 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "id": "9cdaaf8a-f956-44bc-932f-4d33448e8aaf",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "colab": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "base_uri": "https://localhost:8080/"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "9cdaaf8a-f956-44bc-932f-4d33448e8aaf",
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "outputId": "a88a7398-159e-401f-d96c-2fc928908e3e"
							 
						 
					
						
							
								
									
										
										
										
											2024-08-10 09:44:11 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "torch.Size([8, 1024, 768])\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
									
										
										
										
											2024-08-10 09:44:11 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-08-14 03:57:41 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "if current_version >= required_version and torch.cuda.is_available():\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-08-10 09:44:11 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    mha_pytorch_flex = MHAPyTorchFlexAttention(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        d_in=embed_dim,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        d_out=embed_dim,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        context_length=context_len,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        dropout=0.0,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        num_heads=12,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        qkv_bias=False\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    ).to(device)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    out = mha_pytorch_flex(embeddings)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    print(out.shape)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "8877de71-f84f-4f6d-bc87-7552013b6301",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "8877de71-f84f-4f6d-bc87-7552013b6301"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-03-23 07:27:43 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "<br>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    " \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "## Quick speed comparison (M3 Macbook Air CPU)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": null,
							 
						 
					
						
							
								
									
										
										
										
											2024-08-10 09:44:11 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "id": "219cf93a-078f-434d-888c-2458d0731285",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "colab": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "base_uri": "https://localhost:8080/"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "219cf93a-078f-434d-888c-2458d0731285",
							 
						 
					
						
							
								
									
										
										
										
											2024-08-14 03:57:41 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "outputId": "a10b52d4-b4e6-43c2-9677-113c41edd3b7"
							 
						 
					
						
							
								
									
										
										
										
											2024-08-10 09:44:11 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "PyTorch version: 2.4.0\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-08-10 09:44:11 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "Running on cpu\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "torch.manual_seed(123)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(f\"PyTorch version: {torch.__version__}\")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(f\"Running on {device}\")"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": null,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "id": "a97c0b2e-6593-49d8-98bc-2267b3aa610f",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "colab": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "base_uri": "https://localhost:8080/"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "a97c0b2e-6593-49d8-98bc-2267b3aa610f",
							 
						 
					
						
							
								
									
										
										
										
											2024-08-14 03:57:41 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "outputId": "7bcd7da4-d115-4ba6-efba-377a0bd7d3a8"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "179 ms ± 7.39 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "## 1) CausalAttention MHA wrapper class from chapter 3\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "%timeit mha_ch03_wrapper(embeddings)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": null,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "id": "19db9c2c-8e75-431a-8eef-0b4d8284e6e6",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "colab": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "base_uri": "https://localhost:8080/"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "19db9c2c-8e75-431a-8eef-0b4d8284e6e6",
							 
						 
					
						
							
								
									
										
										
										
											2024-08-14 03:57:41 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "outputId": "b04b4d0d-71aa-4944-f02b-131bf5a50202"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "166 ms ± 2.62 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "## 2) The multi-head attention class from chapter 3\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "%timeit mha_ch03(embeddings)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": null,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "id": "aa526ee0-7a88-4f34-a49a-f8f97da83779",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "colab": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "base_uri": "https://localhost:8080/"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "aa526ee0-7a88-4f34-a49a-f8f97da83779",
							 
						 
					
						
							
								
									
										
										
										
											2024-08-14 03:57:41 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "outputId": "5436928a-7b98-4c40-bf51-97973f13327e"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "190 ms ± 2.03 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "## 3) An alternative multi-head attention with combined weights\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "%timeit mha_combined_qkv(embeddings)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": null,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "131ca826-35bf-47e5-b497-540aba439ef9",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "131ca826-35bf-47e5-b497-540aba439ef9",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "outputId": "f5848852-f81b-4e5f-a7ff-e37a8445ad91"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "196 ms ± 1.08 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "## 4) Multi-head attention using Einstein summation\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "%timeit mha_einsum(embeddings)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "execution_count": null,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "id": "cc2b4256-16d8-4c34-9fd0-d4b4af0e60fa",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "colab": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "base_uri": "https://localhost:8080/"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "cc2b4256-16d8-4c34-9fd0-d4b4af0e60fa",
							 
						 
					
						
							
								
									
										
										
										
											2024-08-14 03:57:41 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "outputId": "9e07ce73-a2de-4e2c-8276-64626df9450e"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "110 ms ± 423 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "## 5) Multi-head attention with PyTorch's scaled dot product attention\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "%timeit mha_pytorch_scaled(embeddings)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": null,
							 
						 
					
						
							
								
									
										
										
										
											2024-08-10 09:44:11 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "id": "c44305ce-9f61-451a-b9ef-30caba222357",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "colab": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "base_uri": "https://localhost:8080/"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "c44305ce-9f61-451a-b9ef-30caba222357",
							 
						 
					
						
							
								
									
										
										
										
											2024-08-14 03:57:41 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "outputId": "6bab4a24-5bb4-4ad6-b260-3b442f598950"
							 
						 
					
						
							
								
									
										
										
										
											2024-08-10 09:44:11 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "99.5 ms ± 790 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
							 
						 
					
						
							
								
									
										
										
										
											2024-08-10 09:44:11 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "## 6) PyTorch's scaled dot product attention without FlashAttention\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-08-10 09:44:11 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "%timeit mha_pytorch_sdpa_no_flash(embeddings)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": null,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "id": "0f209e70-ebb6-4a1a-b608-1ff42e41c01d",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "colab": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "base_uri": "https://localhost:8080/"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "0f209e70-ebb6-4a1a-b608-1ff42e41c01d",
							 
						 
					
						
							
								
									
										
										
										
											2024-08-14 03:57:41 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "outputId": "630c49d1-8a06-4148-cd96-a7b2467310a0"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "198 ms ± 3.52 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "## 7) Using PyTorch's torch.nn.MultiheadAttention\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "%timeit mha_pytorch_class_default(embeddings)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": null,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "id": "3f4968c2-8d40-4ab9-8dba-052b4f77d756",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
									
										
										
										
											2024-08-10 09:44:11 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "colab": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "base_uri": "https://localhost:8080/"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "id": "3f4968c2-8d40-4ab9-8dba-052b4f77d756",
							 
						 
					
						
							
								
									
										
										
										
											2024-08-14 03:57:41 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "outputId": "10f6a268-f9cf-446c-aa83-e87b6a0b4f5c"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "168 ms ± 2.63 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "## 8) Using PyTorch's torch.nn.MultiheadAttention disabling `need_weights`\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "%timeit mha_pytorch_class_noweights(embeddings)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2024-08-10 09:44:11 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "execution_count": null,
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "id": "bdd8e0fc-ef24-424c-bccf-c381e73da228",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "bdd8e0fc-ef24-424c-bccf-c381e73da228"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
									
										
										
										
											2024-08-10 09:44:11 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "## 9) Using PyTorch's FlexAttention\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-08-10 09:44:11 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-08-14 03:57:41 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "# Requires PyTorch 2.5.0 or newer and currently only supports CUDA PyTorch\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-08-10 09:44:11 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "%timeit mha_pytorch_flex(embeddings)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "a78ff594-6cc2-496d-a302-789fa104c3c9",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "a78ff594-6cc2-496d-a302-789fa104c3c9"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-03-23 07:27:43 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "<br>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    " \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "## Quick speed comparison (Nvidia A100 GPU)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 17,
							 
						 
					
						
							
								
									
										
										
										
											2024-08-10 09:44:11 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "id": "RStnI1pEi6Eo",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "RStnI1pEi6Eo"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-08-12 14:54:12 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "# Enable tensor cores\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-08-10 09:44:11 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "torch.set_float32_matmul_precision(\"high\")"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 18,
							 
						 
					
						
							
								
									
										
										
										
											2024-08-10 09:44:11 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "id": "e8431d75-e1c9-4d9a-b7da-9a1ff391f2bf",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "colab": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "base_uri": "https://localhost:8080/"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "e8431d75-e1c9-4d9a-b7da-9a1ff391f2bf",
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "outputId": "f6356d4c-7a3f-47f5-cf51-5507db3f5748"
							 
						 
					
						
							
								
									
										
										
										
											2024-08-10 09:44:11 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "PyTorch version: 2.5.0.dev20240905+cu121\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-08-10 09:44:11 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "Running on cuda\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "torch.manual_seed(123)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(f\"PyTorch version: {torch.__version__}\")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(f\"Running on {device}\")"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 19,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "id": "707a2a14-a089-48a8-88aa-d328e1e0a9d0",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "colab": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "base_uri": "https://localhost:8080/"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "707a2a14-a089-48a8-88aa-d328e1e0a9d0",
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "outputId": "4ea5798b-a590-401b-d049-3fed0716db34"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "4.33 ms ± 51.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "## 1) CausalAttention MHA wrapper class from chapter 3\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "%timeit mha_ch03_wrapper(embeddings)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 20,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "id": "8686dd69-3655-40e4-a57b-a2c55532a010",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "colab": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "base_uri": "https://localhost:8080/"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "8686dd69-3655-40e4-a57b-a2c55532a010",
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "outputId": "88094b61-4d87-47bd-8c8b-c9344ab57062"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "3.09 ms ± 363 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "## 2) The multi-head attention class from chapter 3\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "%timeit mha_ch03(embeddings)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 21,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "id": "2209d7df-e54b-4910-ae2b-c78cf684d9bf",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "colab": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "base_uri": "https://localhost:8080/"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "2209d7df-e54b-4910-ae2b-c78cf684d9bf",
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "outputId": "e3d82c53-f75b-425a-ed3e-5e48ea9ef768"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "3.81 ms ± 656 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "## 3) An alternative multi-head attention with combined weights\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "%timeit mha_combined_qkv(embeddings)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 22,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "abee5edf-2585-4f0e-846c-b1c7ca88f545",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "colab": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "base_uri": "https://localhost:8080/"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "abee5edf-2585-4f0e-846c-b1c7ca88f545",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "outputId": "c9bf17f5-de62-4c39-a328-fe430812b156"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "4.12 ms ± 1.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "## 4) Multi-head attention using Einstein summation\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "%timeit mha_einsum(embeddings)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "execution_count": 23,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "id": "1075abe2-4839-4fd6-af3e-c09bb3651e26",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "colab": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "base_uri": "https://localhost:8080/"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "1075abe2-4839-4fd6-af3e-c09bb3651e26",
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "outputId": "b63f4769-3be5-44df-b8f2-2ac379be1ff4"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "1.25 ms ± 1.1 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
									
										
										
										
											2024-03-09 10:09:17 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "## 5) Multi-head attention with PyTorch's scaled dot product attention\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "%timeit mha_pytorch_scaled(embeddings)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 24,
							 
						 
					
						
							
								
									
										
										
										
											2024-08-10 09:44:11 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "id": "218adbaf-f17f-47d9-81d5-41c758218df7",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "colab": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "base_uri": "https://localhost:8080/"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "218adbaf-f17f-47d9-81d5-41c758218df7",
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "outputId": "a30ab365-865d-4175-f148-dc15abc4e07a"
							 
						 
					
						
							
								
									
										
										
										
											2024-08-10 09:44:11 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "2.03 ms ± 119 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
							 
						 
					
						
							
								
									
										
										
										
											2024-08-10 09:44:11 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "## 6) PyTorch's scaled dot product attention without FlashAttention\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-08-10 09:44:11 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "%timeit mha_pytorch_sdpa_no_flash(embeddings)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 25,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "id": "868e3670-8edc-47bc-9e06-eb505e44dc9d",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "colab": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "base_uri": "https://localhost:8080/"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "868e3670-8edc-47bc-9e06-eb505e44dc9d",
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "outputId": "e20e77ac-6573-4830-82c7-795bd139af4f"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "3.05 ms ± 388 ns per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "## 7) Using PyTorch's torch.nn.MultiheadAttention\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "%timeit mha_pytorch_class_default(embeddings)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 26,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "id": "944870e6-de54-4e3b-a455-b8f21f6f92c8",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "colab": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "base_uri": "https://localhost:8080/"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "944870e6-de54-4e3b-a455-b8f21f6f92c8",
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "outputId": "26df6295-fa5c-4b3f-89be-c7183f079fce"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "2.37 ms ± 6.53 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "## 8) Using PyTorch's torch.nn.MultiheadAttention disabling `need_weights`\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "%timeit mha_pytorch_class_noweights(embeddings)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2024-08-10 09:44:11 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 27,
							 
						 
					
						
							
								
									
										
										
										
											2024-08-10 09:44:11 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "id": "evKtpb5QN_2A",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "colab": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "base_uri": "https://localhost:8080/"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "evKtpb5QN_2A",
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "outputId": "23bf5398-c8ec-4463-8af9-17de8f920a33"
							 
						 
					
						
							
								
									
										
										
										
											2024-08-10 09:44:11 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "14.6 ms ± 1.34 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
							 
						 
					
						
							
								
									
										
										
										
											2024-08-10 09:44:11 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "## 9) Using PyTorch's FlexAttention\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-08-10 09:44:11 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "# Requires PyTorch 2.5.0 or newer\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "%timeit mha_pytorch_flex(embeddings)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "dabc6575-0316-4640-a729-e616d5c17b73",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "dabc6575-0316-4640-a729-e616d5c17b73"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-03-23 07:27:43 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "<br>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    " \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-08-14 03:57:41 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "# Visualizations"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 35,
							 
						 
					
						
							
								
									
										
										
										
											2024-08-10 09:44:11 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "id": "bbb2f729-d3d8-46d0-b249-9249197ea574",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "colab": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "base_uri": "https://localhost:8080/"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "bbb2f729-d3d8-46d0-b249-9249197ea574",
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "outputId": "a45fe256-6416-4f43-87d2-27bbf97239e3"
							 
						 
					
						
							
								
									
										
										
										
											2024-08-10 09:44:11 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "PyTorch version: 2.5.0.dev20240905+cu121\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-08-10 09:44:11 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "Running on cuda\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "torch.manual_seed(123)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(f\"PyTorch version: {torch.__version__}\")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(f\"Running on {device}\")"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 36,
							 
						 
					
						
							
								
									
										
										
										
											2024-08-14 03:57:41 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "id": "b0620bf5",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "b0620bf5"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "functions = {\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"1) MHA wrapper class\": mha_ch03_wrapper,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"2) MHA Ch03\": mha_ch03,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"3) MHA with combined QKV weights\": mha_combined_qkv,\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    \"4) MHA with Einsum\": mha_einsum,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"5) MHA with PyTorch scaled_dot_product_attention\": mha_pytorch_scaled,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"6) PyTorch's SDPA, no FlashAttention\": mha_pytorch_sdpa_no_flash,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"7) PyTorch MHA class defaults\": mha_pytorch_class_default,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"8) PyTorch MHA with need_weights=False\": mha_pytorch_class_noweights\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-08-14 03:57:41 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    }\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "if current_version >= required_version:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    functions[\"8) PyTorch's FlexAttention\"] =  mha_pytorch_flex"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 37,
							 
						 
					
						
							
								
									
										
										
										
											2024-08-14 03:57:41 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "id": "CDJAPZaszaqx",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "CDJAPZaszaqx"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "import matplotlib.pyplot as plt\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "# Customize further for dark mode aesthetics\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "plt.rcParams[\"figure.facecolor\"] = \"#121212\"\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "plt.rcParams[\"axes.facecolor\"] = \"#121212\"\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "plt.rcParams[\"axes.edgecolor\"] = \"white\"\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "plt.rcParams[\"axes.labelcolor\"] = \"white\"\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "plt.rcParams[\"text.color\"] = \"white\"\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "plt.rcParams[\"xtick.color\"] = \"white\"\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "plt.rcParams[\"ytick.color\"] = \"white\"\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "plt.rcParams[\"grid.color\"] = \"#444444\"\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "plt.rcParams[\"lines.linewidth\"] = 2\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "plt.rcParams[\"lines.markersize\"] = 8\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "def plot_execution_times(functions, execution_means, execution_stds, filename):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    # Create plot\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    fig, ax = plt.subplots()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    bars = ax.bar(functions.keys(), execution_means, yerr=execution_stds, capsize=5, error_kw={'ecolor': 'grey'})\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    plt.ylabel(\"Execution time (ms)\")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    plt.xticks(rotation=45, ha=\"right\")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    # Calculate new ylim with a margin\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    max_execution_time = max(execution_means)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    upper_ylim = max_execution_time + 0.4 * max_execution_time  # Adding a 40% margin\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    plt.ylim(0, upper_ylim)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    # Annotate bars with execution times\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    for bar in bars:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        yval = bar.get_height()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        plt.text(bar.get_x() + bar.get_width()/2, yval + (0.05 * upper_ylim), round(yval, 2), ha=\"center\", va=\"bottom\")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    plt.tight_layout()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    plt.savefig(filename)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    plt.show()"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "4df834dc",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "4df834dc"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "## Speed comparison (Nvidia A100 GPU) with warmup (forward pass only)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 38,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "id": "29b63d3d-6d0b-43bb-9c68-d5514dc81000",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "29b63d3d-6d0b-43bb-9c68-d5514dc81000"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "# CUDA benchmark code shared by Andrei Aksionov\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "# and based on code from\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "# https://github.com/cuda-mode/lectures/blob/main/lecture1/pytorch_square.py\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-08-14 03:57:41 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "import numpy as np\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "def time_pytorch_function(func, *input, num_repeats=1_000):\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    start = torch.cuda.Event(enable_timing=True)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    end = torch.cuda.Event(enable_timing=True)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    # Warmup\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    for _ in range(5):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        func(*input)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    torch.cuda.synchronize()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-08-14 03:57:41 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    times = []\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    for _ in range(num_repeats):\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-08-14 03:57:41 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "        start.record()\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "        func(*input)\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-08-14 03:57:41 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "        end.record()\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "        torch.cuda.synchronize()\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-08-14 03:57:41 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "        times.append(start.elapsed_time(end))\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    return np.mean(times), np.std(times)"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 39,
							 
						 
					
						
							
								
									
										
										
										
											2024-08-14 03:57:41 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "id": "9dd07a09",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "colab": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "base_uri": "https://localhost:8080/",
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								     "height": 488
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
									
										
										
										
											2024-08-14 03:57:41 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "id": "9dd07a09",
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "outputId": "491d06f4-a6bc-431a-a1ca-4db38df57e0c"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "data": {
							 
						 
					
						
							
								
									
										
										
										
											2024-09-07 07:27:28 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAnIAAAHWCAYAAADzS2TwAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzddVhU6dvA8e/QKYgJdneuLXYHiGKgYq8oig22mKiIrSh2d8eu6Ora3Ws3KgYqYgc17x++nB8j5q46Z/T+XJfX7pxzZuZ+mDPn3POkxsHBQYsQQgghhDA4RvoOQAghhBBC/DuSyAkhhBBCGChJ5IQQQgghDJQkckIIIYQQBkoSOSGEEEIIAyWJnBBCCCGEgZJETgghhBDCQEkiJ4QQQghhoEz0HYChc3R05MWLF/oOQwghhBA/GRsbG+7du/fJYySR+w8cHR05e/asvsMQQgghxE8qf/78n0zmftpErnv37tStW5ccOXLw+vVrjh49ytChQ7l69apyjLm5OcOHD6d+/fqYmZmxc+dO/Pz8ePjw4Re9R0JNXP78+aVWTgghhBDfjI2NDWfPnv1sfvHTJnJlypRhzpw5nDhxAhMTEwYOHMjq1aspU6YMr169AiAgIIBq1arRtm1bnj17RmBgIAsWLKB27dpf9V4vXrzg+fPn36MYQgghhBAfpXFwcNDqO4gfIUWKFFy+fJm6dety8OBBbG1tuXz5Ml5eXmzatAmAHDlycOjQIWrUqMGxY8c++5q2traEhYWROXNmSeSEEEII8c18aY7xy4xaTZYsGQBRUVEAFC5cGDMzM3bv3q0cc+XKFW7fvk2xYsU++BpmZmbY2toq/2xsbL5/4EIIIYQQH/HTNq0mptFoCAgI4NChQ1y8eBGA1KlT8/btW549e6Zz7MOHD0mTJs0HX6d79+706dPnu8crhBBCCPElfolELigoiDx58lCnTp3/9DoTJ05k+vTpyuOEjohCCCGEEPrw0ydygYGBVK9enbp163L37l1l+4MHDzA3NydZsmQ6tXKpUqUiIiLig68VHR1NdHT0d49ZCCGEEOJL/NR95AIDA6lTpw5ubm7cunVLZ9+pU6eIjo6mQoUKyrbs2bOTIUOGLxroIIQQQgihbz9tjVxQUBDu7u54enry4sULUqdODcCzZ8948+YNz58/Z8mSJQwfPpyoqCieP3/O6NGjOXLkiCRyQgghhDAIP20i17ZtWwBlapEEPj4+LFu2DIABAwYQHx/P/PnzdSYEFkIIIYQwBL/MPHLfg8wjJ4QQQojvQeaRE0IIIYT4yUkiJ4QQQghhoCSRE0IIIYQwUJLICSGEEEIYKEnkhBBCCCEMlCRyQgghhBAGShI5IYQQQggDJYmcEEIIIYSB+mlXdhBCCCHEr8HKygpra+uvft7Lly959erVd4jox5FETgghhBAGrUCBApQsWfKrn3f48GEOHz78HSL6cSSRE0IIIYRBO3PmDNevX0+yvV69elhZWfHq1Ss2bNiQZP/Lly9/RHjfleoSuYwZM1K6dGnSp0+PlZUVjx494syZMxw9epS3b9/qOzwhhBBCqMyrV68+2EQaHx+v/Pfhw4c/OqwfQjWJXMOGDenQoQOFCxfmwYMH3L9/nzdv3pA8eXIyZ87M27dvWb16NZMmTSI8PFzf4QohhBBC6J0qErmdO3cSExPDsmXLaNWqFXfv3tXZb2ZmRvHixalfvz47duzAz8+PjRs36ilaIYQQQgh1UEUiN2zYMHbu3PnR/dHR0ezfv5/9+/cTEBBAxowZf2B0QgghhBDqpIpE7lNJ3PuioqKIior6jtEIIYQQQhgG1U0IXLBgQfLkyaM8rlWrFosWLWLgwIGYmprqMTIhhBBCCHVRXSI3fvx4smfPDkCmTJmYNWsWr169wtXVlSFDhug3OCGEEEIIFVFdIpctWzbOnDkDvJv/5eDBg3To0AEfHx9cXFz0HJ0QQgghfoTSpUuzZMkSzp07R2RkJLVr1/7osWPHjiUyMpIOHTp89nUdHR0JCQnhypUrhIeHs3fvXgoXLqzst7a2JjAwkDNnzhAeHs6BAwdo3br1NyjR96GKPnKJaTQajIze5ZcVKlRg69atANy5cwcHBwd9hiaEEEKIH8TKyopz586xdOlSFi5c+NHj6tSpQ7Fixbh3795nX9POzo4///yTffv20aRJEx49ekTWrFl58uSJcszw4cMpV64cHTt25NatW1SqVImgoCDu379PaGjotyjaN6W6RO7UqVP06tWL3bt3U6ZMGXx9fYF3zaw/62R+QgghhNC1Y8cOduzY8cljHB0dGT16NA0bNmT58uWffc1u3bpx584dunTpomy7deuWzjElSpRg+fLl7N+/H4CFCxfSqlUrihYtqspETnVNq/3796dgwYIEBgYyfvx4bty4AYCrqytHjhzRc3RCCCGEUAONRsP06dOZMmUKly5d+qLn1KxZk1OnTjF37lwuXrzIzp07adGihc4xR44coVatWjg6OgLg7OxM9uzZv2qGjR9JdTVy58+fp1y5ckm2Dx48mLi4OD1EJIQQQgi16datG7GxscycOfOLn5MpUybatGnD9OnTmTBhAkWKFGHUqFHExMQoNXp9+/ZlwoQJnD17lpiYGOLj4+nRowcHDx78XkX5T1SXyCVmbW2t9JdL8Pz5cz1FI4QQQgg1KFSoEF5eXlSuXPmrnmdkZMSpU6cYMWIEAGfOnCFPnjy0bt1aSeTat29PsWLFaNasGbdv36ZMmTKMGTOG+/fvs3v37m9elv9KdYlcxowZCQwMpGzZslhYWCjbNRoNWq2W1KlT6zE6IYQQQuhbqVKlSJUqFadPn1a2mZiYMHz4cDp27EiRIkU++LyIiIgkzbCXL19WZsWwsLBg4MCBtGzZkr/++gt411KYP39+OnfuLInclwgJCUGj0dC1a1cePnyIVqvVd0hCCCGEUJGVK1cmSapWr17NypUrWbp06Uefd/jwYWWu2gTZsmXj9u3bAJiammJmZkZ8fLzOMXFxcUlaCNVCdYlcvnz5qFKlClevXtV3KEIIIYTQE2tra7JkyaI8zpgxI/nz5ycqKoo7d+4kWa4zJiaGiIgInfzBy8uLK1euKAMVQkJC2LJlCz169GD9+vUULVqUli1b0rNnT+Bd9619+/YxdOhQ3rx5w+3btylbtixNmjRh0KBBP6DUX091idzJkydJly6dJHJCCCHEL6xw4cJs3LhReRwQEADAsmXL8PHx+aLXSJEiBXfu3FEenzx5kpYtWzJo0CB8fX25desWAwYMYPXq1cox7du3Z9CgQcyYMQN7e3vCw8MJCAhg3rx536hk35bGwcFBVW2XmTNnZty4caxatYoLFy4QExOjs//8+fN6iiwpW1tbwsLCyJw5swzCEEIIIVSmbdu22NjY8OLFC+bOnavvcL7Kl+YYqquRS5kyJZkzZ2bKlCnKNq1WK4MdhBBCCCHeo7pEbvLkyZw5cwYvLy8ePHgggx2EEEIIIT5CdYlc+vTpad68ubKigxBCCCGE+DDVjaXdu3cv+fPn13cYQgghhBCqp7oaua1btzJixAjy5MnzwcEOX7NgbenSpfHx8aFw4cKkTZuWFi1a8Oeffyr7p06dStOmTXWes2PHDho3bvzfCiGEEEII8QOoLpEbN24cAH5+fkn2fe1gBysrK86dO8fSpUtZuHDhB4/Zvn07Xbp0UR6/ffv2KyMWQgghhNAP1SVyqVKl+mavtWPHDnbs2PHJY6Kjo3nw4ME3e08hhBBCiB9FdYncj1a2bFkuXrzI06dP2bt3LwEBAUlmi05gZmaGubm58tjGxuZHhSmEEEIIkYQqBjvUr1//i491cnKiRIkS3+R9d+zYQadOnahfvz5Dhw6lTJkyrFy58qPrqXXv3p2wsDDl39mzZ79JHEIIIYQQ/4YqErk2bdpw8OBBunTpQs6cOZPst7W1pWrVqsyYMYOdO3fi4ODwTd533bp1hIaGcuHCBf7880+aNm1K0aJFcXZ2/uDxEydOJHPmzMo/GV0rhBBCCH1SRdOqq6srNWvWVNY3e/XqFQ8ePODt27fY29uTOnVqIiM
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "text/plain": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "<Figure size 640x480 with 1 Axes>"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "display_data"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-08-14 03:57:41 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "execution_stats = [time_pytorch_function(fn, embeddings) for fn in functions.values()]\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "execution_means = [stat[0] for stat in execution_stats]\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "execution_stds = [stat[1] for stat in execution_stats]\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-08-10 09:44:11 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-08-14 03:57:41 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "plot_execution_times(functions, execution_means, execution_stds, filename=\"1_forward-only.pdf\")"
							 
						 
					
						
							
								
									
										
										
										
											2024-08-10 09:44:11 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "VQaSerWCOnYB",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "VQaSerWCOnYB"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "<br>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    " \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "## Speed comparison (Nvidia A100 GPU) with warmup (forward and backward pass)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 40,
							 
						 
					
						
							
								
									
										
										
										
											2024-08-14 03:57:41 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "id": "69e6377b",
							 
						 
					
						
							
								
									
										
										
										
											2024-08-10 09:44:11 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
									
										
										
										
											2024-08-14 03:57:41 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "id": "69e6377b"
							 
						 
					
						
							
								
									
										
										
										
											2024-08-10 09:44:11 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
									
										
										
										
											2024-08-14 03:57:41 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
									
										
										
										
											2024-08-10 09:44:11 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "def forward_backward(func, embeddings):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    if embeddings.grad is not None:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        embeddings.grad.zero_()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    output = func(embeddings)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    loss = output.sum()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    loss.backward()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "def time_pytorch_function_forward_backward(func, *input, num_repeats = 1_000):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    # CUDA IS ASYNC so can't use python time module\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    start = torch.cuda.Event(enable_timing=True)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    end = torch.cuda.Event(enable_timing=True)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    # Warmup\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    for _ in range(5):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        forward_backward(func, *input)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    torch.cuda.synchronize()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-08-14 03:57:41 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    times = []\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-08-10 09:44:11 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    for _ in range(num_repeats):\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-08-14 03:57:41 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "        start.record()\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-08-10 09:44:11 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "        forward_backward(func, *input)\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-08-14 03:57:41 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "        end.record()\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-08-10 09:44:11 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "        torch.cuda.synchronize()\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-08-14 03:57:41 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "        times.append(start.elapsed_time(end))\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-08-10 09:44:11 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-08-14 03:57:41 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    return np.mean(times), np.std(times)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 41,
							 
						 
					
						
							
								
									
										
										
										
											2024-08-14 03:57:41 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "id": "ReCmeRhCOpm8",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "colab": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "base_uri": "https://localhost:8080/",
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								     "height": 488
							 
						 
					
						
							
								
									
										
										
										
											2024-08-14 03:57:41 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "ReCmeRhCOpm8",
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "outputId": "2bcfa909-ba87-4d31-b926-bc66e63736cc"
							 
						 
					
						
							
								
									
										
										
										
											2024-08-14 03:57:41 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "data": {
							 
						 
					
						
							
								
									
										
										
										
											2024-09-07 07:27:28 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAnIAAAHWCAYAAADzS2TwAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzddVQU6xvA8e+SkgJigGJ77bx2d4MotthXFMUEbMVWxE7s7r52Xbv7WtiF2KKiqAjs7w8O82PFvurO6vM5xyM7OzM8LzM78+w7b2gcHBy0CCGEEEIIg2Ok7wCEEEIIIcS3kUROCCGEEMJASSInhBBCCGGgJJETQgghhDBQksgJIYQQQhgoSeSEEEIIIQyUJHJCCCGEEAZKEjkhhBBCCANlou8ADJ2TkxMvX77UdxhCCCGE+MVYW1tz7969T64jidx/4OTkxLlz5/QdhhBCCCF+Ubly5fpkMieJ3H8QXxOXK1cuqZUTQgghxHdjbW3NuXPnPptfSCL3Hbx8+ZKIiAh9hyGEEEKI34x0dhBCCCGEMFCSyAkhhBBCGChJ5IQQQgghDJQkckIIIYQQBkoSOSGEEEIIAyWJnBBCCCGEgZJETgghhBDCQEkiJ4QQQghhoCSRE0IIIYQwUJLICSGEEEIYKEnkhBBCCCEMlCRyQgghhBAGShI5IYQQQggDJYmcEEIIIYSBkkROCCGEEMJA/bKJXMuWLdm7dy83b97k5s2bbNmyhQoVKijvm5ubM3LkSK5cucKtW7eYO3cuyZMn12PEQgghhBBf55dN5MLCwhg0aBDly5enQoUK7Nu3j4ULF5I1a1YAhg4dSpUqVWjVqhVubm6kSpWKefPm6TlqIYQQQogvp3FwcNDqO4if5erVqwQEBPD3339z+fJlvLy8WL9+PQBZsmTh8OHDVKlShePHj3/R/mxsbLh58ybp06cnIiLiR4YuhBBCiN/Il+YYv2yNXEJGRkbUrl0bS0tLjh8/Tr58+TAzM2PPnj3KOleuXOHOnTsULFjwo/sxMzPDxsZG+Wdtbf0zwhdCCCGE+CATfQfwI2XPnp0tW7aQJEkSXr16RbNmzbh06RK5cuXi7du3vHjxQmf9R48ekTJlyo/ur0uXLvTo0eNHhy2EEEII8UVUl8ilTZuWYsWKkSZNGiwtLXn8+DFnz57l2LFjvH379qv2dfXqVcqWLYutrS1ubm5MnjwZNze3b45t3LhxTJ06VXltbW3NuXPnvnl/QgghhBD/hWoSubp169K2bVvy5cvHw4cPuX//Pm/evMHe3p706dPz9u1bVq5cyfjx4wkNDf2ifb57944bN24AcObMGfLnz4+Xlxdr167F3NwcW1tbnVq55MmT8+DBg4/uLyoqiqioqP9WUCGEEEKI70QVidyuXbt49+4dS5YsoXnz5oSFhem8b2ZmRqFChahduzY7d+7E39+fv//++6t/j5GREebm5pw+fZqoqCjKlCmjdHbInDkzLi4uX9zRQQghhBDqYGlpiZWV1Vdv9+rVKyIjI39ARD+PKhK5QYMGsWvXro++HxUVxYEDBzhw4ABDhw4lbdq0n91nv3792LFjB6GhoVhbW1O3bl1KlChBvXr1iIiIYNGiRQwePJjw8HAiIiIYMWIER48elUROCCGEMDC5c+emSJEiX73dkSNHOHLkyA+I6OdRRSL3qSTufeHh4YSHh392PUdHR6ZMmULKlCl58eIFFy5coF69euzevRuAPn36EBsby9y5czEzM2PXrl34+/t/axGEEEIIoSdnz57l+vXriZbXqlULS0tLIiMjWbduXaL3X7169TPC+6FUkcgllCdPHt69e8fFixcBqFatGo0bN+bSpUsEBgby7t27L9pP586dP/n+27dv6d69O927d//PMQshhBBCfyIjIz/4iDQ2Nlb5/9GjRz87rJ9CdePIjRkzhsyZMwOQLl06ZsyYQWRkJG5ubgwYMEC/wQkhhBBCqIjqErlMmTJx9uxZIK5K9NChQ7Rt2xYfHx9cXV31HJ0QQgghfrQuXbqwY8cObt26RUhICAsWLFAqeT5k2bJlPHnyhOrVq3/x7xg1ahRPnjyhbdu2yjIXFxfGjx/PyZMnCQ0N5fjx4/To0QNTU9P/VJ4fSXWJnEajwcgoLqwyZcqwfft2AO7evYuDg4M+QxNCCCHET1C8eHFmzZpF5cqV8fDwwMTEhJUrV2JpaZlo3Xbt2qHVft1sozVq1KBgwYLcu3dPZ3mWLFkwMjKiW7dulChRgr59+9KiRQv69u37n8rzI6mujdzp06fx9fVlz549FC9eHD8/PyDuMeuv+nxbCCGEEP9Xv359ndc+Pj5cvnyZvHnzcujQIWV5rly56NChAxUqVFDa1n+Ok5MTI0aMoG7duixdulTnvX/++Yd//vlHeX3r1i0yZ85My5YtCQgI+A8l+nFUVyPXu3dv8uTJQ2BgIGPGjFEG9HVzc+Po0aN6jk4IIYQQP5utrS2AzqgVFhYWTJ8+ne7du/Pw4cMv2o9Go2Hq1KlMnDiRS5cuffHvfvbs2VfH/LOorkbuwoULlCpVKtHygIAAYmJi9BCREEIIIfRFo9EwdOhQDh8+TEhIiLJ8yJAhHD16lM2bN3/xvjp37kx0dDTTp0//ovUzZMhAmzZt6N+//1fH/bOoLpFLyMrKSmkvFy8iIkJP0QghhBDiZwsKCiJ79uzUqFFDWVa1alVKlSpFuXLlvng/efPmxcvLi/Lly3/R+k5OTixfvpx169axYMGCr477Z1FdIpc2bVoCAwMpUaIESZIkUZZrNBq0Wi0pUqTQY3RCCCGE+FkCAwOpXLkyNWvW1Jm+s1SpUmTIkCHRIMBz587l0KFD1KpVK9G+ihYtSvLkyTlz5oyyzMTEhMGDB9OuXTvy58+vLE+VKhVr167l2LFjdO3a9QeU7PtRXSIXHByMRqOhU6dOPHr06Kt7ogghhBDC8AUGBlKjRg3c3Ny4ffu2znvjx49PVEt24MAB+vbty5YtWz64v+XLl7Nnzx6dZStXrmT58uUsXrxYWebk5MTatWs5c+YMPj4+qs9DVJfI5cyZkwoVKnD16lV9hyKEEEIIPQgKCsLDwwNPT09evnypPI178eIFb9684eHDhx/s4BAaGqqT9Pn7+7Nr1y6OHz/+wSk+3717x4MHD5Scw8nJiXXr1hEaGkpAQACOjo7Kul/aoeJnU10id+rUKVKnTi2JnBBCCPGbatWqFQDr16/XWe7j48OSJUu+eD8pUqTA3Nz8i9cvW7YsmTJlIlOmTJw7d07nvWTJkn3xfn4m1SVyXbp0YfTo0Tg5OXHx4sVEc6teuHBBT5EJIYQQ4mf4lqTpQ9v4+/tjbW390W0StosDWLJkyVclimqgukTO0dGR9OnTM3HiRGWZVquVzg5CCCGEEO9RXSI3YcIEzp49i5eXFw8fPlR9I0MhhBBCCH1RXSKXJk0amjRposzoIIQQQgghPkx1U3Tt27ePXLly6TsMIYQQQgjVU12N3NatWxkyZAjZs2f/YGeHj40PI4QQQgjxu1FdIjd69GggrqfJ+6SzgxBCCCHE/6kukUuePLm+QxBCCCGEMAiqayMnhBBCCCG+jCoSudq1a3/xus7OzhQuXPgHRiOEEEIIYRhUkci1bNmSQ4cO0bFjR/74449E79vY2FCxYkWmTZvGrl27cHBw0EOUQgghhBDqooo2cm5ublStWpU2bdrQr18/IiMjefjwIW/fvsXOzo4UKVLw5MkTli5dSsmSJXn06JG+QxZCCCGE0DtVJHIQN6zIli1bcHBwoGjRoqRJkwYLCwuePHnC2bNn+ffff2WWByGEEOIXZdNm/nffpybJGeAdGiv7777/iBnNvuv+vpVqErl4T58+ZdOmTfoOQwghhBBC9VTRRk4IIYQQQnw9SeSEEEIIIQyUJHJCCCGEEAZKEjkhhBBCCAOl2kTO1NSUzJkzY2xsrO9QhBBCCCFUSXWJnIWFBePHjyc0NJQDBw6QJk0aAEaMGEHnzp31HJ0QQgghhHqoLpHr168fuXLlws3NjTdv3ijL9+zZg7u7u/4CE0IIIYRQGdWNI1e9enX++us
							 
						 
					
						
							
								
									
										
										
										
											2024-08-14 03:57:41 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "text/plain": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "<Figure size 640x480 with 1 Axes>"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "display_data"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "execution_stats = [time_pytorch_function_forward_backward(fn, embeddings) for fn in functions.values()]\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "execution_means = [stat[0] for stat in execution_stats]\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "execution_stds = [stat[1] for stat in execution_stats]\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-08-10 09:44:11 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-08-14 03:57:41 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "plot_execution_times(functions, execution_means, execution_stds, filename=\"2_forward-and-backward.pdf\")"
							 
						 
					
						
							
								
									
										
										
										
											2024-08-10 09:44:11 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "1gWX-Ayqia1k",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "1gWX-Ayqia1k"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "<br>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    " \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "## Speed comparison (Nvidia A100 GPU) with warmup and compilation (forward and backward pass)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 42,
							 
						 
					
						
							
								
									
										
										
										
											2024-08-10 09:44:11 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "id": "LQDiAPooiYAz",
							 
						 
					
						
							
								
									
										
										
										
											2024-08-14 03:57:41 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "LQDiAPooiYAz"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "import torch._dynamo\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "torch._dynamo.config.suppress_errors = True\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "def prepare_function(fn):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    fn = torch.compile(fn)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    return fn"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 43,
							 
						 
					
						
							
								
									
										
										
										
											2024-08-14 03:57:41 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "id": "aac06ffe",
							 
						 
					
						
							
								
									
										
										
										
											2024-08-10 09:44:11 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "colab": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "base_uri": "https://localhost:8080/",
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								     "height": 489
							 
						 
					
						
							
								
									
										
										
										
											2024-08-10 09:44:11 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
									
										
										
										
											2024-08-14 03:57:41 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "id": "aac06ffe",
							 
						 
					
						
							
								
									
										
										
										
											2024-09-05 18:24:33 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "outputId": "098c66b4-1201-4bdd-af23-e634f5ade806"
							 
						 
					
						
							
								
									
										
										
										
											2024-08-10 09:44:11 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "data": {
							 
						 
					
						
							
								
									
										
										
										
											2024-09-07 07:27:28 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAnIAAAHYCAYAAADJQQWAAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzddVhU6dvA8e/QKYIJirrG2t0NdqEiBiq2oihiYaJiF9gBdnestXbr2ootJjYWYmJQ7x9enJcRXXV/6JzR+3NdXrtzau6HOXPOPc95QmNnZxePEEIIIYTQOwa6DkAIIYQQQvw3ksgJIYQQQugpSeSEEEIIIfSUJHJCCCGEEHpKEjkhhBBCCD0liZwQQgghhJ6SRE4IIYQQQk9JIieEEEIIoackkRNCCCGE0FOSyAkhhBBC6KlfOpErXbo0S5cu5eLFi0RERFCrVq0k2/Tr14+LFy9y79491q1bR9asWXUQqRBCCCHE9zPSdQA/koWFBRcvXmTZsmUsWrQoyXofHx88PT3p0qULt2/fZsCAAaxevZoyZcrw/v37b3oPe3t7Xr9+ndyhCyGEEOI3Z2VlRXh4+L9uo7Gzs4v/SfHoVEREBC1atGDLli3KsosXLzJjxgymT58OgLW1NaGhoXh7e/PXX3999Zj29vZcuHDhh8UshBBCiN9bvnz5/jWZ+6Vr5P5N5syZSZ8+Pfv371eWvXr1ilOnTlG8ePHPJnImJiaYmpomWZ4vXz6plRNCCCFEsrGysuLChQtfzS9+20Qubdq0ADx58kRr+ZMnT5R1n+revTt9+/ZNsvz169e8evUq+YMUQgghhPgXv3Rnh+Q2adIksmTJovzLly+frkMSQgghxG/st03kHj9+DECaNGm0lqdJk0ZZ96kPHz7w6tUr5Z88ThVCCCGELv22idzt27d5+PAhFSpUUJZZW1tTtGhRTpw4ocPIhBBCCCG+zS/dRs7S0pI//vhDeZ0pUyby5ctHZGQk9+/fZ+bMmfTq1YubN28qw488fPhQq2erEEIIIYRa/dKJXKFChdi4caPyeuTIkQAsX74cb29vpkyZgoWFBRMmTMDGxoZjx47RuHHjbx5DTgghhBBCl36bceR+BGtra27dukWWLFmk16oQQgghks235hi/bRs5IYQQQgh9J4mcEEIIIYSekkROCCGEEEJPSSInhBBCCKGnJJETQgghhNBTksgJIYQQQugpSeSEEEIIIfSUJHJCCCGEEHpKEjkhhBBCCD0liZwQQgghhJ6SRE4IIYQQQk9JIieEEEIIoackkRNCCCGE0FNGug7gU5kyZaJ06dJkzJgRCwsLnj59yvnz5zlx4gTv37/XdXhCCCGEEKqhmkSuYcOGdOzYkUKFCvH48WMePnzIu3fvsLW1JUuWLLx//541a9YwefJk7t27p+twhRBCCCF0ThWJ3N69e4mOjmb58uW0atWKBw8eaK03MTGhePHiuLq6snv3bnr37s3GjRt1FK0QQgghhDpo7Ozs4nUdhLOzM3v37v2mbW1tbcmUKRNnz579wVF9nbW1Nbdu3SJLliy8evVK1+EIIYQQ4hfxrTmGamrkvlVkZCSRkZE/MBohhBBCCP2gul6rBQoUIHfu3MrrmjVrsnjxYgYOHIixsbEOIxNCCCGEUBfVJXITJkwge/bsAGTOnJnZs2cTFRVF3bp1GTJkiG6DE0IIIYRQEdUlctmyZeP8+fMA1KtXjyNHjtCxY0e8vb1xcXHRcXRCCCGEEOqhukROo9FgYPAxrIoVK7Jz504A7t+/j52dnS5DU62QkBAiIiKS/Bs3btxnt2/RogWbN2/mxo0b3Lhxg3Xr1lGkSBGtbfr06cPRo0e5c+eOsk3RokV/RnGEEEII8Y1Ul8idOXOGXr160bhxY8qUKaMkcpkzZ+bJkyc6jk6dqlSpQu7cuZV/DRo0AGDDhg2f3b5s2bKsW7eOevXqUaNGDe7fv8+aNWuwt7dXtrlx4wZ9+/alfPny1KpVizt37rBmzRpSpUr1U8okhBBCiK9TxfAjieXJk4eZM2eSMWNGZsyYQUBAAABjxozB1taWjh076jjC/6fW4UdGjhxJtWrVKF68+Ddtb2BgwM2bN+nbty8rV6787DYJZXV1deXAgQPJGa4QQgghPvGtOYbqauQuXbpE+fLl+eOPP5QkDsDf358uXbroMDL9YGxsTKNGjVi2bNk372NhYYGRkdEXh3UxNjamZcuWvHjxggsXLiRXqEIIIZLR9zazAahbty5Hjx7l/v37HDx4kCpVqnxx28DAQCIiIlRVoSJUmMglZmlpibW1NdbW1piYmGBubp6sxzcwMKB///6cPn2ae/fucfLkSXr16pWs7/Gz1apVCxsbG5YvX/7N+/j7+/Pw4UP279+vtbxatWrcvn2bBw8e4OXlhZubG8+ePUvukIUQQiSD721mU7x4cWbPns2SJUtwdnZmy5YtLF68mFy5ciXZtnbt2hQrVozw8PAfWgbx/VQxIHBimTJlYuzYsZQtWxYzMzNluUajIT4+nrRp0ybbe3Xr1o02bdrQpUsXQkNDKVSoENOmTePVq1fMmjUr2d7nZ/Lw8GDXrl08fPjwm7bv1q0brq6u1K1bl/fv32utO3ToEE5OTqRKlYoWLVowd+5cqlWrxtOnT39E6EIIIf4HERERWq+7devGzZs3+eeffz67fceOHdm9ezfTpk0DYPTo0Tg5OdG+fXt8fX2V7ezt7RkzZgwNGzZkxYoVP64A4j9RXSIXHByMRqPBx8eHJ0+eEB//45rwFS9enK1btyodKu7evYubm1uSHpz6ImPGjFSsWJFWrVopyywsLLC0tPzs9m3btqVTp060bduWJ0+ekCZNGmXdmzdviIqKIiwsjLCwME6ePMnx48fx8PBg0qRJP7ooQggh/gcJzWyCgoK+uE3x4sWZMWOG1rI9e/ZQq1Yt5bVGoyEoKIipU6dy5cqVHxav+O9Ul8jlzZuXypUrc/369R/+XidOnKBly5Zky5aNGzdukDdvXkqWLMmgQYM+u72JiQmmpqbKaysrqx8e4/do1qwZT548YceOHcqy/PnzU7JkySTblilThvLly7NkyRLy589P/vz5tdYfO3aMY8eOaS0zMDDAxMTkxwQvhBAi2XxLM5u0adMmGQ3iyZMnWk++unXrRkxMjN4+pfodqC6RCwkJIUOGDD8lkZs0aRLW1tYcPXqU2NhYDA0NGTlyJGvWrPns9t27d6dv374/PK7/QqPR0KxZM1auXElsbKyy/Pz583h6evL48WMmTJgAQPv27XF2dmbdunWEh4crNZJRUVFERUVhbm5Ou3btiI2N5eHDh6RKlYp27dphb2//xbYWQggh1ON7m9l8TsGCBfH09KRSpUrJGJlIbqpL5Lp378748eOxt7fn8uXLREdHa62/dOlSsr1X/fr1adiwIZ6enoSGhpI/f35GjhzJw4cPP9sOYNKkSVrV1FZWVqrpxVmxYkUcHR1ZunSp1vKoqChSp07N27dvlV9ejRs3xsjIiMaNGwMoNZBjx45l3LhxmJqakiVLFhYsWICdnR2RkZGEhIRQp04dqVoXQgiV+1wzm895/PixVpMagDRp0vD48WMASpUqRZo0aTh79qyy3sjIiOHDh9OpUycKFy6c/MGL76a6ceSKFSvGzJkzyZQpk7IsPj7+h3R2OHfuHJMnT2bu3LnKsl69etGoUSNKlSr11f3VOo7ct2jbti1WVla8fv2aefPm6TocIYQQyaRPnz60atWKAgUKaD2h+dScOXMwNzenefPmyrKtW7dy8eJFfH19sbW1JV26dFr7rFmzhlWrVrFs2bKf8uTsd/atOYbqauSmTJmi9TjwR3Z2MDc3Jy4uTmtZbGwsGo3mh72nEEII8aN8qZkNwIwZMwgPD2f48OEAzJw5k02bNtG5c2d27tyJq6srhQoVokePHgBERkYmGV80OjqaR48eSRKnIqpL5DJmzEjz5s0JCwv74e+1fft2evbsyb179wgNDaVAgQJ4eXl912C6QgghhFp8qZkNQIYMGbQqL06cOIGnpyd+fn4MHDiQmzdv0qJFC0JDQ39myOJ
							 
						 
					
						
							
								
									
										
										
										
											2024-08-10 09:44:11 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "text/plain": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "<Figure size 640x480 with 1 Axes>"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "display_data"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-08-14 03:57:41 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "execution_stats = [time_pytorch_function_forward_backward(prepare_function(fn), embeddings) for fn in functions.values()]\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "execution_means = [stat[0] for stat in execution_stats]\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "execution_stds = [stat[1] for stat in execution_stats]\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-08-14 03:57:41 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "plot_execution_times(functions, execution_means, execution_stds, filename=\"3_forward-and-backward-compiled.pdf\")"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  "accelerator": "GPU",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  "colab": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "gpuType": "A100",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "provenance": []
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  "kernelspec": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "display_name": "Python 3 (ipykernel)",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "language": "python",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "name": "python3"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-06 08:30:32 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  "language_info": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "codemirror_mode": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "name": "ipython",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "version": 3
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "file_extension": ".py",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "mimetype": "text/x-python",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "name": "python",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "nbconvert_exporter": "python",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "pygments_lexer": "ipython3",
							 
						 
					
						
							
								
									
										
										
										
											2024-10-23 03:23:31 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "version": "3.11.4"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-13 08:37:54 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 "nbformat": 4,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 "nbformat_minor": 5
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								}