2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								{
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 "cells": [
							 
						 
					
						
							
								
									
										
										
										
											2024-03-19 09:26:26 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "1ae38945-39dd-45dc-ad4f-da7a4404241f",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "<font size=\"1\">\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "Supplementary code for \"Build a Large Language Model From Scratch\": <a href=\"https://www.manning.com/books/build-a-large-language-model-from-scratch\">https://www.manning.com/books/build-a-large-language-model-from-scratch</a> by <a href=\"https://sebastianraschka.com\">Sebastian Raschka</a><br>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "Code repository: <a href=\"https://github.com/rasbt/LLMs-from-scratch\">https://github.com/rasbt/LLMs-from-scratch</a>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "</font>"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "8bfa70ec-5c4c-40e8-b923-16f8167e3181",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-01-13 14:51:39 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "# Chapter 3: Coding Attention Mechanisms"
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2024-01-01 19:41:18 +01:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "c29bcbe8-a034-43a2-b557-997b03c9882d",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "Packages that are being used in this notebook:"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "execution_count": 1,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "e58f33e8-5dc9-4dd5-ab84-5a011fa11d92",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
									
										
										
										
											2024-04-18 05:56:23 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "torch version: 2.2.2\n"
							 
						 
					
						
							
								
									
										
										
										
											2024-01-01 19:41:18 +01:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "from importlib.metadata import version\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(\"torch version:\", version(\"torch\"))"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 09:29:06 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "a2a4474d-7c68-4846-8702-37906cf08197",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- This chapter covers attention mechanisms, the engine of LLMs:"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 09:08:38 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "02a11208-d9d3-44b1-8e0d-0c8414110b93",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch03_compressed/01.webp\" width=\"500px\">"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "50e020fd-9690-4343-80df-da96678bef5e",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch03_compressed/02.webp\" width=\"600px\">"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "ecc4dcee-34ea-4c05-9085-2f8887f70363",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "## 3.1 The problem with modeling long sequences"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "a55aa49c-36c2-48da-b1d9-70f416e46a6a",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 09:29:06 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "- No code in this section\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Translating a text word by word isn't feasible due to the differences in grammatical structures between the source and target languages:"
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 09:08:38 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "55c0c433-aa4b-491e-848a-54905ebb05ad",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch03_compressed/03.webp\" width=\"400px\">"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 09:29:06 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "db03c48a-3429-48ea-9d4a-2e53b0e516b1",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Prior to the introduction of transformer models, encoder-decoder RNNs were commonly used for machine translation tasks\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- In this setup, the encoder processes a sequence of tokens from the source language, using a hidden state—a kind of intermediate layer within the neural network—to generate a condensed representation of the entire input sequence:"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 09:08:38 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "03d8df2c-c1c2-4df0-9977-ade9713088b2",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 09:29:06 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch03_compressed/04.webp\" width=\"500px\">"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 09:08:38 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "3602c585-b87a-41c7-a324-c5e8298849df",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "## 3.2 Capturing data dependencies with attention mechanisms"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "b6fde64c-6034-421d-81d9-8244932086ea",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 09:29:06 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "- No code in this section\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Through an attention mechanism, the text-generating decoder segment of the network is capable of selectively accessing all input tokens, implying that certain input tokens hold more significance than others in the generation of a specific output token:"
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 09:08:38 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "bc4f6293-8ab5-4aeb-a04c-50ee158485b1",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 09:29:06 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch03_compressed/05.webp\" width=\"500px\">"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "8044be1f-e6a2-4a1f-a6dd-e325d3bad05e",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Self-attention in transformers is a technique designed to enhance input representations by enabling each position in a sequence to engage with and determine the relevance of every other position within the same sequence"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 09:08:38 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "6565dc9f-b1be-4c78-b503-42ccc743296c",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 09:29:06 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch03_compressed/06.webp\" width=\"300px\">"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 09:08:38 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "5efe05ff-b441-408e-8d66-cde4eb3397e3",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "## 3.3 Attending to different parts of the input with self-attention"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "6d9af516-7c37-4400-ab53-34936d5495a9",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "### 3.3.1 A simple self-attention mechanism without trainable weights"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "d269e9f1-df11-4644-b575-df338cf46cdf",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 09:29:06 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "- This section explains a very simplified variant of self-attention, which does not contain any trainable weights\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- This is purely for illustration purposes and NOT the attention mechanism that is used in transformers\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- The next section, section 3.3.2, will extend this simple attention mechanism to implement the real self-attention mechanism\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Suppose we are given an input sequence $x^{(1)}$ to $x^{(T)}$\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "  - The input is a text (for example, a sentence like \"Your journey starts with one step\") that has already been converted into token embeddings as described in chapter 2\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "  - For instance, $x^{(1)}$ is a d-dimensional vector representing the word \"Your\", and so forth\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- **Goal:** compute context vectors $z^{(i)}$ for each input sequence element $x^{(i)}$ in $x^{(1)}$ to $x^{(T)}$ (where $z$ and $x$ have the same dimension)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    - A context vector $z^{(i)}$ is a weighted sum over the inputs $x^{(1)}$ to $x^{(T)}$\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    - The context vector is \"context\"-specific to a certain input\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "      - Instead of $x^{(i)}$ as a placeholder for an arbitrary input token, let's consider the second input, $x^{(2)}$\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "      - And to continue with a concrete example, instead of the placeholder $z^{(i)}$, we consider the second output context vector, $z^{(2)}$\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "      - The second context vector, $z^{(2)}$, is a weighted sum over all inputs $x^{(1)}$ to $x^{(T)}$ weighted with respect to the second input element, $x^{(2)}$\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "      - The attention weights are the weights that determine how much each of the input elements contributes to the weighted sum when computing $z^{(2)}$\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "      - In short, think of $z^{(2)}$ as a modified version of $x^{(2)}$ that also incorporates information about all other input elements that are relevant to a given task at hand"
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 09:08:38 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "fcc7c7a2-b6ab-478f-ae37-faa8eaa8049a",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch03_compressed/07.webp\" width=\"400px\">"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "ff856c58-8382-44c7-827f-798040e6e697",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-04-02 18:27:13 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "- By convention, the unnormalized attention weights are referred to as **\"attention scores\"** whereas the normalized attention scores, which sum to 1, are referred to as **\"attention weights\"**\n"
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "01b10344-128d-462a-823f-2178dff5fd58",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 09:29:06 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "- The code below walks through the figure above step by step\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "<br>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 09:29:06 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "- **Step 1:** compute unnormalized attention scores $\\omega$\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "- Suppose we use the second input token as the query, that is, $q^{(2)} = x^{(2)}$, we compute the unnormalized attention scores via dot products:\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-02-11 14:17:55 +08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    - $\\omega_{21} = x^{(1)} q^{(2)\\top}$\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "    - $\\omega_{22} = x^{(2)} q^{(2)\\top}$\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    - $\\omega_{23} = x^{(3)} q^{(2)\\top}$\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    - ...\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    - $\\omega_{2T} = x^{(T)} q^{(2)\\top}$\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 09:29:06 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "- Above, $\\omega$ is the Greek letter \"omega\" used to symbolize the unnormalized attention scores\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    - The subscript \"21\" in $\\omega_{21}$ means that input sequence element 2 was used as a query against input sequence element 1"
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "35e55f7a-f2d0-4f24-858b-228e4fe88fb3",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Suppose we have the following input sentence that is already embedded in 3-dimensional vectors as described in chapter 3 (we use a very small embedding dimension here for illustration purposes, so that it fits onto the page without line breaks):"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-06 19:24:01 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 2,
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "22b9556a-aaf8-4ab4-a5b4-973372b0b2c3",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "import torch\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "inputs = torch.tensor(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "  [[0.43, 0.15, 0.89], # Your     (x^1)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "   [0.55, 0.87, 0.66], # journey  (x^2)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "   [0.57, 0.85, 0.64], # starts   (x^3)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "   [0.22, 0.58, 0.33], # with     (x^4)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "   [0.77, 0.25, 0.10], # one      (x^5)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "   [0.05, 0.80, 0.55]] # step     (x^6)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    ")"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 09:29:06 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "299baef3-b1a8-49ba-bad4-f62c8a416d83",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- The primary objective of this section is to demonstrate how the context vector $z^{(2)}$\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "  is calculated using the second input sequence, $x^{(2)}$, as a query\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- The figure depicts the initial step in this process, which involves calculating the attention scores ω between $x^{(2)}$\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "  and all other input elements through a dot product operation."
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 09:08:38 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "5cb3453a-58fa-42c4-b225-86850bc856f8",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch03_compressed/08.webp\" width=\"400px\">"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "77be52fb-82fd-4886-a4c8-f24a9c87af22",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- We use input sequence element 2, $x^{(2)}$, as an example to compute context vector $z^{(2)}$; later in this section, we will generalize this to compute all context vectors.\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- The first step is to compute the unnormalized attention scores by computing the dot product between the query $x^{(2)}$ and all other input tokens:"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-06 19:24:01 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 3,
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "6fb5b2f8-dd2c-4a6d-94ef-a0e9ad163951",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "query = inputs[1]  # 2nd input token is the query\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "attn_scores_2 = torch.empty(inputs.shape[0])\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "for i, x_i in enumerate(inputs):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    attn_scores_2[i] = torch.dot(x_i, query) # dot product (transpose not necessary here since they are 1-dim vectors)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(attn_scores_2)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "8df09ae0-199f-4b6f-81a0-2f70546684b8",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Side note: a dot product is essentially a shorthand for multiplying two vectors elements-wise and summing the resulting products:"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-06 19:24:01 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 4,
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "9842f39b-1654-410e-88bf-d1b899bf0241",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "tensor(0.9544)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "tensor(0.9544)\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "res = 0.\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "for idx, element in enumerate(inputs[0]):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    res += inputs[0][idx] * query[idx]\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(res)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(torch.dot(inputs[0], query))"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 09:08:38 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 09:29:06 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "id": "7d444d76-e19e-4e9a-a268-f315d966609b",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 09:08:38 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 09:29:06 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "- **Step 2:** normalize the unnormalized attention scores (\"omegas\", $\\omega$) so that they sum up to 1\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Here is a simple way to normalize the unnormalized attention scores to sum up to 1 (a convention, useful for interpretation, and important for training stability):"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 09:08:38 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 09:29:06 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "id": "dfd965d6-980c-476a-93d8-9efe603b1b3b",
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 09:29:06 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch03_compressed/09.webp\" width=\"500px\">"
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-06 19:24:01 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 5,
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "e3ccc99c-33ce-4f11-b7f2-353cf1cbdaba",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Attention weights: tensor([0.1455, 0.2278, 0.2249, 0.1285, 0.1077, 0.1656])\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Sum: tensor(1.0000)\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "attn_weights_2_tmp = attn_scores_2 / attn_scores_2.sum()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(\"Attention weights:\", attn_weights_2_tmp)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(\"Sum:\", attn_weights_2_tmp.sum())"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "75dc0a57-f53e-41bf-8793-daa77a819431",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- However, in practice, using the softmax function for normalization, which is better at handling extreme values and has more desirable gradient properties during training, is common and recommended.\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Here's a naive implementation of a softmax function for scaling, which also normalizes the vector elements such that they sum up to 1:"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-06 19:24:01 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 6,
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "07b2e58d-a6ed-49f0-a1cd-2463e8d53a20",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Attention weights: tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Sum: tensor(1.)\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "def softmax_naive(x):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    return torch.exp(x) / torch.exp(x).sum(dim=0)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "attn_weights_2_naive = softmax_naive(attn_scores_2)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(\"Attention weights:\", attn_weights_2_naive)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(\"Sum:\", attn_weights_2_naive.sum())"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "f0a1cbbb-4744-41cb-8910-f5c1355555fb",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 09:29:06 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "- The naive implementation above can suffer from numerical instability issues for large or small input values due to overflow and underflow issues\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "- Hence, in practice, it's recommended to use the PyTorch implementation of softmax instead, which has been highly optimized for performance:"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-06 19:24:01 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 7,
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "2d99cac4-45ea-46b3-b3c1-e000ad16e158",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Attention weights: tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Sum: tensor(1.)\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "attn_weights_2 = torch.softmax(attn_scores_2, dim=0)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(\"Attention weights:\", attn_weights_2)\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-12-26 22:05:21 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "print(\"Sum:\", attn_weights_2.sum())"
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "e43e36c7-90b2-427f-94f6-bb9d31b2ab3f",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- **Step 3**: compute the context vector $z^{(2)}$ by multiplying the embedded input tokens, $x^{(i)}$ with the attention weights and sum the resulting vectors:"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 09:08:38 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "f1c9f5ac-8d3d-4847-94e3-fd783b7d4d3d",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 09:29:06 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch03_compressed/10.webp\" width=\"500px\">"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 09:08:38 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-06 19:24:01 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 8,
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "8fcb96f0-14e5-4973-a50e-79ea7c6af99f",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "tensor([0.4419, 0.6515, 0.5683])\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "query = inputs[1] # 2nd input token is the query\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "context_vec_2 = torch.zeros(query.shape)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "for i,x_i in enumerate(inputs):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    context_vec_2 += attn_weights_2[i]*x_i\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(context_vec_2)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "5a454262-40eb-430e-9ca4-e43fb8d6cd89",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "### 3.3.2 Computing attention weights for all input tokens"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "6a02bb73-fc19-4c88-b155-8314de5d63a8",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "#### Generalize to all input sequence tokens:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 09:29:06 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "- Above, we computed the attention weights and context vector for input 2 (as illustrated in the highlighted row in the figure below)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Next, we are generalizing this computation to compute all attention weights and context vectors"
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "11c0fb55-394f-42f4-ba07-d01ae5c98ab4",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 09:08:38 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch03_compressed/11.webp\" width=\"400px\">"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 09:29:06 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "b789b990-fb51-4beb-9212-bf58876b5983",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- In self-attention, the process starts with the calculation of attention scores, which are subsequently normalized to derive attention weights that total 1\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- These attention weights are then utilized to generate the context vectors through a weighted summation of the inputs"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 09:08:38 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "d9bffe4b-56fe-4c37-9762-24bd924b7d3c",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch03_compressed/12.webp\" width=\"400px\">"
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "aa652506-f2c8-473c-a905-85c389c842cc",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Apply previous **step 1** to all pairwise elements to compute the unnormalized attention score matrix:"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-06 19:24:01 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 9,
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "04004be8-07a1-468b-ab33-32e16a551b45",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "attn_scores = torch.empty(6, 6)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "for i, x_i in enumerate(inputs):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    for j, x_j in enumerate(inputs):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        attn_scores[i, j] = torch.dot(x_i, x_j)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(attn_scores)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "1539187f-1ece-47b7-bc9b-65a97115f1d4",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- We can achieve the same as above more efficiently via matrix multiplication:"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-06 19:24:01 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 10,
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "2cea69d0-9a47-45da-8d5a-47ceef2df673",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "attn_scores = inputs @ inputs.T\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(attn_scores)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "02c4bac4-acfd-427f-9b11-c436ac71748d",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Similar to **step 2** previously, we normalize each row so that the values in each row sum to 1:"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-06 19:24:01 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 11,
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "fa4ef062-de81-47ee-8415-bfe1708c81b8",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "        [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "        [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "        [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "        [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-04-18 05:56:23 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "attn_weights = torch.softmax(attn_scores, dim=-1)\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "print(attn_weights)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "3fa6d02b-7f15-4eb4-83a7-0b8a819e7a0c",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Quick verification that the values in each row indeed sum to 1:"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-06 19:24:01 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 12,
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "112b492c-fb6f-4e6d-8df5-518ae83363d5",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Row 2 sum: 1.0\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "All row sums: tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000])\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "row_2_sum = sum([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(\"Row 2 sum:\", row_2_sum)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-18 05:56:23 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "print(\"All row sums:\", attn_weights.sum(dim=-1))"
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "138b0b5c-d813-44c7-b373-fde9540ddfd1",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Apply previous **step 3** to compute all context vectors:"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-06 19:24:01 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 13,
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "ba8eafcf-f7f7-4989-b8dc-61b50c4f81dc",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "tensor([[0.4421, 0.5931, 0.5790],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "        [0.4419, 0.6515, 0.5683],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "        [0.4431, 0.6496, 0.5671],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "        [0.4304, 0.6298, 0.5510],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "        [0.4671, 0.5910, 0.5266],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "        [0.4177, 0.6503, 0.5645]])\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2023-12-26 22:41:54 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "all_context_vecs = attn_weights @ inputs\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "print(all_context_vecs)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "25b245b8-7732-4fab-aa1c-e3d333195605",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- As a sanity check, the previously computed context vector $z^{(2)} = [0.4419, 0.6515, 0.5683]$ can be found in the 2nd row in above: "
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-06 19:24:01 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 14,
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "2570eb7d-aee1-457a-a61e-7544478219fa",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Previous 2nd context vector: tensor([0.4419, 0.6515, 0.5683])\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(\"Previous 2nd context vector:\", context_vec_2)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "a303b6fb-9f7e-42bb-9fdb-2adabf0a6525",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "## 3.4 Implementing self-attention with trainable weights"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 09:29:06 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "88363117-93d8-41fb-8240-f7cfe08b14a3",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- A conceptual framework illustrating how the self-attention mechanism developed in this section integrates into the overall narrative and structure of this book and chapter"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 09:08:38 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "ac9492ba-6f66-4f65-bd1d-87cf16d59928",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch03_compressed/13.webp\" width=\"400px\">"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "2b90a77e-d746-4704-9354-1ddad86e6298",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "### 3.4.1 Computing the attention weights step by step"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "46e95a46-1f67-4b71-9e84-8e2db84ab036",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 09:29:06 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "- In this section, we are implementing the self-attention mechanism that is used in the original transformer architecture, the GPT models, and most other popular LLMs\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- This self-attention mechanism is also called \"scaled dot-product attention\"\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "- The overall idea is similar to before:\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 09:29:06 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "  - We want to compute context vectors as weighted sums over the input vectors specific to a certain input element\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "  - For the above, we need attention weights\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "- As you will see, there are only slight differences compared to the basic attention mechanism introduced earlier:\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 09:29:06 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "  - The most notable difference is the introduction of weight matrices that are updated during model training\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "  - These trainable weight matrices are crucial so that the model (specifically, the attention module inside the model) can learn to produce \"good\" context vectors"
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 09:08:38 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "59db4093-93e8-4bee-be8f-c8fac8a08cdd",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 09:29:06 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch03_compressed/14.webp\" width=\"600px\">"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 09:08:38 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "4d996671-87aa-45c9-b2e0-07a7bcc9060a",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 09:29:06 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "- Implementing the self-attention mechanism step by step, we will start by introducing the three training weight matrices $W_q$, $W_k$, and $W_v$\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-10 08:01:19 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "- These three matrices are used to project the embedded input tokens, $x^{(i)}$, into query, key, and value vectors via matrix multiplication:\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "  - Query vector: $q^{(i)} = W_q \\,x^{(i)}$\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "  - Key vector: $k^{(i)} = W_k \\,x^{(i)}$\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "  - Value vector: $v^{(i)} = W_v \\,x^{(i)}$\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "9f334313-5fd0-477b-8728-04080a427049",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 09:29:06 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "- The embedding dimensions of the input $x$ and the query vector $q$ can be the same or different, depending on the model's design and specific implementation\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-10 08:01:19 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "- In GPT models, the input and output dimensions are usually the same, but for illustration purposes, to better follow the computation, we choose different input and output dimensions here:"
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-06 19:24:01 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 15,
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "8250fdc6-6cd6-4c5b-b9c0-8c643aadb7db",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-01-10 08:01:19 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "x_2 = inputs[1] # second input element\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "d_in = inputs.shape[1] # the input embedding size, d=3\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "d_out = 2 # the output embedding size, d=2"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "f528cfb3-e226-47dd-b363-cc2caaeba4bf",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 09:29:06 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "- Below, we initialize the three weight matrices; note that we are setting `requires_grad=False` to reduce clutter in the outputs for illustration purposes, but if we were to use the weight matrices for model training, we would set `requires_grad=True` to update these matrices during model training"
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-06 19:24:01 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 16,
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "bfd7259a-f26c-4cea-b8fc-282b5cae1e00",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "torch.manual_seed(123)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "W_key   = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "abfd0b50-7701-4adb-821c-e5433622d9c4",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Next we compute the query, key, and value vectors:"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-06 19:24:01 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 17,
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "73cedd62-01e1-4196-a575-baecc6095601",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "tensor([0.4306, 1.4551])\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-01-10 08:01:19 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "query_2 = x_2 @ W_query # _2 because it's with respect to the 2nd input element\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "key_2 = x_2 @ W_key \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "value_2 = x_2 @ W_value\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(query_2)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "9be308b3-aca3-421b-b182-19c3a03b71c7",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- As we can see below, we successfully projected the 6 input tokens from a 3D onto a 2D embedding space:"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-06 19:24:01 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 18,
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "8c1c3949-fc08-4d19-a41e-1c235b4e631b",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "keys.shape: torch.Size([6, 2])\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "values.shape: torch.Size([6, 2])\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "keys = inputs @ W_key \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "values = inputs @ W_value\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(\"keys.shape:\", keys.shape)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(\"values.shape:\", values.shape)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "bac5dfd6-ade8-4e7b-b0c1-bed40aa24481",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-02-12 18:32:28 +08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "- In the next step, **step 2**, we compute the unnormalized attention scores by computing the dot product between the query and each key vector:"
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "8ed0a2b7-5c50-4ede-90cf-7ad74412b3aa",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 09:29:06 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch03_compressed/15.webp\" width=\"600px\">"
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-06 19:24:01 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 19,
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "64cbc253-a182-4490-a765-246979ea0a28",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "tensor(1.8524)\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "keys_2 = keys[1] # Python starts index at 0\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "attn_score_22 = query_2.dot(keys_2)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(attn_score_22)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "9e9d15c0-c24e-4e6f-a160-6349b418f935",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Since we have 6 inputs, we have 6 attention scores for the given query vector:"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-06 19:24:01 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 20,
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "b14e44b5-d170-40f9-8847-8990804af26d",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "tensor([1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440])\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "attn_scores_2 = query_2 @ keys.T # All attention scores for given query\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(attn_scores_2)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "8622cf39-155f-4eb5-a0c0-82a03ce9b999",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 09:29:06 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch03_compressed/16.webp\" width=\"600px\">"
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "e1609edb-f089-461a-8de2-c20c1bb29836",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 09:29:06 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "- Next, in **step 3**, we compute the attention weights (normalized attention scores that sum up to 1) using the softmax function we used earlier\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-14 11:58:42 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "- The difference to earlier is that we now scale the attention scores by dividing them by the square root of the embedding dimension, $\\sqrt{d_k}$ (i.e., `d_k**0.5`):"
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-06 19:24:01 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 21,
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "146f5587-c845-4e30-9894-c7ed3a248153",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "tensor([0.1500, 0.2264, 0.2199, 0.1311, 0.0906, 0.1820])\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-01-14 11:58:42 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "d_k = keys.shape[1]\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "attn_weights_2 = torch.softmax(attn_scores_2 / d_k**0.5, dim=-1)\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "print(attn_weights_2)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "b8f61a28-b103-434a-aee1-ae7cbd821126",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 09:29:06 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch03_compressed/17.webp\" width=\"600px\">"
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "1890e3f9-db86-4ab8-9f3b-53113504a61f",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- In **step 4**, we now compute the context vector for input query vector 2:"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-06 19:24:01 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 22,
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "e138f033-fa7e-4e3a-8764-b53a96b26397",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "tensor([0.3061, 0.8210])\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "context_vec_2 = attn_weights_2 @ values\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(context_vec_2)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "9d7b2907-e448-473e-b46c-77735a7281d8",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "### 3.4.2 Implementing a compact SelfAttention class"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "04313410-3155-4d90-a7a3-2f3386e73677",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Putting it all together, we can implement the self-attention mechanism as follows:"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-06 19:24:01 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 23,
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "51590326-cdbe-4e62-93b1-17df71c11ee4",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "tensor([[0.2996, 0.8053],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "        [0.3061, 0.8210],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "        [0.3058, 0.8203],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "        [0.2948, 0.7939],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "        [0.2927, 0.7891],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "        [0.2990, 0.8040]], grad_fn=<MmBackward0>)\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "import torch.nn as nn\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "class SelfAttention_v1(nn.Module):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    def __init__(self, d_in, d_out):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        super().__init__()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        self.W_query = nn.Parameter(torch.rand(d_in, d_out))\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        self.W_key   = nn.Parameter(torch.rand(d_in, d_out))\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        self.W_value = nn.Parameter(torch.rand(d_in, d_out))\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    def forward(self, x):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        keys = x @ self.W_key\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        queries = x @ self.W_query\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        values = x @ self.W_value\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        attn_scores = queries @ keys.T # omega\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-05-18 12:08:39 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "        attn_weights = torch.softmax(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            attn_scores / keys.shape[-1]**0.5, dim=-1\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        )\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        context_vec = attn_weights @ values\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        return context_vec\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "torch.manual_seed(123)\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-13 14:49:02 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "sa_v1 = SelfAttention_v1(d_in, d_out)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(sa_v1(inputs))"
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 09:08:38 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "7ee1a024-84a5-425a-9567-54ab4e4ed445",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch03_compressed/18.webp\" width=\"400px\">"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "048e0c16-d911-4ec8-b0bc-45ceec75c081",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 09:29:06 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "- We can streamline the implementation above using PyTorch's Linear layers, which are equivalent to a matrix multiplication if we disable the bias units\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Another big advantage of using `nn.Linear` over our manual `nn.Parameter(torch.rand(...)` approach is that `nn.Linear` has a preferred weight initialization scheme, which leads to more stable model training"
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-06 19:24:01 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 24,
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "73f411e3-e231-464a-89fe-0a9035e5f839",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "tensor([[-0.0739,  0.0713],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "        [-0.0748,  0.0703],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "        [-0.0749,  0.0702],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "        [-0.0760,  0.0685],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "        [-0.0763,  0.0679],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "        [-0.0754,  0.0693]], grad_fn=<MmBackward0>)\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "class SelfAttention_v2(nn.Module):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-17 07:50:57 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    def __init__(self, d_in, d_out, qkv_bias=False):\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "        super().__init__()\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-17 07:50:57 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    def forward(self, x):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        keys = self.W_key(x)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        queries = self.W_query(x)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        values = self.W_value(x)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        attn_scores = queries @ keys.T\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-18 05:56:23 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        context_vec = attn_weights @ values\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        return context_vec\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "torch.manual_seed(789)\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-13 14:49:02 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "sa_v2 = SelfAttention_v2(d_in, d_out)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(sa_v2(inputs))"
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "915cd8a5-a895-42c9-8b8e-06b5ae19ffce",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 09:29:06 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "- Note that `SelfAttention_v1` and `SelfAttention_v2` give different outputs because they use different initial weights for the weight matrices"
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "c5025b37-0f2c-4a67-a7cb-1286af7026ab",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-01-13 14:49:02 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "## 3.5 Hiding future words with causal attention"
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 09:29:06 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "aef0a6b8-205a-45bf-9d26-8fd77a8a03c3",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-04-03 07:07:21 +08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "- In causal attention, the attention weights above the diagonal are masked, ensuring that for any given input, the LLM is unable to utilize future tokens while calculating the context vectors with the attention weight"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 09:29:06 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 09:08:38 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "71e91bb5-5aae-4f05-8a95-973b3f988a35",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch03_compressed/19.webp\" width=\"400px\">"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "82f405de-cd86-4e72-8f3c-9ea0354946ba",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "### 3.5.1 Applying a causal attention mask"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "014f28d0-8218-48e4-8b9c-bdc5ce489218",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 09:29:06 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "- In this section, we are converting the previous self-attention mechanism into a causal self-attention mechanism\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Causal self-attention ensures that the model's prediction for a certain position in a sequence is only dependent on the known outputs at previous positions, not on future positions\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- In simpler words, this ensures that each next word prediction should only depend on the preceding words\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "- To achieve this, for each given token, we mask out the future tokens (the ones that come after the current token in the input text):"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 09:08:38 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "id": "57f99af3-32bc-48f5-8eb4-63504670ca0a",
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 09:29:06 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch03_compressed/20.webp\" width=\"600px\">"
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "cbfaec7a-68f2-4157-a4b5-2aeceed199d9",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- To illustrate and implement causal self-attention, let's work with the attention scores and weights from the previous section: "
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-06 19:24:01 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 25,
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "1933940d-0fa5-4b17-a3ce-388e5314a1bb",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
									
										
										
										
											2024-02-14 20:23:59 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "tensor([[0.1921, 0.1646, 0.1652, 0.1550, 0.1721, 0.1510],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "        [0.2041, 0.1659, 0.1662, 0.1496, 0.1665, 0.1477],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "        [0.2036, 0.1659, 0.1662, 0.1498, 0.1664, 0.1480],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "        [0.1869, 0.1667, 0.1668, 0.1571, 0.1661, 0.1564],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "        [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.1585],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "       grad_fn=<SoftmaxBackward0>)\n"
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-02-14 20:23:59 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "# Reuse the query and key weight matrices of the\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "# SelfAttention_v2 object from the previous section for convenience\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "queries = sa_v2.W_query(inputs)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "keys = sa_v2.W_key(inputs) \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "attn_scores = queries @ keys.T\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-18 05:56:23 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "print(attn_weights)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "89020a96-b34d-41f8-9349-98c3e23fd5d6",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- The simplest way to mask out future attention weights is by creating a mask via PyTorch's tril function with elements below the main diagonal (including the diagonal itself) set to 1 and above the main diagonal set to 0:"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-13 14:49:02 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 26,
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "43f3d2e3-185b-4184-9f98-edde5e6df746",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "tensor([[1., 0., 0., 0., 0., 0.],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "        [1., 1., 0., 0., 0., 0.],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "        [1., 1., 1., 0., 0., 0.],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "        [1., 1., 1., 1., 0., 0.],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "        [1., 1., 1., 1., 1., 0.],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "        [1., 1., 1., 1., 1., 1.]])\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-04-04 07:27:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "context_length = attn_scores.shape[0]\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "mask_simple = torch.tril(torch.ones(context_length, context_length))\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-06 19:24:01 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "print(mask_simple)"
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "efce2b08-3583-44da-b3fc-cabdd38761f6",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Then, we can multiply the attention weights with this mask to zero out the attention scores above the diagonal:"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-13 14:49:02 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 27,
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "9f531e2e-f4d2-4fea-a87f-4c132e48b9e7",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
									
										
										
										
											2024-02-14 20:23:59 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "tensor([[0.1921, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "        [0.2041, 0.1659, 0.0000, 0.0000, 0.0000, 0.0000],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "        [0.2036, 0.1659, 0.1662, 0.0000, 0.0000, 0.0000],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "        [0.1869, 0.1667, 0.1668, 0.1571, 0.0000, 0.0000],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "        [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.0000],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "       grad_fn=<MulBackward0>)\n"
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-01-06 19:24:01 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "masked_simple = attn_weights*mask_simple\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(masked_simple)"
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "3eb35787-cf12-4024-b66d-e7215e175500",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 09:29:06 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "- However, if the mask were applied after softmax, like above, it would disrupt the probability distribution created by softmax\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Softmax ensures that all output values sum to 1\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Masking after softmax would require re-normalizing the outputs to sum to 1 again, which complicates the process and might lead to unintended effects"
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2024-01-06 19:24:01 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "94db92d7-c397-4e42-bd8a-6a2b3e237e0f",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- To make sure that the rows sum to 1, we can normalize the attention weights as follows:"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-13 14:49:02 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 28,
							 
						 
					
						
							
								
									
										
										
										
											2024-01-06 19:24:01 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "id": "6d392083-fd81-4f70-9bdf-8db985e673d6",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-02-14 20:23:59 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "        [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "        [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "        [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "        [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "       grad_fn=<DivBackward0>)\n"
							 
						 
					
						
							
								
									
										
										
										
											2024-01-06 19:24:01 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "row_sums = masked_simple.sum(dim=1, keepdim=True)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "masked_simple_norm = masked_simple / row_sums\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(masked_simple_norm)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "512e7cf4-dc0e-4cec-948e-c7a3c4eb6877",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 09:29:06 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "- While we are technically done with coding the causal attention mechanism now, let's briefly look at a more efficient approach to achieve the same as above\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-06 19:24:01 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "- So, instead of zeroing out attention weights above the diagonal and renormalizing the results, we can mask the unnormalized attention scores above the diagonal with negative infinity before they enter the softmax function:"
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 09:08:38 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "eb682900-8df2-4767-946c-a82bee260188",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 09:29:06 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch03_compressed/21.webp\" width=\"450px\">"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 09:08:38 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-13 14:49:02 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 29,
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "a2be2f43-9cf0-44f6-8d8b-68ef2fb3cc39",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
									
										
										
										
											2024-02-14 20:23:59 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "tensor([[0.2899,   -inf,   -inf,   -inf,   -inf,   -inf],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "        [0.4656, 0.1723,   -inf,   -inf,   -inf,   -inf],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "        [0.4594, 0.1703, 0.1731,   -inf,   -inf,   -inf],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "        [0.2642, 0.1024, 0.1036, 0.0186,   -inf,   -inf],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "        [0.2183, 0.0874, 0.0882, 0.0177, 0.0786,   -inf],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "        [0.3408, 0.1270, 0.1290, 0.0198, 0.1290, 0.0078]],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "       grad_fn=<MaskedFillBackward0>)\n"
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-04-04 07:27:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "masked = attn_scores.masked_fill(mask.bool(), -torch.inf)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(masked)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "91d5f803-d735-4543-b9da-00ac10fb9c50",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- As we can see below, now the attention weights in each row correctly sum to 1 again:"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-13 14:49:02 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 30,
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "b1cd6d7f-16f2-43c1-915e-0824f1a4bc52",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-02-14 20:23:59 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "        [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "        [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "        [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "        [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "       grad_fn=<SoftmaxBackward0>)\n"
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-04-18 05:56:23 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "attn_weights = torch.softmax(masked / keys.shape[-1]**0.5, dim=-1)\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "print(attn_weights)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "7636fc5f-6bc6-461e-ac6a-99ec8e3c0912",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "### 3.5.2 Masking additional attention weights with dropout"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "ec3dc7ee-6539-4fab-804a-8f31a890c85a",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 09:29:06 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "- In addition, we also apply dropout to reduce overfitting during training\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "- Dropout can be applied in several places:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "  - for example, after computing the attention weights;\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 09:29:06 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "  - or after multiplying the attention weights with the value vectors\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Here, we will apply the dropout mask after computing the attention weights because it's more common\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 09:29:06 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "- Furthermore, in this specific example, we use a dropout rate of 50%, which means randomly masking out half of the attention weights. (When we train the GPT model later, we will use a lower dropout rate, such as 0.1 or 0.2"
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "ee799cf6-6175-45f2-827e-c174afedb722",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 09:29:06 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch03_compressed/22.webp\" width=\"400px\">"
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "5a575458-a6da-4e54-8688-83e155f2de06",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- If we apply a dropout rate of 0.5 (50%), the non-dropped values will be scaled accordingly by a factor of 1/0.5 = 2."
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-13 14:49:02 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 31,
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "0de578db-8289-41d6-b377-ef645751e33f",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "tensor([[2., 2., 0., 2., 2., 0.],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "        [0., 0., 0., 2., 0., 2.],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "        [2., 2., 2., 2., 0., 2.],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "        [0., 2., 2., 0., 0., 2.],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "        [0., 2., 0., 2., 0., 2.],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "        [0., 2., 2., 2., 2., 0.]])\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "torch.manual_seed(123)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "dropout = torch.nn.Dropout(0.5) # dropout rate of 50%\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "example = torch.ones(6, 6) # create a matrix of ones\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(dropout(example))"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-13 14:49:02 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 32,
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "b16c5edb-942b-458c-8e95-25e4e355381e",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "tensor([[2.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-02-14 20:23:59 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "        [0.7599, 0.6194, 0.6206, 0.0000, 0.0000, 0.0000],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "        [0.0000, 0.4921, 0.4925, 0.0000, 0.0000, 0.0000],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "        [0.0000, 0.3966, 0.0000, 0.3775, 0.0000, 0.0000],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "        [0.0000, 0.3327, 0.3331, 0.3084, 0.3331, 0.0000]],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "       grad_fn=<MulBackward0>)\n"
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "torch.manual_seed(123)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(dropout(attn_weights))"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2024-03-10 07:58:10 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "269df5c8-3e25-49d0-95d3-bb232287404f",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Note that the resulting dropout outputs may look different depending on your operating system; you can read more about this inconsistency [here on the PyTorch issue tracker](https://github.com/pytorch/pytorch/issues/121595)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2024-01-13 14:49:02 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "cdc14639-5f0f-4840-aa9d-8eb36ea90fb7",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "### 3.5.3 Implementing a compact causal self-attention class"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "09c41d29-1933-43dc-ada6-2dbb56287204",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 09:29:06 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "- Now, we are ready to implement a working implementation of self-attention, including the causal and dropout masks\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- One more thing is to implement the code to handle batches consisting of more than one input so that our `CausalAttention` class supports the batch outputs produced by the data loader we implemented in chapter 2\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "- For simplicity, to simulate such batch input, we duplicate the input text example:"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-13 14:49:02 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 33,
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "977a5fa7-a9d5-4e2e-8a32-8e0331ccfe28",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "torch.Size([2, 6, 3])\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "batch = torch.stack((inputs, inputs), dim=0)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(batch.shape) # 2 inputs with 6 tokens each, and each token has embedding dimension 3"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-13 14:49:02 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 34,
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "60d8c2eb-2d8e-4d2c-99bc-9eef8cc53ca0",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
									
										
										
										
											2024-03-04 18:54:43 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "tensor([[[-0.4519,  0.2216],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "         [-0.5874,  0.0058],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "         [-0.6300, -0.0632],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "         [-0.5675, -0.0843],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "         [-0.5526, -0.0981],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "         [-0.5299, -0.1081]],\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								      "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-04 18:54:43 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "        [[-0.4519,  0.2216],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "         [-0.5874,  0.0058],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "         [-0.6300, -0.0632],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "         [-0.5675, -0.0843],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "         [-0.5526, -0.0981],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "         [-0.5299, -0.1081]]], grad_fn=<UnsafeViewBackward0>)\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								      "context_vecs.shape: torch.Size([2, 6, 2])\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-01-13 14:49:02 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "class CausalAttention(nn.Module):\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-05-18 12:08:39 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    def __init__(self, d_in, d_out, context_length,\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                 dropout, qkv_bias=False):\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "        super().__init__()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        self.d_out = d_out\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-17 07:50:57 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "        self.dropout = nn.Dropout(dropout) # New\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-04 07:27:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1)) # New\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    def forward(self, x):\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-14 11:58:42 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "        b, num_tokens, d_in = x.shape # New batch dimension b\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "        keys = self.W_key(x)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        queries = self.W_query(x)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        values = self.W_value(x)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        attn_scores = queries @ keys.transpose(1, 2) # Changed transpose\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        attn_scores.masked_fill_(  # New, _ ops are in-place\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-14 11:58:42 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "            self.mask.bool()[:num_tokens, :num_tokens], -torch.inf) \n",
							 
						 
					
						
							
								
									
										
										
										
											2024-05-18 12:08:39 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "        attn_weights = torch.softmax(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            attn_scores / keys.shape[-1]**0.5, dim=-1\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        )\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "        attn_weights = self.dropout(attn_weights) # New\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        context_vec = attn_weights @ values\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        return context_vec\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "torch.manual_seed(123)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-04 07:27:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "context_length = batch.shape[1]\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "ca = CausalAttention(d_in, d_out, context_length, 0.0)\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-13 14:49:02 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "context_vecs = ca(batch)\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(context_vecs)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(\"context_vecs.shape:\", context_vecs.shape)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "c4333d12-17e4-4bb5-9d83-54b3a32618cd",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 09:29:06 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "- Note that dropout is only applied during training, not during inference"
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 09:08:38 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "a554cf47-558c-4f45-84cd-bf9b839a8d50",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 09:29:06 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch03_compressed/23.webp\" width=\"500px\">"
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 09:08:38 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "c8bef90f-cfd4-4289-b0e8-6a00dc9be44c",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "## 3.6 Extending single-head attention to multi-head attention"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "11697757-9198-4a1c-9cee-f450d8bbd3b9",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "### 3.6.1 Stacking multiple single-head attention layers"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "70766faf-cd53-41d9-8a17-f1b229756a5a",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 09:29:06 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "- Below is a summary of the self-attention implemented previously (causal and dropout masks not shown for simplicity)\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- This is also called single-head attention:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 09:08:38 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch03_compressed/24.webp\" width=\"400px\">\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- We simply stack multiple single-head attention modules to obtain a multi-head attention module:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 09:08:38 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch03_compressed/25.webp\" width=\"400px\">\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- The main idea behind multi-head attention is to run the attention mechanism multiple times (in parallel) with different, learned linear projections. This allows the model to jointly attend to information from different representation subspaces at different positions."
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-14 11:58:42 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 35,
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "b9a66e11-7105-4bb4-be84-041f1a1f3bd2",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
									
										
										
										
											2024-03-04 18:54:43 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "tensor([[[-0.4519,  0.2216,  0.4772,  0.1063],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "         [-0.5874,  0.0058,  0.5891,  0.3257],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "         [-0.6300, -0.0632,  0.6202,  0.3860],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "         [-0.5675, -0.0843,  0.5478,  0.3589],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "         [-0.5526, -0.0981,  0.5321,  0.3428],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "         [-0.5299, -0.1081,  0.5077,  0.3493]],\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								      "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-04 18:54:43 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "        [[-0.4519,  0.2216,  0.4772,  0.1063],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "         [-0.5874,  0.0058,  0.5891,  0.3257],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "         [-0.6300, -0.0632,  0.6202,  0.3860],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "         [-0.5675, -0.0843,  0.5478,  0.3589],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "         [-0.5526, -0.0981,  0.5321,  0.3428],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "         [-0.5299, -0.1081,  0.5077,  0.3493]]], grad_fn=<CatBackward0>)\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								      "context_vecs.shape: torch.Size([2, 6, 4])\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "class MultiHeadAttentionWrapper(nn.Module):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-04 07:27:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "        super().__init__()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        self.heads = nn.ModuleList(\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-04 07:27:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "            [CausalAttention(d_in, d_out, context_length, dropout, qkv_bias) \n",
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "             for _ in range(num_heads)]\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        )\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    def forward(self, x):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        return torch.cat([head(x) for head in self.heads], dim=-1)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "torch.manual_seed(123)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-04 07:27:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "context_length = batch.shape[1] # This is the number of tokens\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-13 14:49:02 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "d_in, d_out = 3, 2\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-05-18 12:08:39 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "mha = MultiHeadAttentionWrapper(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    d_in, d_out, context_length, 0.0, num_heads=2\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    ")\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "context_vecs = mha(batch)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(context_vecs)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(\"context_vecs.shape:\", context_vecs.shape)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "193d3d2b-2578-40ba-b791-ea2d49328e48",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 09:29:06 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "- In the implementation above, the embedding dimension is 4, because we `d_out=2` as the embedding dimension for the key, query, and value vectors as well as the context vector. And since we have 2 attention heads, we have the output embedding dimension 2*2=4"
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "6836b5da-ef82-4b4c-bda1-72a462e48d4e",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "### 3.6.2 Implementing multi-head attention with weight splits"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "f4b48d0d-71ba-4fa0-b714-ca80cabcb6f7",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 09:29:06 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "- While the above is an intuitive and fully functional implementation of multi-head attention (wrapping the single-head attention `CausalAttention` implementation from earlier), we can write a stand-alone class called `MultiHeadAttention` to achieve the same\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 09:29:06 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "- We don't concatenate single attention heads for this stand-alone `MultiHeadAttention` class\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Instead, we create single W_query, W_key, and W_value weight matrices and then split those into individual matrices for each attention head:"
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-18 05:56:23 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 36,
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "110b0188-6e9e-4e56-a988-10523c6c8538",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "tensor([[[0.3190, 0.4858],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "         [0.2943, 0.3897],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "         [0.2856, 0.3593],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "         [0.2693, 0.3873],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "         [0.2639, 0.3928],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "         [0.2575, 0.4028]],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "        [[0.3190, 0.4858],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "         [0.2943, 0.3897],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "         [0.2856, 0.3593],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "         [0.2693, 0.3873],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "         [0.2639, 0.3928],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "         [0.2575, 0.4028]]], grad_fn=<ViewBackward0>)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "context_vecs.shape: torch.Size([2, 6, 2])\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "class MultiHeadAttention(nn.Module):\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-04 07:27:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "        super().__init__()\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-05-18 12:08:39 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "        assert (d_out % num_heads == 0), \\\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            \"d_out must be divisible by num_heads\"\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        self.d_out = d_out\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        self.num_heads = num_heads\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-17 07:50:57 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "        self.out_proj = nn.Linear(d_out, d_out)  # Linear layer to combine head outputs\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        self.dropout = nn.Dropout(dropout)\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-05-18 12:08:39 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "        self.register_buffer(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            \"mask\",\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            torch.triu(torch.ones(context_length, context_length),\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                       diagonal=1)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        )\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    def forward(self, x):\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-13 14:49:02 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "        b, num_tokens, d_in = x.shape\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-13 14:49:02 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "        keys = self.W_key(x) # Shape: (b, num_tokens, d_out)\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "        queries = self.W_query(x)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        values = self.W_value(x)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-13 14:49:02 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "        # We implicitly split the matrix by adding a `num_heads` dimension\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim) \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        values = values.view(b, num_tokens, self.num_heads, self.head_dim)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-13 14:49:02 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "        # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "        keys = keys.transpose(1, 2)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        queries = queries.transpose(1, 2)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        values = values.transpose(1, 2)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-13 14:49:02 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "        # Compute scaled dot-product attention (aka self-attention) with a causal mask\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "        attn_scores = queries @ keys.transpose(2, 3)  # Dot product for each head\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-09 17:42:25 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-15 07:36:19 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "        # Original mask truncated to the number of tokens and converted to boolean\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-09 17:42:25 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        # Use the mask to fill attention scores\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        attn_scores.masked_fill_(mask_bool, -torch.inf)\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-15 07:36:19 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "        \n",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-14 11:58:42 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "        attn_weights = self.dropout(attn_weights)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-15 07:36:19 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "        # Shape: (b, num_tokens, num_heads, head_dim)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        context_vec = (attn_weights @ values).transpose(1, 2) \n",
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "        \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        # Combine heads, where self.d_out = self.num_heads * self.head_dim\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-01-13 14:49:02 -06:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "        context_vec = self.out_proj(context_vec) # optional projection\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        return context_vec\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "torch.manual_seed(123)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-04 07:27:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "batch_size, context_length, d_in = batch.shape\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "d_out = 2\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-04 07:27:41 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads=2)\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "context_vecs = mha(batch)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(context_vecs)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(\"context_vecs.shape:\", context_vecs.shape)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "d334dfb5-2b6c-4c33-82d5-b4e9db5867bb",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 09:29:06 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "- Note that the above is essentially a rewritten version of `MultiHeadAttentionWrapper` that is more efficient\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- The resulting output looks a bit different since the random weight initializations differ, but both are fully functional implementations that can be used in the GPT class we will implement in the upcoming chapters\n",
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "- Note that in addition, we added a linear projection layer (`self.out_proj `) to the `MultiHeadAttention` class above. This is simply a linear transformation that doesn't change the dimensions. It's a standard convention to use such a projection layer in LLM implementation, but it's not strictly necessary (recent research has shown that it can be removed without affecting the modeling performance; see the further reading section at the end of this chapter)\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 09:08:38 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "dbe5d396-c990-45dc-9908-2c621461f851",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch03_compressed/26.webp\" width=\"400px\">"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "8b0ed78c-e8ac-4f8f-a479-a98242ae8f65",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 09:29:06 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "- Note that if you are interested in a compact and efficient implementation of the above, you can also consider the [`torch.nn.MultiheadAttention`](https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html) class in PyTorch"
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "363701ad-2022-46c8-9972-390d2a2b9911",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Since the above implementation may look a bit complex at first glance, let's look at what happens when executing `attn_scores = queries @ keys.transpose(2, 3)`:"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-18 05:56:23 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 37,
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "e8cfc1ae-78ab-4faa-bc73-98bd054806c9",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "tensor([[[[1.3208, 1.1631, 1.2879],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "          [1.1631, 2.2150, 1.8424],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "          [1.2879, 1.8424, 2.0402]],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "         [[0.4391, 0.7003, 0.5903],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "          [0.7003, 1.3737, 1.0620],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "          [0.5903, 1.0620, 0.9912]]]])\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "# (b, num_heads, num_tokens, head_dim) = (1, 2, 3, 4)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "a = torch.tensor([[[[0.2745, 0.6584, 0.2775, 0.8573],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                    [0.8993, 0.0390, 0.9268, 0.7388],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                    [0.7179, 0.7058, 0.9156, 0.4340]],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                   [[0.0772, 0.3565, 0.1479, 0.5331],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                    [0.4066, 0.2318, 0.4545, 0.9737],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                    [0.4606, 0.5159, 0.4220, 0.5786]]]])\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(a @ a.transpose(2, 3))"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "0587b946-c8f2-4888-adbf-5a5032fbfd7b",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 09:29:06 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "- In this case, the matrix multiplication implementation in PyTorch will handle the 4-dimensional input tensor so that the matrix multiplication is carried out between the 2 last dimensions (num_tokens, head_dim) and then repeated for the individual heads \n",
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-03-15 05:00:28 +08:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "- For instance, the following becomes a more compact way to compute the matrix multiplication for each head separately:"
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-18 05:56:23 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 38,
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   "id": "053760f1-1a02-42f0-b3bf-3d939e407039",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "First head:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      " tensor([[1.3208, 1.1631, 1.2879],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "        [1.1631, 2.2150, 1.8424],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "        [1.2879, 1.8424, 2.0402]])\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Second head:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      " tensor([[0.4391, 0.7003, 0.5903],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "        [0.7003, 1.3737, 1.0620],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "        [0.5903, 1.0620, 0.9912]])\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "first_head = a[0, 0, :, :]\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "first_res = first_head @ first_head.T\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(\"First head:\\n\", first_res)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "second_head = a[0, 1, :, :]\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "second_res = second_head @ second_head.T\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(\"\\nSecond head:\\n\", second_res)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "dec671bf-7938-4304-ad1e-75d9920e7f43",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "# Summary and takeaways"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "id": "fa3e4113-ffca-432c-b3ec-7a50bd15da25",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-03-17 09:29:06 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "- See the [./multihead-attention.ipynb](./multihead-attention.ipynb) code notebook, which is a concise version of the data loader (chapter 2) plus the multi-head attention class that we implemented in this chapter and will need for training the GPT model in upcoming chapters"
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  "kernelspec": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "display_name": "Python 3 (ipykernel)",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "language": "python",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "name": "python3"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  "language_info": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "codemirror_mode": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "name": "ipython",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "version": 3
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "file_extension": ".py",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "mimetype": "text/x-python",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "name": "python",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "nbconvert_exporter": "python",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "pygments_lexer": "ipython3",
							 
						 
					
						
							
								
									
										
										
										
											2024-04-18 05:56:23 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "version": "3.11.4"
							 
						 
					
						
							
								
									
										
										
										
											2023-12-09 17:13:56 -06:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								  }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 "nbformat": 4,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 "nbformat_minor": 5
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								}