2024-03-11 07:07:36 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								{
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 "cells": [
							 
						 
					
						
							
								
									
										
										
										
											2024-03-19 09:26:26 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "9a5936bd-af17-4a7e-a4d2-e910411708ea",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "<font size=\"1\">\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "Supplementary code for \"Build a Large Language Model From Scratch\": <a href=\"https://www.manning.com/books/build-a-large-language-model-from-scratch\">https://www.manning.com/books/build-a-large-language-model-from-scratch</a> by <a href=\"https://sebastianraschka.com\">Sebastian Raschka</a><br>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "Code repository: <a href=\"https://github.com/rasbt/LLMs-from-scratch\">https://github.com/rasbt/LLMs-from-scratch</a>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "</font>"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2024-03-11 07:07:36 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "af53bcb1-ff9d-49c7-a0bc-5b8d32ff975b",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "## Appendix D: Adding Bells and Whistles to the Training Loop"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "4f58c142-9434-49af-b33a-356b80a45b86",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- In this appendix, we add a few more advanced features to the training function, which are used in typical pretraining and finetuning; finetuning is covered in chapters 6 and 7\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- The next three sections below discuss learning rate warmup, cosine decay, and gradient clipping\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- The final section adds these techniques to the training function"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "744def4f-c03f-42ee-97bb-5d7d5b89b723",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- We start by initializing a model reusing the code from chapter 5:"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "execution_count": 1,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "8755bd5e-bc06-4e6e-9e63-c7c82b816cbe",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "torch version: 2.2.1\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "tiktoken version: 0.5.1\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "from importlib.metadata import version\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "import torch\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "import tiktoken\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(\"torch version:\", version(\"torch\"))\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(\"tiktoken version:\", version(\"tiktoken\"))\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "from previous_chapters import GPTModel\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "GPT_CONFIG_124M = {\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"vocab_size\": 50257,  # Vocabulary size\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"ctx_len\": 256,       # Shortened context length (orig: 1024)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"emb_dim\": 768,       # Embedding dimension\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"n_heads\": 12,        # Number of attention heads\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"n_layers\": 12,       # Number of layers\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"drop_rate\": 0.1,     # Dropout rate\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"qkv_bias\": False     # Query-key-value bias\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "}\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "torch.manual_seed(123)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "model = GPTModel(GPT_CONFIG_124M)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "model.eval();  # Disable dropout during inference"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "51574e57-a098-412c-83e8-66dafa5a0b99",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Next, using the same code we used in chapter 5, we initialize the data loaders:"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "execution_count": 2,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "386ca110-2bb4-42f1-bd54-8836df80acaa",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "import os\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "import urllib.request\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "file_path = \"the-verdict.txt\"\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "url = \"https://raw.githubusercontent.com/rasbt/LLMs-from-scratch/main/ch02/01_main-chapter-code/the-verdict.txt\"\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "if not os.path.exists(file_path):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    with urllib.request.urlopen(url) as response:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        text_data = response.read().decode('utf-8')\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    with open(file_path, \"w\", encoding=\"utf-8\") as file:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        file.write(text_data)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "else:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    with open(file_path, \"r\", encoding=\"utf-8\") as file:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        text_data = file.read()"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "execution_count": 3,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "ae96992b-536a-4684-a924-658b9ffb7e9c",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "from previous_chapters import create_dataloader_v1\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "# Train/validation ratio\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "train_ratio = 0.90\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "split_idx = int(train_ratio * len(text_data))\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "torch.manual_seed(123)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "train_loader = create_dataloader_v1(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    text_data[:split_idx],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    batch_size=2,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    max_length=GPT_CONFIG_124M[\"ctx_len\"],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    stride=GPT_CONFIG_124M[\"ctx_len\"],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    drop_last=True,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    shuffle=True\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    ")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "val_loader = create_dataloader_v1(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    text_data[split_idx:],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    batch_size=2,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    max_length=GPT_CONFIG_124M[\"ctx_len\"],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    stride=GPT_CONFIG_124M[\"ctx_len\"],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    drop_last=False,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    shuffle=False\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    ")"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "939c08d8-257a-41c6-b842-019f7897ac74",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "## D.1 Learning rate warmup"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "7fafcd30-ddf7-4a9f-bcf4-b13c052b3133",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- When training complex models like LLMs, implementing learning rate warmup can help stabilize the training\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- In learning rate warmup, we gradually increase the learning rate from a very low value (`initial_lr`) to a user-specified maximum (`peak_lr`)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- This way, the model will start the training with small weight updates, which helps decrease the risk of large destabilizing updates during the training"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "execution_count": 4,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "2bb4790b-b8b6-4e9e-adf4-704a04b31ddf",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "135\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "n_epochs = 15\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "peak_lr = 0.01\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "initial_lr = 0.0001\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "optimizer = torch.optim.AdamW(model.parameters(), lr=peak_lr, weight_decay=0.1)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "total_training_steps = len(train_loader) * n_epochs\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(total_training_steps)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "5bf3a8da-abc4-4b80-a5d8-f1cc1c7cc5f3",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Typically, the number of warmup steps is between 10% and 20% of the total number of steps\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- We can compute the increment as the difference between the `peak_lr` and `initial_lr` divided by the number of warmup steps"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "execution_count": 5,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "e075f80e-a398-4809-be1d-8019e1d31c90",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "warmup_steps = 20\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "lr_increment = (peak_lr - initial_lr) / warmup_steps\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "global_step = -1\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "track_lrs = []\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "for epoch in range(n_epochs):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    for input_batch, target_batch in train_loader:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        optimizer.zero_grad()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        global_step += 1\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        if global_step < warmup_steps:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            lr = initial_lr + global_step * lr_increment\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        else:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            lr = peak_lr\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        # Apply the calculated learning rate to the optimizer\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        for param_group in optimizer.param_groups:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            param_group[\"lr\"] = lr\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        track_lrs.append(optimizer.param_groups[0][\"lr\"])\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        # Calculate loss and update weights"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "execution_count": 6,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "cb6da121-eeed-4023-bdd8-3666c594b4ed",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "data": {
							 
						 
					
						
							
								
									
										
										
										
											2024-03-16 08:10:58 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAlMAAAGwCAYAAACNeeBZAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAAA9hAAAPYQGoP6dpAAA9s0lEQVR4nO3dfVyV9eH/8fdB5cYb8G6BCCoVm5UoCYqY+7omRWUryjV0ftPM6dqsNCynplCtxrIsZ1nkfn2zfZfTuZWV32I5rKxJqIgKmTctCxUP3g1QSu7O9fvDnSNnonHkHK5z83o+HjyM63zO4XN9VvLedb3P51gMwzAEAACAixJk9gQAAAB8GWEKAACgDQhTAAAAbUCYAgAAaAPCFAAAQBsQpgAAANqAMAUAANAGHc2egK+y2WyqqKhQt27dZLFYzJ4OAABoBcMwdPLkSUVHRysoyD3XlAhTF6miokKxsbFmTwMAAFyEAwcOKCYmxi2vRZi6SN26dZN05n+M8PBwk2cDAABao6amRrGxsY7f4+5AmLpI9lt74eHhhCkAAHyMOys6FNABAADagDAFAADQBoQpAACANiBMAQAAtAFhCgAAoA0IUwAAAG1AmAIAAGgDwhQAAEAbEKYAAADagDAFAADQBqaHqWXLlmnAgAEKDQ1VSkqKNm/efMHxa9as0cCBAxUaGqqEhAS98847To+//vrruv7669WrVy9ZLBZt3779nNc4ffq0ZsyYoV69eqlr164aN26cKisr3XlaAAAgQJgaplavXq2srCzl5ORo27ZtGjJkiNLT03XkyJEWx2/atEkTJkzQ1KlTVVJSooyMDGVkZKisrMwxpra2VqNGjdKTTz553p/7wAMP6O2339aaNWv04YcfqqKiQrfffrvbzw8AAPg/i2EYhlk/PCUlRcOGDdPzzz8vSbLZbIqNjdV9992nuXPnnjM+MzNTtbW1WrdunePYiBEjlJiYqLy8PKexX375peLi4lRSUqLExETH8erqan3nO9/RypUr9eMf/1iStHv3bl1xxRUqLCzUiBEjWjX3mpoaRUREqLq6mg869qDqrxt0sq7B7GkAAEzQu2uIQjt1cOtreuL3d0e3vMpFqK+vV3FxsebNm+c4FhQUpLS0NBUWFrb4nMLCQmVlZTkdS09P19q1a1v9c4uLi9XQ0KC0tDTHsYEDB6pfv34XDFN1dXWqq6tzfF9TU9Pqn4mLs+mfx/Tf/69INtPiPgDATH+4e7j+67vfMXsa38q0MHXs2DE1NTUpMjLS6XhkZKR2797d4nOsVmuL461Wa6t/rtVqVXBwsLp37+7S6+Tm5urRRx9t9c9B2/191xHZDKlDkEUdgyxmTwcA0M6CLL7xd79pYcrXzJs3z+mqWE1NjWJjY02ckf8rO1QtSVo0brDGJcWYPBsAAFpmWpjq3bu3OnTocM676CorKxUVFdXic6Kiolwaf77XqK+vV1VVldPVqW97nZCQEIWEhLT656BtmmyGPq04E6YSYiJMng0AAOdn2rv5goODlZSUpIKCAscxm82mgoICpaamtvic1NRUp/GStH79+vOOb0lSUpI6derk9Dp79uxReXm5S68Dz9p/7JRq65sU1qmDLvtOV7OnAwDAeZl6my8rK0uTJ09WcnKyhg8friVLlqi2tlZTpkyRJE2aNEl9+/ZVbm6uJGnmzJkaPXq0Fi9erLFjx2rVqlXaunWrli9f7njNEydOqLy8XBUVFZLOBCXpzBWpqKgoRUREaOrUqcrKylLPnj0VHh6u++67T6mpqa1+Jx88r/Tft/iujA5XB/pSAAAvZmqYyszM1NGjR5WdnS2r1arExETl5+c7Subl5eUKCjp78WzkyJFauXKlFixYoPnz5ys+Pl5r167VoEGDHGPeeustRxiTpPHjx0uScnJy9Mgjj0iSnn32WQUFBWncuHGqq6tTenq6XnjhhXY4Y7RW6cEz75ZM6MstPgCAdzN1nylfxj5TnnVH3iZt+fJfWnzHEMrnAAC38cTvb9M/Tgb4T2fK5/++MkX5HADg5QhT8Dr7j53S15TPAQA+gjAFr2Mvn19F+RwA4AMIU/A6Ow+eCVODKJ8DAHwAYQpex77zOe/kAwD4AsIUvErz8vlgyucAAB9AmIJX+eLomfJ55+AOupTyOQDABxCm4FUcO5/3oXwOAPANhCl4FXuYonwOAPAVhCl4FXv5nL4UAMBXEKbgNZpshsoO8Zl8AADfQpiC1/ji6Cl900D5HADgWwhT8BrsfA4A8EWEKXgNyucAAF9EmILXKD3IzucAAN9DmIJXaL7zOWEKAOBLCFPwCpTPAQC+ijAFr7DzIOVzAIBvIkzBK1A+BwD4KsIUvIJ953P6UgAAX0OYgumal8/5GBkAgK8hTMF0/2xWPo/rTfkcAOBbCFMwXSnlcwCADyNMwXSljr5Ud3MnAgDARSBMwXSO8nlMuMkzAQDAdYQpmIqdzwEAvo4wBVNRPgcA+DrCFExlL58Pio6gfA4A8EmEKZiKnc8BAL6OMAVTlVI+BwD4OMIUTNNkM7SL8jkAwMcRpmAae/m8C+VzAIAPI0zBNDsdO59TPgcA+C7CFExTRvkcAOAHCFMwjb18PjiGMAUA8F2EKZiiscmmTyu4MgUA8H2EKZjin0drdbrBpi7BHXRp7y5mTwcAgItGmIIp7Lf4roqOUBDlcwCADyNMwRRljs06ucUHAPBthCmYYufBKkls1gkA8H2EKbS7xiabdh0+s/M55XMAgK8jTKHdUT4HAPgTwhTanaN83pfyOQDA9xGm0O5K6UsBAPwIYQrtzn5lijAFAPAHhCm0q+blc7ZFAAD4A8IU2tXnR0/pdINNXUM6Kq4X5XMAgO8jTKFdlR48c4vvyuhwyucAAL9AmEK7KqMvBQDwM4QptCt7+XwwfSkAgJ8gTKHdsPM5AMAfEabQbiifAwD8EWEK7cZePr+K8jkAwI8QptBuKJ8DAPwRYQrtZqc9TFE+BwD4EcIU2kVjk02fUT4HAPghwhTaBeVzAIC/Mj1MLVu2TAMGDFBoaKhSUlK0efPmC45fs2aNBg4cqNDQUCUkJOidd95xetwwDGVnZ6tPnz4KCwtTWlqa9u3b5zRm7969uvXWW9W7d2+Fh4dr1KhRev/9991+bjhrJ+VzAICfMjVMrV69WllZWcrJydG2bds0ZMgQpaen68iRIy2O37RpkyZMmKCpU6eqpKREGRkZysjIUFlZmWPMokWLtHTpUuXl5amoqEhdunRRenq6Tp8+7Rhz8803q7GxURs2bFBxcbGGDBmim2++WVar1ePnHKgonwMA/JXFMAzDrB+ekpKiYcOG6fnnn5ck2Ww2xcbG6r777tPcuXPPGZ+Zmana2lqtW7fOcWzEiBFKTExUXl6eDMNQdHS0Zs+erQcffFCSVF1drcjISK1YsULjx4/XsWPH9J3vfEcbN27U97//fUnSyZMnFR4ervXr1ystLa1Vc6+pqVFERISqq6sVHh7e1qXwe7e98A+VlFfpd+MTdWtiX7OnAwAIUJ74/W3alan6+noVFxc7hZegoCClpaWpsLCwxecUFhaeE3bS09Md4/fv3y+r1eo0JiIiQikpKY4xvXr10ve+9z394Q9/UG1trRobG/XSSy/pkksuUVJS0nnnW1dXp5qaGqcvtE5jk027Ks6sF1emAAD+xrQwdezYMTU1NSkyMtLpeGRk5Hlvt1mt1guOt/95oTEWi0V///vfVVJSom7duik0NFTPPPOM8vPz1aNHj/PONzc3VxEREY6v2NhY1044gO07ckp1jWfK5wMonwMA/IzpBfT2ZhiGZsyYoUsuuUQfffSRNm/erIyMDP3oRz/S4cOHz/u8efPmqbq62vF14MCBdpy1b7N/uDHlcwCAPzItTPXu3VsdOnRQZWWl0/HKykpFRUW1+JyoqKgLjrf/eaExGzZs0Lp167Rq1Spdc801Gjp0qF544QWFhYXp1VdfPe98Q0JCFB4e7vSF1rGXzwezWScAwA+ZFqaCg4OVlJSkgoICxzGbzaaCggKlpqa2+JzU1FSn8ZK0fv16x/i4uDhFRUU5jampqVF
							 
						 
					
						
							
								
									
										
										
										
											2024-03-11 07:07:36 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								      "text/plain": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "<Figure size 640x480 with 1 Axes>"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "display_data"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "import matplotlib.pyplot as plt\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "plt.ylabel(\"Learning rate\")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "plt.xlabel(\"Step\")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "plt.plot(range(total_training_steps), track_lrs);"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "7b3996b6-3f7a-420a-8584-c5760249f3d8",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "## D.2 Cosine decay"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "c5216214-de79-40cf-a733-b1049a73023c",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Another popular technique for training complex deep neural networks is cosine decay, which also adjusts the learning rate across training epochs\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- In cosine decay, the learning rate follows a cosine curve, decreasing from its initial value to near zero following a half-cosine cycle\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- This gradual reduction is designed to slow the pace of learning as the model begins to improve its weights; it reduces the risk of overshooting minima as the training progresses,  which is crucial for stabilizing the training in its later stages\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Cosine decay is often preferred over linear decay for its smoother transition in learning rate adjustments, but linear decay is also used in practice (for example, [OLMo: Accelerating the Science of Language Models](https://arxiv.org/abs/2402.00838))"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "execution_count": 7,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "4e8d2068-a057-4abf-b478-f02cc37191f6",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "import math\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "min_lr = 0.1 * initial_lr\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "track_lrs = []\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "lr_increment = (peak_lr - initial_lr) / warmup_steps\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "global_step = -1\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "for epoch in range(n_epochs):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    for input_batch, target_batch in train_loader:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        optimizer.zero_grad()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        global_step += 1\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        # Adjust the learning rate based on the current phase (warmup or cosine annealing)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        if global_step < warmup_steps:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            # Linear warmup\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            lr = initial_lr + global_step * lr_increment  \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        else:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            # Cosine annealing after warmup\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            progress = ((global_step - warmup_steps) / \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                        (total_training_steps - warmup_steps))\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            lr = min_lr + (peak_lr - min_lr) * 0.5 * (1 + math.cos(math.pi * progress))\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        # Apply the calculated learning rate to the optimizer\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        for param_group in optimizer.param_groups:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            param_group[\"lr\"] = lr\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        track_lrs.append(optimizer.param_groups[0][\"lr\"])\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        # Calculate loss and update weights"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "execution_count": 8,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "0e779e33-8a44-4984-bb23-be0603dc4158",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "data": {
							 
						 
					
						
							
								
									
										
										
										
											2024-03-16 08:10:58 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAlMAAAGwCAYAAACNeeBZAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAAA9hAAAPYQGoP6dpAABiTUlEQVR4nO3deVhUZf8G8HsWZoZ12GRHwaVcAFFQxDRbKDJbSDM1SzPTMtewLM20fO1nWpaZmlqWVppmb5mZWYaVG6IIIoi4JIqCwyoM+zJzfn8go7yigSxnlvtzXXP5cuaZ4TvnNbh9zvc8j0QQBAFEREREdFukYhdAREREZMoYpoiIiIiagWGKiIiIqBkYpoiIiIiagWGKiIiIqBkYpoiIiIiagWGKiIiIqBnkYhdgqvR6PbKysmBvbw+JRCJ2OURERNQIgiCguLgYXl5ekEpbZk6JYeo2ZWVlwdfXV+wyiIiI6DZcvHgRPj4+LfJeDFO3yd7eHkDt/xkODg4iV0NERESNodVq4evra/g93hIYpm5T3aU9BwcHhikiIiIT05ItOmxAJyIiImoGhikiIiKiZmCYIiIiImoGhikiIiKiZmCYIiIiImoGhikiIiKiZmCYIiIiImoGhikiIiKiZmCYIiIiImoGhikiIiKiZhA9TK1cuRJ+fn5QqVQICwvD4cOHbzl+69at6Nq1K1QqFQIDA7Fz5856z//www948MEH4eLiAolEgmPHjt3wHhUVFZg8eTJcXFxgZ2eHYcOGITs7uyU/FhEREVkIUcPUli1bEB0djfnz5yMhIQE9e/ZEZGQkcnJyGhx/8OBBjBo1CuPHj0diYiKioqIQFRWFlJQUw5jS0lIMGDAAixcvvun3feWVV/Dzzz9j69at+Pvvv5GVlYWhQ4e2+OcjIiIi8ycRBEEQ65uHhYWhT58+WLFiBQBAr9fD19cXU6dOxRtvvHHD+BEjRqC0tBQ7duwwHOvXrx+Cg4OxevXqemPPnz8Pf39/JCYmIjg42HC8qKgI7dq1w6ZNm/Dkk08CANLS0tCtWzfExsaiX79+japdq9VCrVajqKiIGx23oqKyapRV10AulcJKJoHKSgalXNqiG1QSEZHlaI3f3/IWeZfbUFVVhaNHj2L27NmGY1KpFBEREYiNjW3wNbGxsYiOjq53LDIyEtu2bWv09z169Ciqq6sRERFhONa1a1e0b9/+lmGqsrISlZWVhq+1Wm2jvyfdnoP/5OGZz+Og/5+4byWTwE4ph4O1FVztlGhnp0Q7eyW8HK3h62yN9s426OBiC7W1lTiFExGRRREtTOXl5UGn08Hd3b3ecXd3d6SlpTX4Go1G0+B4jUbT6O+r0WigUCjg6OjYpPdZtGgR3nnnnUZ/H2q+P1JzoBcAiQS4fv60WifgSlk1rpRV40J+2U1f72avxB3u9rjD3R4B3g4I8lHD39UOMilntYiIqOWIFqZMzezZs+vNimm1Wvj6+opYkflLziwEAHzwZE8M7e2NGr2AimodSiprUFxRg6LyauQVVyK3pBLZ2gpkXinHxSvluFhQhpziSsNj/9k8w3vaKmQIbu+Ivn4u6OPvhN7tnaCykon0CYmIyByIFqZcXV0hk8luuIsuOzsbHh4eDb7Gw8OjSeNv9h5VVVUoLCysNzv1b++jVCqhVCob/X2oeXR6ASmZtZdSg3zUkEgksJJJYCWTwl5lBU/1rV9fXFGNszklOJNdgpMaLZIvFeFElhalVTocOJuPA2fzAQAKuRR9/Zxx9x2uuPuOdrjT3Z79WERE1CSihSmFQoGQkBDExMQgKioKQG0DekxMDKZMmdLga8LDwxETE4MZM2YYju3evRvh4eGN/r4hISGwsrJCTEwMhg0bBgA4deoUMjIymvQ+1LrO5ZagvFoHG4UMHdvZNfn19ior9GrvhF7tnQzHanR6nMkpQfyFKziSXoDD6QXQaCuw/2we9p/Nw//tTIOPkzUe7O6BB3u4o4+fMy8JEhHRvxL1Ml90dDTGjh2L0NBQ9O3bF8uWLUNpaSnGjRsHABgzZgy8vb2xaNEiAMD06dMxaNAgLF26FEOGDMHmzZsRHx+PtWvXGt6zoKAAGRkZyMrKAlAblIDaGSkPDw+o1WqMHz8e0dHRcHZ2hoODA6ZOnYrw8PBG38lHre/4pSIAQA8vhxYLNHKZFN08HdDN0wHP9usAQRBwLq8Uf5/Kxd4zuTh0Lh+XrpTjiwPp+OJAOlztFHgkyAuP9vRC7/aOnLEiIqIGiRqmRowYgdzcXMybNw8ajQbBwcHYtWuXock8IyMDUum1pbD69++PTZs2Ye7cuZgzZw66dOmCbdu2ISAgwDBm+/bthjAGACNHjgQAzJ8/H2+//TYA4KOPPoJUKsWwYcNQWVmJyMhIrFq1qg0+MTVWcmZtmAr0dmy17yGRSNCpnR06tbPD8wP8UV6lw94zufj9RDZi0rKRV1KF9QfPY/3B8/B1tsaTvX3xZKgPvB2tW60mIiIyPaKuM2XKuM5U6xq66gASMgqxbEQwonp5t/n3r9bpsf9MHrYnZeG3ExqUVekA1N5ZOLBLOzwT1h73d3PnZUAiIhNjVutMEd1MjU6P1Mu1zeeBPv/Sad5KrGRS3NvVDfd2dUN5lQ67TlzGd0cuIfZcPvaezsXe07no4GKD5/r7YXioL+yU/E+JiMhScWbqNnFmqvWkabR4aNk+2CnlOD7/QUiNaPbnQn4pvj18Ed8ezkBReTUAwF4px1N9fPFcfz/4OtuIXCEREd1Ka/z+Fn2jY6L/dX3zuTEFKQDo4GKLNwZ3Rezs+7AwKgAd29miuLIG6/anY9D7f+Klr48i5Wq/FxERWQZemyCjk3w1TAWJdImvMWwUcjzTrwOe7tsef5/JxRf707HvTB52ndBg1wkNInu445UH7kBXD85aEhGZO4YpMjqGO/l8HMUtpBGkUgnuvdMN997phtPZxVj151n8lJSF305k47cT2RgS5IlXIrqgs5u92KUSEVEr4WU+MirV1zefexvvzFRD7nC3x7KRvfD7jLsxJMgTAPDL8ct44KO9mLE5EedyS0SukIiIWgPDFBmV09nFqKrRw14lRwcTbebu4m6PlU/3xq/TByKyhzsEAdh2LAsPfrQX7/x8AkVl1WKXSERELYhhioxKimGxTrXRNZ83VTdPB6x5NhQ7pg7AfV3dUKMX8OWB87jngz/xzaEL0Ol5Iy0RkTlgmCKjUncnn1jrS7WGAG81vniuD74e3xdd3Oxwpawac7elYMjyfTj4T57Y5RERUTMxTJFRSb5uZsrcDOzSDr9OH4h3HusBtbUV0jTFePqzOLz09VFcLCgTuzwiIrpNDFNkNKpq9Ei7XAwACGrFPfnEJJdJMba/H/569R6MCe8AqQTYdUKDBz76G5/vO8dLf0REJohhiozG6exiVOn0UFtbwdfZvDcTdrJVYMHjAfh1+t0I83dGRbUeC385iaGrDuDk1bsZiYjINDBMkdEw9Et5qyGRmHbzeWPd6WGPzRP7YdHQQNir5Ei6VIRHP9mPpb+fQmWNTuzyiIioERimyGgkZxYCMK/m88aQSCQY1bc9/ogehAe7u6NGL+CTPWfx8Mf7EH++QOzyiIjoXzBMkdGom5kKMsPm88Zwd1BhzbMh+HR0b7jaKfFPbimGr4nFgp9TUVHNWSoiImPFMEVGoaJah9PZtc3nljYzdT2JRILBgZ6IiR6Ep0J9IAjAFwfSEbXyAE5pisUuj4iIGsAwRUbhlKYY1ToBTjZW8HY07+bzxlDbWGHJkz2xbmwoXGwVSNMU49EV+/HlgXQIAu/4IyIyJgxTZBSOX7e5saU0nzfG/d3csWvG3bj3znaoqtHjnZ9TMfbLI8jRVohdGhERXcUwRUYh+VIhAMvtl7qVdvZKfPFcHyx4vAeUcin2ns7FQx/vw+8nNGKXRkREYJgiI5GcWbu2kiX3S92KRCLBmHA/7Jg6AN08HVBQWoWJXx/Fgp9TUa3Ti10eEZFFY5gi0dVrPufM1C11cbfHtsn9MWGgP4Da5vSRaw9BU8TLfkREYmGYItGlXtZCpxfgaqeAp1oldjlGTym
							 
						 
					
						
							
								
									
										
										
										
											2024-03-11 07:07:36 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								      "text/plain": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "<Figure size 640x480 with 1 Axes>"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "display_data"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "plt.ylabel(\"Learning rate\")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "plt.xlabel(\"Step\")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "plt.plot(range(total_training_steps), track_lrs);"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "e7512808-b48d-4146-86a1-5931b1e3aec1",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "## D.3 Gradient clipping"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "c0a74f76-8d2b-4974-a03c-d645445cdc21",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Gradient clipping is yet another technique used to stabilize the training when training LLMs\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- By setting a threshold, gradients exceeding this limit are scaled down to a maximum magnitude to ensure that the updates to the model's parameters during backpropagation remain within a manageable range\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- For instance, using the `max_norm=1.0` setting in PyTorch's `clip_grad_norm_` method means that the norm of the gradients is clipped such that their maximum norm does not exceed 1.0\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- the \"norm\" refers to a measure of the gradient vector's length (or magnitude) in the parameter space of the model\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Specifically, it's the L2 norm, also known as the Euclidean norm\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Mathematically, for a vector $\\mathbf{v}$ with components $\\mathbf{v} = [v_1, v_2, \\ldots, v_n]$, the L2 norm is defined as:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "$$\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\\| \\mathbf{v} \\|_2 = \\sqrt{v_1^2 + v_2^2 + \\ldots + v_n^2}\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "$$"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "d44838a6-4322-47b2-a935-c00d3a88355f",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- The L2 norm is calculated similarly for matrices.\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Let's assume our gradient matrix is:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "$$\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "G = \\begin{bmatrix}\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "1 & 2 \\\\\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "2 & 4\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\\end{bmatrix}\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "$$\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- And we want to clip these gradients with a `max_norm` of 1.\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- First, we calculate the L2 norm of these gradients:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "$$\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\\|G\\|_2 = \\sqrt{1^2 + 2^2 + 2^2 + 4^2} = \\sqrt{25} = 5\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "$$\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Since $\\|G\\|_2 = 5$ is greater than our `max_norm` of 1, we need to scale down the gradients so that their norm is exactly 1. The scaling factor is calculated as $\\frac{max\\_norm}{\\|G\\|_2} = \\frac{1}{5}$.\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Therefore, the scaled gradient matrix $G'$ will be as follows:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "$$\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "G' = \\frac{1}{5} \\times G = \\begin{bmatrix}\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\\frac{1}{5} & \\frac{2}{5} \\\\\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\\frac{2}{5} & \\frac{4}{5}\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\\end{bmatrix}\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "$$"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "eeb0c3c1-2cff-46f5-8127-24412184428c",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Let's see this in action\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- First, we initialize a new model and calculate the loss for a training batch like we would do in the regular training loop"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "execution_count": 9,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "e199e1ff-58c4-413a-855e-5edbe9292649",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "from previous_chapters import calc_loss_batch\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "torch.manual_seed(123)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "model = GPTModel(GPT_CONFIG_124M)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "loss = calc_loss_batch(input_batch, target_batch, model, device)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "loss.backward()"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "76b60f3a-15ec-4846-838d-fdef3df99899",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- If we call `.backward()`, PyTorch will calculate the gradients and store them in a `.grad` attribute for each weight (parameter) matrix\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Let's define a utility function to calculate the highest gradient based on all model weights"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "execution_count": 10,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "e70729a3-24d1-411d-a002-2529cd3a8a9e",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "tensor(0.0373)\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "def find_highest_gradient(model):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    max_grad = None\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    for param in model.parameters():\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        if param.grad is not None:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            grad_values = param.grad.data.flatten()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            max_grad_param = grad_values.max()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            if max_grad is None or max_grad_param > max_grad:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                max_grad = max_grad_param\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    return max_grad\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(find_highest_gradient(model))"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "734f30e6-6b24-4d4b-ae91-e9a4b871113f",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Applying gradient clipping, we can see that the largest gradient is now substantially smaller:"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "execution_count": 11,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "fa81ef8b-4280-400f-a93e-5210f3e62ff0",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "tensor(0.0166)\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(find_highest_gradient(model))"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "b62c2af0-dac3-4742-be4b-4292c6753099",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "## D.4 The modified training function"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "76715332-94ec-4185-922a-75cb420819d5",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Now let's add the three concepts covered above (learning rate warmup, cosine decay, and gradient clipping) to the `train_model_simple` function covered in chapter 5 to create the more sophisticated `train_model` function below:"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "execution_count": 12,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "46eb9c84-a293-4016-a523-7ad726e171e9",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "from previous_chapters import evaluate_model, generate_and_print_sample\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "def train_model(model, train_loader, val_loader, optimizer, device, n_epochs,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                eval_freq, eval_iter, start_context, warmup_steps=10,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                initial_lr=3e-05, min_lr=1e-6):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    train_losses, val_losses, track_tokens_seen, track_lrs = [], [], [], []\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    tokens_seen, global_step = 0, -1\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    # Retrieve the maximum learning rate from the optimizer\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    peak_lr = optimizer.param_groups[0][\"lr\"]\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    # Calculate the total number of iterations in the training process\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    total_training_steps = len(train_loader) * n_epochs\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    # Calculate the learning rate increment during the warmup phase\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    lr_increment = (peak_lr - initial_lr) / warmup_steps\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    for epoch in range(n_epochs):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        model.train()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        for input_batch, target_batch in train_loader:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            optimizer.zero_grad()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            global_step += 1\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            # Adjust the learning rate based on the current phase (warmup or cosine annealing)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            if global_step < warmup_steps:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                # Linear warmup\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                lr = initial_lr + global_step * lr_increment  \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            else:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                # Cosine annealing after warmup\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                progress = ((global_step - warmup_steps) / \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                            (total_training_steps - warmup_steps))\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                lr = min_lr + (peak_lr - min_lr) * 0.5 * (1 + math.cos(math.pi * progress))\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            # Apply the calculated learning rate to the optimizer\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            for param_group in optimizer.param_groups:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                param_group[\"lr\"] = lr\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            track_lrs.append(lr)  # Store the current learning rate\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            # Calculate and backpropagate the loss\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            loss = calc_loss_batch(input_batch, target_batch, model, device)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            loss.backward()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            # Apply gradient clipping after the warmup phase to avoid exploding gradients\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            if global_step > warmup_steps:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            optimizer.step()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            tokens_seen += input_batch.numel()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            # Periodically evaluate the model on the training and validation sets\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            if global_step % eval_freq == 0:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                train_loss, val_loss = evaluate_model(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                    model, train_loader, val_loader,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                    device, eval_iter\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                )\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                train_losses.append(train_loss)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                val_losses.append(val_loss)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                track_tokens_seen.append(tokens_seen)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                # Print the current losses\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                print(f\"Ep {epoch+1} (Iter {global_step:06d}): \"\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                      f\"Train loss {train_loss:.3f}, Val loss {val_loss:.3f}\")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        # Generate and print a sample from the model to monitor progress\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        generate_and_print_sample(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            model, train_loader.dataset.tokenizer,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            device, start_context\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        )\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    return train_losses, val_losses, track_tokens_seen, track_lrs"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "execution_count": 13,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "55fcd247-ba9d-4b93-a757-0f7ce04fee41",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Ep 1 (Iter 000000): Train loss 10.914, Val loss 10.940\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Ep 1 (Iter 000005): Train loss 8.903, Val loss 9.313\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Every effort moves you,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Ep 2 (Iter 000010): Train loss 7.362, Val loss 7.789\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Ep 2 (Iter 000015): Train loss 6.273, Val loss 6.814\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Every effort moves you,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Ep 3 (Iter 000020): Train loss 5.958, Val loss 6.609\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Ep 3 (Iter 000025): Train loss 5.675, Val loss 6.592\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Every effort moves you.                                                 \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Ep 4 (Iter 000030): Train loss 5.607, Val loss 6.565\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Ep 4 (Iter 000035): Train loss 5.063, Val loss 6.483\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Every effort moves you, and, and the to the to the to the to the to the, and, and the, and the, and, and, and the, and the, and, and the, and, and, and the, and, and the\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Ep 5 (Iter 000040): Train loss 4.384, Val loss 6.379\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Every effort moves you, I was, and I had been.                   \"I, I had the picture, as a little's his pictures, I had been, I was his\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Ep 6 (Iter 000045): Train loss 4.638, Val loss 6.306\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Ep 6 (Iter 000050): Train loss 3.690, Val loss 6.196\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Every effort moves you know the to me a little of his pictures--I had been.  \"I was the's--and, I felt to see a little of his pictures--I had been. \"I of Jack's \"strong. \"I\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Ep 7 (Iter 000055): Train loss 3.157, Val loss 6.148\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Ep 7 (Iter 000060): Train loss 2.498, Val loss 6.157\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Every effort moves you know it was not that, and he was to the fact of the of a and he was--his, the fact of the donkey, in the of the his head to have.   \"I had been his pictures--and by his\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Ep 8 (Iter 000065): Train loss 2.182, Val loss 6.178\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Ep 8 (Iter 000070): Train loss 1.998, Val loss 6.193\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Every effort moves you know,\" was not that my dear, his pictures--so handsome, in a so that he was a year after Jack's resolve had been his painting.     \"Oh, I had the donkey. \"There were, with his\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Ep 9 (Iter 000075): Train loss 1.824, Val loss 6.211\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Ep 9 (Iter 000080): Train loss 1.742, Val loss 6.201\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Every effort moves you know,\" was not that my hostess was \"interesting\": on that point I could have given Miss Croft the fact, and.         \"Oh, as I turned, and down the room, in his\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Ep 10 (Iter 000085): Train loss 1.285, Val loss 6.234\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Every effort moves you?\"  \"Yes--quite insensible to the fact with a little: \"Yes--and by me to me to have to see a smile behind his close grayish beard--as if he had the donkey. \"There were days when I\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Ep 11 (Iter 000090): Train loss 1.256, Val loss 6.236\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Ep 11 (Iter 000095): Train loss 0.803, Val loss 6.255\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Every effort moves you?\"  \"Yes--quite insensible to the irony. She wanted him vindicated--and by me!\"  He laughed again, and threw back his head to look up at the sketch of the donkey. \"There were days when I\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Ep 12 (Iter 000100): Train loss 0.731, Val loss 6.284\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Ep 12 (Iter 000105): Train loss 0.889, Val loss 6.299\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Every effort moves you?\"  \"Yes--quite insensible to the irony. She wanted him vindicated--and by me!\"  He laughed again, and threw back his head to look up at the sketch of the donkey. \"There were days when I\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Ep 13 (Iter 000110): Train loss 0.703, Val loss 6.316\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Ep 13 (Iter 000115): Train loss 0.517, Val loss 6.315\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Every effort moves you?\"  \"Yes--quite insensible to the irony. She wanted him vindicated--and by me!\"  He laughed again, and threw back his head to look up at the sketch of the donkey. \"There were days when I\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Ep 14 (Iter 000120): Train loss 0.594, Val loss 6.324\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Ep 14 (Iter 000125): Train loss 0.481, Val loss 6.325\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Every effort moves you?\"  \"Yes--quite insensible to the irony. She wanted him vindicated--and by me!\"  He laughed again, and threw back his head to look up at the sketch of the donkey. \"There were days when I\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Ep 15 (Iter 000130): Train loss 0.529, Val loss 6.324\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Every effort moves you?\"  \"Yes--quite insensible to the irony. She wanted him vindicated--and by me!\"  He laughed again, and threw back his head to look up at the sketch of the donkey. \"There were days when I\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "torch.manual_seed(123)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "model = GPTModel(GPT_CONFIG_124M)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "model.to(device)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "peak_lr = 5e-4\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "optimizer = torch.optim.AdamW(model.parameters(), lr=peak_lr, weight_decay=0.1)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "n_epochs = 15\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "train_losses, val_losses, tokens_seen, lrs = train_model(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    model, train_loader, val_loader, optimizer, device, n_epochs=n_epochs,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    eval_freq=5, eval_iter=1, start_context=\"Every effort moves you\",\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    warmup_steps=10, initial_lr=1e-5, min_lr=1e-5\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    ")"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "827e8d5e-0872-4b90-98ac-200c80ee2d53",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Looking at the results above, we can see that the model starts out generating incomprehensible strings of words, whereas, towards the end, it's able to produce grammatically more or less correct sentences\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- If we were to check a few passages it writes towards the end, we would find that they are contained in the training set verbatim -- it simply memorizes the training data\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Note that the overfitting here occurs because we have a very, very small training set, and we iterate over it so many times\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "  - The LLM training here primarily serves educational purposes; we mainly want to see that the model can learn to produce coherent text\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "  - Instead of spending weeks or months on training this model on vast amounts of expensive hardware, we load the pretrained weights"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "9decec45-4fdf-4ff6-85a7-1806613f8af7",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- A quick check that the learning rate behaves as intended"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-16 08:10:58 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 17,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-11 07:07:36 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "d8ebb8d2-8308-4a83-a2a6-730c3bf84452",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "data": {
							 
						 
					
						
							
								
									
										
										
										
											2024-03-16 08:10:58 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAlwAAAGwCAYAAAB8crvUAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAAA9hAAAPYQGoP6dpAABjaUlEQVR4nO3deVhU5dsH8O8MAzPIqiKrKLjlhqKiiEtWUriUolZq5BZJmbiE5fa6lC2kZfWjzC1LK7dssTKlCM0Vkc0VtxR3WQRhAGWbOe8fNgcnUUFnODPD93NdcyFnnhnu82TM7XPu5z4yQRAEEBEREZHRyKUOgIiIiMjSMeEiIiIiMjImXERERERGxoSLiIiIyMiYcBEREREZGRMuIiIiIiNjwkVERERkZAqpA7BkWq0WV65cgYODA2QymdThEBERUTUIgoDCwkJ4enpCLjfM2hQTLiO6cuUKvL29pQ6DiIiIHsDFixfRuHFjg7wXEy4jcnBwAHDrP5ijo6PE0RAREVF1qNVqeHt7i5/jhsCEy4h0lxEdHR2ZcBEREZkZQ5YDsWieiIiIyMiYcBEREREZGRMuIiIiIiNjwkVERERkZEy4iIiIiIyMCRcRERGRkTHhIiIiIjIyJlxERERERsaEi4iIiMjImHARERERGZlJJFxLliyBj48PVCoVAgMDceDAgXuO37RpE1q3bg2VSgU/Pz9s3bpV73lBEDBv3jx4eHjA1tYWwcHBOH36tN6YvLw8hIWFwdHREc7OzggPD0dRUZH4/Llz5yCTye547N+/33AnTkRERHWC5AnXxo0bERUVhfnz5yM1NRUdO3ZESEgIsrOzqxy/b98+jBw5EuHh4UhLS0NoaChCQ0Nx9OhRccyiRYsQExODZcuWITExEXZ2dggJCUFJSYk4JiwsDMeOHUNcXBy2bNmCXbt2ISIi4o6f99dff+Hq1avio0uXLoafBCIiIrJoMkEQBCkDCAwMRNeuXfH5558DALRaLby9vTFp0iTMnDnzjvHDhw9HcXExtmzZIh7r3r07/P39sWzZMgiCAE9PT0ybNg1vvPEGAKCgoABubm5YvXo1RowYgePHj6Nt27ZISkpCQEAAACA2NhYDBgzApUuX4OnpiXPnzsHX1xdpaWnw9/d/oHNTq9VwcnJCQUGBRd68uqi0AjdKK6BUWEFpLYeNlRxyueFu9ElERCQFY3x+KwzyLg+orKwMKSkpmDVrlnhMLpcjODgYCQkJVb4mISEBUVFResdCQkKwefNmAEBGRgYyMzMRHBwsPu/k5ITAwEAkJCRgxIgRSEhIgLOzs5hsAUBwcDDkcjkSExMxZMgQ8figQYNQUlKCVq1aYfr06Rg0aNBdz6e0tBSlpaXi92q1unoTYYb+yS7CgJjdKKvQ6h23sZJDqZBDZWOFhnY2aOSghKuDCo0clP/+WQkPJxWaNbJHAzsbiaInIiKqXZImXNeuXYNGo4Gbm5vecTc3N5w4caLK12RmZlY5PjMzU3xed+xeY1xdXfWeVygUaNCggTjG3t4eixcvRs+ePSGXy/Hjjz8iNDQUmzdvvmvSFR0djbfffrs6p272Us9fvyPZAoAyjRZlGi0KSyuQU1iKE5mFd32P+vWs0cLVHs0b/ftwtUMbD0d4ONkaM3QiIqJaJ2nCZcpcXFz0VtK6du2KK1eu4MMPP7xrwjVr1iy916jVanh7exs9VilkqW/Vwz0f0BjvD/FDSYUWpeUalFZoUVqhRXFpBXKLy5CtLkFOUSlyCisfl67fxOX8m7h+oxxJ564j6dx1vfd2d1ShUxNndGrijM5N6qO9lxNU1lZSnCYREZFBSJpwubi4wMrKCllZWXrHs7Ky4O7uXuVr3N3d7zle9zUrKwseHh56Y3S1WO7u7ncU5VdUVCAvL++uPxe4VW8WFxd31+eVSiWUSuVdn7ck2YW3Lp26OaqgsJLD3koOe2X1/zrdLNPg7LUinMkpxpnsIvyTU4Qz2UU4nV2ETHUJth3NxLajt1YbFXIZ2no6Iqh5Q/Rp1QgBTRvARiH5fg8iIqJqkzThsrGxQZcuXRAfH4/Q0FAAt4rm4+PjERkZWeVrgoKCEB8fj6lTp4rH4uLiEBQUBADw9fWFu7s74uPjxQRLrVYjMTEREyZMEN8jPz8fKSkp4q7D7du3Q6vVIjAw8K7xHjx4UC+Jq8t0K1yujqoHer2tjRXaeTqhnaeT3vEbZRU4cqkAaRfzkXbhOlIv5COnsBSHLxXg8KUCLN95FnY2Vghq7oI+jzTCY60awbtBvYc+HyIiImOS/JJiVFQUxowZg4CAAHTr1g2ffvopiouLMW7cOADA6NGj4eXlhejoaADAlClT0KdPHyxevBgDBw7Ehg0bkJycjBUrVgAAZDIZpk6dinfffRctW7aEr68v5s6dC09PTzGpa9OmDfr164fx48dj2bJlKC8vR2RkJEaMGAFPT08AwJo1a2BjY4NOnToBAH766Sd89dVX+PLLL2t5hkxT1r8rXK4Ohl3Rq2ejQGCzhghs1hDArZ5ql/NvIuX8dew8lYNdp67hWlEp/jqehb+O31rpbNbIDgPae+CZjp54xN3BoPEQEREZguQJ1/Dhw5GTk4N58+YhMzMT/v7+iI2NFYveL1y4ALm88vJRjx49sG7dOsyZMwezZ89Gy5YtsXnzZrRv314cM336dBQXFyMiIgL5+fno1asXYmNjoVJVrsasXbsWkZGR6Nu3L+RyOYYNG4aYmBi92N555x2cP38eCoUCrVu3xsaNG/Hss88aeUbMQ86/K1xuD7jCVV0ymQyN69dD4/r1MNjfC1qtgPSrauw8lYOdp3KQcv46zuYU4/Md/+DzHf+gpas9nunoiac7eKBZI3ujxkZERFRdkvfhsmSW2odLqxXQas42VGgFJMx6QtJdheqScuw4kY0th69i58kclGkqd06283TEYH9PDO3cGC72daO2joiIHp7F9eEi85R3owwVWgEyGSRPZBxV1hjs74XB/l4ouFmOuPQs/HboCvb8cw3Hrqhx7IoaH/5xEiHt3PFCYBMENWsImYzNWYmIqHYx4aIay1bfqt9qaGcDayvT2S3oZGuNZ7s0xrNdGiOvuAzbjl7F90kXcehSAbYcvooth6+imYsdRnZrgmFdGrPxKhER1RomXFRjWYW36rcaORi3futhNLCzQVhgU4QFNsXRywVYd+ACfkm7jLPXivHe1uP48I+TeKajJyIebcZCeyIiMjrTWZ4gs5Gj1vXgMo+6qPZeTnh/iB8S/y8Y7w/xQztPR5RptPgx9RJCPt2Fl1YnIfFsLljOSERExsIVLqoxXQ8uNxNe4aqKvVKBFwKbYGQ3bxy8mI8Vu84i9lgmtp/IxvYT2fD3dsarfZrhybbusOJNuImIyIC4wkU1pruk6GomK1z/JZPJ0KlJfSx9sQu2T3sMLwQ2gY1CjoMX8/Hqd6l48uOd+DHlEjRarngREZFhMOGiGtMVzT9ol3lT4utih/eH+GHvjCcQ+XgLOKoUOHutGNM2HULIp7uw9chVaJl4ERHRQ2LCRTWm6zLvZuAu81Jq5KDEGyGPYN+svpjRrzWcbK3xT3YRXlubikFL9mDHyWzWeBER0QNjwkU1lv2Q91E0ZfZKBSY81hy7ZzyOyX1bws7GCkcvqzHu6yQ8tywB+8/mSh0iERGZISZcVCNarYCcQvPapfggHFXWiHqyFXbPeAIRjzaDUiFH8vnrGLFiPyK+Scb53GKpQyQiIjPChItqxJS6zNeGBnY2mD2gDXZNfxwvdm8CK7kMf6Zn4cmPd+GDbSdQVFohdYhERGQGmHBRjehaQphal3ljc3NU4d1QP8RO6Y3eLV1QptFi2c4zeOzDv/F98kUW1hMR0T3VnU9MMojsfy8nuppZDy5DaenmgG9e6oZVYwLg62KHa0WlmP7DYQxeshcp5/OkDo+IiEwUEy6qkcqCecu/nHg3MpkMfdu44Y+pj+L/BrSBg1KBI5cLMGxpAmb9dAQFN8ulDpGIiEwMEy6qkSzdbX3q6ArX7WwUcox/tBl2vPkYhgd4AwDWH7iA4I934vfDV9lGgoiIREy4qEa
							 
						 
					
						
							
								
									
										
										
										
											2024-03-11 07:07:36 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								      "text/plain": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "<Figure size 640x480 with 1 Axes>"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "display_data"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "plt.plot(range(len(lrs)), lrs)\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-16 08:10:58 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "plt.ylabel(\"Learning rate\")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "plt.xlabel(\"Steps\")\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-11 07:07:36 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "plt.show()"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
									
										
										
										
											2024-03-16 08:10:58 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "a2f85b01-859b-4454-a3a3-c7ef593735a6",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-11 07:07:36 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- And a quick look at the loss curves"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-16 08:10:58 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 18,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-11 07:07:36 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "445d8155-6eae-4b50-a381-d0820ebc27cc",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "scrolled": true
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "data": {
							 
						 
					
						
							
								
									
										
										
										
											2024-03-16 08:10:58 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAnUAAAHWCAYAAAARl3+JAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAAA9hAAAPYQGoP6dpAAB0m0lEQVR4nO3dd3wT9f8H8NcladKku6WTTqDQwZ5CkS1DRAERRH4MURQEAQfiAsGFICriAOWr4EBRlKWyyoayV6G0lN0W6KCU7p3c74+jaUMLFGh7afp6Ph73aPK59c4Z2xd39/mcIIqiCCIiIiKq1RRyF0BERERED46hjoiIiMgCMNQRERERWQCGOiIiIiILwFBHREREZAEY6oiIiIgsAEMdERERkQVgqCMiIiKyAAx1RERERBaAoY6Iap1Lly5BEAQcP35c7lKIiMwGQx0RyUIQhDtOs2bNkrtEIqJaRSV3AURUNyUmJhpf//HHH5g5cyZiY2ONbba2tnKURURUa/FMHRHJwsPDwzg5ODhAEATjezc3N3z++efw9vaGRqNBy5YtsXHjxttuS6/XY+zYsQgKCkJ8fDwAYO3atWjdujWsra3RoEEDzJ49G8XFxcZ1BEHA//73PwwaNAg6nQ6BgYFYt26dcf6NGzcwYsQIuLq6QqvVIjAwEEuXLr1tDX/99ReaNWsGrVYLFxcX9OrVCzk5Ocb5//vf/xAcHAxra2sEBQXh22+/NVk/ISEBQ4cOhaOjI5ydnfHEE0/g0qVLxvljxozBwIEDMX/+fHh6esLFxQUTJ05EUVFRpY85EVk2hjoiMjtffvklPvvsM8yfPx8nTpxAnz598Pjjj+Ps2bPlli0oKMBTTz2F48ePY/fu3fD19cXu3bsxatQoTJkyBdHR0fjuu++wbNkyfPTRRybrzp49G0OHDsWJEyfw6KOPYsSIEUhLSwMAzJgxA9HR0diwYQNiYmKwaNEi1KtXr8J6ExMTMXz4cIwdOxYxMTHYsWMHBg8eDFEUAQDLly/HzJkz8dFHHyEmJgYff/wxZsyYgZ9++gkAUFRUhD59+sDOzg67d+9GREQEbG1t0bdvXxQWFhr3s337dpw/fx7bt2/HTz/9hGXLlmHZsmVVcciJyBKIREQyW7p0qejg4GB87+XlJX700Ucmy7Rr10586aWXRFEUxYsXL4oAxN27d4s9e/YUO3fuLKanpxuX7dmzp/jxxx+brP/LL7+Inp6exvcAxHfffdf4Pjs7WwQgbtiwQRRFURwwYID47LPPVqr+I0eOiADES5cuVTi/YcOG4m+//WbS9sEHH4gdO3Y01takSRPRYDAY5xcUFIharVbctGmTKIqiOHr0aNHPz08sLi42LvPUU0+Jw4YNq1SNRGT5eE8dEZmVzMxMXL16FWFhYSbtYWFhiIyMNGkbPnw4vL29sW3bNmi1WmN7ZGQkIiIiTM7M6fV65OfnIzc3FzqdDgDQvHlz43wbGxvY29sjJSUFADBhwgQ8+eSTOHr0KHr37o2BAweiU6dOFdbcokUL9OzZE82aNUOfPn3Qu3dvDBkyBE5OTsjJycH58+fx3HPPYdy4ccZ1iouL4eDgYKz33LlzsLOzM9lufn4+zp8/b3wfGhoKpVJpfO/p6YmTJ0/e4WgSUV3CUEdEtdajjz6KX3/9Ffv27UOPHj2M7dnZ2Zg9ezYGDx5cbh1ra2vjaysrK5N5giDAYDAAAPr164e4uDisX78e4eHh6NmzJyZOnIj58+eX26ZSqUR4eDj27t2LzZs346uvvsI777yDAwcOGAPkkiVL0KFDh3LrldTbpk0bLF++vNy2XV1dK1UvERFDHRGZFXt7e3h5eSEiIgJdu3Y1tkdERKB9+/Ymy06YMAFNmzbF448/jv/++8+4fOvWrREbG4tGjRo9UC2urq4YPXo0Ro8ejYcffhjTpk2rMNQBUsAKCwtDWFgYZs6cCT8/P6xevRqvvvoqvLy8cOHCBYwYMaLCdVu3bo0//vgDbm5usLe3f6CaiajuYqgjIrMzbdo0vPfee2jYsCFatmyJpUuX4vjx4xWeyXr55Zeh1+vx2GOPYcOGDejcuTNmzpyJxx57DL6+vhgyZAgUCgUiIyMRFRWFDz/8sFI1zJw5E23atEFoaCgKCgrw77//Ijg4uMJlDxw4gK1bt6J3795wc3PDgQMHcO3aNePys2fPxuTJk+Hg4IC+ffuioKAAhw8fxo0bN/Dqq69ixIgR+PTTT/HEE0/g/fffh7e3N+Li4rBq1Sq88cYb8Pb2vv+DSUR1BkMdEZmdyZMnIyMjA6+99hpSUlIQEhKCdevWITAwsMLlp06dCoPBgEcffRQbN25Enz598O+//+L999/H3LlzYWVlhaCgIDz//POVrkGtVuOtt97CpUuXoNVq8fDDD2PFihUVLmtvb49du3ZhwYIFyMzMhJ+fHz777DP069cPAPD8889Dp9Ph008/xbRp02BjY4NmzZph6tSpAACdToddu3Zh+vTpGDx4MLKyslC/fn307NmTZ+6IqNIEUbzZ556IiIiIai2OU0dERERkARjqiIiIiCwAQx0RERGRBWCoIyIiIrIADHVEREREFoChjoiIiMgCMNTdo2+++Qb+/v6wtrZGhw4dcPDgQblLqlK7du3CgAED4OXlBUEQsGbNGpP5oihi5syZ8PT0hFarRa9evXD27FmTZdLS0jBixAjY29vD0dERzz33HLKzs02WOXHiBB5++GFYW1vDx8cH8+bNK1fLypUrERQUBGtrazRr1gzr16+v8s97v+bMmYN27drBzs4Obm5uGDhwIGJjY02Wyc/Px8SJE+Hi4gJbW1s8+eSTSE5ONlkmPj4e/fv3h06ng5ubG6ZNm4bi4mKTZXbs2IHWrVtDo9GgUaNGWLZsWbl6zPV7uWjRIjRv3hz29vawt7dHx44dsWHDBuN8HqOKffLJJxAEwTiOHcBjBQCzZs2CIAgmU1BQkHE+j5GpK1eu4P/+7//g4uICrVaLZs2a4fDhw8b5/H0O+Pv7l/tOCYKAiRMnAqiF3ymRKm3FihWiWq0Wf/zxR/HUqVPiuHHjREdHRzE5OVnu0qrM+vXrxXfeeUdctWqVCEBcvXq1yfxPPvlEdHBwENesWSNGRkaKjz/+uBgQECDm5eUZl+nbt6/YokULcf/+/eLu3bvFRo0aicOHDzfOz8jIEN3d3cURI0aIUVFR4u+//y5qtVrxu+++My4TEREhKpVKcd68eWJ0dLT47rvvilZWVuLJkyer/RhURp8+fcSlS5eKUVFR4vHjx8VHH31U9PX1FbOzs43LjB8/XvTx8RG3bt0qHj58WHzooYfETp06GecXFxeLTZs2FXv16iUeO3ZMXL9+vVivXj3xrbfeMi5z4cIFUafTia+++qoYHR0tfvXVV6JSqRQ3btxoXMacv5fr1q0T//vvP/HMmTNibGys+Pbbb4tWVlZiVFSUKIo8RhU5ePCg6O/vLzZv3lycMmWKsZ3HShTfe+89MTQ0VExMTDRO165dM87nMSqVlpYm+vn5iWPGjBEPHDggXrhwQdy0aZN47tw54zL8fS6KKSkpJt+n8PBwEYC4fft2URRr33eKoe4etG/fXpw4caLxvV6vF728vMQ5c+bIWFX1uTXUGQwG0cPDQ/z000+Nbenp6aJGoxF///13URRFMTo6WgQgHjp0yLjMhg0bREEQxCtXroiiKIrffvut6OTkJBYUFBiXmT59utikSRPj+6FDh4r9+/c3qadDhw7iiy++WKWfsaqkpKSIAMSdO3eKoigdFysrK3HlypXGZWJiYkQA4r59+0RRlAK0QqEQk5KSjMssWrRItLe3Nx6bN954QwwNDTXZ17Bhw8Q+ffoY39e276WTk5P4v//9j8eoAllZWWJgYKAYHh4udu3a1RjqeKwk7733ntiiRYsK5/EYmZo+fbrYuXPn287n7/OKTZkyRWzYsKFoMBhq5XeKl18rqbCwEEeOHEGvXr2MbQqFAr169cK+fftkrKzmXLx4EUlJSSbHwMHBAR06dDAeg3379sHR0RFt27Y1LtOrVy8oFAocOHDAuEyXLl2gVqu
							 
						 
					
						
							
								
									
										
										
										
											2024-03-11 07:07:36 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								      "text/plain": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "<Figure size 640x480 with 2 Axes>"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "display_data"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "from previous_chapters import plot_losses\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "epochs_tensor = torch.linspace(1, n_epochs, len(train_losses))\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "plot_losses(epochs_tensor, tokens_seen, train_losses, val_losses)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
									
										
										
										
											2024-03-16 08:10:58 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "c16fa614-67e1-4254-8b7e-c3e2f690c29c",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Note that the model is overfitting here because the dataset is kept very small for educational purposes (so that the code can be executed on a laptop computer)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- For a longer pretraining run on a much larger dataset, see [../../ch05/03_bonus_pretraining_on_gutenberg](../../ch05/03_bonus_pretraining_on_gutenberg)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
									
										
										
										
											2024-03-11 07:07:36 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								  }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  "kernelspec": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "display_name": "Python 3 (ipykernel)",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "language": "python",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "name": "python3"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  "language_info": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "codemirror_mode": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "name": "ipython",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "version": 3
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "file_extension": ".py",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "mimetype": "text/x-python",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "name": "python",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "nbconvert_exporter": "python",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "pygments_lexer": "ipython3",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-23 07:27:43 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "version": "3.10.12"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-11 07:07:36 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								  }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 "nbformat": 4,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 "nbformat_minor": 5
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								}