LLMs-from-scratch/ch03/03_understanding-buffers/understanding-buffers.ipynb
2024-07-28 14:15:32 -05:00

820 lines
30 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "Dlv8N4uWtXcN"
},
"source": [
"<table style=\"width:100%\">\n",
"<tr>\n",
"<td style=\"vertical-align:middle; text-align:left;\">\n",
"<font size=\"2\">\n",
"Supplementary code for the <a href=\"http://mng.bz/orYv\">Build a Large Language Model From Scratch</a> book by <a href=\"https://sebastianraschka.com\">Sebastian Raschka</a><br>\n",
"<br>Code repository: <a href=\"https://github.com/rasbt/LLMs-from-scratch\">https://github.com/rasbt/LLMs-from-scratch</a>\n",
"</font>\n",
"</td>\n",
"<td style=\"vertical-align:middle; text-align:left;\">\n",
"<a href=\"http://mng.bz/orYv\"><img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/cover-small.webp\" width=\"100px\"></a>\n",
"</td>\n",
"</tr>\n",
"</table>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "V6BXGeEJ_s-8"
},
"source": [
"# Understanding PyTorch Buffers"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "aQt9Ob1Y_8EH"
},
"source": [
"In essence, PyTorch buffers are tensor attributes associated with a PyTorch module or model similar to parameters, but unlike parameters, buffers are not updated during training.\n",
"\n",
"Buffers in PyTorch are particularly useful when dealing with GPU computations, as they need to be transferred between devices (like from CPU to GPU) alongside the model's parameters. Unlike parameters, buffers do not require gradient computation, but they still need to be on the correct device to ensure that all computations are performed correctly.\n",
"\n",
"In chapter 3, we use PyTorch buffers via `self.register_buffer`, which is only briefly explained in the book. Since the concept and purpose are not immediately clear, this code notebook offers a longer explanation with a hands-on example."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dAwGo_gYLY45"
},
"source": [
"## An example without buffers"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0qBQC9IPAJVZ"
},
"source": [
"Suppose we have the following code, which is based on code from chapter 3. This version has been modified to exclude buffers. It implements the causal self-attention mechanism used in LLMs:"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"id": "7wx-_rokAN04"
},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"\n",
"class CausalAttentionWithoutBuffers(nn.Module):\n",
"\n",
" def __init__(self, d_in, d_out, context_length,\n",
" dropout, qkv_bias=False):\n",
" super().__init__()\n",
" self.d_out = d_out\n",
" self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
" self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
" self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
" self.dropout = nn.Dropout(dropout)\n",
" self.mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)\n",
"\n",
" def forward(self, x):\n",
" b, num_tokens, d_in = x.shape\n",
" keys = self.W_key(x)\n",
" queries = self.W_query(x)\n",
" values = self.W_value(x)\n",
"\n",
" attn_scores = queries @ keys.transpose(1, 2)\n",
" attn_scores.masked_fill_(\n",
" self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)\n",
" attn_weights = torch.softmax(\n",
" attn_scores / keys.shape[-1]**0.5, dim=-1\n",
" )\n",
" attn_weights = self.dropout(attn_weights)\n",
"\n",
" context_vec = attn_weights @ values\n",
" return context_vec"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "nNrK-wLaNSi7"
},
"source": [
"We can initialize and run the module as follows on some example data:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "e1MZiIsPA0Py",
"outputId": "ce1407c6-c082-4755-b8ad-d9adcc9f153a"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"tensor([[[-0.4519, 0.2216],\n",
" [-0.5874, 0.0058],\n",
" [-0.6300, -0.0632],\n",
" [-0.5675, -0.0843],\n",
" [-0.5526, -0.0981],\n",
" [-0.5299, -0.1081]],\n",
"\n",
" [[-0.4519, 0.2216],\n",
" [-0.5874, 0.0058],\n",
" [-0.6300, -0.0632],\n",
" [-0.5675, -0.0843],\n",
" [-0.5526, -0.0981],\n",
" [-0.5299, -0.1081]]])\n"
]
}
],
"source": [
"torch.manual_seed(123)\n",
"\n",
"inputs = torch.tensor(\n",
" [[0.43, 0.15, 0.89], # Your (x^1)\n",
" [0.55, 0.87, 0.66], # journey (x^2)\n",
" [0.57, 0.85, 0.64], # starts (x^3)\n",
" [0.22, 0.58, 0.33], # with (x^4)\n",
" [0.77, 0.25, 0.10], # one (x^5)\n",
" [0.05, 0.80, 0.55]] # step (x^6)\n",
")\n",
"\n",
"batch = torch.stack((inputs, inputs), dim=0)\n",
"context_length = batch.shape[1]\n",
"d_in = inputs.shape[1]\n",
"d_out = 2\n",
"\n",
"ca_without_buffer = CausalAttentionWithoutBuffers(d_in, d_out, context_length, 0.0)\n",
"\n",
"with torch.no_grad():\n",
" context_vecs = ca_without_buffer(batch)\n",
"\n",
"print(context_vecs)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7_hqz6AgCCc1"
},
"source": [
"So far, everything has worked fine so far.\n",
"\n",
"However, when training LLMs, we typically use GPUs to accelerate the process. Therefore, let's transfer the `CausalAttentionWithoutBuffers` module onto a GPU device.\n",
"\n",
"Please note that this operation requires the code to be run in an environment equipped with GPUs."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "PYwn44HWCPJS",
"outputId": "d7236e0c-2a43-4770-ccc1-03c9d5d11421"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Machine has GPU: True\n"
]
}
],
"source": [
"print(\"Machine has GPU:\", torch.cuda.is_available())\n",
"\n",
"batch = batch.to(\"cuda\")\n",
"ca_without_buffer.to(\"cuda\");"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4_lMki2_CoIR"
},
"source": [
"Now, let's run the code again:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 338
},
"id": "KE9iLcjGC1V1",
"outputId": "ab6921c7-d7dd-44ea-9b92-1911037e3dcc"
},
"outputs": [
{
"output_type": "error",
"ename": "RuntimeError",
"evalue": "expected self and mask to be on the same device, but got mask on cpu and self on cuda:0",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-4-1e0d2e6638f6>\u001b[0m in \u001b[0;36m<cell line: 1>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mno_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mcontext_vecs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mca_without_buffer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 3\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcontext_vecs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1530\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_compiled_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# type: ignore[misc]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1531\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1532\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1533\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1534\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1539\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_pre_hooks\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_hooks\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1540\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1541\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1542\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1543\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m<ipython-input-1-cf1dad0dd611>\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 21\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 22\u001b[0m \u001b[0mattn_scores\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mqueries\u001b[0m \u001b[0;34m@\u001b[0m \u001b[0mkeys\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtranspose\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 23\u001b[0;31m attn_scores.masked_fill_(\n\u001b[0m\u001b[1;32m 24\u001b[0m self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)\n\u001b[1;32m 25\u001b[0m attn_weights = torch.softmax(\n",
"\u001b[0;31mRuntimeError\u001b[0m: expected self and mask to be on the same device, but got mask on cpu and self on cuda:0"
]
}
],
"source": [
"with torch.no_grad():\n",
" context_vecs = ca_without_buffer(batch)\n",
"\n",
"print(context_vecs)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "I7V26PLrC2gk"
},
"source": [
"Running the code resulted in an error. What happened? It seems like we attempted a matrix multiplication between a tensor on a GPU and a tensor on a CPU. But we moved the module to the GPU!?\n",
"\n",
"\n",
"Let's double-check the device locations of some of the tensors:"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "vvYDPBRIDHfU",
"outputId": "4b9703a8-7035-4a2d-8643-c64d37b7abd2"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"W_query.device: cuda:0\n",
"mask.device: cpu\n"
]
}
],
"source": [
"print(\"W_query.device:\", ca_without_buffer.W_query.weight.device)\n",
"print(\"mask.device:\", ca_without_buffer.mask.device)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "d11nX-FFOJ3C",
"outputId": "1e92b0e8-dbc6-41f9-e88f-5d06e0726050"
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"torch.Tensor"
]
},
"metadata": {},
"execution_count": 6
}
],
"source": [
"type(ca_without_buffer.mask)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Ojay-KY-DL5M"
},
"source": [
"As we can see, the `mask` was not moved onto the GPU. That's because it's not a PyTorch parameter like the weights (e.g., `W_query.weight`).\n",
"\n",
"This means we have to manually move it to the GPU via `.to(\"cuda\")`:"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "QYirQ63zDYsW",
"outputId": "304628ac-bc4c-49c2-a0e1-ecf9385ddcd9"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"mask.device: cuda:0\n"
]
}
],
"source": [
"ca_without_buffer.mask = ca_without_buffer.mask.to(\"cuda\")\n",
"print(\"mask.device:\", ca_without_buffer.mask.device)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4OoTqzkpDfAm"
},
"source": [
"Let's try our code again:"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "WfF0yBZODdAZ",
"outputId": "291cfb54-86e6-45f9-99d1-fa145319f379"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"tensor([[[-0.4519, 0.2216],\n",
" [-0.5874, 0.0058],\n",
" [-0.6300, -0.0632],\n",
" [-0.5675, -0.0843],\n",
" [-0.5526, -0.0981],\n",
" [-0.5299, -0.1081]],\n",
"\n",
" [[-0.4519, 0.2216],\n",
" [-0.5874, 0.0058],\n",
" [-0.6300, -0.0632],\n",
" [-0.5675, -0.0843],\n",
" [-0.5526, -0.0981],\n",
" [-0.5299, -0.1081]]], device='cuda:0')\n"
]
}
],
"source": [
"with torch.no_grad():\n",
" context_vecs = ca_without_buffer(batch)\n",
"\n",
"print(context_vecs)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "oUrVgWuuD7UE"
},
"source": [
"This time, it worked!\n",
"\n",
"However, remembering to move individual tensors to the GPU can be tedious. As we will see in the next section, it's easier to use `register_buffer` to register the `mask` as a buffer."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "StS2wUrBLeuW"
},
"source": [
"## An example with buffers"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "nEqD2NFzPO6l"
},
"source": [
"Let's now modify the causal attention class to register the causal `mask` as a buffer:"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"id": "ndsYj3Zf6N8U"
},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"\n",
"class CausalAttentionWithBuffer(nn.Module):\n",
"\n",
" def __init__(self, d_in, d_out, context_length,\n",
" dropout, qkv_bias=False):\n",
" super().__init__()\n",
" self.d_out = d_out\n",
" self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
" self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
" self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
" self.dropout = nn.Dropout(dropout)\n",
" # Old:\n",
" # self.mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)\n",
"\n",
" # New:\n",
" self.register_buffer(\"mask\", torch.triu(torch.ones(context_length, context_length), diagonal=1))\n",
"\n",
" def forward(self, x):\n",
" b, num_tokens, d_in = x.shape\n",
" keys = self.W_key(x)\n",
" queries = self.W_query(x)\n",
" values = self.W_value(x)\n",
"\n",
" attn_scores = queries @ keys.transpose(1, 2)\n",
" attn_scores.masked_fill_(\n",
" self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)\n",
" attn_weights = torch.softmax(\n",
" attn_scores / keys.shape[-1]**0.5, dim=-1\n",
" )\n",
" attn_weights = self.dropout(attn_weights)\n",
"\n",
" context_vec = attn_weights @ values\n",
" return context_vec"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_AL1X6y3Eb7S"
},
"source": [
"Now, conveniently, if we move the module to the GPU, the mask will be located on the GPU as well:"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "8_VCxEa76j00",
"outputId": "4d1af501-5a9e-46aa-b1ac-63bf0c68e02a"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"W_query.device: cuda:0\n",
"mask.device: cuda:0\n"
]
}
],
"source": [
"ca_with_buffer = CausalAttentionWithBuffer(d_in, d_out, context_length, 0.0)\n",
"ca_with_buffer.to(\"cuda\")\n",
"\n",
"print(\"W_query.device:\", ca_with_buffer.W_query.weight.device)\n",
"print(\"mask.device:\", ca_with_buffer.mask.device)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "TBWvKlMe7bbB",
"outputId": "e43bf8ab-3fb9-417e-d087-560858332d86"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"tensor([[[0.4772, 0.1063],\n",
" [0.5891, 0.3257],\n",
" [0.6202, 0.3860],\n",
" [0.5478, 0.3589],\n",
" [0.5321, 0.3428],\n",
" [0.5077, 0.3493]],\n",
"\n",
" [[0.4772, 0.1063],\n",
" [0.5891, 0.3257],\n",
" [0.6202, 0.3860],\n",
" [0.5478, 0.3589],\n",
" [0.5321, 0.3428],\n",
" [0.5077, 0.3493]]], device='cuda:0')\n"
]
}
],
"source": [
"with torch.no_grad():\n",
" context_vecs = ca_with_buffer(batch)\n",
"\n",
"print(context_vecs)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "xvOTh4NNPjef"
},
"source": [
"As we can see above, registering a tensor as a buffer can make our lives a lot easier: We don't have to remember to move tensors to a target device like a GPU manually."
]
},
{
"cell_type": "markdown",
"source": [
"## Buffers and `state_dict`"
],
"metadata": {
"id": "Q-5YYKmJte3h"
}
},
{
"cell_type": "markdown",
"source": [
"- Another advantage of PyTorch buffers, over regular tensors, is that they get included in a model's `state_dict`\n",
"- For example, consider the `state_dict` of the causal attention object without buffers"
],
"metadata": {
"id": "YIHHawPbtjfp"
}
},
{
"cell_type": "code",
"source": [
"ca_without_buffer.state_dict()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "c217juzqtxsS",
"outputId": "dbae3c3d-f4f8-4c70-a64f-90906561d8d9"
},
"execution_count": 12,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"OrderedDict([('W_query.weight',\n",
" tensor([[-0.2354, 0.0191, -0.2867],\n",
" [ 0.2177, -0.4919, 0.4232]], device='cuda:0')),\n",
" ('W_key.weight',\n",
" tensor([[-0.4196, -0.4590, -0.3648],\n",
" [ 0.2615, -0.2133, 0.2161]], device='cuda:0')),\n",
" ('W_value.weight',\n",
" tensor([[-0.4900, -0.3503, -0.2120],\n",
" [-0.1135, -0.4404, 0.3780]], device='cuda:0'))])"
]
},
"metadata": {},
"execution_count": 12
}
]
},
{
"cell_type": "markdown",
"source": [
"- The mask is not included in the `state_dict` above\n",
"- However, the mask *is* included in the `state_dict` below, thanks to registering it as a buffer"
],
"metadata": {
"id": "NdmZuPaqt6aO"
}
},
{
"cell_type": "code",
"source": [
"ca_with_buffer.state_dict()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "uGIGQAwPt1Pl",
"outputId": "00f9bc44-63f9-4ebc-87ea-d4b8cafd81c1"
},
"execution_count": 13,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"OrderedDict([('mask',\n",
" tensor([[0., 1., 1., 1., 1., 1.],\n",
" [0., 0., 1., 1., 1., 1.],\n",
" [0., 0., 0., 1., 1., 1.],\n",
" [0., 0., 0., 0., 1., 1.],\n",
" [0., 0., 0., 0., 0., 1.],\n",
" [0., 0., 0., 0., 0., 0.]], device='cuda:0')),\n",
" ('W_query.weight',\n",
" tensor([[-0.1362, 0.1853, 0.4083],\n",
" [ 0.1076, 0.1579, 0.5573]], device='cuda:0')),\n",
" ('W_key.weight',\n",
" tensor([[-0.2604, 0.1829, -0.2569],\n",
" [ 0.4126, 0.4611, -0.5323]], device='cuda:0')),\n",
" ('W_value.weight',\n",
" tensor([[ 0.4929, 0.2757, 0.2516],\n",
" [ 0.2377, 0.4800, -0.0762]], device='cuda:0'))])"
]
},
"metadata": {},
"execution_count": 13
}
]
},
{
"cell_type": "markdown",
"source": [
"- A `state_dict` is useful when saving and loading trained PyTorch models, for example\n",
"- In this particular case, saving and loading the `mask` is maybe not super useful, because it remains unchanged during training; so, for demonstration purposes, let's assume it was modified where all `1`'s were changed to `2`'s:"
],
"metadata": {
"id": "ACC-a1Hnt4Zv"
}
},
{
"cell_type": "code",
"source": [
"ca_with_buffer.mask[ca_with_buffer.mask == 1.] = 2.\n",
"ca_with_buffer.mask"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "RLm1Sw0cuhvy",
"outputId": "4b2cc70f-1709-44e4-aa17-4e01353b86f8"
},
"execution_count": 14,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"tensor([[0., 2., 2., 2., 2., 2.],\n",
" [0., 0., 2., 2., 2., 2.],\n",
" [0., 0., 0., 2., 2., 2.],\n",
" [0., 0., 0., 0., 2., 2.],\n",
" [0., 0., 0., 0., 0., 2.],\n",
" [0., 0., 0., 0., 0., 0.]], device='cuda:0')"
]
},
"metadata": {},
"execution_count": 14
}
]
},
{
"cell_type": "markdown",
"source": [
"- Then, if we save and load the model, we can see that the mask is restored with the modified value"
],
"metadata": {
"id": "BIkGgGqqvp4S"
}
},
{
"cell_type": "code",
"source": [
"torch.save(ca_with_buffer.state_dict(), \"model.pth\")\n",
"\n",
"new_ca_with_buffer = CausalAttentionWithBuffer(d_in, d_out, context_length, 0.0)\n",
"new_ca_with_buffer.load_state_dict(torch.load(\"model.pth\"))\n",
"\n",
"new_ca_with_buffer.mask"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "e8g0QHUhuVBw",
"outputId": "cc7ee348-7f94-4117-e5cc-e0e01a94e906"
},
"execution_count": 15,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"tensor([[0., 2., 2., 2., 2., 2.],\n",
" [0., 0., 2., 2., 2., 2.],\n",
" [0., 0., 0., 2., 2., 2.],\n",
" [0., 0., 0., 0., 2., 2.],\n",
" [0., 0., 0., 0., 0., 2.],\n",
" [0., 0., 0., 0., 0., 0.]])"
]
},
"metadata": {},
"execution_count": 15
}
]
},
{
"cell_type": "markdown",
"source": [
"- This is not true if we don't use buffers:"
],
"metadata": {
"id": "0pPaJk7bvBD7"
}
},
{
"cell_type": "code",
"source": [
"ca_without_buffer.mask[ca_without_buffer.mask == 1.] = 2.\n",
"\n",
"torch.save(ca_without_buffer.state_dict(), \"model.pth\")\n",
"\n",
"new_ca_without_buffer = CausalAttentionWithoutBuffers(d_in, d_out, context_length, 0.0)\n",
"new_ca_without_buffer.load_state_dict(torch.load(\"model.pth\"))\n",
"\n",
"new_ca_without_buffer.mask"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "D03w8vDyvBRS",
"outputId": "28071601-120c-42da-b327-bb293793839f"
},
"execution_count": 16,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"tensor([[0., 1., 1., 1., 1., 1.],\n",
" [0., 0., 1., 1., 1., 1.],\n",
" [0., 0., 0., 1., 1., 1.],\n",
" [0., 0., 0., 0., 1., 1.],\n",
" [0., 0., 0., 0., 0., 1.],\n",
" [0., 0., 0., 0., 0., 0.]])"
]
},
"metadata": {},
"execution_count": 16
}
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"gpuType": "L4",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
}
},
"nbformat": 4,
"nbformat_minor": 0
}