2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								{
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 "cells": [
							 
						 
					
						
							
								
									
										
										
										
											2024-03-19 09:26:26 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "45398736-7e89-4263-89c8-92153baff553",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-05-24 07:20:37 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "<table style=\"width:100%\">\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "<tr>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "<td style=\"vertical-align:middle; text-align:left;\">\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "<font size=\"2\">\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "Supplementary code for the <a href=\"http://mng.bz/orYv\">Build a Large Language Model From Scratch</a> book by <a href=\"https://sebastianraschka.com\">Sebastian Raschka</a><br>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "<br>Code repository: <a href=\"https://github.com/rasbt/LLMs-from-scratch\">https://github.com/rasbt/LLMs-from-scratch</a>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "</font>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "</td>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "<td style=\"vertical-align:middle; text-align:left;\">\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "<a href=\"http://mng.bz/orYv\"><img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/cover-small.webp\" width=\"100px\"></a>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "</td>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "</tr>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "</table>"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-19 09:26:26 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "66dd524e-864c-4012-b0a2-ccfc56e80024",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "66dd524e-864c-4012-b0a2-ccfc56e80024"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-03-25 08:09:31 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "# Chapter 5: Pretraining on Unlabeled Data"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 08:45:14 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 1,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "92b989e9-da36-4159-b212-799184764dd9",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
									
										
										
										
											2024-05-31 07:30:57 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "matplotlib version: 3.9.0\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-13 14:57:56 -04:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "numpy version: 1.25.2\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "tiktoken version: 0.5.1\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-07 08:41:09 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "torch version: 2.2.2\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-13 14:57:56 -04:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "tensorflow version: 2.15.0\n"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "from importlib.metadata import version\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 20:28:51 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "pkgs = [\"matplotlib\", \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        \"numpy\", \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        \"tiktoken\", \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        \"torch\",\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        \"tensorflow\" # For OpenAI's pretrained weights\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "       ]\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "for p in pkgs:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    print(f\"{p} version: {version(p)}\")"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "0a3bdf9e-2ff0-4a57-abab-ede2d955a237",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- In this chapter, we implement the training loop and code for basic model evaluation to pretrain an LLM\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- At the end of this chapter, we also load openly available pretrained weights from OpenAI into our model"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "efd27fcc-2886-47cb-b544-046c2c31f02a",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch05_compressed/chapter-overview.webp\" width=500px>"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "0d214765-7a73-42d5-95e9-302154b29db9",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- The topics covered in this chapter are shown below"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "f67711d4-8391-4fee-aeef-07ea53dd5841",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch05_compressed/mental-model--0.webp\" width=400px>"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "0d824183-145c-4865-89e1-1f0d0a338f19",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "0d824183-145c-4865-89e1-1f0d0a338f19"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "## 5.1 Evaluating generative text models"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "a3350f8c-5181-4f9b-a789-4523105e98f2",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- We start this section with a brief recap of initializing a GPT model using the code from the previous chapter\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Then, we discuss basic evaluation metrics for LLMs\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Lastly, in this section, we apply these evaluation metrics to a training and validation dataset"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "bdc1cf3f-82d8-46c7-9ecc-58979ce87cdd",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "bdc1cf3f-82d8-46c7-9ecc-58979ce87cdd"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "### 5.1.1 Using GPT to generate text"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "5b3415fd-9f4a-4548-908e-9dfa56edc9bc",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- We initialize a GPT model using the code from the previous chapter"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 08:45:14 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 2,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "86000d74-624a-48f0-86da-f41926cb9e04",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "colab": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "base_uri": "https://localhost:8080/"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "86000d74-624a-48f0-86da-f41926cb9e04",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "outputId": "ad482cfd-5a62-4f0d-e1e0-008d6457f512"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "import torch\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "from previous_chapters import GPTModel\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "GPT_CONFIG_124M = {\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-04 07:27:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    \"vocab_size\": 50257,   # Vocabulary size\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"context_length\": 256, # Shortened context length (orig: 1024)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"emb_dim\": 768,        # Embedding dimension\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"n_heads\": 12,         # Number of attention heads\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"n_layers\": 12,        # Number of layers\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"drop_rate\": 0.1,      # Dropout rate\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"qkv_bias\": False      # Query-key-value bias\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "}\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "torch.manual_seed(123)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "model = GPTModel(GPT_CONFIG_124M)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "model.eval();  # Disable dropout during inference"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "09c6cf0f-7458-48a2-97fd-aa5068d65e8c",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- We use dropout of 0.1 above, but it's relatively common to train LLMs without dropout nowadays\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Modern LLMs also don't use bias vectors in the `nn.Linear` layers for the query, key, and value matrices (unlike earlier GPT models), which is achieved by setting `\"qkv_bias\": False`\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-05-12 03:17:07 +09:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "- We reduce the context length (`context_length`) of only 256 tokens to reduce the computational resource requirements for training the model, whereas the original 124 million parameter GPT-2 model used 1024 tokens\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "  - This is so that more readers will be able to follow and execute the code examples on their laptop computer\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-04 07:27:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "  - However, please feel free to increase the `context_length` to 1024 tokens (this would not require any code changes)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "  - We will also load a model with a 1024 `context_length` later from pretrained weights"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "59f80895-be35-4bb5-81cb-f357ef7367fe",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Next, we use the `generate_text_simple` function from the previous chapter to generate text\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- In addition, we define two convenience functions, `text_to_token_ids` and `token_ids_to_text`, for converting between token and text representations that we use throughout this chapter"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "741881f3-cee0-49ad-b11d-b9df3b3ac234",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch05_compressed/gpt-process.webp\" width=500px>"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 08:45:14 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 3,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "5e062b82-3540-48ce-8eb4-009686d0d16c",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Output text:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      " Every effort moves you rentingetic wasnم refres RexMeCHicular stren\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "import tiktoken\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "from previous_chapters import generate_text_simple\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "def text_to_token_ids(text, tokenizer):\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-26 06:47:05 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    encoded = tokenizer.encode(text, allowed_special={'<|endoftext|>'})\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "    encoded_tensor = torch.tensor(encoded).unsqueeze(0) # add batch dimension\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    return encoded_tensor\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "def token_ids_to_text(token_ids, tokenizer):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    flat = token_ids.squeeze(0) # remove batch dimension\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    return tokenizer.decode(flat.tolist())\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "start_context = \"Every effort moves you\"\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "tokenizer = tiktoken.get_encoding(\"gpt2\")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "token_ids = generate_text_simple(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    model=model,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    idx=text_to_token_ids(start_context, tokenizer),\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    max_new_tokens=10,\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-04 07:27:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    context_size=GPT_CONFIG_124M[\"context_length\"]\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    ")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(\"Output text:\\n\", token_ids_to_text(token_ids, tokenizer))"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "e4d3249b-b2a0-44c4-b589-ae4b403b8305",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- As we can see above, the model does not produce good text because it has not been trained yet\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- How do we measure or capture what \"good text\" is, in a numeric form, to track it during training?\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- The next subsection introduces metrics to calculate a loss metric for the generated outputs that we can use to measure the training progress\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- The next chapters on finetuning LLMs will also introduce additional ways to measure model quality"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "955f9e1a-7bf7-40d8-b1fa-eacabdee8d8e",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "<br>"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "0f3d7ea2-637f-4490-bc76-e361fc81ae98",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "0f3d7ea2-637f-4490-bc76-e361fc81ae98"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "### 5.1.2 Calculating the text generation loss: cross entropy, and perplexity"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "9e1ba8aa-fb03-4d25-957f-fe8778762440",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Suppose we have an `inputs` tensor containing the token IDs for 2 training examples (rows)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Corresponding to the `inputs`, the `targets` contain the desired token IDs that we want the model to generate\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-24 07:09:08 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "- Notice that the `targets` are the `inputs` shifted by 1 position, as explained in chapter 2 when we implemented the data loader"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 08:45:14 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 4,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "6b5402f8-ec0c-4a44-9892-18a97779ee4f",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "colab": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "base_uri": "https://localhost:8080/"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "6b5402f8-ec0c-4a44-9892-18a97779ee4f",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "outputId": "8d6fa0ff-7b37-4634-c3f0-2c050cbe81f0"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "inputs = torch.tensor([[16833, 3626, 6100],   # [\"every effort moves\",\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                       [40,    1107, 588]])   #  \"I really like\"]\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "targets = torch.tensor([[3626, 6100, 345  ],  # [\" effort moves you\",\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-05-29 20:34:23 +05:30 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "                        [1107,  588, 11311]]) #  \" really like chocolate\"]"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "33dc0645-ac2c-4973-9b40-6da40515bede",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Feeding the `inputs` to the model, we obtain the logits vector for the 2 input examples that consist of 3 tokens each\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Each of the tokens is a 50,257-dimensional vector corresponding to the size of the vocabulary\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Applying the softmax function, we can turn the logits tensor into a tensor of the same dimension containing probability scores "
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 08:45:14 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 5,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "e7b6ec51-6f8c-49bd-a349-95ba38b46fb6",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "torch.Size([2, 3, 50257])\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "with torch.no_grad():\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    logits = model(inputs)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "probas = torch.softmax(logits, dim=-1) # Probability of each token in vocabulary\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(probas.shape) # Shape: (batch_size, num_tokens, vocab_size)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "5c36a382-b5e2-4de6-9e65-0b69b685013b",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- The figure below, using a very small vocabulary for illustration purposes, outlines how we convert the probability scores back into text, which we discussed at the end of the previous chapter"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "384d86a9-0013-476c-bb6b-274fd5f20b29",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch05_compressed/proba-to-text.webp\" width=500px>"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "e8480efd-d419-4954-9ecc-2876055334bd",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- As discussed in the previous chapter, we can apply the `argmax` function to convert the probability scores into predicted token IDs\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- The softmax function above produced a 50,257-dimensional vector for each token; the `argmax` function returns the position of the highest probability score in this vector, which is the predicted token ID for the given token"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "f3b84c9f-dd08-482e-b903-a86fe44e1144",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Since we have 2 input batches with 3 tokens each, we obtain 2 by 3 predicted token IDs:"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 08:45:14 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 6,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "34ebd76a-16ec-4c17-8958-8a135735cc1c",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "colab": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "base_uri": "https://localhost:8080/"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "34ebd76a-16ec-4c17-8958-8a135735cc1c",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "outputId": "ed17da47-c3e7-4775-fd00-4ec5bcda3db2"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Token IDs:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      " tensor([[[16657],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "         [  339],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "         [42826]],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "        [[49906],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "         [29669],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "         [41751]]])\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "token_ids = torch.argmax(probas, dim=-1, keepdim=True)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(\"Token IDs:\\n\", token_ids)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "cee4072c-21ed-4df7-8721-dd2535362573",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- If we decode these tokens, we find that these are quite different from the tokens we want the model to predict, namely the target tokens:"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-05-31 07:30:57 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 8,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-25 08:09:31 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "id": "c990ead6-53cd-49a7-a6d1-14d8c1518249",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
									
										
										
										
											2024-03-25 08:09:31 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Targets batch 1:  effort moves you\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Outputs batch 1:  Armed heNetflix\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-03-25 08:09:31 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "print(f\"Targets batch 1: {token_ids_to_text(targets[0], tokenizer)}\")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(f\"Outputs batch 1: {token_ids_to_text(token_ids[0].flatten(), tokenizer)}\")"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "a53eb8a7-070e-46d6-930c-314ba55a6ff2",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- That's because the model wasn't trained yet\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- To train the model, we need to know how far it is away from the correct predictions (targets)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "ad90592f-0d5d-4ec8-9ff5-e7675beab10e",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch05_compressed/proba-index.webp\" width=500px>"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "c7251bf5-a079-4782-901d-68c9225d3157",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-03-25 08:09:31 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "- The token probabilities corresponding to the target indices are as follows:"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-05-31 07:30:57 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 9,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "54aef09c-d6e3-4238-8653-b3a1b0a1077a",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "colab": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "base_uri": "https://localhost:8080/"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "54aef09c-d6e3-4238-8653-b3a1b0a1077a",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "outputId": "41c946a2-c458-433e-a53d-5e7e89d9dddc"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
									
										
										
										
											2024-04-02 20:46:53 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "Text 1: tensor([7.4541e-05, 3.1061e-05, 1.1563e-05])\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-05-31 07:30:57 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "Text 2: tensor([1.0337e-05, 5.6776e-05, 4.7559e-06])\n"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-04-02 20:46:53 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "text_idx = 0\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "target_probas_1 = probas[text_idx, [0, 1, 2], targets[text_idx]]\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(\"Text 1:\", target_probas_1)\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-25 08:09:31 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-02 20:46:53 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "text_idx = 1\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "target_probas_2 = probas[text_idx, [0, 1, 2], targets[text_idx]]\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(\"Text 2:\", target_probas_2)"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "a0e89a19-73c2-4e49-93b4-861f699f1cbf",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- We want to maximize all these values, bringing them close to a probability of 1\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-30 06:26:23 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "- In mathematical optimization, it is easier to maximize the logarithm of the probability score than the probability score itself; this is out of the scope of this book, but I have recorded a lecture with more details here: [L8.2 Logistic Regression Loss Function](https://www.youtube.com/watch?v=GxJe0DZvydM)"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-05-31 07:30:57 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 10,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "31402a67-a16e-4aeb-977e-70abb9c9949b",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "colab": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "base_uri": "https://localhost:8080/"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "31402a67-a16e-4aeb-977e-70abb9c9949b",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "outputId": "1bf18e79-1246-4eab-efd8-12b328c78678"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
									
										
										
										
											2024-05-31 07:30:57 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "tensor([ -9.5042, -10.3796, -11.3677, -11.4798,  -9.7764, -12.2561])\n"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "# Compute logarithm of all token probabilities\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "log_probas = torch.log(torch.cat((target_probas_1, target_probas_2)))\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(log_probas)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "c4261441-a511-4633-9c4c-67998af31b84",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Next, we compute the average log probability:"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-05-31 07:30:57 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 11,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "9b003797-161b-4d98-81dc-e68320e09fec",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "colab": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "base_uri": "https://localhost:8080/"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "9b003797-161b-4d98-81dc-e68320e09fec",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "outputId": "a447fe9c-7e27-40ed-f1fb-51210e3f7cc9"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
									
										
										
										
											2024-05-31 07:30:57 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "tensor(-10.7940)\n"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "# Calculate the average probability for each token\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "avg_log_probas = torch.mean(log_probas)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(avg_log_probas)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "36d51994-ad17-4ba3-a6ec-f588b4b13585",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- The goal is to make this average log probability as large as possible by optimizing the model weights\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Due to the log, the largest possible value is 0, and we are currently far away from 0"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "3de388a1-8a0a-4c94-8894-9041dc6ad514",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- In deep learning, instead of maximizing the average log-probability, it's a standard convention to minimize the *negative* average log-probability value; in our case, instead of maximizing -10.7722 so that it approaches 0, in deep learning, we would minimize 10.7722 so that it approaches 0\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-25 08:09:31 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "- The value negative of -10.7722, i.e., 10.7722, is also called cross entropy loss in deep learning"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-05-31 07:30:57 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 12,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-25 08:09:31 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "id": "176ddf35-1c5f-4d7c-bf17-70f3e7069bd4",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
									
										
										
										
											2024-05-31 07:30:57 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "tensor(10.7940)\n"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-25 08:09:31 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "neg_avg_log_probas = avg_log_probas * -1\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(neg_avg_log_probas)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "84eeb868-abd8-4028-82db-107546bf7c2c",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "- PyTorch already implements a `cross_entropy` function that carries out the previous steps"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "5bd24b7f-b760-47ad-bc84-86d13794aa54",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch05_compressed/cross-entropy.webp\" width=400px>"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "e8aaf9dd-3ee6-42bf-a63f-6e93dbfb989d",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Before we apply the cross entropy function, let's check the shape of the logits and targets"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-05-31 07:30:57 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 13,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "695d6f64-5084-4c23-aea4-105c9e38cfe4",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "colab": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "base_uri": "https://localhost:8080/"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "695d6f64-5084-4c23-aea4-105c9e38cfe4",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "outputId": "43fd802a-8136-4b35-df0d-f61a5d4cb561"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Logits shape: torch.Size([2, 3, 50257])\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Targets shape: torch.Size([2, 3])\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-03-25 08:09:31 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "# Logits have shape (batch_size, num_tokens, vocab_size)\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "print(\"Logits shape:\", logits.shape)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "# Targets have shape (batch_size, num_tokens)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(\"Targets shape:\", targets.shape)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "1d3d65f0-6566-4865-93e4-0c0bcb10cd06",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- For the cross `entropy_loss` function in PyTorch, we want to flatten these tensors by combining them over the batch dimension:"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-05-31 07:30:57 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 14,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "0e17e027-ab9f-4fb5-ac9b-a009b831c122",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "colab": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "base_uri": "https://localhost:8080/"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "0e17e027-ab9f-4fb5-ac9b-a009b831c122",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "outputId": "0b2b778b-02fb-43b2-c879-adc59055a7d8"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Flattened logits: torch.Size([6, 50257])\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Flattened targets: torch.Size([6])\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-03-25 08:09:31 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "logits_flat = logits.flatten(0, 1)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "targets_flat = targets.flatten()\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-25 08:09:31 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "print(\"Flattened logits:\", logits_flat.shape)\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "print(\"Flattened targets:\", targets_flat.shape)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "4921a57f-3a79-473e-a863-6d63b495010f",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Note that the targets are the token IDs, which also represent the index positions in the logits tensors that we want to maximize\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- The `cross_entropy` function in PyTorch will automatically take care of applying the softmax and log-probability computation internally over those token indices in the logits that are to be maximized "
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-05-31 07:30:57 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 15,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "62d0816e-b29a-4c8f-a9a5-a167562de978",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "colab": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "base_uri": "https://localhost:8080/"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "62d0816e-b29a-4c8f-a9a5-a167562de978",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "outputId": "c0be634a-2c65-4ff7-a73f-1bfc2e406ba4"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
									
										
										
										
											2024-05-31 07:30:57 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "tensor(10.7940)\n"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "loss = torch.nn.functional.cross_entropy(logits_flat, targets_flat)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(loss)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "0f15ce17-fd7b-4d8e-99da-b237523a7a80",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- A concept related to the cross entropy loss is the perplexity of an LLM\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- The perplexity is simply the exponential of the cross entropy loss"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-05-31 07:30:57 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 16,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "168952a1-b964-4aa7-8e49-966fa26add54",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "colab": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "base_uri": "https://localhost:8080/"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "168952a1-b964-4aa7-8e49-966fa26add54",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "outputId": "a0a692c1-6412-4068-8aa5-8858548141eb"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
									
										
										
										
											2024-05-31 07:30:57 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "tensor(48725.8203)\n"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "perplexity = torch.exp(loss)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(perplexity)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "71ae26dd-d77e-41fd-b924-6bd103dd4ee7",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- The perplexity is often considered more interpretable because it can be understood as the effective vocabulary size that the model is uncertain about at each step (in the example above, that'd be 47,678 words or tokens)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- In other words, perplexity provides a measure of how well the probability distribution predicted by the model matches the actual distribution of the words in the dataset\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Similar to the loss, a lower perplexity indicates that the model predictions are closer to the actual distribution"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "2ec6c217-e429-40c7-ad71-5d0a9da8e487",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "2ec6c217-e429-40c7-ad71-5d0a9da8e487"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-03-22 09:15:40 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "### 5.1.3 Calculating the training and validation set losses"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "530da89e-2448-436c-8f1b-28e8a31ef85c",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- We use a relatively small dataset for training the LLM (in fact, only one short story)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- The reasons are:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "  - You can run the code examples in a few minutes on a laptop computer without a suitable GPU\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "  - The training finishes relatively fast (minutes instead of weeks), which is good for educational purposes\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "  - We use a text from the public domain, which can be included in this GitHub repository without violating any usage rights or bloating the repository size\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- For example, Llama 2 7B required 184,320 GPU hours on A100 GPUs to be trained on 2 trillion tokens\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "  - At the time of this writing, the hourly cost of an 8xA100 cloud server at AWS is approximately \\\\$30\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "  - So, via an off-the-envelope calculation, training this LLM would cost 184,320 / 8 * \\\\$30 =  \\\\$690,000\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    " \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Below, we use the same dataset we used in chapter 2"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-05-31 07:30:57 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 17,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "654fde37-b2a9-4a20-a8d3-0206c056e2ff",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "import os\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "import urllib.request\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "file_path = \"the-verdict.txt\"\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "url = \"https://raw.githubusercontent.com/rasbt/LLMs-from-scratch/main/ch02/01_main-chapter-code/the-verdict.txt\"\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "if not os.path.exists(file_path):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    with urllib.request.urlopen(url) as response:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        text_data = response.read().decode('utf-8')\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    with open(file_path, \"w\", encoding=\"utf-8\") as file:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        file.write(text_data)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "else:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    with open(file_path, \"r\", encoding=\"utf-8\") as file:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        text_data = file.read()"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "379330f1-80f4-4e34-8724-41d892b04cee",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- A quick check that the text loaded ok by printing the first and last 100 words"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-05-31 07:30:57 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 18,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "6kgJbe4ehI4q",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "colab": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "base_uri": "https://localhost:8080/",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "height": 35
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "6kgJbe4ehI4q",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "outputId": "9ff31e88-ee37-47e9-ee64-da6eb552f46f"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "I HAD always thought Jack Gisburn rather a cheap genius--though a good fellow enough--so it was no \n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "# First 100 characters\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(text_data[:99])"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-05-31 07:30:57 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 19,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "j2XPde_ThM_e",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "colab": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "base_uri": "https://localhost:8080/",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "height": 35
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "j2XPde_ThM_e",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "outputId": "a900c1b9-9a87-4078-968b-a5721deda5cb"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "it for me! The Strouds stand alone, and happen once--but there's no exterminating our kind of art.\"\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "# Last 100 characters\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(text_data[-99:])"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-05-31 07:30:57 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 20,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "6b46a952-d50a-4837-af09-4095698f7fd1",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "colab": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "base_uri": "https://localhost:8080/"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "6b46a952-d50a-4837-af09-4095698f7fd1",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "outputId": "c2a25334-21ca-486e-8226-0296e5fc6486"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Characters: 20479\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Tokens: 5145\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-04-05 07:26:42 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "total_characters = len(text_data)\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-19 06:34:52 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "total_tokens = len(tokenizer.encode(text_data))\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-05 07:26:42 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "print(\"Characters:\", total_characters)\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-19 06:34:52 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "print(\"Tokens:\", total_tokens)"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "a8830cb9-90f6-4e7c-8620-beeabc2d39f7",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- With 5,145 tokens, the text is very short for training an LLM, but again, it's for educational purposes (we will also load pretrained weights later)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "bedcad87-a0e8-4b9d-ac43-4e927ccbb50f",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Next, we divide the dataset into a training and a validation set and use the data loaders from chapter 2 to prepare the batches for LLM training\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- For visualization purposes, the figure below assumes a `max_length=6`, but for the training loader, we set the `max_length` equal to the context length that the LLM supports\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- The figure below only shows the input tokens for simplicity\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-22 09:15:40 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    - Since we train the LLM to predict the next word in the text, the targets look the same as these inputs, except that the targets are shifted by one position"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "46bdaa07-ba96-4ac1-9d71-b3cc153910d9",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch05_compressed/batching.webp\" width=500px>"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-05-31 07:30:57 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 21,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-26 20:34:50 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "id": "0959c855-f860-4358-8b98-bc654f047578",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "from previous_chapters import create_dataloader_v1\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "# Train/validation ratio\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "train_ratio = 0.90\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "split_idx = int(train_ratio * len(text_data))\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-26 20:34:50 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "train_data = text_data[:split_idx]\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "val_data = text_data[split_idx:]\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "torch.manual_seed(123)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "train_loader = create_dataloader_v1(\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-26 20:34:50 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    train_data,\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "    batch_size=2,\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-04 07:27:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    max_length=GPT_CONFIG_124M[\"context_length\"],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    stride=GPT_CONFIG_124M[\"context_length\"],\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "    drop_last=True,\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-13 14:57:56 -04:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    shuffle=True,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    num_workers=0\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    ")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "val_loader = create_dataloader_v1(\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-26 20:34:50 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    val_data,\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "    batch_size=2,\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-04 07:27:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    max_length=GPT_CONFIG_124M[\"context_length\"],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    stride=GPT_CONFIG_124M[\"context_length\"],\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "    drop_last=False,\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-13 14:57:56 -04:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    shuffle=False,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    num_workers=0\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    ")"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2024-03-19 06:34:52 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-05-31 07:30:57 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 22,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-19 06:34:52 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "id": "f37b3eb0-854e-4895-9898-fa7d1e67566e",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "# Sanity check\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-04 07:27:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "if total_tokens * (train_ratio) < GPT_CONFIG_124M[\"context_length\"]:\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-19 06:34:52 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    print(\"Not enough tokens for the training loader. \"\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-04 07:27:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "          \"Try to lower the `GPT_CONFIG_124M['context_length']` or \"\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-19 06:34:52 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "          \"increase the `training_ratio`\")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-04 07:27:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "if total_tokens * (1-train_ratio) < GPT_CONFIG_124M[\"context_length\"]:\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-19 06:34:52 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    print(\"Not enough tokens for the validation loader. \"\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-04 07:27:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "          \"Try to lower the `GPT_CONFIG_124M['context_length']` or \"\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-19 06:34:52 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "          \"decrease the `training_ratio`\")"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "e7ac3296-a4d1-4303-9ac5-376518960c33",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- We use a relatively small batch size to reduce the computational resource demand, and because the dataset is very small to begin with\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Llama 2 7B was trained with a batch size of 1024, for example"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "a8e0514d-b990-4dc0-9afb-7721993284a0",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- An optional check that the data was loaded correctly:"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-05-31 07:30:57 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 23,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "ca0116d0-d229-472c-9fbf-ebc229331c3e",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Train loader:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "torch.Size([2, 256]) torch.Size([2, 256])\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "torch.Size([2, 256]) torch.Size([2, 256])\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "torch.Size([2, 256]) torch.Size([2, 256])\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "torch.Size([2, 256]) torch.Size([2, 256])\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "torch.Size([2, 256]) torch.Size([2, 256])\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "torch.Size([2, 256]) torch.Size([2, 256])\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "torch.Size([2, 256]) torch.Size([2, 256])\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "torch.Size([2, 256]) torch.Size([2, 256])\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "torch.Size([2, 256]) torch.Size([2, 256])\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Validation loader:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "torch.Size([2, 256]) torch.Size([2, 256])\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(\"Train loader:\")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "for x, y in train_loader:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    print(x.shape, y.shape)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(\"\\nValidation loader:\")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "for x, y in val_loader:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    print(x.shape, y.shape)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "f7b9b1a4-863d-456f-a8dd-c07fb5c024ed",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Another optional check that the token sizes are in the expected ballpark:"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-05-31 07:30:57 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 24,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "eb860488-5453-41d7-9870-23b723f742a0",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "colab": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "base_uri": "https://localhost:8080/"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "eb860488-5453-41d7-9870-23b723f742a0",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "outputId": "96b9451a-9557-4126-d1c8-51610a1995ab"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Training tokens: 4608\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Validation tokens: 512\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "All tokens: 5120\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "train_tokens = 0\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "for input_batch, target_batch in train_loader:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    train_tokens += input_batch.numel()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "val_tokens = 0\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "for input_batch, target_batch in val_loader:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    val_tokens += input_batch.numel()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(\"Training tokens:\", train_tokens)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(\"Validation tokens:\", val_tokens)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(\"All tokens:\", train_tokens + val_tokens)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "5c3085e8-665e-48eb-bb41-cdde61537e06",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Next, we implement a utility function to calculate the cross entropy loss of a given batch\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- In addition, we implement a second utility function to compute the loss for a user-specified number of batches in a data loader"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-05-31 07:30:57 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 25,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "7b9de31e-4096-47b3-976d-b6d2fdce04bc",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "7b9de31e-4096-47b3-976d-b6d2fdce04bc"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "def calc_loss_batch(input_batch, target_batch, model, device):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    input_batch, target_batch = input_batch.to(device), target_batch.to(device)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    logits = model(input_batch)\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-04 20:54:09 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    loss = torch.nn.functional.cross_entropy(logits.flatten(0, 1), target_batch.flatten())\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "    return loss\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "def calc_loss_loader(data_loader, model, device, num_batches=None):\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-26 20:34:50 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    total_loss = 0.\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-20 08:02:30 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    if len(data_loader) == 0:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        return float(\"nan\")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    elif num_batches is None:\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "        num_batches = len(data_loader)\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-27 07:11:56 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    else:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        # Reduce the number of batches to match the total number of batches in the data loader\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        # if num_batches exceeds the number of batches in the data loader\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        num_batches = min(num_batches, len(data_loader))\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "    for i, (input_batch, target_batch) in enumerate(data_loader):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        if i < num_batches:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            loss = calc_loss_batch(input_batch, target_batch, model, device)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            total_loss += loss.item()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        else:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            break\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-26 20:34:50 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    return total_loss / num_batches"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "f0691332-84d0-48b3-b462-a885ddeb4fca",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- If you have a machine with a CUDA-supported GPU, the LLM will train on the GPU without making any changes to the code\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Via the `device` setting, we ensure that the data is loaded onto the same device as the LLM model"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-05-31 07:30:57 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 26,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "56f5b0c9-1065-4d67-98b9-010e42fc1e2a",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
									
										
										
										
											2024-03-26 20:34:50 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "Training loss: 10.98758347829183\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								      "Validation loss: 10.98110580444336\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "model.to(device) # no assignment model = model.to(device) necessary for nn.Module classes\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-26 20:34:50 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "torch.manual_seed(123) # For reproducibility due to the shuffling in the data loader\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-14 09:13:07 -04:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "with torch.no_grad(): # Disable gradient tracking for efficiency because we are not training, yet\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    train_loss = calc_loss_loader(train_loader, model, device)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    val_loss = calc_loss_loader(val_loader, model, device)\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(\"Training loss:\", train_loss)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(\"Validation loss:\", val_loss)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "43875e95-190f-4b17-8f9a-35034ba649ec",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch05_compressed/mental-model-1.webp\" width=400px>"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "b9339f8d-00cb-4206-af67-58c32bd72055",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "b9339f8d-00cb-4206-af67-58c32bd72055"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "## 5.2 Training an LLM"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "652a4cf4-e98f-46d9-bdec-60e7ccb8d6bd",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- In this section, we finally implement the code for training the LLM\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-12 19:08:34 -04:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "- We focus on a simple training function (if you are interested in augmenting this training function with more advanced techniques, such as learning rate warmup, cosine annealing, and gradient clipping, please refer to [Appendix D](../../appendix-D/01_main-chapter-code))\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch05_compressed/train-steps.webp\" width=300px>"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-05-31 07:30:57 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 27,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "Mtp4gY0ZO-qq",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "Mtp4gY0ZO-qq"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-03-19 06:34:52 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "def train_model_simple(model, train_loader, val_loader, optimizer, device, num_epochs,\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-13 14:57:56 -04:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "                       eval_freq, eval_iter, start_context, tokenizer):\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "    # Initialize lists to track losses and tokens seen\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    train_losses, val_losses, track_tokens_seen = [], [], []\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-28 08:02:05 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    tokens_seen, global_step = 0, -1\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    # Main training loop\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-19 06:34:52 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    for epoch in range(num_epochs):\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "        model.train()  # Set model to training mode\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        for input_batch, target_batch in train_loader:\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-06-09 06:14:02 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "            optimizer.zero_grad() # Reset loss gradients from previous batch iteration\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "            loss = calc_loss_batch(input_batch, target_batch, model, device)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            loss.backward() # Calculate loss gradients\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            optimizer.step() # Update model weights using loss gradients\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            tokens_seen += input_batch.numel()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            global_step += 1\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            # Optional evaluation step\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            if global_step % eval_freq == 0:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                train_loss, val_loss = evaluate_model(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                    model, train_loader, val_loader, device, eval_iter)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                train_losses.append(train_loss)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                val_losses.append(val_loss)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                track_tokens_seen.append(tokens_seen)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                print(f\"Ep {epoch+1} (Step {global_step:06d}): \"\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                      f\"Train loss {train_loss:.3f}, Val loss {val_loss:.3f}\")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        # Print a sample text after each epoch\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        generate_and_print_sample(\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-13 14:57:56 -04:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "            model, tokenizer, device, start_context\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "        )\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    return train_losses, val_losses, track_tokens_seen\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "def evaluate_model(model, train_loader, val_loader, device, eval_iter):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    model.eval()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    with torch.no_grad():\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        train_loss = calc_loss_loader(train_loader, model, device, num_batches=eval_iter)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        val_loss = calc_loss_loader(val_loader, model, device, num_batches=eval_iter)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    model.train()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    return train_loss, val_loss\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "def generate_and_print_sample(model, tokenizer, device, start_context):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    model.eval()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    context_size = model.pos_emb.weight.shape[0]\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    encoded = text_to_token_ids(start_context, tokenizer).to(device)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    with torch.no_grad():\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-19 06:34:52 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "        token_ids = generate_text_simple(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            model=model, idx=encoded,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            max_new_tokens=50, context_size=context_size\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        )\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "        decoded_text = token_ids_to_text(token_ids, tokenizer)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        print(decoded_text.replace(\"\\n\", \" \"))  # Compact print format\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    model.train()"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "a301b333-b9d4-4eeb-a212-3a9874e3ac47",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Now, let's train the LLM using the training function defined above:"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-05-31 07:30:57 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 28,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "3422000b-7aa2-485b-92df-99372cd22311",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "colab": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "base_uri": "https://localhost:8080/"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "3422000b-7aa2-485b-92df-99372cd22311",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "outputId": "0e046603-908d-4093-8ae5-ef2f632639fb"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
									
										
										
										
											2024-04-12 19:55:07 -04:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "Ep 1 (Step 000000): Train loss 9.781, Val loss 9.933\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Ep 1 (Step 000005): Train loss 8.111, Val loss 8.339\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								      "Every effort moves you,,,,,,,,,,,,.                                     \n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-12 19:55:07 -04:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "Ep 2 (Step 000010): Train loss 6.661, Val loss 7.048\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Ep 2 (Step 000015): Train loss 5.961, Val loss 6.616\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Every effort moves you, and, and, and, and, and, and, and, and, and, and, and, and, and, and, and, and, and, and, and, and, and, and,, and, and,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Ep 3 (Step 000020): Train loss 5.726, Val loss 6.600\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Ep 3 (Step 000025): Train loss 5.201, Val loss 6.348\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Every effort moves you, and I had been.                                            \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Ep 4 (Step 000030): Train loss 4.417, Val loss 6.278\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Ep 4 (Step 000035): Train loss 4.069, Val loss 6.226\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Every effort moves you know the                          \"I he had the donkey and I had the and I had the donkey and down the room, I had\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Ep 5 (Step 000040): Train loss 3.732, Val loss 6.160\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Every effort moves you know it was not that the picture--I had the fact by the last I had been--his, and in the            \"Oh, and he said, and down the room, and in\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Ep 6 (Step 000045): Train loss 2.850, Val loss 6.179\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Ep 6 (Step 000050): Train loss 2.427, Val loss 6.141\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Every effort moves you know,\" was one of the picture. The--I had a little of a little: \"Yes, and in fact, and in the picture was, and I had been at my elbow and as his pictures, and down the room, I had\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Ep 7 (Step 000055): Train loss 2.104, Val loss 6.134\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Ep 7 (Step 000060): Train loss 1.882, Val loss 6.233\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Every effort moves you know,\" was one of the picture for nothing--I told Mrs.  \"I was no--as! The women had been, in the moment--as Jack himself, as once one had been the donkey, and were, and in his\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Ep 8 (Step 000065): Train loss 1.320, Val loss 6.238\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Ep 8 (Step 000070): Train loss 0.985, Val loss 6.242\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Every effort moves you know,\" was one of the axioms he had been the tips of a self-confident moustache, I felt to see a smile behind his close grayish beard--as if he had the donkey. \"strongest,\" as his\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Ep 9 (Step 000075): Train loss 0.717, Val loss 6.293\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Ep 9 (Step 000080): Train loss 0.541, Val loss 6.393\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Every effort moves you?\"  \"Yes--quite insensible to the irony. She wanted him vindicated--and by me!\"  He laughed again, and threw back the window-curtains, I had the donkey. \"There were days when I\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Ep 10 (Step 000085): Train loss 0.391, Val loss 6.452\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Every effort moves you know,\" was one of the axioms he laid down across the Sevres and silver of an exquisitely appointed luncheon-table, when, on a later day, I had again run over from Monte Carlo; and Mrs. Gis\n"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "torch.manual_seed(123)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "model = GPTModel(GPT_CONFIG_124M)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "model.to(device)\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-12 19:55:07 -04:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "optimizer = torch.optim.AdamW(model.parameters(), lr=0.0004, weight_decay=0.1)\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-19 06:34:52 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "num_epochs = 10\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "train_losses, val_losses, tokens_seen = train_model_simple(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    model, train_loader, val_loader, optimizer, device,\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-27 07:11:56 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    num_epochs=num_epochs, eval_freq=5, eval_iter=5,\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-13 14:57:56 -04:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    start_context=\"Every effort moves you\", tokenizer=tokenizer\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    ")"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-05-31 07:30:57 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 29,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "0WSRu2i0iHJE",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "colab": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "base_uri": "https://localhost:8080/",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "height": 487
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "0WSRu2i0iHJE",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "outputId": "9d36c61b-517d-4f07-a7e8-4563aff78b11"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "data": {
							 
						 
					
						
							
								
									
										
										
										
											2024-05-31 07:30:57 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAeoAAAEiCAYAAAA21pHjAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy80BEi2AAAACXBIWXMAAA9hAAAPYQGoP6dpAABXqUlEQVR4nO3dd3gU5drH8e9uyqb3DiQEElLo3RCwEQmISFFRT1RAlCMdEUVUEGyIIgdBDlhe4VgQGyBSBaSGKhCKhNBCQkmhpZOQZJ/3jyUbliaBhN2E+3Ndc7E788zMvUOS387MMzMapZRCCCGEEBZJa+4ChBBCCHF9EtRCCCGEBZOgFkIIISyYBLUQQghhwSSohRBCCAsmQS2EEEJYMAlqIYQQwoJJUAshhBAWTIJaCCGEsGAS1ELUAMeOHUOj0ZCQkGDuUoQQlUyCWggLodFobjiMHz/e3CUKIczA2twFCCEM0tLSjK9//PFHxo0bR1JSknGck5OTOcoSQpiZ7FELYSH8/PyMg6urKxqNxvjex8eHKVOmULt2bXQ6Hc2aNWP58uXXXVZpaSnPP/884eHhpKamAvDbb7/RokUL7OzsqFevHhMmTKCkpMQ4j0aj4auvvqJnz544ODgQGhrKokWLjNPPnz9PXFwc3t7e2NvbExoayuzZs69bwy+//ELjxo2xt7fH09OTmJgY8vPzjdO/+uorIiIisLOzIzw8nP/+978m8x8/fpzevXvj5uaGh4cH3bt359ixY8bpffv2pUePHkyePBl/f388PT0ZPHgwxcXFN73NhagWlBDC4syePVu5uroa30+ZMkW5uLioH374QR04cEC99tprysbGRh08eFAppVRycrIC1K5du1RhYaHq2bOnat68ucrMzFRKKbV+/Xrl4uKi5syZo44cOaL++OMPVbduXTV+/HjjOgBVu3ZtNXfuXHXo0CE1bNgw5eTkpM6ePauUUmrw4MGqWbNmavv27So5OVmtXLlSLVq06Jr1nzp1SllbW6spU6ao5ORktWfPHjVjxgyVm5urlFLqu+++U/7+/urXX39VR48eVb/++qvy8PBQc+bMUUopdfHiRRUREaGef/55tWfPHrV//371r3/9S4WFhamioiKllFJ9+vRRLi4u6qWXXlKJiYnq999/Vw4ODuqLL76o3P8MIcxMgloIC3RlUAcEBKj333/fpE3r1q3VoEGDlFLlQb1hwwbVsWNH1b59e5WVlWVs27FjR/XBBx+YzP/tt98qf39/43tAvfXWW8b3eXl5ClDLli1TSinVrVs31a9fv5uqf8eOHQpQx44du+b0+vXrq7lz55qMe/fdd1VUVJSxtrCwMKXX643Ti4qKlL29vVqxYoVSyhDUQUFBqqSkxNjmiSeeUE8++eRN1ShEdSHnqIWwcDk5OZw6dYro6GiT8dHR0ezevdtk3NNPP03t2rX5888/sbe3N47fvXs38fHxvP/++8ZxpaWlFBYWUlBQgIODAwBNmjQxTnd0dMTFxYXMzEwABg4cyGOPPcbOnTvp1KkTPXr0oF27dtesuWnTpnTs2JHGjRsTGxtLp06dePzxx3F3dyc/P58jR47Qv39/XnzxReM8JSUluLq6Gus9fPgwzs7OJsstLCzkyJEjxvcNGzbEysrK+N7f35+9e/feYGsKUf1IUAtRgzz88MN89913bN68mQcffNA4Pi8vjwkTJtCrV6+r5rGzszO+trGxMZmm0WjQ6/UAdOnShZSUFJYuXcrKlSvp2LEjgwcPZvLkyVct08rKipUrV7Jp0yb++OMPpk+fzptvvsnWrVuNXwq+/PJL2rZte9V8ZfW2bNmS77///qple3t731S9QtQUEtRCWDgXFxcCAgKIj4/nvvvuM46Pj4+nTZs2Jm0HDhxIo0aNePTRR1myZImxfYsWLUhKSiIkJOS2avH29qZPnz706dOHDh068Oqrr14zqMEQmtHR0URHRzNu3DiCgoJYsGABI0eOJCAggKNHjxIXF3fNeVu0aMGPP/6Ij48PLi4ut1WzENWdBLUQ1cCrr77K22+/Tf369WnWrBmzZ88mISHhmnucQ4cOpbS0lEceeYRly5bRvn17xo0bxyOPPEJgYCCPP/44Wq2W3bt3s2/fPt57772bqmHcuHG0bNmShg0bUlRUxOLFi4mIiLhm261bt7J69Wo6deqEj48PW7du5fTp08b2EyZMYNiwYbi6utK5c2eKior466+/OH/+PCNHjiQuLo6PP/6Y7t27884771C7dm1SUlKYP38+r732GrVr1771jSlENSNBLUQ1MGzYMLKzs3nllVfIzMwkMjKSRYsWERoaes32I0aMQK/X8/DDD7N8+XJiY2NZvHgx77zzDpMmTcLGxobw8HBeeOGFm67B1taWMWPGcOzYMezt7enQoQPz5s27ZlsXFxfWr1/P1KlTycnJISgoiE8++YQuXboA8MILL+Dg4MDHH3/Mq6++iqOjI40bN2bEiBEAODg4sH79ekaPHk2vXr3Izc2lVq1adOzYUfawxV1Ho5RS5i5CCCGEENcmNzwRQgghLJgEtRBCCGHBJKiFEEIICyZBLYQQQlgwCWohhBDCgklQCyGEEBZMgvo6ZsyYQd26dbGzs6Nt27Zs27bN3CVZhPXr19OtWzcCAgLQaDQsXLjQZLpSinHjxuHv74+9vT0xMTEcOnTIpM25c+eIi4vDxcUFNzc3+vfvT15enkmbPXv20KFDB+zs7KhTpw4fffTRVbX8/PPPhIeHY2dnR+PGjVm6dGmlf947aeLEibRu3RpnZ2d8fHzo0aOHyfOowXCv68GDB+Pp6YmTkxOPPfYYGRkZJm1SU1Pp2rUrDg4O+Pj48Oqrr5o8zhJg7dq1tGjRAp1OR0hICHPmzLmqnpr4OzBz5kyaNGmCi4sLLi4uREVFsWzZMuN02b6V68MPP0Sj0RivjwfZxrfEzA8FsUjz5s1Ttra26uuvv1Z///23evHFF5Wbm5vKyMgwd2lmt3TpUvXmm2+q+fPnK0AtWLDAZPqHH36oXF1d1cKFC9Xu3bvVo48+qoKDg9WFCxeMbTp37qyaNm2qtmzZojZs2KBCQkLU008/bZyenZ2tfH19VVxcnNq3b5/64YcflL29vfr888+NbeLj45WVlZX66KOP1P79+9Vbb72lbGxs1N69e6t8G1SV2NhYNXv2bLVv3z6VkJCgHn74YRUYGKjy8vKMbV566SVVp04dtXr1avXXX3+pe+65R7Vr1844vaSkRDVq1EjFxMSoXbt2qaVLlyovLy81ZswYY5ujR48qBwcHNXLkSLV//341ffp0ZWVlpZYvX25sU1N/BxYtWqSWLFmiDh48qJKSktQbb7yhbGxs1L59+5RSsn0r07Zt21TdunVVkyZN1PDhw43jZRtXnAT1NbRp00YNHjzY+L60tFQFBASoiRMnmrEqy3NlUOv1euXn56c+/vhj47isrCyl0+nUDz/8oJRSav/+/QpQ27dvN7ZZtmyZ0mg06uTJk0oppf773/8qd3d343OHlVJq9OjRKiwszPi+d+/eqmvXrib1tG3bVv373/+u1M9oTpmZmQpQ69atU0oZtqWNjY36+eefjW0SExMVoDZv3qyUMnyR0mq1Kj093dhm5syZysXFxbg9X3vtNdWwYUOTdT355JMqNjbW+P5u+h1wd3dXX331lWzfSpSbm6tCQ0PVypUr1X333WcMatnGt0YOfV/h4sWL7Nixg5iYGOM4rVZLTEwMmzdvNmNlli85OZn09HSTbefq6krbtm2N227z5s24ubnRqlUrY5uYmBi0Wi1bt241trn33nuxtbU1tomNjSUpKYnz588b21y+nrI2Nen/KDs7GwAPDw8AduzYQXFxscnnDg8PJzAw0GT7Nm7cGF9fX2Ob2NhYcnJy+Pvvv41tbrTt7pbfgdLSUubNm0d+fj5RUVGyfSvR4MGD6dq161XbQbbxrZF7fV/hzJkzlJaWmvyQAPj6+nLgwAEzVVU9pKenA1xz25VNS09Px8fHx2S6tbU1Hh4eJm2Cg4OvWkbZNHd3d9LT02+4nupOr9czYsQIoqOjadSoEWD
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								      "text/plain": [
							 
						 
					
						
							
								
									
										
										
										
											2024-04-01 08:05:01 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								       "<Figure size 500x300 with 2 Axes>"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								      ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "display_data"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "import matplotlib.pyplot as plt\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "def plot_losses(epochs_seen, tokens_seen, train_losses, val_losses):\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-01 08:05:01 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    fig, ax1 = plt.subplots(figsize=(5, 3))\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    # Plot training and validation loss against epochs\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    ax1.plot(epochs_seen, train_losses, label=\"Training loss\")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    ax1.plot(epochs_seen, val_losses, linestyle=\"-.\", label=\"Validation loss\")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    ax1.set_xlabel(\"Epochs\")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    ax1.set_ylabel(\"Loss\")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    ax1.legend(loc=\"upper right\")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    # Create a second x-axis for tokens seen\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    ax2 = ax1.twiny()  # Create a second x-axis that shares the same y-axis\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    ax2.plot(tokens_seen, train_losses, alpha=0)  # Invisible plot for aligning ticks\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    ax2.set_xlabel(\"Tokens seen\")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    fig.tight_layout()  # Adjust layout to make room\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-27 07:11:56 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    plt.savefig(\"loss-plot.pdf\")\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "    plt.show()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-19 06:34:52 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "epochs_tensor = torch.linspace(0, num_epochs, len(train_losses))\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-29 08:16:22 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "plot_losses(epochs_tensor, tokens_seen, train_losses, val_losses)"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "8bc83ded-5f80-4e1c-bf4d-ccb59999d995",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Looking at the results above, we can see that the model starts out generating incomprehensible strings of words, whereas towards the end, it's able to produce grammatically more or less correct sentences\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- However, based on the training and validation set losses, we can see that the model starts overfitting\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- If we were to check a few passages it writes towards the end, we would find that they are contained in the training set verbatim -- it simply memorizes the training data\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Later, we will cover decoding strategies that can mitigate this memorization by a certain degree\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Note that the overfitting here occurs because we have a very, very small training set, and we iterate over it so many times\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "  - The LLM training here primarily serves educational purposes; we mainly want to see that the model can learn to produce coherent text\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "  - Instead of spending weeks or months on training this model on vast amounts of expensive hardware, we load pretrained weights later"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "eb380c42-b31c-4ee1-b8b9-244094537272",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch05_compressed/mental-model-2.webp\" width=350px>"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "de713235-1561-467f-bf63-bf11ade383f0",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-04-10 22:09:46 -04:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "**If you are interested in augmenting this training function with more advanced techniques, such as learning rate warmup, cosine annealing, and gradient clipping, please refer to [Appendix D](../../appendix-D/01_main-chapter-code)**"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "6d5cdf2f-09a5-4eb0-a20a-d7aac5c14c2c",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "**If you are interested in a larger training dataset and longer training run, see [../03_bonus_pretraining_on_gutenberg](../03_bonus_pretraining_on_gutenberg)**"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "699f45fc-bf78-42f2-bd24-2355db41b28f",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "699f45fc-bf78-42f2-bd24-2355db41b28f"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "## 5.3 Decoding strategies to control randomness"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "6be9086e-2c27-41da-97d0-49137d0ba3c7",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Inference is relatively cheap with a relatively small LLM as the GPT model we trained above, so there's no need to use a GPU for it in case you used a GPU for training it above\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-28 08:02:05 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "- Using the `generate_text_simple` function (from the previous chapter) that we used earlier inside the simple training function, we can generate new text one word (or token) at a time\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "- As explained in section 5.1.2, the next generated token is the token corresponding to the largest probability score among all tokens in the vocabulary"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-05-31 07:30:57 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 30,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "2734cee0-f6f9-42d5-b71c-fa7e0ef28b6d",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Output text:\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-12 19:55:07 -04:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      " Every effort moves you know,\" was one of the axioms he laid down across the Sevres and silver of an exquisitely appointed lun\n"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "model.to(\"cpu\")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "model.eval()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "tokenizer = tiktoken.get_encoding(\"gpt2\")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "token_ids = generate_text_simple(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    model=model,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    idx=text_to_token_ids(\"Every effort moves you\", tokenizer),\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    max_new_tokens=25,\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-04 07:27:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    context_size=GPT_CONFIG_124M[\"context_length\"]\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    ")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(\"Output text:\\n\", token_ids_to_text(token_ids, tokenizer))"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "d25dbe31-bb7c-4893-b25b-47d0492d4aa4",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Even if we execute the `generate_text_simple` function above multiple times, the LLM will always generate the same outputs\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- We now introduce two concepts, so-called decoding strategies, to modify the `generate_text_simple`: *temperature scaling* and *top-k* sampling\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- These will allow the model to control the randomness and diversity of the generated text"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "4bb6f380-a798-4fd9-825c-17b7cd29a994",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "### 5.3.1 Temperature scaling"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "a7f4f53c-0612-43d3-aa82-52447eac50fa",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Previously, we always sampled the token with the highest probability as the next token using `torch.argmax`\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- To add variety, we can sample the next token using The `torch.multinomial(probs, num_samples=1)`, sampling from a probability distribution\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Here, each index's chance of being picked corresponds to its probability in the input tensor"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "e7531bae-d5de-44c0-bc78-78fed077e22a",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Here's a little recap of generating the next token, assuming a very small vocabulary for illustration purposes:"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-05-31 07:30:57 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 31,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "01a5ce39-3dc8-4c35-96bc-6410a1e42412",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "forward\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "vocab = { \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"closer\": 0,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"every\": 1, \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"effort\": 2, \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"forward\": 3,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"inches\": 4,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"moves\": 5, \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"pizza\": 6,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"toward\": 7,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"you\": 8,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "} \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "inverse_vocab = {v: k for k, v in vocab.items()}\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-22 09:15:40 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "# Suppose input is \"every effort moves you\", and the LLM\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "# returns the following logits for the next token:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "next_token_logits = torch.tensor(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    [4.51, 0.89, -1.90, 6.75, 1.63, -1.62, -1.89, 6.28, 1.79]\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    ")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "probas = torch.softmax(next_token_logits, dim=0)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "next_token_id = torch.argmax(probas).item()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-22 09:15:40 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "# The next generated token is then as follows:\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "print(inverse_vocab[next_token_id])"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2024-03-28 08:02:05 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-05-31 07:30:57 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 32,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-28 08:02:05 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "id": "6400572f-b3c8-49e2-95bc-433e55c5b3a1",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "forward\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "torch.manual_seed(123)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "next_token_id = torch.multinomial(probas, num_samples=1).item()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(inverse_vocab[next_token_id])"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-05-31 07:30:57 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 33,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-28 08:02:05 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "id": "b23b863e-252a-403c-b5b1-62bc0a42319f",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "73 x closer\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "0 x every\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "0 x effort\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "582 x forward\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "2 x inches\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "0 x moves\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "0 x pizza\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "343 x toward\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "def print_sampled_tokens(probas):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    torch.manual_seed(123) # Manual seed for reproducibility\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    sample = [torch.multinomial(probas, num_samples=1).item() for i in range(1_000)]\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    sampled_ids = torch.bincount(torch.tensor(sample))\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    for i, freq in enumerate(sampled_ids):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        print(f\"{freq} x {inverse_vocab[i]}\")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print_sampled_tokens(probas)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "c63d0a27-830b-42b5-9986-6d1a7de04dd9",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Instead of determining the most likely token via `torch.argmax`, we use `torch.multinomial(probas, num_samples=1)` to determine the most likely token by sampling from the softmax distribution\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- For illustration purposes, let's see what happens when we sample the next token 1,000 times using the original softmax probabilities:"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "32e7d9cf-a26d-4d9a-8664-4af1efa73832",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-03-28 08:02:05 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "- We can control the distribution and selection process via a concept called temperature scaling\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- \"Temperature scaling\" is just a fancy word for dividing the logits by a number greater than 0\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "- Temperatures greater than 1 will result in more uniformly distributed token probabilities after applying the softmax\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Temperatures smaller than 1 will result in more confident (sharper or more peaky) distributions after applying the softmax"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-05-31 07:30:57 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 34,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "0759e4c8-5362-467c-bec6-b0a19d1ba43d",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "def softmax_with_temperature(logits, temperature):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    scaled_logits = logits / temperature\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    return torch.softmax(scaled_logits, dim=0)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "# Temperature values\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-28 08:02:05 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "temperatures = [1, 0.1, 5]  # Original, higher confidence, and lower confidence\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "# Calculate scaled probabilities\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "scaled_probas = [softmax_with_temperature(next_token_logits, T) for T in temperatures]"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-05-31 07:30:57 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 35,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "2e66e613-4aca-4296-a984-ddd0d80c6578",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "data": {
							 
						 
					
						
							
								
									
										
										
										
											2024-05-31 07:30:57 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAeoAAAEiCAYAAAA21pHjAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy80BEi2AAAACXBIWXMAAA9hAAAPYQGoP6dpAABM5klEQVR4nO3deVxU1f8/8Newg2wimyAKiiYUO0q4oUWCGmqkGWooIt8scYFwjUUgwDQR/YRiKu5rRlqaJvIRcc0dMxEDREhBcSVA1jm/P/xxP44DyH7v4Pv5eMzjw5y5d+Y185l8zz333HNEjDEGQgghhAiSHN8BCCGEEFI/KtSEEEKIgFGhJoQQQgSMCjUhhBAiYFSoCSGEEAGjQk0IIYQIGBVqQgghRMCoUBNCCCECpsB3gPYmFotx7949aGhoQCQS8R2HEELIG4gxhn///RdGRkaQk2v4mPmNK9T37t2DiYkJ3zEIIYQQ5Ofno1u3bg1u88YVag0NDQAvPhxNTU2e0xBCCHkTFRcXw8TEhKtJDXnjCnVtd7empiYVakIIIbxqzClYGkxGCCGECBivhTotLQ0eHh4wMjKCSCTC/v37X7tPamoq7O3toaysDHNzc2zevLnNcxJCCCF84bVQl5aWwsbGBvHx8Y3a/vbt2xg1ahSGDRuGq1evYu7cuZg+fTp+//33Nk5KCCGE8IPXc9QjRozAiBEjGr19QkICzMzMsGLFCgCAhYUFTp06hZUrV8LNza2tYhJC2plYLEZlZSXfMQhpNkVFRcjLy7fKc8nUYLKzZ8/C1dVVos3NzQ1z586td5+KigpUVFRw94uLi9sqHiGkFVRWVuL27dsQi8V8RyGkRbS1tWFoaNjiOTtkqlAXFhbCwMBAos3AwADFxcV4/vw5VFVVpfaJiYlBeHh4e0UkhLQAYwwFBQWQl5eHiYnJayeCIESIGGMoKyvDgwcPAABdu3Zt0fPJVKFujkWLFiEwMJC7X3vtGiFEeKqrq1FWVgYjIyOoqanxHYeQZqs9cHzw4AH09fVb1A0uU4Xa0NAQ9+/fl2i7f/8+NDU16zyaBgBlZWUoKyu3RzxCGm+JVgOPPWu/HAJTU1MDAFBSUuI5CSEtV/tjs6qqqkWFWqb6lZydnZGSkiLRlpycDGdnZ54SEULaAs3DTzqC1voe81qoS0pKcPXqVVy9ehXAi8uvrl69iry8PAAvuq29vb257WfMmIGcnBzMnz8fN2/exJo1a7B3714EBATwEZ8QQghpc7wW6osXL8LOzg52dnYAgMDAQNjZ2SE0NBQAUFBQwBVtADAzM8OhQ4eQnJwMGxsbrFixAhs2bKBLswghhHRYvJ6jHjp0KBhj9T5e16xjQ4cOxZUrV9owFSFEaEwXHmrX18tdOqrR276uezMsLAxLlixpYSJhMTU1xdy5cxu8NFboZs+ejdOnT+P69euwsLDgenaFSKYGkxFCiNAUFBRwf+/ZswehoaHIzMzk2tTV1fmI1WSMMdTU1EBBof3KQmVlJa8DB6dNm4Y//vgD165d4y1DY8jUYDJCCBEaQ0ND7qalpQWRSCTRtnv3blhYWEBFRQV9+/bFmjVruH1zc3MhEomwd+9eDB48GKqqqujXrx9u3bqFCxcuwNHREerq6hgxYgSKioq4/aZOnYqxY8ciPDwcenp60NTUxIwZMyRmcxOLxYiJiYGZmRlUVVVhY2ODffv2cY+npqZCJBLh8OHDcHBwgLKyMk6dOoXs7GyMGTMGBgYGUFdXR79+/XDs2DFuv6FDh+LOnTsICAiASCTiehSWLFkCW1tbic8mLi4OpqamUrmjoqJgZGSEt956C8CLZYc/+eQTaGtrQ0dHB2PGjEFubm5r/N9Tr9WrV2PmzJno2bNnm75Oa6BCTQghbWTHjh0IDQ1FVFQUMjIyEB0djZCQEGzZskViu7CwMAQHB+Py5ctQUFDAxIkTMX/+fKxatQonT55EVlYWN3anVkpKCjIyMpCamopdu3YhKSlJYnKnmJgYbN26FQkJCfjrr78QEBCAyZMn48SJExLPs3DhQixduhQZGRmwtrZGSUkJRo4ciZSUFFy5cgXu7u7w8PDgxgslJSWhW7duiIiIQEFBgUSPQmOkpKQgMzMTycnJOHjwIKqqquDm5gYNDQ2cPHkSp0+fhrq6Otzd3RucRlZdXb3B24wZM5qUS8io65sQQtpIWFgYVqxYAU9PTwAvBsTeuHED69atw5QpU7jtgoKCuEGxc+bMgZeXF1JSUjBw4EAAgK+vr9SYHSUlJSQmJkJNTQ1vv/02IiIiMG/ePERGRqKqqgrR0dE4duwYd/lqz549cerUKaxbtw4uLi7c80REROCDDz7g7uvo6MDGxoa7HxkZiZ9//hm//PIL/P39oaOjA3l5eWhoaMDQ0LDJn0mnTp2wYcMGrst7+/btEIvF2LBhA3d0vmnTJmhrayM1NRXDhw+v83led05ZU1OzydmEigo1IYS0gdLSUmRnZ8PX1xd+fn5ce3V1NbS0JCe8sba25v6unSbZyspKoq12OspaNjY2ErO3OTs7o6SkBPn5+SgpKUFZWZlEAQZenBOuvcqmlqOjo8T9kpISLFmyBIcOHUJBQQGqq6vx/PlziStwWsLKykrivHR6ejqysrKgoaEhsV15eTmys7PrfR5zc/NWySMLqFATQkgbKCkpAQCsX78eTk5OEo+9OkuVoqIi93ftUeWrbU1ZpKT2tQ8dOgRjY2OJx16dqbFTp04S94OCgpCcnIzvvvsO5ubmUFVVxbhx4167mpmcnJzUVTxVVVVS2736eiUlJXBwcMCOHTukttXT06v39V43SG/y5MlISEhocBtZQYWaEELagIGBAYyMjJCTk4NJkya1+vOnp6dLLEZ07tw5qKurw8TEBDo6OlBWVkZeXp5EN3djnD59GlOnTsVHH30E4EUhfXVgl5KSEjfday09PT0UFhaCMcb92GjMJU/29vbYs2cP9PX1m9RdTV3fhBBCWiw8PByzZ8+GlpYW3N3dUVFRgYsXL+LJkycSiwU1R2VlJXx9fREcHIzc3FyEhYXB398fcnJy0NDQQFBQEAICAiAWizFo0CA8e/YMp0+fhqampsT58Vf17t0bSUlJ8PDwgEgkQkhIiNTRvKmpKdLS0vDpp59CWVkZurq6GDp0KIqKirBs2TKMGzcOR44cweHDh19bMCdNmoTly5djzJgxiIiIQLdu3XDnzh0kJSVh/vz56NatW537tbTrOysrCyUlJSgsLMTz58+5wm9paSm4ueZp1DchhLSR6dOnY8OGDdi0aROsrKzg4uKCzZs3w8zMrMXP/f7776N3794YMmQIJkyYgNGjR0tMrBIZGYmQkBDExMTAwsIC7u7uOHTo0GtfOzY2Fp07d8aAAQPg4eEBNzc32NvbS2wTERGB3Nxc9OrVi+uetrCwwJo1axAfHw8bGxucP38eQUFBr30fampqSEtLQ/fu3eHp6QkLCwv4+vqivLy8TY+Kp0+fDjs7O6xbtw63bt3iZsm8d+9em71mc4lYQ1ODdUDFxcXQ0tLCs2fPOlTXCJExtHpWncrLy3H79m2YmZlBRUWF7ziCNXXqVDx9+hT79+/nOwppQEPf56bUIjqiJoQQQgSMCjUhhBAiYDSYjBBCZExdCxaRjouOqAkhhBABo0JNCCGECBgVakIIIUTAqFATQgghAkaFmhBCCBEwKtSEEEKIgFGhJoSQFhCJRA3eXp7Ws6MwNTVFXFwc3zFaJC8vD6NGjYKamhr09fUxb948VFdXN7hPVFQUBgwYADU1NWhra7dPUNB11IQQWdDQlKtt8nqNn8a1oKCA+3vPnj0IDQ1FZmYm1/a65RiFgjGGmpoaKCi0X1morKzkZQGMmpoajBo1CoaGhjhz5gwKCgrg7e0NRUVFREdH17tfZWUlxo8fD2dnZ2zcuLHd8tIRNSGEtIChoSF309LSgkgkkmjbvXs3LCwsoKKigr59+2LNmjXcvrm5uRCJRNi7dy8GDx4MVVVV9OvXD7du3cKFCxfg6OgIdXV1jBgxAkVFRdx+U6dOxdixYxEeHg49PT1oampixow
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								      "text/plain": [
							 
						 
					
						
							
								
									
										
										
										
											2024-04-01 08:05:01 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								       "<Figure size 500x300 with 1 Axes>"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								      ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "display_data"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "# Plotting\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "x = torch.arange(len(vocab))\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "bar_width = 0.15\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-01 08:05:01 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "fig, ax = plt.subplots(figsize=(5, 3))\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "for i, T in enumerate(temperatures):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    rects = ax.bar(x + i * bar_width, scaled_probas[i], bar_width, label=f'Temperature = {T}')\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "ax.set_ylabel('Probability')\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "ax.set_xticks(x)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "ax.set_xticklabels(vocab.keys(), rotation=90)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "ax.legend()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-22 09:15:40 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "plt.tight_layout()\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-01 08:05:01 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "plt.savefig(\"temperature-plot.pdf\")\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "plt.show()"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "d750e989-842a-4cfa-a44b-cf44d6e49163",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- We can see that the rescaling via temperature 0.1 results in a sharper distribution, approaching `torch.argmax`, such that the most likely word is almost always selected:"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-05-31 07:30:57 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 36,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "e4600713-c51e-4f53-bf58-040a6eb362b8",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "0 x closer\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "0 x every\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "0 x effort\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "985 x forward\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "0 x inches\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "0 x moves\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "0 x pizza\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "15 x toward\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print_sampled_tokens(scaled_probas[1])"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "526e93cb-8e2a-42a1-b1ba-4fd5fe64c26b",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-03-23 11:28:20 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "- The rescaled probabilities via temperature 5 are more uniformly distributed:"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-05-31 07:30:57 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 37,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "9dfb48f0-bc3f-46a5-9844-33b6c9b0f4df",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "165 x closer\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "75 x every\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "42 x effort\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "239 x forward\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "71 x inches\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "46 x moves\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "32 x pizza\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "227 x toward\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "103 x you\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print_sampled_tokens(scaled_probas[2])"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "0c83f0c4-3774-4375-ad7f-96440ba5fef7",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Assuming an LLM input \"every effort moves you\", using the approach above can sometimes result in nonsensical texts, such as \"every effort moves you pizza\", 3.2% of the time (32 out of 1000 times)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "c6e4873e-07e4-4abb-85df-bdaedcc1a6f7",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "### 5.3.2 Top-k sampling"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "6d4da95a-8bb2-4f69-a9b0-a643531db5df",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- To be able to use higher temperatures to increase output diversity and to reduce the probability of nonsensical sentences, we can restrict the sampled tokens to the top-k most likely tokens:"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "7ae6fffd-2730-4abe-a2d3-781fc4836f17",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-06-12 19:50:11 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch05_compressed/topk.webp\" width=500px>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- (Please note that the numbers in this figure are truncated to two\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "digits after the decimal point to reduce visual clutter. The values in the Softmax row should add up to 1.0.)"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "0ba12da5-6ff1-4008-91b8-d2d537cbc14c",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- In code, we can implement this as follows:"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-05-31 07:30:57 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 38,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "2a7f908a-e9ec-446a-b407-fb6dbf05c806",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Top logits: tensor([6.7500, 6.2800, 4.5100])\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Top positions: tensor([3, 7, 0])\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "top_k = 3\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "top_logits, top_pos = torch.topk(next_token_logits, top_k)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(\"Top logits:\", top_logits)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(\"Top positions:\", top_pos)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-05-31 07:30:57 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 39,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "753865ed-79c5-48b1-b9f2-ccb132ff1d2f",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
									
										
										
										
											2024-03-29 08:16:22 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "tensor([4.5100,   -inf,   -inf, 6.7500,   -inf,   -inf,   -inf, 6.2800,   -inf])\n"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "new_logits = torch.where(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    condition=next_token_logits < top_logits[-1],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    input=torch.tensor(float('-inf')), \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    other=next_token_logits\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    ")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-29 08:16:22 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "print(new_logits)"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-05-31 07:30:57 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 40,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "4844f000-c329-4e7e-aa89-16a2c4ebee43",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "tensor([0.0615, 0.0000, 0.0000, 0.5775, 0.0000, 0.0000, 0.0000, 0.3610, 0.0000])\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "topk_probas = torch.softmax(new_logits, dim=0)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(topk_probas)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "56056503-a15d-4315-a3ff-46647a4c7c45",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "### 5.3.3 Modifying the text generation function"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "34770423-473d-46f6-a5fa-6b2979564d26",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- The previous two subsections introduced temperature sampling and top-k sampling\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Let's use these two concepts to modify the `generate_simple` function we used to generate text via the LLM earlier, creating a new `generate` function:"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-05-31 07:30:57 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 41,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "8e318891-bcc0-4d71-b147-33ce55febfa3",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-05-19 09:04:49 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "def generate(model, idx, max_new_tokens, context_size, temperature=0.0, top_k=None, eos_id=None):\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    # For-loop is the same as before: Get logits, and only focus on last time step\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    for _ in range(max_new_tokens):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        idx_cond = idx[:, -context_size:]\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        with torch.no_grad():\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            logits = model(idx_cond)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        logits = logits[:, -1, :]\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        # New: Filter logits with top_k sampling\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        if top_k is not None:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            # Keep only top_k values\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            top_logits, _ = torch.topk(logits, top_k)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            min_val = top_logits[:, -1]\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            logits = torch.where(logits < min_val, torch.tensor(float('-inf')).to(logits.device), logits)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        # New: Apply temperature scaling\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        if temperature > 0.0:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            logits = logits / temperature\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            # Apply softmax to get probabilities\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            probs = torch.softmax(logits, dim=-1)  # (batch_size, context_len)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            # Sample from the distribution\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            idx_next = torch.multinomial(probs, num_samples=1)  # (batch_size, 1)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        # Otherwise same as before: get idx of the vocab entry with the highest logits value\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        else:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            idx_next = torch.argmax(logits, dim=-1, keepdim=True)  # (batch_size, 1)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-05-18 12:35:40 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "        if idx_next == eos_id:  # Stop generating early if end-of-sequence token is encountered and eos_id is specified\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            break\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "        # Same as before: append sampled index to the running sequence\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        idx = torch.cat((idx, idx_next), dim=1)  # (batch_size, num_tokens+1)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    return idx"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-05-31 07:30:57 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 42,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-29 08:16:22 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "id": "aa2a0d7d-0457-42d1-ab9d-bd67683e7ed8",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Output text:\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-12 19:55:07 -04:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      " Every effort moves you stand to work on surprise, a one of us had gone with random-\n"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "torch.manual_seed(123)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "token_ids = generate(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    model=model,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    idx=text_to_token_ids(\"Every effort moves you\", tokenizer),\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-03 20:19:08 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    max_new_tokens=15,\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-04 07:27:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    context_size=GPT_CONFIG_124M[\"context_length\"],\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-03 20:19:08 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    top_k=25,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    temperature=1.4\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    ")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(\"Output text:\\n\", token_ids_to_text(token_ids, tokenizer))"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "4e2002ca-f4c1-48af-9e0a-88bfc163ba0b",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-03-22 09:15:40 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "## 5.4 Loading and saving model weights in PyTorch"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "0fc52676-f026-4566-a226-2a90269f9d53",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Training LLMs is computationally expensive, so it's crucial to be able to save and load LLM weights\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch05_compressed/mental-model-3.webp\" width=400px>"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "10e4c7f9-592f-43d6-a00e-598fa01dfb82",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- The recommended way in PyTorch is to save the model weights, the so-called `state_dict` via by applying the `torch.save` function to the  `.state_dict()` method:"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-05-31 07:30:57 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 43,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "3d67d869-ac04-4382-bcfb-c96d1ca80d47",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-03-19 06:34:52 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "torch.save(model.state_dict(), \"model.pth\")"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "90e889e0-07bf-43e5-8f92-5c5c7aeaad9e",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Then we can load the model weights into a new `GPTModel` model instance as follows:"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-05-31 07:30:57 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 44,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "9d57d914-60a3-47f1-b499-5352f4c457cb",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "model = GPTModel(GPT_CONFIG_124M)\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-19 06:34:52 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "model.load_state_dict(torch.load(\"model.pth\"))\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "model.eval();"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "caa81aec-9c72-4f46-8ae2-4a4fde3edbc1",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- It's common to train LLMs with adaptive optimizers like Adam or AdamW instead of regular SGD\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- These adaptive optimizers store additional parameters for each model weight, so it makes sense to save them as well in case we plan to continue the pretraining later:"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-05-31 07:30:57 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 45,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "bbd175bb-edf4-450e-a6de-d3e8913c6532",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "torch.save({\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-19 06:34:52 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    \"model_state_dict\": model.state_dict(),\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"optimizer_state_dict\": optimizer.state_dict(),\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "    }, \n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-19 06:34:52 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    \"model_and_optimizer.pth\"\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    ")"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-05-31 07:30:57 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 46,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "8a0c7295-c822-43bf-9286-c45abc542868",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-03-19 06:34:52 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "checkpoint = torch.load(\"model_and_optimizer.pth\")\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-29 08:16:22 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "model = GPTModel(GPT_CONFIG_124M)\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-19 06:34:52 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "model.load_state_dict(checkpoint[\"model_state_dict\"])\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-29 08:16:22 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-03 20:19:08 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "optimizer = torch.optim.AdamW(model.parameters(), lr=0.0005, weight_decay=0.1)\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-19 06:34:52 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "optimizer.load_state_dict(checkpoint[\"optimizer_state_dict\"])\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "model.train();"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "4194350e-0409-4a63-8ffd-d3a896509032",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-04-12 19:55:07 -04:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "## 5.5 Loading pretrained weights from OpenAI"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "83eb6c38-7278-40e0-bd9f-8a2b1feac3ec",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Previously, we only trained a small GPT-2 model using a very small short-story book for educational purposes\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-12 19:08:34 -04:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "- Interested readers can also find a longer pretraining run on the complete Project Gutenberg book corpus in [../03_bonus_pretraining_on_gutenberg](../03_bonus_pretraining_on_gutenberg)\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "- Fortunately, we don't have to spend tens to hundreds of thousands of dollars to pretrain the model on a large pretraining corpus but can load the pretrained weights provided by OpenAI"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "127ddbdb-3878-4669-9a39-d231fbdfb834",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- For an alternative way to load the weights from the Hugging Face Hub, see [../02_alternative_weight_loading](../02_alternative_weight_loading)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "75cab892-a165-4f43-9601-f517bc212ab6",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- First, some boilerplate code to download the files from OpenAI and load the weights into Python\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Since OpenAI used [TensorFlow](https://www.tensorflow.org/), we will have to install and use TensorFlow for loading the weights; [tqdm](https://github.com/tqdm/tqdm) is a progress bar library\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Uncomment and run the next cell to install the required libraries"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-05-31 07:30:57 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 47,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "fb9fdf02-972a-444e-bf65-8ffcaaf30ce8",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "# pip install tensorflow tqdm"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-05-31 07:30:57 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 48,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "a0747edc-559c-44ef-a93f-079d60227e3f",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
									
										
										
										
											2024-04-13 14:57:56 -04:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "TensorFlow version: 2.15.0\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-12 19:55:07 -04:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "tqdm version: 4.66.2\n"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(\"TensorFlow version:\", version(\"tensorflow\"))\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(\"tqdm version:\", version(\"tqdm\"))"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-05-31 07:30:57 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 49,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "c5bc89eb-4d39-4287-9b0c-e459ebe7f5ed",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-03-30 09:43:51 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "# Relative import from the gpt_download.py contained in this folder\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-29 08:16:22 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "from gpt_download import download_and_load_gpt2"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "ff76a736-6f9f-4328-872e-f89a7b70a2cc",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- We can then download the model weights for the 124 million parameter model as follows:"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-05-31 07:30:57 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 50,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "76271dd7-108d-4f5b-9c01-6ae0aac4b395",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
									
										
										
										
											2024-05-31 07:30:57 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								     "name": "stderr",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
									
										
										
										
											2024-05-31 07:30:57 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "checkpoint: 100%|███████████████████████████| 77.0/77.0 [00:00<00:00, 58.8kiB/s]\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "encoder.json: 100%|███████████████████████| 1.04M/1.04M [00:00<00:00, 2.70MiB/s]\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "hparams.json: 100%|█████████████████████████| 90.0/90.0 [00:00<00:00, 27.8kiB/s]\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "model.ckpt.data-00000-of-00001: 100%|███████| 498M/498M [00:30<00:00, 16.1MiB/s]\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "model.ckpt.index: 100%|███████████████████| 5.21k/5.21k [00:00<00:00, 1.18MiB/s]\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "model.ckpt.meta: 100%|██████████████████████| 471k/471k [00:00<00:00, 2.22MiB/s]\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "vocab.bpe: 100%|████████████████████████████| 456k/456k [00:00<00:00, 2.04MiB/s]\n"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-04-05 07:24:46 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "settings, params = download_and_load_gpt2(model_size=\"124M\", models_dir=\"gpt2\")"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-05-31 07:30:57 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 51,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "b1a31951-d971-4a6e-9c43-11ee1168ec6a",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Settings: {'n_vocab': 50257, 'n_ctx': 1024, 'n_embd': 768, 'n_head': 12, 'n_layer': 12}\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-04-05 07:24:46 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "print(\"Settings:\", settings)"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-05-31 07:30:57 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 52,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "857c8331-130e-46ba-921d-fa35d7a73cfe",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Parameter dictionary keys: dict_keys(['blocks', 'b', 'g', 'wpe', 'wte'])\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(\"Parameter dictionary keys:\", params.keys())"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2024-04-07 08:41:09 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-05-31 07:30:57 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 53,
							 
						 
					
						
							
								
									
										
										
										
											2024-04-07 08:41:09 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "id": "c48dac94-8562-4a66-84ef-46c613cdc4cd",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "[[-0.11010301 -0.03926672  0.03310751 ... -0.1363697   0.01506208\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "   0.04531523]\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      " [ 0.04034033 -0.04861503  0.04624869 ...  0.08605453  0.00253983\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "   0.04318958]\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      " [-0.12746179  0.04793796  0.18410145 ...  0.08991534 -0.12972379\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "  -0.08785918]\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      " ...\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      " [-0.04453601 -0.05483596  0.01225674 ...  0.10435229  0.09783269\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "  -0.06952604]\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      " [ 0.1860082   0.01665728  0.04611587 ... -0.09625227  0.07847701\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "  -0.02245961]\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      " [ 0.05135201 -0.02768905  0.0499369  ...  0.00704835  0.15519823\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "   0.12067825]]\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Token embedding weight tensor dimensions: (50257, 768)\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(params[\"wte\"])\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(\"Token embedding weight tensor dimensions:\", params[\"wte\"].shape)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "466e100c-294e-4afc-a70a-2f398ac4c104",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Alternatively, \"355M\", \"774M\", and \"1558M\" are also supported `model_size` arguments\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- The difference between these differently sized models is summarized in the figure below:"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "20f19d32-5aae-4176-9f86-f391672c8f0d",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-03-30 09:43:51 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch05_compressed/gpt-sizes.webp?timestamp=123\" width=500px>"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "ea6e5076-f08d-41fc-bd8b-1cfe53538f41",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Above, we loaded the 124M GPT-2 model weights into Python, however we still need to transfer them into our `GPTModel` instance\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- First, we initialize a new GPTModel instance\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Note that the original GPT model initialized the linear layers for the query, key, and value matrices in the multi-head attention module with bias vectors, which is not required or recommended; however, to be able to load the weights correctly, we have to enable these too by setting `qkv_bias` to `True` in our implementation, too\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- We are also using the `1024` token context length that was used by the original GPT-2 model(s)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-05-31 07:30:57 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 54,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "9fef90dd-0654-4667-844f-08e28339ef7d",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "# Define model configurations in a dictionary for compactness\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "model_configs = {\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 08:45:14 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    \"gpt2-small (124M)\": {\"emb_dim\": 768, \"n_layers\": 12, \"n_heads\": 12},\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"gpt2-medium (355M)\": {\"emb_dim\": 1024, \"n_layers\": 24, \"n_heads\": 16},\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"gpt2-large (774M)\": {\"emb_dim\": 1280, \"n_layers\": 36, \"n_heads\": 20},\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"gpt2-xl (1558M)\": {\"emb_dim\": 1600, \"n_layers\": 48, \"n_heads\": 25},\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "}\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "# Copy the base configuration and update with specific model settings\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 08:45:14 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "model_name = \"gpt2-small (124M)\"  # Example model name\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "NEW_CONFIG = GPT_CONFIG_124M.copy()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "NEW_CONFIG.update(model_configs[model_name])\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-04 07:27:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "NEW_CONFIG.update({\"context_length\": 1024, \"qkv_bias\": True})\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "gpt = GPTModel(NEW_CONFIG)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "gpt.eval();"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "272f29ac-8342-4b3d-a57d-9b0166ced314",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- The next task is to assign the OpenAI weights to the corresponding weight tensors in our `GPTModel` instance"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-05-31 07:30:57 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 55,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "f9a92229-c002-49a6-8cfb-248297ad8296",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "def assign(left, right):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    if left.shape != right.shape:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        raise ValueError(f\"Shape mismatch. Left: {left.shape}, Right: {right.shape}\")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    return torch.nn.Parameter(torch.tensor(right))"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-05-31 07:30:57 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 56,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "f22d5d95-ca5a-425c-a9ec-fc432a12d4e9",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 08:45:14 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "import numpy as np\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "def load_weights_into_gpt(gpt, params):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    gpt.pos_emb.weight = assign(gpt.pos_emb.weight, params['wpe'])\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    gpt.tok_emb.weight = assign(gpt.tok_emb.weight, params['wte'])\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    for b in range(len(params[\"blocks\"])):\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 20:28:51 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "        q_w, k_w, v_w = np.split(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            (params[\"blocks\"][b][\"attn\"][\"c_attn\"])[\"w\"], 3, axis=-1)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        gpt.trf_blocks[b].att.W_query.weight = assign(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            gpt.trf_blocks[b].att.W_query.weight, q_w.T)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        gpt.trf_blocks[b].att.W_key.weight = assign(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            gpt.trf_blocks[b].att.W_key.weight, k_w.T)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        gpt.trf_blocks[b].att.W_value.weight = assign(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            gpt.trf_blocks[b].att.W_value.weight, v_w.T)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        q_b, k_b, v_b = np.split(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            (params[\"blocks\"][b][\"attn\"][\"c_attn\"])[\"b\"], 3, axis=-1)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        gpt.trf_blocks[b].att.W_query.bias = assign(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            gpt.trf_blocks[b].att.W_query.bias, q_b)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        gpt.trf_blocks[b].att.W_key.bias = assign(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            gpt.trf_blocks[b].att.W_key.bias, k_b)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        gpt.trf_blocks[b].att.W_value.bias = assign(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            gpt.trf_blocks[b].att.W_value.bias, v_b)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        gpt.trf_blocks[b].att.out_proj.weight = assign(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            gpt.trf_blocks[b].att.out_proj.weight, \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            params[\"blocks\"][b][\"attn\"][\"c_proj\"][\"w\"].T)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        gpt.trf_blocks[b].att.out_proj.bias = assign(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            gpt.trf_blocks[b].att.out_proj.bias, \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            params[\"blocks\"][b][\"attn\"][\"c_proj\"][\"b\"])\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        gpt.trf_blocks[b].ff.layers[0].weight = assign(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            gpt.trf_blocks[b].ff.layers[0].weight, \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            params[\"blocks\"][b][\"mlp\"][\"c_fc\"][\"w\"].T)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        gpt.trf_blocks[b].ff.layers[0].bias = assign(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            gpt.trf_blocks[b].ff.layers[0].bias, \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            params[\"blocks\"][b][\"mlp\"][\"c_fc\"][\"b\"])\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        gpt.trf_blocks[b].ff.layers[2].weight = assign(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            gpt.trf_blocks[b].ff.layers[2].weight, \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            params[\"blocks\"][b][\"mlp\"][\"c_proj\"][\"w\"].T)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        gpt.trf_blocks[b].ff.layers[2].bias = assign(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            gpt.trf_blocks[b].ff.layers[2].bias, \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            params[\"blocks\"][b][\"mlp\"][\"c_proj\"][\"b\"])\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        gpt.trf_blocks[b].norm1.scale = assign(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            gpt.trf_blocks[b].norm1.scale, \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            params[\"blocks\"][b][\"ln_1\"][\"g\"])\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        gpt.trf_blocks[b].norm1.shift = assign(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            gpt.trf_blocks[b].norm1.shift, \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            params[\"blocks\"][b][\"ln_1\"][\"b\"])\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        gpt.trf_blocks[b].norm2.scale = assign(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            gpt.trf_blocks[b].norm2.scale, \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            params[\"blocks\"][b][\"ln_2\"][\"g\"])\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        gpt.trf_blocks[b].norm2.shift = assign(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            gpt.trf_blocks[b].norm2.shift, \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            params[\"blocks\"][b][\"ln_2\"][\"b\"])\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    gpt.final_norm.scale = assign(gpt.final_norm.scale, params[\"g\"])\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    gpt.final_norm.shift = assign(gpt.final_norm.shift, params[\"b\"])\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    gpt.out_head.weight = assign(gpt.out_head.weight, params[\"wte\"])\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "    \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "load_weights_into_gpt(gpt, params)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "gpt.to(device);"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "4f7472cb-54dc-4311-96d8-b2694f885cee",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- If the model is loaded correctly, we can use it to generate new text using our previous `generate` function:"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-05-31 07:30:57 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 57,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "1f690253-f845-4347-b7b6-43fabbd2affa",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Output text:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      " Every effort moves you toward finding an ideal new way to practice something!\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "What makes us want to be on top of that?\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "torch.manual_seed(123)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "token_ids = generate(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    model=gpt,\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-05-18 12:35:40 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    idx=text_to_token_ids(\"Every effort moves you\", tokenizer).to(device),\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "    max_new_tokens=25,\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-04 07:27:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    context_size=NEW_CONFIG[\"context_length\"],\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "    top_k=50,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    temperature=1.5\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    ")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(\"Output text:\\n\", token_ids_to_text(token_ids, tokenizer))"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "6d079f98-a7c4-462e-8416-5a64f670861c",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- We know that we loaded the model weights correctly because the model can generate coherent text; if we made even a small mistake, the mode would not be able to do that"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "28493b9b-a1ae-4f31-87bc-c10ee4447f44",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- For an alternative way to load the weights from the Hugging Face Hub, see [../02_alternative_weight_loading](../02_alternative_weight_loading)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "f2a66474-230d-4180-a8ff-843e04f1f1c4",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "## Summary and takeaways"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "fc7ed189-a633-458c-bf12-4f70b42684b8",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-03-29 08:16:22 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "- See the [gpt_train.py](gpt_train.py) script containing a self-contained training script\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- The [gpt_generate.py](gpt_generate.py) script loads pretrained weights from OpenAI and generates text based on a prompt\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "- You can find the exercise solutions in [exercise-solutions.ipynb](exercise-solutions.ipynb)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  "accelerator": "GPU",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  "colab": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "gpuType": "A100",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "machine_shape": "hm",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "provenance": []
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  "kernelspec": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "display_name": "Python 3 (ipykernel)",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "language": "python",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "name": "python3"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  "language_info": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "codemirror_mode": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "name": "ipython",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "version": 3
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "file_extension": ".py",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "mimetype": "text/x-python",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "name": "python",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "nbconvert_exporter": "python",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "pygments_lexer": "ipython3",
							 
						 
					
						
							
								
									
										
										
										
											2024-06-15 12:48:34 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "version": "3.10.6"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 21:07:19 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								  }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 "nbformat": 4,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 "nbformat_minor": 5
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								}