diff --git a/ch05/07_gpt_to_llama/README.md b/ch05/07_gpt_to_llama/README.md index 280d43e..fda7ab7 100644 --- a/ch05/07_gpt_to_llama/README.md +++ b/ch05/07_gpt_to_llama/README.md @@ -2,6 +2,10 @@ -This folder contains code for converting the GPT implementation from chapter 4 and 5 to Meta AI's Llama architecture: +This folder contains code for converting the GPT implementation from chapter 4 and 5 to Meta AI's Llama architecture in the following recommended reading order: -- [converting-gpt-to-llama2.ipynb](converting-gpt-to-llama2.ipynb): contains code to convert GPT to Llama 2 7B step by step and loads pretrained weights from Meta AI \ No newline at end of file +- [converting-gpt-to-llama2.ipynb](converting-gpt-to-llama2.ipynb): contains code to convert GPT to Llama 2 7B step by step and loads pretrained weights from Meta AI +- [converting-llama2-to-llama3.ipynb](converting-llama2-to-llama3.ipynb): contains code to convert the Llama 2 model to Llama 3, Llama 3.1, and Llama 3.2 +- [standalone-llama32.ipynb](standalone-llama32.ipynb): a standalone notebook implementing Llama 3.2 + + \ No newline at end of file diff --git a/ch05/07_gpt_to_llama/converting-gpt-to-llama2.ipynb b/ch05/07_gpt_to_llama/converting-gpt-to-llama2.ipynb index 731ab98..e8c5bf6 100644 --- a/ch05/07_gpt_to_llama/converting-gpt-to-llama2.ipynb +++ b/ch05/07_gpt_to_llama/converting-gpt-to-llama2.ipynb @@ -108,6 +108,7 @@ "id": "UJJneXpTEg4W" }, "source": [ + " \n", "# 1. Convert the GPT model implementation step by step" ] }, @@ -129,6 +130,7 @@ "id": "979c7b6d-1370-4da1-8bfb-a2b27537bf2f" }, "source": [ + " \n", "## 1.1 Replace LayerNorm with RMSNorm layer" ] }, @@ -228,6 +230,7 @@ "id": "5eb81f83-c38c-46a4-b763-aa630a32e357" }, "source": [ + " \n", "## 1.2 Replace GELU with SiLU activation" ] }, @@ -300,6 +303,7 @@ "id": "4f9b5167-1da9-46c8-9964-8036b3b1deb9" }, "source": [ + " \n", "## 1.3 Update the FeedForward module" ] }, @@ -388,6 +392,7 @@ "id": "f6b7bf4f-99d0-42c1-807c-5074d2cc1949" }, "source": [ + " \n", "## 1.4 Implement RoPE" ] }, @@ -503,6 +508,7 @@ "id": "f78127b0-dda2-4c5a-98dd-bae8f5fe8297" }, "source": [ + " \n", "## 1.5 Add RoPE to MultiHeadAttention module" ] }, @@ -652,6 +658,7 @@ "id": "e5a1a272-a038-4b8f-aaaa-f4b241e7f23f" }, "source": [ + " \n", "## 1.6 Update the TransformerBlock module" ] }, @@ -727,6 +734,7 @@ "id": "ada953bc-e2c0-4432-a32d-3f7efa3f6e0f" }, "source": [ + " \n", "## 1.7 Update the model class" ] }, @@ -791,6 +799,7 @@ "id": "4bc94940-aaeb-45b9-9399-3a69b8043e60" }, "source": [ + " \n", "## 2. Initialize model" ] }, @@ -1029,6 +1038,7 @@ "id": "5dc64a06-27dc-46ec-9e6d-1700a8227d34" }, "source": [ + " \n", "## 3. Load tokenizer" ] }, @@ -1288,6 +1298,7 @@ "id": "f63cc248-1d27-4eb6-aa50-173b436652f8" }, "source": [ + " \n", "## 4. Load pretrained weights" ] }, @@ -1544,6 +1555,15 @@ "print(\"Output text:\\n\", token_ids_to_text(token_ids, tokenizer))" ] }, + { + "cell_type": "markdown", + "id": "d72ed949-b6c0-4966-922f-eb0da732c404", + "metadata": {}, + "source": [ + " \n", + "## 5. Using the instruction-finetuned model" + ] + }, { "cell_type": "markdown", "id": "akyo7WNyF_YL", @@ -1551,7 +1571,7 @@ "id": "akyo7WNyF_YL" }, "source": [ - "- Tip: as mentioned earlier, this is the pretrained base model; if you want to use a model capable of following instructions, use the `\"meta-llama/Llama-2-7b-chat\"` model instead" + "- As mentioned earlier, above we used the pretrained base model; if you want to use a model capable of following instructions, use the `\"meta-llama/Llama-2-7b-chat\"` model instead, as shown below" ] }, { @@ -1630,6 +1650,24 @@ "\n", "print(\"Output text:\\n\", token_ids_to_text(token_ids, tokenizer))" ] + }, + { + "cell_type": "markdown", + "id": "0f693da1-a07c-4e1d-af5a-c3923525f1e2", + "metadata": {}, + "source": [ + " \n", + "# What's next?" + ] + }, + { + "cell_type": "markdown", + "id": "fae93739-ca12-46ba-8ca7-7c07c59f669b", + "metadata": {}, + "source": [ + "- This notebook converted the original GPT-2 architecture into a Llama 2 model\n", + "- If you are interested in how to convert Llama 2 into Llama 3, Llama 3.1, and Llama 3.2, check out the [converting-llama2-to-llama3.ipynb](converting-llama2-to-llama3.ipynb) notebook" + ] } ], "metadata": { @@ -1653,7 +1691,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.6" + "version": "3.11.4" }, "widgets": { "application/vnd.jupyter.widget-state+json": { diff --git a/ch05/07_gpt_to_llama/converting-llama2-to-llama3.ipynb b/ch05/07_gpt_to_llama/converting-llama2-to-llama3.ipynb new file mode 100644 index 0000000..9a69c5d --- /dev/null +++ b/ch05/07_gpt_to_llama/converting-llama2-to-llama3.ipynb @@ -0,0 +1,7848 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0_xya1nyDHfY", + "metadata": { + "id": "0_xya1nyDHfY" + }, + "source": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "
\n", + "\n", + "Supplementary code for the Build a Large Language Model From Scratch book by Sebastian Raschka
\n", + "
Code repository: https://github.com/rasbt/LLMs-from-scratch\n", + "
\n", + "
\n", + "\n", + "
" + ] + }, + { + "cell_type": "markdown", + "id": "l62zIRRSBy_R", + "metadata": { + "id": "l62zIRRSBy_R" + }, + "source": [ + "# Converting Llama 2 to Llama 3.2 From Scratch" + ] + }, + { + "cell_type": "markdown", + "id": "aFmxTQbwCUMl", + "metadata": { + "id": "aFmxTQbwCUMl" + }, + "source": [ + "- This is a follow-up notebook to [Converting a From-Scratch GPT Architecture to Llama 2](./converting-gpt-to-llama2.ipynb), converting Meta AI's Llama 2 architecture model step by step to Llama 3, Llama 3.1, and Llama 3.2\n", + "- The explanations are purposefully kept minimal in this notebook so as not to bloat it unnecessarily and focus on the main code\n", + "- For more information about the architectures, please see the Llama 2 and Llama 3 papers\n", + " - [Llama 2: Open Foundation and Fine-Tuned Chat Models (2023)](https://arxiv.org/abs/2307.09288)\n", + " - [The Llama 3 Herd of Models](https://arxiv.org/abs/2407.21783)" + ] + }, + { + "cell_type": "markdown", + "id": "ohhMKUWvGm9z", + "metadata": { + "id": "ohhMKUWvGm9z" + }, + "source": [ + "" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ws0wsUzwLH2k", + "metadata": { + "id": "ws0wsUzwLH2k" + }, + "outputs": [], + "source": [ + "# pip install -r requirements-extra.txt" + ] + }, + { + "cell_type": "markdown", + "id": "JBpQwU89ETA1", + "metadata": { + "id": "JBpQwU89ETA1" + }, + "source": [ + "- Packages that are being used in this notebook:" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "34a9a440-84c2-42cc-808b-38677cb6af8a", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "34a9a440-84c2-42cc-808b-38677cb6af8a", + "outputId": "1a64035b-daeb-4514-a49f-6bfde84357e7" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "blobfile version: 3.0.0\n", + "huggingface_hub version: 0.24.7\n", + "tiktoken version: 0.8.0\n", + "torch version: 2.4.1+cu121\n" + ] + } + ], + "source": [ + "from importlib.metadata import version\n", + "\n", + "pkgs = [\n", + " \"blobfile\", # to download pretrained weights\n", + " \"huggingface_hub\", # to download pretrained weights\n", + " \"tiktoken\", # to implement the tokenizer\n", + " \"torch\", # to implement the model\n", + "]\n", + "for p in pkgs:\n", + " print(f\"{p} version: {version(p)}\")" + ] + }, + { + "cell_type": "markdown", + "id": "UJJneXpTEg4W", + "metadata": { + "id": "UJJneXpTEg4W" + }, + "source": [ + " \n", + "# 1. Convert the Llama model implementation step by step" + ] + }, + { + "cell_type": "markdown", + "id": "v1zpfX2GHBKa", + "metadata": { + "id": "v1zpfX2GHBKa" + }, + "source": [ + "- If you are new to implementing LLM architectures, I recommend starting with [chapter 4](../../ch04/01_main-chapter-code/ch04.ipynb), which walks you through the implementation of the original GPT architecture step by step\n", + "- The [Converting a From-Scratch GPT Architecture to Llama 2](./converting-gpt-to-llama2.ipynb) then implements the Llama-specific components, such as RMSNorm layers, SiLU and SwiGLU activations, RoPE (rotary position embeddings), and the SentencePiece tokenizer\n", + "- This notebook takes the Llama 2 architecture and transforms it into Llama 3 architecture by\n", + " 1. modifying the rotary embeddings\n", + " 2. implementing grouped-query attention\n", + " 3. and using a customized version of the GPT-4 tokenizer\n", + "- Later, we then load the original Llama 3 weights shared by Meta AI into the architecture" + ] + }, + { + "cell_type": "markdown", + "id": "c14b9121-abe1-4a46-99b8-acdef71e5b41", + "metadata": { + "id": "c14b9121-abe1-4a46-99b8-acdef71e5b41" + }, + "source": [ + " \n", + "## 1.1 Reusing Llama 2 components" + ] + }, + { + "cell_type": "markdown", + "id": "dgDhJGJ6xR4e", + "metadata": { + "id": "dgDhJGJ6xR4e" + }, + "source": [ + "- Llama 2 is actually quite similar to Llama 3, as mentioned above and illustrated in the figure at the top of this notebook\n", + "- This means that we can import several building blocks from the [Llama 2 notebook](./converting-gpt-to-llama2.ipynb) using the following code" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "a5bc3948-231b-4f1f-8d41-24ad0b7643d0", + "metadata": { + "id": "a5bc3948-231b-4f1f-8d41-24ad0b7643d0" + }, + "outputs": [], + "source": [ + "import os\n", + "import sys\n", + "import io\n", + "import nbformat\n", + "import types\n", + "\n", + "def import_from_notebook():\n", + " def import_definitions_from_notebook(fullname, names):\n", + " current_dir = os.getcwd()\n", + " path = os.path.join(current_dir, fullname + \".ipynb\")\n", + " path = os.path.normpath(path)\n", + "\n", + " # Load the notebook\n", + " if not os.path.exists(path):\n", + " raise FileNotFoundError(f\"Notebook file not found at: {path}\")\n", + "\n", + " with io.open(path, \"r\", encoding=\"utf-8\") as f:\n", + " nb = nbformat.read(f, as_version=4)\n", + "\n", + " # Create a module to store the imported functions and classes\n", + " mod = types.ModuleType(fullname)\n", + " sys.modules[fullname] = mod\n", + "\n", + " # Go through the notebook cells and only execute function or class definitions\n", + " for cell in nb.cells:\n", + " if cell.cell_type == \"code\":\n", + " cell_code = cell.source\n", + " for name in names:\n", + " # Check for function or class definitions\n", + " if f\"def {name}\" in cell_code or f\"class {name}\" in cell_code:\n", + " exec(cell_code, mod.__dict__)\n", + " return mod\n", + "\n", + " fullname = \"converting-gpt-to-llama2\"\n", + " names = [\"precompute_rope_params\", \"compute_rope\", \"SiLU\", \"FeedForward\", \"RMSNorm\", \"MultiHeadAttention\"]\n", + "\n", + " return import_definitions_from_notebook(fullname, names)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "d546032d-fce4-47cf-8d0e-682b78b21c61", + "metadata": { + "id": "d546032d-fce4-47cf-8d0e-682b78b21c61" + }, + "outputs": [], + "source": [ + "imported_module = import_from_notebook()\n", + "\n", + "# We need to redefine precompute_rope_params\n", + "# precompute_rope_params = getattr(imported_module, \"precompute_rope_params\", None)\n", + "compute_rope = getattr(imported_module, \"compute_rope\", None)\n", + "SiLU = getattr(imported_module, \"SiLU\", None)\n", + "FeedForward = getattr(imported_module, \"FeedForward\", None)\n", + "RMSNorm = getattr(imported_module, \"RMSNorm\", None)\n", + "\n", + "# MultiHeadAttention only for comparison purposes\n", + "MultiHeadAttention = getattr(imported_module, \"MultiHeadAttention\", None)" + ] + }, + { + "cell_type": "markdown", + "id": "979c7b6d-1370-4da1-8bfb-a2b27537bf2f", + "metadata": { + "id": "979c7b6d-1370-4da1-8bfb-a2b27537bf2f" + }, + "source": [ + " \n", + "## 1.2 Modified RoPE" + ] + }, + { + "cell_type": "markdown", + "id": "m9_oDcHCx8VI", + "metadata": { + "id": "m9_oDcHCx8VI" + }, + "source": [ + "- Llama 3 uses rotary position embeddings (RoPE) similar to Llama 2 (for a detailed explanation, please see the [RoPE paper](https://arxiv.org/abs/2104.09864))\n", + "- There are some subtle differences in the RoPE settings, though\n", + " - Llama 3 now supports up to 8,192 tokens, twice as many as Llama 2 (4,096)\n", + " - The base value for the so-called RoPE $\\theta$ (see equation below) was increased from 10,000 (Llama 2) to 50,000 (Llama 3) in the following equation (adapted from the [RoPE paper](https://arxiv.org/abs/2104.09864))\n", + "\n", + "$$\\Theta = \\left\\{\\theta_i = \\text{base}^{\\frac{2(i-1)}{d}}, i \\in \\left[1, 2, ..., d/2\\right]\\right\\}$$\n", + "\n", + "- These $\\theta$ values are a set of predefined parameters that are used to determine the rotational angles in the rotary matrix, where $d$ is the dimensionality of the embedding space\n", + "- Increasing the base from 10,000 to 50,000 makes the frequencies (or rotation angles) decay more slowly across the dimensions, which means that higher dimensions will be associated with larger angles than before (essentially, it's a decompression of the frequencies)\n", + "- In addition, we introduce a `freq_config` section in the code below that adjusts the frequency; however, we won't be needing it in Llama 3 (only Llama 3.1 and Llama 3.2), so we will revisit this `freq_config` later (it's set to `None` and ignored by default)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "6Upl109OOAcu", + "metadata": { + "id": "6Upl109OOAcu" + }, + "outputs": [], + "source": [ + "import torch\n", + "\n", + "def precompute_rope_params(head_dim, theta_base=10000, context_length=4096, freq_config=None):\n", + " assert head_dim % 2 == 0, \"Embedding dimension must be even\"\n", + "\n", + " # Compute the inverse frequencies\n", + " inv_freq = 1.0 / (theta_base ** (torch.arange(0, head_dim // 2) / (head_dim // 2)))\n", + "\n", + " ################################ NEW ###############################################\n", + " # Frequency adjustments\n", + " if freq_config is not None:\n", + " low_freq_wavelen = freq_config[\"original_context_length\"] / freq_config[\"low_freq_factor\"]\n", + " high_freq_wavelen = freq_config[\"original_context_length\"] / freq_config[\"high_freq_factor\"]\n", + "\n", + " wavelen = 2 * torch.pi / inv_freq\n", + "\n", + " inv_freq_llama = torch.where(\n", + " wavelen > low_freq_wavelen, inv_freq / freq_config[\"factor\"], inv_freq\n", + " )\n", + "\n", + " smooth_factor = (freq_config[\"original_context_length\"] / wavelen - freq_config[\"low_freq_factor\"]) / (\n", + " freq_config[\"high_freq_factor\"] - freq_config[\"low_freq_factor\"]\n", + " )\n", + "\n", + " smoothed_inv_freq = (\n", + " (1 - smooth_factor) * (inv_freq / freq_config[\"factor\"]) + smooth_factor * inv_freq\n", + " )\n", + "\n", + " is_medium_freq = (wavelen <= low_freq_wavelen) & (wavelen >= high_freq_wavelen)\n", + " inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)\n", + " inv_freq = inv_freq_llama\n", + " ####################################################################################\n", + "\n", + "\n", + " # Generate position indices\n", + " positions = torch.arange(context_length)\n", + "\n", + " # Compute the angles\n", + " angles = positions[:, None] * inv_freq[None, :] # Shape: (context_length, head_dim // 2)\n", + "\n", + " # Expand angles to match the head_dim\n", + " angles = torch.cat([angles, angles], dim=1) # Shape: (context_length, head_dim)\n", + "\n", + " # Precompute sine and cosine\n", + " cos = torch.cos(angles)\n", + " sin = torch.sin(angles)\n", + "\n", + " return cos, sin" + ] + }, + { + "cell_type": "markdown", + "id": "jJBvO0YMJBXR", + "metadata": { + "id": "jJBvO0YMJBXR" + }, + "source": [ + "- To summarize, what's new so far for Llama 3 compared to Llama 2 are the context length and theta base parameter:" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "56c37216-e022-4603-be16-f9d3eaeaf4a1", + "metadata": { + "id": "56c37216-e022-4603-be16-f9d3eaeaf4a1" + }, + "outputs": [], + "source": [ + "# Instantiate RoPE parameters\n", + "\n", + "llama_2_context_len = 4096\n", + "llama_3_context_len = 8192\n", + "\n", + "llama_2_theta_base = 10_000\n", + "llama_3_theta_base = 50_000" + ] + }, + { + "cell_type": "markdown", + "id": "_V8v6i7MJItU", + "metadata": { + "id": "_V8v6i7MJItU" + }, + "source": [ + "- The usage remains the same as before in Llama 2:" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "dae70c8a-eb18-40f9-a2e5-a6af2a57628b", + "metadata": { + "id": "dae70c8a-eb18-40f9-a2e5-a6af2a57628b" + }, + "outputs": [], + "source": [ + "# Settings\n", + "batch_size = 2\n", + "num_heads = 4\n", + "head_dim = 16\n", + "\n", + "# Instantiate RoPE parameters\n", + "cos, sin = precompute_rope_params(\n", + " head_dim=head_dim,\n", + " theta_base=llama_3_theta_base,\n", + " context_length=llama_3_context_len\n", + ")\n", + "\n", + "# Dummy query and key tensors\n", + "torch.manual_seed(123)\n", + "queries = torch.randn(batch_size, llama_3_context_len, num_heads, head_dim)\n", + "keys = torch.randn(batch_size, llama_3_context_len, num_heads, head_dim)\n", + "\n", + "# Apply rotary position embeddings\n", + "queries_rot = compute_rope(queries, cos, sin)\n", + "keys_rot = compute_rope(keys, cos, sin)" + ] + }, + { + "cell_type": "markdown", + "id": "cd19b75c-cf25-47b8-a010-6733fc0e9a8a", + "metadata": { + "id": "cd19b75c-cf25-47b8-a010-6733fc0e9a8a" + }, + "source": [ + " \n", + "## 1.3 Grouped-query attention" + ] + }, + { + "cell_type": "markdown", + "id": "111c7d3f-fded-49e8-a617-9fe67b81dddc", + "metadata": { + "id": "111c7d3f-fded-49e8-a617-9fe67b81dddc" + }, + "source": [ + "- In this section, we replace multi-head attention (MHA) with an alternative mechanism called grouped-query attention (GQA)\n", + "- In short, one can think of GQA as a more compute- and parameter-efficient version of MHA\n", + "- In GQA, we reduce the number of key and value projections by sharing them among multiple attention heads\n", + "- Each attention head still has its unique query, but these queries attend to the same group of keys and values\n", + "- Below is an illustration of GQA with 2 key-value-groups (kv-groups):\n", + "\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "id": "perAYa2R_KW2", + "metadata": { + "id": "perAYa2R_KW2" + }, + "source": [ + "- The main idea behind GQA is to reduce the number of unique query groups that attend to the key-value pairs, reducing the size of some of the matrix multiplications and the number of parameters in MHA without significantly reducing modeling performance\n", + "- The GQA code is very similar to MHA (I highlighted the changes below via the \"NEW\" sections)\n", + "- In short, the main change in GQA is that each query group needs to be repeated to match the number of heads it is associated with, as implemented below" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "9b12e674-ef08-4dd7-8843-615b65b39c91", + "metadata": { + "id": "9b12e674-ef08-4dd7-8843-615b65b39c91" + }, + "outputs": [], + "source": [ + "import torch.nn as nn\n", + "\n", + "class GroupedQueryAttention(nn.Module):\n", + " def __init__(\n", + " self, d_in, d_out, context_length, num_heads,\n", + " num_kv_groups, # NEW\n", + " rope_base=10_000, # NEW\n", + " rope_config=None, # NEW\n", + " dtype=None\n", + " ):\n", + " super().__init__()\n", + " assert d_out % num_heads == 0, \"d_out must be divisible by num_heads\"\n", + " assert num_heads % num_kv_groups == 0, \"num_heads must be divisible by num_kv_groups\"\n", + "\n", + " self.d_out = d_out\n", + " self.num_heads = num_heads\n", + " self.head_dim = d_out // num_heads\n", + "\n", + " ############################# NEW #############################\n", + " # self.W_key = nn.Linear(d_in, d_out, bias=False, dtype=dtype)\n", + " # self.W_value = nn.Linear(d_in, d_out, bias=False, dtype=dtype)\n", + " self.W_key = nn.Linear(d_in, num_kv_groups * self.head_dim, bias=False, dtype=dtype)\n", + " self.W_value = nn.Linear(d_in, num_kv_groups * self.head_dim, bias=False, dtype=dtype)\n", + " self.num_kv_groups = num_kv_groups\n", + " self.group_size = num_heads // num_kv_groups\n", + " ################################################################\n", + "\n", + " self.W_query = nn.Linear(d_in, d_out, bias=False, dtype=dtype)\n", + " self.out_proj = nn.Linear(d_out, d_out, bias=False, dtype=dtype)\n", + "\n", + " self.register_buffer(\"mask\", torch.triu(torch.ones(context_length, context_length), diagonal=1))\n", + " cos, sin = precompute_rope_params(\n", + " head_dim=self.head_dim,\n", + " theta_base=rope_base, # NEW\n", + " freq_config=rope_config, # NEW\n", + " context_length=8192\n", + " )\n", + " self.register_buffer(\"cos\", cos)\n", + " self.register_buffer(\"sin\", sin)\n", + "\n", + " def forward(self, x):\n", + " b, num_tokens, d_in = x.shape\n", + "\n", + " queries = self.W_query(x) # Shape: (b, num_tokens, d_out)\n", + " keys = self.W_key(x) # Shape: (b, num_tokens, num_kv_groups * head_dim)\n", + " values = self.W_value(x) # Shape: (b, num_tokens, num_kv_groups * head_dim)\n", + "\n", + " # Reshape queries, keys, and values\n", + " queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)\n", + "\n", + " ##################### NEW #####################\n", + " # keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)\n", + " # values = values.view(b, num_tokens, self.num_heads, self.head_dim)\n", + " keys = keys.view(b, num_tokens, self.num_kv_groups, self.head_dim)\n", + " values = values.view(b, num_tokens, self.num_kv_groups, self.head_dim)\n", + " ################################################\n", + "\n", + " # Transpose keys, values, and queries\n", + " keys = keys.transpose(1, 2) # Shape: (b, num_heads, num_tokens, head_dim)\n", + " values = values.transpose(1, 2) # Shape: (b, num_heads, num_tokens, head_dim)\n", + " queries = queries.transpose(1, 2) # Shape: (b, num_query_groups, num_tokens, head_dim)\n", + "\n", + " # Apply RoPE\n", + " keys = compute_rope(keys, self.cos, self.sin)\n", + " queries = compute_rope(queries, self.cos, self.sin)\n", + "\n", + " ##################### NEW #####################\n", + " # Expand keys and values to match the number of heads\n", + " # Shape: (b, num_heads, num_tokens, head_dim)\n", + "\n", + " keys = keys.repeat_interleave(self.group_size, dim=1) # Shape: (b, num_heads, num_tokens, head_dim)\n", + " values = values.repeat_interleave(self.group_size, dim=1) # Shape: (b, num_heads, num_tokens, head_dim)\n", + " # For example, before repeat_interleave along dim=1 (query groups):\n", + " # [K1, K2]\n", + " # After repeat_interleave (each query group is repeated group_size times):\n", + " # [K1, K1, K2, K2]\n", + " # If we used regular repeat instead of repeat_interleave, we'd get:\n", + " # [K1, K2, K1, K2]\n", + " ################################################\n", + "\n", + " # Compute scaled dot-product attention (aka self-attention) with a causal mask\n", + " # Shape: (b, num_heads, num_tokens, num_tokens)\n", + " attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head\n", + "\n", + " # Original mask truncated to the number of tokens and converted to boolean\n", + " mask_bool = self.mask.bool()[:num_tokens, :num_tokens]\n", + "\n", + " # Use the mask to fill attention scores\n", + " attn_scores.masked_fill_(mask_bool, -torch.inf)\n", + "\n", + " attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)\n", + " assert keys.shape[-1] == self.head_dim\n", + "\n", + " # Shape: (b, num_tokens, num_heads, head_dim)\n", + " context_vec = (attn_weights @ values).transpose(1, 2)\n", + "\n", + " # Combine heads, where self.d_out = self.num_heads * self.head_dim\n", + " context_vec = context_vec.reshape(b, num_tokens, self.d_out)\n", + " context_vec = self.out_proj(context_vec) # optional projection\n", + "\n", + " return context_vec" + ] + }, + { + "cell_type": "markdown", + "id": "roAXSwJs9hR8", + "metadata": { + "id": "roAXSwJs9hR8" + }, + "source": [ + "- To illustrate the parameter savings, consider the following multi-head attention example from the GPT and Llama 2 code:" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "b4b8f085-349e-4674-a3f0-78fde0664fac", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "b4b8f085-349e-4674-a3f0-78fde0664fac", + "outputId": "16b37235-d4a0-41ac-b878-f0d2f9584174" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "W_key: torch.Size([4096, 4096])\n", + "W_value: torch.Size([4096, 4096])\n", + "W_query: torch.Size([4096, 4096])\n" + ] + } + ], + "source": [ + "# Settings\n", + "batch_size = 1\n", + "context_len = 3000\n", + "max_context_len = 8192\n", + "embed_dim = 4096\n", + "num_heads = 32\n", + "\n", + "\n", + "example_batch = torch.randn((batch_size, context_len, embed_dim))\n", + "\n", + "mha = MultiHeadAttention(\n", + " d_in=embed_dim,\n", + " d_out=embed_dim,\n", + " context_length=max_context_len,\n", + " num_heads=num_heads\n", + ")\n", + "\n", + "mha(example_batch)\n", + "\n", + "print(\"W_key:\", mha.W_key.weight.shape)\n", + "print(\"W_value:\", mha.W_value.weight.shape)\n", + "print(\"W_query:\", mha.W_query.weight.shape)" + ] + }, + { + "cell_type": "markdown", + "id": "IMQtFkcQ9sXC", + "metadata": { + "id": "IMQtFkcQ9sXC" + }, + "source": [ + "- Now, if we use grouped-query attention instead, with 8 kv-groups (that's how many Llama 3 8B uses), we can see that the number of rows of the key and value matrices are reduced by a factor of 4 (because 32 attention heads divided by 8 kv-groups is 4)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "15e65d3c-7b42-4ed3-bfee-bb09578657bb", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "15e65d3c-7b42-4ed3-bfee-bb09578657bb", + "outputId": "85432dba-827b-4a27-aedd-63f9c5044352" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "W_key: torch.Size([1024, 4096])\n", + "W_value: torch.Size([1024, 4096])\n", + "W_query: torch.Size([4096, 4096])\n" + ] + } + ], + "source": [ + "gqa = GroupedQueryAttention(\n", + " d_in=embed_dim,\n", + " d_out=embed_dim,\n", + " context_length=max_context_len,\n", + " num_heads=num_heads,\n", + " num_kv_groups=8,\n", + " rope_base=llama_3_theta_base\n", + ")\n", + "\n", + "gqa(example_batch)\n", + "\n", + "print(\"W_key:\", gqa.W_key.weight.shape)\n", + "print(\"W_value:\", gqa.W_value.weight.shape)\n", + "print(\"W_query:\", gqa.W_query.weight.shape)" + ] + }, + { + "cell_type": "markdown", + "id": "1a5d4c88-c66a-483b-b4e2-419ff9fd60d5", + "metadata": { + "id": "1a5d4c88-c66a-483b-b4e2-419ff9fd60d5" + }, + "source": [ + "- As a side note, to make the GroupedQueryAttention equivalent to standard multi-head attention, you can set the number of query groups (`num_kv_groups`) equal to the number of heads (`num_heads`)\n", + "- Lastly, let's compare the number of parameters below:" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "58f713aa-ac00-4e2f-8247-94609aa01350", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "58f713aa-ac00-4e2f-8247-94609aa01350", + "outputId": "20caa098-41bd-4572-e8b7-1020073c5912" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Total number of parameters:\n", + "MHA: 67,108,864\n", + "GQA: 41,943,040\n" + ] + } + ], + "source": [ + "print(\"Total number of parameters:\")\n", + "\n", + "mha_total_params = sum(p.numel() for p in mha.parameters())\n", + "print(f\"MHA: {mha_total_params:,}\")\n", + "\n", + "gqa_total_params = sum(p.numel() for p in gqa.parameters())\n", + "print(f\"GQA: {gqa_total_params:,}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "78b60dfd-6c0f-41f7-8f0c-8e57116f07f5", + "metadata": { + "id": "78b60dfd-6c0f-41f7-8f0c-8e57116f07f5" + }, + "outputs": [], + "source": [ + "# Free up memory:\n", + "del mha\n", + "del gqa" + ] + }, + { + "cell_type": "markdown", + "id": "8fcd8802-2859-45a2-905a-f4fe96629dd9", + "metadata": { + "id": "8fcd8802-2859-45a2-905a-f4fe96629dd9" + }, + "source": [ + " \n", + "## 1.4 Update the TransformerBlock module" + ] + }, + { + "cell_type": "markdown", + "id": "KABNccft_YnR", + "metadata": { + "id": "KABNccft_YnR" + }, + "source": [ + "- Next, we update the `TransformerBlock`\n", + "- Here, we simply swap `MultiHeadAttention` with `GroupedQueryAttention` and add the new RoPE settings" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "f9fa8eb4-7196-4dee-aec6-0dcbc70921c4", + "metadata": { + "id": "f9fa8eb4-7196-4dee-aec6-0dcbc70921c4" + }, + "outputs": [], + "source": [ + "class TransformerBlock(nn.Module):\n", + " def __init__(self, cfg):\n", + " super().__init__()\n", + " self.att = GroupedQueryAttention( # MultiHeadAttention(\n", + " d_in=cfg[\"emb_dim\"],\n", + " d_out=cfg[\"emb_dim\"],\n", + " context_length=cfg[\"context_length\"],\n", + " num_heads=cfg[\"n_heads\"],\n", + " num_kv_groups=cfg[\"n_kv_groups\"], # NEW\n", + " rope_base=cfg[\"rope_base\"], # NEW\n", + " rope_config=cfg[\"rope_freq\"], # NEW\n", + " dtype=cfg[\"dtype\"]\n", + " )\n", + " self.ff = FeedForward(cfg)\n", + " self.norm1 = RMSNorm(cfg[\"emb_dim\"], eps=1e-5)\n", + " self.norm2 = RMSNorm(cfg[\"emb_dim\"], eps=1e-5)\n", + "\n", + " def forward(self, x):\n", + " # Shortcut connection for attention block\n", + " shortcut = x\n", + " x = self.norm1(x)\n", + " x = self.att(x.to(torch.bfloat16)) # Shape [batch_size, num_tokens, emb_size]\n", + " x = x + shortcut # Add the original input back\n", + "\n", + " # Shortcut connection for feed-forward block\n", + " shortcut = x\n", + " x = self.norm2(x)\n", + " x = self.ff(x.to(torch.bfloat16))\n", + " x = x + shortcut # Add the original input back\n", + "\n", + " return x" + ] + }, + { + "cell_type": "markdown", + "id": "fd921ab5-c48c-4c52-bf41-b847b3b822b9", + "metadata": { + "id": "fd921ab5-c48c-4c52-bf41-b847b3b822b9" + }, + "source": [ + " \n", + "## 1.5 Defining the model class" + ] + }, + { + "cell_type": "markdown", + "id": "M_tLAq_r_llN", + "metadata": { + "id": "M_tLAq_r_llN" + }, + "source": [ + "- When setting up the model class, we fortunately don't have to do much; we just update the name to `Llama3Model`" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "475755d6-01f7-4e6e-ad9a-cec6f031ebf6", + "metadata": { + "id": "475755d6-01f7-4e6e-ad9a-cec6f031ebf6" + }, + "outputs": [], + "source": [ + "# class Llama2Model(nn.Module):\n", + "class Llama3Model(nn.Module):\n", + " def __init__(self, cfg):\n", + " super().__init__()\n", + " self.tok_emb = nn.Embedding(cfg[\"vocab_size\"], cfg[\"emb_dim\"], dtype=cfg[\"dtype\"])\n", + "\n", + " self.trf_blocks = nn.Sequential(\n", + " *[TransformerBlock(cfg) for _ in range(cfg[\"n_layers\"])])\n", + "\n", + " self.final_norm = RMSNorm(cfg[\"emb_dim\"], eps=1e-5)\n", + " self.out_head = nn.Linear(cfg[\"emb_dim\"], cfg[\"vocab_size\"], bias=False, dtype=cfg[\"dtype\"])\n", + "\n", + " def forward(self, in_idx):\n", + " batch_size, seq_len = in_idx.shape\n", + " tok_embeds = self.tok_emb(in_idx)\n", + " x = tok_embeds\n", + " x = self.trf_blocks(x)\n", + " x = self.final_norm(x)\n", + " logits = self.out_head(x.to(torch.bfloat16))\n", + " return logits" + ] + }, + { + "cell_type": "markdown", + "id": "4bc94940-aaeb-45b9-9399-3a69b8043e60", + "metadata": { + "id": "4bc94940-aaeb-45b9-9399-3a69b8043e60" + }, + "source": [ + " \n", + "## 2. Initialize model" + ] + }, + { + "cell_type": "markdown", + "id": "HoGGRAGykQTE", + "metadata": { + "id": "HoGGRAGykQTE" + }, + "source": [ + "- Now we can define a Llama 3 config file (the Llama 2 config file is shown for comparison)" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "e0564727-2d35-4f0c-b0fc-cde1e9134a18", + "metadata": { + "id": "e0564727-2d35-4f0c-b0fc-cde1e9134a18" + }, + "outputs": [], + "source": [ + "LLAMA2_CONFIG_7B = {\n", + " \"vocab_size\": 32_000, # Vocabulary size\n", + " \"context_length\": 4096, # Context length\n", + " \"emb_dim\": 4096, # Embedding dimension\n", + " \"n_heads\": 32, # Number of attention heads\n", + " \"n_layers\": 32, # Number of layers\n", + " \"hidden_dim\": 11_008, # Size of the intermediate dimension in FeedForward\n", + " \"dtype\": torch.bfloat16 # Lower-precision dtype to save memory\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "2ad90f82-15c7-4806-b509-e45b56f57db5", + "metadata": { + "id": "2ad90f82-15c7-4806-b509-e45b56f57db5" + }, + "outputs": [], + "source": [ + "LLAMA3_CONFIG_8B = {\n", + " \"vocab_size\": 128_256, # NEW: Larger vocabulary size\n", + " \"context_length\": 8192, # NEW: Larger context length\n", + " \"emb_dim\": 4096, # Embedding dimension\n", + " \"n_heads\": 32, # Number of attention heads\n", + " \"n_layers\": 32, # Number of layers\n", + " \"hidden_dim\": 14_336, # NEW: Larger size of the intermediate dimension in FeedForward\n", + " \"n_kv_groups\": 8, # NEW: Key-Value groups for grouped-query attention\n", + " \"rope_base\": 50_000, # NEW: The base in RoPE's \"theta\" was increased to 50_000\n", + " \"rope_freq\": None, # NEW: Additional configuration for adjusting the RoPE frequencies\n", + " \"dtype\": torch.bfloat16 # Lower-precision dtype to save memory\n", + "}" + ] + }, + { + "cell_type": "markdown", + "id": "FAP7fiBzkaBz", + "metadata": { + "id": "FAP7fiBzkaBz" + }, + "source": [ + "- Using these settings, we can now initialize a Llama 3 8B model\n", + "- Note that this requires ~34 GB of memory (for comparison, Llama 2 7B required ~26 GB of memory)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "7004d785-ac9a-4df5-8760-6807fc604686", + "metadata": { + "id": "7004d785-ac9a-4df5-8760-6807fc604686" + }, + "outputs": [], + "source": [ + "model = Llama3Model(LLAMA3_CONFIG_8B)" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "6079f747-8f20-4c6b-8d38-7156f1101729", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "6079f747-8f20-4c6b-8d38-7156f1101729", + "outputId": "8ce6476f-ea77-4513-d31a-3d3cdffa3044" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Total number of parameters: 8,030,261,248\n" + ] + } + ], + "source": [ + "total_params = sum(p.numel() for p in model.parameters())\n", + "print(f\"Total number of parameters: {total_params:,}\")" + ] + }, + { + "cell_type": "markdown", + "id": "Bx14NtzWk2wj", + "metadata": { + "id": "Bx14NtzWk2wj" + }, + "source": [ + "- As shown above, the model contains 8 billion parameters\n", + "- Additionally, we can calculate the memory requirements for this model using the code below:" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "0df1c79e-27a7-4b0f-ba4e-167fe107125a", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "0df1c79e-27a7-4b0f-ba4e-167fe107125a", + "outputId": "13683a74-017a-41d1-a49f-c264f899c4cc" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "float32 (PyTorch default): 68.08 GB\n", + "bfloat16: 34.04 GB\n" + ] + } + ], + "source": [ + "def model_memory_size(model, input_dtype=torch.float32):\n", + " total_params = 0\n", + " total_grads = 0\n", + " for param in model.parameters():\n", + " # Calculate total number of elements per parameter\n", + " param_size = param.numel()\n", + " total_params += param_size\n", + " # Check if gradients are stored for this parameter\n", + " if param.requires_grad:\n", + " total_grads += param_size\n", + "\n", + " # Calculate buffer size (non-parameters that require memory)\n", + " total_buffers = sum(buf.numel() for buf in model.buffers())\n", + "\n", + " # Size in bytes = (Number of elements) * (Size of each element in bytes)\n", + " # We assume parameters and gradients are stored in the same type as input dtype\n", + " element_size = torch.tensor(0, dtype=input_dtype).element_size()\n", + " total_memory_bytes = (total_params + total_grads + total_buffers) * element_size\n", + "\n", + " # Convert bytes to gigabytes\n", + " total_memory_gb = total_memory_bytes / (1024**3)\n", + "\n", + " return total_memory_gb\n", + "\n", + "print(f\"float32 (PyTorch default): {model_memory_size(model, input_dtype=torch.float32):.2f} GB\")\n", + "print(f\"bfloat16: {model_memory_size(model, input_dtype=torch.bfloat16):.2f} GB\")" + ] + }, + { + "cell_type": "markdown", + "id": "zudd-5PulKFL", + "metadata": { + "id": "zudd-5PulKFL" + }, + "source": [ + "- Lastly, we can also transfer the model to an NVIDIA or Apple Silicon GPU if applicable:" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "a4c50e19-1402-45b6-8ccd-9077b2ba836d", + "metadata": { + "id": "a4c50e19-1402-45b6-8ccd-9077b2ba836d" + }, + "outputs": [], + "source": [ + "if torch.cuda.is_available():\n", + " device = torch.device(\"cuda\")\n", + "elif torch.backends.mps.is_available():\n", + " device = torch.device(\"mps\")\n", + "else:\n", + " device = torch.device(\"cpu\")\n", + "\n", + "model.to(device);" + ] + }, + { + "cell_type": "markdown", + "id": "5dc64a06-27dc-46ec-9e6d-1700a8227d34", + "metadata": { + "id": "5dc64a06-27dc-46ec-9e6d-1700a8227d34" + }, + "source": [ + " \n", + "## 3. Load tokenizer" + ] + }, + { + "cell_type": "markdown", + "id": "0eb30f0c-6144-4bed-87d9-6b2bac377005", + "metadata": { + "id": "0eb30f0c-6144-4bed-87d9-6b2bac377005" + }, + "source": [ + "- In this section, we are going to load the tokenizer for the model\n", + "- Llama 2 used Google's [SentencePiece](https://github.com/google/sentencepiece) tokenizer instead of OpenAI's BPE tokenizer based on the [Tiktoken](https://github.com/openai/tiktoken) library\n", + "- Llama 3, however, reverted back to using the BPE tokenizer from Tiktoken; specifically, it uses the GPT-4 tokenizer with an extended vocabulary\n", + "- You can find the original Tiktoken-adaptation by Meta AI [here](https://github.com/meta-llama/llama3/blob/main/llama/tokenizer.py) in their official Llama 3 repository\n", + "- Below, I rewrote the tokenizer code to make it more readable and minimal for this notebook (but the behavior should be similar)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "5f390cbf-8f92-46dc-afe3-d90b5affae10", + "metadata": { + "id": "5f390cbf-8f92-46dc-afe3-d90b5affae10" + }, + "outputs": [], + "source": [ + "import os\n", + "from pathlib import Path\n", + "\n", + "import tiktoken\n", + "from tiktoken.load import load_tiktoken_bpe\n", + "\n", + "\n", + "class Tokenizer:\n", + " def __init__(self, model_path):\n", + " assert os.path.isfile(model_path), f\"Model file {model_path} not found\"\n", + " mergeable_ranks = load_tiktoken_bpe(model_path)\n", + " num_base_tokens = len(mergeable_ranks)\n", + "\n", + " self.special_tokens = {\n", + " \"<|begin_of_text|>\": 128000,\n", + " \"<|end_of_text|>\": 128001,\n", + " \"<|start_header_id|>\": 128006,\n", + " \"<|end_header_id|>\": 128007,\n", + " \"<|eot_id|>\": 128009,\n", + " }\n", + " self.special_tokens.update({\n", + " f\"<|reserved_{i}|>\": 128002 + i for i in range(256) if (128002 + i) not in self.special_tokens.values()\n", + " })\n", + "\n", + " self.model = tiktoken.Encoding(\n", + " name=Path(model_path).name,\n", + " pat_str=r\"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+\",\n", + " mergeable_ranks=mergeable_ranks,\n", + " special_tokens=self.special_tokens\n", + " )\n", + "\n", + "\n", + " def encode(self, text, bos=False, eos=False, allowed_special=set(), disallowed_special=()):\n", + " if bos:\n", + " tokens = [self.special_tokens[\"<|begin_of_text|>\"]]\n", + " else:\n", + " tokens = []\n", + "\n", + " tokens += self.model.encode(text, allowed_special=allowed_special, disallowed_special=disallowed_special)\n", + "\n", + " if eos:\n", + " tokens.append(self.special_tokens[\"<|end_of_text|>\"])\n", + " return tokens\n", + "\n", + " def decode(self, tokens):\n", + " return self.model.decode(tokens)" + ] + }, + { + "cell_type": "markdown", + "id": "0a1509f8-8778-4fec-ba32-14d95c646167", + "metadata": { + "id": "0a1509f8-8778-4fec-ba32-14d95c646167" + }, + "source": [ + "- Meta AI shared the original Llama 3 model weights and tokenizer vocabulary on the Hugging Face Hub\n", + "- We will first download the tokenizer vocabulary from the Hub and load it into the code above" + ] + }, + { + "cell_type": "markdown", + "id": "KbnlzsbYmJU6", + "metadata": { + "id": "KbnlzsbYmJU6" + }, + "source": [ + "- Please note that Meta AI requires that you accept the Llama 3 licensing terms before you can download the files; to do this, you have to create a Hugging Face Hub account and visit the [meta-llama/Meta-Llama-3-8B](https://huggingface.co/meta-llama/Meta-Llama-3-8B) repository to accept the terms\n", + "- Next, you will need to create an access token; to generate an access token with READ permissions, click on the profile picture in the upper right and click on \"Settings\"\n", + "\n", + "\n", + "\n", + "\n", + "- Then, create and copy the access token so you can copy & paste it into the next code cell\n", + "\n", + "" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "3357a230-b678-4691-a238-257ee4e80185", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "3357a230-b678-4691-a238-257ee4e80185", + "outputId": "57472e6e-95c4-4b41-bc18-231f6ff69e95" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.\n", + "Token is valid (permission: read).\n", + "Your token has been saved to /root/.cache/huggingface/token\n", + "Login successful\n" + ] + } + ], + "source": [ + "from huggingface_hub import login\n", + "import json\n", + "\n", + "with open(\"config.json\", \"r\") as config_file:\n", + " config = json.load(config_file)\n", + " access_token = config[\"HF_ACCESS_TOKEN\"]\n", + "\n", + "login(token=access_token)" + ] + }, + { + "cell_type": "markdown", + "id": "IxGh6ZYQo0VN", + "metadata": { + "id": "IxGh6ZYQo0VN" + }, + "source": [ + "- After login via the access token, which is necessary to verify that we accepted the Llama 3 licensing terms, we can now download the tokenizer vocabulary:" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "69714ea8-b9b8-4687-8392-f3abb8f93a32", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 153, + "referenced_widgets": [ + "dd1779b5e0484f0c9c72af34a6a3e638", + "2dc88f14cf83432fbfb62c914a40a9d3", + "f42effc8bf4b443eba7d108b69d4d417", + "3ea3c0f23f1746ce82685c92056ee83d", + "cc6d7bc9b1034e208d14ef0a2e2766cd", + "0f6cd37c1bf14d32922d1f24fe57f895", + "5cad22d53fe34dc4af4d6a2bcc0f3081", + "80615905cbd8495dbe72924048de5fec", + "9c4420d3100440f1bf217d30b5ef74c5", + "e98c47789e4d43e9950f6c496dc8ccea", + "0f61d20b92c54ce2843790b4acbd49b5" + ] + }, + "id": "69714ea8-b9b8-4687-8392-f3abb8f93a32", + "outputId": "13ad9048-e4b3-4232-f2e3-ce9546738af0" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_token.py:89: UserWarning: \n", + "The secret `HF_TOKEN` does not exist in your Colab secrets.\n", + "To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n", + "You will be able to reuse this secret in all of your notebooks.\n", + "Please note that authentication is recommended but still optional to access public models or datasets.\n", + " warnings.warn(\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "dd1779b5e0484f0c9c72af34a6a3e638", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "tokenizer.model: 0%| | 0.00/2.18M [00:00',\n", + "\n" + ] + } + ], + "source": [ + "from previous_chapters import generate, text_to_token_ids, token_ids_to_text\n", + "\n", + "\n", + "torch.manual_seed(123)\n", + "\n", + "token_ids = generate(\n", + " model=model,\n", + " idx=text_to_token_ids(\"Every effort\", tokenizer).to(device),\n", + " max_new_tokens=30,\n", + " context_size=LLAMA3_CONFIG_8B[\"context_length\"],\n", + " top_k=1,\n", + " temperature=0.\n", + ")\n", + "\n", + "print(\"Output text:\\n\", token_ids_to_text(token_ids, tokenizer))" + ] + }, + { + "cell_type": "markdown", + "id": "93WTtAA5paYV", + "metadata": { + "id": "93WTtAA5paYV" + }, + "source": [ + "- Of course, as we can see above, the text is nonsensical since we haven't trained the Llama 3 model yet\n", + "- In the next section, instead of training it ourselves, which would cost tens to hundreds of thousands of dollars, we load the pretrained weights from Meta AI" + ] + }, + { + "cell_type": "markdown", + "id": "f63cc248-1d27-4eb6-aa50-173b436652f8", + "metadata": { + "id": "f63cc248-1d27-4eb6-aa50-173b436652f8" + }, + "source": [ + " \n", + "## 4. Load pretrained weights" + ] + }, + { + "cell_type": "markdown", + "id": "aKeN7rUfqZMI", + "metadata": { + "id": "aKeN7rUfqZMI" + }, + "source": [ + "- We are loading the [\"meta-llama/Meta-Llama-3-8B\"](https://huggingface.co/meta-llama/Meta-Llama-3-8B) base model below, which is a simple text completion model before finetuning\n", + "- Alternatively, you can load the instruction-finetuned and aligned [\"meta-llama/Meta-Llama-3-8B-Instruct\"](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) model by modifying the string in the next code cell accordingly\n", + "- Combined, the weight files are about 16 GB large" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "5fa9c06c-7a53-4b4d-9ce4-acc027322ee4", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 145, + "referenced_widgets": [ + "bf1ddeaf6985478ca466b71e4724a0f9", + "da547372b7024322a3ff455757ee264d", + "5a3b5c70adf444908b29f8e206986b07", + "a471de0efd7b439bb809e80a95f39b35", + "b891846006734ff79b3f1b2306b0d1df", + "23ad2086ff48400aae0b3d9061cda257", + "7e4ece38b0034c8496fe7bc4ed5eba85", + "fc68351e057d47b3a0fb3bf7d0304a91", + "b424e3df452d4d6095549a2e2e3e7840", + "c3b8c70e6907463aaa21a055dbbf0487", + "88478778657542babae7a4be9018b6e5", + "058766ddf03a4148a685170ccaea1831", + "acd39bf9aa0e43afbfd39cf21e21c31a", + "beec9356d8594f1fa1dc97d239ffbbb8", + "965cf0c53bf74930bf967fbb157cc1b5", + "54b34287f3714003b86948d13a076cc4", + "9d08e6234835499db4ba81f52b57fe92", + "3ce254c8287e4c7c8bfbb7e2d36cb781", + "1e4ca258ea624bb59a07d1f0e14c0bd4", + "f3963a4f634b4ef3aed7c9eaeabca281", + "cba668f75df04ef9b844d6f16ea0d1a5", + "e5ae5853ccb440cc9be2eeebcfcdec7a", + "6a1b639c131a4f3383b33d7f542b558d", + "654b9359234f45d6810ef319119acc2e", + "8925c4a2fb0a451e8865a1f1319a2ecd", + "4bd434edb0c84777b7d8893c9525d9e9", + "a11b4e04e0234c1ab6b6fbccde598195", + "11d38236c32140c296665a41107e2a77", + "f15678cd373f469ba9e9fa3b09a790f2", + "8e86af3182ea4f9c9fa9a18b4a17195b", + "9e26bbe275ee49aab7af5916a40c6ba2", + "b0b342edc852407ca06d65684abfd81c", + "c6dacb418b0b4bf3aa9eecffb380f44c", + "aaecfb3b66644f0982be5fb4d27dd484", + "d13023e5cc564765bbbba2f0908c4850", + "dd58c67ccd42464d9300e9b97432230a", + "8bac15c0853b4cb286da885bec977533", + "f67ec3c99d9243f5a02c2ddebef6ea14", + "fcaec6bf58164472a7a193c0b16c0eb6", + "812418ee5e324ae7a2fb9e3b7c34693d", + "d748373c23c249d1843f77e56955f5e2", + "f93e1328bf9b424e98bbcd46792efb51", + "8fbc37f0a5804f8bb790bd82f95e7dd3", + "63113c4df8e6413f82bfa8ccb1cfa78d" + ] + }, + "id": "5fa9c06c-7a53-4b4d-9ce4-acc027322ee4", + "outputId": "3b5c3da6-8fe3-4654-e8bb-f8b55014a1e8" + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "bf1ddeaf6985478ca466b71e4724a0f9", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "model-00001-of-00004.safetensors: 0%| | 0.00/4.98G [00:00\"])\n", + " tokens.extend(self.tokenizer.encode(message[\"role\"], bos=False, eos=False))\n", + " tokens.append(self.tokenizer.special_tokens[\"<|end_header_id|>\"])\n", + " tokens.extend(self.tokenizer.encode(\"\\n\\n\", bos=False, eos=False))\n", + " return tokens\n", + "\n", + " def encode(self, text):\n", + " message = {\n", + " \"role\": \"user\",\n", + " \"content\": text\n", + " }\n", + "\n", + " tokens = self.encode_header(message)\n", + " tokens.extend(\n", + " self.tokenizer.encode(message[\"content\"].strip(), bos=False, eos=False)\n", + " )\n", + " tokens.append(self.tokenizer.special_tokens[\"<|eot_id|>\"])\n", + " return tokens\n", + "\n", + " def decode(self, token_ids):\n", + " return self.tokenizer.decode(token_ids)\n", + "\n", + "\n", + "chat_tokenizer = ChatFormat(tokenizer)" + ] + }, + { + "cell_type": "markdown", + "id": "M-dkSNvwDttN", + "metadata": { + "id": "M-dkSNvwDttN" + }, + "source": [ + "- The usage is as follows:" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "id": "nwBrTGTsUNhn", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "nwBrTGTsUNhn", + "outputId": "587b5259-3124-46de-b13f-ef2d1662026d" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[128006, 882, 128007, 271, 9906, 4435, 0, 128009]\n" + ] + } + ], + "source": [ + "token_ids = chat_tokenizer.encode(\"Hello World!\")\n", + "print(token_ids)" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "id": "0fpmpVgYVTRZ", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 36 + }, + "id": "0fpmpVgYVTRZ", + "outputId": "fb24c3f2-da13-4429-e5c3-13e016d51eac" + }, + "outputs": [ + { + "data": { + "application/vnd.google.colaboratory.intrinsic+json": { + "type": "string" + }, + "text/plain": [ + "'<|start_header_id|>user<|end_header_id|>\\n\\nHello World!<|eot_id|>'" + ] + }, + "execution_count": 35, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tokenizer.decode(token_ids)" + ] + }, + { + "cell_type": "markdown", + "id": "Wo-aUGeKDvqq", + "metadata": { + "id": "Wo-aUGeKDvqq" + }, + "source": [ + "- Let's now see the Llama 3 instruction model in action:" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "id": "ozGOBu6XOkEW", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "ozGOBu6XOkEW", + "outputId": "bae98ed7-0e6b-439d-978e-9c48710b7e25" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Output text:\n", + " Llamas are herbivores, which means they primarily eat plants and plant-based foods. Here are some of the things llamas like to eat:\n", + "\n", + "1. Grass: Llamas love to graze on grass, especially in the spring and summer months.\n", + "2. Hay: Hay is a staple in a llama's diet. They like to eat timothy hay, alfalfa hay, and other types of hay.\n", + "3. Grains: Llamas may also be fed grains like oats, barley, and corn. However, grains should not make up more than 10% of a llama's diet.\n", + "4. Fruits and vegetables: Llamas may enjoy fruits and vegetables as treats, such as apples,\n" + ] + } + ], + "source": [ + "import re\n", + "\n", + "\n", + "torch.manual_seed(123)\n", + "\n", + "token_ids = generate(\n", + " model=model,\n", + " idx=text_to_token_ids(\"What do llamas eat?\", chat_tokenizer).to(device),\n", + " max_new_tokens=150,\n", + " context_size=LLAMA3_CONFIG_8B[\"context_length\"],\n", + " top_k=1,\n", + " temperature=0.\n", + ")\n", + "\n", + "output_text = token_ids_to_text(token_ids, tokenizer)\n", + "\n", + "\n", + "def clean_text(text, header_end=\"assistant<|end_header_id|>\\n\\n\"):\n", + " # Find the index of the first occurrence of \"<|end_header_id|>\"\n", + " index = text.find(header_end)\n", + "\n", + " if index != -1:\n", + " # Return the substring starting after \"<|end_header_id|>\"\n", + " return text[index + len(header_end):].strip() # Strip removes leading/trailing whitespace\n", + " else:\n", + " # If the token is not found, return the original text\n", + " return text\n", + "\n", + "print(\"Output text:\\n\", clean_text(output_text))" + ] + }, + { + "cell_type": "markdown", + "id": "2r5JKrO-ZOHK", + "metadata": { + "id": "2r5JKrO-ZOHK" + }, + "source": [ + " \n", + "# Llama 3.1 8B" + ] + }, + { + "cell_type": "markdown", + "id": "QiQxX0XnP_iC", + "metadata": { + "id": "QiQxX0XnP_iC" + }, + "source": [ + "- A few months after the initial Llama 3 release, Meta AI followed up with their Llama 3.1 suite of models (see the official [Introducing Llama 3.1: Our most capable models to date](https://ai.meta.com/blog/meta-llama-3-1/) announcement blog post for details)\n", + "- Conveniently, we can reuse our previous Llama 3 code from above to implement Llama 3.1 8B\n", + "\n", + "\n", + "\n", + "- The architecture is identical, with the only change being a rescaling of the RoPE frequencies as indicated in the configuration file below\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "id": "X5Fg8XUHMv4M", + "metadata": { + "id": "X5Fg8XUHMv4M" + }, + "outputs": [], + "source": [ + "LLAMA3_CONFIG_8B = {\n", + " \"vocab_size\": 128_256, # Vocabulary size\n", + " \"context_length\": 8192, # Context length\n", + " \"emb_dim\": 4096, # Embedding dimension\n", + " \"n_heads\": 32, # Number of attention heads\n", + " \"n_layers\": 32, # Number of layers\n", + " \"hidden_dim\": 14_336, # Size of the intermediate dimension in FeedForward\n", + " \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n", + " \"rope_base\": 50_000, # The base in RoPE's \"theta\"\n", + " \"rope_freq\": None, # Additional configuration for adjusting the RoPE frequencies\n", + " \"dtype\": torch.bfloat16 # Lower-precision dtype to save memory\n", + "}\n", + "\n", + "LLAMA31_CONFIG_8B = {\n", + " \"vocab_size\": 128_256, # Vocabulary size\n", + " \"context_length\": 8192, # Context length\n", + " \"emb_dim\": 4096, # Embedding dimension\n", + " \"n_heads\": 32, # Number of attention heads\n", + " \"n_layers\": 32, # Number of layers\n", + " \"hidden_dim\": 14_336, # Size of the intermediate dimension in FeedForward\n", + " \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n", + " \"rope_base\": 50_000, # The base in RoPE's \"theta\"\n", + " \"dtype\": torch.bfloat16, # Lower-precision dtype to save memory\n", + " \"rope_freq\": { # NEW: RoPE frequency scaling\n", + " \"factor\": 8.0,\n", + " \"low_freq_factor\": 1.0,\n", + " \"high_freq_factor\": 4.0,\n", + " \"original_context_length\": 8192,\n", + " }\n", + "}" + ] + }, + { + "cell_type": "markdown", + "id": "xa3bpMDtTdBs", + "metadata": { + "id": "xa3bpMDtTdBs" + }, + "source": [ + "- As we've seen in the code earlier, the RoPE method uses sinusoidal functions (sine and cosine) to embed positional information directly into the attention mechanism\n", + "- In Llama 3.1, via the additional configuration, we introduce additional adjustments to the inverse frequency calculations\n", + "- These adjustments influence how different frequency components contribute to the positional embeddings (a detailed explanation is a topic for another time)\n", + "- Let's try out the Llama 3.1 model in practice; first, we clear out the old model to free up some GPU memory" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "id": "7dUtYnNUOqhL", + "metadata": { + "id": "7dUtYnNUOqhL" + }, + "outputs": [], + "source": [ + "# free up memory\n", + "del model\n", + "\n", + "gc.collect() # Run Python garbage collector\n", + "\n", + "if torch.cuda.is_available():\n", + " torch.cuda.empty_cache()" + ] + }, + { + "cell_type": "markdown", + "id": "DbbVsll6TYWR", + "metadata": { + "id": "DbbVsll6TYWR" + }, + "source": [ + "- Next, we download the tokenizer\n", + "- Note that since the Llama 3.1 family is distinct from the Llama 3 family, you'd have to go to the [meta-llama/Llama-3.1-8B](https://huggingface.co/meta-llama/Llama-3.1-8B) repository and acknowledge the license terms for your Hugging Face access token to work for the download\n", + "- Tip: For simplicity, we only load the base model below, but there's also an instruction-finetuned version you can use by replacing `\"meta-llama/Llama-3.1-8B\"` with `\"meta-llama/Llama-3.1-8B-Instruct\"`" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "id": "8xDk4chtPNU4", + "metadata": { + "id": "8xDk4chtPNU4" + }, + "outputs": [], + "source": [ + "tokenizer_file_path = hf_hub_download(\n", + " repo_id=\"meta-llama/Llama-3.1-8B\",\n", + " filename=\"original/tokenizer.model\",\n", + " local_dir=\"llama3-files\"\n", + ")\n", + "\n", + "tokenizer = Tokenizer(tokenizer_file_path)" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "id": "a7l21VE4Otcs", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "a7l21VE4Otcs", + "outputId": "ea31712b-6678-4ad3-d246-a7e1ad7cbd66" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Total number of parameters: 8,030,261,248\n" + ] + } + ], + "source": [ + "model = Llama3Model(LLAMA31_CONFIG_8B)\n", + "\n", + "total_params = sum(p.numel() for p in model.parameters())\n", + "print(f\"Total number of parameters: {total_params:,}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "id": "u4J7IxOvOyPM", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 145, + "referenced_widgets": [ + "f5c112068f5a4850b6291fe85fa6ad7d", + "feedc809874345bd8e7954ea79c67045", + "2e90cedb16f3482ab0e8e59a2883d559", + "abda6f03955d4f019739dc2b06c3fe10", + "155961a28bea4e29b7e983b22570344f", + "83d5bfe6c1d14d799306449b82d2aa05", + "847bebf73639433e93c64030591fe7cb", + "6d1b30c011934ede879f32c3ab8b259f", + "54ea67e62e9a4cd78fefe661450c3389", + "a2f70ad00c844a99b8c18707d30e670c", + "054169cdf794410b8c00cc059c2cfdd3", + "05dd05934e164c159bd73cc015ba2d39", + "cfb94182fcaf4329bb3fa18241943fba", + "5c829a299bf84a73851449c399f44f6d", + "b196454543b4482884e4c74f2fd3f1ce", + "6cedcdd2b27f4892b865bced95516559", + "88b77252689f46f08d95a19dd405ad19", + "fa51af329d62423d86491cb2ab1b8dad", + "356d5a5cfe9a42839339fb59aeeec50f", + "9c5dc483690d45d6ad41922bde16fd02", + "c970327a316d4846a16be7a3ffd96b52", + "1ce3da98ea894053a415a2d23292784e", + "7c55b07c8efa4c2a97ea201f09fbeb56", + "d43de2dba0a84bb781cb79b57947ccb6", + "104c09d32366453eab0e7f18c1e77cc2", + "7de2d78f12974c9f944458d566b55b6e", + "0649cdb8c83c449f8cf9eb5e8d094659", + "2c7cb6e8c3ec4254b5dc5982b161f39d", + "4af4cd15802846f69938d00c80116788", + "f7375354f1044927a0f70dff97c289f9", + "f45f67087839412f8071fdf0fd5a9e77", + "d1a3f516bb0a45a7ab44c0aac4ae23f4", + "e23a141dbe294a7b91112eb0628c3ca6", + "3ec5cf160bd146f397a34d72008efabf", + "c2ea10cf970a4973a30bf716eb6397cf", + "9eeeebd4a17e4031ac51c735e9ab5f67", + "0987a11e19fe4ae286edba3b3f9cb1fa", + "f9ec37bb0a754d5eb20398a7a3f8b3c7", + "c5b17cf09ce4481b9e8b601ca24fe7bc", + "7914c5eae30842b784e3e1f453c26fde", + "e1117dc64c3648629ef7434e2280d1b6", + "5e6b266059d54265b9a62372fac06ba4", + "e121b62fef0f40f686ddfa24f117c9f3", + "987e18e93abd4a1c9eeb4ae6f0f5231c" + ] + }, + "id": "u4J7IxOvOyPM", + "outputId": "8ea92361-66a9-4c45-f1c4-cdc7d322f0ba" + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "f5c112068f5a4850b6291fe85fa6ad7d", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "model-00001-of-00004.safetensors: 0%| | 0.00/4.98G [00:00" + ] + }, + { + "cell_type": "markdown", + "id": "K0KgjwCCJ9Fb", + "metadata": { + "id": "K0KgjwCCJ9Fb" + }, + "source": [ + "- As we can see based on the figure above, the main difference between the Llama 3.1 8B and Llama 3.2 1B architectures are the respective sizes\n", + "- A small additional change is an increased RoPE rescaling factor, which is reflected in the configuration file below" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "id": "Yv_yF3NCQTBx", + "metadata": { + "id": "Yv_yF3NCQTBx" + }, + "outputs": [], + "source": [ + "LLAMA31_CONFIG_8B = {\n", + " \"vocab_size\": 128_256, # Vocabulary size\n", + " \"context_length\": 8192, # Context length\n", + " \"emb_dim\": 4096, # Embedding dimension\n", + " \"n_heads\": 32, # Number of attention heads\n", + " \"n_layers\": 32, # Number of layers\n", + " \"hidden_dim\": 14_336, # Size of the intermediate dimension in FeedForward\n", + " \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n", + " \"rope_base\": 50_000, # The base in RoPE's \"theta\"\n", + " \"dtype\": torch.bfloat16, # Lower-precision dtype to save memory\n", + " \"rope_freq\": { # RoPE frequency scaling\n", + " \"factor\": 8.0,\n", + " \"low_freq_factor\": 1.0,\n", + " \"high_freq_factor\": 4.0,\n", + " \"original_context_length\": 8192,\n", + " }\n", + "}\n", + "\n", + "\n", + "LLAMA32_CONFIG_1B = {\n", + " \"vocab_size\": 128_256, # Vocabulary size\n", + " \"context_length\": 8192, # Context length\n", + " \"emb_dim\": 2048, # NEW: Half the embedding dimension\n", + " \"n_heads\": 32, # Number of attention heads\n", + " \"n_layers\": 16, # NEW: Half the number of layers\n", + " \"hidden_dim\": 8192, # NEW: Almopst half the size of the intermediate dimension in FeedForward\n", + " \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n", + " \"rope_base\": 50_000, # The base in RoPE's \"theta\"\n", + " \"dtype\": torch.bfloat16, # Lower-precision dtype to save memory\n", + " \"rope_freq\": { # RoPE frequency scaling\n", + " \"factor\": 32.0, # NEW: Adjustment of the rescaling factor\n", + " \"low_freq_factor\": 1.0,\n", + " \"high_freq_factor\": 4.0,\n", + " \"original_context_length\": 8192,\n", + " }\n", + "}" + ] + }, + { + "cell_type": "markdown", + "id": "Dl4_0EoJKKYv", + "metadata": { + "id": "Dl4_0EoJKKYv" + }, + "source": [ + "- Below, we can reuse the code from the Llama 3.1 8B section to load the Llama 3.2 1B model\n", + "- Again, since the Llama 3.2 family is distinct from the Llama 3.1 family, you'd have to go to the [meta-llama/Llama-3.2-1B](https://huggingface.co/meta-llama/Llama-3.2-1B) repository and acknowledge the license terms for your Hugging Face access token to work for the download\n", + "- Tip: For simplicity, we only load the base model below, but there's also an instruction-finetuned version you can use by replacing `\"meta-llama/Llama-3.2-1B\"` with `\"meta-llama/Llama-3.2-1B-Instruct\"`" + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "id": "tCstHgyRRD2x", + "metadata": { + "id": "tCstHgyRRD2x" + }, + "outputs": [], + "source": [ + "# free up memory\n", + "del model\n", + "\n", + "\n", + "gc.collect() # Run Python garbage collector\n", + "\n", + "if torch.cuda.is_available():\n", + " torch.cuda.empty_cache()" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "id": "jt8BKAHXRCPI", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 49, + "referenced_widgets": [ + "db0b4112ca1c4070b6c08ae77198a07b", + "1795601570fa45599e28f012d7bbfea5", + "ec7f2c55d44e415c95fcdca659567ca2", + "00be0d9cc80e43d78d551b50a08b5bbf", + "bc506e05c702459ead566a98c17f6a34", + "05940cb26ce9448cb38b2e5b29e5f81b", + "f80e2773edfc4adcad5707c52acf3758", + "875451e9744b4bbd8798979adab5c6b7", + "44cd0fe55d62476390444f79b4e6c2b8", + "73228d2d2a6440bf903866f491344a7a", + "e083f2fa6a6a42a78923148a2c5f020f" + ] + }, + "id": "jt8BKAHXRCPI", + "outputId": "2a95b5da-69f5-40ba-c180-27c8150af310" + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "db0b4112ca1c4070b6c08ae77198a07b", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "tokenizer.model: 0%| | 0.00/2.18M [00:00=0.24.7 -sentencepiece>=0.1.99 \ No newline at end of file + modified: ../../README.md + modified: README.md + modified: converting-gpt-to-llama2.ipynb + modified: requirements-extra.txt + modified: ../README.md + +Untracked files: + (use "git add ..." to include in what will be committed) + converting-llama2-to-llama3.ipynb + llama3-files/ + llama32-files/ + standalone-llama32.ipynb diff --git a/ch05/07_gpt_to_llama/standalone-llama32.ipynb b/ch05/07_gpt_to_llama/standalone-llama32.ipynb new file mode 100644 index 0000000..f760eed --- /dev/null +++ b/ch05/07_gpt_to_llama/standalone-llama32.ipynb @@ -0,0 +1,968 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "e1b280ab-b61f-4d1a-bf7e-44e5f9ed3a5c", + "metadata": {}, + "source": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "
\n", + "\n", + "Supplementary code for the Build a Large Language Model From Scratch book by Sebastian Raschka
\n", + "
Code repository: https://github.com/rasbt/LLMs-from-scratch\n", + "
\n", + "
\n", + "\n", + "
" + ] + }, + { + "cell_type": "markdown", + "id": "efde77f2-6af3-4781-8597-89ecd3f41a52", + "metadata": {}, + "source": [ + "# Llama 3.2 From Scratch (A Standalone Notebook)" + ] + }, + { + "cell_type": "markdown", + "id": "55cdef4d-de59-4a65-89f9-fa2a8ef3471d", + "metadata": {}, + "source": [ + "- This notebook is purposefully minimal and focuses on the code to implement the Llama 3.2 1B and 3B LLMs\n", + "- For a step-by-step guide that explains the individual components and the relationship between GPT, Llama 2, and Llama 3, please see the following companion notebooks:\n", + " - [Converting a From-Scratch GPT Architecture to Llama 2](converting-gpt-to-llama2.ipynb)\n", + " - [Converting Llama 2 to Llama 3.2 From Scratch](converting-llama2-to-llama3.ipynb)\n", + " \n", + " \n", + "\n", + " \n", + " \n", + "- About the code:\n", + " - all code is my own code, mapping the Llama 3 architecture onto the model code implemented in my [Build A Large Language Model (From Scratch)](http://mng.bz/orYv) book; the code is released under a permissive open-source Apache 2.0 license (see [LICENSE.txt](https://github.com/rasbt/LLMs-from-scratch/blob/main/LICENSE.txt))\n", + " - the tokenizer code is inspired by the original [Llama 3 tokenizer code](https://github.com/meta-llama/llama3/blob/main/llama/tokenizer.py), which Meta AI used to to extends the Tiktoken GPT-4 tokenizer\n", + " - the RoPE rescaling section is inspired by the [_compute_llama3_parameters function](https://github.com/huggingface/transformers/blob/5c1027bf09717f664b579e01cbb8ec3ef5aeb140/src/transformers/modeling_rope_utils.py#L329-L347) in the `transformers` library" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "beef121b-2454-4577-8b56-aa00961089cb", + "metadata": {}, + "outputs": [], + "source": [ + "# pip install -r https://raw.githubusercontent.com/rasbt/LLMs-from-scratch/refs/heads/main/ch05/07_gpt_to_llama/requirements-extra.txt" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "dd1b65a8-4301-444a-bd7c-a6f2bd1df9df", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "blobfile version: 3.0.0\n", + "huggingface_hub version: 0.25.1\n", + "tiktoken version: 0.7.0\n", + "torch version: 2.4.0\n" + ] + } + ], + "source": [ + "from importlib.metadata import version\n", + "\n", + "pkgs = [\n", + " \"blobfile\", # to download pretrained weights\n", + " \"huggingface_hub\", # to download pretrained weights\n", + " \"tiktoken\", # to implement the tokenizer\n", + " \"torch\", # to implement the model\n", + "]\n", + "for p in pkgs:\n", + " print(f\"{p} version: {version(p)}\")" + ] + }, + { + "cell_type": "markdown", + "id": "653410a6-dd2b-4eb2-a722-23d9782e726d", + "metadata": {}, + "source": [ + " \n", + "# 1. Architecture code" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "82076c21-9331-4dcd-b017-42b046cf1a60", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "\n", + "\n", + "class FeedForward(nn.Module):\n", + " def __init__(self, cfg):\n", + " super().__init__()\n", + " self.fc1 = nn.Linear(cfg[\"emb_dim\"], cfg[\"hidden_dim\"], dtype=cfg[\"dtype\"], bias=False)\n", + " self.fc2 = nn.Linear(cfg[\"emb_dim\"], cfg[\"hidden_dim\"], dtype=cfg[\"dtype\"], bias=False)\n", + " self.fc3 = nn.Linear(cfg[\"hidden_dim\"], cfg[\"emb_dim\"], dtype=cfg[\"dtype\"], bias=False)\n", + "\n", + " def forward(self, x):\n", + " x_fc1 = self.fc1(x)\n", + " x_fc2 = self.fc2(x)\n", + " x = nn.functional.silu(x_fc1) * x_fc2\n", + " return self.fc3(x)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "4b9a346f-5826-4083-9162-abd56afc03f0", + "metadata": {}, + "outputs": [], + "source": [ + "def precompute_rope_params(head_dim, theta_base=10000, context_length=4096, freq_config=None):\n", + " assert head_dim % 2 == 0, \"Embedding dimension must be even\"\n", + "\n", + " # Compute the inverse frequencies\n", + " inv_freq = 1.0 / (theta_base ** (torch.arange(0, head_dim // 2) / (head_dim // 2)))\n", + "\n", + " # Frequency adjustments\n", + " if freq_config is not None:\n", + " low_freq_wavelen = freq_config[\"original_context_length\"] / freq_config[\"low_freq_factor\"]\n", + " high_freq_wavelen = freq_config[\"original_context_length\"] / freq_config[\"high_freq_factor\"]\n", + "\n", + " wavelen = 2 * torch.pi / inv_freq\n", + "\n", + " inv_freq_llama = torch.where(\n", + " wavelen > low_freq_wavelen, inv_freq / freq_config[\"factor\"], inv_freq\n", + " )\n", + "\n", + " smooth_factor = (freq_config[\"original_context_length\"] / wavelen - freq_config[\"low_freq_factor\"]) / (\n", + " freq_config[\"high_freq_factor\"] - freq_config[\"low_freq_factor\"]\n", + " )\n", + "\n", + " smoothed_inv_freq = (\n", + " (1 - smooth_factor) * (inv_freq / freq_config[\"factor\"]) + smooth_factor * inv_freq\n", + " )\n", + "\n", + " is_medium_freq = (wavelen <= low_freq_wavelen) & (wavelen >= high_freq_wavelen)\n", + " inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)\n", + " inv_freq = inv_freq_llama\n", + "\n", + " # Generate position indices\n", + " positions = torch.arange(context_length)\n", + "\n", + " # Compute the angles\n", + " angles = positions[:, None] * inv_freq[None, :] # Shape: (context_length, head_dim // 2)\n", + "\n", + " # Expand angles to match the head_dim\n", + " angles = torch.cat([angles, angles], dim=1) # Shape: (context_length, head_dim)\n", + "\n", + " # Precompute sine and cosine\n", + " cos = torch.cos(angles)\n", + " sin = torch.sin(angles)\n", + "\n", + " return cos, sin\n", + "\n", + "\n", + "def compute_rope(x, cos, sin):\n", + " # x: (batch_size, num_heads, seq_len, head_dim)\n", + " batch_size, num_heads, seq_len, head_dim = x.shape\n", + " assert head_dim % 2 == 0, \"Head dimension must be even\"\n", + "\n", + " # Split x into first half and second half\n", + " x1 = x[..., : head_dim // 2] # First half\n", + " x2 = x[..., head_dim // 2 :] # Second half\n", + "\n", + " # Adjust sin and cos shapes\n", + " cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0) # Shape: (1, 1, seq_len, head_dim)\n", + " sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0)\n", + "\n", + " # Apply the rotary transformation\n", + " rotated = torch.cat((-x2, x1), dim=-1)\n", + " x_rotated = (x * cos) + (rotated * sin)\n", + "\n", + " return x_rotated.to(dtype=x.dtype)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "e8169ab5-f976-4222-a2e1-eb1cabf267cb", + "metadata": {}, + "outputs": [], + "source": [ + "class GroupedQueryAttention(nn.Module):\n", + " def __init__(\n", + " self, d_in, d_out, context_length, num_heads,\n", + " num_kv_groups,\n", + " rope_base=10_000,\n", + " rope_config=None,\n", + " dtype=None\n", + " ):\n", + " super().__init__()\n", + " assert d_out % num_heads == 0, \"d_out must be divisible by num_heads\"\n", + " assert num_heads % num_kv_groups == 0, \"num_heads must be divisible by num_kv_groups\"\n", + "\n", + " self.d_out = d_out\n", + " self.num_heads = num_heads\n", + " self.head_dim = d_out // num_heads\n", + "\n", + " self.W_key = nn.Linear(d_in, num_kv_groups * self.head_dim, bias=False, dtype=dtype)\n", + " self.W_value = nn.Linear(d_in, num_kv_groups * self.head_dim, bias=False, dtype=dtype)\n", + " self.num_kv_groups = num_kv_groups\n", + " self.group_size = num_heads // num_kv_groups\n", + "\n", + " self.W_query = nn.Linear(d_in, d_out, bias=False, dtype=dtype)\n", + " self.out_proj = nn.Linear(d_out, d_out, bias=False, dtype=dtype)\n", + "\n", + " self.register_buffer(\"mask\", torch.triu(torch.ones(context_length, context_length), diagonal=1))\n", + " cos, sin = precompute_rope_params(\n", + " head_dim=self.head_dim,\n", + " theta_base=rope_base,\n", + " freq_config=rope_config,\n", + " context_length=8192\n", + " )\n", + " self.register_buffer(\"cos\", cos)\n", + " self.register_buffer(\"sin\", sin)\n", + "\n", + " def forward(self, x):\n", + " b, num_tokens, d_in = x.shape\n", + "\n", + " queries = self.W_query(x) # Shape: (b, num_tokens, d_out)\n", + " keys = self.W_key(x) # Shape: (b, num_tokens, num_kv_groups * head_dim)\n", + " values = self.W_value(x) # Shape: (b, num_tokens, num_kv_groups * head_dim)\n", + "\n", + " # Reshape queries, keys, and values\n", + " queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)\n", + " keys = keys.view(b, num_tokens, self.num_kv_groups, self.head_dim)\n", + " values = values.view(b, num_tokens, self.num_kv_groups, self.head_dim)\n", + "\n", + " # Transpose keys, values, and queries\n", + " keys = keys.transpose(1, 2) # Shape: (b, num_heads, num_tokens, head_dim)\n", + " values = values.transpose(1, 2) # Shape: (b, num_heads, num_tokens, head_dim)\n", + " queries = queries.transpose(1, 2) # Shape: (b, num_query_groups, num_tokens, head_dim)\n", + "\n", + " # Apply RoPE\n", + " keys = compute_rope(keys, self.cos, self.sin)\n", + " queries = compute_rope(queries, self.cos, self.sin)\n", + "\n", + " # Expand keys and values to match the number of heads\n", + " # Shape: (b, num_heads, num_tokens, head_dim)\n", + " keys = keys.repeat_interleave(self.group_size, dim=1) # Shape: (b, num_heads, num_tokens, head_dim)\n", + " values = values.repeat_interleave(self.group_size, dim=1) # Shape: (b, num_heads, num_tokens, head_dim)\n", + " # For example, before repeat_interleave along dim=1 (query groups):\n", + " # [K1, K2]\n", + " # After repeat_interleave (each query group is repeated group_size times):\n", + " # [K1, K1, K2, K2]\n", + " # If we used regular repeat instead of repeat_interleave, we'd get:\n", + " # [K1, K2, K1, K2]\n", + "\n", + " # Compute scaled dot-product attention (aka self-attention) with a causal mask\n", + " # Shape: (b, num_heads, num_tokens, num_tokens)\n", + " attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head\n", + "\n", + " # Original mask truncated to the number of tokens and converted to boolean\n", + " mask_bool = self.mask.bool()[:num_tokens, :num_tokens]\n", + "\n", + " # Use the mask to fill attention scores\n", + " attn_scores.masked_fill_(mask_bool, -torch.inf)\n", + "\n", + " attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)\n", + " assert keys.shape[-1] == self.head_dim\n", + "\n", + " # Shape: (b, num_tokens, num_heads, head_dim)\n", + " context_vec = (attn_weights @ values).transpose(1, 2)\n", + "\n", + " # Combine heads, where self.d_out = self.num_heads * self.head_dim\n", + " context_vec = context_vec.reshape(b, num_tokens, self.d_out)\n", + " context_vec = self.out_proj(context_vec) # optional projection\n", + "\n", + " return context_vec" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "457cb2f8-50c1-4045-8a74-f181bfb5fea9", + "metadata": {}, + "outputs": [], + "source": [ + "class TransformerBlock(nn.Module):\n", + " def __init__(self, cfg):\n", + " super().__init__()\n", + " self.att = GroupedQueryAttention(\n", + " d_in=cfg[\"emb_dim\"],\n", + " d_out=cfg[\"emb_dim\"],\n", + " context_length=cfg[\"context_length\"],\n", + " num_heads=cfg[\"n_heads\"],\n", + " num_kv_groups=cfg[\"n_kv_groups\"],\n", + " rope_base=cfg[\"rope_base\"],\n", + " rope_config=cfg[\"rope_freq\"],\n", + " dtype=cfg[\"dtype\"]\n", + " )\n", + " self.ff = FeedForward(cfg)\n", + " self.norm1 = nn.RMSNorm(cfg[\"emb_dim\"], eps=1e-5)\n", + " self.norm2 = nn.RMSNorm(cfg[\"emb_dim\"], eps=1e-5)\n", + "\n", + " def forward(self, x):\n", + " # Shortcut connection for attention block\n", + " shortcut = x\n", + " x = self.norm1(x)\n", + " x = self.att(x.to(torch.bfloat16)) # Shape [batch_size, num_tokens, emb_size]\n", + " x = x + shortcut # Add the original input back\n", + "\n", + " # Shortcut connection for feed-forward block\n", + " shortcut = x\n", + " x = self.norm2(x)\n", + " x = self.ff(x.to(torch.bfloat16))\n", + " x = x + shortcut # Add the original input back\n", + "\n", + " return x" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "e88de3e3-9f07-42cc-816b-28dbd46e96c4", + "metadata": {}, + "outputs": [], + "source": [ + "class Llama3Model(nn.Module):\n", + " def __init__(self, cfg):\n", + " super().__init__()\n", + " self.tok_emb = nn.Embedding(cfg[\"vocab_size\"], cfg[\"emb_dim\"], dtype=cfg[\"dtype\"])\n", + "\n", + " self.trf_blocks = nn.Sequential(\n", + " *[TransformerBlock(cfg) for _ in range(cfg[\"n_layers\"])])\n", + "\n", + " self.final_norm = nn.RMSNorm(cfg[\"emb_dim\"], eps=1e-5)\n", + " self.out_head = nn.Linear(cfg[\"emb_dim\"], cfg[\"vocab_size\"], bias=False, dtype=cfg[\"dtype\"])\n", + "\n", + " def forward(self, in_idx):\n", + " batch_size, seq_len = in_idx.shape\n", + " tok_embeds = self.tok_emb(in_idx)\n", + " x = tok_embeds\n", + " x = self.trf_blocks(x)\n", + " x = self.final_norm(x)\n", + " logits = self.out_head(x.to(torch.bfloat16))\n", + " return logits" + ] + }, + { + "cell_type": "markdown", + "id": "be2d201f-74ad-4d63-ab9c-601b00674a48", + "metadata": {}, + "source": [ + " \n", + "# 2. Initialize model" + ] + }, + { + "cell_type": "markdown", + "id": "23dea40c-fe20-4a75-be25-d6fce5863c01", + "metadata": {}, + "source": [ + "- The remainder of this notebook uses the Llama 3.2 1B model; to use the 3B model variant, just uncomment the second configuration file in the following code cell" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "caa142fa-b375-4e78-b392-2072ced666f3", + "metadata": {}, + "outputs": [], + "source": [ + "# Llama 3.2 1B\n", + "\n", + "LLAMA32_CONFIG = {\n", + " \"vocab_size\": 128_256, # Vocabulary size\n", + " \"context_length\": 8192, # Context length\n", + " \"emb_dim\": 2048, # Embedding dimension\n", + " \"n_heads\": 32, # Number of attention heads\n", + " \"n_layers\": 16, # Number of layers\n", + " \"hidden_dim\": 8192, # Size of the intermediate dimension in FeedForward\n", + " \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n", + " \"rope_base\": 50_000, # The base in RoPE's \"theta\"\n", + " \"dtype\": torch.bfloat16, # Lower-precision dtype to save memory\n", + " \"rope_freq\": { # RoPE frequency scaling\n", + " \"factor\": 32.0,\n", + " \"low_freq_factor\": 1.0,\n", + " \"high_freq_factor\": 4.0,\n", + " \"original_context_length\": 8192,\n", + " }\n", + "}\n", + "\n", + "# Llama 3.2 3B\n", + "\n", + "# LLAMA32_CONFIG = {\n", + "# \"vocab_size\": 128_256, # Vocabulary size\n", + "# \"context_length\": 8192, # Context length\n", + "# \"emb_dim\": 3072, # Embedding dimension\n", + "# \"n_heads\": 24, # Number of attention heads\n", + "# \"n_layers\": 28, # Number of layers\n", + "# \"hidden_dim\": 8192, # Size of the intermediate dimension in FeedForward\n", + "# \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n", + "# \"rope_base\": 50_000, # The base in RoPE's \"theta\"\n", + "# \"dtype\": torch.bfloat16, # Lower-precision dtype to save memory\n", + "# \"rope_freq\": { # RoPE frequency scaling\n", + "# \"factor\": 32.0,\n", + "# \"low_freq_factor\": 1.0,\n", + "# \"high_freq_factor\": 4.0,\n", + "# \"original_context_length\": 8192,\n", + "# }\n", + "# }\n", + "\n", + "LLAMA_SIZE_STR = \"1B\" if LLAMA32_CONFIG[\"emb_dim\"] == 2048 else \"3B\"" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "156253fe-aacd-4da2-8f13-705f05c4b11e", + "metadata": {}, + "outputs": [], + "source": [ + "model = Llama3Model(LLAMA32_CONFIG)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "364e76ca-52f8-4fa5-af37-c4069f9694bc", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Total number of parameters: 1,498,482,688\n" + ] + } + ], + "source": [ + "total_params = sum(p.numel() for p in model.parameters())\n", + "print(f\"Total number of parameters: {total_params:,}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "fd5efb03-5a07-46e8-8607-93ed47549d2b", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "float32 (PyTorch default): 15.23 GB\n", + "bfloat16: 7.61 GB\n" + ] + } + ], + "source": [ + "def model_memory_size(model, input_dtype=torch.float32):\n", + " total_params = 0\n", + " total_grads = 0\n", + " for param in model.parameters():\n", + " # Calculate total number of elements per parameter\n", + " param_size = param.numel()\n", + " total_params += param_size\n", + " # Check if gradients are stored for this parameter\n", + " if param.requires_grad:\n", + " total_grads += param_size\n", + "\n", + " # Calculate buffer size (non-parameters that require memory)\n", + " total_buffers = sum(buf.numel() for buf in model.buffers())\n", + "\n", + " # Size in bytes = (Number of elements) * (Size of each element in bytes)\n", + " # We assume parameters and gradients are stored in the same type as input dtype\n", + " element_size = torch.tensor(0, dtype=input_dtype).element_size()\n", + " total_memory_bytes = (total_params + total_grads + total_buffers) * element_size\n", + "\n", + " # Convert bytes to gigabytes\n", + " total_memory_gb = total_memory_bytes / (1024**3)\n", + "\n", + " return total_memory_gb\n", + "\n", + "print(f\"float32 (PyTorch default): {model_memory_size(model, input_dtype=torch.float32):.2f} GB\")\n", + "print(f\"bfloat16: {model_memory_size(model, input_dtype=torch.bfloat16):.2f} GB\")" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "31f12baf-f79b-499f-85c0-51328a6a20f5", + "metadata": {}, + "outputs": [], + "source": [ + "if torch.cuda.is_available():\n", + " device = torch.device(\"cuda\")\n", + "elif torch.backends.mps.is_available():\n", + " device = torch.device(\"mps\")\n", + "else:\n", + " device = torch.device(\"cpu\")\n", + "\n", + "model.to(device);" + ] + }, + { + "cell_type": "markdown", + "id": "78e091e1-afa8-4d23-9aea-cced86181bfd", + "metadata": {}, + "source": [ + " \n", + "# 3. Load tokenizer" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "9482b01c-49f9-48e4-ab2c-4a4c75240e77", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "from pathlib import Path\n", + "\n", + "import tiktoken\n", + "from tiktoken.load import load_tiktoken_bpe\n", + "\n", + "\n", + "class Tokenizer:\n", + " def __init__(self, model_path):\n", + " assert os.path.isfile(model_path), f\"Model file {model_path} not found\"\n", + " mergeable_ranks = load_tiktoken_bpe(model_path)\n", + " num_base_tokens = len(mergeable_ranks)\n", + "\n", + " self.special_tokens = {\n", + " \"<|begin_of_text|>\": 128000,\n", + " \"<|end_of_text|>\": 128001,\n", + " \"<|start_header_id|>\": 128006,\n", + " \"<|end_header_id|>\": 128007,\n", + " \"<|eot_id|>\": 128009,\n", + " }\n", + " self.special_tokens.update({\n", + " f\"<|reserved_{i}|>\": 128002 + i for i in range(256) if (128002 + i) not in self.special_tokens.values()\n", + " })\n", + "\n", + " self.model = tiktoken.Encoding(\n", + " name=Path(model_path).name,\n", + " pat_str=r\"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+\",\n", + " mergeable_ranks=mergeable_ranks,\n", + " special_tokens=self.special_tokens\n", + " )\n", + "\n", + "\n", + " def encode(self, text, bos=False, eos=False, allowed_special=set(), disallowed_special=()):\n", + " if bos:\n", + " tokens = [self.special_tokens[\"<|begin_of_text|>\"]]\n", + " else:\n", + " tokens = []\n", + "\n", + " tokens += self.model.encode(text, allowed_special=allowed_special, disallowed_special=disallowed_special)\n", + "\n", + " if eos:\n", + " tokens.append(self.special_tokens[\"<|end_of_text|>\"])\n", + " return tokens\n", + "\n", + " def decode(self, tokens):\n", + " return self.model.decode(tokens)\n", + " \n", + "\n", + "class ChatFormat:\n", + " def __init__(self, tokenizer):\n", + " self.tokenizer = tokenizer\n", + "\n", + " def encode_header(self, message):\n", + " tokens = []\n", + " tokens.append(self.tokenizer.special_tokens[\"<|start_header_id|>\"])\n", + " tokens.extend(self.tokenizer.encode(message[\"role\"], bos=False, eos=False))\n", + " tokens.append(self.tokenizer.special_tokens[\"<|end_header_id|>\"])\n", + " tokens.extend(self.tokenizer.encode(\"\\n\\n\", bos=False, eos=False))\n", + " return tokens\n", + "\n", + " def encode(self, text):\n", + " message = {\n", + " \"role\": \"user\",\n", + " \"content\": text\n", + " }\n", + "\n", + " tokens = self.encode_header(message)\n", + " tokens.extend(\n", + " self.tokenizer.encode(message[\"content\"].strip(), bos=False, eos=False)\n", + " )\n", + " tokens.append(self.tokenizer.special_tokens[\"<|eot_id|>\"])\n", + " return tokens\n", + "\n", + " def decode(self, token_ids):\n", + " return self.tokenizer.decode(token_ids)" + ] + }, + { + "cell_type": "markdown", + "id": "b771b60c-c198-4b30-bf10-42031197ae86", + "metadata": {}, + "source": [ + "- Please note that Meta AI requires that you accept the Llama 3,2 licensing terms before you can download the files; to do this, you have to create a Hugging Face Hub account and visit the [meta-llama/Llama-3.2-1B](https://huggingface.co/meta-llama/Llama-3.2-1B) repository to accept the terms\n", + "- Next, you will need to create an access token; to generate an access token with READ permissions, click on the profile picture in the upper right and click on \"Settings\"\n", + "\n", + "\n", + "\n", + "\n", + "- Then, create and copy the access token so you can copy & paste it into the next code cell\n", + "\n", + "" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "edcc384a-adb7-43f6-acc3-ebe4b182ec91", + "metadata": {}, + "outputs": [], + "source": [ + "from huggingface_hub import login\n", + "\n", + "login()" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "986bc1a0-804f-4154-80f8-44cefbee1368", + "metadata": {}, + "outputs": [], + "source": [ + "from huggingface_hub import hf_hub_download\n", + "\n", + "tokenizer_file_path = hf_hub_download(\n", + " repo_id=f\"meta-llama/Llama-3.2-{LLAMA_SIZE_STR}-Instruct\",\n", + " filename=\"original/tokenizer.model\",\n", + " local_dir=\"llama32-files\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "f5a3014f-4c66-4fe2-874e-7b57562c49ad", + "metadata": {}, + "outputs": [], + "source": [ + "tokenizer = Tokenizer(tokenizer_file_path)\n", + "chat_tokenizer = ChatFormat(tokenizer)" + ] + }, + { + "cell_type": "markdown", + "id": "c172f89f-d301-439f-b809-46169e5f5945", + "metadata": {}, + "source": [ + " \n", + "# 4. Load pretrained weights" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "75166128-5899-4995-9b88-9672e135650e", + "metadata": {}, + "outputs": [], + "source": [ + "def assign(left, right, tensor_name=\"unknown\"):\n", + " if left.shape != right.shape:\n", + " raise ValueError(f\"Shape mismatch in tensor '{tensor_name}'. Left: {left.shape}, Right: {right.shape}\")\n", + "\n", + " if isinstance(right, torch.Tensor):\n", + " return torch.nn.Parameter(right.clone().detach())\n", + " else:\n", + " return torch.nn.Parameter(torch.tensor(right))\n", + "\n", + "\n", + "def load_weights_into_llama(model, param_config, params):\n", + " model.tok_emb.weight = assign(model.tok_emb.weight, params[\"model.embed_tokens.weight\"], \"model.embed_tokens.weight\")\n", + "\n", + " for l in range(param_config[\"n_layers\"]):\n", + "\n", + " # Load attention weights\n", + " model.trf_blocks[l].att.W_query.weight = assign(\n", + " model.trf_blocks[l].att.W_query.weight,\n", + " params[f\"model.layers.{l}.self_attn.q_proj.weight\"],\n", + " f\"model.layers.{l}.self_attn.q_proj.weight\"\n", + " )\n", + " model.trf_blocks[l].att.W_key.weight = assign(\n", + " model.trf_blocks[l].att.W_key.weight,\n", + " params[f\"model.layers.{l}.self_attn.k_proj.weight\"],\n", + " f\"model.layers.{l}.self_attn.k_proj.weight\"\n", + " )\n", + " model.trf_blocks[l].att.W_value.weight = assign(\n", + " model.trf_blocks[l].att.W_value.weight,\n", + " params[f\"model.layers.{l}.self_attn.v_proj.weight\"],\n", + " f\"model.layers.{l}.self_attn.v_proj.weight\"\n", + " )\n", + " model.trf_blocks[l].att.out_proj.weight = assign(\n", + " model.trf_blocks[l].att.out_proj.weight,\n", + " params[f\"model.layers.{l}.self_attn.o_proj.weight\"],\n", + " f\"model.layers.{l}.self_attn.o_proj.weight\"\n", + " )\n", + " model.trf_blocks[l].norm1.weight = assign(\n", + " model.trf_blocks[l].norm1.weight,\n", + " params[f\"model.layers.{l}.input_layernorm.weight\"],\n", + " f\"model.layers.{l}.input_layernorm.weight\"\n", + " )\n", + "\n", + " # Load FeedForward weights\n", + " model.trf_blocks[l].ff.fc1.weight = assign(\n", + " model.trf_blocks[l].ff.fc1.weight,\n", + " params[f\"model.layers.{l}.mlp.gate_proj.weight\"],\n", + " f\"model.layers.{l}.mlp.gate_proj.weight\"\n", + " )\n", + " model.trf_blocks[l].ff.fc2.weight = assign(\n", + " model.trf_blocks[l].ff.fc2.weight,\n", + " params[f\"model.layers.{l}.mlp.up_proj.weight\"],\n", + " f\"model.layers.{l}.mlp.up_proj.weight\"\n", + " )\n", + " model.trf_blocks[l].ff.fc3.weight = assign(\n", + " model.trf_blocks[l].ff.fc3.weight,\n", + " params[f\"model.layers.{l}.mlp.down_proj.weight\"],\n", + " f\"model.layers.{l}.mlp.down_proj.weight\"\n", + " )\n", + " model.trf_blocks[l].norm2.weight = assign(\n", + " model.trf_blocks[l].norm2.weight,\n", + " params[f\"model.layers.{l}.post_attention_layernorm.weight\"],\n", + " f\"model.layers.{l}.post_attention_layernorm.weight\"\n", + " )\n", + "\n", + " # Load output layer weights\n", + " model.final_norm.weight = assign(model.final_norm.weight, params[\"model.norm.weight\"], \"model.norm.weight\")\n", + "\n", + " if \"lm_head.weight\" in params.keys():\n", + " model.out_head.weight = assign(model.out_head.weight, params[\"lm_head.weight\"], \"lm_head.weight\")\n", + " else:\n", + " model.out_head.weight = assign(model.out_head.weight, params[\"model.embed_tokens.weight\"], \"model.embed_tokens.weight\")" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "699cb1b8-a67d-49fb-80a6-0dad9d81f392", + "metadata": {}, + "outputs": [], + "source": [ + "from safetensors.torch import load_file\n", + "\n", + "\n", + "if LLAMA_SIZE_STR == \"1B\":\n", + " weights_file = hf_hub_download(\n", + " repo_id=f\"meta-llama/Llama-3.2-{LLAMA_SIZE_STR}-Instruct\",\n", + " filename=f\"model.safetensors\",\n", + " local_dir=\"llama32-files\"\n", + " )\n", + " combined_weights = load_file(weights_file)\n", + "\n", + "\n", + "else:\n", + " combined_weights = {}\n", + " for i in range(1, 5):\n", + " weights_file = hf_hub_download(\n", + " repo_id=f\"meta-llama/Llama-3.2-{LLAMA_SIZE_STR}-Instruct\",\n", + " filename=f\"model-0000{i}-of-00002.safetensors\",\n", + " local_dir=\"llama3-files\"\n", + " )\n", + " current_weights = load_file(weights_file)\n", + " combined_weights.update(current_weights)\n", + " \n", + "\n", + "load_weights_into_llama(model, LLAMA32_CONFIG, combined_weights)\n", + "model.to(device);" + ] + }, + { + "cell_type": "markdown", + "id": "57d07df1-4401-4792-b549-7c4cc5632323", + "metadata": {}, + "source": [ + " \n", + "# 5. Generate text" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "7b8401c6-e244-4cb7-9849-2ba71ce758d5", + "metadata": {}, + "outputs": [], + "source": [ + "def text_to_token_ids(text, tokenizer):\n", + " encoded = tokenizer.encode(text)\n", + " encoded_tensor = torch.tensor(encoded).unsqueeze(0) # add batch dimension\n", + " return encoded_tensor\n", + "\n", + "\n", + "def token_ids_to_text(token_ids, tokenizer):\n", + " flat = token_ids.squeeze(0) # remove batch dimension\n", + " return tokenizer.decode(flat.tolist())\n", + "\n", + "\n", + "def generate(model, idx, max_new_tokens, context_size, temperature=0.0, top_k=None, eos_id=None):\n", + "\n", + " # For-loop is the same as before: Get logits, and only focus on last time step\n", + " for _ in range(max_new_tokens):\n", + " idx_cond = idx[:, -context_size:]\n", + " with torch.no_grad():\n", + " logits = model(idx_cond)\n", + " logits = logits[:, -1, :]\n", + "\n", + " # New: Filter logits with top_k sampling\n", + " if top_k is not None:\n", + " # Keep only top_k values\n", + " top_logits, _ = torch.topk(logits, top_k)\n", + " min_val = top_logits[:, -1]\n", + " logits = torch.where(logits < min_val, torch.tensor(float('-inf')).to(logits.device), logits)\n", + "\n", + " # New: Apply temperature scaling\n", + " if temperature > 0.0:\n", + " logits = logits / temperature\n", + "\n", + " # Apply softmax to get probabilities\n", + " probs = torch.softmax(logits, dim=-1) # (batch_size, context_len)\n", + "\n", + " # Sample from the distribution\n", + " idx_next = torch.multinomial(probs, num_samples=1) # (batch_size, 1)\n", + "\n", + " # Otherwise same as before: get idx of the vocab entry with the highest logits value\n", + " else:\n", + " idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch_size, 1)\n", + "\n", + " if idx_next == eos_id: # Stop generating early if end-of-sequence token is encountered and eos_id is specified\n", + " break\n", + "\n", + " # Same as before: append sampled index to the running sequence\n", + " idx = torch.cat((idx, idx_next), dim=1) # (batch_size, num_tokens+1)\n", + "\n", + " return idx" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "1c7a04fa-6aac-416b-8f63-f1e19227633d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Output text:\n", + " Llamas are herbivores, which means they primarily eat plants and plant-based foods. Their diet typically consists of:\n", + "\n", + "1. Grasses: Llamas love to graze on various types of grasses, including tall grasses and short grasses.\n", + "2. Hay: Llamas also eat hay, which is a dry, compressed form of grass or other plants.\n", + "3. Alfalfa: Alfalfa is a legume that is commonly fed to llamas, as it is high in protein and fiber.\n", + "4. Other plants: Llamas will also eat other plants, such as clover, wild grasses, and shrubs.\n", + "\n", + "It's worth noting that llamas are adapted to high altitudes and\n" + ] + } + ], + "source": [ + "import re\n", + "\n", + "\n", + "PROMPT = \"What do llamas eat?\"\n", + "\n", + "torch.manual_seed(123)\n", + "\n", + "token_ids = generate(\n", + " model=model,\n", + " idx=text_to_token_ids(PROMPT, chat_tokenizer).to(device),\n", + " max_new_tokens=150,\n", + " context_size=LLAMA32_CONFIG[\"context_length\"],\n", + " top_k=1,\n", + " temperature=0.\n", + ")\n", + "\n", + "output_text = token_ids_to_text(token_ids, tokenizer)\n", + "\n", + "\n", + "def clean_text(text, header_end=\"assistant<|end_header_id|>\\n\\n\"):\n", + " # Find the index of the first occurrence of \"<|end_header_id|>\"\n", + " index = text.find(header_end)\n", + "\n", + " if index != -1:\n", + " # Return the substring starting after \"<|end_header_id|>\"\n", + " return text[index + len(header_end):].strip() # Strip removes leading/trailing whitespace\n", + " else:\n", + " # If the token is not found, return the original text\n", + " return text\n", + "\n", + "print(\"Output text:\\n\", clean_text(output_text))" + ] + }, + { + "cell_type": "markdown", + "id": "549324d6-5c71-4147-ae21-2e67675faa3d", + "metadata": {}, + "source": [ + " \n", + "# What's next?" + ] + }, + { + "cell_type": "markdown", + "id": "e6edaaae-2de1-406c-8ffa-897cdfa3808c", + "metadata": {}, + "source": [ + "- The notebook was kept purposefully minimal; if you are interested in additional explanation about the individual components, check out the following two companion notebooks:\n", + "\n", + "\n", + "\n", + " 1. [Converting a From-Scratch GPT Architecture to Llama 2](converting-gpt-to-llama2.ipynb)\n", + " 2. [Converting Llama 2 to Llama 3.2 From Scratch](converting-llama2-to-llama3.ipynb)\n", + " \n", + "- For those interested in a comprehensive guide on building a large language model from scratch and gaining a deeper understanding of its mechanics, you might like my [Build a Large Language Model (From Scratch)](http://mng.bz/orYv)\n", + "\n", + "" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bf864c28-2ce1-44bf-84e4-c0671f494d62", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.4" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}