mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-09-02 12:57:41 +00:00
explain how class labels are obtained
This commit is contained in:
parent
6fe8d1a10e
commit
694a57a472
@ -570,7 +570,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"execution_count": 9,
|
||||
"id": "d7791b52-af18-4ac4-afa9-b921068e383e",
|
||||
"metadata": {
|
||||
"id": "d7791b52-af18-4ac4-afa9-b921068e383e"
|
||||
@ -628,7 +628,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"execution_count": 10,
|
||||
"id": "uzj85f8ou82h",
|
||||
"metadata": {
|
||||
"colab": {
|
||||
@ -668,7 +668,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"execution_count": 11,
|
||||
"id": "bb0c502d-a75e-4248-8ea0-196e2b00c61e",
|
||||
"metadata": {
|
||||
"id": "bb0c502d-a75e-4248-8ea0-196e2b00c61e"
|
||||
@ -705,7 +705,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"execution_count": 12,
|
||||
"id": "8681adc0-6f02-4e75-b01a-a6ab75d05542",
|
||||
"metadata": {
|
||||
"colab": {
|
||||
@ -756,7 +756,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"execution_count": 13,
|
||||
"id": "4dee6882-4c3a-4964-af15-fa31f86ad047",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -789,7 +789,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"execution_count": 14,
|
||||
"id": "IZfw-TYD2zTj",
|
||||
"metadata": {
|
||||
"colab": {
|
||||
@ -837,7 +837,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"execution_count": 15,
|
||||
"id": "2992d779-f9fb-4812-a117-553eb790a5a9",
|
||||
"metadata": {
|
||||
"id": "2992d779-f9fb-4812-a117-553eb790a5a9"
|
||||
@ -866,7 +866,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
"execution_count": 16,
|
||||
"id": "022a649a-44f5-466c-8a8e-326c063384f5",
|
||||
"metadata": {
|
||||
"colab": {
|
||||
@ -877,16 +877,16 @@
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"checkpoint: 100%|███████████████████████████| 77.0/77.0 [00:00<00:00, 39.7kiB/s]\n",
|
||||
"encoder.json: 100%|███████████████████████| 1.04M/1.04M [00:00<00:00, 3.25MiB/s]\n",
|
||||
"hparams.json: 100%|█████████████████████████| 90.0/90.0 [00:00<00:00, 51.4kiB/s]\n",
|
||||
"model.ckpt.data-00000-of-00001: 100%|███████| 498M/498M [01:00<00:00, 8.20MiB/s]\n",
|
||||
"model.ckpt.index: 100%|███████████████████| 5.21k/5.21k [00:00<00:00, 2.34MiB/s]\n",
|
||||
"model.ckpt.meta: 100%|██████████████████████| 471k/471k [00:00<00:00, 2.26MiB/s]\n",
|
||||
"vocab.bpe: 100%|████████████████████████████| 456k/456k [00:00<00:00, 2.62MiB/s]\n"
|
||||
"File already exists and is up-to-date: gpt2/124M/checkpoint\n",
|
||||
"File already exists and is up-to-date: gpt2/124M/encoder.json\n",
|
||||
"File already exists and is up-to-date: gpt2/124M/hparams.json\n",
|
||||
"File already exists and is up-to-date: gpt2/124M/model.ckpt.data-00000-of-00001\n",
|
||||
"File already exists and is up-to-date: gpt2/124M/model.ckpt.index\n",
|
||||
"File already exists and is up-to-date: gpt2/124M/model.ckpt.meta\n",
|
||||
"File already exists and is up-to-date: gpt2/124M/vocab.bpe\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -912,7 +912,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"execution_count": 17,
|
||||
"id": "d8ac25ff-74b1-4149-8dc5-4c429d464330",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -956,7 +956,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 19,
|
||||
"execution_count": 18,
|
||||
"id": "94224aa9-c95a-4f8a-a420-76d01e3a800c",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -1024,7 +1024,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 20,
|
||||
"execution_count": 19,
|
||||
"id": "b23aff91-6bd0-48da-88f6-353657e6c981",
|
||||
"metadata": {
|
||||
"colab": {
|
||||
@ -1294,7 +1294,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 21,
|
||||
"execution_count": 20,
|
||||
"id": "fkMWFl-0etea",
|
||||
"metadata": {
|
||||
"id": "fkMWFl-0etea"
|
||||
@ -1317,7 +1317,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 22,
|
||||
"execution_count": 21,
|
||||
"id": "7e759fa0-0f69-41be-b576-17e5f20e04cb",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -1348,7 +1348,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 23,
|
||||
"execution_count": 39,
|
||||
"id": "2aedc120-5ee3-48f6-92f2-ad9304ebcdc7",
|
||||
"metadata": {
|
||||
"id": "2aedc120-5ee3-48f6-92f2-ad9304ebcdc7"
|
||||
@ -1373,7 +1373,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 24,
|
||||
"execution_count": 42,
|
||||
"id": "f645c06a-7df6-451c-ad3f-eafb18224ebc",
|
||||
"metadata": {
|
||||
"colab": {
|
||||
@ -1409,7 +1409,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 25,
|
||||
"execution_count": 43,
|
||||
"id": "48dc84f1-85cc-4609-9cee-94ff539f00f4",
|
||||
"metadata": {
|
||||
"colab": {
|
||||
@ -1470,7 +1470,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 25,
|
||||
"execution_count": 44,
|
||||
"id": "49383a8c-41d5-4dab-98f1-238bca0c2ed7",
|
||||
"metadata": {
|
||||
"colab": {
|
||||
@ -1484,7 +1484,7 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Last output token: tensor([[-5.7543, 5.3615]])\n"
|
||||
"Last output token: tensor([[-3.5983, 3.9902]])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -1516,6 +1516,68 @@
|
||||
"<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch06_compressed/overview-3.webp\" width=500px>"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "7a7df4ee-0a34-4a4d-896d-affbbf81e0b3",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"- Before explaining the loss calculation, let's have a brief look at how the model outputs are turned into class labels"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 45,
|
||||
"id": "c77faab1-3461-4118-866a-6171f2b89aa0",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Last output token: tensor([[-3.5983, 3.9902]])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(\"Last output token:\", outputs[:, -1, :])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "7edd71fa-628a-4d00-b81d-6d8bcb2c341d",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"- Similar to chapter 5, we convert the outputs (logits) into probability scores via the `softmax` function and then obtain the index position of the largest probability value via the `argmax` function:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 48,
|
||||
"id": "b81efa92-9be1-4b9e-8790-ce1fc7b17f01",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Class label: 1\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"probas = torch.softmax(outputs[:, -1, :], dim=-1)\n",
|
||||
"label = torch.argmax(probas)\n",
|
||||
"print(\"Class label:\", label.item())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "d5241f47-a1e4-4bba-8064-5d06cffa7941",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"- Note that the softmax function is optional here, as explained in chapter 5, because the largest outputs correspond to the largest probability scores"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "4f4a9d15-8fc7-48a2-8734-d92a2f265328",
|
||||
|
Loading…
x
Reference in New Issue
Block a user