explain how class labels are obtained

This commit is contained in:
rasbt 2024-05-11 07:42:13 -05:00
parent 6fe8d1a10e
commit 694a57a472

View File

@ -570,7 +570,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 9,
"id": "d7791b52-af18-4ac4-afa9-b921068e383e",
"metadata": {
"id": "d7791b52-af18-4ac4-afa9-b921068e383e"
@ -628,7 +628,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 10,
"id": "uzj85f8ou82h",
"metadata": {
"colab": {
@ -668,7 +668,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 11,
"id": "bb0c502d-a75e-4248-8ea0-196e2b00c61e",
"metadata": {
"id": "bb0c502d-a75e-4248-8ea0-196e2b00c61e"
@ -705,7 +705,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 12,
"id": "8681adc0-6f02-4e75-b01a-a6ab75d05542",
"metadata": {
"colab": {
@ -756,7 +756,7 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 13,
"id": "4dee6882-4c3a-4964-af15-fa31f86ad047",
"metadata": {},
"outputs": [
@ -789,7 +789,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 14,
"id": "IZfw-TYD2zTj",
"metadata": {
"colab": {
@ -837,7 +837,7 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 15,
"id": "2992d779-f9fb-4812-a117-553eb790a5a9",
"metadata": {
"id": "2992d779-f9fb-4812-a117-553eb790a5a9"
@ -866,7 +866,7 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 16,
"id": "022a649a-44f5-466c-8a8e-326c063384f5",
"metadata": {
"colab": {
@ -877,16 +877,16 @@
},
"outputs": [
{
"name": "stderr",
"name": "stdout",
"output_type": "stream",
"text": [
"checkpoint: 100%|███████████████████████████| 77.0/77.0 [00:00<00:00, 39.7kiB/s]\n",
"encoder.json: 100%|███████████████████████| 1.04M/1.04M [00:00<00:00, 3.25MiB/s]\n",
"hparams.json: 100%|█████████████████████████| 90.0/90.0 [00:00<00:00, 51.4kiB/s]\n",
"model.ckpt.data-00000-of-00001: 100%|███████| 498M/498M [01:00<00:00, 8.20MiB/s]\n",
"model.ckpt.index: 100%|███████████████████| 5.21k/5.21k [00:00<00:00, 2.34MiB/s]\n",
"model.ckpt.meta: 100%|██████████████████████| 471k/471k [00:00<00:00, 2.26MiB/s]\n",
"vocab.bpe: 100%|████████████████████████████| 456k/456k [00:00<00:00, 2.62MiB/s]\n"
"File already exists and is up-to-date: gpt2/124M/checkpoint\n",
"File already exists and is up-to-date: gpt2/124M/encoder.json\n",
"File already exists and is up-to-date: gpt2/124M/hparams.json\n",
"File already exists and is up-to-date: gpt2/124M/model.ckpt.data-00000-of-00001\n",
"File already exists and is up-to-date: gpt2/124M/model.ckpt.index\n",
"File already exists and is up-to-date: gpt2/124M/model.ckpt.meta\n",
"File already exists and is up-to-date: gpt2/124M/vocab.bpe\n"
]
}
],
@ -912,7 +912,7 @@
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": 17,
"id": "d8ac25ff-74b1-4149-8dc5-4c429d464330",
"metadata": {},
"outputs": [
@ -956,7 +956,7 @@
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": 18,
"id": "94224aa9-c95a-4f8a-a420-76d01e3a800c",
"metadata": {},
"outputs": [
@ -1024,7 +1024,7 @@
},
{
"cell_type": "code",
"execution_count": 20,
"execution_count": 19,
"id": "b23aff91-6bd0-48da-88f6-353657e6c981",
"metadata": {
"colab": {
@ -1294,7 +1294,7 @@
},
{
"cell_type": "code",
"execution_count": 21,
"execution_count": 20,
"id": "fkMWFl-0etea",
"metadata": {
"id": "fkMWFl-0etea"
@ -1317,7 +1317,7 @@
},
{
"cell_type": "code",
"execution_count": 22,
"execution_count": 21,
"id": "7e759fa0-0f69-41be-b576-17e5f20e04cb",
"metadata": {},
"outputs": [],
@ -1348,7 +1348,7 @@
},
{
"cell_type": "code",
"execution_count": 23,
"execution_count": 39,
"id": "2aedc120-5ee3-48f6-92f2-ad9304ebcdc7",
"metadata": {
"id": "2aedc120-5ee3-48f6-92f2-ad9304ebcdc7"
@ -1373,7 +1373,7 @@
},
{
"cell_type": "code",
"execution_count": 24,
"execution_count": 42,
"id": "f645c06a-7df6-451c-ad3f-eafb18224ebc",
"metadata": {
"colab": {
@ -1409,7 +1409,7 @@
},
{
"cell_type": "code",
"execution_count": 25,
"execution_count": 43,
"id": "48dc84f1-85cc-4609-9cee-94ff539f00f4",
"metadata": {
"colab": {
@ -1470,7 +1470,7 @@
},
{
"cell_type": "code",
"execution_count": 25,
"execution_count": 44,
"id": "49383a8c-41d5-4dab-98f1-238bca0c2ed7",
"metadata": {
"colab": {
@ -1484,7 +1484,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Last output token: tensor([[-5.7543, 5.3615]])\n"
"Last output token: tensor([[-3.5983, 3.9902]])\n"
]
}
],
@ -1516,6 +1516,68 @@
"<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch06_compressed/overview-3.webp\" width=500px>"
]
},
{
"cell_type": "markdown",
"id": "7a7df4ee-0a34-4a4d-896d-affbbf81e0b3",
"metadata": {},
"source": [
"- Before explaining the loss calculation, let's have a brief look at how the model outputs are turned into class labels"
]
},
{
"cell_type": "code",
"execution_count": 45,
"id": "c77faab1-3461-4118-866a-6171f2b89aa0",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Last output token: tensor([[-3.5983, 3.9902]])\n"
]
}
],
"source": [
"print(\"Last output token:\", outputs[:, -1, :])"
]
},
{
"cell_type": "markdown",
"id": "7edd71fa-628a-4d00-b81d-6d8bcb2c341d",
"metadata": {},
"source": [
"- Similar to chapter 5, we convert the outputs (logits) into probability scores via the `softmax` function and then obtain the index position of the largest probability value via the `argmax` function:"
]
},
{
"cell_type": "code",
"execution_count": 48,
"id": "b81efa92-9be1-4b9e-8790-ce1fc7b17f01",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Class label: 1\n"
]
}
],
"source": [
"probas = torch.softmax(outputs[:, -1, :], dim=-1)\n",
"label = torch.argmax(probas)\n",
"print(\"Class label:\", label.item())"
]
},
{
"cell_type": "markdown",
"id": "d5241f47-a1e4-4bba-8064-5d06cffa7941",
"metadata": {},
"source": [
"- Note that the softmax function is optional here, as explained in chapter 5, because the largest outputs correspond to the largest probability scores"
]
},
{
"cell_type": "markdown",
"id": "4f4a9d15-8fc7-48a2-8734-d92a2f265328",