2024-01-13 14:49:02 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								{
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 "cells": [
							 
						 
					
						
							
								
									
										
										
										
											2024-03-19 09:26:26 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "78224549-3637-44b0-aed1-8ff889c65192",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-05-24 07:20:37 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "<table style=\"width:100%\">\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "<tr>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "<td style=\"vertical-align:middle; text-align:left;\">\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "<font size=\"2\">\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "Supplementary code for the <a href=\"http://mng.bz/orYv\">Build a Large Language Model From Scratch</a> book by <a href=\"https://sebastianraschka.com\">Sebastian Raschka</a><br>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "<br>Code repository: <a href=\"https://github.com/rasbt/LLMs-from-scratch\">https://github.com/rasbt/LLMs-from-scratch</a>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "</font>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "</td>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "<td style=\"vertical-align:middle; text-align:left;\">\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "<a href=\"http://mng.bz/orYv\"><img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/cover-small.webp\" width=\"100px\"></a>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "</td>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "</tr>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "</table>\n"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-19 09:26:26 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2024-01-13 14:49:02 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "51c9672d-8d0c-470d-ac2d-1271f8ec3f14",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "# Chapter 3 Exercise solutions"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "33dfa199-9aee-41d4-a64b-7e3811b9a616",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "# Exercise 3.1"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "execution_count": 5,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "5fee2cf5-61c3-4167-81b5-44ea155bbaf2",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "import torch\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "inputs = torch.tensor(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "  [[0.43, 0.15, 0.89], # Your     (x^1)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "   [0.55, 0.87, 0.66], # journey  (x^2)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "   [0.57, 0.85, 0.64], # starts   (x^3)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "   [0.22, 0.58, 0.33], # with     (x^4)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "   [0.77, 0.25, 0.10], # one      (x^5)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "   [0.05, 0.80, 0.55]] # step     (x^6)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    ")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "d_in, d_out = 3, 2"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "execution_count": 58,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "62ea289c-41cd-4416-89dd-dde6383a6f70",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "import torch.nn as nn\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "class SelfAttention_v1(nn.Module):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    def __init__(self, d_in, d_out):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        super().__init__()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        self.d_out = d_out\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        self.W_query = nn.Parameter(torch.rand(d_in, d_out))\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        self.W_key   = nn.Parameter(torch.rand(d_in, d_out))\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        self.W_value = nn.Parameter(torch.rand(d_in, d_out))\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    def forward(self, x):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        keys = x @ self.W_key\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        queries = x @ self.W_query\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        values = x @ self.W_value\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        attn_scores = queries @ keys.T # omega\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-14 11:58:42 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-13 14:49:02 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        context_vec = attn_weights @ values\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        return context_vec\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "torch.manual_seed(123)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "sa_v1 = SelfAttention_v1(d_in, d_out)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "execution_count": 59,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "7b035143-f4e8-45fb-b398-dec1bd5153d4",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "class SelfAttention_v2(nn.Module):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    def __init__(self, d_in, d_out):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        super().__init__()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        self.d_out = d_out\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        self.W_query = nn.Linear(d_in, d_out, bias=False)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        self.W_key   = nn.Linear(d_in, d_out, bias=False)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        self.W_value = nn.Linear(d_in, d_out, bias=False)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    def forward(self, x):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        keys = self.W_key(x)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        queries = self.W_query(x)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        values = self.W_value(x)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        attn_scores = queries @ keys.T\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-14 11:58:42 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=1)\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-13 14:49:02 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        context_vec = attn_weights @ values\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        return context_vec\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "torch.manual_seed(123)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "sa_v2 = SelfAttention_v2(d_in, d_out)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "execution_count": 60,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "7591d79c-c30e-406d-adfd-20c12eb448f6",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "sa_v1.W_query = torch.nn.Parameter(sa_v2.W_query.weight.T)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "sa_v1.W_key = torch.nn.Parameter(sa_v2.W_key.weight.T)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "sa_v1.W_value = torch.nn.Parameter(sa_v2.W_value.weight.T)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "execution_count": 61,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "ddd0f54f-6bce-46cc-a428-17c2a56557d0",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "data": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "text/plain": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "tensor([[-0.5337, -0.1051],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "        [-0.5323, -0.1080],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "        [-0.5323, -0.1079],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "        [-0.5297, -0.1076],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "        [-0.5311, -0.1066],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "        [-0.5299, -0.1081]], grad_fn=<MmBackward0>)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "execution_count": 61,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "execute_result"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "sa_v1(inputs)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "execution_count": 62,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "340908f8-1144-4ddd-a9e1-a1c5c3d592f5",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "data": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "text/plain": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "tensor([[-0.5337, -0.1051],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "        [-0.5323, -0.1080],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "        [-0.5323, -0.1079],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "        [-0.5297, -0.1076],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "        [-0.5311, -0.1066],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								       "        [-0.5299, -0.1081]], grad_fn=<MmBackward0>)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "execution_count": 62,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "execute_result"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "sa_v2(inputs)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "33543edb-46b5-4b01-8704-f7f101230544",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "# Exercise 3.2"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "0588e209-1644-496a-8dae-7630b4ef9083",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "If we want to have an output dimension of 2, as earlier in single-head attention, we can have to change the projection dimension `d_out` to 1:"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "18e748ef-3106-4e11-a781-b230b74a0cef",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "```python\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "torch.manual_seed(123)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "d_out = 1\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-04 07:27:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "mha = MultiHeadAttentionWrapper(d_in, d_out, context_length, 0.0, num_heads=2)\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-13 14:49:02 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "context_vecs = mha(batch)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(context_vecs)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(\"context_vecs.shape:\", context_vecs.shape)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "```"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "78234544-d989-4f71-ac28-85a7ec1e6b7b",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "```\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "tensor([[[-9.1476e-02,  3.4164e-02],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "         [-2.6796e-01, -1.3427e-03],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "         [-4.8421e-01, -4.8909e-02],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "         [-6.4808e-01, -1.0625e-01],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "         [-8.8380e-01, -1.7140e-01],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "         [-1.4744e+00, -3.4327e-01]],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        [[-9.1476e-02,  3.4164e-02],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "         [-2.6796e-01, -1.3427e-03],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "         [-4.8421e-01, -4.8909e-02],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "         [-6.4808e-01, -1.0625e-01],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "         [-8.8380e-01, -1.7140e-01],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "         [-1.4744e+00, -3.4327e-01]]], grad_fn=<CatBackward0>)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "context_vecs.shape: torch.Size([2, 6, 2])\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "```"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "92bdabcb-06cf-4576-b810-d883bbd313ba",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "# Exercise 3.3"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "84c9b963-d01f-46e6-96bf-8eb2a54c5e42",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "```python\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-04 07:27:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "context_length = 1024\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-13 14:49:02 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "d_in, d_out = 768, 768\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "num_heads = 12\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-04 07:27:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads)\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-13 14:49:02 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "```"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "375d5290-8e8b-4149-958e-1efb58a69191",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "Optionally, the number of parameters is as follows:"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "6d7e603c-1658-4da9-9c0b-ef4bc72832b4",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "```python\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "def count_parameters(model):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    return sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "count_parameters(mha)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "```"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "51ba00bd-feb0-4424-84cb-7c2b1f908779",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "```\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "2360064  # (2.36 M)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "```"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "a56c1d47-9b95-4bd1-a517-580a6f779c52",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "The GPT-2 model has 117M parameters in total, but as we can see, most of its parameters are not in the multi-head attention module itself."
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  "kernelspec": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "display_name": "Python 3 (ipykernel)",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "language": "python",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "name": "python3"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  "language_info": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "codemirror_mode": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "name": "ipython",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "version": 3
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "file_extension": ".py",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "mimetype": "text/x-python",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "name": "python",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "nbconvert_exporter": "python",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "pygments_lexer": "ipython3",
							 
						 
					
						
							
								
									
										
										
										
											2024-05-27 07:32:45 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "version": "3.10.6"
							 
						 
					
						
							
								
									
										
										
										
											2024-01-13 14:49:02 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								  }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 "nbformat": 4,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 "nbformat_minor": 5
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								}