2024-07-26 08:45:36 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								{
							 
						 
					
						
							
								
									
										
										
										
											2024-07-28 14:15:32 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  "cells": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "id": "Dlv8N4uWtXcN"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "<table style=\"width:100%\">\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "<tr>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "<td style=\"vertical-align:middle; text-align:left;\">\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "<font size=\"2\">\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "Supplementary code for the <a href=\"http://mng.bz/orYv\">Build a Large Language Model From Scratch</a> book by <a href=\"https://sebastianraschka.com\">Sebastian Raschka</a><br>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "<br>Code repository: <a href=\"https://github.com/rasbt/LLMs-from-scratch\">https://github.com/rasbt/LLMs-from-scratch</a>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "</font>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "</td>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "<td style=\"vertical-align:middle; text-align:left;\">\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "<a href=\"http://mng.bz/orYv\"><img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/cover-small.webp\" width=\"100px\"></a>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "</td>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "</tr>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "</table>"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      ]
							 
						 
					
						
							
								
									
										
										
										
											2024-07-26 08:45:36 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
									
										
										
										
											2024-07-28 14:15:32 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "id": "V6BXGeEJ_s-8"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "# Understanding PyTorch Buffers"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "id": "aQt9Ob1Y_8EH"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "In essence, PyTorch buffers are tensor attributes associated with a PyTorch module or model similar to parameters, but unlike parameters, buffers are not updated during training.\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "Buffers in PyTorch are particularly useful when dealing with GPU computations, as they need to be transferred between devices (like from CPU to GPU) alongside the model's parameters. Unlike parameters, buffers do not require gradient computation, but they still need to be on the correct device to ensure that all computations are performed correctly.\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "In chapter 3, we use PyTorch buffers via `self.register_buffer`, which is only briefly explained in the book. Since the concept and purpose are not immediately clear, this code notebook offers a longer explanation with a hands-on example."
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "id": "dAwGo_gYLY45"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "## An example without buffers"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "id": "0qBQC9IPAJVZ"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "Suppose we have the following code, which is based on code from chapter 3. This version has been modified to exclude buffers. It implements the causal self-attention mechanism used in LLMs:"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "cell_type": "code",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "execution_count": 1,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "id": "7wx-_rokAN04"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "import torch\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "import torch.nn as nn\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "class CausalAttentionWithoutBuffers(nn.Module):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "    def __init__(self, d_in, d_out, context_length,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "                 dropout, qkv_bias=False):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "        super().__init__()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "        self.d_out = d_out\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "        self.dropout = nn.Dropout(dropout)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "        self.mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "    def forward(self, x):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "        b, num_tokens, d_in = x.shape\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "        keys = self.W_key(x)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "        queries = self.W_query(x)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "        values = self.W_value(x)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "        attn_scores = queries @ keys.transpose(1, 2)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "        attn_scores.masked_fill_(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "            self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "        attn_weights = torch.softmax(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "            attn_scores / keys.shape[-1]**0.5, dim=-1\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "        )\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "        attn_weights = self.dropout(attn_weights)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "        context_vec = attn_weights @ values\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "        return context_vec"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "id": "nNrK-wLaNSi7"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "We can initialize and run the module as follows on some example data:"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "cell_type": "code",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "execution_count": 2,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "colab": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								          "base_uri": "https://localhost:8080/"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "id": "e1MZiIsPA0Py",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "outputId": "ce1407c6-c082-4755-b8ad-d9adcc9f153a"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								          "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								          "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								          "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            "tensor([[[-0.4519,  0.2216],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            "         [-0.5874,  0.0058],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            "         [-0.6300, -0.0632],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            "         [-0.5675, -0.0843],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            "         [-0.5526, -0.0981],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            "         [-0.5299, -0.1081]],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            "        [[-0.4519,  0.2216],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            "         [-0.5874,  0.0058],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            "         [-0.6300, -0.0632],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            "         [-0.5675, -0.0843],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            "         [-0.5526, -0.0981],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            "         [-0.5299, -0.1081]]])\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								          ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "torch.manual_seed(123)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "inputs = torch.tensor(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "  [[0.43, 0.15, 0.89], # Your     (x^1)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "   [0.55, 0.87, 0.66], # journey  (x^2)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "   [0.57, 0.85, 0.64], # starts   (x^3)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "   [0.22, 0.58, 0.33], # with     (x^4)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "   [0.77, 0.25, 0.10], # one      (x^5)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "   [0.05, 0.80, 0.55]] # step     (x^6)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        ")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "batch = torch.stack((inputs, inputs), dim=0)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "context_length = batch.shape[1]\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "d_in = inputs.shape[1]\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "d_out = 2\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "ca_without_buffer = CausalAttentionWithoutBuffers(d_in, d_out, context_length, 0.0)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "with torch.no_grad():\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "    context_vecs = ca_without_buffer(batch)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "print(context_vecs)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "id": "7_hqz6AgCCc1"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "So far, everything has worked fine so far.\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "However, when training LLMs, we typically use GPUs to accelerate the process. Therefore, let's transfer the `CausalAttentionWithoutBuffers` module onto a GPU device.\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "Please note that this operation requires the code to be run in an environment equipped with GPUs."
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "cell_type": "code",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "execution_count": 3,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "colab": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								          "base_uri": "https://localhost:8080/"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "id": "PYwn44HWCPJS",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "outputId": "d7236e0c-2a43-4770-ccc1-03c9d5d11421"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								          "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								          "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								          "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            "Machine has GPU: True\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								          ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "print(\"Machine has GPU:\", torch.cuda.is_available())\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "batch = batch.to(\"cuda\")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "ca_without_buffer.to(\"cuda\");"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "id": "4_lMki2_CoIR"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "Now, let's run the code again:"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "cell_type": "code",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "execution_count": 4,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "colab": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								          "base_uri": "https://localhost:8080/",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								          "height": 338
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "id": "KE9iLcjGC1V1",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "outputId": "ab6921c7-d7dd-44ea-9b92-1911037e3dcc"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								          "output_type": "error",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								          "ename": "RuntimeError",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								          "evalue": "expected self and mask to be on the same device, but got mask on cpu and self on cuda:0",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								          "traceback": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            "\u001b[0;32m<ipython-input-4-1e0d2e6638f6>\u001b[0m in \u001b[0;36m<cell line: 1>\u001b[0;34m()\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mno_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m     \u001b[0mcontext_vecs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mca_without_buffer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      3\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      4\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcontext_vecs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1530\u001b[0m             \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_compiled_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m  \u001b[0;31m# type: ignore[misc]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1531\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1532\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1533\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1534\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            "\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1539\u001b[0m                 \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_pre_hooks\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_hooks\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1540\u001b[0m                 or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1541\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1542\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1543\u001b[0m         \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            "\u001b[0;32m<ipython-input-1-cf1dad0dd611>\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m     21\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     22\u001b[0m         \u001b[0mattn_scores\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mqueries\u001b[0m \u001b[0;34m@\u001b[0m \u001b[0mkeys\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtranspose\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 23\u001b[0;31m         attn_scores.masked_fill_(\n\u001b[0m\u001b[1;32m     24\u001b[0m             self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)\n\u001b[1;32m     25\u001b[0m         attn_weights = torch.softmax(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            "\u001b[0;31mRuntimeError\u001b[0m: expected self and mask to be on the same device, but got mask on cpu and self on cuda:0"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								          ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "with torch.no_grad():\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "    context_vecs = ca_without_buffer(batch)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "print(context_vecs)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "id": "I7V26PLrC2gk"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "Running the code resulted in an error. What happened? It seems like we attempted a matrix multiplication between a tensor on a GPU and a tensor on a CPU. But we moved the module to the GPU!?\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "Let's double-check the device locations of some of the tensors:"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "cell_type": "code",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "execution_count": 5,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "colab": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								          "base_uri": "https://localhost:8080/"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "id": "vvYDPBRIDHfU",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "outputId": "4b9703a8-7035-4a2d-8643-c64d37b7abd2"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								          "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								          "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								          "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            "W_query.device: cuda:0\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            "mask.device: cpu\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								          ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "print(\"W_query.device:\", ca_without_buffer.W_query.weight.device)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "print(\"mask.device:\", ca_without_buffer.mask.device)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "cell_type": "code",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "execution_count": 6,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "colab": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								          "base_uri": "https://localhost:8080/"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "id": "d11nX-FFOJ3C",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "outputId": "1e92b0e8-dbc6-41f9-e88f-5d06e0726050"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								          "output_type": "execute_result",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								          "data": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            "text/plain": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								              "torch.Tensor"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								          },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								          "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								          "execution_count": 6
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "type(ca_without_buffer.mask)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "id": "Ojay-KY-DL5M"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "As we can see, the `mask` was not moved onto the GPU. That's because it's not a PyTorch parameter like the weights (e.g., `W_query.weight`).\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "This means we  have to manually move it to the GPU via `.to(\"cuda\")`:"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "cell_type": "code",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "execution_count": 7,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "colab": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								          "base_uri": "https://localhost:8080/"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "id": "QYirQ63zDYsW",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "outputId": "304628ac-bc4c-49c2-a0e1-ecf9385ddcd9"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								          "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								          "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								          "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            "mask.device: cuda:0\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								          ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "ca_without_buffer.mask = ca_without_buffer.mask.to(\"cuda\")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "print(\"mask.device:\", ca_without_buffer.mask.device)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "id": "4OoTqzkpDfAm"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "Let's try our code again:"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "cell_type": "code",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "execution_count": 8,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "colab": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								          "base_uri": "https://localhost:8080/"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "id": "WfF0yBZODdAZ",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "outputId": "291cfb54-86e6-45f9-99d1-fa145319f379"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								          "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								          "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								          "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            "tensor([[[-0.4519,  0.2216],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            "         [-0.5874,  0.0058],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            "         [-0.6300, -0.0632],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            "         [-0.5675, -0.0843],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            "         [-0.5526, -0.0981],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            "         [-0.5299, -0.1081]],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            "        [[-0.4519,  0.2216],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            "         [-0.5874,  0.0058],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            "         [-0.6300, -0.0632],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            "         [-0.5675, -0.0843],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            "         [-0.5526, -0.0981],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            "         [-0.5299, -0.1081]]], device='cuda:0')\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								          ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "with torch.no_grad():\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "    context_vecs = ca_without_buffer(batch)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "print(context_vecs)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "id": "oUrVgWuuD7UE"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "This time, it worked!\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "However, remembering to move individual tensors to the GPU can be tedious. As we will see in the next section, it's easier to use `register_buffer` to register the `mask` as a buffer."
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "id": "StS2wUrBLeuW"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "## An example with buffers"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "id": "nEqD2NFzPO6l"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "Let's now modify the causal attention class to register the causal `mask` as a buffer:"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "cell_type": "code",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "execution_count": 9,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "id": "ndsYj3Zf6N8U"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "import torch\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "import torch.nn as nn\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "class CausalAttentionWithBuffer(nn.Module):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "    def __init__(self, d_in, d_out, context_length,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "                 dropout, qkv_bias=False):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "        super().__init__()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "        self.d_out = d_out\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "        self.dropout = nn.Dropout(dropout)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "        # Old:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "        # self.mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "        # New:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "        self.register_buffer(\"mask\", torch.triu(torch.ones(context_length, context_length), diagonal=1))\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "    def forward(self, x):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "        b, num_tokens, d_in = x.shape\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "        keys = self.W_key(x)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "        queries = self.W_query(x)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "        values = self.W_value(x)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "        attn_scores = queries @ keys.transpose(1, 2)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "        attn_scores.masked_fill_(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "            self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "        attn_weights = torch.softmax(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "            attn_scores / keys.shape[-1]**0.5, dim=-1\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "        )\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "        attn_weights = self.dropout(attn_weights)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "        context_vec = attn_weights @ values\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "        return context_vec"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "id": "_AL1X6y3Eb7S"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "Now, conveniently, if we move the module to the GPU, the mask will be located on the GPU as well:"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "cell_type": "code",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "execution_count": 10,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "colab": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								          "base_uri": "https://localhost:8080/"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "id": "8_VCxEa76j00",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "outputId": "4d1af501-5a9e-46aa-b1ac-63bf0c68e02a"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								          "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								          "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								          "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            "W_query.device: cuda:0\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            "mask.device: cuda:0\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								          ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "ca_with_buffer = CausalAttentionWithBuffer(d_in, d_out, context_length, 0.0)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "ca_with_buffer.to(\"cuda\")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "print(\"W_query.device:\", ca_with_buffer.W_query.weight.device)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "print(\"mask.device:\", ca_with_buffer.mask.device)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "cell_type": "code",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "execution_count": 11,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "colab": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								          "base_uri": "https://localhost:8080/"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "id": "TBWvKlMe7bbB",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "outputId": "e43bf8ab-3fb9-417e-d087-560858332d86"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								          "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								          "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								          "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            "tensor([[[0.4772, 0.1063],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            "         [0.5891, 0.3257],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            "         [0.6202, 0.3860],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            "         [0.5478, 0.3589],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            "         [0.5321, 0.3428],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            "         [0.5077, 0.3493]],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            "        [[0.4772, 0.1063],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            "         [0.5891, 0.3257],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            "         [0.6202, 0.3860],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            "         [0.5478, 0.3589],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            "         [0.5321, 0.3428],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            "         [0.5077, 0.3493]]], device='cuda:0')\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								          ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "with torch.no_grad():\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "    context_vecs = ca_with_buffer(batch)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "print(context_vecs)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "id": "xvOTh4NNPjef"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "As we can see above, registering a tensor as a buffer can make our lives a lot easier: We don't have to remember to move tensors to a target device like a GPU manually."
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "## Buffers and `state_dict`"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "id": "Q-5YYKmJte3h"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "- Another advantage of PyTorch buffers, over regular tensors, is that they get included in a model's `state_dict`\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "- For example, consider the `state_dict` of the causal attention object without buffers"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "id": "YIHHawPbtjfp"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "cell_type": "code",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "ca_without_buffer.state_dict()"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "colab": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								          "base_uri": "https://localhost:8080/"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "id": "c217juzqtxsS",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "outputId": "dbae3c3d-f4f8-4c70-a64f-90906561d8d9"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "execution_count": 12,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								          "output_type": "execute_result",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								          "data": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            "text/plain": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								              "OrderedDict([('W_query.weight',\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								              "              tensor([[-0.2354,  0.0191, -0.2867],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								              "                      [ 0.2177, -0.4919,  0.4232]], device='cuda:0')),\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								              "             ('W_key.weight',\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								              "              tensor([[-0.4196, -0.4590, -0.3648],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								              "                      [ 0.2615, -0.2133,  0.2161]], device='cuda:0')),\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								              "             ('W_value.weight',\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								              "              tensor([[-0.4900, -0.3503, -0.2120],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								              "                      [-0.1135, -0.4404,  0.3780]], device='cuda:0'))])"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								          },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								          "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								          "execution_count": 12
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "- The mask is not included in the `state_dict` above\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "- However, the mask *is* included in the `state_dict` below, thanks to registering it as a buffer"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "id": "NdmZuPaqt6aO"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "cell_type": "code",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "ca_with_buffer.state_dict()"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "colab": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								          "base_uri": "https://localhost:8080/"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "id": "uGIGQAwPt1Pl",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "outputId": "00f9bc44-63f9-4ebc-87ea-d4b8cafd81c1"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "execution_count": 13,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								          "output_type": "execute_result",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								          "data": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            "text/plain": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								              "OrderedDict([('mask',\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								              "              tensor([[0., 1., 1., 1., 1., 1.],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								              "                      [0., 0., 1., 1., 1., 1.],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								              "                      [0., 0., 0., 1., 1., 1.],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								              "                      [0., 0., 0., 0., 1., 1.],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								              "                      [0., 0., 0., 0., 0., 1.],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								              "                      [0., 0., 0., 0., 0., 0.]], device='cuda:0')),\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								              "             ('W_query.weight',\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								              "              tensor([[-0.1362,  0.1853,  0.4083],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								              "                      [ 0.1076,  0.1579,  0.5573]], device='cuda:0')),\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								              "             ('W_key.weight',\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								              "              tensor([[-0.2604,  0.1829, -0.2569],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								              "                      [ 0.4126,  0.4611, -0.5323]], device='cuda:0')),\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								              "             ('W_value.weight',\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								              "              tensor([[ 0.4929,  0.2757,  0.2516],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								              "                      [ 0.2377,  0.4800, -0.0762]], device='cuda:0'))])"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								          },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								          "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								          "execution_count": 13
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "- A `state_dict` is useful when saving and loading trained PyTorch models, for example\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "- In this particular case, saving and loading the `mask` is maybe not super useful, because it remains unchanged during training; so, for demonstration purposes, let's assume it was modified where all `1`'s were changed to `2`'s:"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "id": "ACC-a1Hnt4Zv"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "cell_type": "code",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "ca_with_buffer.mask[ca_with_buffer.mask == 1.] = 2.\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "ca_with_buffer.mask"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "colab": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								          "base_uri": "https://localhost:8080/"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "id": "RLm1Sw0cuhvy",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "outputId": "4b2cc70f-1709-44e4-aa17-4e01353b86f8"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "execution_count": 14,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								          "output_type": "execute_result",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								          "data": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            "text/plain": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								              "tensor([[0., 2., 2., 2., 2., 2.],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								              "        [0., 0., 2., 2., 2., 2.],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								              "        [0., 0., 0., 2., 2., 2.],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								              "        [0., 0., 0., 0., 2., 2.],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								              "        [0., 0., 0., 0., 0., 2.],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								              "        [0., 0., 0., 0., 0., 0.]], device='cuda:0')"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								          },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								          "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								          "execution_count": 14
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "- Then, if we save and load the model, we can see that the mask is restored with the modified value"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "id": "BIkGgGqqvp4S"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "cell_type": "code",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "torch.save(ca_with_buffer.state_dict(), \"model.pth\")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "new_ca_with_buffer = CausalAttentionWithBuffer(d_in, d_out, context_length, 0.0)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "new_ca_with_buffer.load_state_dict(torch.load(\"model.pth\"))\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "new_ca_with_buffer.mask"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "colab": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								          "base_uri": "https://localhost:8080/"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "id": "e8g0QHUhuVBw",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "outputId": "cc7ee348-7f94-4117-e5cc-e0e01a94e906"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "execution_count": 15,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								          "output_type": "execute_result",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								          "data": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            "text/plain": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								              "tensor([[0., 2., 2., 2., 2., 2.],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								              "        [0., 0., 2., 2., 2., 2.],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								              "        [0., 0., 0., 2., 2., 2.],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								              "        [0., 0., 0., 0., 2., 2.],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								              "        [0., 0., 0., 0., 0., 2.],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								              "        [0., 0., 0., 0., 0., 0.]])"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								          },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								          "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								          "execution_count": 15
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "- This is not true if we don't use buffers:"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "id": "0pPaJk7bvBD7"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "cell_type": "code",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "ca_without_buffer.mask[ca_without_buffer.mask == 1.] = 2.\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "torch.save(ca_without_buffer.state_dict(), \"model.pth\")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "new_ca_without_buffer = CausalAttentionWithoutBuffers(d_in, d_out, context_length, 0.0)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "new_ca_without_buffer.load_state_dict(torch.load(\"model.pth\"))\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "new_ca_without_buffer.mask"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "colab": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								          "base_uri": "https://localhost:8080/"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "id": "D03w8vDyvBRS",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "outputId": "28071601-120c-42da-b327-bb293793839f"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "execution_count": 16,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								          "output_type": "execute_result",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								          "data": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            "text/plain": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								              "tensor([[0., 1., 1., 1., 1., 1.],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								              "        [0., 0., 1., 1., 1., 1.],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								              "        [0., 0., 0., 1., 1., 1.],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								              "        [0., 0., 0., 0., 1., 1.],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								              "        [0., 0., 0., 0., 0., 1.],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								              "        [0., 0., 0., 0., 0., 0.]])"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								            ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								          },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								          "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								          "execution_count": 16
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        }
							 
						 
					
						
							
								
									
										
										
										
											2024-07-26 08:45:36 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								      ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
									
										
										
										
											2024-07-28 14:15:32 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "accelerator": "GPU",
							 
						 
					
						
							
								
									
										
										
										
											2024-07-26 08:45:36 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "colab": {
							 
						 
					
						
							
								
									
										
										
										
											2024-07-28 14:15:32 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "gpuType": "L4",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "provenance": []
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "kernelspec": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "display_name": "Python 3 (ipykernel)",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "language": "python",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "name": "python3"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "language_info": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "codemirror_mode": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "name": "ipython",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								        "version": 3
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "file_extension": ".py",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "mimetype": "text/x-python",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "name": "python",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "nbconvert_exporter": "python",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "pygments_lexer": "ipython3",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "version": "3.10.6"
							 
						 
					
						
							
								
									
										
										
										
											2024-07-26 08:45:36 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2024-07-28 14:15:32 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  "nbformat": 4,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  "nbformat_minor": 0
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								}