2024-03-11 07:07:36 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								{
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 "cells": [
							 
						 
					
						
							
								
									
										
										
										
											2024-03-19 09:26:26 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "9a5936bd-af17-4a7e-a4d2-e910411708ea",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-05-24 07:20:37 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "<table style=\"width:100%\">\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "<tr>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "<td style=\"vertical-align:middle; text-align:left;\">\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "<font size=\"2\">\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "Supplementary code for the <a href=\"http://mng.bz/orYv\">Build a Large Language Model From Scratch</a> book by <a href=\"https://sebastianraschka.com\">Sebastian Raschka</a><br>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "<br>Code repository: <a href=\"https://github.com/rasbt/LLMs-from-scratch\">https://github.com/rasbt/LLMs-from-scratch</a>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "</font>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "</td>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "<td style=\"vertical-align:middle; text-align:left;\">\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "<a href=\"http://mng.bz/orYv\"><img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/cover-small.webp\" width=\"100px\"></a>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "</td>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "</tr>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "</table>\n"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-19 09:26:26 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2024-03-11 07:07:36 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "af53bcb1-ff9d-49c7-a0bc-5b8d32ff975b",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "## Appendix D: Adding Bells and Whistles to the Training Loop"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "4f58c142-9434-49af-b33a-356b80a45b86",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- In this appendix, we add a few more advanced features to the training function, which are used in typical pretraining and finetuning; finetuning is covered in chapters 6 and 7\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- The next three sections below discuss learning rate warmup, cosine decay, and gradient clipping\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- The final section adds these techniques to the training function"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "744def4f-c03f-42ee-97bb-5d7d5b89b723",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- We start by initializing a model reusing the code from chapter 5:"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "execution_count": 1,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "8755bd5e-bc06-4e6e-9e63-c7c82b816cbe",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
									
										
										
										
											2024-04-13 14:57:56 -04:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "torch version: 2.2.2\n"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-11 07:07:36 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "from importlib.metadata import version\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "import torch\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(\"torch version:\", version(\"torch\"))\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "from previous_chapters import GPTModel\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "GPT_CONFIG_124M = {\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-04 07:27:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    \"vocab_size\": 50257,   # Vocabulary size\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"context_length\": 256, # Shortened context length (orig: 1024)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"emb_dim\": 768,        # Embedding dimension\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"n_heads\": 12,         # Number of attention heads\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"n_layers\": 12,        # Number of layers\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"drop_rate\": 0.1,      # Dropout rate\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"qkv_bias\": False      # Query-key-value bias\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-11 07:07:36 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "}\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "torch.manual_seed(123)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "model = GPTModel(GPT_CONFIG_124M)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "model.eval();  # Disable dropout during inference"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "51574e57-a098-412c-83e8-66dafa5a0b99",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Next, using the same code we used in chapter 5, we initialize the data loaders:"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "execution_count": 2,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "386ca110-2bb4-42f1-bd54-8836df80acaa",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "import os\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "import urllib.request\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "file_path = \"the-verdict.txt\"\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "url = \"https://raw.githubusercontent.com/rasbt/LLMs-from-scratch/main/ch02/01_main-chapter-code/the-verdict.txt\"\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "if not os.path.exists(file_path):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    with urllib.request.urlopen(url) as response:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        text_data = response.read().decode('utf-8')\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    with open(file_path, \"w\", encoding=\"utf-8\") as file:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        file.write(text_data)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "else:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    with open(file_path, \"r\", encoding=\"utf-8\") as file:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        text_data = file.read()"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "execution_count": 3,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "ae96992b-536a-4684-a924-658b9ffb7e9c",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "from previous_chapters import create_dataloader_v1\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "# Train/validation ratio\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "train_ratio = 0.90\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "split_idx = int(train_ratio * len(text_data))\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "torch.manual_seed(123)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "train_loader = create_dataloader_v1(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    text_data[:split_idx],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    batch_size=2,\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-04 07:27:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    max_length=GPT_CONFIG_124M[\"context_length\"],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    stride=GPT_CONFIG_124M[\"context_length\"],\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-11 07:07:36 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "    drop_last=True,\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-13 14:57:56 -04:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    shuffle=True,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    num_workers=0\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-11 07:07:36 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    ")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "val_loader = create_dataloader_v1(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    text_data[split_idx:],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    batch_size=2,\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-04 07:27:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    max_length=GPT_CONFIG_124M[\"context_length\"],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    stride=GPT_CONFIG_124M[\"context_length\"],\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-11 07:07:36 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "    drop_last=False,\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-13 14:57:56 -04:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    shuffle=False,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    num_workers=0\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-11 07:07:36 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    ")"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "939c08d8-257a-41c6-b842-019f7897ac74",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "## D.1 Learning rate warmup"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "7fafcd30-ddf7-4a9f-bcf4-b13c052b3133",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- When training complex models like LLMs, implementing learning rate warmup can help stabilize the training\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- In learning rate warmup, we gradually increase the learning rate from a very low value (`initial_lr`) to a user-specified maximum (`peak_lr`)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- This way, the model will start the training with small weight updates, which helps decrease the risk of large destabilizing updates during the training"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "execution_count": 4,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "2bb4790b-b8b6-4e9e-adf4-704a04b31ddf",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
									
										
										
										
											2024-04-01 08:05:01 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
									
										
										
										
											2024-03-11 07:07:36 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "n_epochs = 15\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "initial_lr = 0.0001\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-01 08:05:01 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "peak_lr = 0.01"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-11 07:07:36 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "5bf3a8da-abc4-4b80-a5d8-f1cc1c7cc5f3",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-04-17 20:30:42 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "- Typically, the number of warmup steps is between 0.1% to 10% of the total number of steps\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-11 07:07:36 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "- We can compute the increment as the difference between the `peak_lr` and `initial_lr` divided by the number of warmup steps"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "execution_count": 5,
							 
						 
					
						
							
								
									
										
										
										
											2024-04-17 20:30:42 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "id": "5f6d083f-1b25-4c23-b46d-ef7783446690",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "27\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-05-13 07:45:59 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "total_steps = len(train_loader) * n_epochs\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "warmup_steps = int(0.2 * total_steps) # 20% warmup\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-17 20:30:42 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "print(warmup_steps)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "execution_count": 6,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-11 07:07:36 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "e075f80e-a398-4809-be1d-8019e1d31c90",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "lr_increment = (peak_lr - initial_lr) / warmup_steps\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "global_step = -1\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "track_lrs = []\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-01 08:05:01 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "optimizer = torch.optim.AdamW(model.parameters(), weight_decay=0.1)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-11 07:07:36 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "for epoch in range(n_epochs):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    for input_batch, target_batch in train_loader:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        optimizer.zero_grad()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        global_step += 1\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        if global_step < warmup_steps:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            lr = initial_lr + global_step * lr_increment\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        else:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            lr = peak_lr\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        # Apply the calculated learning rate to the optimizer\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        for param_group in optimizer.param_groups:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            param_group[\"lr\"] = lr\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-05-23 17:19:01 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "        track_lrs.append(optimizer.defaults[\"lr\"])\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-11 07:07:36 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "    \n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 21:22:49 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "        # Calculate loss and update weights\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        # ..."
							 
						 
					
						
							
								
									
										
										
										
											2024-03-11 07:07:36 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-17 20:30:42 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 7,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-11 07:07:36 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "cb6da121-eeed-4023-bdd8-3666c594b4ed",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "data": {
							 
						 
					
						
							
								
									
										
										
										
											2024-04-17 20:30:42 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAekAAAEiCAYAAADd4SrgAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAAA9hAAAPYQGoP6dpAAA6LElEQVR4nO3dfViUVd4H8O8MLzOIMCgkA4oyGkUKokIQxq5rUmSokZbm8qiZq9VSafRk+QbVYw+FuevaWupuT7bX5su6W6Su0hKWZiIoIIrvJYqJAwgxgyhvM+f5g7i3WVEZHLhn4Pu5rrlw7vt3z/yO0vw65z5zjkIIIUBERER2Ryl3AkRERNQ2FmkiIiI7xSJNRERkp1ikiYiI7BSLNBERkZ1ikSYiIrJTLNJERER2ikWaiIjITjnLnYCjMpvNKCsrg4eHBxQKhdzpEBGRzIQQqK2thb+/P5RK2/SBWaQ7qKysDAEBAXKnQUREdubChQsYMGCATV6LRbqDPDw8ALT8Y3h6esqcDRERyc1oNCIgIECqD7bAIt1BrUPcnp6eLNJERCSx5S1QThwjIiKyUyzSREREdopFmoiIyE7JXqTXrFmDwMBAqNVqREVFIS8v76bxW7duRXBwMNRqNUJDQ7Fz506L859++ikeeugheHt7Q6FQ4PDhw9e9Rn19PZKSkuDt7Y3evXtjypQpKC8vt2WziIiIbpusRXrLli1ITk5GamoqCgoKEBYWhri4OFRUVLQZv3//fkyfPh1z5sxBYWEhEhISkJCQgOLiYimmrq4OMTExeOedd274vi+99BK2b9+OrVu3Ys+ePSgrK8PkyZNt3j4iIqLboRBCCLnePCoqCvfeey/++Mc/AmhZICQgIAAvvPACXnvttevip02bhrq6OuzYsUM6dt9992HEiBFYu3atRey5c+eg0+lQWFiIESNGSMcNBgPuuOMObNy4EY8//jgA4OTJk7jnnnuQk5OD++67r125G41GaDQaGAwGzu4mIqJOqQuyfQWrsbER+fn5WLRokXRMqVQiNjYWOTk5bV6Tk5OD5ORki2NxcXHIyMho9/vm5+ejqakJsbGx0rHg4GAMHDjwpkW6oaEBDQ0N0nOj0dju9yTrmc0Cb+08gZN6/j0TUedzVirx8dORcqdxHdmK9OXLl2EymeDr62tx3NfXFydPnmzzGr1e32a8Xq9v9/vq9Xq4urrCy8vLqtdJS0vDG2+80e73odtz4GwVPtxXIncaRNRDuDrJPkWrTVzMpJ0WLVpk0YtvXVmGOse2ojIAwAPB/fDoCH+ZsyGi7k5pp3swyFakfXx84OTkdN2s6vLycmi12jav0Wq1VsXf6DUaGxtRU1Nj0Zu+1euoVCqoVKp2vw91XGOzGbuKW0Y1fhOjw+g7fWTOiIhIHrL1711dXREeHo7s7GzpmNlsRnZ2NqKjo9u8Jjo62iIeALKysm4Y35bw8HC4uLhYvM6pU6dQWlpq1etQ5/nmTCUM15rQz0OFqMHecqdDRCQbWYe7k5OTMWvWLERERCAyMhKrVq1CXV0dZs+eDQCYOXMm+vfvj7S0NADA/PnzMWbMGKxcuRLx8fHYvHkzDh06hPXr10uvWV1djdLSUpSVtQyXnjp1CkBLD1qr1UKj0WDOnDlITk5G37594enpiRdeeAHR0dHtntlNnat1qDt+uB+clPY5BEVE1BVkLdLTpk1DZWUlUlJSoNfrMWLECGRmZkqTw0pLSy325Bw9ejQ2btyIpUuXYvHixQgKCkJGRgZCQkKkmG3btklFHgCefPJJAEBqaipef/11AMDvf/97KJVKTJkyBQ0NDYiLi8P777/fBS2mW7nWaELW8ZZbGhPDeC+aiHo2Wb8n7cj4PenOsb2oDC9sKkRAXzfsfWWsTXeTISLqTJ1RF+xzzjn1WNt/GuqeONyfBZqIejwWabIbhmtN+PpUJQBgEr92RUTEIk3244tjejSazAjq1xt3+3rInQ4RkexYpMlutA51TwrjUDcREcAiTXbi8pUG7P++CgBndRMRtWKRJruw8+glmMwCwwdoEOjjLnc6RER2gUWa7MLPh7qJiKgFizTJ7mLNNRw89yMUCmDCcBZpIqJWLNIkux0/9aLvDewLrUYtczZERPaDRZpkt41D3UREbWKRJll9X3kFx8qMcFYq8Eion9zpEBHZFRZpklXrhLGYIB/0dXeVORsiIvvCIk2yEUJIQ90TOWGMiOg6LNIkm+OXjDhbWQeVsxIPDfOVOx0iIrvDIk2yae1FPxDcDx5qF5mzISKyPyzSJAuzWWBH0SUAXAaUiOhGWKRJFoUXfsTFmmvorXLGA8H95E6HiMgusUiTLLYdbhnqfmioL9QuTjJnQ0Rkn1ikqcs1m8z451EOdRMR3QqLNHW5A2ercflKI/r0ckFMkI/c6RAR2S0Waepy24ouAgDGh/rBxYm/gkREN8JPSOpSDc0m7CrWA+Ba3UREt8IiTV1qz6lK1NY3w9dThXsD+8qdDhGRXWORpi61/UjLhLEJw/3hpFTInA0RkX1jkaYuc7WxGV8eLwfAoW4iovZgkaYuk3W8HNeaTBjk3QvDB2jkToeIyO6xSFOX2d66DOhwfygUHOomIroVFmnqEoarTdhzugIAMGkEh7qJiNqDRZq6ROaxS2gyCdzt64G7fD3kToeIyCGwSFOXaB3qZi+aiKj9ZC/Sa9asQWBgINRqNaKiopCXl3fT+K1btyI4OBhqtRqhoaHYuXOnxXkhBFJSUuDn5wc3NzfExsbizJkzFjGnT5/Go48+Ch8fH3h6eiImJgZfffWVzdtGLSpq67H/+8sAWu5HExFR+8hapLds2YLk5GSkpqaioKAAYWFhiIuLQ0VFRZvx+/fvx/Tp0zFnzhwUFhYiISEBCQkJKC4ulmLS09OxevVqrF27Frm5uXB3d0dcXBzq6+ulmAkTJqC5uRm7d+9Gfn4+wsLCMGHCBOj1+k5vc0+088glmAUwIsALA717yZ0OEZHjEDKKjIwUSUlJ0nOTyST8/f1FWlpam/FTp04V8fHxFseioqLEM888I4QQwmw2C61WK1asWCGdr6mpESqVSmzatEkIIURlZaUAIPbu3SvFGI1GAUBkZWW1O3eDwSAACIPB0O5reqrH1uwTg17dIf78zVm5UyEi6jSdURdk60k3NjYiPz8fsbGx0jGlUonY2Fjk5OS0eU1OTo5FPADExcVJ8SUlJdDr9RYxGo0GUVFRUoy3tzfuvvtu/OUvf0FdXR2am5uxbt069OvXD+Hh4bZuZo93ofoqCkproFAAE4b7yZ0OEZFDcZbrjS9fvgyTyQRfX1+L476+vjh58mSb1+j1+jbjW4epW3/eLEahUODLL79EQkICPDw8oFQq0a9fP2RmZqJPnz43zLehoQENDQ3Sc6PR2M6W9mw7floG9D6dN3w91TJnQ0TkWGSfONbVhBBISkpCv3798M033yAvLw8JCQmYOHEiLl26dMPr0tLSoNFopEdAQEAXZu24thWVAQAmchlQIiKryVakfXx84OTkhPLycovj5eXl0Gq1bV6j1WpvGt/682Yxu3fvxo4dO7B582bcf//9GDVqFN5//324ubnh448/vmG+ixYtgsFgkB4XLlywrsE90HcVtThxyQhnpQLjQ9r+NyUiohuTrUi7uroiPDwc2dnZ0jGz2Yzs7GxER0e3eU10dLRFPABkZWVJ8TqdDlqt1iLGaDQiNzdXirl69SqAlvvfP6dUKmE2m2+Yr0qlgqenp8WDbm7bT9+N/uVdd6CPu6vM2RAROR7Z7kkDQHJyMmbNmoWIiAhERkZi1apVqKurw+zZswEAM2fORP/+/ZGWlgYAmD9/PsaMGYOVK1ciPj4emzdvxqFDh7B+/XoALfebFyxYgOXLlyMoKAg6nQ7Lli2Dv78/EhISALQU+j59+mDWrFlISUmBm5sb/vSnP6GkpATx8fGy/D10R0IIbJeGujlhjIioI2Qt0tOmTUNlZSVSUlKg1+sxYsQIZGZmShO/SktLLXq8o0ePxsaNG7F06VIsXrwYQUFByMjIQEhIiBSzcOFC1NXVYd68eaipqUFMTAw
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 21:22:49 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "text/plain": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "<Figure size 500x300 with 1 Axes>"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-11 07:07:36 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								      ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "display_data"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "import matplotlib.pyplot as plt\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 21:22:49 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "plt.figure(figsize=(5, 3))\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-11 07:07:36 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "plt.ylabel(\"Learning rate\")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "plt.xlabel(\"Step\")\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-01 08:05:01 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "total_training_steps = len(train_loader) * n_epochs\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 21:22:49 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "plt.plot(range(total_training_steps), track_lrs)\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-03 20:19:08 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "plt.tight_layout(); plt.savefig(\"1.pdf\")\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-01 08:05:01 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "plt.show()"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-11 07:07:36 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "7b3996b6-3f7a-420a-8584-c5760249f3d8",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "## D.2 Cosine decay"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "c5216214-de79-40cf-a733-b1049a73023c",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Another popular technique for training complex deep neural networks is cosine decay, which also adjusts the learning rate across training epochs\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- In cosine decay, the learning rate follows a cosine curve, decreasing from its initial value to near zero following a half-cosine cycle\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- This gradual reduction is designed to slow the pace of learning as the model begins to improve its weights; it reduces the risk of overshooting minima as the training progresses,  which is crucial for stabilizing the training in its later stages\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Cosine decay is often preferred over linear decay for its smoother transition in learning rate adjustments, but linear decay is also used in practice (for example, [OLMo: Accelerating the Science of Language Models](https://arxiv.org/abs/2402.00838))"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-17 20:30:42 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 8,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-11 07:07:36 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "4e8d2068-a057-4abf-b478-f02cc37191f6",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "import math\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "min_lr = 0.1 * initial_lr\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "track_lrs = []\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "lr_increment = (peak_lr - initial_lr) / warmup_steps\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "global_step = -1\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "for epoch in range(n_epochs):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    for input_batch, target_batch in train_loader:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        optimizer.zero_grad()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        global_step += 1\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        # Adjust the learning rate based on the current phase (warmup or cosine annealing)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        if global_step < warmup_steps:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            # Linear warmup\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            lr = initial_lr + global_step * lr_increment  \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        else:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            # Cosine annealing after warmup\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            progress = ((global_step - warmup_steps) / \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                        (total_training_steps - warmup_steps))\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            lr = min_lr + (peak_lr - min_lr) * 0.5 * (1 + math.cos(math.pi * progress))\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        # Apply the calculated learning rate to the optimizer\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        for param_group in optimizer.param_groups:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            param_group[\"lr\"] = lr\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-05-23 17:19:01 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "        track_lrs.append(optimizer.defaults[\"lr\"])\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-11 07:07:36 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "    \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        # Calculate loss and update weights"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-17 20:30:42 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 9,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-11 07:07:36 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "0e779e33-8a44-4984-bb23-be0603dc4158",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "data": {
							 
						 
					
						
							
								
									
										
										
										
											2024-04-17 20:30:42 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAekAAAEiCAYAAADd4SrgAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAAA9hAAAPYQGoP6dpAABQLklEQVR4nO3deVxU9f7H8dcMywwqDAjKgIqSWi4gKCiilpUUmaaYN5dMzSzNrDQqS0utbv24Wt7KMpduV20xzW6pmVGGWyqC4r7gkgsqDqDIqmwz5/cHMV2uaKDAmYHP8/GYB8053zO8v2p8OOd8z/erURRFQQghhBA2R6t2ACGEEEJUTIq0EEIIYaOkSAshhBA2Soq0EEIIYaOkSAshhBA2Soq0EEIIYaOkSAshhBA2Soq0EEIIYaMc1Q5grywWC6mpqbi6uqLRaNSOI4QQQmWKopCbm4uvry9abfWcA0uRvkmpqam0aNFC7RhCCCFszNmzZ2nevHm1fJYU6Zvk6uoKlP5luLm5qZxGCCGE2nJycmjRooW1PlQHKdI3qewSt5ubmxRpIYQQVtV5C1QGjgkhhBA2Soq0EEIIYaOkSAshhBA2SvUiPW/ePFq1aoVerycsLIzExMQbtl+5ciXt2rVDr9cTGBjIunXryu3/7rvvuP/++/H09ESj0bB3795rPqOgoICJEyfi6elJo0aNGDx4MGlpadXZLSGEEOKWqVqkV6xYQXR0NDNnzmT37t0EBQURGRlJenp6he23b9/O8OHDGTt2LHv27CEqKoqoqCgOHjxobZOfn0+vXr2YNWvWdb/vCy+8wA8//MDKlSvZvHkzqampPPzww9XePyGEEOJWaBRFUdT65mFhYXTt2pWPP/4YKJ0gpEWLFjz33HO8+uqr17QfOnQo+fn5rF271rqte/fuBAcHs2DBgnJtT58+jb+/P3v27CE4ONi6PTs7myZNmrBs2TL+9re/AZCcnEz79u2Jj4+ne/fulcqek5ODwWAgOztbRncLIYSokbqg2iNYRUVFJCUlMXXqVOs2rVZLREQE8fHxFR4THx9PdHR0uW2RkZGsWrWq0t83KSmJ4uJiIiIirNvatWuHn5/fDYt0YWEhhYWF1vc5OTmV/p6i6iwWhXfWHeF4eh46Ry16J4c/vmppqHOkSSMdTd30NHXV4e2mx9tNRwNneaJQCFG3qPZT7eLFi5jNZry9vctt9/b2Jjk5ucJjTCZThe1NJlOlv6/JZMLZ2Rl3d/cqfU5MTAxvvvlmpb+PuDXxJy/x2dZTVTrG6KanrXcj2jRtRNumrrT1bkR7Hzca6aR4CyHsk/z0qqSpU6eWO4svm1lG1Iw1e1MBuPuOJtzfwUhhiZmCYguFJWZyC0pIzy0kLaeAjNxC0nMKyC8yY8opwJRTwG/HL1o/R6uBdkY3Qlt5ENKy9NXM3UXmWxdC2AXVirSXlxcODg7XjKpOS0vDaDRWeIzRaKxS++t9RlFREVlZWeXOpv/qc3Q6HTqdrtLfR9y8whIz6w5eAGD8Xa0Jb+35l8dkXy3mRHoeJ9JzOZ6Wx/H0PI6l5XIhu4DDF3I4fCGHz+PPANDM3YW772jCve2a0qO1Fy7ODjXaHyGEuFmqFWlnZ2dCQkKIi4sjKioKKB04FhcXx7PPPlvhMeHh4cTFxTF58mTrtvXr1xMeHl7p7xsSEoKTkxNxcXEMHjwYgKNHj5KSklKlzxE1Z/PRDHILSjC66enm37hSxxhcnKxnyv/NlF3ArjOZJJ25TNKZyxxKzeF81lW+Skjhq4QUdI5aerT25N723vQNMOLVSH4RE0LYDlUvd0dHRzN69GhCQ0Pp1q0bH3zwAfn5+YwZMwaAUaNG0axZM2JiYgCYNGkSvXv3Zs6cOfTr14/ly5eza9cuFi1aZP3MzMxMUlJSSE0tvVx69OhRoPQM2mg0YjAYGDt2LNHR0TRu3Bg3Nzeee+45wsPDKz2yW9Ss1ftK/+76d/LBQXtrl6WNBj39O/nSv5MvAFeKSoj//RIbktPZmJxOanYBG49msPFoBm+sOcSdbb2ICm7G/R29ZSCaEEJ1qv4UGjp0KBkZGcyYMQOTyURwcDCxsbHWwWEpKSnl1uTs0aMHy5Yt4/XXX2fatGm0bduWVatWERAQYG2zZs0aa5EHGDZsGAAzZ87kjTfeAOD9999Hq9UyePBgCgsLiYyM5JNPPqmFHou/kldYQtyR0lsaA4ObVfvnN3B2pE97b/q090ZRFI6m5bIhOZ3Ygyb2n8tm09EMNh3NoIGzA/d38GZI1xaE3+Yp97CFEKpQ9TlpeybPSdeM7/ec44UV+/D3asiGF3vXanH8PSOP1XtTWb33PGcuXbFub9u0ESPDWzKoczNc9U61lkcIYV9qoi5Ikb5JUqRrxpjFiWw8msGkPm154b7bVcmgKAp7z2bxbdI5vt9znitFZgAaOjvwcJfmjO7RkjZNq2+9WCFE3SBF2oZIka5+mflFdHvnV0osCnEv9qZ1k0ZqRyK3oJjvdp/n8/jT/J6RD4BGA5EdjEy8pw2BzQ0qJxRC2Io6NeOYEP9r3YELlFgUApq52USBBnDVOzG6RytGhbdk+++XWLL9NOsPpxF7yETsIRN33d6EiXe3Juy2v35MTAghqkqKtLAZZROYDAjyVTnJtTQaDT3beNGzjRfH0nKZv+l31uxLZcuxDLYcy6Bbq8a8FHlHpR8ZE0KIylB9qUohAFKzrpJ4OhONBh6ywSL93273duX9ocFsfPFuHg3zw9lBS+LpTIYsjGfskp0cNeWqHVEIUUdIkRY24Yc/no3u2qoxPgYXldNUjp9nA/5vUCBbptzDo2F+OGg1xCWn88CHW3hp5T7OZ11VO6IQws5JkRY2Yc0fRXpgsG2fRVfEaNDzf4MC+eWFu+gbYERR4Nukc9zz3ibe/TmZK0UlakcUQtgpKdJCdSfS8ziUmoOjVsODAT5qx7lprZs0Yv5jIXz3TA/C/BtTVGJh3sbf6TNnMz/uv4A8SCGEqCop0kJ1ZWfRd93eBI+GziqnuXVd/DxYPq47Cx4LoZm7CxeyC5i4bDePfZbAiXS5Xy2EqDwp0kJViqKwZu95wDZHdd8sjUbDAwFGfo3uzfN92uLsqGXbiUs88MFvxPx0hIJis9oRhRB2QIq0UNWB89mcvnQFvZOW+zp4qx2n2rk4OxB93+38+kJvItp7U2JRWLj5JH0//I2Ek5fUjieEsHFSpIWqVv/xbHREe28a6uruY/t+ng341+hQPh0VirebjlMX8xm6aAfTVx0kt6BY7XhCCBslRVqoxmxRWLu/bFR39a94ZYvu6+DNLy/0Zni3FgB8seMMke9vYePRdJWTCSFskRRpoZqEU5dIyynETe/IXbd7qR2n1hhcnIh5uBPLngzDr3EDUrMLGLN4J699f0Ae1xJClCNFWqimbAKTvgE+6BwdVE5T+3q08SJ28p2M6dkKgK8SUug/dyv7zmapmksIYTukSAtVFJVYWHfABNjnBCbVpYGzIzMf6shXT4ZhdNNz8mI+g+dv56O445SYLWrHE0KoTIq0UMWWYxlkXy2mqatOVpACev5xVt2vkw8lFoU5648xdNEOzmZeUTuaEEJFUqSFKlb/cam7fydfHLQaldPYBvcGznw8vDPvDw3CVedI0pnL9Jv7G+sPp6kdTQihEinSotblF5bw6x+FZ0A9vtRdEY1Gw6DOzVk36U6CWriTU1DCU5/v4v/WHaFYLn8LUe9IkRa17tcjaVwtNtPSswFBzQ1qx7FJLRo3YOX4cJ7o6Q/Aoi0nGbownlRZWUuIekWKtKh1a/6YwGRAkC8ajVzqvh5nRy0zHurAgsdCcNU7sjsliwfn/sYmeaZaiHpDirSoVZfzi9h8LAOo36O6q+KBACM/Pncngc0MZF0pZsySnczbeEJW1RKiHpAiLWrVTwdNlFgU2vu40aapq9px7IafZwO+nRDO8G5+KAq8+/NRJi7bTX6hTH4iRF0mRVrUqtV/rHglZ9FVp3N0IObhQP5vUCBODhrWHTDx8CfbOXMpX+1oQogaIkVa1JoL2VdJPJ0JwEN1aFnK2vZomB/Lx3W
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 21:22:49 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "text/plain": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "<Figure size 500x300 with 1 Axes>"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-11 07:07:36 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								      ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "display_data"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 21:22:49 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "plt.figure(figsize=(5, 3))\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-11 07:07:36 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "plt.ylabel(\"Learning rate\")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "plt.xlabel(\"Step\")\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 21:22:49 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "plt.plot(range(total_training_steps), track_lrs)\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-03 20:19:08 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "plt.tight_layout(); plt.savefig(\"2.pdf\")\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-01 08:05:01 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "plt.show()"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-11 07:07:36 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "e7512808-b48d-4146-86a1-5931b1e3aec1",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "## D.3 Gradient clipping"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "c0a74f76-8d2b-4974-a03c-d645445cdc21",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Gradient clipping is yet another technique used to stabilize the training when training LLMs\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- By setting a threshold, gradients exceeding this limit are scaled down to a maximum magnitude to ensure that the updates to the model's parameters during backpropagation remain within a manageable range\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- For instance, using the `max_norm=1.0` setting in PyTorch's `clip_grad_norm_` method means that the norm of the gradients is clipped such that their maximum norm does not exceed 1.0\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- the \"norm\" refers to a measure of the gradient vector's length (or magnitude) in the parameter space of the model\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Specifically, it's the L2 norm, also known as the Euclidean norm\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Mathematically, for a vector $\\mathbf{v}$ with components $\\mathbf{v} = [v_1, v_2, \\ldots, v_n]$, the L2 norm is defined as:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "$$\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\\| \\mathbf{v} \\|_2 = \\sqrt{v_1^2 + v_2^2 + \\ldots + v_n^2}\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "$$"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "d44838a6-4322-47b2-a935-c00d3a88355f",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- The L2 norm is calculated similarly for matrices.\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Let's assume our gradient matrix is:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "$$\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "G = \\begin{bmatrix}\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "1 & 2 \\\\\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "2 & 4\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\\end{bmatrix}\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "$$\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- And we want to clip these gradients with a `max_norm` of 1.\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- First, we calculate the L2 norm of these gradients:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "$$\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\\|G\\|_2 = \\sqrt{1^2 + 2^2 + 2^2 + 4^2} = \\sqrt{25} = 5\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "$$\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Since $\\|G\\|_2 = 5$ is greater than our `max_norm` of 1, we need to scale down the gradients so that their norm is exactly 1. The scaling factor is calculated as $\\frac{max\\_norm}{\\|G\\|_2} = \\frac{1}{5}$.\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Therefore, the scaled gradient matrix $G'$ will be as follows:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "$$\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "G' = \\frac{1}{5} \\times G = \\begin{bmatrix}\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\\frac{1}{5} & \\frac{2}{5} \\\\\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\\frac{2}{5} & \\frac{4}{5}\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\\end{bmatrix}\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "$$"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "eeb0c3c1-2cff-46f5-8127-24412184428c",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Let's see this in action\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- First, we initialize a new model and calculate the loss for a training batch like we would do in the regular training loop"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-17 20:30:42 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 10,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-11 07:07:36 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "e199e1ff-58c4-413a-855e-5edbe9292649",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "from previous_chapters import calc_loss_batch\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "torch.manual_seed(123)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "model = GPTModel(GPT_CONFIG_124M)\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-05-14 22:28:00 -07:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "model.to(device)\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-11 07:07:36 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "loss = calc_loss_batch(input_batch, target_batch, model, device)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "loss.backward()"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "76b60f3a-15ec-4846-838d-fdef3df99899",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- If we call `.backward()`, PyTorch will calculate the gradients and store them in a `.grad` attribute for each weight (parameter) matrix\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Let's define a utility function to calculate the highest gradient based on all model weights"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-17 20:30:42 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 11,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-11 07:07:36 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "e70729a3-24d1-411d-a002-2529cd3a8a9e",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
									
										
										
										
											2024-04-03 20:19:08 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "tensor(0.0411)\n"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-11 07:07:36 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "def find_highest_gradient(model):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    max_grad = None\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    for param in model.parameters():\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        if param.grad is not None:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            grad_values = param.grad.data.flatten()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            max_grad_param = grad_values.max()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            if max_grad is None or max_grad_param > max_grad:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                max_grad = max_grad_param\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    return max_grad\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(find_highest_gradient(model))"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "734f30e6-6b24-4d4b-ae91-e9a4b871113f",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Applying gradient clipping, we can see that the largest gradient is now substantially smaller:"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-17 20:30:42 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 12,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-11 07:07:36 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "fa81ef8b-4280-400f-a93e-5210f3e62ff0",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
									
										
										
										
											2024-04-03 20:19:08 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "tensor(0.0185)\n"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-11 07:07:36 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(find_highest_gradient(model))"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "b62c2af0-dac3-4742-be4b-4292c6753099",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "## D.4 The modified training function"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "76715332-94ec-4185-922a-75cb420819d5",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Now let's add the three concepts covered above (learning rate warmup, cosine decay, and gradient clipping) to the `train_model_simple` function covered in chapter 5 to create the more sophisticated `train_model` function below:"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-17 20:30:42 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 13,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-11 07:07:36 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "46eb9c84-a293-4016-a523-7ad726e171e9",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "from previous_chapters import evaluate_model, generate_and_print_sample\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "def train_model(model, train_loader, val_loader, optimizer, device, n_epochs,\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-17 20:30:42 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "                eval_freq, eval_iter, start_context, tokenizer, warmup_steps,\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-11 07:07:36 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "                initial_lr=3e-05, min_lr=1e-6):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    train_losses, val_losses, track_tokens_seen, track_lrs = [], [], [], []\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    tokens_seen, global_step = 0, -1\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    # Retrieve the maximum learning rate from the optimizer\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-05-23 17:19:01 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    peak_lr = optimizer.defaults[\"lr\"]\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-11 07:07:36 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    # Calculate the total number of iterations in the training process\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    total_training_steps = len(train_loader) * n_epochs\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    # Calculate the learning rate increment during the warmup phase\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    lr_increment = (peak_lr - initial_lr) / warmup_steps\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    for epoch in range(n_epochs):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        model.train()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        for input_batch, target_batch in train_loader:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            optimizer.zero_grad()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            global_step += 1\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            # Adjust the learning rate based on the current phase (warmup or cosine annealing)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            if global_step < warmup_steps:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                # Linear warmup\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                lr = initial_lr + global_step * lr_increment  \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            else:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                # Cosine annealing after warmup\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                progress = ((global_step - warmup_steps) / \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                            (total_training_steps - warmup_steps))\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                lr = min_lr + (peak_lr - min_lr) * 0.5 * (1 + math.cos(math.pi * progress))\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            # Apply the calculated learning rate to the optimizer\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            for param_group in optimizer.param_groups:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                param_group[\"lr\"] = lr\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            track_lrs.append(lr)  # Store the current learning rate\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            # Calculate and backpropagate the loss\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            loss = calc_loss_batch(input_batch, target_batch, model, device)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            loss.backward()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            # Apply gradient clipping after the warmup phase to avoid exploding gradients\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            if global_step > warmup_steps:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            optimizer.step()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            tokens_seen += input_batch.numel()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            # Periodically evaluate the model on the training and validation sets\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            if global_step % eval_freq == 0:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                train_loss, val_loss = evaluate_model(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                    model, train_loader, val_loader,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                    device, eval_iter\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                )\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                train_losses.append(train_loss)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                val_losses.append(val_loss)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                track_tokens_seen.append(tokens_seen)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                # Print the current losses\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                print(f\"Ep {epoch+1} (Iter {global_step:06d}): \"\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                      f\"Train loss {train_loss:.3f}, Val loss {val_loss:.3f}\")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        # Generate and print a sample from the model to monitor progress\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        generate_and_print_sample(\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-13 14:57:56 -04:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "            model, tokenizer, device, start_context\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-11 07:07:36 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "        )\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    return train_losses, val_losses, track_tokens_seen, track_lrs"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-17 20:30:42 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 14,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-11 07:07:36 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "55fcd247-ba9d-4b93-a757-0f7ce04fee41",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
									
										
										
										
											2024-04-03 20:19:08 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "Ep 1 (Iter 000000): Train loss 10.934, Val loss 10.939\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-17 20:30:42 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "Ep 1 (Iter 000005): Train loss 9.151, Val loss 9.461\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-11 07:07:36 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								      "Every effort moves you,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-17 20:30:42 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "Ep 2 (Iter 000010): Train loss 7.949, Val loss 8.184\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Ep 2 (Iter 000015): Train loss 6.362, Val loss 6.876\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Every effort moves you,,,,,,,,,,,,,,,,,,, the,,,,,,,,, the,,,,,,,,,,, the,,,,,,,,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Ep 3 (Iter 000020): Train loss 5.851, Val loss 6.607\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Ep 3 (Iter 000025): Train loss 5.750, Val loss 6.634\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Every effort moves you. \"I\"I and I had to the to the to the and the of the to the of the to Gisburn, and the of the the of the of the to the to the of the of the of the to the of\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Ep 4 (Iter 000030): Train loss 5.225, Val loss 6.944\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Ep 4 (Iter 000035): Train loss 4.304, Val loss 6.512\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Every effort moves you know   \"--and--and--I                 \", and, and, and, and I had been, and, and \" it.   \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Ep 5 (Iter 000040): Train loss 3.736, Val loss 6.383\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Every effort moves you know the picture to have the picture--his--his, the donkey of a little: \"strong, with a little of the donkey, in the picture--as, with a little of his painting, the donkey, the donkey, with a little\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Ep 6 (Iter 000045): Train loss 2.395, Val loss 6.244\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Ep 6 (Iter 000050): Train loss 2.948, Val loss 6.279\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Every effort moves you?\"     I, and he had a little the in a flash that he was a little the fact, and in the picture. Gisburn's my unexpected discovery; and as I had the picture--the. He was his\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Ep 7 (Iter 000055): Train loss 2.316, Val loss 6.169\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Ep 7 (Iter 000060): Train loss 1.003, Val loss 6.343\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Every effort moves you?\"  \"Yes--I glanced after him, so inevitably the last word. Gisburn's past! The women had been his pictures I remember getting off a prodigious phrase about the honour being _mine_--because he didn't say\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Ep 8 (Iter 000065): Train loss 0.860, Val loss 6.348\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Ep 8 (Iter 000070): Train loss 1.117, Val loss 6.375\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Every effort moves you?\" \"I that my hostess was \"interesting\": on that point I could have given Miss Croft the fact, and Mrs. \"I must have Jack himself, I had again run over from Monte Carlo; and Mrs. Gis\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Ep 9 (Iter 000075): Train loss 0.367, Val loss 6.498\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Ep 9 (Iter 000080): Train loss 0.289, Val loss 6.612\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Every effort moves you?\" \" on--forming, as it were, so inevitably the background of the house.\"  \" went on groping and muddling; then I looked at the donkey again. I may be pardoned the bull--that I found\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Ep 10 (Iter 000085): Train loss 0.263, Val loss 6.700\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Every effort moves you?\"  \"Yes--quite insensible to the irony. She wanted him vindicated--and by me!\"  He laughed again, and threw back his head to look up at the sketch of the donkey. \"There were days when I\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Ep 11 (Iter 000090): Train loss 0.151, Val loss 6.788\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Ep 11 (Iter 000095): Train loss 0.097, Val loss 6.805\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-11 07:07:36 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								      "Every effort moves you?\"  \"Yes--quite insensible to the irony. She wanted him vindicated--and by me!\"  He laughed again, and threw back his head to look up at the sketch of the donkey. \"There were days when I\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-17 20:30:42 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "Ep 12 (Iter 000100): Train loss 0.081, Val loss 6.832\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Ep 12 (Iter 000105): Train loss 0.089, Val loss 6.900\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-11 07:07:36 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								      "Every effort moves you?\"  \"Yes--quite insensible to the irony. She wanted him vindicated--and by me!\"  He laughed again, and threw back his head to look up at the sketch of the donkey. \"There were days when I\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-17 20:30:42 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "Ep 13 (Iter 000110): Train loss 0.045, Val loss 6.911\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Ep 13 (Iter 000115): Train loss 0.047, Val loss 6.903\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-11 07:07:36 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								      "Every effort moves you?\"  \"Yes--quite insensible to the irony. She wanted him vindicated--and by me!\"  He laughed again, and threw back his head to look up at the sketch of the donkey. \"There were days when I\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-17 20:30:42 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "Ep 14 (Iter 000120): Train loss 0.038, Val loss 6.907\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Ep 14 (Iter 000125): Train loss 0.040, Val loss 6.912\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Every effort moves you?\"  \"Yes--quite insensible to the irony. She wanted him vindicated--and by me!\"  He laughed again, and threw back his head to look up at the sketch of the donkey. \"There were days when I\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Ep 15 (Iter 000130): Train loss 0.041, Val loss 6.915\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-11 07:07:36 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								      "Every effort moves you?\"  \"Yes--quite insensible to the irony. She wanted him vindicated--and by me!\"  He laughed again, and threw back his head to look up at the sketch of the donkey. \"There were days when I\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-04-13 14:57:56 -04:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "import tiktoken\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-11 07:07:36 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "torch.manual_seed(123)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "model = GPTModel(GPT_CONFIG_124M)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "model.to(device)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "peak_lr = 5e-4\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-01 08:05:01 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "optimizer = torch.optim.AdamW(model.parameters(), weight_decay=0.1)\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-13 14:57:56 -04:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "tokenizer = tiktoken.get_encoding(\"gpt2\")\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-11 07:07:36 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "n_epochs = 15\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "train_losses, val_losses, tokens_seen, lrs = train_model(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    model, train_loader, val_loader, optimizer, device, n_epochs=n_epochs,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    eval_freq=5, eval_iter=1, start_context=\"Every effort moves you\",\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-17 20:30:42 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    tokenizer=tokenizer, warmup_steps=warmup_steps, \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    initial_lr=1e-5, min_lr=1e-5\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-11 07:07:36 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    ")"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "827e8d5e-0872-4b90-98ac-200c80ee2d53",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Looking at the results above, we can see that the model starts out generating incomprehensible strings of words, whereas, towards the end, it's able to produce grammatically more or less correct sentences\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- If we were to check a few passages it writes towards the end, we would find that they are contained in the training set verbatim -- it simply memorizes the training data\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Note that the overfitting here occurs because we have a very, very small training set, and we iterate over it so many times\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "  - The LLM training here primarily serves educational purposes; we mainly want to see that the model can learn to produce coherent text\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "  - Instead of spending weeks or months on training this model on vast amounts of expensive hardware, we load the pretrained weights"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "9decec45-4fdf-4ff6-85a7-1806613f8af7",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- A quick check that the learning rate behaves as intended"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-17 20:30:42 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 15,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-11 07:07:36 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "d8ebb8d2-8308-4a83-a2a6-730c3bf84452",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "data": {
							 
						 
					
						
							
								
									
										
										
										
											2024-04-17 20:30:42 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAfAAAAEmCAYAAACdy8LUAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAAA9hAAAPYQGoP6dpAABUvUlEQVR4nO3deVxUVf8H8M8Mw8ywzQyIzIDKYu6IyiIjWppBYVJupUmkiCZm9mjZolbqr3p6XOspfcylTKzMrXIJNSM0V3YUFxAtcUMHBGSGfZk5vz+QqUlURme4M/B9v173RXPv9858D9p8veeeew6PMcZACCGEEKvC5zoBQgghhBiPCjghhBBihaiAE0IIIVaICjghhBBihaiAE0IIIVaICjghhBBihaiAE0IIIVaICjghhBBihQRcJ9Ca6XQ6XL9+HU5OTuDxeFynQwghhAOMMZSVlcHDwwN8vumum6mAm9H169fRqVMnrtMghBBiAa5evYqOHTua7P2ogJuRk5MTgIY/NIlEwnE2hBBCuKDRaNCpUyd9TTAVKuBm1NhtLpFIqIATQkgbZ+pbqTSIjRBCCLFCVMAJIYQQK0QFnBBCCLFCnBfwVatWwdvbG2KxGEqlEqmpqfeM3759O3r06AGxWAw/Pz/s3bvX4DhjDAsWLIC7uzvs7OwQFhaGCxcuGMR8/PHHGDhwIOzt7SGTyZr8nCtXriAiIgL29vZwc3PD22+/jfr6+odqKyGEEGIqnBbwrVu3Yvbs2Vi4cCEyMzPRt29fhIeHo7CwsMn448ePIzIyElOmTMGJEycwatQojBo1CmfOnNHHLF26FCtWrMCaNWuQkpICBwcHhIeHo7q6Wh9TW1uLsWPHYvr06U1+jlarRUREBGpra3H8+HFs3LgRcXFxWLBggWl/AYQQQsiDYhwKDg5mM2bM0L/WarXMw8ODLVq0qMn4cePGsYiICIN9SqWSTZs2jTHGmE6nYwqFgi1btkx/vLS0lIlEIrZ58+Y73m/Dhg1MKpXesX/v3r2Mz+czlUql37d69WomkUhYTU1Ns9unVqsZAKZWq5t9DiGEkNbFXLWAs8fIamtrkZGRgXnz5un38fl8hIWFISkpqclzkpKSMHv2bIN94eHh2LlzJwAgLy8PKpUKYWFh+uNSqRRKpRJJSUkYP358s3JLSkqCn58f5HK5wedMnz4dZ8+ehb+/f5Pn1dTUoKamRv9ao9E06/PIg/nfgQvYe1oFkS0fYoGN/qe90AbtJSLIncRwk4ggl4ghdxLDQyaGwIbzu0aEEGISnBXwoqIiaLVagyIJAHK5HOfOnWvyHJVK1WS8SqXSH2/cd7eY5rjb5/z9M5qyaNEifPDBB83+HPLgispr8GnCeehY888RCvh4pL0juskd0U3uhK5ujvDrKIW71M58iRJCiJnQRC4mNG/ePIMegsbZd4jp7Tujgo4BPRROmP1kN1TX61Bdp0VNnRblNVrcLKtBQVk1CjXVKNDUQKWpRm29Djk3NMi5Ydgz0kFmh/7ezgjydkF/bxd0dXMEn09z1xNCLBtnBdzV1RU2NjYoKCgw2F9QUACFQtHkOQqF4p7xjT8LCgrg7u5uENOvX79m56ZQKO4YDd/4uXfLDQBEIhFEIlGzP4c8uPis6wCA5wI64infu/+ZNNLpGK7dqsL5gjKcLyzDeVUZcgvKcb6gDPmlVcg/WYWdJxve09neFkO7uyG0pxyDu7nCSWxr1rYQQsiD4KyAC4VCBAYGIjExEaNGjQLQsHpXYmIiXnvttSbPCQkJQWJiIl5//XX9voSEBISEhAAAfHx8oFAokJiYqC/YGo0GKSkpdx1xfrfP+fjjj1FYWAg3Nzf950gkEvTq1cv4xhKTKtRUI/VSCQBgeB/3+0Q34PN58GxnD8929gjr9dftkYqaepy8WorUvBKkXy7BiSuluFVZh59O5OOnE/mwteFB6dMOYT3dMLyPO9ycxGZpEyGEGIvTLvTZs2cjOjoaQUFBCA4OxmeffYaKigrExMQAACZOnIgOHTpg0aJFAIBZs2ZhyJAh+OSTTxAREYEtW7YgPT0d69atA9Awz+zrr7+Of//73+jatSt8fHwwf/58eHh46P+RADQ8411SUoIrV65Aq9Xi5MmTAIAuXbrA0dERTz31FHr16oUJEyZg6dKlUKlUeP/99zFjxgy6wrYAe0/fAGNAgKcMHWQPd//aQSTAoC6uGNTFFQBQp9Uh4/ItJOYUIDGnEBeLKnD0jyIc/aMIH8Zn47Gu7TEmoAOe6qWAndDGFM0hhJAHY9Ix7Q9g5cqVzNPTkwmFQhYcHMySk5P1x4YMGcKio6MN4rdt28a6devGhEIh8/X1ZXv27DE4rtPp2Pz585lcLmcikYiFhoay3Nxcg5jo6GgG4I7t4MGD+phLly6xp59+mtnZ2TFXV1f25ptvsrq6OqPaRo+Rmcfzq48xrznxbP2Ri2b/rD8Ly9iXh/9ko1YdZV5z4vWb74Jf2JvbTrL0S8VMp9OZPQ9CiPUyVy3gMcaMGMdLjKHRaCCVSqFWq2k1MhO5oa5CyKID4PGApLmhUEhbrks7r6gCO07kY8eJa7haUqXf79dBikkDvfFMX3eIBHRVTggxZK5aQAXcjKiAm976o3n4KD4bwd4u2PZKCCc5MMaQfvkWtqZdxe6s66it1wEAXB2FeDHYEy8N8IKbhO6VE0IamKsW0KwWxKrEn2oYKR7RzMFr5sDj8dDf2wXLx/ZF8rxQvB3eHQqJGEXltVhx4A88uvQg5u88g/zSqvu/GSGEPCC6AjcjugI3rWu3KvHokoPg8YCUd0MtakR4nVaHX88W4Otjeci4fAsAYGvDw/OBnfDq44+gk4s9xxkSQrhCV+Ckzdt7+gYAQOnjYlHFGwBsbfiI6OOOH14JweapAxDSuR3qtAybU6/g8eW/450fsnCdrsgJISZEBZxYjT2nGgr4M308OM7k7ng8HkIeaYfNsQOwbVoIHuvqCq2OYVv6NQxd/juW/nIOmuo6rtMkhLQCVMCJVbhSXImsa2rwecCw3vefec0SBPu44NspSvw4fSCCfVxQU6/DF7//iceX/Y6Nxy+hTqvjOkVCiBWjAk6swp7b3ecDH3GFq6N1TaYT6OWMrbEDsG5CIDq3d0BJRS0W7j6Lp/57GAfPFXKdHiHESlEBJ1bBEkafPwwej4enfBXY//pgfDSqN1wdhcgrqkBMXBqmfZtOI9YJIUajAk4sXl5RBc5e18CGz8OwZixcYslsbfiYMMALB996HFMf84ENn4f9ZwsQ9skhrD30J3WrE0KajQo4sXh7bl99D+riCmcHIcfZmIaT2BbvRfTCnpmPIsjLGVV1Wizadw4RK44g/fZCLYQQci9UwInFi9ePPrfO7vN76aGQYNu0ECx7vg9cHIQ4X1COsWuT8OHP2aiq1XKdHiHEglEBJxbtj8IynFOVwdaGh/Be1t19fjd8Pg9jgzrhwJtDMDawIxgDvj6Wh+ErjiCNrsYJIXdBBZxYtMar78e6tofU3pbjbMxLZi/EsrF9sSGmPxQSMfKKKjCOrsYJIXdBBZxYtD2tuPv8boZ2d8P+NwZjXNBfV+MRK47g9DU116kRQiwIFXBisXJVZbhQWA6hDR9hveRcp9OipHa2WPr8X1fjF4sqMGb1Maw7/Cd0Olq+gBBCBZxYsMbR50O6t4dE3Lq7z+9maHc3/PL6Yxjmq0CdluE/e88hekMqCjXVXKdGCOEYFXBikRhjrXr0uTFk9kKsfikAi8b4QWzLx5ELRRj2+REk5hRwnRohhENUwIlFyrlRhotFFRAJ+Ajt2ba6z5vC4/EQGeyJ+H89hl7uEpRU1GLKxnQs2puDepr8hZA2iQo4sUiNU6cO7e4GR5GA42wsRxc3R+yYMRCTB/kAANYevogXv0pBYRl1qRPS1lABJxaHMaZfvOSZvm27+7wpIoENFjzbC19EBcBRJEBqXgkiVhxFah49M05IW0IFnFicM/kaXC6uhJ2tDZ7o4cZ1OhZruJ87dr02CN3kjrhZVoPIL5Px5eGLYIxGqRPSFlABJxYn/nRD9/kTPd1gL6Tu83t5pL0
							 
						 
					
						
							
								
									
										
										
										
											2024-03-11 07:07:36 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								      "text/plain": [
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 21:22:49 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								       "<Figure size 500x300 with 1 Axes>"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-11 07:07:36 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								      ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "display_data"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 21:22:49 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "plt.figure(figsize=(5, 3))\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-11 07:07:36 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "plt.plot(range(len(lrs)), lrs)\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-16 08:10:58 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "plt.ylabel(\"Learning rate\")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "plt.xlabel(\"Steps\")\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-11 07:07:36 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "plt.show()"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
									
										
										
										
											2024-03-16 08:10:58 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "a2f85b01-859b-4454-a3a3-c7ef593735a6",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-11 07:07:36 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- And a quick look at the loss curves"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-17 20:30:42 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 16,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-11 07:07:36 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "445d8155-6eae-4b50-a381-d0820ebc27cc",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "scrolled": true
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
									
										
										
										
											2024-04-03 20:19:08 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stderr",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
									
										
										
										
											2024-04-17 20:30:42 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "/var/folders/jg/tpqyh1fd5js5wsr1d138k3n40000gn/T/ipykernel_36856/3589549395.py:5: UserWarning: The figure layout has changed to tight\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-03 20:19:08 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "  plt.tight_layout(); plt.savefig(\"3.pdf\")\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
									
										
										
										
											2024-03-11 07:07:36 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "data": {
							 
						 
					
						
							
								
									
										
										
										
											2024-04-17 20:30:42 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAekAAAEiCAYAAADd4SrgAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAAA9hAAAPYQGoP6dpAABZdUlEQVR4nO3deXxM1/vA8c9Mlsm+iMiCxBYiQexKqrTyFapaVKlfqnRTO9WqamvpglqqipbSb/m21dKNqrWhtqp9V8QWu4g1q2wz5/fHMDH2Jcmdief9et3XzD13e06Geeaee+49OqWUQgghhBA2R691AEIIIYS4OUnSQgghhI2SJC2EEELYKEnSQgghhI2SJC2EEELYKEnSQgghhI2SJC2EEELYKEnSQgghhI2SJC2EEELYKEnSQtiwI0eOoNPp2L59u9ahCCE0IElaiEKm0+luOw0fPlzrEIUQNspR6wCEKO5Onz5teT9nzhyGDh1KQkKCpczDw0OLsIQQdkDOpIUoZIGBgZbJ29sbnU5nmS9VqhTjx4+nTJkyGAwGatasyZIlS265L6PRyMsvv0x4eDjHjh0D4Pfff6d27dq4uLhQoUIFPvjgA/Ly8izb6HQ6vv76a9q2bYubmxthYWHMnz/fsvzixYvExcXh7++Pq6srYWFhzJgx45Yx/PLLL1SvXh1XV1f8/PyIiYkhIyPDsvzrr7+matWquLi4EB4ezpdffmm1/fHjx+nQoQM+Pj6UKFGCZ555hiNHjliWd+3alTZt2jBu3DiCgoLw8/OjV69e5Obm3vXfXIhiQwkhisyMGTOUt7e3ZX78+PHKy8tL/fjjj2rfvn3q7bffVk5OTmr//v1KKaUSExMVoLZt26aysrJU27ZtVa1atVRycrJSSqnVq1crLy8vNXPmTHXo0CH1559/qnLlyqnhw4dbjgGoMmXKqB9++EEdOHBA9e3bV3l4eKjz588rpZTq1auXqlmzptq0aZNKTExU8fHxav78+TeN/9SpU8rR0VGNHz9eJSYmqp07d6ovvvhCpaWlKaWU+v7771VQUJD69ddf1eHDh9Wvv/6qSpQooWbOnKmUUionJ0dVrVpVvfzyy2rnzp1qz5496v/+7/9UlSpVVHZ2tlJKqS5duigvLy/VvXt3tXfvXvXHH38oNzc3NW3atIL9MISwA5KkhShC1yfp4OBgNWLECKt16tWrp3r27KmUyk/Sa9asUc2aNVOPPvqounTpkmXdZs2aqZEjR1pt/91336mgoCDLPKDef/99y3x6eroC1OLFi5VSSrVu3Vq99NJLdxX/li1bFKCOHDly0+UVK1ZUP/zwg1XZRx99pBo2bGiJrUqVKspkMlmWZ2dnK1dXV7V06VKllDlJh4aGqry8PMs6zz33nOrYseNdxShEcSLXpIXQSGpqKqdOnSI6OtqqPDo6mh07dliVderUiTJlyvDXX3/h6upqKd+xYwdr165lxIgRljKj0UhWVhaZmZm4ubkBUKNGDctyd3d3vLy8SE5OBqBHjx48++yzbN26lebNm9OmTRsaNWp005ijoqJo1qwZ1atXJzY2lubNm9O+fXt8fX3JyMjg0KFDvPLKK7z22muWbfLy8vD29rbEe/DgQTw9Pa32m5WVxaFDhyzzkZGRODg4WOaDgoLYtWvXbf6aQhRPkqSFsANPPvkk33//PevWreOJJ56wlKenp/PBBx/Qrl27G7ZxcXGxvHdycrJaptPpMJlMALRs2ZKjR4+yaNEi4uPjadasGb169WLcuHE37NPBwYH4+Hj++ecf/vzzTyZNmsR7773Hhg0bLD8Ipk+fToMGDW7Y7mq8derUYdasWTfs29/f/67iFeJhIklaCI14eXkRHBzM2rVradKkiaV87dq11K9f32rdHj16UK1aNZ5++mkWLlxoWb927dokJCRQqVKlB4rF39+fLl260KVLFxo3bszAgQNvmqTBnDCjo6OJjo5m6NChhIaGMnfuXAYMGEBwcDCHDx8mLi7uptvWrl2bOXPmUKpUKby8vB4oZiEeBpKkhdDQwIEDGTZsGBUrVqRmzZrMmDGD7du33/RMs0+fPhiNRp566ikWL17Mo48+ytChQ3nqqacICQmhffv26PV6duzYwe7du/n444/vKoahQ4dSp04dIiMjyc7OZsGCBVStWvWm627YsIHly5fTvHlzSpUqxYYNGzh79qxl/Q8++IC+ffvi7e1NixYtyM7OZvPmzVy8eJEBAwYQFxfH2LFjeeaZZ/jwww8pU6YMR48e5bfffuPtt9+mTJky9//HFKIYkiQthIb69u1LSkoKb775JsnJyURERDB//nzCwsJuun7//v0xmUw8+eSTLFmyhNjYWBYsWMCHH37I6NGjcXJyIjw8nFdfffWuY3B2dmbw4MEcOXIEV1dXGjduzOzZs2+6rpeXF6tXr2bChAmkpqYSGhrKp59+SsuWLQF49dVXcXNzY+zYsQwcOBB3d3eqV69O//79AXBzc2P16tUMGjSIdu3akZaWRunSpWnWrJmcWQtxEzqllNI6CCGEEELcSB5mIoQQQtgoSdJCCCGEjZIkLYQQQtgoSdJCCCGEjZIkLYQQQtgoSdJCCCGEjZIkfRtffPEF5cqVw8XFhQYNGrBx40ZN41m9ejWtW7cmODgYnU7HvHnzrJYrpRg6dChBQUG4uroSExPDgQMHrNa5cOECcXFxeHl54ePjwyuvvEJ6errVOjt37qRx48a4uLhQtmxZxowZc0MsP//8M+Hh4bi4uFC9enUWLVr0QHUbNWoU9erVw9PTk1KlStGmTRurMZfB/HznXr164efnh4eHB88++yxnzpyxWufYsWO0atUKNzc3SpUqxcCBA62GbQRYuXIltWvXxmAwUKlSJWbOnHlDPAX52U+ZMoUaNWrg5eWFl5cXDRs2ZPHixXZfr5v55JNP0Ol0lvui7bl+w4cPR6fTWU3h4eF2X6+rTp48yQsvvICfnx+urq5Ur16dzZs3W5bb6/dJuXLlbvjcdDodvXr1Auzwc9N2fA/bNXv2bOXs7Ky++eYb9e+//6rXXntN+fj4qDNnzmgW06JFi9R7772nfvvtNwWouXPnWi3/5JNPlLe3t5o3b57asWOHevrpp1X58uXV5cuXLeu0aNFCRUVFqfXr16s1a9aoSpUqqU6dOlmWp6SkqICAABUXF6d2796tfvzxR+Xq6qq++uoryzpr165VDg4OasyYMWrPnj3q/fffV05OTmrXrl33XbfY2Fg1Y8YMtXv3brV9+3b15JNPqpCQEJWenm5Zp3v37qps2bJq+fLlavPmzeqRRx5RjRo1sizPy8tT1apVUzExMWrbtm1q0aJFqmTJkmrw4MGWdQ4fPqzc3NzUgAED1J49e9SkSZOUg4ODWrJkiWWdgv7s58+frxYuXKj279+vEhIS1LvvvqucnJzU7t277bpe19u4caMqV66cqlGjhurXr5+l3F7rN2zYMBUZGalOnz5tmc6ePWv39VJKqQsXLqjQ0FDVtWtXtWHDBnX48GG1dOlSdfDgQcs69vp9kpycbPWZxcfHK0CtWLFCKWV/n5sk6VuoX7++6tWrl2XeaDSq4OBgNWrUKA2jynd9kjaZTCowMFCNHTvWUnbp0iVlMBjUjz/+qJRSas+ePQpQmzZtsqyzePFipdPp1MmTJ5VSSn355ZfK19fXMravUkoNGjRIValSxTLfoUMH1apVK6t4GjRooF5//fUCq19ycrIC1KpVqyx1cXJyUj///LNlnb179ypArVu3Till/hGj1+tVUlKSZZ0pU6YoLy8vS33efvttFRkZaXWsjh07qtjYWMt8UXz2vr6+6uuvvy429UpLS1NhYWEqPj5eNWnSxJKk7bl+w4YNU1FRUTddZs/1Usr8f/rRRx+95fLi9H3Sr18/VbFiRWUymezyc5Pm7pvIyclhy5YtxMTEWMr0ej0xMTGsW7dOw8huLTExkaSkJKuYvb29adCggSXmdevW4ePjQ926dS3rxMTEoNfr2bBhg2Wdxx57DGdnZ8s6sbGxJCQkcPHiRcs61x7n6joF+bdJSUkBoESJEgBs2bKF3Nx
							 
						 
					
						
							
								
									
										
										
										
											2024-03-11 07:07:36 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								      "text/plain": [
							 
						 
					
						
							
								
									
										
										
										
											2024-04-01 08:05:01 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								       "<Figure size 500x300 with 2 Axes>"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-11 07:07:36 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								      ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "display_data"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "from previous_chapters import plot_losses\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "epochs_tensor = torch.linspace(1, n_epochs, len(train_losses))\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 21:22:49 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "plot_losses(epochs_tensor, tokens_seen, train_losses, val_losses)\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-03 20:19:08 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "plt.tight_layout(); plt.savefig(\"3.pdf\")\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 21:22:49 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "plt.show()"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-11 07:07:36 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
									
										
										
										
											2024-03-16 08:10:58 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "c16fa614-67e1-4254-8b7e-c3e2f690c29c",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Note that the model is overfitting here because the dataset is kept very small for educational purposes (so that the code can be executed on a laptop computer)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- For a longer pretraining run on a much larger dataset, see [../../ch05/03_bonus_pretraining_on_gutenberg](../../ch05/03_bonus_pretraining_on_gutenberg)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
									
										
										
										
											2024-03-11 07:07:36 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								  }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  "kernelspec": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "display_name": "Python 3 (ipykernel)",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "language": "python",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "name": "python3"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  "language_info": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "codemirror_mode": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "name": "ipython",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "version": 3
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "file_extension": ".py",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "mimetype": "text/x-python",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "name": "python",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "nbconvert_exporter": "python",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "pygments_lexer": "ipython3",
							 
						 
					
						
							
								
									
										
										
										
											2024-05-23 17:19:01 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "version": "3.11.4"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-11 07:07:36 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								  }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 "nbformat": 4,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 "nbformat_minor": 5
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								}