add state_dict example

This commit is contained in:
rasbt 2024-07-28 14:15:32 -05:00
parent ce33e706ba
commit 358717870b

View File

@ -2,7 +2,9 @@
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"metadata": {
"id": "Dlv8N4uWtXcN"
},
"source": [
"<table style=\"width:100%\">\n",
"<tr>\n",
@ -117,12 +119,12 @@
"base_uri": "https://localhost:8080/"
},
"id": "e1MZiIsPA0Py",
"outputId": "a0746523-3cf3-492f-e996-495c21371837"
"outputId": "ce1407c6-c082-4755-b8ad-d9adcc9f153a"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"name": "stdout",
"text": [
"tensor([[[-0.4519, 0.2216],\n",
" [-0.5874, 0.0058],\n",
@ -186,12 +188,12 @@
"base_uri": "https://localhost:8080/"
},
"id": "PYwn44HWCPJS",
"outputId": "1aa6bfe9-e9a9-477f-e944-65388820498d"
"outputId": "d7236e0c-2a43-4770-ccc1-03c9d5d11421"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"name": "stdout",
"text": [
"Machine has GPU: True\n"
]
@ -222,13 +224,13 @@
"height": 338
},
"id": "KE9iLcjGC1V1",
"outputId": "110f444f-f887-4a0a-a156-a263b444941f"
"outputId": "ab6921c7-d7dd-44ea-9b92-1911037e3dcc"
},
"outputs": [
{
"output_type": "error",
"ename": "RuntimeError",
"evalue": "expected self and mask to be on the same device, but got mask on cpu and self on cuda:0",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
@ -267,12 +269,12 @@
"base_uri": "https://localhost:8080/"
},
"id": "vvYDPBRIDHfU",
"outputId": "5765233b-432d-4078-9064-26260b5ea672"
"outputId": "4b9703a8-7035-4a2d-8643-c64d37b7abd2"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"name": "stdout",
"text": [
"W_query.device: cuda:0\n",
"mask.device: cpu\n"
@ -292,18 +294,18 @@
"base_uri": "https://localhost:8080/"
},
"id": "d11nX-FFOJ3C",
"outputId": "3f8b2dac-378c-49b7-c544-61b91fe36351"
"outputId": "1e92b0e8-dbc6-41f9-e88f-5d06e0726050"
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"torch.Tensor"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
"execution_count": 6
}
],
"source": [
@ -329,12 +331,12 @@
"base_uri": "https://localhost:8080/"
},
"id": "QYirQ63zDYsW",
"outputId": "46af1038-23fd-400c-f013-f56bc8a0e730"
"outputId": "304628ac-bc4c-49c2-a0e1-ecf9385ddcd9"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"name": "stdout",
"text": [
"mask.device: cuda:0\n"
]
@ -362,12 +364,12 @@
"base_uri": "https://localhost:8080/"
},
"id": "WfF0yBZODdAZ",
"outputId": "c7425750-c995-43a6-ca2f-f2dfc402a4fb"
"outputId": "291cfb54-86e6-45f9-99d1-fa145319f379"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"name": "stdout",
"text": [
"tensor([[[-0.4519, 0.2216],\n",
" [-0.5874, 0.0058],\n",
@ -483,12 +485,12 @@
"base_uri": "https://localhost:8080/"
},
"id": "8_VCxEa76j00",
"outputId": "7152e74b-ce7a-44fb-c8d9-46da0908190e"
"outputId": "4d1af501-5a9e-46aa-b1ac-63bf0c68e02a"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"name": "stdout",
"text": [
"W_query.device: cuda:0\n",
"mask.device: cuda:0\n"
@ -511,12 +513,12 @@
"base_uri": "https://localhost:8080/"
},
"id": "TBWvKlMe7bbB",
"outputId": "63aa3589-4fb9-4b75-b161-458afb7d72e2"
"outputId": "e43bf8ab-3fb9-417e-d087-560858332d86"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"name": "stdout",
"text": [
"tensor([[[0.4772, 0.1063],\n",
" [0.5891, 0.3257],\n",
@ -549,6 +551,244 @@
"source": [
"As we can see above, registering a tensor as a buffer can make our lives a lot easier: We don't have to remember to move tensors to a target device like a GPU manually."
]
},
{
"cell_type": "markdown",
"source": [
"## Buffers and `state_dict`"
],
"metadata": {
"id": "Q-5YYKmJte3h"
}
},
{
"cell_type": "markdown",
"source": [
"- Another advantage of PyTorch buffers, over regular tensors, is that they get included in a model's `state_dict`\n",
"- For example, consider the `state_dict` of the causal attention object without buffers"
],
"metadata": {
"id": "YIHHawPbtjfp"
}
},
{
"cell_type": "code",
"source": [
"ca_without_buffer.state_dict()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "c217juzqtxsS",
"outputId": "dbae3c3d-f4f8-4c70-a64f-90906561d8d9"
},
"execution_count": 12,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"OrderedDict([('W_query.weight',\n",
" tensor([[-0.2354, 0.0191, -0.2867],\n",
" [ 0.2177, -0.4919, 0.4232]], device='cuda:0')),\n",
" ('W_key.weight',\n",
" tensor([[-0.4196, -0.4590, -0.3648],\n",
" [ 0.2615, -0.2133, 0.2161]], device='cuda:0')),\n",
" ('W_value.weight',\n",
" tensor([[-0.4900, -0.3503, -0.2120],\n",
" [-0.1135, -0.4404, 0.3780]], device='cuda:0'))])"
]
},
"metadata": {},
"execution_count": 12
}
]
},
{
"cell_type": "markdown",
"source": [
"- The mask is not included in the `state_dict` above\n",
"- However, the mask *is* included in the `state_dict` below, thanks to registering it as a buffer"
],
"metadata": {
"id": "NdmZuPaqt6aO"
}
},
{
"cell_type": "code",
"source": [
"ca_with_buffer.state_dict()"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "uGIGQAwPt1Pl",
"outputId": "00f9bc44-63f9-4ebc-87ea-d4b8cafd81c1"
},
"execution_count": 13,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"OrderedDict([('mask',\n",
" tensor([[0., 1., 1., 1., 1., 1.],\n",
" [0., 0., 1., 1., 1., 1.],\n",
" [0., 0., 0., 1., 1., 1.],\n",
" [0., 0., 0., 0., 1., 1.],\n",
" [0., 0., 0., 0., 0., 1.],\n",
" [0., 0., 0., 0., 0., 0.]], device='cuda:0')),\n",
" ('W_query.weight',\n",
" tensor([[-0.1362, 0.1853, 0.4083],\n",
" [ 0.1076, 0.1579, 0.5573]], device='cuda:0')),\n",
" ('W_key.weight',\n",
" tensor([[-0.2604, 0.1829, -0.2569],\n",
" [ 0.4126, 0.4611, -0.5323]], device='cuda:0')),\n",
" ('W_value.weight',\n",
" tensor([[ 0.4929, 0.2757, 0.2516],\n",
" [ 0.2377, 0.4800, -0.0762]], device='cuda:0'))])"
]
},
"metadata": {},
"execution_count": 13
}
]
},
{
"cell_type": "markdown",
"source": [
"- A `state_dict` is useful when saving and loading trained PyTorch models, for example\n",
"- In this particular case, saving and loading the `mask` is maybe not super useful, because it remains unchanged during training; so, for demonstration purposes, let's assume it was modified where all `1`'s were changed to `2`'s:"
],
"metadata": {
"id": "ACC-a1Hnt4Zv"
}
},
{
"cell_type": "code",
"source": [
"ca_with_buffer.mask[ca_with_buffer.mask == 1.] = 2.\n",
"ca_with_buffer.mask"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "RLm1Sw0cuhvy",
"outputId": "4b2cc70f-1709-44e4-aa17-4e01353b86f8"
},
"execution_count": 14,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"tensor([[0., 2., 2., 2., 2., 2.],\n",
" [0., 0., 2., 2., 2., 2.],\n",
" [0., 0., 0., 2., 2., 2.],\n",
" [0., 0., 0., 0., 2., 2.],\n",
" [0., 0., 0., 0., 0., 2.],\n",
" [0., 0., 0., 0., 0., 0.]], device='cuda:0')"
]
},
"metadata": {},
"execution_count": 14
}
]
},
{
"cell_type": "markdown",
"source": [
"- Then, if we save and load the model, we can see that the mask is restored with the modified value"
],
"metadata": {
"id": "BIkGgGqqvp4S"
}
},
{
"cell_type": "code",
"source": [
"torch.save(ca_with_buffer.state_dict(), \"model.pth\")\n",
"\n",
"new_ca_with_buffer = CausalAttentionWithBuffer(d_in, d_out, context_length, 0.0)\n",
"new_ca_with_buffer.load_state_dict(torch.load(\"model.pth\"))\n",
"\n",
"new_ca_with_buffer.mask"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "e8g0QHUhuVBw",
"outputId": "cc7ee348-7f94-4117-e5cc-e0e01a94e906"
},
"execution_count": 15,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"tensor([[0., 2., 2., 2., 2., 2.],\n",
" [0., 0., 2., 2., 2., 2.],\n",
" [0., 0., 0., 2., 2., 2.],\n",
" [0., 0., 0., 0., 2., 2.],\n",
" [0., 0., 0., 0., 0., 2.],\n",
" [0., 0., 0., 0., 0., 0.]])"
]
},
"metadata": {},
"execution_count": 15
}
]
},
{
"cell_type": "markdown",
"source": [
"- This is not true if we don't use buffers:"
],
"metadata": {
"id": "0pPaJk7bvBD7"
}
},
{
"cell_type": "code",
"source": [
"ca_without_buffer.mask[ca_without_buffer.mask == 1.] = 2.\n",
"\n",
"torch.save(ca_without_buffer.state_dict(), \"model.pth\")\n",
"\n",
"new_ca_without_buffer = CausalAttentionWithoutBuffers(d_in, d_out, context_length, 0.0)\n",
"new_ca_without_buffer.load_state_dict(torch.load(\"model.pth\"))\n",
"\n",
"new_ca_without_buffer.mask"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "D03w8vDyvBRS",
"outputId": "28071601-120c-42da-b327-bb293793839f"
},
"execution_count": 16,
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"tensor([[0., 1., 1., 1., 1., 1.],\n",
" [0., 0., 1., 1., 1., 1.],\n",
" [0., 0., 0., 1., 1., 1.],\n",
" [0., 0., 0., 0., 1., 1.],\n",
" [0., 0., 0., 0., 0., 1.],\n",
" [0., 0., 0., 0., 0., 0.]])"
]
},
"metadata": {},
"execution_count": 16
}
]
}
],
"metadata": {
@ -576,5 +816,5 @@
}
},
"nbformat": 4,
"nbformat_minor": 4
"nbformat_minor": 0
}