2024-01-29 08:13:52 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								{
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 "cells": [
							 
						 
					
						
							
								
									
										
										
										
											2024-03-19 09:26:26 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "08f4321d-d32a-4a90-bfc7-e923f316b2f8",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-05-24 07:20:37 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "<table style=\"width:100%\">\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "<tr>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "<td style=\"vertical-align:middle; text-align:left;\">\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "<font size=\"2\">\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "Supplementary code for the <a href=\"http://mng.bz/orYv\">Build a Large Language Model From Scratch</a> book by <a href=\"https://sebastianraschka.com\">Sebastian Raschka</a><br>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "<br>Code repository: <a href=\"https://github.com/rasbt/LLMs-from-scratch\">https://github.com/rasbt/LLMs-from-scratch</a>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "</font>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "</td>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "<td style=\"vertical-align:middle; text-align:left;\">\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "<a href=\"http://mng.bz/orYv\"><img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/cover-small.webp\" width=\"100px\"></a>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "</td>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "</tr>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "</table>"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-19 09:26:26 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2024-01-29 08:13:52 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "ce9295b2-182b-490b-8325-83a67c4a001d",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 10:02:05 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "# Chapter 4: Implementing a GPT model from Scratch To Generate Text "
							 
						 
					
						
							
								
									
										
										
										
											2024-01-29 08:13:52 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2024-03-16 09:50:00 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "execution_count": 1,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "f9eac223-a125-40f7-bacc-bd0d890450c7",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "matplotlib version: 3.7.2\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "torch version: 2.2.1\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "tiktoken version: 0.5.1\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "from importlib.metadata import version\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "import matplotlib\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "import tiktoken\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "import torch\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(\"matplotlib version:\", version(\"matplotlib\"))\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(\"torch version:\", version(\"torch\"))\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(\"tiktoken version:\", version(\"tiktoken\"))"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2024-01-29 08:13:52 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "e7da97ed-e02f-4d7f-b68e-a0eba3716e02",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 10:02:05 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "- In this chapter, we implement a GPT-like LLM architecture; the next chapter will focus on training this LLM"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "7d4f11e0-4434-4979-9dee-e1207df0eb01",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 09:29:06 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch04_compressed/01.webp\" width=\"500px\">"
							 
						 
					
						
							
								
									
										
										
										
											2024-01-29 08:13:52 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "53fe99ab-0bcf-4778-a6b5-6db81fb826ef",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-02-08 20:16:54 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "## 4.1 Coding an LLM architecture"
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 10:02:05 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "ad72d1ff-d82d-4e33-a88e-3c1a8831797b",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Chapter 1 discussed models like GPT and Llama, which generate words sequentially and are based on the decoder part of the original transformer architecture\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Therefore, these LLMs are often referred to as \"decoder-like\" LLMs\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Compared to conventional deep learning models, LLMs are larger, mainly due to their vast number of parameters, not the amount of code\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- We'll see that many elements are repeated in an LLM's architecture"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "5c5213e9-bd1c-437e-aee8-f5e8fb717251",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 09:08:38 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch04_compressed/02.webp\" width=\"400px\">"
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 10:02:05 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "0d43f5e2-fb51-434a-b9be-abeef6b98d99",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- In previous chapters, we used small embedding dimensions for token inputs and outputs for ease of illustration, ensuring they fit on a single page\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- In this chapter, we consider embedding and model sizes akin to a small GPT-2 model\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- We'll specifically code the architecture of the smallest GPT-2 model (124 million parameters), as outlined in Radford et al.'s [Language Models are Unsupervised Multitask Learners](https://d4mucfpksywv.cloudfront.net/better-language-models/language_models_are_unsupervised_multitask_learners.pdf) (note that the initial report lists it as 117M parameters, but this was later corrected in the model weight repository)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Chapter 6 will show how to load pretrained weights into our implementation, which will be compatible with model sizes of 345, 762, and 1542 million parameters"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "21baa14d-24b8-4820-8191-a2808f7fbabc",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Configuration details for the 124 million parameter GPT-2 model include:"
							 
						 
					
						
							
								
									
										
										
										
											2024-01-29 08:13:52 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-03 20:19:08 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 2,
							 
						 
					
						
							
								
									
										
										
										
											2024-01-29 08:13:52 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "5ed66875-1f24-445d-add6-006aae3c5707",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-01-31 08:00:19 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "GPT_CONFIG_124M = {\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-04 07:27:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    \"vocab_size\": 50257,    # Vocabulary size\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"context_length\": 1024, # Context length\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"emb_dim\": 768,         # Embedding dimension\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"n_heads\": 12,          # Number of attention heads\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"n_layers\": 12,         # Number of layers\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"drop_rate\": 0.1,       # Dropout rate\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"qkv_bias\": False       # Query-Key-Value bias\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-29 08:13:52 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "}"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 10:02:05 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "c12fcd28-d210-4c57-8be6-06cfcd5d73a4",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- We use short variable names to avoid long lines of code later\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- `\"vocab_size\"` indicates a vocabulary size of 50,257 words, supported by the BPE tokenizer discussed in Chapter 2\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-04 07:27:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "- `\"context_length\"` represents the model's maximum input token count, as enabled by positional embeddings covered in Chapter 2\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 10:02:05 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "- `\"emb_dim\"` is the embedding size for token inputs, converting each input token into a 768-dimensional vector\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- `\"n_heads\"` is the number of attention heads in the multi-head attention mechanism implemented in Chapter 3\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- `\"n_layers\"` is the number of transformer blocks within the model, which we'll implement in upcoming sections\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- `\"drop_rate\"` is the dropout mechanism's intensity, discussed in Chapter 3; 0.1 means dropping 10% of hidden units during training to mitigate overfitting\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- `\"qkv_bias\"` decides if the `Linear` layers in the multi-head attention mechanism (from Chapter 3) should include a bias vector when computing query (Q), key (K), and value (V) tensors; we'll disable this option, which is standard practice in modern LLMs; however, we'll revisit this later when loading pretrained GPT-2 weights from OpenAI into our reimplementation in Chapter 6"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "4adce779-857b-4418-9501-12a7f3818d88",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 09:29:06 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch04_compressed/03.webp\" width=\"500px\">"
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 10:02:05 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2024-01-29 08:13:52 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-03 20:19:08 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 3,
							 
						 
					
						
							
								
									
										
										
										
											2024-01-29 08:13:52 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "619c2eed-f8ea-4ff5-92c3-feda0f29b227",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-02-27 23:05:36 +08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "import torch\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-29 08:13:52 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "import torch.nn as nn\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "class DummyGPTModel(nn.Module):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    def __init__(self, cfg):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        super().__init__()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        self.tok_emb = nn.Embedding(cfg[\"vocab_size\"], cfg[\"emb_dim\"])\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-04 07:58:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "        self.pos_emb = nn.Embedding(cfg[\"context_length\"], cfg[\"emb_dim\"])\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-29 08:13:52 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "        self.drop_emb = nn.Dropout(cfg[\"drop_rate\"])\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        # Use a placeholder for TransformerBlock\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        self.trf_blocks = nn.Sequential(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            *[DummyTransformerBlock(cfg) for _ in range(cfg[\"n_layers\"])])\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        # Use a placeholder for LayerNorm\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        self.final_norm = DummyLayerNorm(cfg[\"emb_dim\"])\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-02-05 06:51:58 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "        self.out_head = nn.Linear(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            cfg[\"emb_dim\"], cfg[\"vocab_size\"], bias=False\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        )\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-29 08:13:52 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    def forward(self, in_idx):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        batch_size, seq_len = in_idx.shape\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        tok_embeds = self.tok_emb(in_idx)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        x = tok_embeds + pos_embeds\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        x = self.drop_emb(x)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        x = self.trf_blocks(x)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        x = self.final_norm(x)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        logits = self.out_head(x)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        return logits\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "class DummyTransformerBlock(nn.Module):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    def __init__(self, cfg):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        super().__init__()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        # A simple placeholder\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    def forward(self, x):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        # This block does nothing and just returns its input.\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        return x\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "class DummyLayerNorm(nn.Module):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    def __init__(self, normalized_shape, eps=1e-5):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        super().__init__()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        # The parameters here are just to mimic the LayerNorm interface.\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    def forward(self, x):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        # This layer does nothing and just returns its input.\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        return x"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 10:02:05 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "9665e8ab-20ca-4100-b9b9-50d9bdee33be",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 09:29:06 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch04_compressed/04.webp\" width=\"500px\">"
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 10:02:05 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2024-01-29 08:13:52 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-03 20:19:08 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 4,
							 
						 
					
						
							
								
									
										
										
										
											2024-01-29 08:13:52 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "794b6b6c-d36f-411e-a7db-8ac566a87fee",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
									
										
										
										
											2024-02-10 17:53:54 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
									
										
										
										
											2024-02-15 20:09:20 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "tensor([[6109, 3626, 6100,  345],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "        [6109, 1110, 6622,  257]])\n"
							 
						 
					
						
							
								
									
										
										
										
											2024-02-10 17:53:54 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
									
										
										
										
											2024-01-29 08:13:52 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "import tiktoken\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "tokenizer = tiktoken.get_encoding(\"gpt2\")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "batch = []\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-02-15 20:09:20 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "txt1 = \"Every effort moves you\"\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "txt2 = \"Every day holds a\"\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-29 08:13:52 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "batch.append(torch.tensor(tokenizer.encode(txt1)))\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "batch.append(torch.tensor(tokenizer.encode(txt2)))\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "batch = torch.stack(batch, dim=0)\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-02-05 06:51:58 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "print(batch)"
							 
						 
					
						
							
								
									
										
										
										
											2024-01-29 08:13:52 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-03 20:19:08 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 5,
							 
						 
					
						
							
								
									
										
										
										
											2024-01-29 08:13:52 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "009238cd-0160-4834-979c-309710986bb0",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
									
										
										
										
											2024-02-15 20:09:20 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "Output shape: torch.Size([2, 4, 50257])\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-29 08:13:52 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								      "tensor([[[-1.2034,  0.3201, -0.7130,  ..., -1.5548, -0.2390, -0.4667],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "         [-0.1192,  0.4539, -0.4432,  ...,  0.2392,  1.3469,  1.2430],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "         [ 0.5307,  1.6720, -0.4695,  ...,  1.1966,  0.0111,  0.5835],\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-03 20:19:08 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "         [ 0.0139,  1.6754, -0.3388,  ...,  1.1586, -0.0435, -1.0400]],\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-29 08:13:52 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								      "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-02-15 20:09:20 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "        [[-1.0908,  0.1798, -0.9484,  ..., -1.6047,  0.2439, -0.4530],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "         [-0.7860,  0.5581, -0.0610,  ...,  0.4835, -0.0077,  1.6621],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "         [ 0.3567,  1.2698, -0.6398,  ..., -0.0162, -0.1296,  0.3717],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "         [-0.2407, -0.7349, -0.5102,  ...,  2.0057, -0.3694,  0.1814]]],\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-29 08:13:52 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								      "       grad_fn=<UnsafeViewBackward0>)\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "torch.manual_seed(123)\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-31 08:00:19 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "model = DummyGPTModel(GPT_CONFIG_124M)\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-29 08:13:52 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 10:02:05 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "logits = model(batch)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(\"Output shape:\", logits.shape)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(logits)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "f8332a00-98da-4eb4-b882-922776a89917",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "## 4.2 Normalizing activations with layer normalization"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "066cfb81-d59b-4d95-afe3-e43cf095f292",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Layer normalization, also known as LayerNorm ([Ba et al. 2016](https://arxiv.org/abs/1607.06450)), centers the activations of a neural network layer around a mean of 0 and normalizes their variance to 1\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- This stabilizes training and enables faster convergence to effective weights\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Layer normalization is applied both before and after the multi-head attention module within the transformer block, which we will implement later; it's also applied before the final output layer"
							 
						 
					
						
							
								
									
										
										
										
											2024-01-29 08:13:52 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 10:02:05 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "id": "314ac47a-69cc-4597-beeb-65bed3b5910f",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-29 08:13:52 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 09:08:38 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch04_compressed/05.webp\" width=\"400px\">"
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 10:02:05 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "5ab49940-6b35-4397-a80e-df8d092770a7",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Let's see how layer normalization works by passing a small input sample through a simple neural network layer:"
							 
						 
					
						
							
								
									
										
										
										
											2024-01-29 08:13:52 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-03 20:19:08 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 6,
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 10:02:05 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "id": "79e1b463-dc3f-44ac-9cdb-9d5b6f64eb9d",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "tensor([[0.2260, 0.3470, 0.0000, 0.2216, 0.0000, 0.0000],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "        [0.2133, 0.2394, 0.0000, 0.5198, 0.3297, 0.0000]],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "       grad_fn=<ReluBackward0>)\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "torch.manual_seed(123)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "# create 2 training examples with 5 dimensions (features) each\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-02-13 07:05:37 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "batch_example = torch.randn(2, 5) \n",
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 10:02:05 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "layer = nn.Sequential(nn.Linear(5, 6), nn.ReLU())\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-02-13 07:05:37 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "out = layer(batch_example)\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 10:02:05 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "print(out)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "8fccc29e-71fc-4c16-898c-6137c6ea5d2e",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Let's compute the mean and variance for each of the 2 inputs above:"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-03 20:19:08 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 7,
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 10:02:05 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "id": "9888f79e-8e69-44aa-8a19-cd34292adbf5",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Mean:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      " tensor([[0.1324],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "        [0.2170]], grad_fn=<MeanBackward1>)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Variance:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      " tensor([[0.0231],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "        [0.0398]], grad_fn=<VarBackward0>)\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "mean = out.mean(dim=-1, keepdim=True)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "var = out.var(dim=-1, keepdim=True)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(\"Mean:\\n\", mean)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(\"Variance:\\n\", var)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "052eda3e-b395-48c4-acd4-eb8083bab958",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- The normalization is applied to each of the two inputs (rows) independently; using dim=-1 applies the calculation across the last dimension (in this case, the feature dimension) instead of the row dimension"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "570db83a-205c-4f6f-b219-1f6195dde1a7",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 09:08:38 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch04_compressed/06.webp\" width=\"400px\">"
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 10:02:05 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "9f8ecbc7-eb14-4fa1-b5d0-7e1ff9694f99",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Subtracting the mean and dividing by the square-root of the variance (standard deviation) centers the inputs to have a mean of 0 and a variance of 1 across the column (feature) dimension:"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-03 20:19:08 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 8,
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 10:02:05 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "id": "9a1d1bb9-3341-4c9a-bc2a-d2489bf89cda",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Normalized layer outputs:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      " tensor([[ 0.6159,  1.4126, -0.8719,  0.5872, -0.8719, -0.8719],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "        [-0.0189,  0.1121, -1.0876,  1.5173,  0.5647, -1.0876]],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "       grad_fn=<DivBackward0>)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Mean:\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-03 20:19:08 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      " tensor([[-5.9605e-08],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "        [ 1.9868e-08]], grad_fn=<MeanBackward1>)\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 10:02:05 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "Variance:\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-03 20:19:08 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      " tensor([[1.0000],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "        [1.0000]], grad_fn=<VarBackward0>)\n"
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 10:02:05 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "out_norm = (out - mean) / torch.sqrt(var)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(\"Normalized layer outputs:\\n\", out_norm)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "mean = out_norm.mean(dim=-1, keepdim=True)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "var = out_norm.var(dim=-1, keepdim=True)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(\"Mean:\\n\", mean)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(\"Variance:\\n\", var)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "ac62b90c-7156-4979-9a79-ce1fb92969c1",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Each input is centered at 0 and has a unit variance of 1; to improve readability, we can disable PyTorch's scientific notation:"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-03 20:19:08 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 9,
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 10:02:05 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "id": "3e06c34b-c68a-4b36-afbe-b30eda4eca39",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Mean:\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-03 20:19:08 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      " tensor([[    -0.0000],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "        [     0.0000]], grad_fn=<MeanBackward1>)\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 10:02:05 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "Variance:\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-03 20:19:08 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      " tensor([[1.0000],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "        [1.0000]], grad_fn=<VarBackward0>)\n"
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 10:02:05 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "torch.set_printoptions(sci_mode=False)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(\"Mean:\\n\", mean)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(\"Variance:\\n\", var)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "944fb958-d4ed-43cc-858d-00052bb6b31a",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Above, we normalized the features of each input\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Now, using the same idea, we can implement a `LayerNorm` class:"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-03 20:19:08 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 10,
							 
						 
					
						
							
								
									
										
										
										
											2024-01-29 08:13:52 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "3333a305-aa3d-460a-bcce-b80662d464d9",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "class LayerNorm(nn.Module):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    def __init__(self, emb_dim):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        super().__init__()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        self.eps = 1e-5\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        self.scale = nn.Parameter(torch.ones(emb_dim))\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        self.shift = nn.Parameter(torch.zeros(emb_dim))\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    def forward(self, x):\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 10:02:05 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "        mean = x.mean(dim=-1, keepdim=True)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        var = x.var(dim=-1, keepdim=True, unbiased=False)\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-29 08:13:52 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "        norm_x = (x - mean) / torch.sqrt(var + self.eps)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        return self.scale * norm_x + self.shift"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 10:02:05 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "id": "e56c3908-7544-4808-b8cb-5d0a55bcca72",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-29 08:13:52 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 10:02:05 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "**Scale and shift**\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Note that in addition to performing the normalization by subtracting the mean and dividing by the variance, we added two trainable parameters, a `scale` and a `shift` parameter\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- The initial `scale` (multiplying by 1) and `shift` (adding 0) values don't have any effect; however, `scale` and `shift` are trainable parameters that the LLM automatically adjusts during training if it is determined that doing so would improve the model's performance on its training task\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- This allows the model to learn appropriate scaling and shifting that best suit the data it is processing\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Note that we also add a smaller value (`eps`) before computing the square root of the variance; this is to avoid division-by-zero errors if the variance is 0\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "**Biased variance**\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-02-08 20:16:54 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "- In the variance calculation above, setting `unbiased=False` means using the formula $\\frac{\\sum_i (x_i - \\bar{x})^2}{n}$ to compute the variance where n is the sample size (here, the number of features or columns); this formula does not include Bessel's correction (which uses `n-1` in the denominator), thus providing a biased estimate of the  variance \n",
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 10:02:05 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "- For LLMs, where the embedding dimension `n` is very large, the difference between using n and `n-1`\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    " is negligible\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- However, GPT-2 was trained with a biased variance in the normalization layers, which is why we also adopted this setting for compatibility reasons with the pretrained weights that we will load in later chapters\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Let's now try out `LayerNorm` in practice:"
							 
						 
					
						
							
								
									
										
										
										
											2024-01-29 08:13:52 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-03 20:19:08 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 11,
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 10:02:05 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "id": "23b1000a-e613-4b43-bd90-e54deed8d292",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "ln = LayerNorm(emb_dim=5)\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-02-13 07:05:37 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "out_ln = ln(batch_example)"
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 10:02:05 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-03 20:19:08 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 12,
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 10:02:05 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "id": "94c12de2-1cab-46e0-a099-e2e470353bff",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Mean:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      " tensor([[    -0.0000],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "        [     0.0000]], grad_fn=<MeanBackward1>)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Variance:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      " tensor([[1.0000],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "        [1.0000]], grad_fn=<VarBackward0>)\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "mean = out_ln.mean(dim=-1, keepdim=True)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "var = out_ln.var(dim=-1, unbiased=False, keepdim=True)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(\"Mean:\\n\", mean)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(\"Variance:\\n\", var)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "e136cfc4-7c89-492e-b120-758c272bca8c",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 09:08:38 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch04_compressed/07.webp\" width=\"400px\">"
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 10:02:05 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "11190e7d-8c29-4115-824a-e03702f9dd54",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-02-10 10:10:14 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "## 4.3 Implementing a feed forward network with GELU activations"
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 10:02:05 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "b0585dfb-f21e-40e5-973f-2f63ad5cb169",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- In this section, we implement a small neural network submodule that is used as part of the transformer block in LLMs\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- We start with the activation function\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- In deep learning, ReLU (Rectified Linear Unit) activation functions are commonly used due to their simplicity and effectiveness in various neural network architectures\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-05-01 20:26:17 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "- In LLMs, various other types of activation functions are used beyond the traditional ReLU; two notable examples are GELU (Gaussian Error Linear Unit) and SwiGLU (Swish-Gated Linear Unit)\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 10:02:05 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "- GELU and SwiGLU are more complex, smooth activation functions incorporating Gaussian and sigmoid-gated linear units, respectively, offering better performance for deep learning models, unlike the simpler, piecewise linear function of ReLU"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "7d482ce7-e493-4bfc-a820-3ea99f564ebc",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-02-08 20:16:54 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "- GELU ([Hendrycks and Gimpel 2016](https://arxiv.org/abs/1606.08415)) can be implemented in several ways; the exact version is defined as GELU(x)=x⋅Φ(x), where Φ(x) is the cumulative distribution function of the standard Gaussian distribution.\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 10:02:05 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "- In practice, it's common to implement a computationally cheaper approximation: $\\text{GELU}(x) \\approx 0.5 \\cdot x \\cdot \\left(1 + \\tanh\\left[\\sqrt{\\frac{2}{\\pi}} \\cdot \\left(x + 0.044715 \\cdot x^3\\right)\\right]\\right)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "$ (the original GPT-2 model was also trained with this approximation)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-03 20:19:08 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 13,
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 10:02:05 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "id": "f84694b7-95f3-4323-b6d6-0a73df278e82",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-29 08:13:52 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "class GELU(nn.Module):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    def __init__(self):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        super().__init__()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    def forward(self, x):\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 10:02:05 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "        return 0.5 * x * (1 + torch.tanh(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            torch.sqrt(torch.tensor(2.0 / torch.pi)) * \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            (x + 0.044715 * torch.pow(x, 3))\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        ))"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-03 20:19:08 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 14,
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 10:02:05 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "id": "fc5487d2-2576-4118-80a7-56c4caac2e71",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "data": {
							 
						 
					
						
							
								
									
										
										
										
											2024-04-03 20:19:08 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAxYAAAEiCAYAAABkykQ1AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAAA9hAAAPYQGoP6dpAABn3klEQVR4nO3deVhUZfsH8O8My7AJiiAoICoqigsipKG5lYpbRSnZoqJmqWHlkiX+SjPfpDK33K2UJM19KTMVTVJzB1HRJBcQFzZllWUYZs7vD2QSAWXYzpnh+7muud53zpzlvmdyHu55zvM8MkEQBBAREREREVWBXOwAiIiIiIhI/7GwICIiIiKiKmNhQUREREREVcbCgoiIiIiIqoyFBRERERERVRkLCyIiIiIiqjIWFkREREREVGUsLIiIiIiIqMpYWBARERERUZWxsCAqw+effw6ZTCbKtUNDQyGTyRAfH1/r1y4sLMTHH38MFxcXyOVy+Pv713oMFSHme0REddvo0aPRrFkzUa4tZtv04MEDjBs3Do6OjpDJZJg8ebIocTyNmO8RsbCok+Li4jBp0iS0bt0aFhYWsLCwgIeHB4KCgnDhwoUS+xb/Ay3vkZSUBACIj4+HTCbDt99+W+51mzVrhiFDhpT52tmzZyGTyRAaGlpteT5Nbm4uPv/8c0RERNTaNR81b9487Nq1S5Rrl2ft2rWYP38+hg0bhp9++glTpkwRNR4pvkdEhqy4aC9+GBsbw8nJCaNHj8adO3cqdc6IiAjIZDJs27at3H1kMhkmTZpU5mvbtm2DTCar1e/qu3fv4vPPP0d0dHStXbOY2G1TeebNm4fQ0FBMnDgRYWFhGDlypGixSPU9IsBY7ACodu3ZswfDhw+HsbEx3nrrLXh6ekIul+PKlSvYsWMHVq5cibi4OLi6upY4buXKlbCysip1vvr169dS5NUvNzcXc+bMAQD07t27xGuffvopZsyYUaPXnzdvHoYNG1aqV2DkyJF4/fXXoVAoavT6Zfnzzz/h5OSERYsW1fq1yyLF94ioLvjiiy/QvHlz5Ofn4+TJkwgNDcWxY8cQExMDMzMzscOrcXfv3sWcOXPQrFkzdOrUqcRr33//PTQaTY1dW+y2qTx//vknnn32WcyePVuU6z9Kqu8RsbCoU65fv47XX38drq6uOHToEBo3blzi9a+//horVqyAXF66I2vYsGGws7OrrVBFZ2xsDGNjcf55GBkZwcjISJRrp6Sk6EWxKOZ7RFQXDBw4ED4+PgCAcePGwc7ODl9//TV+/fVXvPbaayJHJy4TExPRri1m25SSkgIPDw9Rrq0LMd8j4q1Qdco333yDnJwcrFu3rlRRART9Y/zggw/g4uIiQnQVk5aWho8++ggdOnSAlZUVrK2tMXDgQJw/f77Uvvn5+fj888/RunVrmJmZoXHjxnj11Vdx/fp1xMfHw97eHgAwZ84cbbf/559/DqD0PZrt27dHnz59Sl1Do9HAyckJw4YN02779ttv0a1bNzRs2BDm5ubw9vYudQuATCZDTk4OfvrpJ+21R48eDaD88QMrVqxAu3btoFAo0KRJEwQFBSEjI6PEPr1790b79u1x+fJl9OnTBxYWFnBycsI333zzxPe1+Fa2w4cP49KlS9qYIiIitLcxPN7lXHzMo7evjR49GlZWVrhz5w78/f1hZWUFe3t7fPTRR1Cr1aXeuyVLlqBDhw4wMzODvb09BgwYgLNnz0ryPSKqy3r06AGg6AeqR125cgXDhg2Dra0tzMzM4OPjg19//VWMEHHz5k289957cHd3h7m5ORo2bIiAgIAyx2JlZGRgypQpaNasGRQKBZydnTFq1Cjcu3cPEREReOaZZwAAY8aM0X7/FH/XPTrGQqVSwdbWFmPGjCl1jaysLJiZmeGjjz4CABQUFGDWrFnw9vaGjY0NLC0t0aNHDxw+fFh7jK5tE1A0Nm7u3Llwc3ODQqFAs2bNMHPmTCiVyhL7Fd+OfOzYMXTp0gVmZmZo0aIF1q9f/8T3tbgNiIuLw++//66NKT4+vtzv4rLaDV2+e6uz/a6N94j+w8KiDtmzZw9atmyJrl276nxsWloa7t27V+Lx+B9steHGjRvYtWsXhgwZgoULF2L69Om4ePEievXqhbt372r3U6vVGDJkCObMmQNvb28sWLAAH374ITIzMxETEwN7e3usXLkSAPDKK68gLCwMYWFhePXVV8u87vDhw3HkyBHtmJJix44dw927d/H6669rty1ZsgReXl744osvMG/ePBgbGyMgIAC///67dp+wsDAoFAr06NFDe+3x48eXm/fnn3+OoKAgNGnSBAsWLMDQoUOxevVq9O/fHyqVqsS+6enpGDBgADw9PbFgwQK0adMGn3zyCf74449yz29vb4+wsDC0adMGzs7O2pjatm1b7jHlUavV8PPzQ8OGDfHtt9+iV69eWLBgAdasWVNiv7fffhuTJ0+Gi4sLvv76a8yYMQNmZmY4efKkJN8jorqs+A/HBg0aaLddunQJzz77LP755x/MmDEDCxYsgKWlJfz9/bFz585aj/HMmTM4fvw4Xn/9dXz33XeYMGECDh06hN69eyM3N1e734MHD9CjRw8sXboU/fv3x5IlSzBhwgRcuXIFt2/fRtu2bfHFF18AAN59913t90/Pnj1LXdPExASvvPIKdu3ahYKCghKv7dq1C0qlUts+ZGVl4YcffkDv3r3x9ddf4/PPP0dqair8/Py0Yzl0bZuAoh6lWbNmoXPnzli0aBF69eqFkJCQEu1SsWvXrmHYsGHo168fFixYgAYNGmD06NG4dOlSuedv27YtwsLCYGdnh06dOmljKv7jXhcV+e6t7va7Nt4jeoRAdUJmZqYAQPD39y/1Wnp6upCamqp95Obmal+bPXu2AKDMh7u7u3a/uLg4AYAwf/78cmNwdXUVBg8eXOZrZ86cEQAI69ate2Ie+fn5glqtLrEtLi5OUCgUwhdffKHdtnbtWgGAsHDhwlLn0Gg0giAIQmpqqgBAmD17dql9ivMuFhsbKwAQli5dWmK/9957T7Cysirxnj36/wVBEAoKCoT27dsLzz//fIntlpaWQmBgYKlrr1u3TgAgxMXFCYIgCCkpKYKpqanQv3//ErkvW7ZMACCsXbtWu61Xr14CAGH9+vXabUqlUnB0dBSGDh1a6lqP69Wrl9CuXbsS2w4fPiwAEA4fPlxie/Fn/uhnFhgYKAAo8VkIgiB4eXkJ3t7e2ud//vmnAED44IMPSsVQ/PkIgjTfIyJDVvxv6+DBg0Jqaqpw69YtYdu2bYK9vb2gUCiEW7duafd94YUXhA4dOgj5+fnabRqNRujWrZvQqlUr7bbi75CtW7eWe10AQlBQUJmvbd26tczvoMc9/t0rCIJw4sSJUv/eZ82aJQAQduzYUWr/4u+fJ7VJgYGBgqurq/b5/v37BQDCb7/9VmK/QYMGCS1atNA+LywsFJRKZYl90tPTBQcHB2Hs2LHabbq0TdHR0QIAYdy4cSX2++ijjwQAwp9//qnd5urqKgAQjhw5ot2WkpIiKBQKYdq0aaWu9biy2vDHv4uLldVuVPS7t7rb79p8j0gQ2GNRR2RlZQFAmQOwe/fuDXt7e+1j+fLlpfbZvn07wsPDSzzWrVtX43E/TqFQaMeAqNVq3L9/H1ZWVnB3d0dUVFSJeO3s7PD++++XOkdlpqFr3bo1OnXqhM2bN2u3qdVqbNu2DS+++CLMzc212x/9/+np6cjMzESPHj1KxKeLgwcPoqCgAJMnTy4x/uWdd96BtbV1iZ4QoOgzHjFihPa5qakpunTpghs3blTq+pUxYcKEEs979OhR4vrbt2+HTCYrcxBgZT4ffXyPiKSsb9++sLe3h4uLC4YNGwZLS0v8+uuvcHZ2BlDUi/3nn3/itddeQ3Z2trYn+/79+/Dz88PVq1crPYtUZT363atSqXD//n20bNkS9evXL9U+eHp64pVXXil1jsp8/zz//POws7Mr0T6kp6cjPDwcw4cP124zMjKCqakpgKJbQdPS0lBYWAgfH59Ktw979+4FAEydOrXE9mnTpgF
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 10:02:05 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "text/plain": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "<Figure size 800x300 with 2 Axes>"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "display_data"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "import matplotlib.pyplot as plt\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "gelu, relu = GELU(), nn.ReLU()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "# Some sample data\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "x = torch.linspace(-3, 3, 100)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "y_gelu, y_relu = gelu(x), relu(x)\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-29 08:13:52 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 10:02:05 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "plt.figure(figsize=(8, 3))\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "for i, (y, label) in enumerate(zip([y_gelu, y_relu], [\"GELU\", \"ReLU\"]), 1):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    plt.subplot(1, 2, i)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    plt.plot(x, y)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    plt.title(f\"{label} activation function\")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    plt.xlabel(\"x\")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    plt.ylabel(f\"{label}(x)\")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    plt.grid(True)\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-29 08:13:52 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 10:02:05 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "plt.tight_layout()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "plt.show()"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "1cd01662-14cb-43fd-bffd-2d702813de2d",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- As we can see, ReLU is a piecewise linear function that outputs the input directly if it is positive; otherwise, it outputs zero\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- GELU is a smooth, non-linear function that approximates ReLU but with a non-zero gradient for negative values\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Next, let's implement the small neural network module, `FeedForward`, that we will be using in the LLM's transformer block later:"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-03 20:19:08 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 15,
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 10:02:05 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "id": "9275c879-b148-4579-a107-86827ca14d4d",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-01-29 08:13:52 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "class FeedForward(nn.Module):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    def __init__(self, cfg):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        super().__init__()\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 10:02:05 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "        self.layers = nn.Sequential(\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-29 08:13:52 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "            nn.Linear(cfg[\"emb_dim\"], 4 * cfg[\"emb_dim\"]),\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            GELU(),\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            nn.Linear(4 * cfg[\"emb_dim\"], cfg[\"emb_dim\"]),\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        )\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    def forward(self, x):\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 10:02:05 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "        return self.layers(x)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-03 20:19:08 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 16,
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 10:02:05 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "id": "7c4976e2-0261-418e-b042-c5be98c2ccaf",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "768\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(GPT_CONFIG_124M[\"emb_dim\"])"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "fdcaacfa-3cfc-4c9e-b668-b71a2753145a",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-04-03 20:19:08 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch04_compressed/09.webp?12\" width=\"400px\">"
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 10:02:05 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-03 20:19:08 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 17,
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 10:02:05 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "id": "928e7f7c-d0b1-499f-8d07-4cadb428a6f9",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
									
										
										
										
											2024-02-08 20:16:54 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "torch.Size([2, 3, 768])\n"
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 10:02:05 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "ffn = FeedForward(GPT_CONFIG_124M)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-02-08 20:16:54 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "# input shape: [batch_size, num_token, emb_size]\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "x = torch.rand(2, 3, 768) \n",
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 10:02:05 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "out = ffn(x)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(out.shape)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "8f8756c5-6b04-443b-93d0-e555a316c377",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 09:08:38 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch04_compressed/10.webp\" width=\"400px\">"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "e5da2a50-04f4-4388-af23-ad32e405a972",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch04_compressed/11.webp\" width=\"400px\">"
							 
						 
					
						
							
								
									
										
										
										
											2024-01-29 08:13:52 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "4ffcb905-53c7-4886-87d2-4464c5fecf89",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 10:02:05 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "## 4.4 Adding shortcut connections"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "ffae416c-821e-4bfa-a741-8af4ba5db00e",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Next, let's talk about the concept behind shortcut connections, also called skip or residual connections\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Originally, shortcut connections were proposed in deep networks for computer vision (residual networks) to mitigate vanishing gradient problems\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- A shortcut connection creates an alternative shorter path for the gradient to flow through the network\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- This is achieved by adding the output of one layer to the output of a later layer, usually skipping one or more layers in between\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Let's illustrate this idea with a small example network:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 09:08:38 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch04_compressed/12.webp\" width=\"400px\">"
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 10:02:05 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "14cfd241-a32e-4601-8790-784b82f2f23e",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- In code, it looks like this:"
							 
						 
					
						
							
								
									
										
										
										
											2024-01-29 08:13:52 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-03 20:19:08 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 18,
							 
						 
					
						
							
								
									
										
										
										
											2024-01-29 08:13:52 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "05473938-799c-49fd-86d4-8ed65f94fee6",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
									
										
										
										
											2024-02-15 19:34:12 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "class ExampleDeepNeuralNetwork(nn.Module):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    def __init__(self, layer_sizes, use_shortcut):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        super().__init__()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        self.use_shortcut = use_shortcut\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        self.layers = nn.ModuleList([\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            nn.Sequential(nn.Linear(layer_sizes[0], layer_sizes[1]), GELU()),\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            nn.Sequential(nn.Linear(layer_sizes[1], layer_sizes[2]), GELU()),\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            nn.Sequential(nn.Linear(layer_sizes[2], layer_sizes[3]), GELU()),\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            nn.Sequential(nn.Linear(layer_sizes[3], layer_sizes[4]), GELU()),\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            nn.Sequential(nn.Linear(layer_sizes[4], layer_sizes[5]), GELU())\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        ])\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    def forward(self, x):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        for layer in self.layers:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            # Compute the output of the current layer\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            layer_output = layer(x)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            # Check if shortcut can be applied\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-02-25 08:47:25 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "            if self.use_shortcut and x.shape == layer_output.shape:\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-02-15 19:34:12 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "                x = x + layer_output\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            else:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                x = layer_output\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        return x\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "def print_gradients(model, x):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    # Forward pass\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    output = model(x)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    target = torch.tensor([[0.]])\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    # Calculate loss based on how close the target\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    # and output are\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    loss = nn.MSELoss()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    loss = loss(output, target)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    # Backward pass to calculate the gradients\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    loss.backward()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    for name, param in model.named_parameters():\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        if 'weight' in name:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            # Print the mean absolute gradient of the weights\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            print(f\"{name} has gradient mean of {param.grad.abs().mean().item()}\")"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "b39bf277-b3db-4bb1-84ce-7a20caff1011",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Let's print the gradient values first **without** shortcut connections:"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-03 20:19:08 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 19,
							 
						 
					
						
							
								
									
										
										
										
											2024-02-15 19:34:12 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "id": "c75f43cc-6923-4018-b980-26023086572c",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
									
										
										
										
											2024-01-29 08:13:52 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
									
										
										
										
											2024-02-10 10:10:14 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
									
										
										
										
											2024-02-15 19:34:12 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "layers.0.0.weight has gradient mean of 0.00020173587836325169\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-03 20:19:08 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "layers.1.0.weight has gradient mean of 0.00012011159560643137\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "layers.2.0.weight has gradient mean of 0.0007152039906941354\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "layers.3.0.weight has gradient mean of 0.0013988736318424344\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "layers.4.0.weight has gradient mean of 0.005049645435065031\n"
							 
						 
					
						
							
								
									
										
										
										
											2024-02-10 10:10:14 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
									
										
										
										
											2024-01-29 08:13:52 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-02-15 19:34:12 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "layer_sizes = [3, 3, 3, 3, 3, 1]  \n",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-29 08:13:52 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-02-15 19:34:12 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "sample_input = torch.tensor([[1., 0., -1.]])\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-29 08:13:52 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "torch.manual_seed(123)\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-02-15 19:34:12 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "model_without_shortcut = ExampleDeepNeuralNetwork(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    layer_sizes, use_shortcut=False\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    ")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print_gradients(model_without_shortcut, sample_input)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "837fd5d4-7345-4663-97f5-38f19dfde621",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Next, let's print the gradient values **with** shortcut connections:"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-03 20:19:08 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 20,
							 
						 
					
						
							
								
									
										
										
										
											2024-02-15 19:34:12 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "id": "11b7c0c2-f9dd-4dd5-b096-a05c48c5f6d6",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "layers.0.0.weight has gradient mean of 0.22169792652130127\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-03 20:19:08 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "layers.1.0.weight has gradient mean of 0.20694106817245483\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-02-15 19:34:12 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "layers.2.0.weight has gradient mean of 0.32896995544433594\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-03 20:19:08 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "layers.3.0.weight has gradient mean of 0.2665732204914093\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "layers.4.0.weight has gradient mean of 1.3258540630340576\n"
							 
						 
					
						
							
								
									
										
										
										
											2024-02-15 19:34:12 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "torch.manual_seed(123)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "model_with_shortcut = ExampleDeepNeuralNetwork(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    layer_sizes, use_shortcut=True\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    ")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print_gradients(model_with_shortcut, sample_input)"
							 
						 
					
						
							
								
									
										
										
										
											2024-01-29 08:13:52 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 10:02:05 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "79ff783a-46f0-49c5-a7a9-26a525764b6e",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-02-15 19:34:12 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "- As we can see based on the output above, shortcut connections prevent the gradients from vanishing in the early layers (towards `layer.0`)\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 10:02:05 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "- We will use this concept of a shortcut connection next when we implement a transformer block"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2024-01-29 08:13:52 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "cae578ca-e564-42cf-8635-a2267047cdff",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 10:02:05 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "## 4.5 Connecting attention and linear layers in a transformer block"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "a3daac6f-6545-4258-8f2d-f45a7394f429",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- In this section, we now combine the previous concepts into a so-called transformer block\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- A transformer block combines the causal multi-head attention module from the previous chapter with the linear layers, the feed forward neural network we implemented in an earlier section\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- In addition, the transformer block also uses dropout and shortcut connections"
							 
						 
					
						
							
								
									
										
										
										
											2024-01-29 08:13:52 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-03 20:19:08 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 21,
							 
						 
					
						
							
								
									
										
										
										
											2024-01-29 08:13:52 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "0e1e8176-e5e3-4152-b1aa-0bbd7891dfd9",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "from previous_chapters import MultiHeadAttention\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "class TransformerBlock(nn.Module):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    def __init__(self, cfg):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        super().__init__()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        self.att = MultiHeadAttention(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            d_in=cfg[\"emb_dim\"],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            d_out=cfg[\"emb_dim\"],\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-04 07:58:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "            context_length=cfg[\"context_length\"],\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-29 08:13:52 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "            num_heads=cfg[\"n_heads\"], \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            dropout=cfg[\"drop_rate\"],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            qkv_bias=cfg[\"qkv_bias\"])\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        self.ff = FeedForward(cfg)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        self.norm1 = LayerNorm(cfg[\"emb_dim\"])\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        self.norm2 = LayerNorm(cfg[\"emb_dim\"])\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-28 14:31:27 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "        self.drop_shortcut = nn.Dropout(cfg[\"drop_rate\"])\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-29 08:13:52 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    def forward(self, x):\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 10:02:05 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "        # Shortcut connection for attention block\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        shortcut = x\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        x = self.norm1(x)\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-02-08 20:16:54 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "        x = self.att(x)  # Shape [batch_size, num_tokens, emb_size]\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-28 14:31:27 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "        x = self.drop_shortcut(x)\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 10:02:05 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "        x = x + shortcut  # Add the original input back\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-02-14 20:23:59 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "        # Shortcut connection for feed forward block\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 10:02:05 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "        shortcut = x\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        x = self.norm2(x)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        x = self.ff(x)\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-28 14:31:27 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "        x = self.drop_shortcut(x)\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 10:02:05 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "        x = x + shortcut  # Add the original input back\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-29 08:13:52 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "        return x"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 10:02:05 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "36b64d16-94a6-4d13-8c85-9494c50478a9",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-04-20 11:42:03 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch04_compressed/13.webp?1\" width=\"400px\">"
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 10:02:05 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "54d2d375-87bd-4153-9040-63a1e6a2b7cb",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Suppose we have 2 input samples with 6 tokens each, where each token is a 768-dimensional embedding vector; then this transformer block applies self-attention, followed by linear layers, to produce an output of similar size\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- You can think of the output as an augmented version of the context vectors we discussed in the previous chapter"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2024-01-29 08:13:52 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-03 20:19:08 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 22,
							 
						 
					
						
							
								
									
										
										
										
											2024-02-15 20:09:20 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "id": "3fb45a63-b1f3-4b08-b525-dafbc8228405",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Input shape: torch.Size([2, 4, 768])\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Output shape: torch.Size([2, 4, 768])\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "torch.manual_seed(123)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "x = torch.rand(2, 4, 768)  # Shape: [batch_size, num_tokens, emb_dim]\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "block = TransformerBlock(GPT_CONFIG_124M)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "output = block(x)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(\"Input shape:\", x.shape)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(\"Output shape:\", output.shape)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 15:57:03 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-20 11:42:03 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "id": "8f9e4ee4-cf23-4583-b1fd-317abb4fcd13",
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 15:57:03 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-04-20 11:42:03 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch04_compressed/14.webp?1\" width=\"400px\">"
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 15:57:03 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 10:02:05 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "46618527-15ac-4c32-ad85-6cfea83e006e",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "## 4.6 Coding the GPT model"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "dec7d03d-9ff3-4ca3-ad67-01b67c2f5457",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- We are almost there: now let's plug in the transformer block into the architecture we coded at the very beginning of this chapter so that we obtain a useable GPT architecture\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Note that the transformer block is repeated multiple times; in the case of the smallest 124M GPT-2 model, we repeat it 12 times:"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "9b7b362d-f8c5-48d2-8ebd-722480ac5073",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-04-20 11:42:03 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch04_compressed/15.webp\" width=\"400px\">"
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 10:02:05 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "324e4b5d-ed89-4fdf-9a52-67deee0593bc",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- The corresponding code implementation, where `cfg[\"n_layers\"] = 12`:"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-04 07:05:06 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 23,
							 
						 
					
						
							
								
									
										
										
										
											2024-01-29 08:13:52 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "c61de39c-d03c-4a32-8b57-f49ac3834857",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "class GPTModel(nn.Module):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    def __init__(self, cfg):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        super().__init__()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        self.tok_emb = nn.Embedding(cfg[\"vocab_size\"], cfg[\"emb_dim\"])\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-04 07:58:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "        self.pos_emb = nn.Embedding(cfg[\"context_length\"], cfg[\"emb_dim\"])\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-02-11 11:51:39 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "        self.drop_emb = nn.Dropout(cfg[\"drop_rate\"])\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-29 08:13:52 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "        \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        self.trf_blocks = nn.Sequential(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            *[TransformerBlock(cfg) for _ in range(cfg[\"n_layers\"])])\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        self.final_norm = LayerNorm(cfg[\"emb_dim\"])\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-02-05 06:51:58 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "        self.out_head = nn.Linear(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            cfg[\"emb_dim\"], cfg[\"vocab_size\"], bias=False\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        )\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-29 08:13:52 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    def forward(self, in_idx):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        batch_size, seq_len = in_idx.shape\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        tok_embeds = self.tok_emb(in_idx)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-02-08 20:16:54 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "        x = tok_embeds + pos_embeds  # Shape [batch_size, num_tokens, emb_size]\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-04 07:05:06 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "        x = self.drop_emb(x)\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-29 08:13:52 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "        x = self.trf_blocks(x)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        x = self.final_norm(x)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        logits = self.out_head(x)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        return logits"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 10:02:05 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "2750270f-c45d-4410-8767-a6adbd05d5c3",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Using the configuration of the 124M parameter model, we can now instantiate this GPT model with random initial weights as follows:"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2024-03-04 07:05:06 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-03 20:19:08 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 25,
							 
						 
					
						
							
								
									
										
										
										
											2024-03-04 07:05:06 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "id": "ef94fd9c-4e9d-470d-8f8e-dd23d1bb1f64",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Input batch:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      " tensor([[6109, 3626, 6100,  345],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "        [6109, 1110, 6622,  257]])\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Output shape: torch.Size([2, 4, 50257])\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-03 20:19:08 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "tensor([[[ 0.3613,  0.4222, -0.0711,  ...,  0.3483,  0.4661, -0.2838],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "         [-0.1792, -0.5660, -0.9485,  ...,  0.0477,  0.5181, -0.3168],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "         [ 0.7120,  0.0332,  0.1085,  ...,  0.1018, -0.4327, -0.2553],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "         [-1.0076,  0.3418, -0.1190,  ...,  0.7195,  0.4023,  0.0532]],\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-04 07:05:06 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-03 20:19:08 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "        [[-0.2564,  0.0900,  0.0335,  ...,  0.2659,  0.4454, -0.6806],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "         [ 0.1230,  0.3653, -0.2074,  ...,  0.7705,  0.2710,  0.2246],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "         [ 1.0558,  1.0318, -0.2800,  ...,  0.6936,  0.3205, -0.3178],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "         [-0.1565,  0.3926,  0.3288,  ...,  1.2630, -0.1858,  0.0388]]],\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-04 07:05:06 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "       grad_fn=<UnsafeViewBackward0>)\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "torch.manual_seed(123)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "model = GPTModel(GPT_CONFIG_124M)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "out = model(batch)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(\"Input batch:\\n\", batch)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(\"\\nOutput shape:\", out.shape)\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-29 08:13:52 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "print(out)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 10:02:05 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "6d616e7a-568b-4921-af29-bd3f4683cd2e",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- We will train this model in the next chapter\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- However, a quick note about its size: we previously referred to it as a 124M parameter model; we can double check this number as follows:"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2024-01-29 08:13:52 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-03 20:19:08 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 26,
							 
						 
					
						
							
								
									
										
										
										
											2024-01-29 08:13:52 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "84fb8be4-9d3b-402b-b3da-86b663aac33a",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 10:02:05 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "Total number of parameters: 163,009,536\n"
							 
						 
					
						
							
								
									
										
										
										
											2024-01-29 08:13:52 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "total_params = sum(p.numel() for p in model.parameters())\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 10:02:05 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "print(f\"Total number of parameters: {total_params:,}\")"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "b67d13dd-dd01-4ba6-a2ad-31ca8a9fd660",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- As we see above, this model has 163M, not 124M parameters; why?\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- In the original GPT-2 paper, the researchers applied weight tying, which means that they reused the token embedding layer (`tok_emb`) as the output layer, which means setting `self.out_head.weight = self.tok_emb.weight`\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- The token embedding layer projects the 50,257-dimensional one-hot encoded input tokens to a 768-dimensional embedding representation\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- The output layer projects 768-dimensional embeddings back into a 50,257-dimensional representation so that we can convert these back into words (more about that in the next section)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- So, the embedding and output layer have the same number of weight parameters, as we can see based on the shape of their weight matrices: the next chapter\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- However, a quick note about its size: we previously referred to it as a 124M parameter model; we can double check this number as follows:"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-03 20:19:08 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 27,
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 10:02:05 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "id": "e3b43233-e9b8-4f5a-b72b-a263ec686982",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Token embedding layer shape: torch.Size([50257, 768])\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Output layer shape: torch.Size([50257, 768])\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(\"Token embedding layer shape:\", model.tok_emb.weight.shape)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(\"Output layer shape:\", model.out_head.weight.shape)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "f02259f6-6f79-4c89-a866-4ebeae1c3289",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- In the original GPT-2 paper, the researchers reused the token embedding matrix as an output matrix\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Correspondingly, if we subtracted the number of parameters of the output layer, we'd get a 124M parameter model:"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-03 20:19:08 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 28,
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 10:02:05 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "id": "95a22e02-50d3-48b3-a4e0-d9863343c164",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Number of trainable parameters considering weight tying: 124,412,160\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "total_params_gpt2 =  total_params - sum(p.numel() for p in model.out_head.parameters())\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-29 08:13:52 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "print(f\"Number of trainable parameters considering weight tying: {total_params_gpt2:,}\")"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 10:02:05 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "40b03f80-b94c-46e7-9d42-d0df399ff3db",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- In practice, I found it easier to train the model without weight-tying, which is why we didn't implement it here\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- However, we will revisit and apply this weight-tying idea later when we load the pretrained weights in Chapter 6\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Lastly, we can compute the memory requirements of the model as follows, which can be a helpful reference point:"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2024-01-29 08:13:52 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-03 20:19:08 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 29,
							 
						 
					
						
							
								
									
										
										
										
											2024-01-29 08:13:52 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "5131a752-fab8-4d70-a600-e29870b33528",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 10:02:05 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "Total size of the model: 621.83 MB\n"
							 
						 
					
						
							
								
									
										
										
										
											2024-01-29 08:13:52 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "# Calculate the total size in bytes (assuming float32, 4 bytes per parameter)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "total_size_bytes = total_params * 4\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "# Convert to megabytes\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "total_size_mb = total_size_bytes / (1024 * 1024)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(f\"Total size of the model: {total_size_mb:.2f} MB\")"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 10:02:05 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "309a3be4-c20a-4657-b4e0-77c97510b47c",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-02-10 17:47:56 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "- Exercise: you can try the following other configurations, which are referenced in the [GPT-2 paper](https://d4mucfpksywv.cloudfront.net/better-language-models/language_models_are_unsupervised_multitask_learners.pdf), as well.\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 10:02:05 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-02-10 17:47:56 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    - **GPT2-small** (the 124M configuration we already implemented):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        - \"emb_dim\" = 768\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        - \"n_layers\" = 12\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        - \"n_heads\" = 12\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 10:02:05 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-02-10 17:47:56 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    - **GPT2-medium:**\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        - \"emb_dim\" = 1024\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        - \"n_layers\" = 24\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        - \"n_heads\" = 16\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    - **GPT2-large:**\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        - \"emb_dim\" = 1280\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        - \"n_layers\" = 36\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        - \"n_heads\" = 20\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    - **GPT2-XL:**\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        - \"emb_dim\" = 1600\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        - \"n_layers\" = 48\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        - \"n_heads\" = 25"
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 10:02:05 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2024-01-29 08:13:52 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "da5d9bc0-95ab-45d4-9378-417628d86e35",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-02-11 07:35:07 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "## 4.7 Generating text"
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 10:02:05 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "48da5deb-6ee0-4b9b-8dd2-abed7ed65172",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- LLMs like the GPT model we implemented above are used to generate one word at a time"
							 
						 
					
						
							
								
									
										
										
										
											2024-01-29 08:13:52 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2024-01-31 08:00:19 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 10:02:05 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "id": "caade12a-fe97-480f-939c-87d24044edff",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 09:08:38 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch04_compressed/16.webp\" width=\"400px\">"
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 10:02:05 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "a7061524-a3bd-4803-ade6-2e3b7b79ac13",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-31 08:00:19 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- The following `generate_text_simple` function implements greedy decoding, which is a simple and fast method to generate text\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-02-11 08:41:45 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "- In greedy decoding, at each step, the model chooses the word (or token) with the highest probability as its next output (the highest logit corresponds to the highest probability, so we technically wouldn't even have to compute the softmax function explicitly)\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 10:02:05 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "- In the next chapter, we will implement a more advanced `generate_text` function\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- The figure below depicts how the GPT model, given an input context, generates the next word token"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "7ee0f32c-c18c-445e-b294-a879de2aa187",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 09:08:38 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch04_compressed/17.webp\" width=\"600px\">"
							 
						 
					
						
							
								
									
										
										
										
											2024-01-31 08:00:19 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2024-01-29 08:13:52 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-03 20:19:08 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 30,
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 10:02:05 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "id": "c9b428a9-8764-4b36-80cd-7d4e00595ba6",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-29 08:13:52 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
									
										
										
										
											2024-01-31 08:00:19 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "def generate_text_simple(model, idx, max_new_tokens, context_size):\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-02-27 06:40:35 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    # idx is (batch, n_tokens) array of indices in the current context\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-31 08:00:19 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    for _ in range(max_new_tokens):\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 10:02:05 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "        \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        # Crop current context if it exceeds the supported context size\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        # E.g., if LLM supports only 5 tokens, and the context size is 10\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        # then only the last 5 tokens are used as context\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-31 08:00:19 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "        idx_cond = idx[:, -context_size:]\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 10:02:05 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "        \n",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-31 08:00:19 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "        # Get the predictions\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        with torch.no_grad():\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            logits = model(idx_cond)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        # Focus only on the last time step\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-02-27 06:40:35 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "        # (batch, n_tokens, vocab_size) becomes (batch, vocab_size)\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 10:02:05 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "        logits = logits[:, -1, :]  \n",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-31 08:00:19 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-02-11 08:41:45 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "        # Apply softmax to get probabilities\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        probas = torch.softmax(logits, dim=-1)  # (batch, vocab_size)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-02-11 08:42:21 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "        # Get the idx of the vocab entry with the highest probability value\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-02-11 08:41:45 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "        idx_next = torch.argmax(probas, dim=-1, keepdim=True)  # (batch, 1)\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-31 08:00:19 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        # Append sampled index to the running sequence\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 10:02:05 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "        idx = torch.cat((idx, idx_next), dim=1)  # (batch, n_tokens+1)\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-31 08:00:19 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    return idx"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 10:02:05 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "6515f2c1-3cc7-421c-8d58-cc2f563b7030",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- The `generate_text_simple` above implements an iterative process, where it creates one token at a time\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 09:08:38 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch04_compressed/18.webp\" width=\"600px\">"
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 10:02:05 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "f682eac4-f9bd-438b-9dec-6b1cc7bc05ce",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Let's prepare an input example:"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2024-01-31 08:00:19 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-03 20:19:08 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 31,
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 10:02:05 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "id": "3d7e3e94-df0f-4c0f-a6a1-423f500ac1d3",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-31 08:00:19 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 10:02:05 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "encoded: [15496, 11, 314, 716]\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-31 08:00:19 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "encoded_tensor.shape: torch.Size([1, 4])\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 10:02:05 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "start_context = \"Hello, I am\"\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-31 08:00:19 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "encoded = tokenizer.encode(start_context)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(\"encoded:\", encoded)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 10:02:05 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "encoded_tensor = torch.tensor(encoded).unsqueeze(0)\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-31 08:00:19 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "print(\"encoded_tensor.shape:\", encoded_tensor.shape)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-03 20:19:08 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 32,
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 10:02:05 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "id": "a72a9b60-de66-44cf-b2f9-1e638934ada4",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-31 08:00:19 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 15:57:03 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
									
										
										
										
											2024-02-11 11:51:39 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "Output: tensor([[15496,    11,   314,   716, 27018, 24086, 47843, 30961, 42348,  7267]])\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Output length: 10\n"
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 15:57:03 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
									
										
										
										
											2024-01-31 08:00:19 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "model.eval() # disable dropout\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "out = generate_text_simple(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    model=model,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    idx=encoded_tensor, \n",
							 
						 
					
						
							
								
									
										
										
										
											2024-02-11 11:51:39 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    max_new_tokens=6, \n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-04 07:58:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    context_size=GPT_CONFIG_124M[\"context_length\"]\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-31 08:00:19 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    ")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 10:02:05 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "print(\"Output:\", out)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(\"Output length:\", len(out[0]))"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "1d131c00-1787-44ba-bec3-7c145497b2c3",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Remove batch dimension and convert back into text:"
							 
						 
					
						
							
								
									
										
										
										
											2024-01-31 08:00:19 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-03 20:19:08 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 33,
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 10:02:05 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "id": "053d99f6-5710-4446-8d52-117fb34ea9f6",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-31 08:00:19 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 15:57:03 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
									
										
										
										
											2024-02-11 11:51:39 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "Hello, I am Featureiman Byeswickattribute argue\n"
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 15:57:03 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
									
										
										
										
											2024-01-31 08:00:19 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "decoded_text = tokenizer.decode(out.squeeze(0).tolist())\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(decoded_text)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 10:02:05 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "id": "9a894003-51f6-4ccc-996f-3b9c7d5a1d70",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-31 08:00:19 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-02-04 10:02:05 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "- Note that the model is untrained; hence the random output texts above\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-31 08:00:19 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "- We will train the model in the next chapter"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
									
										
										
										
											2024-01-29 08:13:52 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								  }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  "kernelspec": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "display_name": "Python 3 (ipykernel)",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "language": "python",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "name": "python3"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  "language_info": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "codemirror_mode": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "name": "ipython",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "version": 3
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "file_extension": ".py",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "mimetype": "text/x-python",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "name": "python",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "nbconvert_exporter": "python",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "pygments_lexer": "ipython3",
							 
						 
					
						
							
								
									
										
										
										
											2024-05-23 20:35:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "version": "3.11.4"
							 
						 
					
						
							
								
									
										
										
										
											2024-01-29 08:13:52 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								  }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 "nbformat": 4,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 "nbformat_minor": 5
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								}