2024-03-31 20:28:51 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								{
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 "cells": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "ba450fb1-8a26-4894-ab7a-5d7bfefe90ce",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-05-24 07:20:37 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "<table style=\"width:100%\">\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "<tr>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "<td style=\"vertical-align:middle; text-align:left;\">\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "<font size=\"2\">\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "Supplementary code for the <a href=\"http://mng.bz/orYv\">Build a Large Language Model From Scratch</a> book by <a href=\"https://sebastianraschka.com\">Sebastian Raschka</a><br>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "<br>Code repository: <a href=\"https://github.com/rasbt/LLMs-from-scratch\">https://github.com/rasbt/LLMs-from-scratch</a>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "</font>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "</td>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "<td style=\"vertical-align:middle; text-align:left;\">\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "<a href=\"http://mng.bz/orYv\"><img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/cover-small.webp\" width=\"100px\"></a>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "</td>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "</tr>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "</table>"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 20:28:51 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "51c9672d-8d0c-470d-ac2d-1271f8ec3f14",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "# Chapter 5 Exercise solutions"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2024-07-24 21:53:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "execution_count": 1,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "37aa4692-2357-4d88-b072-6d2d988d7f4f",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "numpy version: 1.26.4\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "tiktoken version: 0.7.0\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "torch version: 2.4.0\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "tensorflow version: 2.16.1\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "from importlib.metadata import version\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "pkgs = [\"numpy\", \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        \"tiktoken\", \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        \"torch\",\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        \"tensorflow\" # For OpenAI's pretrained weights\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "       ]\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "for p in pkgs:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    print(f\"{p} version: {version(p)}\")"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 20:28:51 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "5fea8be3-30a1-4623-a6d7-b095c6c1092e",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "# Exercise 5.1: Temperature-scaled softmax scores and sampling probabilities"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "5860ba9f-2db3-4480-b96b-4be1c68981eb",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-08-19 21:04:18 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "- We can print the number of times the word \"pizza\" is sampled using the `print_sampled_tokens` function we defined in this section\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Let's start with the code we defined in section 5.3.1\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 20:28:51 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-08-19 21:04:18 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "- It is sampled 0x if the temperature is 0 or 0.1, and it is sampled 32x if the temperature is scaled up to 5. The estimated probability is 32/1000 * 100% = 3.2%\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 20:28:51 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-08-19 21:04:18 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "- The actual probability is 4.3% and contained in the rescaled softmax probability tensor (`scaled_probas[2][6]`)"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 20:28:51 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "9cba59c2-a8a3-4af3-add4-70230795225e",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-08-19 21:04:18 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "- Below is a self-contained example using code from chapter 5:"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 20:28:51 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-07-24 21:53:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 2,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 20:28:51 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "42dda298-3014-4c36-8d63-97c210bcf4e8",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "import torch\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "vocab = { \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"closer\": 0,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"every\": 1, \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"effort\": 2, \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"forward\": 3,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"inches\": 4,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"moves\": 5, \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"pizza\": 6,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"toward\": 7,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"you\": 8,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "} \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "inverse_vocab = {v: k for k, v in vocab.items()}\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "next_token_logits = torch.tensor(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    [4.51, 0.89, -1.90, 6.75, 1.63, -1.62, -1.89, 6.28, 1.79]\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    ")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "def print_sampled_tokens(probas):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    torch.manual_seed(123)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    sample = [torch.multinomial(probas, num_samples=1).item() for i in range(1_000)]\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    sampled_ids = torch.bincount(torch.tensor(sample))\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    for i, freq in enumerate(sampled_ids):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        print(f\"{freq} x {inverse_vocab[i]}\")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "def softmax_with_temperature(logits, temperature):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    scaled_logits = logits / temperature\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    return torch.softmax(scaled_logits, dim=0)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "temperatures = [1, 0.1, 5]  # Original, higher, and lower temperature\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "scaled_probas = [softmax_with_temperature(next_token_logits, T) for T in temperatures]"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "1ee0f9f3-4132-42c7-8324-252fd8f59145",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-08-19 21:04:18 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "- Now, we can iterate over the `scaled_probas` and print the sampling frequencies in each case:"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 20:28:51 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-07-24 21:53:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 3,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 20:28:51 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "b5605236-e300-4844-aea7-509d868efbdd",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Temperature: 1\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "73 x closer\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "0 x every\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "0 x effort\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "582 x forward\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "2 x inches\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "0 x moves\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "0 x pizza\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "343 x toward\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Temperature: 0.1\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "0 x closer\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "0 x every\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "0 x effort\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "985 x forward\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "0 x inches\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "0 x moves\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "0 x pizza\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "15 x toward\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Temperature: 5\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "165 x closer\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "75 x every\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "42 x effort\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "239 x forward\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "71 x inches\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "46 x moves\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "32 x pizza\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "227 x toward\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "103 x you\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "for i, probas in enumerate(scaled_probas):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    print(\"\\n\\nTemperature:\", temperatures[i])\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    print_sampled_tokens(probas)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "fbf88c97-19c4-462c-924a-411c8c765d2c",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-08-19 21:04:18 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "- Note that sampling offers an approximation of the actual probabilities when the word \"pizza\" is sampled\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- E.g., if it is sampled 32/1000 times, the estimated probability is 3.2%\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- To obtain the actual probability, we can check the probabilities directly by accessing the corresponding entry in `scaled_probas`\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 20:28:51 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-08-19 21:04:18 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "- Since \"pizza\" is the 7th entry in the vocabulary, for the temperature of 5, we obtain it as follows:"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 20:28:51 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-07-24 21:53:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 4,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 20:28:51 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "1d4163c0-22ad-4f5b-8e20-b7420e9dbfc6",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "data": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "text/plain": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "tensor(0.0430)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     },
							 
						 
					
						
							
								
									
										
										
										
											2024-07-24 21:53:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								     "execution_count": 4,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 20:28:51 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								     "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "execute_result"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "temp5_idx = 2\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "pizza_idx = 6\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "scaled_probas[temp5_idx][pizza_idx]"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "d3dcb438-5f18-4332-9627-66009f30a1a4",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-08-19 21:04:18 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "There is a 4.3% probability that the word \"pizza\" is sampled if the temperature is set to 5"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 20:28:51 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2024-04-07 06:03:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "b510ffb0-adca-4d64-8a12-38c4646fd736",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "# Exercise 5.2: Different temperature and top-k settings"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "884990db-d1a6-4c4e-8e36-2c1e4c1e67c7",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Both temperature and top-k settings have to be adjusted based on the individual LLM (a kind of trial and error process until it generates desirable outputs)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- The desirable outcomes are also application-specific, though\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "  - Lower top-k and temperatures result in less random outcomes, which is desired when creating educational content, technical writing or question answering, data analyses, code generation, and so forth\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "  - Higher top-k and temperatures result in more diverse and random outputs, which is more desirable for brainstorming tasks, creative writing, and so forth"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 20:28:51 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "3f35425d-529d-4179-a1c4-63cb8b25b156",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-04-07 06:03:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "# Exercise 5.3: Deterministic behavior in the decoding functions"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 20:28:51 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "d12229a2-1d52-46ff-b1e8-198f2e58a7d2",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "There are multiple ways to force deterministic behavior with the `generate` function:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2025-04-13 11:06:57 -07:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "1. Setting to `temperature=0.0`;\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 20:28:51 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "2. Setting `top_k=1`."
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "391c5dc8-8dd7-4a0a-90bd-519b72f528c7",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "Below is a self-contained example using code from chapter 5:"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-07-24 21:53:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 5,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 20:28:51 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "a61a4034-797a-4635-bf42-ddfff1b07125",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "import tiktoken\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "import torch\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "from previous_chapters import GPTModel\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "GPT_CONFIG_124M = {\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"vocab_size\": 50257,  # Vocabulary size\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-04 07:27:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    \"context_length\": 256,       # Shortened context length (orig: 1024)\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 20:28:51 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "    \"emb_dim\": 768,       # Embedding dimension\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"n_heads\": 12,        # Number of attention heads\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"n_layers\": 12,       # Number of layers\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"drop_rate\": 0.1,     # Dropout rate\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"qkv_bias\": False     # Query-key-value bias\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "}\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "torch.manual_seed(123)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "tokenizer = tiktoken.get_encoding(\"gpt2\")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "model = GPTModel(GPT_CONFIG_124M)\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-07-24 21:53:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "model.load_state_dict(torch.load(\"model.pth\", weights_only=True))\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 20:28:51 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "model.eval();"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-07-24 21:53:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 6,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 20:28:51 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "ee95a272-b852-43b4-9827-ea7e1dbd5724",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "from gpt_generate import generate, text_to_token_ids, token_ids_to_text\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "from previous_chapters import generate_text_simple"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-07-24 21:53:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 7,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 20:28:51 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "4ab43658-3240-484a-9072-a40a0ed85be6",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Output text:\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-03 20:19:08 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      " Every effort moves you know,\" was one of the axioms he laid down across the Sevres and silver of an exquisitely appointed lun\n"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 20:28:51 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "# Deterministic function that used torch.argmax\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "start_context = \"Every effort moves you\"\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "token_ids = generate_text_simple(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    model=model,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    idx=text_to_token_ids(start_context, tokenizer),\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    max_new_tokens=25,\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-04 07:27:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    context_size=GPT_CONFIG_124M[\"context_length\"]\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 20:28:51 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    ")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(\"Output text:\\n\", token_ids_to_text(token_ids, tokenizer))"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-07-24 21:53:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 8,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 20:28:51 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "ebb22d06-393a-42d3-ab64-66646d33b39b",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Output text:\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-03 20:19:08 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      " Every effort moves you know,\" was one of the axioms he laid down across the Sevres and silver of an exquisitely appointed lun\n"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 20:28:51 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "# Deterministic behavior: No top_k, no temperature scaling\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "token_ids = generate(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    model=model,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    idx=text_to_token_ids(\"Every effort moves you\", tokenizer),\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    max_new_tokens=25,\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-04 07:27:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    context_size=GPT_CONFIG_124M[\"context_length\"],\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 20:28:51 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "    top_k=None,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    temperature=0.0\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    ")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(\"Output text:\\n\", token_ids_to_text(token_ids, tokenizer))"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2024-08-19 21:04:18 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "c85b1f11-37a5-477d-9c2d-170a6865e669",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Note that re-executing the previous code cell will produce the exact same generated text:"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 20:28:51 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-07-24 21:53:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 9,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 20:28:51 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "75469f24-47cc-458d-a200-fe64c648131d",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Output text:\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-03 20:19:08 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      " Every effort moves you know,\" was one of the axioms he laid down across the Sevres and silver of an exquisitely appointed lun\n"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 20:28:51 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "# Deterministic behavior: No top_k, no temperature scaling\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "token_ids = generate(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    model=model,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    idx=text_to_token_ids(\"Every effort moves you\", tokenizer),\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    max_new_tokens=25,\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-04 07:27:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    context_size=GPT_CONFIG_124M[\"context_length\"],\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 20:28:51 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "    top_k=None,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    temperature=0.0\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    ")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(\"Output text:\\n\", token_ids_to_text(token_ids, tokenizer))"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "6d0480e5-fb4e-41f8-a161-7ac980d71d47",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-04-07 06:03:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "# Exercise 5.4: Continued pretraining"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 20:28:51 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "f40044e8-a0f5-476c-99fd-489b999fd80a",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-08-19 21:04:18 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "- If we are still in the Python session where you first trained the model in chapter 5, to continue the pretraining for one more epoch, we just have to load the model and optimizer that we saved in the main chapter and call the `train_model_simple` function again\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 20:28:51 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-08-19 21:04:18 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "- It takes a couple more steps to make this reproducible in this new code environment\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- First, we load the tokenizer, model, and optimizer:"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 20:28:51 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-07-24 21:53:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 10,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 20:28:51 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "94eae6ba-d9fd-417a-8e31-fc39e9299870",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "import tiktoken\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "import torch\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "from previous_chapters import GPTModel\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "GPT_CONFIG_124M = {\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-04 07:27:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    \"vocab_size\": 50257,   # Vocabulary size\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"context_length\": 256, # Shortened context length (orig: 1024)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"emb_dim\": 768,        # Embedding dimension\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"n_heads\": 12,         # Number of attention heads\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"n_layers\": 12,        # Number of layers\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"drop_rate\": 0.1,      # Dropout rate\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"qkv_bias\": False      # Query-key-value bias\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 20:28:51 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "}\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "tokenizer = tiktoken.get_encoding(\"gpt2\")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-07-24 21:53:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "checkpoint = torch.load(\"model_and_optimizer.pth\", weights_only=True)\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 20:28:51 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "model = GPTModel(GPT_CONFIG_124M)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "model.load_state_dict(checkpoint[\"model_state_dict\"])\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-06-20 08:07:00 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "model.to(device)\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 20:28:51 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-03 20:19:08 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "optimizer = torch.optim.AdamW(model.parameters(), lr=0.0004, weight_decay=0.1)\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 20:28:51 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "optimizer.load_state_dict(checkpoint[\"optimizer_state_dict\"])\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "model.train();"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "688fce4a-9ab2-4d97-a95c-fef02c32b4f3",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-08-19 21:04:18 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "- Next, we initialize the data loader:"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 20:28:51 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-07-24 21:53:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 11,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 20:28:51 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "b5a78470-0652-4abd-875a-664e23c07c36",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "import os\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "import urllib.request\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "from previous_chapters import create_dataloader_v1\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "file_path = \"the-verdict.txt\"\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "url = \"https://raw.githubusercontent.com/rasbt/LLMs-from-scratch/main/ch02/01_main-chapter-code/the-verdict.txt\"\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "if not os.path.exists(file_path):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    with urllib.request.urlopen(url) as response:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        text_data = response.read().decode('utf-8')\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    with open(file_path, \"w\", encoding=\"utf-8\") as file:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        file.write(text_data)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "else:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    with open(file_path, \"r\", encoding=\"utf-8\") as file:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        text_data = file.read()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "# Train/validation ratio\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "train_ratio = 0.90\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "split_idx = int(train_ratio * len(text_data))\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "train_data = text_data[:split_idx]\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "val_data = text_data[split_idx:]\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "torch.manual_seed(123)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "train_loader = create_dataloader_v1(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    train_data,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    batch_size=2,\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-04 07:27:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    max_length=GPT_CONFIG_124M[\"context_length\"],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    stride=GPT_CONFIG_124M[\"context_length\"],\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 20:28:51 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "    drop_last=True,\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-13 14:57:56 -04:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    shuffle=True,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    num_workers=0\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 20:28:51 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    ")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "val_loader = create_dataloader_v1(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    val_data,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    batch_size=2,\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-04 07:27:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    max_length=GPT_CONFIG_124M[\"context_length\"],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    stride=GPT_CONFIG_124M[\"context_length\"],\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 20:28:51 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "    drop_last=False,\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-13 14:57:56 -04:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    shuffle=False,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    num_workers=0\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 20:28:51 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    ")"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "76598ef8-165c-4bcc-af5e-b6fe72398365",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-08-19 21:04:18 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "- Lastly, we use the `train_model_simple` function to train the model:"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 20:28:51 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-07-24 21:53:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 12,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 20:28:51 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "ab4693dc-1359-47a7-8110-1e90f514a49e",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
									
										
										
										
											2024-04-03 20:19:08 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "Ep 1 (Step 000000): Train loss 0.271, Val loss 6.545\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Ep 1 (Step 000005): Train loss 0.244, Val loss 6.614\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Every effort moves you?\"  \"Yes--quite insensible to the irony. She wanted him vindicated--and by me!\"  He laughed again, and threw back his head to look up at the sketch of the donkey. \"There were days when I\n"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 20:28:51 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "from gpt_train import train_model_simple\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "num_epochs = 1\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "train_losses, val_losses, tokens_seen = train_model_simple(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    model, train_loader, val_loader, optimizer, device,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    num_epochs=num_epochs, eval_freq=5, eval_iter=5,\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-28 14:31:27 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    start_context=\"Every effort moves you\", tokenizer=tokenizer\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 20:28:51 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    ")"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "3384e788-f5a1-407c-8dd1-87959b75026d",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-04-07 06:03:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "# Exercise 5.5: Training and validation set losses of the pretrained model"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 20:28:51 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "7cb1140b-2027-4156-8d19-600ac849edbe",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-08-19 21:04:18 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "- We can use the following code to calculate the training and validation set losses of the GPT model:\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 20:28:51 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "```python\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "train_loss = calc_loss_loader(train_loader, gpt, device)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "val_loss = calc_loss_loader(val_loader, gpt, device)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "```\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-08-19 21:04:18 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "- The resulting losses for the 124M parameter are as follows:\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 20:28:51 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "```\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "Training loss: 3.754748503367106\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "Validation loss: 3.559617757797241\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "```\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-08-19 21:04:18 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "- The main observation is that the training and validation set performances are in the same ballpark\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- This can have multiple explanations:\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 20:28:51 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "1. The Verdict was not part of the pretraining dataset when OpenAI trained GPT-2. Hence, the model is not explicitly overfitting to the training set and performs similarly well on The Verdict's training and validation set portions. (The validation set loss is slightly lower than the training set loss, which is unusual in deep learning. However, it's likely due to random noise since the dataset is relatively small. In practice, if there is no overfitting, the training and validation set performances are expected to be roughly identical).\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "2. The Verdict was part of GPT -2's training dataset. In this case, we can't tell whether the model is overfitting the training data because the validation set would have been used for training as well. To evaluate the degree of overfitting, we'd need a new dataset generated after OpenAI finished training GPT-2 to make sure that it couldn't have been part of the pretraining."
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "66bb4316-a57c-437f-9a01-fe99b1678524",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "The code below is a reproducible standalone example for this new notebook."
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-07-24 21:53:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 13,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 20:28:51 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "68d162d6-bbb9-4d6d-82ee-1c410694f872",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "import tiktoken\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "import torch\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "from previous_chapters import GPTModel\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "GPT_CONFIG_124M = {\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-04 07:27:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    \"vocab_size\": 50257,   # Vocabulary size\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"context_length\": 256, # Shortened context length (orig: 1024)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"emb_dim\": 768,        # Embedding dimension\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"n_heads\": 12,         # Number of attention heads\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"n_layers\": 12,        # Number of layers\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"drop_rate\": 0.1,      # Dropout rate\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"qkv_bias\": False      # Query-key-value bias\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 20:28:51 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "}\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "torch.manual_seed(123)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "tokenizer = tiktoken.get_encoding(\"gpt2\")"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-07-24 21:53:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 14,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 20:28:51 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "d8373461-7dad-47da-a489-3e23f0799b23",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "File already exists and is up-to-date: gpt2/124M/checkpoint\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "File already exists and is up-to-date: gpt2/124M/encoder.json\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-28 14:31:27 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "File already exists and is up-to-date: gpt2/124M/hparams.json\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 20:28:51 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								      "File already exists and is up-to-date: gpt2/124M/model.ckpt.data-00000-of-00001\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "File already exists and is up-to-date: gpt2/124M/model.ckpt.index\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "File already exists and is up-to-date: gpt2/124M/model.ckpt.meta\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "File already exists and is up-to-date: gpt2/124M/vocab.bpe\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "from gpt_download import download_and_load_gpt2\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-05 07:24:46 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "settings, params = download_and_load_gpt2(model_size=\"124M\", models_dir=\"gpt2\")"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 20:28:51 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-07-24 21:53:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 15,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 20:28:51 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "cdd44873-d6c2-4471-a20f-f639b09fdcd3",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "# Define model configurations in a dictionary for compactness\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "model_configs = {\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"gpt2-small (124M)\": {\"emb_dim\": 768, \"n_layers\": 12, \"n_heads\": 12},\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"gpt2-medium (355M)\": {\"emb_dim\": 1024, \"n_layers\": 24, \"n_heads\": 16},\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"gpt2-large (774M)\": {\"emb_dim\": 1280, \"n_layers\": 36, \"n_heads\": 20},\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"gpt2-xl (1558M)\": {\"emb_dim\": 1600, \"n_layers\": 48, \"n_heads\": 25},\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "}\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "# Copy the base configuration and update with specific model settings\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "model_name = \"gpt2-small (124M)\"  # Example model name\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "NEW_CONFIG = GPT_CONFIG_124M.copy()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "NEW_CONFIG.update(model_configs[model_name])\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-04 07:27:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "NEW_CONFIG.update({\"context_length\": 1024, \"qkv_bias\": True})\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 20:28:51 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "gpt = GPTModel(NEW_CONFIG)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "gpt.eval();"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-07-24 21:53:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 16,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 20:28:51 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "c7d562e4-33f6-4611-9b75-6ad1cb441d3b",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-06-21 15:40:54 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "from gpt_generate import load_weights_into_gpt\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 20:28:51 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "load_weights_into_gpt(gpt, params)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "gpt.to(device);"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-07-24 21:53:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 17,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 20:28:51 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "46eda9ea-ccb0-46ee-931b-3c07502b2544",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "import os\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "import urllib.request\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "from previous_chapters import create_dataloader_v1\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "file_path = \"the-verdict.txt\"\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "url = \"https://raw.githubusercontent.com/rasbt/LLMs-from-scratch/main/ch02/01_main-chapter-code/the-verdict.txt\"\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "if not os.path.exists(file_path):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    with urllib.request.urlopen(url) as response:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        text_data = response.read().decode('utf-8')\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    with open(file_path, \"w\", encoding=\"utf-8\") as file:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        file.write(text_data)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "else:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    with open(file_path, \"r\", encoding=\"utf-8\") as file:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        text_data = file.read()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "# Train/validation ratio\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "train_ratio = 0.90\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "split_idx = int(train_ratio * len(text_data))\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "train_data = text_data[:split_idx]\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "val_data = text_data[split_idx:]\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "torch.manual_seed(123)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "train_loader = create_dataloader_v1(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    train_data,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    batch_size=2,\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-04 07:27:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    max_length=GPT_CONFIG_124M[\"context_length\"],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    stride=GPT_CONFIG_124M[\"context_length\"],\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 20:28:51 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "    drop_last=True,\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-13 14:57:56 -04:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    shuffle=True,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    num_workers=0\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 20:28:51 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    ")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "val_loader = create_dataloader_v1(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    val_data,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    batch_size=2,\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-04 07:27:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    max_length=GPT_CONFIG_124M[\"context_length\"],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    stride=GPT_CONFIG_124M[\"context_length\"],\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 20:28:51 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "    drop_last=False,\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-13 14:57:56 -04:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    shuffle=False,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    num_workers=0\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 20:28:51 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    ")"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-07-24 21:53:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 18,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 20:28:51 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "4e3574a2-687d-47a2-a2f6-457fe9d595f1",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
									
										
										
										
											2024-07-24 21:53:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "Training loss: 3.7547486888037787\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Validation loss: 3.5596182346343994\n"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 20:28:51 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "from gpt_train import calc_loss_loader\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "torch.manual_seed(123) # For reproducibility due to the shuffling in the data loader\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "train_loss = calc_loss_loader(train_loader, gpt, device)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "val_loss = calc_loss_loader(val_loader, gpt, device)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(\"Training loss:\", train_loss)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(\"Validation loss:\", val_loss)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "96485d6b-bf1f-4bc0-a53f-73b08d85726e",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "We can also repeat this for the largest GPT-2 model, but don't forget to update the context length:"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-07-24 21:53:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 19,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 20:28:51 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "1a79a4b6-fe8f-40c2-a018-e731dcf391b3",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
									
										
										
										
											2024-07-24 21:53:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stderr",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "checkpoint: 100%|███████████████████████████| 77.0/77.0 [00:00<00:00, 43.5kiB/s]\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "encoder.json: 100%|███████████████████████| 1.04M/1.04M [00:00<00:00, 2.75MiB/s]\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "hparams.json: 100%|█████████████████████████| 91.0/91.0 [00:00<00:00, 60.2kiB/s]\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "model.ckpt.data-00000-of-00001: 100%|█████| 6.23G/6.23G [06:02<00:00, 17.2MiB/s]\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "model.ckpt.index: 100%|████████████████████| 20.7k/20.7k [00:00<00:00, 171kiB/s]\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "model.ckpt.meta: 100%|████████████████████| 1.84M/1.84M [00:00<00:00, 4.27MiB/s]\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "vocab.bpe: 100%|████████████████████████████| 456k/456k [00:00<00:00, 1.73MiB/s]\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 20:28:51 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
									
										
										
										
											2024-07-24 21:53:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "Training loss: 3.3046312861972384\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Validation loss: 3.1195147037506104\n"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 20:28:51 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-04-05 07:24:46 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "settings, params = download_and_load_gpt2(model_size=\"1558M\", models_dir=\"gpt2\")\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 20:28:51 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "model_name = \"gpt2-xl (1558M)\"\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "NEW_CONFIG = GPT_CONFIG_124M.copy()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "NEW_CONFIG.update(model_configs[model_name])\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-04 07:27:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "NEW_CONFIG.update({\"context_length\": 1024, \"qkv_bias\": True})\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 20:28:51 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "gpt = GPTModel(NEW_CONFIG)\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-06-21 15:40:54 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "gpt.eval()\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 20:28:51 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "load_weights_into_gpt(gpt, params)\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-06-21 15:40:54 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "gpt.to(device)\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 20:28:51 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "torch.manual_seed(123)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "train_loss = calc_loss_loader(train_loader, gpt, device)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "val_loss = calc_loss_loader(val_loader, gpt, device)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(\"Training loss:\", train_loss)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(\"Validation loss:\", val_loss)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "3a76a1e0-9635-480a-9391-3bda7aea402d",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-04-07 06:03:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "# Exercise 5.6: Trying larger models"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 20:28:51 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "b3d313f4-0038-4bc9-a340-84b3b55dc0e3",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-08-19 21:04:18 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "- In the main chapter, we experimented with the smallest GPT-2 model, which has only 124M parameters\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- The reason was to keep the resource requirements as low as possible\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- However, you can easily experiment with larger models with minimal code changes\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- For example, instead of loading the 1558M instead of 124M model in chapter 5, the only 2 lines of code that we have to change are\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 20:28:51 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-06-21 15:40:54 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "```python\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-05 07:24:46 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "settings, params = download_and_load_gpt2(model_size=\"124M\", models_dir=\"gpt2\")\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 20:28:51 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "model_name = \"gpt2-small (124M)\"\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "```\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-08-19 21:04:18 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "- The updated code becomes\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 20:28:51 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-06-21 15:40:54 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "```python\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-05 07:24:46 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "settings, params = download_and_load_gpt2(model_size=\"1558M\", models_dir=\"gpt2\")\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 20:28:51 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "model_name = \"gpt2-xl (1558M)\"\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "```"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-07-24 21:53:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 20,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 20:28:51 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "31e0972b-e85e-4904-a0f5-24c3eacd5fa2",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "import tiktoken\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "import torch\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "from previous_chapters import GPTModel\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "GPT_CONFIG_124M = {\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-04 07:27:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    \"vocab_size\": 50257,   # Vocabulary size\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"context_length\": 256, # Shortened context length (orig: 1024)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"emb_dim\": 768,        # Embedding dimension\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"n_heads\": 12,         # Number of attention heads\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"n_layers\": 12,        # Number of layers\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"drop_rate\": 0.1,      # Dropout rate\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"qkv_bias\": False      # Query-key-value bias\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 20:28:51 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "}\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "tokenizer = tiktoken.get_encoding(\"gpt2\")"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-07-24 21:53:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 21,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 20:28:51 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "b641ee88-f9d4-43ec-a787-e34199eed356",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "File already exists and is up-to-date: gpt2/1558M/checkpoint\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "File already exists and is up-to-date: gpt2/1558M/encoder.json\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-28 14:31:27 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "File already exists and is up-to-date: gpt2/1558M/hparams.json\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 20:28:51 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								      "File already exists and is up-to-date: gpt2/1558M/model.ckpt.data-00000-of-00001\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "File already exists and is up-to-date: gpt2/1558M/model.ckpt.index\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "File already exists and is up-to-date: gpt2/1558M/model.ckpt.meta\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "File already exists and is up-to-date: gpt2/1558M/vocab.bpe\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "from gpt_download import download_and_load_gpt2\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "from gpt_generate import load_weights_into_gpt\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "model_configs = {\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"gpt2-small (124M)\": {\"emb_dim\": 768, \"n_layers\": 12, \"n_heads\": 12},\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"gpt2-medium (355M)\": {\"emb_dim\": 1024, \"n_layers\": 24, \"n_heads\": 16},\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"gpt2-large (774M)\": {\"emb_dim\": 1280, \"n_layers\": 36, \"n_heads\": 20},\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"gpt2-xl (1558M)\": {\"emb_dim\": 1600, \"n_layers\": 48, \"n_heads\": 25},\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "}\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "model_name = \"gpt2-xl (1558M)\"\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "NEW_CONFIG = GPT_CONFIG_124M.copy()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "NEW_CONFIG.update(model_configs[model_name])\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-04 07:27:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "NEW_CONFIG.update({\"context_length\": 1024, \"qkv_bias\": True})\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 20:28:51 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "gpt = GPTModel(NEW_CONFIG)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "gpt.eval()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-05 07:24:46 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "settings, params = download_and_load_gpt2(model_size=\"1558M\", models_dir=\"gpt2\")\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 20:28:51 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "load_weights_into_gpt(gpt, params)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-07-24 21:53:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 22,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 20:28:51 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "c98f56f4-98fc-43b4-9ee5-726e9d17c73f",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-06-21 15:40:54 +02:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "from gpt_generate import generate, text_to_token_ids, token_ids_to_text"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 20:28:51 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-07-24 21:53:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 23,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 20:28:51 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "b1f7853c-6e81-4f1f-a1d0-61e2c7d33a20",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Output text:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      " Every effort moves you toward finding an ideal life. You don't have to accept your current one at once, because if you do you'll never\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "torch.manual_seed(123)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "token_ids = generate(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    model=gpt,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    idx=text_to_token_ids(\"Every effort moves you\", tokenizer),\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    max_new_tokens=25,\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-04 07:27:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    context_size=NEW_CONFIG[\"context_length\"],\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 20:28:51 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "    top_k=50,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    temperature=1.5\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    ")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(\"Output text:\\n\", token_ids_to_text(token_ids, tokenizer))"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  "kernelspec": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "display_name": "Python 3 (ipykernel)",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "language": "python",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "name": "python3"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  "language_info": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "codemirror_mode": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "name": "ipython",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "version": 3
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "file_extension": ".py",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "mimetype": "text/x-python",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "name": "python",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "nbconvert_exporter": "python",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "pygments_lexer": "ipython3",
							 
						 
					
						
							
								
									
										
										
										
											2025-04-13 11:06:57 -07:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "version": "3.10.16"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-31 20:28:51 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								  }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 "nbformat": 4,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 "nbformat_minor": 5
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								}