diff --git a/ch06/01_main-chapter-code/ch06.ipynb b/ch06/01_main-chapter-code/ch06.ipynb index fecff3f..5f9b025 100644 --- a/ch06/01_main-chapter-code/ch06.ipynb +++ b/ch06/01_main-chapter-code/ch06.ipynb @@ -570,7 +570,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 9, "id": "d7791b52-af18-4ac4-afa9-b921068e383e", "metadata": { "id": "d7791b52-af18-4ac4-afa9-b921068e383e" @@ -628,7 +628,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 10, "id": "uzj85f8ou82h", "metadata": { "colab": { @@ -668,7 +668,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 11, "id": "bb0c502d-a75e-4248-8ea0-196e2b00c61e", "metadata": { "id": "bb0c502d-a75e-4248-8ea0-196e2b00c61e" @@ -705,7 +705,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 12, "id": "8681adc0-6f02-4e75-b01a-a6ab75d05542", "metadata": { "colab": { @@ -756,7 +756,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 13, "id": "4dee6882-4c3a-4964-af15-fa31f86ad047", "metadata": {}, "outputs": [ @@ -789,7 +789,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 14, "id": "IZfw-TYD2zTj", "metadata": { "colab": { @@ -837,7 +837,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 15, "id": "2992d779-f9fb-4812-a117-553eb790a5a9", "metadata": { "id": "2992d779-f9fb-4812-a117-553eb790a5a9" @@ -866,7 +866,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 16, "id": "022a649a-44f5-466c-8a8e-326c063384f5", "metadata": { "colab": { @@ -877,16 +877,16 @@ }, "outputs": [ { - "name": "stderr", + "name": "stdout", "output_type": "stream", "text": [ - "checkpoint: 100%|███████████████████████████| 77.0/77.0 [00:00<00:00, 39.7kiB/s]\n", - "encoder.json: 100%|███████████████████████| 1.04M/1.04M [00:00<00:00, 3.25MiB/s]\n", - "hparams.json: 100%|█████████████████████████| 90.0/90.0 [00:00<00:00, 51.4kiB/s]\n", - "model.ckpt.data-00000-of-00001: 100%|███████| 498M/498M [01:00<00:00, 8.20MiB/s]\n", - "model.ckpt.index: 100%|███████████████████| 5.21k/5.21k [00:00<00:00, 2.34MiB/s]\n", - "model.ckpt.meta: 100%|██████████████████████| 471k/471k [00:00<00:00, 2.26MiB/s]\n", - "vocab.bpe: 100%|████████████████████████████| 456k/456k [00:00<00:00, 2.62MiB/s]\n" + "File already exists and is up-to-date: gpt2/124M/checkpoint\n", + "File already exists and is up-to-date: gpt2/124M/encoder.json\n", + "File already exists and is up-to-date: gpt2/124M/hparams.json\n", + "File already exists and is up-to-date: gpt2/124M/model.ckpt.data-00000-of-00001\n", + "File already exists and is up-to-date: gpt2/124M/model.ckpt.index\n", + "File already exists and is up-to-date: gpt2/124M/model.ckpt.meta\n", + "File already exists and is up-to-date: gpt2/124M/vocab.bpe\n" ] } ], @@ -912,7 +912,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 17, "id": "d8ac25ff-74b1-4149-8dc5-4c429d464330", "metadata": {}, "outputs": [ @@ -956,7 +956,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 18, "id": "94224aa9-c95a-4f8a-a420-76d01e3a800c", "metadata": {}, "outputs": [ @@ -1024,7 +1024,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 19, "id": "b23aff91-6bd0-48da-88f6-353657e6c981", "metadata": { "colab": { @@ -1294,7 +1294,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 20, "id": "fkMWFl-0etea", "metadata": { "id": "fkMWFl-0etea" @@ -1317,7 +1317,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 21, "id": "7e759fa0-0f69-41be-b576-17e5f20e04cb", "metadata": {}, "outputs": [], @@ -1348,7 +1348,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 39, "id": "2aedc120-5ee3-48f6-92f2-ad9304ebcdc7", "metadata": { "id": "2aedc120-5ee3-48f6-92f2-ad9304ebcdc7" @@ -1373,7 +1373,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 42, "id": "f645c06a-7df6-451c-ad3f-eafb18224ebc", "metadata": { "colab": { @@ -1409,7 +1409,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 43, "id": "48dc84f1-85cc-4609-9cee-94ff539f00f4", "metadata": { "colab": { @@ -1470,7 +1470,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 44, "id": "49383a8c-41d5-4dab-98f1-238bca0c2ed7", "metadata": { "colab": { @@ -1484,7 +1484,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Last output token: tensor([[-5.7543, 5.3615]])\n" + "Last output token: tensor([[-3.5983, 3.9902]])\n" ] } ], @@ -1516,6 +1516,68 @@ "" ] }, + { + "cell_type": "markdown", + "id": "7a7df4ee-0a34-4a4d-896d-affbbf81e0b3", + "metadata": {}, + "source": [ + "- Before explaining the loss calculation, let's have a brief look at how the model outputs are turned into class labels" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "id": "c77faab1-3461-4118-866a-6171f2b89aa0", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Last output token: tensor([[-3.5983, 3.9902]])\n" + ] + } + ], + "source": [ + "print(\"Last output token:\", outputs[:, -1, :])" + ] + }, + { + "cell_type": "markdown", + "id": "7edd71fa-628a-4d00-b81d-6d8bcb2c341d", + "metadata": {}, + "source": [ + "- Similar to chapter 5, we convert the outputs (logits) into probability scores via the `softmax` function and then obtain the index position of the largest probability value via the `argmax` function:" + ] + }, + { + "cell_type": "code", + "execution_count": 48, + "id": "b81efa92-9be1-4b9e-8790-ce1fc7b17f01", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Class label: 1\n" + ] + } + ], + "source": [ + "probas = torch.softmax(outputs[:, -1, :], dim=-1)\n", + "label = torch.argmax(probas)\n", + "print(\"Class label:\", label.item())" + ] + }, + { + "cell_type": "markdown", + "id": "d5241f47-a1e4-4bba-8064-5d06cffa7941", + "metadata": {}, + "source": [ + "- Note that the softmax function is optional here, as explained in chapter 5, because the largest outputs correspond to the largest probability scores" + ] + }, { "cell_type": "markdown", "id": "4f4a9d15-8fc7-48a2-8734-d92a2f265328",