mirror of
				https://github.com/rasbt/LLMs-from-scratch.git
				synced 2025-10-25 23:11:23 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			2036 lines
		
	
	
		
			68 KiB
		
	
	
	
		
			Plaintext
		
	
	
	
	
	
			
		
		
	
	
			2036 lines
		
	
	
		
			68 KiB
		
	
	
	
		
			Plaintext
		
	
	
	
	
	
| {
 | |
|  "cells": [
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "1ae38945-39dd-45dc-ad4f-da7a4404241f",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "<table style=\"width:100%\">\n",
 | |
|     "<tr>\n",
 | |
|     "<td style=\"vertical-align:middle; text-align:left;\">\n",
 | |
|     "<font size=\"2\">\n",
 | |
|     "Supplementary code for the <a href=\"http://mng.bz/orYv\">Build a Large Language Model From Scratch</a> book by <a href=\"https://sebastianraschka.com\">Sebastian Raschka</a><br>\n",
 | |
|     "<br>Code repository: <a href=\"https://github.com/rasbt/LLMs-from-scratch\">https://github.com/rasbt/LLMs-from-scratch</a>\n",
 | |
|     "</font>\n",
 | |
|     "</td>\n",
 | |
|     "<td style=\"vertical-align:middle; text-align:left;\">\n",
 | |
|     "<a href=\"http://mng.bz/orYv\"><img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/cover-small.webp\" width=\"100px\"></a>\n",
 | |
|     "</td>\n",
 | |
|     "</tr>\n",
 | |
|     "</table>\n"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "8bfa70ec-5c4c-40e8-b923-16f8167e3181",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "# Chapter 3: Coding Attention Mechanisms"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "c29bcbe8-a034-43a2-b557-997b03c9882d",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "Packages that are being used in this notebook:"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "code",
 | |
|    "execution_count": 1,
 | |
|    "id": "e58f33e8-5dc9-4dd5-ab84-5a011fa11d92",
 | |
|    "metadata": {},
 | |
|    "outputs": [
 | |
|     {
 | |
|      "name": "stdout",
 | |
|      "output_type": "stream",
 | |
|      "text": [
 | |
|       "torch version: 2.2.2\n"
 | |
|      ]
 | |
|     }
 | |
|    ],
 | |
|    "source": [
 | |
|     "from importlib.metadata import version\n",
 | |
|     "\n",
 | |
|     "print(\"torch version:\", version(\"torch\"))"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "a2a4474d-7c68-4846-8702-37906cf08197",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "- This chapter covers attention mechanisms, the engine of LLMs:"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "02a11208-d9d3-44b1-8e0d-0c8414110b93",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch03_compressed/01.webp?123\" width=\"500px\">"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "50e020fd-9690-4343-80df-da96678bef5e",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch03_compressed/02.webp\" width=\"600px\">"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "ecc4dcee-34ea-4c05-9085-2f8887f70363",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "## 3.1 The problem with modeling long sequences"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "a55aa49c-36c2-48da-b1d9-70f416e46a6a",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "- No code in this section\n",
 | |
|     "- Translating a text word by word isn't feasible due to the differences in grammatical structures between the source and target languages:"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "55c0c433-aa4b-491e-848a-54905ebb05ad",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch03_compressed/03.webp\" width=\"400px\">"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "db03c48a-3429-48ea-9d4a-2e53b0e516b1",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "- Prior to the introduction of transformer models, encoder-decoder RNNs were commonly used for machine translation tasks\n",
 | |
|     "- In this setup, the encoder processes a sequence of tokens from the source language, using a hidden state—a kind of intermediate layer within the neural network—to generate a condensed representation of the entire input sequence:"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "03d8df2c-c1c2-4df0-9977-ade9713088b2",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch03_compressed/04.webp\" width=\"500px\">"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "3602c585-b87a-41c7-a324-c5e8298849df",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "## 3.2 Capturing data dependencies with attention mechanisms"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "b6fde64c-6034-421d-81d9-8244932086ea",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "- No code in this section\n",
 | |
|     "- Through an attention mechanism, the text-generating decoder segment of the network is capable of selectively accessing all input tokens, implying that certain input tokens hold more significance than others in the generation of a specific output token:"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "bc4f6293-8ab5-4aeb-a04c-50ee158485b1",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch03_compressed/05.webp\" width=\"500px\">"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "8044be1f-e6a2-4a1f-a6dd-e325d3bad05e",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "- Self-attention in transformers is a technique designed to enhance input representations by enabling each position in a sequence to engage with and determine the relevance of every other position within the same sequence"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "6565dc9f-b1be-4c78-b503-42ccc743296c",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch03_compressed/06.webp\" width=\"300px\">"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "5efe05ff-b441-408e-8d66-cde4eb3397e3",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "## 3.3 Attending to different parts of the input with self-attention"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "6d9af516-7c37-4400-ab53-34936d5495a9",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "### 3.3.1 A simple self-attention mechanism without trainable weights"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "d269e9f1-df11-4644-b575-df338cf46cdf",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "- This section explains a very simplified variant of self-attention, which does not contain any trainable weights\n",
 | |
|     "- This is purely for illustration purposes and NOT the attention mechanism that is used in transformers\n",
 | |
|     "- The next section, section 3.3.2, will extend this simple attention mechanism to implement the real self-attention mechanism\n",
 | |
|     "- Suppose we are given an input sequence $x^{(1)}$ to $x^{(T)}$\n",
 | |
|     "  - The input is a text (for example, a sentence like \"Your journey starts with one step\") that has already been converted into token embeddings as described in chapter 2\n",
 | |
|     "  - For instance, $x^{(1)}$ is a d-dimensional vector representing the word \"Your\", and so forth\n",
 | |
|     "- **Goal:** compute context vectors $z^{(i)}$ for each input sequence element $x^{(i)}$ in $x^{(1)}$ to $x^{(T)}$ (where $z$ and $x$ have the same dimension)\n",
 | |
|     "    - A context vector $z^{(i)}$ is a weighted sum over the inputs $x^{(1)}$ to $x^{(T)}$\n",
 | |
|     "    - The context vector is \"context\"-specific to a certain input\n",
 | |
|     "      - Instead of $x^{(i)}$ as a placeholder for an arbitrary input token, let's consider the second input, $x^{(2)}$\n",
 | |
|     "      - And to continue with a concrete example, instead of the placeholder $z^{(i)}$, we consider the second output context vector, $z^{(2)}$\n",
 | |
|     "      - The second context vector, $z^{(2)}$, is a weighted sum over all inputs $x^{(1)}$ to $x^{(T)}$ weighted with respect to the second input element, $x^{(2)}$\n",
 | |
|     "      - The attention weights are the weights that determine how much each of the input elements contributes to the weighted sum when computing $z^{(2)}$\n",
 | |
|     "      - In short, think of $z^{(2)}$ as a modified version of $x^{(2)}$ that also incorporates information about all other input elements that are relevant to a given task at hand"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "fcc7c7a2-b6ab-478f-ae37-faa8eaa8049a",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch03_compressed/07.webp\" width=\"400px\">\n",
 | |
|     "\n",
 | |
|     "- (Please note that the numbers in this figure are truncated to one\n",
 | |
|     "digit after the decimal point to reduce visual clutter; similarly, other figures may also contain truncated values)"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "ff856c58-8382-44c7-827f-798040e6e697",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "- By convention, the unnormalized attention weights are referred to as **\"attention scores\"** whereas the normalized attention scores, which sum to 1, are referred to as **\"attention weights\"**\n"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "01b10344-128d-462a-823f-2178dff5fd58",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "- The code below walks through the figure above step by step\n",
 | |
|     "\n",
 | |
|     "<br>\n",
 | |
|     "\n",
 | |
|     "- **Step 1:** compute unnormalized attention scores $\\omega$\n",
 | |
|     "- Suppose we use the second input token as the query, that is, $q^{(2)} = x^{(2)}$, we compute the unnormalized attention scores via dot products:\n",
 | |
|     "    - $\\omega_{21} = x^{(1)} q^{(2)\\top}$\n",
 | |
|     "    - $\\omega_{22} = x^{(2)} q^{(2)\\top}$\n",
 | |
|     "    - $\\omega_{23} = x^{(3)} q^{(2)\\top}$\n",
 | |
|     "    - ...\n",
 | |
|     "    - $\\omega_{2T} = x^{(T)} q^{(2)\\top}$\n",
 | |
|     "- Above, $\\omega$ is the Greek letter \"omega\" used to symbolize the unnormalized attention scores\n",
 | |
|     "    - The subscript \"21\" in $\\omega_{21}$ means that input sequence element 2 was used as a query against input sequence element 1"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "35e55f7a-f2d0-4f24-858b-228e4fe88fb3",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "- Suppose we have the following input sentence that is already embedded in 3-dimensional vectors as described in chapter 3 (we use a very small embedding dimension here for illustration purposes, so that it fits onto the page without line breaks):"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "code",
 | |
|    "execution_count": 2,
 | |
|    "id": "22b9556a-aaf8-4ab4-a5b4-973372b0b2c3",
 | |
|    "metadata": {},
 | |
|    "outputs": [],
 | |
|    "source": [
 | |
|     "import torch\n",
 | |
|     "\n",
 | |
|     "inputs = torch.tensor(\n",
 | |
|     "  [[0.43, 0.15, 0.89], # Your     (x^1)\n",
 | |
|     "   [0.55, 0.87, 0.66], # journey  (x^2)\n",
 | |
|     "   [0.57, 0.85, 0.64], # starts   (x^3)\n",
 | |
|     "   [0.22, 0.58, 0.33], # with     (x^4)\n",
 | |
|     "   [0.77, 0.25, 0.10], # one      (x^5)\n",
 | |
|     "   [0.05, 0.80, 0.55]] # step     (x^6)\n",
 | |
|     ")"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "299baef3-b1a8-49ba-bad4-f62c8a416d83",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "- (In this book, we follow the common machine learning and deep learning convention where training examples are represented as rows and feature values as columns; in the case of the tensor shown above, each row represents a word, and each column represents an embedding dimension)\n",
 | |
|     "\n",
 | |
|     "- The primary objective of this section is to demonstrate how the context vector $z^{(2)}$\n",
 | |
|     "  is calculated using the second input sequence, $x^{(2)}$, as a query\n",
 | |
|     "\n",
 | |
|     "- The figure depicts the initial step in this process, which involves calculating the attention scores ω between $x^{(2)}$\n",
 | |
|     "  and all other input elements through a dot product operation"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "5cb3453a-58fa-42c4-b225-86850bc856f8",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch03_compressed/08.webp\" width=\"400px\">"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "77be52fb-82fd-4886-a4c8-f24a9c87af22",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "- We use input sequence element 2, $x^{(2)}$, as an example to compute context vector $z^{(2)}$; later in this section, we will generalize this to compute all context vectors.\n",
 | |
|     "- The first step is to compute the unnormalized attention scores by computing the dot product between the query $x^{(2)}$ and all other input tokens:"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "code",
 | |
|    "execution_count": 3,
 | |
|    "id": "6fb5b2f8-dd2c-4a6d-94ef-a0e9ad163951",
 | |
|    "metadata": {},
 | |
|    "outputs": [
 | |
|     {
 | |
|      "name": "stdout",
 | |
|      "output_type": "stream",
 | |
|      "text": [
 | |
|       "tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])\n"
 | |
|      ]
 | |
|     }
 | |
|    ],
 | |
|    "source": [
 | |
|     "query = inputs[1]  # 2nd input token is the query\n",
 | |
|     "\n",
 | |
|     "attn_scores_2 = torch.empty(inputs.shape[0])\n",
 | |
|     "for i, x_i in enumerate(inputs):\n",
 | |
|     "    attn_scores_2[i] = torch.dot(x_i, query) # dot product (transpose not necessary here since they are 1-dim vectors)\n",
 | |
|     "\n",
 | |
|     "print(attn_scores_2)"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "8df09ae0-199f-4b6f-81a0-2f70546684b8",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "- Side note: a dot product is essentially a shorthand for multiplying two vectors elements-wise and summing the resulting products:"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "code",
 | |
|    "execution_count": 4,
 | |
|    "id": "9842f39b-1654-410e-88bf-d1b899bf0241",
 | |
|    "metadata": {},
 | |
|    "outputs": [
 | |
|     {
 | |
|      "name": "stdout",
 | |
|      "output_type": "stream",
 | |
|      "text": [
 | |
|       "tensor(0.9544)\n",
 | |
|       "tensor(0.9544)\n"
 | |
|      ]
 | |
|     }
 | |
|    ],
 | |
|    "source": [
 | |
|     "res = 0.\n",
 | |
|     "\n",
 | |
|     "for idx, element in enumerate(inputs[0]):\n",
 | |
|     "    res += inputs[0][idx] * query[idx]\n",
 | |
|     "\n",
 | |
|     "print(res)\n",
 | |
|     "print(torch.dot(inputs[0], query))"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "7d444d76-e19e-4e9a-a268-f315d966609b",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "- **Step 2:** normalize the unnormalized attention scores (\"omegas\", $\\omega$) so that they sum up to 1\n",
 | |
|     "- Here is a simple way to normalize the unnormalized attention scores to sum up to 1 (a convention, useful for interpretation, and important for training stability):"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "dfd965d6-980c-476a-93d8-9efe603b1b3b",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch03_compressed/09.webp\" width=\"500px\">"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "code",
 | |
|    "execution_count": 5,
 | |
|    "id": "e3ccc99c-33ce-4f11-b7f2-353cf1cbdaba",
 | |
|    "metadata": {},
 | |
|    "outputs": [
 | |
|     {
 | |
|      "name": "stdout",
 | |
|      "output_type": "stream",
 | |
|      "text": [
 | |
|       "Attention weights: tensor([0.1455, 0.2278, 0.2249, 0.1285, 0.1077, 0.1656])\n",
 | |
|       "Sum: tensor(1.0000)\n"
 | |
|      ]
 | |
|     }
 | |
|    ],
 | |
|    "source": [
 | |
|     "attn_weights_2_tmp = attn_scores_2 / attn_scores_2.sum()\n",
 | |
|     "\n",
 | |
|     "print(\"Attention weights:\", attn_weights_2_tmp)\n",
 | |
|     "print(\"Sum:\", attn_weights_2_tmp.sum())"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "75dc0a57-f53e-41bf-8793-daa77a819431",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "- However, in practice, using the softmax function for normalization, which is better at handling extreme values and has more desirable gradient properties during training, is common and recommended.\n",
 | |
|     "- Here's a naive implementation of a softmax function for scaling, which also normalizes the vector elements such that they sum up to 1:"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "code",
 | |
|    "execution_count": 6,
 | |
|    "id": "07b2e58d-a6ed-49f0-a1cd-2463e8d53a20",
 | |
|    "metadata": {},
 | |
|    "outputs": [
 | |
|     {
 | |
|      "name": "stdout",
 | |
|      "output_type": "stream",
 | |
|      "text": [
 | |
|       "Attention weights: tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])\n",
 | |
|       "Sum: tensor(1.)\n"
 | |
|      ]
 | |
|     }
 | |
|    ],
 | |
|    "source": [
 | |
|     "def softmax_naive(x):\n",
 | |
|     "    return torch.exp(x) / torch.exp(x).sum(dim=0)\n",
 | |
|     "\n",
 | |
|     "attn_weights_2_naive = softmax_naive(attn_scores_2)\n",
 | |
|     "\n",
 | |
|     "print(\"Attention weights:\", attn_weights_2_naive)\n",
 | |
|     "print(\"Sum:\", attn_weights_2_naive.sum())"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "f0a1cbbb-4744-41cb-8910-f5c1355555fb",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "- The naive implementation above can suffer from numerical instability issues for large or small input values due to overflow and underflow issues\n",
 | |
|     "- Hence, in practice, it's recommended to use the PyTorch implementation of softmax instead, which has been highly optimized for performance:"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "code",
 | |
|    "execution_count": 7,
 | |
|    "id": "2d99cac4-45ea-46b3-b3c1-e000ad16e158",
 | |
|    "metadata": {},
 | |
|    "outputs": [
 | |
|     {
 | |
|      "name": "stdout",
 | |
|      "output_type": "stream",
 | |
|      "text": [
 | |
|       "Attention weights: tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])\n",
 | |
|       "Sum: tensor(1.)\n"
 | |
|      ]
 | |
|     }
 | |
|    ],
 | |
|    "source": [
 | |
|     "attn_weights_2 = torch.softmax(attn_scores_2, dim=0)\n",
 | |
|     "\n",
 | |
|     "print(\"Attention weights:\", attn_weights_2)\n",
 | |
|     "print(\"Sum:\", attn_weights_2.sum())"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "e43e36c7-90b2-427f-94f6-bb9d31b2ab3f",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "- **Step 3**: compute the context vector $z^{(2)}$ by multiplying the embedded input tokens, $x^{(i)}$ with the attention weights and sum the resulting vectors:"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "f1c9f5ac-8d3d-4847-94e3-fd783b7d4d3d",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch03_compressed/10.webp\" width=\"500px\">"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "code",
 | |
|    "execution_count": 8,
 | |
|    "id": "8fcb96f0-14e5-4973-a50e-79ea7c6af99f",
 | |
|    "metadata": {},
 | |
|    "outputs": [
 | |
|     {
 | |
|      "name": "stdout",
 | |
|      "output_type": "stream",
 | |
|      "text": [
 | |
|       "tensor([0.4419, 0.6515, 0.5683])\n"
 | |
|      ]
 | |
|     }
 | |
|    ],
 | |
|    "source": [
 | |
|     "query = inputs[1] # 2nd input token is the query\n",
 | |
|     "\n",
 | |
|     "context_vec_2 = torch.zeros(query.shape)\n",
 | |
|     "for i,x_i in enumerate(inputs):\n",
 | |
|     "    context_vec_2 += attn_weights_2[i]*x_i\n",
 | |
|     "\n",
 | |
|     "print(context_vec_2)"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "5a454262-40eb-430e-9ca4-e43fb8d6cd89",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "### 3.3.2 Computing attention weights for all input tokens"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "6a02bb73-fc19-4c88-b155-8314de5d63a8",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "#### Generalize to all input sequence tokens:\n",
 | |
|     "\n",
 | |
|     "- Above, we computed the attention weights and context vector for input 2 (as illustrated in the highlighted row in the figure below)\n",
 | |
|     "- Next, we are generalizing this computation to compute all attention weights and context vectors"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "11c0fb55-394f-42f4-ba07-d01ae5c98ab4",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch03_compressed/11.webp\" width=\"400px\">\n",
 | |
|     "\n",
 | |
|     "- (Please note that the numbers in this figure are truncated to two\n",
 | |
|     "digits after the decimal point to reduce visual clutter; the values in each row should add up to 1.0 or 100%; similarly, digits in other figures are truncated)"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "b789b990-fb51-4beb-9212-bf58876b5983",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "- In self-attention, the process starts with the calculation of attention scores, which are subsequently normalized to derive attention weights that total 1\n",
 | |
|     "- These attention weights are then utilized to generate the context vectors through a weighted summation of the inputs"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "d9bffe4b-56fe-4c37-9762-24bd924b7d3c",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch03_compressed/12.webp\" width=\"400px\">"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "aa652506-f2c8-473c-a905-85c389c842cc",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "- Apply previous **step 1** to all pairwise elements to compute the unnormalized attention score matrix:"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "code",
 | |
|    "execution_count": 9,
 | |
|    "id": "04004be8-07a1-468b-ab33-32e16a551b45",
 | |
|    "metadata": {},
 | |
|    "outputs": [
 | |
|     {
 | |
|      "name": "stdout",
 | |
|      "output_type": "stream",
 | |
|      "text": [
 | |
|       "tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],\n",
 | |
|       "        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],\n",
 | |
|       "        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],\n",
 | |
|       "        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],\n",
 | |
|       "        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],\n",
 | |
|       "        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])\n"
 | |
|      ]
 | |
|     }
 | |
|    ],
 | |
|    "source": [
 | |
|     "attn_scores = torch.empty(6, 6)\n",
 | |
|     "\n",
 | |
|     "for i, x_i in enumerate(inputs):\n",
 | |
|     "    for j, x_j in enumerate(inputs):\n",
 | |
|     "        attn_scores[i, j] = torch.dot(x_i, x_j)\n",
 | |
|     "\n",
 | |
|     "print(attn_scores)"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "1539187f-1ece-47b7-bc9b-65a97115f1d4",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "- We can achieve the same as above more efficiently via matrix multiplication:"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "code",
 | |
|    "execution_count": 10,
 | |
|    "id": "2cea69d0-9a47-45da-8d5a-47ceef2df673",
 | |
|    "metadata": {},
 | |
|    "outputs": [
 | |
|     {
 | |
|      "name": "stdout",
 | |
|      "output_type": "stream",
 | |
|      "text": [
 | |
|       "tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],\n",
 | |
|       "        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],\n",
 | |
|       "        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],\n",
 | |
|       "        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],\n",
 | |
|       "        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],\n",
 | |
|       "        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])\n"
 | |
|      ]
 | |
|     }
 | |
|    ],
 | |
|    "source": [
 | |
|     "attn_scores = inputs @ inputs.T\n",
 | |
|     "print(attn_scores)"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "02c4bac4-acfd-427f-9b11-c436ac71748d",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "- Similar to **step 2** previously, we normalize each row so that the values in each row sum to 1:"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "code",
 | |
|    "execution_count": 11,
 | |
|    "id": "fa4ef062-de81-47ee-8415-bfe1708c81b8",
 | |
|    "metadata": {},
 | |
|    "outputs": [
 | |
|     {
 | |
|      "name": "stdout",
 | |
|      "output_type": "stream",
 | |
|      "text": [
 | |
|       "tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],\n",
 | |
|       "        [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],\n",
 | |
|       "        [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],\n",
 | |
|       "        [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],\n",
 | |
|       "        [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],\n",
 | |
|       "        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])\n"
 | |
|      ]
 | |
|     }
 | |
|    ],
 | |
|    "source": [
 | |
|     "attn_weights = torch.softmax(attn_scores, dim=-1)\n",
 | |
|     "print(attn_weights)"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "3fa6d02b-7f15-4eb4-83a7-0b8a819e7a0c",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "- Quick verification that the values in each row indeed sum to 1:"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "code",
 | |
|    "execution_count": 12,
 | |
|    "id": "112b492c-fb6f-4e6d-8df5-518ae83363d5",
 | |
|    "metadata": {},
 | |
|    "outputs": [
 | |
|     {
 | |
|      "name": "stdout",
 | |
|      "output_type": "stream",
 | |
|      "text": [
 | |
|       "Row 2 sum: 1.0\n",
 | |
|       "All row sums: tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000])\n"
 | |
|      ]
 | |
|     }
 | |
|    ],
 | |
|    "source": [
 | |
|     "row_2_sum = sum([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])\n",
 | |
|     "print(\"Row 2 sum:\", row_2_sum)\n",
 | |
|     "\n",
 | |
|     "print(\"All row sums:\", attn_weights.sum(dim=-1))"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "138b0b5c-d813-44c7-b373-fde9540ddfd1",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "- Apply previous **step 3** to compute all context vectors:"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "code",
 | |
|    "execution_count": 13,
 | |
|    "id": "ba8eafcf-f7f7-4989-b8dc-61b50c4f81dc",
 | |
|    "metadata": {},
 | |
|    "outputs": [
 | |
|     {
 | |
|      "name": "stdout",
 | |
|      "output_type": "stream",
 | |
|      "text": [
 | |
|       "tensor([[0.4421, 0.5931, 0.5790],\n",
 | |
|       "        [0.4419, 0.6515, 0.5683],\n",
 | |
|       "        [0.4431, 0.6496, 0.5671],\n",
 | |
|       "        [0.4304, 0.6298, 0.5510],\n",
 | |
|       "        [0.4671, 0.5910, 0.5266],\n",
 | |
|       "        [0.4177, 0.6503, 0.5645]])\n"
 | |
|      ]
 | |
|     }
 | |
|    ],
 | |
|    "source": [
 | |
|     "all_context_vecs = attn_weights @ inputs\n",
 | |
|     "print(all_context_vecs)"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "25b245b8-7732-4fab-aa1c-e3d333195605",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "- As a sanity check, the previously computed context vector $z^{(2)} = [0.4419, 0.6515, 0.5683]$ can be found in the 2nd row in above: "
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "code",
 | |
|    "execution_count": 14,
 | |
|    "id": "2570eb7d-aee1-457a-a61e-7544478219fa",
 | |
|    "metadata": {},
 | |
|    "outputs": [
 | |
|     {
 | |
|      "name": "stdout",
 | |
|      "output_type": "stream",
 | |
|      "text": [
 | |
|       "Previous 2nd context vector: tensor([0.4419, 0.6515, 0.5683])\n"
 | |
|      ]
 | |
|     }
 | |
|    ],
 | |
|    "source": [
 | |
|     "print(\"Previous 2nd context vector:\", context_vec_2)"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "a303b6fb-9f7e-42bb-9fdb-2adabf0a6525",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "## 3.4 Implementing self-attention with trainable weights"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "88363117-93d8-41fb-8240-f7cfe08b14a3",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "- A conceptual framework illustrating how the self-attention mechanism developed in this section integrates into the overall narrative and structure of this book and chapter"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "ac9492ba-6f66-4f65-bd1d-87cf16d59928",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch03_compressed/13.webp\" width=\"400px\">"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "2b90a77e-d746-4704-9354-1ddad86e6298",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "### 3.4.1 Computing the attention weights step by step"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "46e95a46-1f67-4b71-9e84-8e2db84ab036",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "- In this section, we are implementing the self-attention mechanism that is used in the original transformer architecture, the GPT models, and most other popular LLMs\n",
 | |
|     "- This self-attention mechanism is also called \"scaled dot-product attention\"\n",
 | |
|     "- The overall idea is similar to before:\n",
 | |
|     "  - We want to compute context vectors as weighted sums over the input vectors specific to a certain input element\n",
 | |
|     "  - For the above, we need attention weights\n",
 | |
|     "- As you will see, there are only slight differences compared to the basic attention mechanism introduced earlier:\n",
 | |
|     "  - The most notable difference is the introduction of weight matrices that are updated during model training\n",
 | |
|     "  - These trainable weight matrices are crucial so that the model (specifically, the attention module inside the model) can learn to produce \"good\" context vectors"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "59db4093-93e8-4bee-be8f-c8fac8a08cdd",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch03_compressed/14.webp\" width=\"600px\">"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "4d996671-87aa-45c9-b2e0-07a7bcc9060a",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "- Implementing the self-attention mechanism step by step, we will start by introducing the three training weight matrices $W_q$, $W_k$, and $W_v$\n",
 | |
|     "- These three matrices are used to project the embedded input tokens, $x^{(i)}$, into query, key, and value vectors via matrix multiplication:\n",
 | |
|     "\n",
 | |
|     "  - Query vector: $q^{(i)} = W_q \\,x^{(i)}$\n",
 | |
|     "  - Key vector: $k^{(i)} = W_k \\,x^{(i)}$\n",
 | |
|     "  - Value vector: $v^{(i)} = W_v \\,x^{(i)}$\n"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "9f334313-5fd0-477b-8728-04080a427049",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "- The embedding dimensions of the input $x$ and the query vector $q$ can be the same or different, depending on the model's design and specific implementation\n",
 | |
|     "- In GPT models, the input and output dimensions are usually the same, but for illustration purposes, to better follow the computation, we choose different input and output dimensions here:"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "code",
 | |
|    "execution_count": 15,
 | |
|    "id": "8250fdc6-6cd6-4c5b-b9c0-8c643aadb7db",
 | |
|    "metadata": {},
 | |
|    "outputs": [],
 | |
|    "source": [
 | |
|     "x_2 = inputs[1] # second input element\n",
 | |
|     "d_in = inputs.shape[1] # the input embedding size, d=3\n",
 | |
|     "d_out = 2 # the output embedding size, d=2"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "f528cfb3-e226-47dd-b363-cc2caaeba4bf",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "- Below, we initialize the three weight matrices; note that we are setting `requires_grad=False` to reduce clutter in the outputs for illustration purposes, but if we were to use the weight matrices for model training, we would set `requires_grad=True` to update these matrices during model training"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "code",
 | |
|    "execution_count": 16,
 | |
|    "id": "bfd7259a-f26c-4cea-b8fc-282b5cae1e00",
 | |
|    "metadata": {},
 | |
|    "outputs": [],
 | |
|    "source": [
 | |
|     "torch.manual_seed(123)\n",
 | |
|     "\n",
 | |
|     "W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)\n",
 | |
|     "W_key   = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)\n",
 | |
|     "W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "abfd0b50-7701-4adb-821c-e5433622d9c4",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "- Next we compute the query, key, and value vectors:"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "code",
 | |
|    "execution_count": 17,
 | |
|    "id": "73cedd62-01e1-4196-a575-baecc6095601",
 | |
|    "metadata": {},
 | |
|    "outputs": [
 | |
|     {
 | |
|      "name": "stdout",
 | |
|      "output_type": "stream",
 | |
|      "text": [
 | |
|       "tensor([0.4306, 1.4551])\n"
 | |
|      ]
 | |
|     }
 | |
|    ],
 | |
|    "source": [
 | |
|     "query_2 = x_2 @ W_query # _2 because it's with respect to the 2nd input element\n",
 | |
|     "key_2 = x_2 @ W_key \n",
 | |
|     "value_2 = x_2 @ W_value\n",
 | |
|     "\n",
 | |
|     "print(query_2)"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "9be308b3-aca3-421b-b182-19c3a03b71c7",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "- As we can see below, we successfully projected the 6 input tokens from a 3D onto a 2D embedding space:"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "code",
 | |
|    "execution_count": 18,
 | |
|    "id": "8c1c3949-fc08-4d19-a41e-1c235b4e631b",
 | |
|    "metadata": {},
 | |
|    "outputs": [
 | |
|     {
 | |
|      "name": "stdout",
 | |
|      "output_type": "stream",
 | |
|      "text": [
 | |
|       "keys.shape: torch.Size([6, 2])\n",
 | |
|       "values.shape: torch.Size([6, 2])\n"
 | |
|      ]
 | |
|     }
 | |
|    ],
 | |
|    "source": [
 | |
|     "keys = inputs @ W_key \n",
 | |
|     "values = inputs @ W_value\n",
 | |
|     "\n",
 | |
|     "print(\"keys.shape:\", keys.shape)\n",
 | |
|     "print(\"values.shape:\", values.shape)"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "bac5dfd6-ade8-4e7b-b0c1-bed40aa24481",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "- In the next step, **step 2**, we compute the unnormalized attention scores by computing the dot product between the query and each key vector:"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "8ed0a2b7-5c50-4ede-90cf-7ad74412b3aa",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch03_compressed/15.webp\" width=\"600px\">"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "code",
 | |
|    "execution_count": 19,
 | |
|    "id": "64cbc253-a182-4490-a765-246979ea0a28",
 | |
|    "metadata": {},
 | |
|    "outputs": [
 | |
|     {
 | |
|      "name": "stdout",
 | |
|      "output_type": "stream",
 | |
|      "text": [
 | |
|       "tensor(1.8524)\n"
 | |
|      ]
 | |
|     }
 | |
|    ],
 | |
|    "source": [
 | |
|     "keys_2 = keys[1] # Python starts index at 0\n",
 | |
|     "attn_score_22 = query_2.dot(keys_2)\n",
 | |
|     "print(attn_score_22)"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "9e9d15c0-c24e-4e6f-a160-6349b418f935",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "- Since we have 6 inputs, we have 6 attention scores for the given query vector:"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "code",
 | |
|    "execution_count": 20,
 | |
|    "id": "b14e44b5-d170-40f9-8847-8990804af26d",
 | |
|    "metadata": {},
 | |
|    "outputs": [
 | |
|     {
 | |
|      "name": "stdout",
 | |
|      "output_type": "stream",
 | |
|      "text": [
 | |
|       "tensor([1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440])\n"
 | |
|      ]
 | |
|     }
 | |
|    ],
 | |
|    "source": [
 | |
|     "attn_scores_2 = query_2 @ keys.T # All attention scores for given query\n",
 | |
|     "print(attn_scores_2)"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "8622cf39-155f-4eb5-a0c0-82a03ce9b999",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch03_compressed/16.webp\" width=\"600px\">"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "e1609edb-f089-461a-8de2-c20c1bb29836",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "- Next, in **step 3**, we compute the attention weights (normalized attention scores that sum up to 1) using the softmax function we used earlier\n",
 | |
|     "- The difference to earlier is that we now scale the attention scores by dividing them by the square root of the embedding dimension, $\\sqrt{d_k}$ (i.e., `d_k**0.5`):"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "code",
 | |
|    "execution_count": 21,
 | |
|    "id": "146f5587-c845-4e30-9894-c7ed3a248153",
 | |
|    "metadata": {},
 | |
|    "outputs": [
 | |
|     {
 | |
|      "name": "stdout",
 | |
|      "output_type": "stream",
 | |
|      "text": [
 | |
|       "tensor([0.1500, 0.2264, 0.2199, 0.1311, 0.0906, 0.1820])\n"
 | |
|      ]
 | |
|     }
 | |
|    ],
 | |
|    "source": [
 | |
|     "d_k = keys.shape[1]\n",
 | |
|     "attn_weights_2 = torch.softmax(attn_scores_2 / d_k**0.5, dim=-1)\n",
 | |
|     "print(attn_weights_2)"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "b8f61a28-b103-434a-aee1-ae7cbd821126",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch03_compressed/17.webp\" width=\"600px\">"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "1890e3f9-db86-4ab8-9f3b-53113504a61f",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "- In **step 4**, we now compute the context vector for input query vector 2:"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "code",
 | |
|    "execution_count": 22,
 | |
|    "id": "e138f033-fa7e-4e3a-8764-b53a96b26397",
 | |
|    "metadata": {},
 | |
|    "outputs": [
 | |
|     {
 | |
|      "name": "stdout",
 | |
|      "output_type": "stream",
 | |
|      "text": [
 | |
|       "tensor([0.3061, 0.8210])\n"
 | |
|      ]
 | |
|     }
 | |
|    ],
 | |
|    "source": [
 | |
|     "context_vec_2 = attn_weights_2 @ values\n",
 | |
|     "print(context_vec_2)"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "9d7b2907-e448-473e-b46c-77735a7281d8",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "### 3.4.2 Implementing a compact SelfAttention class"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "04313410-3155-4d90-a7a3-2f3386e73677",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "- Putting it all together, we can implement the self-attention mechanism as follows:"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "code",
 | |
|    "execution_count": 23,
 | |
|    "id": "51590326-cdbe-4e62-93b1-17df71c11ee4",
 | |
|    "metadata": {},
 | |
|    "outputs": [
 | |
|     {
 | |
|      "name": "stdout",
 | |
|      "output_type": "stream",
 | |
|      "text": [
 | |
|       "tensor([[0.2996, 0.8053],\n",
 | |
|       "        [0.3061, 0.8210],\n",
 | |
|       "        [0.3058, 0.8203],\n",
 | |
|       "        [0.2948, 0.7939],\n",
 | |
|       "        [0.2927, 0.7891],\n",
 | |
|       "        [0.2990, 0.8040]], grad_fn=<MmBackward0>)\n"
 | |
|      ]
 | |
|     }
 | |
|    ],
 | |
|    "source": [
 | |
|     "import torch.nn as nn\n",
 | |
|     "\n",
 | |
|     "class SelfAttention_v1(nn.Module):\n",
 | |
|     "\n",
 | |
|     "    def __init__(self, d_in, d_out):\n",
 | |
|     "        super().__init__()\n",
 | |
|     "        self.W_query = nn.Parameter(torch.rand(d_in, d_out))\n",
 | |
|     "        self.W_key   = nn.Parameter(torch.rand(d_in, d_out))\n",
 | |
|     "        self.W_value = nn.Parameter(torch.rand(d_in, d_out))\n",
 | |
|     "\n",
 | |
|     "    def forward(self, x):\n",
 | |
|     "        keys = x @ self.W_key\n",
 | |
|     "        queries = x @ self.W_query\n",
 | |
|     "        values = x @ self.W_value\n",
 | |
|     "        \n",
 | |
|     "        attn_scores = queries @ keys.T # omega\n",
 | |
|     "        attn_weights = torch.softmax(\n",
 | |
|     "            attn_scores / keys.shape[-1]**0.5, dim=-1\n",
 | |
|     "        )\n",
 | |
|     "\n",
 | |
|     "        context_vec = attn_weights @ values\n",
 | |
|     "        return context_vec\n",
 | |
|     "\n",
 | |
|     "torch.manual_seed(123)\n",
 | |
|     "sa_v1 = SelfAttention_v1(d_in, d_out)\n",
 | |
|     "print(sa_v1(inputs))"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "7ee1a024-84a5-425a-9567-54ab4e4ed445",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch03_compressed/18.webp\" width=\"400px\">"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "048e0c16-d911-4ec8-b0bc-45ceec75c081",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "- We can streamline the implementation above using PyTorch's Linear layers, which are equivalent to a matrix multiplication if we disable the bias units\n",
 | |
|     "- Another big advantage of using `nn.Linear` over our manual `nn.Parameter(torch.rand(...)` approach is that `nn.Linear` has a preferred weight initialization scheme, which leads to more stable model training"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "code",
 | |
|    "execution_count": 24,
 | |
|    "id": "73f411e3-e231-464a-89fe-0a9035e5f839",
 | |
|    "metadata": {},
 | |
|    "outputs": [
 | |
|     {
 | |
|      "name": "stdout",
 | |
|      "output_type": "stream",
 | |
|      "text": [
 | |
|       "tensor([[-0.0739,  0.0713],\n",
 | |
|       "        [-0.0748,  0.0703],\n",
 | |
|       "        [-0.0749,  0.0702],\n",
 | |
|       "        [-0.0760,  0.0685],\n",
 | |
|       "        [-0.0763,  0.0679],\n",
 | |
|       "        [-0.0754,  0.0693]], grad_fn=<MmBackward0>)\n"
 | |
|      ]
 | |
|     }
 | |
|    ],
 | |
|    "source": [
 | |
|     "class SelfAttention_v2(nn.Module):\n",
 | |
|     "\n",
 | |
|     "    def __init__(self, d_in, d_out, qkv_bias=False):\n",
 | |
|     "        super().__init__()\n",
 | |
|     "        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
 | |
|     "        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
 | |
|     "        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
 | |
|     "\n",
 | |
|     "    def forward(self, x):\n",
 | |
|     "        keys = self.W_key(x)\n",
 | |
|     "        queries = self.W_query(x)\n",
 | |
|     "        values = self.W_value(x)\n",
 | |
|     "        \n",
 | |
|     "        attn_scores = queries @ keys.T\n",
 | |
|     "        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)\n",
 | |
|     "\n",
 | |
|     "        context_vec = attn_weights @ values\n",
 | |
|     "        return context_vec\n",
 | |
|     "\n",
 | |
|     "torch.manual_seed(789)\n",
 | |
|     "sa_v2 = SelfAttention_v2(d_in, d_out)\n",
 | |
|     "print(sa_v2(inputs))"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "915cd8a5-a895-42c9-8b8e-06b5ae19ffce",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "- Note that `SelfAttention_v1` and `SelfAttention_v2` give different outputs because they use different initial weights for the weight matrices"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "c5025b37-0f2c-4a67-a7cb-1286af7026ab",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "## 3.5 Hiding future words with causal attention"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "aef0a6b8-205a-45bf-9d26-8fd77a8a03c3",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "- In causal attention, the attention weights above the diagonal are masked, ensuring that for any given input, the LLM is unable to utilize future tokens while calculating the context vectors with the attention weight"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "71e91bb5-5aae-4f05-8a95-973b3f988a35",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch03_compressed/19.webp\" width=\"400px\">"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "82f405de-cd86-4e72-8f3c-9ea0354946ba",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "### 3.5.1 Applying a causal attention mask"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "014f28d0-8218-48e4-8b9c-bdc5ce489218",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "- In this section, we are converting the previous self-attention mechanism into a causal self-attention mechanism\n",
 | |
|     "- Causal self-attention ensures that the model's prediction for a certain position in a sequence is only dependent on the known outputs at previous positions, not on future positions\n",
 | |
|     "- In simpler words, this ensures that each next word prediction should only depend on the preceding words\n",
 | |
|     "- To achieve this, for each given token, we mask out the future tokens (the ones that come after the current token in the input text):"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "57f99af3-32bc-48f5-8eb4-63504670ca0a",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch03_compressed/20.webp\" width=\"600px\">"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "cbfaec7a-68f2-4157-a4b5-2aeceed199d9",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "- To illustrate and implement causal self-attention, let's work with the attention scores and weights from the previous section: "
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "code",
 | |
|    "execution_count": 25,
 | |
|    "id": "1933940d-0fa5-4b17-a3ce-388e5314a1bb",
 | |
|    "metadata": {},
 | |
|    "outputs": [
 | |
|     {
 | |
|      "name": "stdout",
 | |
|      "output_type": "stream",
 | |
|      "text": [
 | |
|       "tensor([[0.1921, 0.1646, 0.1652, 0.1550, 0.1721, 0.1510],\n",
 | |
|       "        [0.2041, 0.1659, 0.1662, 0.1496, 0.1665, 0.1477],\n",
 | |
|       "        [0.2036, 0.1659, 0.1662, 0.1498, 0.1664, 0.1480],\n",
 | |
|       "        [0.1869, 0.1667, 0.1668, 0.1571, 0.1661, 0.1564],\n",
 | |
|       "        [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.1585],\n",
 | |
|       "        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],\n",
 | |
|       "       grad_fn=<SoftmaxBackward0>)\n"
 | |
|      ]
 | |
|     }
 | |
|    ],
 | |
|    "source": [
 | |
|     "# Reuse the query and key weight matrices of the\n",
 | |
|     "# SelfAttention_v2 object from the previous section for convenience\n",
 | |
|     "queries = sa_v2.W_query(inputs)\n",
 | |
|     "keys = sa_v2.W_key(inputs) \n",
 | |
|     "attn_scores = queries @ keys.T\n",
 | |
|     "\n",
 | |
|     "attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)\n",
 | |
|     "print(attn_weights)"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "89020a96-b34d-41f8-9349-98c3e23fd5d6",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "- The simplest way to mask out future attention weights is by creating a mask via PyTorch's tril function with elements below the main diagonal (including the diagonal itself) set to 1 and above the main diagonal set to 0:"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "code",
 | |
|    "execution_count": 26,
 | |
|    "id": "43f3d2e3-185b-4184-9f98-edde5e6df746",
 | |
|    "metadata": {},
 | |
|    "outputs": [
 | |
|     {
 | |
|      "name": "stdout",
 | |
|      "output_type": "stream",
 | |
|      "text": [
 | |
|       "tensor([[1., 0., 0., 0., 0., 0.],\n",
 | |
|       "        [1., 1., 0., 0., 0., 0.],\n",
 | |
|       "        [1., 1., 1., 0., 0., 0.],\n",
 | |
|       "        [1., 1., 1., 1., 0., 0.],\n",
 | |
|       "        [1., 1., 1., 1., 1., 0.],\n",
 | |
|       "        [1., 1., 1., 1., 1., 1.]])\n"
 | |
|      ]
 | |
|     }
 | |
|    ],
 | |
|    "source": [
 | |
|     "context_length = attn_scores.shape[0]\n",
 | |
|     "mask_simple = torch.tril(torch.ones(context_length, context_length))\n",
 | |
|     "print(mask_simple)"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "efce2b08-3583-44da-b3fc-cabdd38761f6",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "- Then, we can multiply the attention weights with this mask to zero out the attention scores above the diagonal:"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "code",
 | |
|    "execution_count": 27,
 | |
|    "id": "9f531e2e-f4d2-4fea-a87f-4c132e48b9e7",
 | |
|    "metadata": {},
 | |
|    "outputs": [
 | |
|     {
 | |
|      "name": "stdout",
 | |
|      "output_type": "stream",
 | |
|      "text": [
 | |
|       "tensor([[0.1921, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
 | |
|       "        [0.2041, 0.1659, 0.0000, 0.0000, 0.0000, 0.0000],\n",
 | |
|       "        [0.2036, 0.1659, 0.1662, 0.0000, 0.0000, 0.0000],\n",
 | |
|       "        [0.1869, 0.1667, 0.1668, 0.1571, 0.0000, 0.0000],\n",
 | |
|       "        [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.0000],\n",
 | |
|       "        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],\n",
 | |
|       "       grad_fn=<MulBackward0>)\n"
 | |
|      ]
 | |
|     }
 | |
|    ],
 | |
|    "source": [
 | |
|     "masked_simple = attn_weights*mask_simple\n",
 | |
|     "print(masked_simple)"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "3eb35787-cf12-4024-b66d-e7215e175500",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "- However, if the mask were applied after softmax, like above, it would disrupt the probability distribution created by softmax\n",
 | |
|     "- Softmax ensures that all output values sum to 1\n",
 | |
|     "- Masking after softmax would require re-normalizing the outputs to sum to 1 again, which complicates the process and might lead to unintended effects"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "94db92d7-c397-4e42-bd8a-6a2b3e237e0f",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "- To make sure that the rows sum to 1, we can normalize the attention weights as follows:"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "code",
 | |
|    "execution_count": 28,
 | |
|    "id": "6d392083-fd81-4f70-9bdf-8db985e673d6",
 | |
|    "metadata": {},
 | |
|    "outputs": [
 | |
|     {
 | |
|      "name": "stdout",
 | |
|      "output_type": "stream",
 | |
|      "text": [
 | |
|       "tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
 | |
|       "        [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],\n",
 | |
|       "        [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],\n",
 | |
|       "        [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],\n",
 | |
|       "        [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],\n",
 | |
|       "        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],\n",
 | |
|       "       grad_fn=<DivBackward0>)\n"
 | |
|      ]
 | |
|     }
 | |
|    ],
 | |
|    "source": [
 | |
|     "row_sums = masked_simple.sum(dim=-1, keepdim=True)\n",
 | |
|     "masked_simple_norm = masked_simple / row_sums\n",
 | |
|     "print(masked_simple_norm)"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "512e7cf4-dc0e-4cec-948e-c7a3c4eb6877",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "- While we are technically done with coding the causal attention mechanism now, let's briefly look at a more efficient approach to achieve the same as above\n",
 | |
|     "- So, instead of zeroing out attention weights above the diagonal and renormalizing the results, we can mask the unnormalized attention scores above the diagonal with negative infinity before they enter the softmax function:"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "eb682900-8df2-4767-946c-a82bee260188",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch03_compressed/21.webp\" width=\"450px\">"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "code",
 | |
|    "execution_count": 29,
 | |
|    "id": "a2be2f43-9cf0-44f6-8d8b-68ef2fb3cc39",
 | |
|    "metadata": {},
 | |
|    "outputs": [
 | |
|     {
 | |
|      "name": "stdout",
 | |
|      "output_type": "stream",
 | |
|      "text": [
 | |
|       "tensor([[0.2899,   -inf,   -inf,   -inf,   -inf,   -inf],\n",
 | |
|       "        [0.4656, 0.1723,   -inf,   -inf,   -inf,   -inf],\n",
 | |
|       "        [0.4594, 0.1703, 0.1731,   -inf,   -inf,   -inf],\n",
 | |
|       "        [0.2642, 0.1024, 0.1036, 0.0186,   -inf,   -inf],\n",
 | |
|       "        [0.2183, 0.0874, 0.0882, 0.0177, 0.0786,   -inf],\n",
 | |
|       "        [0.3408, 0.1270, 0.1290, 0.0198, 0.1290, 0.0078]],\n",
 | |
|       "       grad_fn=<MaskedFillBackward0>)\n"
 | |
|      ]
 | |
|     }
 | |
|    ],
 | |
|    "source": [
 | |
|     "mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)\n",
 | |
|     "masked = attn_scores.masked_fill(mask.bool(), -torch.inf)\n",
 | |
|     "print(masked)"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "91d5f803-d735-4543-b9da-00ac10fb9c50",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "- As we can see below, now the attention weights in each row correctly sum to 1 again:"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "code",
 | |
|    "execution_count": 30,
 | |
|    "id": "b1cd6d7f-16f2-43c1-915e-0824f1a4bc52",
 | |
|    "metadata": {},
 | |
|    "outputs": [
 | |
|     {
 | |
|      "name": "stdout",
 | |
|      "output_type": "stream",
 | |
|      "text": [
 | |
|       "tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
 | |
|       "        [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],\n",
 | |
|       "        [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],\n",
 | |
|       "        [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],\n",
 | |
|       "        [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],\n",
 | |
|       "        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],\n",
 | |
|       "       grad_fn=<SoftmaxBackward0>)\n"
 | |
|      ]
 | |
|     }
 | |
|    ],
 | |
|    "source": [
 | |
|     "attn_weights = torch.softmax(masked / keys.shape[-1]**0.5, dim=-1)\n",
 | |
|     "print(attn_weights)"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "7636fc5f-6bc6-461e-ac6a-99ec8e3c0912",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "### 3.5.2 Masking additional attention weights with dropout"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "ec3dc7ee-6539-4fab-804a-8f31a890c85a",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "- In addition, we also apply dropout to reduce overfitting during training\n",
 | |
|     "- Dropout can be applied in several places:\n",
 | |
|     "  - for example, after computing the attention weights;\n",
 | |
|     "  - or after multiplying the attention weights with the value vectors\n",
 | |
|     "- Here, we will apply the dropout mask after computing the attention weights because it's more common\n",
 | |
|     "\n",
 | |
|     "- Furthermore, in this specific example, we use a dropout rate of 50%, which means randomly masking out half of the attention weights. (When we train the GPT model later, we will use a lower dropout rate, such as 0.1 or 0.2"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "ee799cf6-6175-45f2-827e-c174afedb722",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch03_compressed/22.webp\" width=\"400px\">"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "5a575458-a6da-4e54-8688-83e155f2de06",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "- If we apply a dropout rate of 0.5 (50%), the non-dropped values will be scaled accordingly by a factor of 1/0.5 = 2."
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "code",
 | |
|    "execution_count": 31,
 | |
|    "id": "0de578db-8289-41d6-b377-ef645751e33f",
 | |
|    "metadata": {},
 | |
|    "outputs": [
 | |
|     {
 | |
|      "name": "stdout",
 | |
|      "output_type": "stream",
 | |
|      "text": [
 | |
|       "tensor([[2., 2., 0., 2., 2., 0.],\n",
 | |
|       "        [0., 0., 0., 2., 0., 2.],\n",
 | |
|       "        [2., 2., 2., 2., 0., 2.],\n",
 | |
|       "        [0., 2., 2., 0., 0., 2.],\n",
 | |
|       "        [0., 2., 0., 2., 0., 2.],\n",
 | |
|       "        [0., 2., 2., 2., 2., 0.]])\n"
 | |
|      ]
 | |
|     }
 | |
|    ],
 | |
|    "source": [
 | |
|     "torch.manual_seed(123)\n",
 | |
|     "dropout = torch.nn.Dropout(0.5) # dropout rate of 50%\n",
 | |
|     "example = torch.ones(6, 6) # create a matrix of ones\n",
 | |
|     "\n",
 | |
|     "print(dropout(example))"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "code",
 | |
|    "execution_count": 32,
 | |
|    "id": "b16c5edb-942b-458c-8e95-25e4e355381e",
 | |
|    "metadata": {},
 | |
|    "outputs": [
 | |
|     {
 | |
|      "name": "stdout",
 | |
|      "output_type": "stream",
 | |
|      "text": [
 | |
|       "tensor([[2.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
 | |
|       "        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n",
 | |
|       "        [0.7599, 0.6194, 0.6206, 0.0000, 0.0000, 0.0000],\n",
 | |
|       "        [0.0000, 0.4921, 0.4925, 0.0000, 0.0000, 0.0000],\n",
 | |
|       "        [0.0000, 0.3966, 0.0000, 0.3775, 0.0000, 0.0000],\n",
 | |
|       "        [0.0000, 0.3327, 0.3331, 0.3084, 0.3331, 0.0000]],\n",
 | |
|       "       grad_fn=<MulBackward0>)\n"
 | |
|      ]
 | |
|     }
 | |
|    ],
 | |
|    "source": [
 | |
|     "torch.manual_seed(123)\n",
 | |
|     "print(dropout(attn_weights))"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "269df5c8-3e25-49d0-95d3-bb232287404f",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "- Note that the resulting dropout outputs may look different depending on your operating system; you can read more about this inconsistency [here on the PyTorch issue tracker](https://github.com/pytorch/pytorch/issues/121595)"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "cdc14639-5f0f-4840-aa9d-8eb36ea90fb7",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "### 3.5.3 Implementing a compact causal self-attention class"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "09c41d29-1933-43dc-ada6-2dbb56287204",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "- Now, we are ready to implement a working implementation of self-attention, including the causal and dropout masks\n",
 | |
|     "- One more thing is to implement the code to handle batches consisting of more than one input so that our `CausalAttention` class supports the batch outputs produced by the data loader we implemented in chapter 2\n",
 | |
|     "- For simplicity, to simulate such batch input, we duplicate the input text example:"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "code",
 | |
|    "execution_count": 33,
 | |
|    "id": "977a5fa7-a9d5-4e2e-8a32-8e0331ccfe28",
 | |
|    "metadata": {},
 | |
|    "outputs": [
 | |
|     {
 | |
|      "name": "stdout",
 | |
|      "output_type": "stream",
 | |
|      "text": [
 | |
|       "torch.Size([2, 6, 3])\n"
 | |
|      ]
 | |
|     }
 | |
|    ],
 | |
|    "source": [
 | |
|     "batch = torch.stack((inputs, inputs), dim=0)\n",
 | |
|     "print(batch.shape) # 2 inputs with 6 tokens each, and each token has embedding dimension 3"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "code",
 | |
|    "execution_count": 34,
 | |
|    "id": "60d8c2eb-2d8e-4d2c-99bc-9eef8cc53ca0",
 | |
|    "metadata": {},
 | |
|    "outputs": [
 | |
|     {
 | |
|      "name": "stdout",
 | |
|      "output_type": "stream",
 | |
|      "text": [
 | |
|       "tensor([[[-0.4519,  0.2216],\n",
 | |
|       "         [-0.5874,  0.0058],\n",
 | |
|       "         [-0.6300, -0.0632],\n",
 | |
|       "         [-0.5675, -0.0843],\n",
 | |
|       "         [-0.5526, -0.0981],\n",
 | |
|       "         [-0.5299, -0.1081]],\n",
 | |
|       "\n",
 | |
|       "        [[-0.4519,  0.2216],\n",
 | |
|       "         [-0.5874,  0.0058],\n",
 | |
|       "         [-0.6300, -0.0632],\n",
 | |
|       "         [-0.5675, -0.0843],\n",
 | |
|       "         [-0.5526, -0.0981],\n",
 | |
|       "         [-0.5299, -0.1081]]], grad_fn=<UnsafeViewBackward0>)\n",
 | |
|       "context_vecs.shape: torch.Size([2, 6, 2])\n"
 | |
|      ]
 | |
|     }
 | |
|    ],
 | |
|    "source": [
 | |
|     "class CausalAttention(nn.Module):\n",
 | |
|     "\n",
 | |
|     "    def __init__(self, d_in, d_out, context_length,\n",
 | |
|     "                 dropout, qkv_bias=False):\n",
 | |
|     "        super().__init__()\n",
 | |
|     "        self.d_out = d_out\n",
 | |
|     "        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
 | |
|     "        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
 | |
|     "        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
 | |
|     "        self.dropout = nn.Dropout(dropout) # New\n",
 | |
|     "        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1)) # New\n",
 | |
|     "\n",
 | |
|     "    def forward(self, x):\n",
 | |
|     "        b, num_tokens, d_in = x.shape # New batch dimension b\n",
 | |
|     "        keys = self.W_key(x)\n",
 | |
|     "        queries = self.W_query(x)\n",
 | |
|     "        values = self.W_value(x)\n",
 | |
|     "\n",
 | |
|     "        attn_scores = queries @ keys.transpose(1, 2) # Changed transpose\n",
 | |
|     "        attn_scores.masked_fill_(  # New, _ ops are in-place\n",
 | |
|     "            self.mask.bool()[:num_tokens, :num_tokens], -torch.inf) \n",
 | |
|     "        attn_weights = torch.softmax(\n",
 | |
|     "            attn_scores / keys.shape[-1]**0.5, dim=-1\n",
 | |
|     "        )\n",
 | |
|     "        attn_weights = self.dropout(attn_weights) # New\n",
 | |
|     "\n",
 | |
|     "        context_vec = attn_weights @ values\n",
 | |
|     "        return context_vec\n",
 | |
|     "\n",
 | |
|     "torch.manual_seed(123)\n",
 | |
|     "\n",
 | |
|     "context_length = batch.shape[1]\n",
 | |
|     "ca = CausalAttention(d_in, d_out, context_length, 0.0)\n",
 | |
|     "\n",
 | |
|     "context_vecs = ca(batch)\n",
 | |
|     "\n",
 | |
|     "print(context_vecs)\n",
 | |
|     "print(\"context_vecs.shape:\", context_vecs.shape)"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "c4333d12-17e4-4bb5-9d83-54b3a32618cd",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "- Note that dropout is only applied during training, not during inference"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "a554cf47-558c-4f45-84cd-bf9b839a8d50",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch03_compressed/23.webp\" width=\"500px\">"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "c8bef90f-cfd4-4289-b0e8-6a00dc9be44c",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "## 3.6 Extending single-head attention to multi-head attention"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "11697757-9198-4a1c-9cee-f450d8bbd3b9",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "### 3.6.1 Stacking multiple single-head attention layers"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "70766faf-cd53-41d9-8a17-f1b229756a5a",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "- Below is a summary of the self-attention implemented previously (causal and dropout masks not shown for simplicity)\n",
 | |
|     "\n",
 | |
|     "- This is also called single-head attention:\n",
 | |
|     "\n",
 | |
|     "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch03_compressed/24.webp\" width=\"400px\">\n",
 | |
|     "\n",
 | |
|     "- We simply stack multiple single-head attention modules to obtain a multi-head attention module:\n",
 | |
|     "\n",
 | |
|     "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch03_compressed/25.webp\" width=\"400px\">\n",
 | |
|     "\n",
 | |
|     "- The main idea behind multi-head attention is to run the attention mechanism multiple times (in parallel) with different, learned linear projections. This allows the model to jointly attend to information from different representation subspaces at different positions."
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "code",
 | |
|    "execution_count": 35,
 | |
|    "id": "b9a66e11-7105-4bb4-be84-041f1a1f3bd2",
 | |
|    "metadata": {},
 | |
|    "outputs": [
 | |
|     {
 | |
|      "name": "stdout",
 | |
|      "output_type": "stream",
 | |
|      "text": [
 | |
|       "tensor([[[-0.4519,  0.2216,  0.4772,  0.1063],\n",
 | |
|       "         [-0.5874,  0.0058,  0.5891,  0.3257],\n",
 | |
|       "         [-0.6300, -0.0632,  0.6202,  0.3860],\n",
 | |
|       "         [-0.5675, -0.0843,  0.5478,  0.3589],\n",
 | |
|       "         [-0.5526, -0.0981,  0.5321,  0.3428],\n",
 | |
|       "         [-0.5299, -0.1081,  0.5077,  0.3493]],\n",
 | |
|       "\n",
 | |
|       "        [[-0.4519,  0.2216,  0.4772,  0.1063],\n",
 | |
|       "         [-0.5874,  0.0058,  0.5891,  0.3257],\n",
 | |
|       "         [-0.6300, -0.0632,  0.6202,  0.3860],\n",
 | |
|       "         [-0.5675, -0.0843,  0.5478,  0.3589],\n",
 | |
|       "         [-0.5526, -0.0981,  0.5321,  0.3428],\n",
 | |
|       "         [-0.5299, -0.1081,  0.5077,  0.3493]]], grad_fn=<CatBackward0>)\n",
 | |
|       "context_vecs.shape: torch.Size([2, 6, 4])\n"
 | |
|      ]
 | |
|     }
 | |
|    ],
 | |
|    "source": [
 | |
|     "class MultiHeadAttentionWrapper(nn.Module):\n",
 | |
|     "\n",
 | |
|     "    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):\n",
 | |
|     "        super().__init__()\n",
 | |
|     "        self.heads = nn.ModuleList(\n",
 | |
|     "            [CausalAttention(d_in, d_out, context_length, dropout, qkv_bias) \n",
 | |
|     "             for _ in range(num_heads)]\n",
 | |
|     "        )\n",
 | |
|     "\n",
 | |
|     "    def forward(self, x):\n",
 | |
|     "        return torch.cat([head(x) for head in self.heads], dim=-1)\n",
 | |
|     "\n",
 | |
|     "\n",
 | |
|     "torch.manual_seed(123)\n",
 | |
|     "\n",
 | |
|     "context_length = batch.shape[1] # This is the number of tokens\n",
 | |
|     "d_in, d_out = 3, 2\n",
 | |
|     "mha = MultiHeadAttentionWrapper(\n",
 | |
|     "    d_in, d_out, context_length, 0.0, num_heads=2\n",
 | |
|     ")\n",
 | |
|     "\n",
 | |
|     "context_vecs = mha(batch)\n",
 | |
|     "\n",
 | |
|     "print(context_vecs)\n",
 | |
|     "print(\"context_vecs.shape:\", context_vecs.shape)"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "193d3d2b-2578-40ba-b791-ea2d49328e48",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "- In the implementation above, the embedding dimension is 4, because we `d_out=2` as the embedding dimension for the key, query, and value vectors as well as the context vector. And since we have 2 attention heads, we have the output embedding dimension 2*2=4"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "6836b5da-ef82-4b4c-bda1-72a462e48d4e",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "### 3.6.2 Implementing multi-head attention with weight splits"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "f4b48d0d-71ba-4fa0-b714-ca80cabcb6f7",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "- While the above is an intuitive and fully functional implementation of multi-head attention (wrapping the single-head attention `CausalAttention` implementation from earlier), we can write a stand-alone class called `MultiHeadAttention` to achieve the same\n",
 | |
|     "\n",
 | |
|     "- We don't concatenate single attention heads for this stand-alone `MultiHeadAttention` class\n",
 | |
|     "- Instead, we create single W_query, W_key, and W_value weight matrices and then split those into individual matrices for each attention head:"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "code",
 | |
|    "execution_count": 36,
 | |
|    "id": "110b0188-6e9e-4e56-a988-10523c6c8538",
 | |
|    "metadata": {},
 | |
|    "outputs": [
 | |
|     {
 | |
|      "name": "stdout",
 | |
|      "output_type": "stream",
 | |
|      "text": [
 | |
|       "tensor([[[0.3190, 0.4858],\n",
 | |
|       "         [0.2943, 0.3897],\n",
 | |
|       "         [0.2856, 0.3593],\n",
 | |
|       "         [0.2693, 0.3873],\n",
 | |
|       "         [0.2639, 0.3928],\n",
 | |
|       "         [0.2575, 0.4028]],\n",
 | |
|       "\n",
 | |
|       "        [[0.3190, 0.4858],\n",
 | |
|       "         [0.2943, 0.3897],\n",
 | |
|       "         [0.2856, 0.3593],\n",
 | |
|       "         [0.2693, 0.3873],\n",
 | |
|       "         [0.2639, 0.3928],\n",
 | |
|       "         [0.2575, 0.4028]]], grad_fn=<ViewBackward0>)\n",
 | |
|       "context_vecs.shape: torch.Size([2, 6, 2])\n"
 | |
|      ]
 | |
|     }
 | |
|    ],
 | |
|    "source": [
 | |
|     "class MultiHeadAttention(nn.Module):\n",
 | |
|     "    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):\n",
 | |
|     "        super().__init__()\n",
 | |
|     "        assert (d_out % num_heads == 0), \\\n",
 | |
|     "            \"d_out must be divisible by num_heads\"\n",
 | |
|     "\n",
 | |
|     "        self.d_out = d_out\n",
 | |
|     "        self.num_heads = num_heads\n",
 | |
|     "        self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim\n",
 | |
|     "\n",
 | |
|     "        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
 | |
|     "        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
 | |
|     "        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)\n",
 | |
|     "        self.out_proj = nn.Linear(d_out, d_out)  # Linear layer to combine head outputs\n",
 | |
|     "        self.dropout = nn.Dropout(dropout)\n",
 | |
|     "        self.register_buffer(\n",
 | |
|     "            \"mask\",\n",
 | |
|     "            torch.triu(torch.ones(context_length, context_length),\n",
 | |
|     "                       diagonal=1)\n",
 | |
|     "        )\n",
 | |
|     "\n",
 | |
|     "    def forward(self, x):\n",
 | |
|     "        b, num_tokens, d_in = x.shape\n",
 | |
|     "\n",
 | |
|     "        keys = self.W_key(x) # Shape: (b, num_tokens, d_out)\n",
 | |
|     "        queries = self.W_query(x)\n",
 | |
|     "        values = self.W_value(x)\n",
 | |
|     "\n",
 | |
|     "        # We implicitly split the matrix by adding a `num_heads` dimension\n",
 | |
|     "        # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)\n",
 | |
|     "        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim) \n",
 | |
|     "        values = values.view(b, num_tokens, self.num_heads, self.head_dim)\n",
 | |
|     "        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)\n",
 | |
|     "\n",
 | |
|     "        # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)\n",
 | |
|     "        keys = keys.transpose(1, 2)\n",
 | |
|     "        queries = queries.transpose(1, 2)\n",
 | |
|     "        values = values.transpose(1, 2)\n",
 | |
|     "\n",
 | |
|     "        # Compute scaled dot-product attention (aka self-attention) with a causal mask\n",
 | |
|     "        attn_scores = queries @ keys.transpose(2, 3)  # Dot product for each head\n",
 | |
|     "\n",
 | |
|     "        # Original mask truncated to the number of tokens and converted to boolean\n",
 | |
|     "        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]\n",
 | |
|     "\n",
 | |
|     "        # Use the mask to fill attention scores\n",
 | |
|     "        attn_scores.masked_fill_(mask_bool, -torch.inf)\n",
 | |
|     "        \n",
 | |
|     "        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)\n",
 | |
|     "        attn_weights = self.dropout(attn_weights)\n",
 | |
|     "\n",
 | |
|     "        # Shape: (b, num_tokens, num_heads, head_dim)\n",
 | |
|     "        context_vec = (attn_weights @ values).transpose(1, 2) \n",
 | |
|     "        \n",
 | |
|     "        # Combine heads, where self.d_out = self.num_heads * self.head_dim\n",
 | |
|     "        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)\n",
 | |
|     "        context_vec = self.out_proj(context_vec) # optional projection\n",
 | |
|     "\n",
 | |
|     "        return context_vec\n",
 | |
|     "\n",
 | |
|     "torch.manual_seed(123)\n",
 | |
|     "\n",
 | |
|     "batch_size, context_length, d_in = batch.shape\n",
 | |
|     "d_out = 2\n",
 | |
|     "mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads=2)\n",
 | |
|     "\n",
 | |
|     "context_vecs = mha(batch)\n",
 | |
|     "\n",
 | |
|     "print(context_vecs)\n",
 | |
|     "print(\"context_vecs.shape:\", context_vecs.shape)"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "d334dfb5-2b6c-4c33-82d5-b4e9db5867bb",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "- Note that the above is essentially a rewritten version of `MultiHeadAttentionWrapper` that is more efficient\n",
 | |
|     "- The resulting output looks a bit different since the random weight initializations differ, but both are fully functional implementations that can be used in the GPT class we will implement in the upcoming chapters\n",
 | |
|     "- Note that in addition, we added a linear projection layer (`self.out_proj `) to the `MultiHeadAttention` class above. This is simply a linear transformation that doesn't change the dimensions. It's a standard convention to use such a projection layer in LLM implementation, but it's not strictly necessary (recent research has shown that it can be removed without affecting the modeling performance; see the further reading section at the end of this chapter)\n"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "dbe5d396-c990-45dc-9908-2c621461f851",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch03_compressed/26.webp\" width=\"400px\">"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "8b0ed78c-e8ac-4f8f-a479-a98242ae8f65",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "- Note that if you are interested in a compact and efficient implementation of the above, you can also consider the [`torch.nn.MultiheadAttention`](https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html) class in PyTorch"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "363701ad-2022-46c8-9972-390d2a2b9911",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "- Since the above implementation may look a bit complex at first glance, let's look at what happens when executing `attn_scores = queries @ keys.transpose(2, 3)`:"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "code",
 | |
|    "execution_count": 37,
 | |
|    "id": "e8cfc1ae-78ab-4faa-bc73-98bd054806c9",
 | |
|    "metadata": {},
 | |
|    "outputs": [
 | |
|     {
 | |
|      "name": "stdout",
 | |
|      "output_type": "stream",
 | |
|      "text": [
 | |
|       "tensor([[[[1.3208, 1.1631, 1.2879],\n",
 | |
|       "          [1.1631, 2.2150, 1.8424],\n",
 | |
|       "          [1.2879, 1.8424, 2.0402]],\n",
 | |
|       "\n",
 | |
|       "         [[0.4391, 0.7003, 0.5903],\n",
 | |
|       "          [0.7003, 1.3737, 1.0620],\n",
 | |
|       "          [0.5903, 1.0620, 0.9912]]]])\n"
 | |
|      ]
 | |
|     }
 | |
|    ],
 | |
|    "source": [
 | |
|     "# (b, num_heads, num_tokens, head_dim) = (1, 2, 3, 4)\n",
 | |
|     "a = torch.tensor([[[[0.2745, 0.6584, 0.2775, 0.8573],\n",
 | |
|     "                    [0.8993, 0.0390, 0.9268, 0.7388],\n",
 | |
|     "                    [0.7179, 0.7058, 0.9156, 0.4340]],\n",
 | |
|     "\n",
 | |
|     "                   [[0.0772, 0.3565, 0.1479, 0.5331],\n",
 | |
|     "                    [0.4066, 0.2318, 0.4545, 0.9737],\n",
 | |
|     "                    [0.4606, 0.5159, 0.4220, 0.5786]]]])\n",
 | |
|     "\n",
 | |
|     "print(a @ a.transpose(2, 3))"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "0587b946-c8f2-4888-adbf-5a5032fbfd7b",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "- In this case, the matrix multiplication implementation in PyTorch will handle the 4-dimensional input tensor so that the matrix multiplication is carried out between the 2 last dimensions (num_tokens, head_dim) and then repeated for the individual heads \n",
 | |
|     "\n",
 | |
|     "- For instance, the following becomes a more compact way to compute the matrix multiplication for each head separately:"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "code",
 | |
|    "execution_count": 38,
 | |
|    "id": "053760f1-1a02-42f0-b3bf-3d939e407039",
 | |
|    "metadata": {},
 | |
|    "outputs": [
 | |
|     {
 | |
|      "name": "stdout",
 | |
|      "output_type": "stream",
 | |
|      "text": [
 | |
|       "First head:\n",
 | |
|       " tensor([[1.3208, 1.1631, 1.2879],\n",
 | |
|       "        [1.1631, 2.2150, 1.8424],\n",
 | |
|       "        [1.2879, 1.8424, 2.0402]])\n",
 | |
|       "\n",
 | |
|       "Second head:\n",
 | |
|       " tensor([[0.4391, 0.7003, 0.5903],\n",
 | |
|       "        [0.7003, 1.3737, 1.0620],\n",
 | |
|       "        [0.5903, 1.0620, 0.9912]])\n"
 | |
|      ]
 | |
|     }
 | |
|    ],
 | |
|    "source": [
 | |
|     "first_head = a[0, 0, :, :]\n",
 | |
|     "first_res = first_head @ first_head.T\n",
 | |
|     "print(\"First head:\\n\", first_res)\n",
 | |
|     "\n",
 | |
|     "second_head = a[0, 1, :, :]\n",
 | |
|     "second_res = second_head @ second_head.T\n",
 | |
|     "print(\"\\nSecond head:\\n\", second_res)"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "dec671bf-7938-4304-ad1e-75d9920e7f43",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "# Summary and takeaways"
 | |
|    ]
 | |
|   },
 | |
|   {
 | |
|    "cell_type": "markdown",
 | |
|    "id": "fa3e4113-ffca-432c-b3ec-7a50bd15da25",
 | |
|    "metadata": {},
 | |
|    "source": [
 | |
|     "- See the [./multihead-attention.ipynb](./multihead-attention.ipynb) code notebook, which is a concise version of the data loader (chapter 2) plus the multi-head attention class that we implemented in this chapter and will need for training the GPT model in upcoming chapters"
 | |
|    ]
 | |
|   }
 | |
|  ],
 | |
|  "metadata": {
 | |
|   "kernelspec": {
 | |
|    "display_name": "Python 3 (ipykernel)",
 | |
|    "language": "python",
 | |
|    "name": "python3"
 | |
|   },
 | |
|   "language_info": {
 | |
|    "codemirror_mode": {
 | |
|     "name": "ipython",
 | |
|     "version": 3
 | |
|    },
 | |
|    "file_extension": ".py",
 | |
|    "mimetype": "text/x-python",
 | |
|    "name": "python",
 | |
|    "nbconvert_exporter": "python",
 | |
|    "pygments_lexer": "ipython3",
 | |
|    "version": "3.10.6"
 | |
|   }
 | |
|  },
 | |
|  "nbformat": 4,
 | |
|  "nbformat_minor": 5
 | |
| }
 | 
