mirror of
				https://github.com/rasbt/LLMs-from-scratch.git
				synced 2025-11-04 03:40:21 +00:00 
			
		
		
		
	explain how class labels are obtained
This commit is contained in:
		
							parent
							
								
									6fe8d1a10e
								
							
						
					
					
						commit
						694a57a472
					
				@ -570,7 +570,7 @@
 | 
			
		||||
  },
 | 
			
		||||
  {
 | 
			
		||||
   "cell_type": "code",
 | 
			
		||||
   "execution_count": 10,
 | 
			
		||||
   "execution_count": 9,
 | 
			
		||||
   "id": "d7791b52-af18-4ac4-afa9-b921068e383e",
 | 
			
		||||
   "metadata": {
 | 
			
		||||
    "id": "d7791b52-af18-4ac4-afa9-b921068e383e"
 | 
			
		||||
@ -628,7 +628,7 @@
 | 
			
		||||
  },
 | 
			
		||||
  {
 | 
			
		||||
   "cell_type": "code",
 | 
			
		||||
   "execution_count": 11,
 | 
			
		||||
   "execution_count": 10,
 | 
			
		||||
   "id": "uzj85f8ou82h",
 | 
			
		||||
   "metadata": {
 | 
			
		||||
    "colab": {
 | 
			
		||||
@ -668,7 +668,7 @@
 | 
			
		||||
  },
 | 
			
		||||
  {
 | 
			
		||||
   "cell_type": "code",
 | 
			
		||||
   "execution_count": 12,
 | 
			
		||||
   "execution_count": 11,
 | 
			
		||||
   "id": "bb0c502d-a75e-4248-8ea0-196e2b00c61e",
 | 
			
		||||
   "metadata": {
 | 
			
		||||
    "id": "bb0c502d-a75e-4248-8ea0-196e2b00c61e"
 | 
			
		||||
@ -705,7 +705,7 @@
 | 
			
		||||
  },
 | 
			
		||||
  {
 | 
			
		||||
   "cell_type": "code",
 | 
			
		||||
   "execution_count": 13,
 | 
			
		||||
   "execution_count": 12,
 | 
			
		||||
   "id": "8681adc0-6f02-4e75-b01a-a6ab75d05542",
 | 
			
		||||
   "metadata": {
 | 
			
		||||
    "colab": {
 | 
			
		||||
@ -756,7 +756,7 @@
 | 
			
		||||
  },
 | 
			
		||||
  {
 | 
			
		||||
   "cell_type": "code",
 | 
			
		||||
   "execution_count": 14,
 | 
			
		||||
   "execution_count": 13,
 | 
			
		||||
   "id": "4dee6882-4c3a-4964-af15-fa31f86ad047",
 | 
			
		||||
   "metadata": {},
 | 
			
		||||
   "outputs": [
 | 
			
		||||
@ -789,7 +789,7 @@
 | 
			
		||||
  },
 | 
			
		||||
  {
 | 
			
		||||
   "cell_type": "code",
 | 
			
		||||
   "execution_count": 15,
 | 
			
		||||
   "execution_count": 14,
 | 
			
		||||
   "id": "IZfw-TYD2zTj",
 | 
			
		||||
   "metadata": {
 | 
			
		||||
    "colab": {
 | 
			
		||||
@ -837,7 +837,7 @@
 | 
			
		||||
  },
 | 
			
		||||
  {
 | 
			
		||||
   "cell_type": "code",
 | 
			
		||||
   "execution_count": 16,
 | 
			
		||||
   "execution_count": 15,
 | 
			
		||||
   "id": "2992d779-f9fb-4812-a117-553eb790a5a9",
 | 
			
		||||
   "metadata": {
 | 
			
		||||
    "id": "2992d779-f9fb-4812-a117-553eb790a5a9"
 | 
			
		||||
@ -866,7 +866,7 @@
 | 
			
		||||
  },
 | 
			
		||||
  {
 | 
			
		||||
   "cell_type": "code",
 | 
			
		||||
   "execution_count": 17,
 | 
			
		||||
   "execution_count": 16,
 | 
			
		||||
   "id": "022a649a-44f5-466c-8a8e-326c063384f5",
 | 
			
		||||
   "metadata": {
 | 
			
		||||
    "colab": {
 | 
			
		||||
@ -877,16 +877,16 @@
 | 
			
		||||
   },
 | 
			
		||||
   "outputs": [
 | 
			
		||||
    {
 | 
			
		||||
     "name": "stderr",
 | 
			
		||||
     "name": "stdout",
 | 
			
		||||
     "output_type": "stream",
 | 
			
		||||
     "text": [
 | 
			
		||||
      "checkpoint: 100%|███████████████████████████| 77.0/77.0 [00:00<00:00, 39.7kiB/s]\n",
 | 
			
		||||
      "encoder.json: 100%|███████████████████████| 1.04M/1.04M [00:00<00:00, 3.25MiB/s]\n",
 | 
			
		||||
      "hparams.json: 100%|█████████████████████████| 90.0/90.0 [00:00<00:00, 51.4kiB/s]\n",
 | 
			
		||||
      "model.ckpt.data-00000-of-00001: 100%|███████| 498M/498M [01:00<00:00, 8.20MiB/s]\n",
 | 
			
		||||
      "model.ckpt.index: 100%|███████████████████| 5.21k/5.21k [00:00<00:00, 2.34MiB/s]\n",
 | 
			
		||||
      "model.ckpt.meta: 100%|██████████████████████| 471k/471k [00:00<00:00, 2.26MiB/s]\n",
 | 
			
		||||
      "vocab.bpe: 100%|████████████████████████████| 456k/456k [00:00<00:00, 2.62MiB/s]\n"
 | 
			
		||||
      "File already exists and is up-to-date: gpt2/124M/checkpoint\n",
 | 
			
		||||
      "File already exists and is up-to-date: gpt2/124M/encoder.json\n",
 | 
			
		||||
      "File already exists and is up-to-date: gpt2/124M/hparams.json\n",
 | 
			
		||||
      "File already exists and is up-to-date: gpt2/124M/model.ckpt.data-00000-of-00001\n",
 | 
			
		||||
      "File already exists and is up-to-date: gpt2/124M/model.ckpt.index\n",
 | 
			
		||||
      "File already exists and is up-to-date: gpt2/124M/model.ckpt.meta\n",
 | 
			
		||||
      "File already exists and is up-to-date: gpt2/124M/vocab.bpe\n"
 | 
			
		||||
     ]
 | 
			
		||||
    }
 | 
			
		||||
   ],
 | 
			
		||||
@ -912,7 +912,7 @@
 | 
			
		||||
  },
 | 
			
		||||
  {
 | 
			
		||||
   "cell_type": "code",
 | 
			
		||||
   "execution_count": 18,
 | 
			
		||||
   "execution_count": 17,
 | 
			
		||||
   "id": "d8ac25ff-74b1-4149-8dc5-4c429d464330",
 | 
			
		||||
   "metadata": {},
 | 
			
		||||
   "outputs": [
 | 
			
		||||
@ -956,7 +956,7 @@
 | 
			
		||||
  },
 | 
			
		||||
  {
 | 
			
		||||
   "cell_type": "code",
 | 
			
		||||
   "execution_count": 19,
 | 
			
		||||
   "execution_count": 18,
 | 
			
		||||
   "id": "94224aa9-c95a-4f8a-a420-76d01e3a800c",
 | 
			
		||||
   "metadata": {},
 | 
			
		||||
   "outputs": [
 | 
			
		||||
@ -1024,7 +1024,7 @@
 | 
			
		||||
  },
 | 
			
		||||
  {
 | 
			
		||||
   "cell_type": "code",
 | 
			
		||||
   "execution_count": 20,
 | 
			
		||||
   "execution_count": 19,
 | 
			
		||||
   "id": "b23aff91-6bd0-48da-88f6-353657e6c981",
 | 
			
		||||
   "metadata": {
 | 
			
		||||
    "colab": {
 | 
			
		||||
@ -1294,7 +1294,7 @@
 | 
			
		||||
  },
 | 
			
		||||
  {
 | 
			
		||||
   "cell_type": "code",
 | 
			
		||||
   "execution_count": 21,
 | 
			
		||||
   "execution_count": 20,
 | 
			
		||||
   "id": "fkMWFl-0etea",
 | 
			
		||||
   "metadata": {
 | 
			
		||||
    "id": "fkMWFl-0etea"
 | 
			
		||||
@ -1317,7 +1317,7 @@
 | 
			
		||||
  },
 | 
			
		||||
  {
 | 
			
		||||
   "cell_type": "code",
 | 
			
		||||
   "execution_count": 22,
 | 
			
		||||
   "execution_count": 21,
 | 
			
		||||
   "id": "7e759fa0-0f69-41be-b576-17e5f20e04cb",
 | 
			
		||||
   "metadata": {},
 | 
			
		||||
   "outputs": [],
 | 
			
		||||
@ -1348,7 +1348,7 @@
 | 
			
		||||
  },
 | 
			
		||||
  {
 | 
			
		||||
   "cell_type": "code",
 | 
			
		||||
   "execution_count": 23,
 | 
			
		||||
   "execution_count": 39,
 | 
			
		||||
   "id": "2aedc120-5ee3-48f6-92f2-ad9304ebcdc7",
 | 
			
		||||
   "metadata": {
 | 
			
		||||
    "id": "2aedc120-5ee3-48f6-92f2-ad9304ebcdc7"
 | 
			
		||||
@ -1373,7 +1373,7 @@
 | 
			
		||||
  },
 | 
			
		||||
  {
 | 
			
		||||
   "cell_type": "code",
 | 
			
		||||
   "execution_count": 24,
 | 
			
		||||
   "execution_count": 42,
 | 
			
		||||
   "id": "f645c06a-7df6-451c-ad3f-eafb18224ebc",
 | 
			
		||||
   "metadata": {
 | 
			
		||||
    "colab": {
 | 
			
		||||
@ -1409,7 +1409,7 @@
 | 
			
		||||
  },
 | 
			
		||||
  {
 | 
			
		||||
   "cell_type": "code",
 | 
			
		||||
   "execution_count": 25,
 | 
			
		||||
   "execution_count": 43,
 | 
			
		||||
   "id": "48dc84f1-85cc-4609-9cee-94ff539f00f4",
 | 
			
		||||
   "metadata": {
 | 
			
		||||
    "colab": {
 | 
			
		||||
@ -1470,7 +1470,7 @@
 | 
			
		||||
  },
 | 
			
		||||
  {
 | 
			
		||||
   "cell_type": "code",
 | 
			
		||||
   "execution_count": 25,
 | 
			
		||||
   "execution_count": 44,
 | 
			
		||||
   "id": "49383a8c-41d5-4dab-98f1-238bca0c2ed7",
 | 
			
		||||
   "metadata": {
 | 
			
		||||
    "colab": {
 | 
			
		||||
@ -1484,7 +1484,7 @@
 | 
			
		||||
     "name": "stdout",
 | 
			
		||||
     "output_type": "stream",
 | 
			
		||||
     "text": [
 | 
			
		||||
      "Last output token: tensor([[-5.7543,  5.3615]])\n"
 | 
			
		||||
      "Last output token: tensor([[-3.5983,  3.9902]])\n"
 | 
			
		||||
     ]
 | 
			
		||||
    }
 | 
			
		||||
   ],
 | 
			
		||||
@ -1516,6 +1516,68 @@
 | 
			
		||||
    "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch06_compressed/overview-3.webp\" width=500px>"
 | 
			
		||||
   ]
 | 
			
		||||
  },
 | 
			
		||||
  {
 | 
			
		||||
   "cell_type": "markdown",
 | 
			
		||||
   "id": "7a7df4ee-0a34-4a4d-896d-affbbf81e0b3",
 | 
			
		||||
   "metadata": {},
 | 
			
		||||
   "source": [
 | 
			
		||||
    "- Before explaining the loss calculation, let's have a brief look at how the model outputs are turned into class labels"
 | 
			
		||||
   ]
 | 
			
		||||
  },
 | 
			
		||||
  {
 | 
			
		||||
   "cell_type": "code",
 | 
			
		||||
   "execution_count": 45,
 | 
			
		||||
   "id": "c77faab1-3461-4118-866a-6171f2b89aa0",
 | 
			
		||||
   "metadata": {},
 | 
			
		||||
   "outputs": [
 | 
			
		||||
    {
 | 
			
		||||
     "name": "stdout",
 | 
			
		||||
     "output_type": "stream",
 | 
			
		||||
     "text": [
 | 
			
		||||
      "Last output token: tensor([[-3.5983,  3.9902]])\n"
 | 
			
		||||
     ]
 | 
			
		||||
    }
 | 
			
		||||
   ],
 | 
			
		||||
   "source": [
 | 
			
		||||
    "print(\"Last output token:\", outputs[:, -1, :])"
 | 
			
		||||
   ]
 | 
			
		||||
  },
 | 
			
		||||
  {
 | 
			
		||||
   "cell_type": "markdown",
 | 
			
		||||
   "id": "7edd71fa-628a-4d00-b81d-6d8bcb2c341d",
 | 
			
		||||
   "metadata": {},
 | 
			
		||||
   "source": [
 | 
			
		||||
    "- Similar to chapter 5, we convert the outputs (logits) into probability scores via the `softmax` function and then obtain the index position of the largest probability value via the `argmax` function:"
 | 
			
		||||
   ]
 | 
			
		||||
  },
 | 
			
		||||
  {
 | 
			
		||||
   "cell_type": "code",
 | 
			
		||||
   "execution_count": 48,
 | 
			
		||||
   "id": "b81efa92-9be1-4b9e-8790-ce1fc7b17f01",
 | 
			
		||||
   "metadata": {},
 | 
			
		||||
   "outputs": [
 | 
			
		||||
    {
 | 
			
		||||
     "name": "stdout",
 | 
			
		||||
     "output_type": "stream",
 | 
			
		||||
     "text": [
 | 
			
		||||
      "Class label: 1\n"
 | 
			
		||||
     ]
 | 
			
		||||
    }
 | 
			
		||||
   ],
 | 
			
		||||
   "source": [
 | 
			
		||||
    "probas = torch.softmax(outputs[:, -1, :], dim=-1)\n",
 | 
			
		||||
    "label = torch.argmax(probas)\n",
 | 
			
		||||
    "print(\"Class label:\", label.item())"
 | 
			
		||||
   ]
 | 
			
		||||
  },
 | 
			
		||||
  {
 | 
			
		||||
   "cell_type": "markdown",
 | 
			
		||||
   "id": "d5241f47-a1e4-4bba-8064-5d06cffa7941",
 | 
			
		||||
   "metadata": {},
 | 
			
		||||
   "source": [
 | 
			
		||||
    "- Note that the softmax function is optional here, as explained in chapter 5, because the largest outputs correspond to the largest probability scores"
 | 
			
		||||
   ]
 | 
			
		||||
  },
 | 
			
		||||
  {
 | 
			
		||||
   "cell_type": "markdown",
 | 
			
		||||
   "id": "4f4a9d15-8fc7-48a2-8734-d92a2f265328",
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user