mirror of
				https://github.com/rasbt/LLMs-from-scratch.git
				synced 2025-10-31 09:50:23 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			1706 lines
		
	
	
		
			56 KiB
		
	
	
	
		
			Plaintext
		
	
	
	
	
	
			
		
		
	
	
			1706 lines
		
	
	
		
			56 KiB
		
	
	
	
		
			Plaintext
		
	
	
	
	
	
| {
 | ||
|  "cells": [
 | ||
|   {
 | ||
|    "cell_type": "markdown",
 | ||
|    "id": "0_xya1nyDHfY",
 | ||
|    "metadata": {
 | ||
|     "id": "0_xya1nyDHfY"
 | ||
|    },
 | ||
|    "source": [
 | ||
|     "<table style=\"width:100%\">\n",
 | ||
|     "<tr>\n",
 | ||
|     "<td style=\"vertical-align:middle; text-align:left;\">\n",
 | ||
|     "<font size=\"2\">\n",
 | ||
|     "Supplementary code for the <a href=\"http://mng.bz/orYv\">Build a Large Language Model From Scratch</a> book by <a href=\"https://sebastianraschka.com\">Sebastian Raschka</a><br>\n",
 | ||
|     "<br>Code repository: <a href=\"https://github.com/rasbt/LLMs-from-scratch\">https://github.com/rasbt/LLMs-from-scratch</a>\n",
 | ||
|     "</font>\n",
 | ||
|     "</td>\n",
 | ||
|     "<td style=\"vertical-align:middle; text-align:left;\">\n",
 | ||
|     "<a href=\"http://mng.bz/orYv\"><img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/cover-small.webp\" width=\"100px\"></a>\n",
 | ||
|     "</td>\n",
 | ||
|     "</tr>\n",
 | ||
|     "</table>"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "markdown",
 | ||
|    "id": "l62zIRRSBy_R",
 | ||
|    "metadata": {
 | ||
|     "id": "l62zIRRSBy_R"
 | ||
|    },
 | ||
|    "source": [
 | ||
|     "# Converting a From-Scratch GPT Architecture to Llama 2"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "markdown",
 | ||
|    "id": "aFmxTQbwCUMl",
 | ||
|    "metadata": {
 | ||
|     "id": "aFmxTQbwCUMl"
 | ||
|    },
 | ||
|    "source": [
 | ||
|     "- In this notebook, we convert the original GPT architecture into a Llama 2 model step by step (note the GPT and GPT-2 share the same architecture)\n",
 | ||
|     "- Why not Llama 1 or Llama 3?\n",
 | ||
|     "   - The Llama 1 architecture is similar to Llama 2, except that Llama 2 has a larger context window (which is nice); the Llama 1 weights are not readily available and have more usage restrictions, so it makes more sense to focus on Llama 2\n",
 | ||
|     "   - Regarding Llama 3, I will share a separate notebook to convert Llama 2 to Llama 3 (there are only a few small additional changes)\n",
 | ||
|     "- The explanations are purposefully kept minimal in this notebook not to bloat it unnecessarily and focus on the main code\n",
 | ||
|     "- For more information, please see the Llama 2 paper: [Llama 2: Open Foundation and Fine-Tuned Chat Models (2023)](https://arxiv.org/abs/2307.09288)"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "markdown",
 | ||
|    "id": "ohhMKUWvGm9z",
 | ||
|    "metadata": {
 | ||
|     "id": "ohhMKUWvGm9z"
 | ||
|    },
 | ||
|    "source": [
 | ||
|     "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/gpt-to-llama/gpt2-to-llama2-llama3.webp?1\">"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "markdown",
 | ||
|    "id": "JBpQwU89ETA1",
 | ||
|    "metadata": {
 | ||
|     "id": "JBpQwU89ETA1"
 | ||
|    },
 | ||
|    "source": [
 | ||
|     "- Packages that are being used in this notebook:"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "code",
 | ||
|    "execution_count": 1,
 | ||
|    "id": "34a9a440-84c2-42cc-808b-38677cb6af8a",
 | ||
|    "metadata": {
 | ||
|     "colab": {
 | ||
|      "base_uri": "https://localhost:8080/"
 | ||
|     },
 | ||
|     "id": "34a9a440-84c2-42cc-808b-38677cb6af8a",
 | ||
|     "outputId": "8118963b-3c72-43af-874b-439ffebdc94c"
 | ||
|    },
 | ||
|    "outputs": [
 | ||
|     {
 | ||
|      "name": "stdout",
 | ||
|      "output_type": "stream",
 | ||
|      "text": [
 | ||
|       "huggingface_hub version: 0.24.7\n",
 | ||
|       "sentencepiece version: 0.2.0\n",
 | ||
|       "torch version: 2.4.1+cu121\n"
 | ||
|      ]
 | ||
|     }
 | ||
|    ],
 | ||
|    "source": [
 | ||
|     "from importlib.metadata import version\n",
 | ||
|     "\n",
 | ||
|     "pkgs = [\n",
 | ||
|     "    \"huggingface_hub\",  # to download pretrained weights\n",
 | ||
|     "    \"sentencepiece\",    # to implement the tokenizer\n",
 | ||
|     "    \"torch\",            # to implement the model\n",
 | ||
|     "]\n",
 | ||
|     "for p in pkgs:\n",
 | ||
|     "    print(f\"{p} version: {version(p)}\")"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "markdown",
 | ||
|    "id": "UJJneXpTEg4W",
 | ||
|    "metadata": {
 | ||
|     "id": "UJJneXpTEg4W"
 | ||
|    },
 | ||
|    "source": [
 | ||
|     " \n",
 | ||
|     "# 1. Convert the GPT model implementation step by step"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "markdown",
 | ||
|    "id": "v1zpfX2GHBKa",
 | ||
|    "metadata": {
 | ||
|     "id": "v1zpfX2GHBKa"
 | ||
|    },
 | ||
|    "source": [
 | ||
|     "- In this section, we go through the GPT model code from [chapter 4](../../ch04/01_main-chapter-code/ch04.ipynb) and modify it step by step to implement the Llama 2 architecture\n",
 | ||
|     "- Later, we load the original Llama 2 weights shared by Meta AI"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "markdown",
 | ||
|    "id": "979c7b6d-1370-4da1-8bfb-a2b27537bf2f",
 | ||
|    "metadata": {
 | ||
|     "id": "979c7b6d-1370-4da1-8bfb-a2b27537bf2f"
 | ||
|    },
 | ||
|    "source": [
 | ||
|     " \n",
 | ||
|     "## 1.1 Replace LayerNorm with RMSNorm layer"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "markdown",
 | ||
|    "id": "f8b27fc8-23a1-4e0e-a1ea-792e0428e5e6",
 | ||
|    "metadata": {
 | ||
|     "id": "f8b27fc8-23a1-4e0e-a1ea-792e0428e5e6"
 | ||
|    },
 | ||
|    "source": [
 | ||
|     "- First, we replace LayerNorm by Root Mean Square Layer Normalization (RMSNorm)\n",
 | ||
|     "- LayerNorm normalizes inputs using mean and variance, while RMSNorm uses only the root mean square, which improves computational efficiency\n",
 | ||
|     "- The RMSNorm operation is as follows, where $x$ is the input $\\gamma$ is a trainable parameter (vector), and $\\epsilon$ is a small constant to avoid zero-division errors:\n",
 | ||
|     "\n",
 | ||
|     "$$y_i = \\frac{x_i}{\\text{RMS}(x)} \\gamma_i, \\quad \\text{where} \\quad \\text{RMS}(x) = \\sqrt{\\epsilon + \\frac{1}{n} \\sum x_i^2}$$\n",
 | ||
|     "\n",
 | ||
|     "- For more details, please see the paper [Root Mean Square Layer Normalization (2019)](https://arxiv.org/abs/1910.07467)"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "code",
 | ||
|    "execution_count": 2,
 | ||
|    "id": "d7094381-9499-4e9e-93f9-b79470da3771",
 | ||
|    "metadata": {
 | ||
|     "id": "d7094381-9499-4e9e-93f9-b79470da3771"
 | ||
|    },
 | ||
|    "outputs": [],
 | ||
|    "source": [
 | ||
|     "import torch\n",
 | ||
|     "import torch.nn as nn\n",
 | ||
|     "\n",
 | ||
|     "\n",
 | ||
|     "#####################################\n",
 | ||
|     "# Chapter 4\n",
 | ||
|     "#####################################\n",
 | ||
|     "\n",
 | ||
|     "# class LayerNorm(nn.Module):\n",
 | ||
|     "#     def __init__(self, emb_dim):\n",
 | ||
|     "#         super().__init__()\n",
 | ||
|     "#         self.eps = 1e-5\n",
 | ||
|     "#         self.scale = nn.Parameter(torch.ones(emb_dim))\n",
 | ||
|     "#         self.shift = nn.Parameter(torch.zeros(emb_dim))\n",
 | ||
|     "\n",
 | ||
|     "#     def forward(self, x):\n",
 | ||
|     "#         mean = x.mean(dim=-1, keepdim=True)\n",
 | ||
|     "#         var = x.var(dim=-1, keepdim=True, unbiased=False)\n",
 | ||
|     "#         norm_x = (x - mean) / torch.sqrt(var + self.eps)\n",
 | ||
|     "#         return self.scale * norm_x + self.shift\n",
 | ||
|     "\n",
 | ||
|     "\n",
 | ||
|     "class RMSNorm(nn.Module):\n",
 | ||
|     "    def __init__(self, emb_dim, eps=1e-5):\n",
 | ||
|     "        super().__init__()\n",
 | ||
|     "        self.eps = eps\n",
 | ||
|     "        self.emb_dim = emb_dim\n",
 | ||
|     "        self.weight = nn.Parameter(torch.ones(emb_dim)).float()\n",
 | ||
|     "\n",
 | ||
|     "    def forward(self, x):\n",
 | ||
|     "        means = x.pow(2).mean(dim=-1, keepdim=True)\n",
 | ||
|     "        x_normed = x * torch.rsqrt(means + self.eps)\n",
 | ||
|     "        return (x_normed * self.weight).to(dtype=x.dtype)"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "markdown",
 | ||
|    "id": "mtWC8DOmIu0F",
 | ||
|    "metadata": {
 | ||
|     "id": "mtWC8DOmIu0F"
 | ||
|    },
 | ||
|    "source": [
 | ||
|     "- The following code cell checks that this implementation works the same as PyTorch's built-in implementation:"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "code",
 | ||
|    "execution_count": 3,
 | ||
|    "id": "e41ade7a-bf06-48b1-8b7e-0e4037d5753f",
 | ||
|    "metadata": {
 | ||
|     "id": "e41ade7a-bf06-48b1-8b7e-0e4037d5753f"
 | ||
|    },
 | ||
|    "outputs": [],
 | ||
|    "source": [
 | ||
|     "torch.manual_seed(123)\n",
 | ||
|     "\n",
 | ||
|     "example_batch = torch.randn(2, 3, 4)\n",
 | ||
|     "\n",
 | ||
|     "rms_norm = RMSNorm(emb_dim=example_batch.shape[-1])\n",
 | ||
|     "rmsnorm_pytorch = torch.nn.RMSNorm(example_batch.shape[-1], eps=1e-5)\n",
 | ||
|     "\n",
 | ||
|     "assert torch.allclose(rms_norm(example_batch), rmsnorm_pytorch(example_batch))"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "markdown",
 | ||
|    "id": "5eb81f83-c38c-46a4-b763-aa630a32e357",
 | ||
|    "metadata": {
 | ||
|     "id": "5eb81f83-c38c-46a4-b763-aa630a32e357"
 | ||
|    },
 | ||
|    "source": [
 | ||
|     " \n",
 | ||
|     "## 1.2 Replace GELU with SiLU activation"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "markdown",
 | ||
|    "id": "0b8aa702-f118-4ff6-9135-90725ec8756c",
 | ||
|    "metadata": {
 | ||
|     "id": "0b8aa702-f118-4ff6-9135-90725ec8756c"
 | ||
|    },
 | ||
|    "source": [
 | ||
|     "- Llama uses the SiLU activation function (instead of GELU), which is also known as the Swish function:\n",
 | ||
|     "\n",
 | ||
|     "$$\n",
 | ||
|     "\\text{silu}(x) = x \\cdot \\sigma(x), \\quad \\text{where} \\quad \\sigma(x) \\text{ is the logistic sigmoid.}\n",
 | ||
|     "$$\n",
 | ||
|     "\n",
 | ||
|     "- For more information, see the SiLU paper: [Sigmoid-Weighted Linear Units for Neural Network Function Approximation in Reinforcement Learning (2017)](https://arxiv.org/abs/1702.03118)"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "code",
 | ||
|    "execution_count": 4,
 | ||
|    "id": "a74f3757-c634-4a3a-a8f3-6334cde454fe",
 | ||
|    "metadata": {
 | ||
|     "id": "a74f3757-c634-4a3a-a8f3-6334cde454fe"
 | ||
|    },
 | ||
|    "outputs": [],
 | ||
|    "source": [
 | ||
|     "#####################################\n",
 | ||
|     "# Chapter 4\n",
 | ||
|     "#####################################\n",
 | ||
|     "\n",
 | ||
|     "# class GELU(nn.Module):\n",
 | ||
|     "#     def __init__(self):\n",
 | ||
|     "#         super().__init__()\n",
 | ||
|     "\n",
 | ||
|     "#     def forward(self, x):\n",
 | ||
|     "#         return 0.5 * x * (1 + torch.tanh(\n",
 | ||
|     "#             torch.sqrt(torch.tensor(2.0 / torch.pi)) *\n",
 | ||
|     "#             (x + 0.044715 * torch.pow(x, 3))\n",
 | ||
|     "#         ))\n",
 | ||
|     "\n",
 | ||
|     "\n",
 | ||
|     "class SiLU(nn.Module):\n",
 | ||
|     "    def __init__(self):\n",
 | ||
|     "        super(SiLU, self).__init__()\n",
 | ||
|     "\n",
 | ||
|     "    def forward(self, x):\n",
 | ||
|     "        return x * torch.sigmoid(x)"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "code",
 | ||
|    "execution_count": 5,
 | ||
|    "id": "72ecbe2e-b6b7-4319-972b-1a7fefa3368c",
 | ||
|    "metadata": {
 | ||
|     "id": "72ecbe2e-b6b7-4319-972b-1a7fefa3368c"
 | ||
|    },
 | ||
|    "outputs": [],
 | ||
|    "source": [
 | ||
|     "silu = SiLU()\n",
 | ||
|     "\n",
 | ||
|     "assert torch.allclose(silu(example_batch), torch.nn.functional.silu(example_batch))"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "markdown",
 | ||
|    "id": "4f9b5167-1da9-46c8-9964-8036b3b1deb9",
 | ||
|    "metadata": {
 | ||
|     "id": "4f9b5167-1da9-46c8-9964-8036b3b1deb9"
 | ||
|    },
 | ||
|    "source": [
 | ||
|     " \n",
 | ||
|     "## 1.3 Update the FeedForward module"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "markdown",
 | ||
|    "id": "3a381e7a-b807-472e-91c9-3e4e3fc5ad91",
 | ||
|    "metadata": {
 | ||
|     "id": "3a381e7a-b807-472e-91c9-3e4e3fc5ad91"
 | ||
|    },
 | ||
|    "source": [
 | ||
|     "- In fact, Llama uses a \"Gates Linear Unit\" (GLU) variant of SiLU called SwiGLU, which essentially results in a slightly differently structured `FeedForward` module\n",
 | ||
|     "- SwiGLU uses a gating mechanism in the feedforward layer, with the formula:\n",
 | ||
|     "\n",
 | ||
|     "$$\\text{SwiGLU}(x) = \\text{SiLU}(\\text{Linear}_1(x)) * (\\text{Linear}_2(x))$$\n",
 | ||
|     "\n",
 | ||
|     "- Here, $\\text{Linear}_1$ and $\\text{Linear}_2$ are two linear layers, and $*$ denotes element-wise multiplication\n",
 | ||
|     "- The third linear layer, $\\text{Linear}_3$, is applied after this gated activation\n",
 | ||
|     "\n",
 | ||
|     "- For more information, see SwiGLU paper: [GLU Variants Improve Transformer (2020)](https://arxiv.org/abs/2002.05202)"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "code",
 | ||
|    "execution_count": 6,
 | ||
|    "id": "d25fbe3d-b7c9-4772-ad67-bc0527e4e20a",
 | ||
|    "metadata": {
 | ||
|     "id": "d25fbe3d-b7c9-4772-ad67-bc0527e4e20a"
 | ||
|    },
 | ||
|    "outputs": [],
 | ||
|    "source": [
 | ||
|     "#####################################\n",
 | ||
|     "# Chapter 4\n",
 | ||
|     "#####################################\n",
 | ||
|     "# class FeedForward(nn.Module):\n",
 | ||
|     "#     def __init__(self, cfg):\n",
 | ||
|     "#         super().__init__()\n",
 | ||
|     "#         self.layers = nn.Sequential(\n",
 | ||
|     "#             nn.Linear(cfg[\"emb_dim\"], 4 * cfg[\"emb_dim\"]),\n",
 | ||
|     "#             GELU(),\n",
 | ||
|     "#             nn.Linear(4 * cfg[\"emb_dim\"], cfg[\"emb_dim\"]),\n",
 | ||
|     "#         )\n",
 | ||
|     "\n",
 | ||
|     "#     def forward(self, x):\n",
 | ||
|     "#         return self.layers(x)"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "code",
 | ||
|    "execution_count": 7,
 | ||
|    "id": "477568cb-03cd-4510-b663-a42ce3ec64a2",
 | ||
|    "metadata": {
 | ||
|     "id": "477568cb-03cd-4510-b663-a42ce3ec64a2"
 | ||
|    },
 | ||
|    "outputs": [],
 | ||
|    "source": [
 | ||
|     "class FeedForward(nn.Module):\n",
 | ||
|     "    def __init__(self, cfg):\n",
 | ||
|     "        super().__init__()\n",
 | ||
|     "        self.fc1 = nn.Linear(cfg[\"emb_dim\"], cfg[\"hidden_dim\"], dtype=cfg[\"dtype\"], bias=False)\n",
 | ||
|     "        self.fc2 = nn.Linear(cfg[\"emb_dim\"], cfg[\"hidden_dim\"], dtype=cfg[\"dtype\"], bias=False)\n",
 | ||
|     "        self.fc3 = nn.Linear(cfg[\"hidden_dim\"], cfg[\"emb_dim\"], dtype=cfg[\"dtype\"], bias=False)\n",
 | ||
|     "        self.silu = SiLU()\n",
 | ||
|     "\n",
 | ||
|     "    def forward(self, x):\n",
 | ||
|     "        x_fc1 = self.fc1(x)\n",
 | ||
|     "        x_fc2 = self.fc2(x)\n",
 | ||
|     "        x = self.silu(x_fc1) * x_fc2\n",
 | ||
|     "        return self.fc3(x)"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "markdown",
 | ||
|    "id": "qcD8LSHNhBRW",
 | ||
|    "metadata": {
 | ||
|     "id": "qcD8LSHNhBRW"
 | ||
|    },
 | ||
|    "source": [
 | ||
|     "- Note that we also added a `dtype=cfg[\"dtype\"]` setting above, which will allow us to load the model directly in lower precision formats later to reduce memory usage (versus instantiating it in the original 32-bit precision format and then converting it)\n",
 | ||
|     "- We also set `bias=False` since Llama doesn't use any bias units"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "markdown",
 | ||
|    "id": "f6b7bf4f-99d0-42c1-807c-5074d2cc1949",
 | ||
|    "metadata": {
 | ||
|     "id": "f6b7bf4f-99d0-42c1-807c-5074d2cc1949"
 | ||
|    },
 | ||
|    "source": [
 | ||
|     " \n",
 | ||
|     "## 1.4 Implement RoPE"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "markdown",
 | ||
|    "id": "d3487a6f-0373-49d8-b2eb-f8ee05d42884",
 | ||
|    "metadata": {
 | ||
|     "id": "d3487a6f-0373-49d8-b2eb-f8ee05d42884"
 | ||
|    },
 | ||
|    "source": [
 | ||
|     "- In the GPT model, the positional embeddings are implemented as follows:\n",
 | ||
|     "\n",
 | ||
|     "```python\n",
 | ||
|     "self.pos_emb = nn.Embedding(cfg[\"context_length\"], cfg[\"emb_dim\"])\n",
 | ||
|     "```\n",
 | ||
|     "\n",
 | ||
|     "- Unlike traditional absolute positional embeddings, Llama uses rotary position embeddings (RoPE), which enable it to capture both absolute and relative positional information simultaneously\n",
 | ||
|     "- The reference paper for RoPE is [RoFormer: Enhanced Transformer with Rotary Position Embedding (2021)](https://arxiv.org/abs/2104.09864)"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "code",
 | ||
|    "execution_count": 8,
 | ||
|    "id": "a34180fb-448f-44e9-a244-0c736051687b",
 | ||
|    "metadata": {
 | ||
|     "id": "a34180fb-448f-44e9-a244-0c736051687b"
 | ||
|    },
 | ||
|    "outputs": [],
 | ||
|    "source": [
 | ||
|     "def precompute_rope_params(head_dim, theta_base=10_000, context_length=4096):\n",
 | ||
|     "    assert head_dim % 2 == 0, \"Embedding dimension must be even\"\n",
 | ||
|     "\n",
 | ||
|     "    # Compute the inverse frequencies\n",
 | ||
|     "    inv_freq = 1.0 / (theta_base ** (torch.arange(0, head_dim, 2)[: (head_dim // 2)].float() / head_dim))\n",
 | ||
|     "\n",
 | ||
|     "    # Generate position indices\n",
 | ||
|     "    positions = torch.arange(context_length)\n",
 | ||
|     "\n",
 | ||
|     "    # Compute the angles\n",
 | ||
|     "    angles = positions[:, None] * inv_freq[None, :]  # Shape: (context_length, head_dim // 2)\n",
 | ||
|     "\n",
 | ||
|     "    # Expand angles to match the head_dim\n",
 | ||
|     "    angles = torch.cat([angles, angles], dim=1)  # Shape: (context_length, head_dim)\n",
 | ||
|     "\n",
 | ||
|     "    # Precompute sine and cosine\n",
 | ||
|     "    cos = torch.cos(angles)\n",
 | ||
|     "    sin = torch.sin(angles)\n",
 | ||
|     "\n",
 | ||
|     "    return cos, sin\n",
 | ||
|     "\n",
 | ||
|     "def compute_rope(x, cos, sin):\n",
 | ||
|     "    # x: (batch_size, num_heads, seq_len, head_dim)\n",
 | ||
|     "    batch_size, num_heads, seq_len, head_dim = x.shape\n",
 | ||
|     "    assert head_dim % 2 == 0, \"Head dimension must be even\"\n",
 | ||
|     "\n",
 | ||
|     "    # Split x into first half and second half\n",
 | ||
|     "    x1 = x[..., : head_dim // 2]  # First half\n",
 | ||
|     "    x2 = x[..., head_dim // 2 :]  # Second half\n",
 | ||
|     "\n",
 | ||
|     "    # Adjust sin and cos shapes\n",
 | ||
|     "    cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0)  # Shape: (1, 1, seq_len, head_dim)\n",
 | ||
|     "    sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0)\n",
 | ||
|     "\n",
 | ||
|     "    # Apply the rotary transformation\n",
 | ||
|     "    rotated = torch.cat((-x2, x1), dim=-1)\n",
 | ||
|     "    x_rotated = (x * cos) + (rotated * sin)\n",
 | ||
|     "\n",
 | ||
|     "    return x_rotated.to(dtype=x.dtype)"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "markdown",
 | ||
|    "id": "8e841b8e-75aa-49db-b1a7-d5c2116dc299",
 | ||
|    "metadata": {
 | ||
|     "id": "8e841b8e-75aa-49db-b1a7-d5c2116dc299"
 | ||
|    },
 | ||
|    "source": [
 | ||
|     "- The following is an example of applying RoPE to the `q` and `k` tensors:"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "code",
 | ||
|    "execution_count": 9,
 | ||
|    "id": "8c89f022-7167-4001-8c21-8e012878733f",
 | ||
|    "metadata": {
 | ||
|     "id": "8c89f022-7167-4001-8c21-8e012878733f"
 | ||
|    },
 | ||
|    "outputs": [],
 | ||
|    "source": [
 | ||
|     "# Settings\n",
 | ||
|     "batch_size = 2\n",
 | ||
|     "context_len = 5\n",
 | ||
|     "num_heads = 4\n",
 | ||
|     "head_dim = 16\n",
 | ||
|     "\n",
 | ||
|     "# Instantiate RoPE parameters\n",
 | ||
|     "cos, sin = precompute_rope_params(head_dim=head_dim, context_length=context_len)\n",
 | ||
|     "\n",
 | ||
|     "# Dummy query and key tensors\n",
 | ||
|     "torch.manual_seed(123)\n",
 | ||
|     "queries = torch.randn(batch_size, num_heads, context_len, head_dim)\n",
 | ||
|     "keys = torch.randn(batch_size, num_heads, context_len, head_dim)\n",
 | ||
|     "\n",
 | ||
|     "# Apply rotary position embeddings\n",
 | ||
|     "queries_rot = compute_rope(queries, cos, sin)\n",
 | ||
|     "keys_rot = compute_rope(keys, cos, sin)"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "markdown",
 | ||
|    "id": "f78127b0-dda2-4c5a-98dd-bae8f5fe8297",
 | ||
|    "metadata": {
 | ||
|     "id": "f78127b0-dda2-4c5a-98dd-bae8f5fe8297"
 | ||
|    },
 | ||
|    "source": [
 | ||
|     " \n",
 | ||
|     "## 1.5 Add RoPE to MultiHeadAttention module"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "markdown",
 | ||
|    "id": "RnmSHROLhhR3",
 | ||
|    "metadata": {
 | ||
|     "id": "RnmSHROLhhR3"
 | ||
|    },
 | ||
|    "source": [
 | ||
|     "- It's important to note that GPT applies the positional embeddings to the inputs, whereas Llama applies rotations to the query and key vectors in the self-attention mechanism itself\n",
 | ||
|     "- Here, we modify the `MultiHeadAttention` class with the appropriate RoPE code\n",
 | ||
|     "- In addition, we remove the `qkv_bias` option and hardcode the `bias=False` setting\n",
 | ||
|     "- Also, we add a dtype setting to be able to instantiate the model with a lower precision later\n",
 | ||
|     " - Tip: since the `TransformerBlock`s (in the next section) are repeated exactly, we could simplify the code and only initialize the buffers once instead for each `MultiHeadAttention` module; however, we add the precomputed RoPE parameters to the `MultiHeadAttention` class so that it can function as a standalone module"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "code",
 | ||
|    "execution_count": 10,
 | ||
|    "id": "d81a441e-0b79-4a8b-8291-ea7f55d58c84",
 | ||
|    "metadata": {
 | ||
|     "id": "d81a441e-0b79-4a8b-8291-ea7f55d58c84"
 | ||
|    },
 | ||
|    "outputs": [],
 | ||
|    "source": [
 | ||
|     "#####################################\n",
 | ||
|     "# Chapter 3\n",
 | ||
|     "#####################################\n",
 | ||
|     "class MultiHeadAttention(nn.Module):\n",
 | ||
|     "    def __init__(self, d_in, d_out, context_length, num_heads, dtype=None):  # ,dropout, num_heads, qkv_bias=False):\n",
 | ||
|     "        super().__init__()\n",
 | ||
|     "        assert d_out % num_heads == 0, \"d_out must be divisible by n_heads\"\n",
 | ||
|     "\n",
 | ||
|     "        self.d_out = d_out\n",
 | ||
|     "        self.num_heads = num_heads\n",
 | ||
|     "        self.head_dim = d_out // num_heads  # Reduce the projection dim to match desired output dim\n",
 | ||
|     "\n",
 | ||
|     "        ################################### NEW ###################################\n",
 | ||
|     "        # Set bias=False and dtype=dtype for all linear layers below\n",
 | ||
|     "        ###########################################################################\n",
 | ||
|     "        self.W_query = nn.Linear(d_in, d_out, bias=False, dtype=dtype)\n",
 | ||
|     "        self.W_key = nn.Linear(d_in, d_out, bias=False, dtype=dtype)\n",
 | ||
|     "        self.W_value = nn.Linear(d_in, d_out, bias=False, dtype=dtype)\n",
 | ||
|     "        self.out_proj = nn.Linear(d_out, d_out, bias=False, dtype=dtype)  # Linear layer to combine head outputs\n",
 | ||
|     "        # self.dropout = nn.Dropout(dropout)\n",
 | ||
|     "        self.register_buffer(\"mask\", torch.triu(torch.ones(context_length, context_length), diagonal=1))\n",
 | ||
|     "\n",
 | ||
|     "        ################################### NEW ###################################\n",
 | ||
|     "        cos, sin = precompute_rope_params(head_dim=self.head_dim, context_length=context_length)\n",
 | ||
|     "        self.register_buffer(\"cos\", cos)\n",
 | ||
|     "        self.register_buffer(\"sin\", sin)\n",
 | ||
|     "        ###########################################################################\n",
 | ||
|     "\n",
 | ||
|     "\n",
 | ||
|     "    def forward(self, x):\n",
 | ||
|     "\n",
 | ||
|     "        b, num_tokens, d_in = x.shape\n",
 | ||
|     "\n",
 | ||
|     "        keys = self.W_key(x)  # Shape: (b, num_tokens, d_out)\n",
 | ||
|     "        queries = self.W_query(x)\n",
 | ||
|     "        values = self.W_value(x)\n",
 | ||
|     "\n",
 | ||
|     "        # We implicitly split the matrix by adding a `num_heads` dimension\n",
 | ||
|     "        # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)\n",
 | ||
|     "        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)\n",
 | ||
|     "        values = values.view(b, num_tokens, self.num_heads, self.head_dim)\n",
 | ||
|     "        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)\n",
 | ||
|     "\n",
 | ||
|     "        # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)\n",
 | ||
|     "        keys = keys.transpose(1, 2)\n",
 | ||
|     "        queries = queries.transpose(1, 2)\n",
 | ||
|     "        values = values.transpose(1, 2)\n",
 | ||
|     "\n",
 | ||
|     "        ################################### NEW ###################################\n",
 | ||
|     "        keys = compute_rope(keys, self.cos, self.sin)\n",
 | ||
|     "        queries = compute_rope(queries, self.cos, self.sin)\n",
 | ||
|     "        ###########################################################################\n",
 | ||
|     "\n",
 | ||
|     "        # Compute scaled dot-product attention (aka self-attention) with a causal mask\n",
 | ||
|     "        attn_scores = queries @ keys.transpose(2, 3)  # Dot product for each head\n",
 | ||
|     "\n",
 | ||
|     "        # Original mask truncated to the number of tokens and converted to boolean\n",
 | ||
|     "        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]\n",
 | ||
|     "\n",
 | ||
|     "        # Use the mask to fill attention scores\n",
 | ||
|     "        attn_scores.masked_fill_(mask_bool, -torch.inf)\n",
 | ||
|     "\n",
 | ||
|     "        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)\n",
 | ||
|     "        # attn_weights = self.dropout(attn_weights)\n",
 | ||
|     "\n",
 | ||
|     "        # Shape: (b, num_tokens, num_heads, head_dim)\n",
 | ||
|     "        context_vec = (attn_weights @ values).transpose(1, 2)\n",
 | ||
|     "\n",
 | ||
|     "        # Combine heads, where self.d_out = self.num_heads * self.head_dim\n",
 | ||
|     "        context_vec = context_vec.reshape(b, num_tokens, self.d_out)\n",
 | ||
|     "        context_vec = self.out_proj(context_vec)  # optional projection\n",
 | ||
|     "\n",
 | ||
|     "        return context_vec"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "markdown",
 | ||
|    "id": "-lt9SfnVioB3",
 | ||
|    "metadata": {
 | ||
|     "id": "-lt9SfnVioB3"
 | ||
|    },
 | ||
|    "source": [
 | ||
|     "- Below is an example using the `MultiHeadAttention` module on an example input:"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "code",
 | ||
|    "execution_count": 11,
 | ||
|    "id": "03f15755-0083-483f-963b-99b599651638",
 | ||
|    "metadata": {
 | ||
|     "id": "03f15755-0083-483f-963b-99b599651638"
 | ||
|    },
 | ||
|    "outputs": [],
 | ||
|    "source": [
 | ||
|     "# Settings\n",
 | ||
|     "batch_size = 1\n",
 | ||
|     "context_len = 100\n",
 | ||
|     "max_context_len = 4096\n",
 | ||
|     "embed_dim = 128\n",
 | ||
|     "num_heads = 4\n",
 | ||
|     "\n",
 | ||
|     "\n",
 | ||
|     "example_batch = torch.randn((batch_size, context_len, embed_dim))\n",
 | ||
|     "\n",
 | ||
|     "mha = MultiHeadAttention(\n",
 | ||
|     "    d_in=embed_dim,\n",
 | ||
|     "    d_out=embed_dim,\n",
 | ||
|     "    context_length=max_context_len,\n",
 | ||
|     "    num_heads=num_heads\n",
 | ||
|     ")\n",
 | ||
|     "\n",
 | ||
|     "mha(example_batch)\n",
 | ||
|     "\n",
 | ||
|     "del mha  # delete to free up memory"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "markdown",
 | ||
|    "id": "e5a1a272-a038-4b8f-aaaa-f4b241e7f23f",
 | ||
|    "metadata": {
 | ||
|     "id": "e5a1a272-a038-4b8f-aaaa-f4b241e7f23f"
 | ||
|    },
 | ||
|    "source": [
 | ||
|     " \n",
 | ||
|     "## 1.6 Update the TransformerBlock module"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "markdown",
 | ||
|    "id": "255f70ac-9c2e-4328-8af7-1c298b8d4a18",
 | ||
|    "metadata": {
 | ||
|     "id": "255f70ac-9c2e-4328-8af7-1c298b8d4a18"
 | ||
|    },
 | ||
|    "source": [
 | ||
|     "- At this stage, most of the hard work is already done; we can now update the `TransformerBlock` to use the code we implemented above\n",
 | ||
|     "- This means we\n",
 | ||
|     " - replace LayerNorm with RMSNorm\n",
 | ||
|     " - remove dropout\n",
 | ||
|     " - remove the `qkv_bias` setting\n",
 | ||
|     " - add the `dtype` setting"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "code",
 | ||
|    "execution_count": 12,
 | ||
|    "id": "2e110721-bf2b-42b3-989a-1635b1658af0",
 | ||
|    "metadata": {
 | ||
|     "id": "2e110721-bf2b-42b3-989a-1635b1658af0"
 | ||
|    },
 | ||
|    "outputs": [],
 | ||
|    "source": [
 | ||
|     "class TransformerBlock(nn.Module):\n",
 | ||
|     "    def __init__(self, cfg):\n",
 | ||
|     "        super().__init__()\n",
 | ||
|     "        self.att = MultiHeadAttention(\n",
 | ||
|     "            d_in=cfg[\"emb_dim\"],\n",
 | ||
|     "            d_out=cfg[\"emb_dim\"],\n",
 | ||
|     "            context_length=cfg[\"context_length\"],\n",
 | ||
|     "            num_heads=cfg[\"n_heads\"],\n",
 | ||
|     "            dtype=cfg[\"dtype\"]  # NEW\n",
 | ||
|     "            # dropout=cfg[\"drop_rate\"],\n",
 | ||
|     "            # qkv_bias=cfg[\"qkv_bias\"]\n",
 | ||
|     "        )\n",
 | ||
|     "        self.ff = FeedForward(cfg)\n",
 | ||
|     "\n",
 | ||
|     "        ################################### NEW ###################################\n",
 | ||
|     "        # self.norm1 = LayerNorm(cfg[\"emb_dim\"])\n",
 | ||
|     "        # self.norm2 = LayerNorm(cfg[\"emb_dim\"])\n",
 | ||
|     "        self.norm1 = RMSNorm(cfg[\"emb_dim\"])\n",
 | ||
|     "        self.norm2 = RMSNorm(cfg[\"emb_dim\"])\n",
 | ||
|     "        ###########################################################################\n",
 | ||
|     "\n",
 | ||
|     "        # self.drop_shortcut = nn.Dropout(cfg[\"drop_rate\"])\n",
 | ||
|     "\n",
 | ||
|     "    def forward(self, x):\n",
 | ||
|     "        # Shortcut connection for attention block\n",
 | ||
|     "        shortcut = x\n",
 | ||
|     "        x = self.norm1(x)\n",
 | ||
|     "        x = self.att(x)   # Shape [batch_size, num_tokens, emb_size]\n",
 | ||
|     "        # x = self.drop_shortcut(x)\n",
 | ||
|     "        x = x + shortcut  # Add the original input back\n",
 | ||
|     "\n",
 | ||
|     "        # Shortcut connection for feed-forward block\n",
 | ||
|     "        shortcut = x\n",
 | ||
|     "        x = self.norm2(x)\n",
 | ||
|     "        x = self.ff(x)\n",
 | ||
|     "        # x = self.drop_shortcut(x)\n",
 | ||
|     "        x = x + shortcut  # Add the original input back\n",
 | ||
|     "\n",
 | ||
|     "        return x"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "markdown",
 | ||
|    "id": "ada953bc-e2c0-4432-a32d-3f7efa3f6e0f",
 | ||
|    "metadata": {
 | ||
|     "id": "ada953bc-e2c0-4432-a32d-3f7efa3f6e0f"
 | ||
|    },
 | ||
|    "source": [
 | ||
|     " \n",
 | ||
|     "## 1.7 Update the model class"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "markdown",
 | ||
|    "id": "ba5d991a-559b-47be-96f4-31b881ab2da8",
 | ||
|    "metadata": {
 | ||
|     "id": "ba5d991a-559b-47be-96f4-31b881ab2da8"
 | ||
|    },
 | ||
|    "source": [
 | ||
|     "- As you may recall from [chapter 5](../01_main-chapter-code/ch05.ipynb), the `TransformerBlock` is a repeated block within the main model\n",
 | ||
|     "- Our Llama model is almost complete; we just have to update the model code surrounding the `TransformerBlock`\n",
 | ||
|     "- This means we\n",
 | ||
|     "  - remove absolute positional embeddings since we have RoPE embeddings now\n",
 | ||
|     "  - replace LayerNorm with RMSNorm\n",
 | ||
|     "  - remove dropout\n",
 | ||
|     "  - add the dtype setting"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "code",
 | ||
|    "execution_count": 13,
 | ||
|    "id": "cf8240fe-5d7f-4e7e-b1ac-e0755aab5e79",
 | ||
|    "metadata": {
 | ||
|     "id": "cf8240fe-5d7f-4e7e-b1ac-e0755aab5e79"
 | ||
|    },
 | ||
|    "outputs": [],
 | ||
|    "source": [
 | ||
|     "# class GPTModel(nn.Module):\n",
 | ||
|     "class Llama2Model(nn.Module):\n",
 | ||
|     "    def __init__(self, cfg):\n",
 | ||
|     "        super().__init__()\n",
 | ||
|     "        self.tok_emb = nn.Embedding(cfg[\"vocab_size\"], cfg[\"emb_dim\"], dtype=cfg[\"dtype\"])\n",
 | ||
|     "        # self.pos_emb = nn.Embedding(cfg[\"context_length\"], cfg[\"emb_dim\"])\n",
 | ||
|     "        # self.drop_emb = nn.Dropout(cfg[\"drop_rate\"])\n",
 | ||
|     "\n",
 | ||
|     "        self.trf_blocks = nn.Sequential(\n",
 | ||
|     "            *[TransformerBlock(cfg) for _ in range(cfg[\"n_layers\"])])\n",
 | ||
|     "\n",
 | ||
|     "        ################################### NEW ###################################\n",
 | ||
|     "        # self.final_norm = LayerNorm(cfg[\"emb_dim\"])\n",
 | ||
|     "        self.final_norm = RMSNorm(cfg[\"emb_dim\"])\n",
 | ||
|     "        ###########################################################################\n",
 | ||
|     "        self.out_head = nn.Linear(cfg[\"emb_dim\"], cfg[\"vocab_size\"], bias=False, dtype=cfg[\"dtype\"])\n",
 | ||
|     "\n",
 | ||
|     "    def forward(self, in_idx):\n",
 | ||
|     "        # batch_size, seq_len = in_idx.shape\n",
 | ||
|     "        tok_embeds = self.tok_emb(in_idx)\n",
 | ||
|     "        # pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))\n",
 | ||
|     "        x = tok_embeds  # + pos_embeds  # Shape [batch_size, num_tokens, emb_size]\n",
 | ||
|     "        # x = self.drop_emb(x)\n",
 | ||
|     "        x = self.trf_blocks(x)\n",
 | ||
|     "        x = self.final_norm(x)\n",
 | ||
|     "        logits = self.out_head(x)\n",
 | ||
|     "        return logits"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "markdown",
 | ||
|    "id": "4bc94940-aaeb-45b9-9399-3a69b8043e60",
 | ||
|    "metadata": {
 | ||
|     "id": "4bc94940-aaeb-45b9-9399-3a69b8043e60"
 | ||
|    },
 | ||
|    "source": [
 | ||
|     " \n",
 | ||
|     "## 2. Initialize model"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "markdown",
 | ||
|    "id": "bG--zY-Ljj1f",
 | ||
|    "metadata": {
 | ||
|     "id": "bG--zY-Ljj1f"
 | ||
|    },
 | ||
|    "source": [
 | ||
|     "- The model code is now complete, and we are ready to initialize it\n",
 | ||
|     "- In [chapter 5](../01_main-chapter-code/ch05.ipynb), we used the following config file to specify the 124M-parameter GPT model:"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "code",
 | ||
|    "execution_count": 14,
 | ||
|    "id": "4b7428df-3d02-4ccd-97b5-a629bdabbe8f",
 | ||
|    "metadata": {
 | ||
|     "id": "4b7428df-3d02-4ccd-97b5-a629bdabbe8f"
 | ||
|    },
 | ||
|    "outputs": [],
 | ||
|    "source": [
 | ||
|     "GPT_CONFIG_124M = {\n",
 | ||
|     "    \"vocab_size\": 50257,     # Vocabulary size\n",
 | ||
|     "    \"context_length\": 1024,  # Context length\n",
 | ||
|     "    \"emb_dim\": 768,          # Embedding dimension\n",
 | ||
|     "    \"n_heads\": 12,           # Number of attention heads\n",
 | ||
|     "    \"n_layers\": 12,          # Number of layers\n",
 | ||
|     "    \"drop_rate\": 0.1,        # Dropout rate\n",
 | ||
|     "    \"qkv_bias\": False        # Query-Key-Value bias\n",
 | ||
|     "}"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "markdown",
 | ||
|    "id": "8bVi8uiBjw2T",
 | ||
|    "metadata": {
 | ||
|     "id": "8bVi8uiBjw2T"
 | ||
|    },
 | ||
|    "source": [
 | ||
|     "- For reference, the 1.5B parameter GPT model config is shown below as well:"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "code",
 | ||
|    "execution_count": 15,
 | ||
|    "id": "tAOojV_mkEnd",
 | ||
|    "metadata": {
 | ||
|     "id": "tAOojV_mkEnd"
 | ||
|    },
 | ||
|    "outputs": [],
 | ||
|    "source": [
 | ||
|     "GPT_CONFIG_1558M = {\n",
 | ||
|     "    \"vocab_size\": 50257,     # Vocabulary size\n",
 | ||
|     "    \"context_length\": 1024,  # Context length\n",
 | ||
|     "    \"emb_dim\": 1600,         # Embedding dimension\n",
 | ||
|     "    \"n_heads\": 25,           # Number of attention heads\n",
 | ||
|     "    \"n_layers\": 48,          # Number of layers\n",
 | ||
|     "    \"drop_rate\": 0.1,        # Dropout rate\n",
 | ||
|     "    \"qkv_bias\": False        # Query-Key-Value bias\n",
 | ||
|     "}"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "markdown",
 | ||
|    "id": "HoGGRAGykQTE",
 | ||
|    "metadata": {
 | ||
|     "id": "HoGGRAGykQTE"
 | ||
|    },
 | ||
|    "source": [
 | ||
|     "- Similarly, we can define a Llama 2 config file for the 7B model (we ignore the other larger models for simplicity here):"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "code",
 | ||
|    "execution_count": 16,
 | ||
|    "id": "e0564727-2d35-4f0c-b0fc-cde1e9134a18",
 | ||
|    "metadata": {
 | ||
|     "id": "e0564727-2d35-4f0c-b0fc-cde1e9134a18"
 | ||
|    },
 | ||
|    "outputs": [],
 | ||
|    "source": [
 | ||
|     "LLAMA2_CONFIG_7B = {\n",
 | ||
|     "    \"vocab_size\": 32000,     # Vocabulary size\n",
 | ||
|     "    \"context_length\": 4096,  # Context length\n",
 | ||
|     "    \"emb_dim\": 4096,         # Embedding dimension\n",
 | ||
|     "    \"n_heads\": 32,           # Number of attention heads\n",
 | ||
|     "    \"n_layers\": 32,          # Number of layers\n",
 | ||
|     "    \"hidden_dim\": 11008,     # NEW: Size of the intermediate dimension in FeedForward\n",
 | ||
|     "    \"dtype\": torch.bfloat16  # NEW: Lower-precision dtype to reduce memory usage\n",
 | ||
|     "}"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "markdown",
 | ||
|    "id": "FAP7fiBzkaBz",
 | ||
|    "metadata": {
 | ||
|     "id": "FAP7fiBzkaBz"
 | ||
|    },
 | ||
|    "source": [
 | ||
|     "- Using these settings, we can now initialize a Llama 2 7B model (note that this requires ~26 GB of memory)"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "code",
 | ||
|    "execution_count": 17,
 | ||
|    "id": "7004d785-ac9a-4df5-8760-6807fc604686",
 | ||
|    "metadata": {
 | ||
|     "id": "7004d785-ac9a-4df5-8760-6807fc604686"
 | ||
|    },
 | ||
|    "outputs": [],
 | ||
|    "source": [
 | ||
|     "model = Llama2Model(LLAMA2_CONFIG_7B)"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "code",
 | ||
|    "execution_count": 18,
 | ||
|    "id": "6079f747-8f20-4c6b-8d38-7156f1101729",
 | ||
|    "metadata": {
 | ||
|     "colab": {
 | ||
|      "base_uri": "https://localhost:8080/"
 | ||
|     },
 | ||
|     "id": "6079f747-8f20-4c6b-8d38-7156f1101729",
 | ||
|     "outputId": "0a0eb34b-1a21-4c11-804f-b40007bda5a3"
 | ||
|    },
 | ||
|    "outputs": [
 | ||
|     {
 | ||
|      "name": "stdout",
 | ||
|      "output_type": "stream",
 | ||
|      "text": [
 | ||
|       "Total number of parameters: 6,738,415,616\n"
 | ||
|      ]
 | ||
|     }
 | ||
|    ],
 | ||
|    "source": [
 | ||
|     "total_params = sum(p.numel() for p in model.parameters())\n",
 | ||
|     "print(f\"Total number of parameters: {total_params:,}\")"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "markdown",
 | ||
|    "id": "Bx14NtzWk2wj",
 | ||
|    "metadata": {
 | ||
|     "id": "Bx14NtzWk2wj"
 | ||
|    },
 | ||
|    "source": [
 | ||
|     "- As shown above, the model contains 6.7 billion parameters (commonly rounded and referred to as a 7B model)\n",
 | ||
|     "- Additionally, we can calculate the memory requirements for this model using the code below:"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "code",
 | ||
|    "execution_count": 19,
 | ||
|    "id": "0df1c79e-27a7-4b0f-ba4e-167fe107125a",
 | ||
|    "metadata": {
 | ||
|     "colab": {
 | ||
|      "base_uri": "https://localhost:8080/"
 | ||
|     },
 | ||
|     "id": "0df1c79e-27a7-4b0f-ba4e-167fe107125a",
 | ||
|     "outputId": "11ced939-556d-4511-d5c0-10a94ed3df32"
 | ||
|    },
 | ||
|    "outputs": [
 | ||
|     {
 | ||
|      "name": "stdout",
 | ||
|      "output_type": "stream",
 | ||
|      "text": [
 | ||
|       "float32 (PyTorch default): 52.33 GB\n",
 | ||
|       "bfloat16: 26.17 GB\n"
 | ||
|      ]
 | ||
|     }
 | ||
|    ],
 | ||
|    "source": [
 | ||
|     "def model_memory_size(model, input_dtype=torch.float32):\n",
 | ||
|     "    total_params = 0\n",
 | ||
|     "    total_grads = 0\n",
 | ||
|     "    for param in model.parameters():\n",
 | ||
|     "        # Calculate total number of elements per parameter\n",
 | ||
|     "        param_size = param.numel()\n",
 | ||
|     "        total_params += param_size\n",
 | ||
|     "        # Check if gradients are stored for this parameter\n",
 | ||
|     "        if param.requires_grad:\n",
 | ||
|     "            total_grads += param_size\n",
 | ||
|     "\n",
 | ||
|     "    # Calculate buffer size (non-parameters that require memory)\n",
 | ||
|     "    total_buffers = sum(buf.numel() for buf in model.buffers())\n",
 | ||
|     "\n",
 | ||
|     "    # Size in bytes = (Number of elements) * (Size of each element in bytes)\n",
 | ||
|     "    # We assume parameters and gradients are stored in the same type as input dtype\n",
 | ||
|     "    element_size = torch.tensor(0, dtype=input_dtype).element_size()\n",
 | ||
|     "    total_memory_bytes = (total_params + total_grads + total_buffers) * element_size\n",
 | ||
|     "\n",
 | ||
|     "    # Convert bytes to gigabytes\n",
 | ||
|     "    total_memory_gb = total_memory_bytes / (1024**3)\n",
 | ||
|     "\n",
 | ||
|     "    return total_memory_gb\n",
 | ||
|     "\n",
 | ||
|     "print(f\"float32 (PyTorch default): {model_memory_size(model, input_dtype=torch.float32):.2f} GB\")\n",
 | ||
|     "print(f\"bfloat16: {model_memory_size(model, input_dtype=torch.bfloat16):.2f} GB\")"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "markdown",
 | ||
|    "id": "zudd-5PulKFL",
 | ||
|    "metadata": {
 | ||
|     "id": "zudd-5PulKFL"
 | ||
|    },
 | ||
|    "source": [
 | ||
|     "- Lastly, we can also transfer the model to an NVIDIA or Apple Silicon GPU if applicable:"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "code",
 | ||
|    "execution_count": 20,
 | ||
|    "id": "a4c50e19-1402-45b6-8ccd-9077b2ba836d",
 | ||
|    "metadata": {
 | ||
|     "id": "a4c50e19-1402-45b6-8ccd-9077b2ba836d"
 | ||
|    },
 | ||
|    "outputs": [],
 | ||
|    "source": [
 | ||
|     "if torch.cuda.is_available():\n",
 | ||
|     "    device = torch.device(\"cuda\")\n",
 | ||
|     "elif torch.backends.mps.is_available():\n",
 | ||
|     "    device = torch.device(\"mps\")\n",
 | ||
|     "else:\n",
 | ||
|     "    device = torch.device(\"cpu\")\n",
 | ||
|     "\n",
 | ||
|     "model.to(device);"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "markdown",
 | ||
|    "id": "5dc64a06-27dc-46ec-9e6d-1700a8227d34",
 | ||
|    "metadata": {
 | ||
|     "id": "5dc64a06-27dc-46ec-9e6d-1700a8227d34"
 | ||
|    },
 | ||
|    "source": [
 | ||
|     " \n",
 | ||
|     "## 3. Load tokenizer"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "markdown",
 | ||
|    "id": "0eb30f0c-6144-4bed-87d9-6b2bac377005",
 | ||
|    "metadata": {
 | ||
|     "id": "0eb30f0c-6144-4bed-87d9-6b2bac377005"
 | ||
|    },
 | ||
|    "source": [
 | ||
|     "- In this section, we are going to load the tokenizer for the model\n",
 | ||
|     "- Llama 2 uses Google's [SentencePiece](https://github.com/google/sentencepiece) tokenizer instead of OpenAI's [Tiktoken](https://github.com/openai/tiktoken) (but Llama 3 uses Tiktoken)\n",
 | ||
|     "- Meta AI shared the original Llama 2 model weights and tokenizer vocabulary on the Hugging Face Hub\n",
 | ||
|     "- We will download the tokenizer vocabulary from the Hub and load it into SentencePiece\n",
 | ||
|     "- Uncomment and run the following code to install the required libraries:"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "code",
 | ||
|    "execution_count": 21,
 | ||
|    "id": "768989ea-dc60-4dc8-ae84-cbb3fd224422",
 | ||
|    "metadata": {
 | ||
|     "id": "768989ea-dc60-4dc8-ae84-cbb3fd224422"
 | ||
|    },
 | ||
|    "outputs": [],
 | ||
|    "source": [
 | ||
|     "# !pip install huggingface_hub sentencepiece"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "markdown",
 | ||
|    "id": "KbnlzsbYmJU6",
 | ||
|    "metadata": {
 | ||
|     "id": "KbnlzsbYmJU6"
 | ||
|    },
 | ||
|    "source": [
 | ||
|     "- Please note that Meta AI requires that you accept the Llama 2 licensing terms before you can download the files; to do this, you have to create a Hugging Face Hub account and visit the [meta-llama/Llama-2-7b](https://huggingface.co/meta-llama/Llama-2-7b) repository to accept the terms\n",
 | ||
|     "- Next, you will need to create an access token; to generate an access token with READ permissions, click on the profile picture in the upper right and click on \"Settings\"\n",
 | ||
|     "\n",
 | ||
|     "\n",
 | ||
|     "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/gpt-to-llama/settings.webp?1\" width=\"300px\">\n",
 | ||
|     "\n",
 | ||
|     "- Then, create and copy the access token so you can copy & paste it into the next code cell\n",
 | ||
|     "\n",
 | ||
|     "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/gpt-to-llama/access-token.webp?1\" width=\"600px\">"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "code",
 | ||
|    "execution_count": 22,
 | ||
|    "id": "3357a230-b678-4691-a238-257ee4e80185",
 | ||
|    "metadata": {
 | ||
|     "colab": {
 | ||
|      "base_uri": "https://localhost:8080/"
 | ||
|     },
 | ||
|     "id": "3357a230-b678-4691-a238-257ee4e80185",
 | ||
|     "outputId": "768ed6af-ce14-40bc-ca18-117b4b448269"
 | ||
|    },
 | ||
|    "outputs": [
 | ||
|     {
 | ||
|      "name": "stdout",
 | ||
|      "output_type": "stream",
 | ||
|      "text": [
 | ||
|       "The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.\n",
 | ||
|       "Token is valid (permission: read).\n",
 | ||
|       "Your token has been saved to /root/.cache/huggingface/token\n",
 | ||
|       "Login successful\n"
 | ||
|      ]
 | ||
|     }
 | ||
|    ],
 | ||
|    "source": [
 | ||
|     "from huggingface_hub import login\n",
 | ||
|     "import json\n",
 | ||
|     "\n",
 | ||
|     "with open(\"config.json\", \"r\") as config_file:\n",
 | ||
|     "    config = json.load(config_file)\n",
 | ||
|     "    access_token = config[\"HF_ACCESS_TOKEN\"]\n",
 | ||
|     "\n",
 | ||
|     "login(token=access_token)"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "markdown",
 | ||
|    "id": "IxGh6ZYQo0VN",
 | ||
|    "metadata": {
 | ||
|     "id": "IxGh6ZYQo0VN"
 | ||
|    },
 | ||
|    "source": [
 | ||
|     "- After login via the access token, which is necessary to verify that we accepted the Llama 2 licensing terms, we can now download the tokenizer vocabulary:"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "code",
 | ||
|    "execution_count": 23,
 | ||
|    "id": "69714ea8-b9b8-4687-8392-f3abb8f93a32",
 | ||
|    "metadata": {
 | ||
|     "colab": {
 | ||
|      "base_uri": "https://localhost:8080/",
 | ||
|      "height": 153,
 | ||
|      "referenced_widgets": [
 | ||
|       "e6c75a6aa7b942fe84160e286e3acb3d",
 | ||
|       "08f0bf9459bd425498a5cb236f9d4a72",
 | ||
|       "10251d6f724e43788c41d4b7879cbfd3",
 | ||
|       "53a973c0853b44418698136bd04df039",
 | ||
|       "bdb071e7145a4007ae01599333e72612",
 | ||
|       "6b1821a7f4574e3aba09c1e410cc81e4",
 | ||
|       "8c2873eaec3445888ad3d54ad7387950",
 | ||
|       "0c8f7044966e4207b12352503c67dcbb",
 | ||
|       "8b5951213c9e4798a258146d61d02d11",
 | ||
|       "2c05df3f91e64df7b33905b1065a76f7",
 | ||
|       "742ae5487f2648fcae7ca8e22c7f8db9"
 | ||
|      ]
 | ||
|     },
 | ||
|     "id": "69714ea8-b9b8-4687-8392-f3abb8f93a32",
 | ||
|     "outputId": "c230fec9-5c71-4a41-90ab-8a34d114ea01"
 | ||
|    },
 | ||
|    "outputs": [
 | ||
|     {
 | ||
|      "name": "stderr",
 | ||
|      "output_type": "stream",
 | ||
|      "text": [
 | ||
|       "/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_token.py:89: UserWarning: \n",
 | ||
|       "The secret `HF_TOKEN` does not exist in your Colab secrets.\n",
 | ||
|       "To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n",
 | ||
|       "You will be able to reuse this secret in all of your notebooks.\n",
 | ||
|       "Please note that authentication is recommended but still optional to access public models or datasets.\n",
 | ||
|       "  warnings.warn(\n"
 | ||
|      ]
 | ||
|     },
 | ||
|     {
 | ||
|      "data": {
 | ||
|       "application/vnd.jupyter.widget-view+json": {
 | ||
|        "model_id": "e6c75a6aa7b942fe84160e286e3acb3d",
 | ||
|        "version_major": 2,
 | ||
|        "version_minor": 0
 | ||
|       },
 | ||
|       "text/plain": [
 | ||
|        "tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]"
 | ||
|       ]
 | ||
|      },
 | ||
|      "metadata": {},
 | ||
|      "output_type": "display_data"
 | ||
|     }
 | ||
|    ],
 | ||
|    "source": [
 | ||
|     "from huggingface_hub import hf_hub_download\n",
 | ||
|     "\n",
 | ||
|     "tokenizer_file = hf_hub_download(\n",
 | ||
|     "    repo_id=\"meta-llama/Llama-2-7b\",\n",
 | ||
|     "    filename=\"tokenizer.model\",\n",
 | ||
|     "    local_dir=\"Llama-2-7b\"\n",
 | ||
|     ")"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "markdown",
 | ||
|    "id": "gp7iQ8cXAJLv",
 | ||
|    "metadata": {
 | ||
|     "id": "gp7iQ8cXAJLv"
 | ||
|    },
 | ||
|    "source": [
 | ||
|     "- To provide a more familiar interface for the tokenizer, we define a small `LlamaTokenizer` wrapper class:"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "code",
 | ||
|    "execution_count": 24,
 | ||
|    "id": "Ef4WxhjOBOOc",
 | ||
|    "metadata": {
 | ||
|     "id": "Ef4WxhjOBOOc"
 | ||
|    },
 | ||
|    "outputs": [],
 | ||
|    "source": [
 | ||
|     "import sentencepiece as spm\n",
 | ||
|     "\n",
 | ||
|     "\n",
 | ||
|     "class LlamaTokenizer:\n",
 | ||
|     "    def __init__(self, tokenizer_file):\n",
 | ||
|     "        sp = spm.SentencePieceProcessor()\n",
 | ||
|     "        sp.load(tokenizer_file)\n",
 | ||
|     "        self.tokenizer = sp\n",
 | ||
|     "\n",
 | ||
|     "    def encode(self, text):\n",
 | ||
|     "        return self.tokenizer.encode_as_ids(text)\n",
 | ||
|     "\n",
 | ||
|     "    def decode(self, ids):\n",
 | ||
|     "        return self.tokenizer.decode_pieces(ids)\n",
 | ||
|     "\n",
 | ||
|     "\n",
 | ||
|     "tokenizer = LlamaTokenizer(tokenizer_file)"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "markdown",
 | ||
|    "id": "NVhmFeX3pT_M",
 | ||
|    "metadata": {
 | ||
|     "id": "NVhmFeX3pT_M"
 | ||
|    },
 | ||
|    "source": [
 | ||
|     "- We can now use the `generate` function to have the Llama 2 model generate new text:"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "code",
 | ||
|    "execution_count": 25,
 | ||
|    "id": "e0a2b5cd-6cba-4d72-b8ff-04d8315d483e",
 | ||
|    "metadata": {
 | ||
|     "colab": {
 | ||
|      "base_uri": "https://localhost:8080/"
 | ||
|     },
 | ||
|     "id": "e0a2b5cd-6cba-4d72-b8ff-04d8315d483e",
 | ||
|     "outputId": "acd5065d-8900-4ba8-ef85-968365f3a0cb"
 | ||
|    },
 | ||
|    "outputs": [
 | ||
|     {
 | ||
|      "name": "stdout",
 | ||
|      "output_type": "stream",
 | ||
|      "text": [
 | ||
|       "Output text:\n",
 | ||
|       " Every effort movesαllRadius deletingpretcc否']; future eer napulate lackус während inter DES издаSchéon로жа Bass differencespadxsnu ;; ctx始\n"
 | ||
|      ]
 | ||
|     }
 | ||
|    ],
 | ||
|    "source": [
 | ||
|     "from previous_chapters import generate, text_to_token_ids, token_ids_to_text\n",
 | ||
|     "# If the `previous_chapters.py` file is not available locally,\n",
 | ||
|     "# you can import it from the `llms-from-scratch` PyPI package.\n",
 | ||
|     "# For details, see: https://github.com/rasbt/LLMs-from-scratch/tree/main/pkg\n",
 | ||
|     "# E.g.,\n",
 | ||
|     "# from llms_from_scratch.ch05 import generate, text_to_token_ids, token_ids_to_text\n",
 | ||
|     "\n",
 | ||
|     "\n",
 | ||
|     "\n",
 | ||
|     "torch.manual_seed(123)\n",
 | ||
|     "\n",
 | ||
|     "token_ids = generate(\n",
 | ||
|     "    model=model,\n",
 | ||
|     "    idx=text_to_token_ids(\"Every effort moves\", tokenizer).to(device),\n",
 | ||
|     "    max_new_tokens=30,\n",
 | ||
|     "    context_size=LLAMA2_CONFIG_7B[\"context_length\"],\n",
 | ||
|     "    top_k=1,\n",
 | ||
|     "    temperature=0.\n",
 | ||
|     ")\n",
 | ||
|     "\n",
 | ||
|     "print(\"Output text:\\n\", token_ids_to_text(token_ids, tokenizer))"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "markdown",
 | ||
|    "id": "93WTtAA5paYV",
 | ||
|    "metadata": {
 | ||
|     "id": "93WTtAA5paYV"
 | ||
|    },
 | ||
|    "source": [
 | ||
|     "- Of course, as we can see above, the text is nonsensical since we haven't trained the Llama 2 model yet\n",
 | ||
|     "- In the next section, instead of training it ourselves, which would cost tens to hundreds of thousands of dollars, we load the pretrained weights from Meta AI"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "markdown",
 | ||
|    "id": "f63cc248-1d27-4eb6-aa50-173b436652f8",
 | ||
|    "metadata": {
 | ||
|     "id": "f63cc248-1d27-4eb6-aa50-173b436652f8"
 | ||
|    },
 | ||
|    "source": [
 | ||
|     " \n",
 | ||
|     "## 4. Load pretrained weights"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "markdown",
 | ||
|    "id": "aKeN7rUfqZMI",
 | ||
|    "metadata": {
 | ||
|     "id": "aKeN7rUfqZMI"
 | ||
|    },
 | ||
|    "source": [
 | ||
|     "- We are loading the [\"meta-llama/Llama-2-7b\"](https://huggingface.co/meta-llama/Llama-2-7b) base model below, which is a simple text completion model before finetuning\n",
 | ||
|     "- Alternatively, you can load the instruction-finetuned and aligned [\"meta-llama/Llama-2-7b-chat\"](https://huggingface.co/meta-llama/Llama-2-7b-chat) model by modifying the string in the next code cell accordingly"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "code",
 | ||
|    "execution_count": 26,
 | ||
|    "id": "5fa9c06c-7a53-4b4d-9ce4-acc027322ee4",
 | ||
|    "metadata": {
 | ||
|     "colab": {
 | ||
|      "base_uri": "https://localhost:8080/",
 | ||
|      "height": 49,
 | ||
|      "referenced_widgets": [
 | ||
|       "66e777955e8748df878f118f07f38dab",
 | ||
|       "da89ae3ea4d2474e98f64ada608f3cea",
 | ||
|       "93e6da39c25f4edfaa72056c89df1f7f",
 | ||
|       "b628603e4cb0405398c916587ee96756",
 | ||
|       "93bedcb9245e44a0a1eb7e4155070f66",
 | ||
|       "0723f467d37b4904819a8bb33ebda10f",
 | ||
|       "e54928776bc649339002adced63738b0",
 | ||
|       "d8e0f42068af4cb094e2f115f76e06e0",
 | ||
|       "0a939565b6e94f08bee0a66e0f9827d4",
 | ||
|       "a5fedbb7ec2e43d99711bb4cd84b9486",
 | ||
|       "0c186f6539714d8eab023969ce47c500"
 | ||
|      ]
 | ||
|     },
 | ||
|     "id": "5fa9c06c-7a53-4b4d-9ce4-acc027322ee4",
 | ||
|     "outputId": "0d8942cc-e5e2-4e77-ec41-1ac7bec7d94f"
 | ||
|    },
 | ||
|    "outputs": [
 | ||
|     {
 | ||
|      "data": {
 | ||
|       "application/vnd.jupyter.widget-view+json": {
 | ||
|        "model_id": "66e777955e8748df878f118f07f38dab",
 | ||
|        "version_major": 2,
 | ||
|        "version_minor": 0
 | ||
|       },
 | ||
|       "text/plain": [
 | ||
|        "consolidated.00.pth:   0%|          | 0.00/13.5G [00:00<?, ?B/s]"
 | ||
|       ]
 | ||
|      },
 | ||
|      "metadata": {},
 | ||
|      "output_type": "display_data"
 | ||
|     }
 | ||
|    ],
 | ||
|    "source": [
 | ||
|     "weights_file = hf_hub_download(\n",
 | ||
|     "   repo_id=\"meta-llama/Llama-2-7b\",\n",
 | ||
|     "   filename=\"consolidated.00.pth\",\n",
 | ||
|     "   local_dir=\"Llama-2-7b\"\n",
 | ||
|     ")"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "code",
 | ||
|    "execution_count": 27,
 | ||
|    "id": "e67cca5c-ba4b-4be5-85c7-fdceae8a5701",
 | ||
|    "metadata": {
 | ||
|     "id": "e67cca5c-ba4b-4be5-85c7-fdceae8a5701"
 | ||
|    },
 | ||
|    "outputs": [],
 | ||
|    "source": [
 | ||
|     "weights = torch.load(weights_file, weights_only=True)"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "markdown",
 | ||
|    "id": "-15SJ7btq2zE",
 | ||
|    "metadata": {
 | ||
|     "id": "-15SJ7btq2zE"
 | ||
|    },
 | ||
|    "source": [
 | ||
|     "- The `weights` contains the following tensors (only the first 15 are shown for simplicity):"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "code",
 | ||
|    "execution_count": 28,
 | ||
|    "id": "ee26bd0b-fea9-4924-97f7-409c14f28e49",
 | ||
|    "metadata": {
 | ||
|     "colab": {
 | ||
|      "base_uri": "https://localhost:8080/"
 | ||
|     },
 | ||
|     "id": "ee26bd0b-fea9-4924-97f7-409c14f28e49",
 | ||
|     "outputId": "fa83d38a-bb41-4cb2-d3c7-c573bfe1f8a4"
 | ||
|    },
 | ||
|    "outputs": [
 | ||
|     {
 | ||
|      "data": {
 | ||
|       "text/plain": [
 | ||
|        "['tok_embeddings.weight',\n",
 | ||
|        " 'norm.weight',\n",
 | ||
|        " 'output.weight',\n",
 | ||
|        " 'layers.0.attention.wq.weight',\n",
 | ||
|        " 'layers.0.attention.wk.weight',\n",
 | ||
|        " 'layers.0.attention.wv.weight',\n",
 | ||
|        " 'layers.0.attention.wo.weight',\n",
 | ||
|        " 'layers.0.feed_forward.w1.weight',\n",
 | ||
|        " 'layers.0.feed_forward.w2.weight',\n",
 | ||
|        " 'layers.0.feed_forward.w3.weight',\n",
 | ||
|        " 'layers.0.attention_norm.weight',\n",
 | ||
|        " 'layers.0.ffn_norm.weight',\n",
 | ||
|        " 'layers.1.attention.wq.weight',\n",
 | ||
|        " 'layers.1.attention.wk.weight',\n",
 | ||
|        " 'layers.1.attention.wv.weight']"
 | ||
|       ]
 | ||
|      },
 | ||
|      "execution_count": 28,
 | ||
|      "metadata": {},
 | ||
|      "output_type": "execute_result"
 | ||
|     }
 | ||
|    ],
 | ||
|    "source": [
 | ||
|     "list(weights.keys())[:15]"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "markdown",
 | ||
|    "id": "UeeSpnunrDFB",
 | ||
|    "metadata": {
 | ||
|     "id": "UeeSpnunrDFB"
 | ||
|    },
 | ||
|    "source": [
 | ||
|     "- The following function, modeled after the `load_weights_into_gpt` function in [chapter 5](../01_main-chapter-code/ch05.ipynb), loads the pretrained weights into our Llama 2 model:"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "code",
 | ||
|    "execution_count": 29,
 | ||
|    "id": "3820e2a7-4f26-41bc-953b-f3879b0aff65",
 | ||
|    "metadata": {
 | ||
|     "id": "3820e2a7-4f26-41bc-953b-f3879b0aff65"
 | ||
|    },
 | ||
|    "outputs": [],
 | ||
|    "source": [
 | ||
|     "def assign(left, right):\n",
 | ||
|     "    if left.shape != right.shape:\n",
 | ||
|     "        raise ValueError(f\"Shape mismatch. Left: {left.shape}, Right: {right.shape}\")\n",
 | ||
|     "\n",
 | ||
|     "    if isinstance(right, torch.Tensor):\n",
 | ||
|     "        return torch.nn.Parameter(right.clone().detach())\n",
 | ||
|     "    else:\n",
 | ||
|     "        return torch.nn.Parameter(torch.tensor(right))\n",
 | ||
|     "\n",
 | ||
|     "\n",
 | ||
|     "def load_weights_into_llama(model, param_config, params):\n",
 | ||
|     "    model.tok_emb.weight = assign(model.tok_emb.weight, params[\"tok_embeddings.weight\"])\n",
 | ||
|     "\n",
 | ||
|     "    for l in range(param_config[\"n_layers\"]):\n",
 | ||
|     "\n",
 | ||
|     "        # Load attention weights\n",
 | ||
|     "        model.trf_blocks[l].att.W_query.weight = assign(\n",
 | ||
|     "            model.trf_blocks[l].att.W_query.weight,\n",
 | ||
|     "            params[f\"layers.{l}.attention.wq.weight\"]\n",
 | ||
|     "        )\n",
 | ||
|     "        model.trf_blocks[l].att.W_key.weight = assign(\n",
 | ||
|     "            model.trf_blocks[l].att.W_key.weight,\n",
 | ||
|     "            params[f\"layers.{l}.attention.wk.weight\"]\n",
 | ||
|     "        )\n",
 | ||
|     "        model.trf_blocks[l].att.W_value.weight = assign(\n",
 | ||
|     "            model.trf_blocks[l].att.W_value.weight,\n",
 | ||
|     "            params[f\"layers.{l}.attention.wv.weight\"]\n",
 | ||
|     "        )\n",
 | ||
|     "        model.trf_blocks[l].att.out_proj.weight = assign(\n",
 | ||
|     "            model.trf_blocks[l].att.out_proj.weight,\n",
 | ||
|     "            params[f\"layers.{l}.attention.wo.weight\"]\n",
 | ||
|     "        )\n",
 | ||
|     "        model.trf_blocks[l].norm1.weight = assign(\n",
 | ||
|     "            model.trf_blocks[l].norm1.weight,\n",
 | ||
|     "            params[f\"layers.{l}.attention_norm.weight\"]\n",
 | ||
|     "        )\n",
 | ||
|     "\n",
 | ||
|     "        # Load FeedForward weights\n",
 | ||
|     "        model.trf_blocks[l].ff.fc1.weight = assign(\n",
 | ||
|     "            model.trf_blocks[l].ff.fc1.weight,\n",
 | ||
|     "            params[f\"layers.{l}.feed_forward.w1.weight\"]\n",
 | ||
|     "        )\n",
 | ||
|     "        # For some reason w2 and w3 are provided in the wrong order in the weights file\n",
 | ||
|     "        model.trf_blocks[l].ff.fc2.weight = assign(\n",
 | ||
|     "            model.trf_blocks[l].ff.fc2.weight,\n",
 | ||
|     "            params[f\"layers.{l}.feed_forward.w3.weight\"]\n",
 | ||
|     "        )\n",
 | ||
|     "        model.trf_blocks[l].ff.fc3.weight = assign(\n",
 | ||
|     "            model.trf_blocks[l].ff.fc3.weight,\n",
 | ||
|     "            params[f\"layers.{l}.feed_forward.w2.weight\"]\n",
 | ||
|     "        )\n",
 | ||
|     "        model.trf_blocks[l].norm2.weight = assign(\n",
 | ||
|     "            model.trf_blocks[l].norm2.weight,\n",
 | ||
|     "            params[f\"layers.{l}.ffn_norm.weight\"]\n",
 | ||
|     "        )\n",
 | ||
|     "\n",
 | ||
|     "    # Load output layer weights\n",
 | ||
|     "    model.final_norm.weight = assign(model.final_norm.weight, params[\"norm.weight\"])\n",
 | ||
|     "    model.out_head.weight = assign(model.out_head.weight, params[\"output.weight\"])\n",
 | ||
|     "\n",
 | ||
|     "\n",
 | ||
|     "load_weights_into_llama(model, LLAMA2_CONFIG_7B, weights)\n",
 | ||
|     "model.to(device);"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "markdown",
 | ||
|    "id": "TDuv_Us2rNvk",
 | ||
|    "metadata": {
 | ||
|     "id": "TDuv_Us2rNvk"
 | ||
|    },
 | ||
|    "source": [
 | ||
|     "- Next, we are ready to use the model for text generation"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "code",
 | ||
|    "execution_count": 30,
 | ||
|    "id": "240987e8-a023-462e-9376-9edfb27559ec",
 | ||
|    "metadata": {
 | ||
|     "colab": {
 | ||
|      "base_uri": "https://localhost:8080/"
 | ||
|     },
 | ||
|     "id": "240987e8-a023-462e-9376-9edfb27559ec",
 | ||
|     "outputId": "044f24b3-4018-4860-834d-6c2731b9e47c"
 | ||
|    },
 | ||
|    "outputs": [
 | ||
|     {
 | ||
|      "name": "stdout",
 | ||
|      "output_type": "stream",
 | ||
|      "text": [
 | ||
|       "Output text:\n",
 | ||
|       " Every effort has been made to ensure that the information contained in this website is accurate and up to date and correct at the time of publication\n"
 | ||
|      ]
 | ||
|     }
 | ||
|    ],
 | ||
|    "source": [
 | ||
|     "torch.manual_seed(123)\n",
 | ||
|     "\n",
 | ||
|     "token_ids = generate(\n",
 | ||
|     "    model=model,\n",
 | ||
|     "    idx=text_to_token_ids(\"Every effort\", tokenizer).to(device),\n",
 | ||
|     "    max_new_tokens=25,\n",
 | ||
|     "    context_size=LLAMA2_CONFIG_7B[\"context_length\"],\n",
 | ||
|     "    top_k=1,\n",
 | ||
|     "    temperature=0.\n",
 | ||
|     ")\n",
 | ||
|     "\n",
 | ||
|     "print(\"Output text:\\n\", token_ids_to_text(token_ids, tokenizer))"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "markdown",
 | ||
|    "id": "d72ed949-b6c0-4966-922f-eb0da732c404",
 | ||
|    "metadata": {},
 | ||
|    "source": [
 | ||
|     " \n",
 | ||
|     "## 5. Using the instruction-finetuned model"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "markdown",
 | ||
|    "id": "akyo7WNyF_YL",
 | ||
|    "metadata": {
 | ||
|     "id": "akyo7WNyF_YL"
 | ||
|    },
 | ||
|    "source": [
 | ||
|     "- As mentioned earlier, above we used the pretrained base model; if you want to use a model capable of following instructions, use the `\"meta-llama/Llama-2-7b-chat\"` model instead, as shown below"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "code",
 | ||
|    "execution_count": 34,
 | ||
|    "id": "nbvAV7vaz6yc",
 | ||
|    "metadata": {
 | ||
|     "colab": {
 | ||
|      "base_uri": "https://localhost:8080/",
 | ||
|      "height": 101,
 | ||
|      "referenced_widgets": [
 | ||
|       "3b2448a60f5f4ba5b2c686037c8ecd78",
 | ||
|       "60c5932944f24f5fad1d8da89c8e5ae9",
 | ||
|       "aa31aed1b8854a4281fd7e81c60e1205",
 | ||
|       "d4acf06c2414412f8f2fb4f48981c954",
 | ||
|       "693d69251d3d48219c084af17b54b851",
 | ||
|       "ff36d28c55dd4db3a0f76a87640fdfe2",
 | ||
|       "71c49ef820494d5f8908a3daf39f0755",
 | ||
|       "525dc406534f4369b11208816f8fd0d7",
 | ||
|       "865f39213a7341b68f2fe73caaf801b1",
 | ||
|       "eaf4c0231b6d4993b2f8e9e63d8b6921",
 | ||
|       "a11edf3b018e42c88a63a515cf7fe478"
 | ||
|      ]
 | ||
|     },
 | ||
|     "id": "nbvAV7vaz6yc",
 | ||
|     "outputId": "724f5508-d976-4e31-b3d7-95fa65b2c1e8"
 | ||
|    },
 | ||
|    "outputs": [
 | ||
|     {
 | ||
|      "data": {
 | ||
|       "application/vnd.jupyter.widget-view+json": {
 | ||
|        "model_id": "3b2448a60f5f4ba5b2c686037c8ecd78",
 | ||
|        "version_major": 2,
 | ||
|        "version_minor": 0
 | ||
|       },
 | ||
|       "text/plain": [
 | ||
|        "consolidated.00.pth:   0%|          | 0.00/13.5G [00:00<?, ?B/s]"
 | ||
|       ]
 | ||
|      },
 | ||
|      "metadata": {},
 | ||
|      "output_type": "display_data"
 | ||
|     },
 | ||
|     {
 | ||
|      "name": "stdout",
 | ||
|      "output_type": "stream",
 | ||
|      "text": [
 | ||
|       "Output text:\n",
 | ||
|       " What do llamas eat?\n",
 | ||
|       "Llamas and alpacas are herbivores, which means they eat grasses, leaves, grass\n"
 | ||
|      ]
 | ||
|     }
 | ||
|    ],
 | ||
|    "source": [
 | ||
|     "del model  # to free up memory\n",
 | ||
|     "\n",
 | ||
|     "weights_file = hf_hub_download(\n",
 | ||
|     "   repo_id=\"meta-llama/Llama-2-7b-chat\",\n",
 | ||
|     "   filename=\"consolidated.00.pth\",\n",
 | ||
|     "   local_dir=\"Llama-2-7b-chat\"\n",
 | ||
|     ")\n",
 | ||
|     "\n",
 | ||
|     "model = Llama2Model(LLAMA2_CONFIG_7B)\n",
 | ||
|     "load_weights_into_llama(model, LLAMA2_CONFIG_7B, weights)\n",
 | ||
|     "model.to(device);\n",
 | ||
|     "\n",
 | ||
|     "torch.manual_seed(123)\n",
 | ||
|     "\n",
 | ||
|     "token_ids = generate(\n",
 | ||
|     "    model=model,\n",
 | ||
|     "    idx=text_to_token_ids(\"What do llamas eat?\", tokenizer).to(device),\n",
 | ||
|     "    max_new_tokens=25,\n",
 | ||
|     "    context_size=LLAMA2_CONFIG_7B[\"context_length\"],\n",
 | ||
|     "    top_k=1,\n",
 | ||
|     "    temperature=0.\n",
 | ||
|     ")\n",
 | ||
|     "\n",
 | ||
|     "print(\"Output text:\\n\", token_ids_to_text(token_ids, tokenizer))"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "markdown",
 | ||
|    "id": "0f693da1-a07c-4e1d-af5a-c3923525f1e2",
 | ||
|    "metadata": {},
 | ||
|    "source": [
 | ||
|     " \n",
 | ||
|     "# What's next?"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "markdown",
 | ||
|    "id": "fae93739-ca12-46ba-8ca7-7c07c59f669b",
 | ||
|    "metadata": {},
 | ||
|    "source": [
 | ||
|     "- This notebook converted the original GPT-2 architecture into a Llama 2 model\n",
 | ||
|     "- If you are interested in how to convert Llama 2 into Llama 3, Llama 3.1, and Llama 3.2, check out the [converting-llama2-to-llama3.ipynb](converting-llama2-to-llama3.ipynb) notebook"
 | ||
|    ]
 | ||
|   }
 | ||
|  ],
 | ||
|  "metadata": {
 | ||
|   "accelerator": "GPU",
 | ||
|   "colab": {
 | ||
|    "gpuType": "A100",
 | ||
|    "provenance": []
 | ||
|   },
 | ||
|   "kernelspec": {
 | ||
|    "display_name": "Python 3 (ipykernel)",
 | ||
|    "language": "python",
 | ||
|    "name": "python3"
 | ||
|   },
 | ||
|   "language_info": {
 | ||
|    "codemirror_mode": {
 | ||
|     "name": "ipython",
 | ||
|     "version": 3
 | ||
|    },
 | ||
|    "file_extension": ".py",
 | ||
|    "mimetype": "text/x-python",
 | ||
|    "name": "python",
 | ||
|    "nbconvert_exporter": "python",
 | ||
|    "pygments_lexer": "ipython3",
 | ||
|    "version": "3.10.16"
 | ||
|   }
 | ||
|  },
 | ||
|  "nbformat": 4,
 | ||
|  "nbformat_minor": 5
 | ||
| }
 | 
