{ "cells": [ { "cell_type": "markdown", "id": "e1b280ab-b61f-4d1a-bf7e-44e5f9ed3a5c", "metadata": {}, "source": [ "\n", "\n", "\n", "\n", "\n", "
\n", "\n", "Supplementary code for the Build a Large Language Model From Scratch book by Sebastian Raschka
\n", "
Code repository: https://github.com/rasbt/LLMs-from-scratch\n", "
\n", "
\n", "\n", "
" ] }, { "cell_type": "markdown", "id": "efde77f2-6af3-4781-8597-89ecd3f41a52", "metadata": {}, "source": [ "# Llama 3.2 From Scratch (A Standalone Notebook)" ] }, { "cell_type": "markdown", "id": "55cdef4d-de59-4a65-89f9-fa2a8ef3471d", "metadata": {}, "source": [ "- This notebook is purposefully minimal and focuses on the code to implement the Llama 3.2 1B and 3B LLMs\n", "- For a step-by-step guide that explains the individual components and the relationship between GPT, Llama 2, and Llama 3, please see the following companion notebooks:\n", " - [Converting a From-Scratch GPT Architecture to Llama 2](converting-gpt-to-llama2.ipynb)\n", " - [Converting Llama 2 to Llama 3.2 From Scratch](converting-llama2-to-llama3.ipynb)\n", " \n", " \n", "\n", " \n", " \n", "- About the code:\n", " - all code is my own code, mapping the Llama 3 architecture onto the model code implemented in my [Build A Large Language Model (From Scratch)](http://mng.bz/orYv) book; the code is released under a permissive open-source Apache 2.0 license (see [LICENSE.txt](https://github.com/rasbt/LLMs-from-scratch/blob/main/LICENSE.txt))\n", " - the tokenizer code is inspired by the original [Llama 3 tokenizer code](https://github.com/meta-llama/llama3/blob/main/llama/tokenizer.py), which Meta AI used to to extends the Tiktoken GPT-4 tokenizer\n", " - the RoPE rescaling section is inspired by the [_compute_llama3_parameters function](https://github.com/huggingface/transformers/blob/5c1027bf09717f664b579e01cbb8ec3ef5aeb140/src/transformers/modeling_rope_utils.py#L329-L347) in the `transformers` library" ] }, { "cell_type": "code", "execution_count": 1, "id": "beef121b-2454-4577-8b56-aa00961089cb", "metadata": {}, "outputs": [], "source": [ "# pip install -r https://raw.githubusercontent.com/rasbt/LLMs-from-scratch/refs/heads/main/ch05/07_gpt_to_llama/requirements-extra.txt" ] }, { "cell_type": "code", "execution_count": 2, "id": "dd1b65a8-4301-444a-bd7c-a6f2bd1df9df", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "blobfile version: 3.0.0\n", "huggingface_hub version: 0.25.1\n", "tiktoken version: 0.7.0\n", "torch version: 2.4.0\n" ] } ], "source": [ "from importlib.metadata import version\n", "\n", "pkgs = [\n", " \"blobfile\", # to download pretrained weights\n", " \"huggingface_hub\", # to download pretrained weights\n", " \"tiktoken\", # to implement the tokenizer\n", " \"torch\", # to implement the model\n", "]\n", "for p in pkgs:\n", " print(f\"{p} version: {version(p)}\")" ] }, { "cell_type": "markdown", "id": "653410a6-dd2b-4eb2-a722-23d9782e726d", "metadata": {}, "source": [ " \n", "# 1. Architecture code" ] }, { "cell_type": "code", "execution_count": 3, "id": "82076c21-9331-4dcd-b017-42b046cf1a60", "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "\n", "\n", "class FeedForward(nn.Module):\n", " def __init__(self, cfg):\n", " super().__init__()\n", " self.fc1 = nn.Linear(cfg[\"emb_dim\"], cfg[\"hidden_dim\"], dtype=cfg[\"dtype\"], bias=False)\n", " self.fc2 = nn.Linear(cfg[\"emb_dim\"], cfg[\"hidden_dim\"], dtype=cfg[\"dtype\"], bias=False)\n", " self.fc3 = nn.Linear(cfg[\"hidden_dim\"], cfg[\"emb_dim\"], dtype=cfg[\"dtype\"], bias=False)\n", "\n", " def forward(self, x):\n", " x_fc1 = self.fc1(x)\n", " x_fc2 = self.fc2(x)\n", " x = nn.functional.silu(x_fc1) * x_fc2\n", " return self.fc3(x)" ] }, { "cell_type": "code", "execution_count": 4, "id": "4b9a346f-5826-4083-9162-abd56afc03f0", "metadata": {}, "outputs": [], "source": [ "def precompute_rope_params(head_dim, theta_base=10000, context_length=4096, freq_config=None):\n", " assert head_dim % 2 == 0, \"Embedding dimension must be even\"\n", "\n", " # Compute the inverse frequencies\n", " inv_freq = 1.0 / (theta_base ** (torch.arange(0, head_dim // 2) / (head_dim // 2)))\n", "\n", " # Frequency adjustments\n", " if freq_config is not None:\n", " low_freq_wavelen = freq_config[\"original_context_length\"] / freq_config[\"low_freq_factor\"]\n", " high_freq_wavelen = freq_config[\"original_context_length\"] / freq_config[\"high_freq_factor\"]\n", "\n", " wavelen = 2 * torch.pi / inv_freq\n", "\n", " inv_freq_llama = torch.where(\n", " wavelen > low_freq_wavelen, inv_freq / freq_config[\"factor\"], inv_freq\n", " )\n", "\n", " smooth_factor = (freq_config[\"original_context_length\"] / wavelen - freq_config[\"low_freq_factor\"]) / (\n", " freq_config[\"high_freq_factor\"] - freq_config[\"low_freq_factor\"]\n", " )\n", "\n", " smoothed_inv_freq = (\n", " (1 - smooth_factor) * (inv_freq / freq_config[\"factor\"]) + smooth_factor * inv_freq\n", " )\n", "\n", " is_medium_freq = (wavelen <= low_freq_wavelen) & (wavelen >= high_freq_wavelen)\n", " inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)\n", " inv_freq = inv_freq_llama\n", "\n", " # Generate position indices\n", " positions = torch.arange(context_length)\n", "\n", " # Compute the angles\n", " angles = positions[:, None] * inv_freq[None, :] # Shape: (context_length, head_dim // 2)\n", "\n", " # Expand angles to match the head_dim\n", " angles = torch.cat([angles, angles], dim=1) # Shape: (context_length, head_dim)\n", "\n", " # Precompute sine and cosine\n", " cos = torch.cos(angles)\n", " sin = torch.sin(angles)\n", "\n", " return cos, sin\n", "\n", "\n", "def compute_rope(x, cos, sin):\n", " # x: (batch_size, num_heads, seq_len, head_dim)\n", " batch_size, num_heads, seq_len, head_dim = x.shape\n", " assert head_dim % 2 == 0, \"Head dimension must be even\"\n", "\n", " # Split x into first half and second half\n", " x1 = x[..., : head_dim // 2] # First half\n", " x2 = x[..., head_dim // 2 :] # Second half\n", "\n", " # Adjust sin and cos shapes\n", " cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0) # Shape: (1, 1, seq_len, head_dim)\n", " sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0)\n", "\n", " # Apply the rotary transformation\n", " rotated = torch.cat((-x2, x1), dim=-1)\n", " x_rotated = (x * cos) + (rotated * sin)\n", "\n", " return x_rotated.to(dtype=x.dtype)" ] }, { "cell_type": "code", "execution_count": 5, "id": "e8169ab5-f976-4222-a2e1-eb1cabf267cb", "metadata": {}, "outputs": [], "source": [ "class GroupedQueryAttention(nn.Module):\n", " def __init__(\n", " self, d_in, d_out, context_length, num_heads,\n", " num_kv_groups,\n", " rope_base=10_000,\n", " rope_config=None,\n", " dtype=None\n", " ):\n", " super().__init__()\n", " assert d_out % num_heads == 0, \"d_out must be divisible by num_heads\"\n", " assert num_heads % num_kv_groups == 0, \"num_heads must be divisible by num_kv_groups\"\n", "\n", " self.d_out = d_out\n", " self.num_heads = num_heads\n", " self.head_dim = d_out // num_heads\n", "\n", " self.W_key = nn.Linear(d_in, num_kv_groups * self.head_dim, bias=False, dtype=dtype)\n", " self.W_value = nn.Linear(d_in, num_kv_groups * self.head_dim, bias=False, dtype=dtype)\n", " self.num_kv_groups = num_kv_groups\n", " self.group_size = num_heads // num_kv_groups\n", "\n", " self.W_query = nn.Linear(d_in, d_out, bias=False, dtype=dtype)\n", " self.out_proj = nn.Linear(d_out, d_out, bias=False, dtype=dtype)\n", "\n", " self.register_buffer(\"mask\", torch.triu(torch.ones(context_length, context_length), diagonal=1))\n", " cos, sin = precompute_rope_params(\n", " head_dim=self.head_dim,\n", " theta_base=rope_base,\n", " freq_config=rope_config,\n", " context_length=8192\n", " )\n", " self.register_buffer(\"cos\", cos)\n", " self.register_buffer(\"sin\", sin)\n", "\n", " def forward(self, x):\n", " b, num_tokens, d_in = x.shape\n", "\n", " queries = self.W_query(x) # Shape: (b, num_tokens, d_out)\n", " keys = self.W_key(x) # Shape: (b, num_tokens, num_kv_groups * head_dim)\n", " values = self.W_value(x) # Shape: (b, num_tokens, num_kv_groups * head_dim)\n", "\n", " # Reshape queries, keys, and values\n", " queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)\n", " keys = keys.view(b, num_tokens, self.num_kv_groups, self.head_dim)\n", " values = values.view(b, num_tokens, self.num_kv_groups, self.head_dim)\n", "\n", " # Transpose keys, values, and queries\n", " keys = keys.transpose(1, 2) # Shape: (b, num_heads, num_tokens, head_dim)\n", " values = values.transpose(1, 2) # Shape: (b, num_heads, num_tokens, head_dim)\n", " queries = queries.transpose(1, 2) # Shape: (b, num_query_groups, num_tokens, head_dim)\n", "\n", " # Apply RoPE\n", " keys = compute_rope(keys, self.cos, self.sin)\n", " queries = compute_rope(queries, self.cos, self.sin)\n", "\n", " # Expand keys and values to match the number of heads\n", " # Shape: (b, num_heads, num_tokens, head_dim)\n", " keys = keys.repeat_interleave(self.group_size, dim=1) # Shape: (b, num_heads, num_tokens, head_dim)\n", " values = values.repeat_interleave(self.group_size, dim=1) # Shape: (b, num_heads, num_tokens, head_dim)\n", " # For example, before repeat_interleave along dim=1 (query groups):\n", " # [K1, K2]\n", " # After repeat_interleave (each query group is repeated group_size times):\n", " # [K1, K1, K2, K2]\n", " # If we used regular repeat instead of repeat_interleave, we'd get:\n", " # [K1, K2, K1, K2]\n", "\n", " # Compute scaled dot-product attention (aka self-attention) with a causal mask\n", " # Shape: (b, num_heads, num_tokens, num_tokens)\n", " attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head\n", "\n", " # Original mask truncated to the number of tokens and converted to boolean\n", " mask_bool = self.mask.bool()[:num_tokens, :num_tokens]\n", "\n", " # Use the mask to fill attention scores\n", " attn_scores.masked_fill_(mask_bool, -torch.inf)\n", "\n", " attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)\n", " assert keys.shape[-1] == self.head_dim\n", "\n", " # Shape: (b, num_tokens, num_heads, head_dim)\n", " context_vec = (attn_weights @ values).transpose(1, 2)\n", "\n", " # Combine heads, where self.d_out = self.num_heads * self.head_dim\n", " context_vec = context_vec.reshape(b, num_tokens, self.d_out)\n", " context_vec = self.out_proj(context_vec) # optional projection\n", "\n", " return context_vec" ] }, { "cell_type": "code", "execution_count": 6, "id": "457cb2f8-50c1-4045-8a74-f181bfb5fea9", "metadata": {}, "outputs": [], "source": [ "class TransformerBlock(nn.Module):\n", " def __init__(self, cfg):\n", " super().__init__()\n", " self.att = GroupedQueryAttention(\n", " d_in=cfg[\"emb_dim\"],\n", " d_out=cfg[\"emb_dim\"],\n", " context_length=cfg[\"context_length\"],\n", " num_heads=cfg[\"n_heads\"],\n", " num_kv_groups=cfg[\"n_kv_groups\"],\n", " rope_base=cfg[\"rope_base\"],\n", " rope_config=cfg[\"rope_freq\"],\n", " dtype=cfg[\"dtype\"]\n", " )\n", " self.ff = FeedForward(cfg)\n", " self.norm1 = nn.RMSNorm(cfg[\"emb_dim\"], eps=1e-5)\n", " self.norm2 = nn.RMSNorm(cfg[\"emb_dim\"], eps=1e-5)\n", "\n", " def forward(self, x):\n", " # Shortcut connection for attention block\n", " shortcut = x\n", " x = self.norm1(x)\n", " x = self.att(x.to(torch.bfloat16)) # Shape [batch_size, num_tokens, emb_size]\n", " x = x + shortcut # Add the original input back\n", "\n", " # Shortcut connection for feed-forward block\n", " shortcut = x\n", " x = self.norm2(x)\n", " x = self.ff(x.to(torch.bfloat16))\n", " x = x + shortcut # Add the original input back\n", "\n", " return x" ] }, { "cell_type": "code", "execution_count": 7, "id": "e88de3e3-9f07-42cc-816b-28dbd46e96c4", "metadata": {}, "outputs": [], "source": [ "class Llama3Model(nn.Module):\n", " def __init__(self, cfg):\n", " super().__init__()\n", " self.tok_emb = nn.Embedding(cfg[\"vocab_size\"], cfg[\"emb_dim\"], dtype=cfg[\"dtype\"])\n", "\n", " self.trf_blocks = nn.Sequential(\n", " *[TransformerBlock(cfg) for _ in range(cfg[\"n_layers\"])])\n", "\n", " self.final_norm = nn.RMSNorm(cfg[\"emb_dim\"], eps=1e-5)\n", " self.out_head = nn.Linear(cfg[\"emb_dim\"], cfg[\"vocab_size\"], bias=False, dtype=cfg[\"dtype\"])\n", "\n", " def forward(self, in_idx):\n", " batch_size, seq_len = in_idx.shape\n", " tok_embeds = self.tok_emb(in_idx)\n", " x = tok_embeds\n", " x = self.trf_blocks(x)\n", " x = self.final_norm(x)\n", " logits = self.out_head(x.to(torch.bfloat16))\n", " return logits" ] }, { "cell_type": "markdown", "id": "be2d201f-74ad-4d63-ab9c-601b00674a48", "metadata": {}, "source": [ " \n", "# 2. Initialize model" ] }, { "cell_type": "markdown", "id": "23dea40c-fe20-4a75-be25-d6fce5863c01", "metadata": {}, "source": [ "- The remainder of this notebook uses the Llama 3.2 1B model; to use the 3B model variant, just uncomment the second configuration file in the following code cell" ] }, { "cell_type": "code", "execution_count": 8, "id": "caa142fa-b375-4e78-b392-2072ced666f3", "metadata": {}, "outputs": [], "source": [ "# Llama 3.2 1B\n", "\n", "LLAMA32_CONFIG = {\n", " \"vocab_size\": 128_256, # Vocabulary size\n", " \"context_length\": 8192, # Context length\n", " \"emb_dim\": 2048, # Embedding dimension\n", " \"n_heads\": 32, # Number of attention heads\n", " \"n_layers\": 16, # Number of layers\n", " \"hidden_dim\": 8192, # Size of the intermediate dimension in FeedForward\n", " \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n", " \"rope_base\": 50_000, # The base in RoPE's \"theta\"\n", " \"dtype\": torch.bfloat16, # Lower-precision dtype to save memory\n", " \"rope_freq\": { # RoPE frequency scaling\n", " \"factor\": 32.0,\n", " \"low_freq_factor\": 1.0,\n", " \"high_freq_factor\": 4.0,\n", " \"original_context_length\": 8192,\n", " }\n", "}\n", "\n", "# Llama 3.2 3B\n", "\n", "# LLAMA32_CONFIG = {\n", "# \"vocab_size\": 128_256, # Vocabulary size\n", "# \"context_length\": 8192, # Context length\n", "# \"emb_dim\": 3072, # Embedding dimension\n", "# \"n_heads\": 24, # Number of attention heads\n", "# \"n_layers\": 28, # Number of layers\n", "# \"hidden_dim\": 8192, # Size of the intermediate dimension in FeedForward\n", "# \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n", "# \"rope_base\": 50_000, # The base in RoPE's \"theta\"\n", "# \"dtype\": torch.bfloat16, # Lower-precision dtype to save memory\n", "# \"rope_freq\": { # RoPE frequency scaling\n", "# \"factor\": 32.0,\n", "# \"low_freq_factor\": 1.0,\n", "# \"high_freq_factor\": 4.0,\n", "# \"original_context_length\": 8192,\n", "# }\n", "# }\n", "\n", "LLAMA_SIZE_STR = \"1B\" if LLAMA32_CONFIG[\"emb_dim\"] == 2048 else \"3B\"" ] }, { "cell_type": "code", "execution_count": 9, "id": "156253fe-aacd-4da2-8f13-705f05c4b11e", "metadata": {}, "outputs": [], "source": [ "model = Llama3Model(LLAMA32_CONFIG)" ] }, { "cell_type": "code", "execution_count": 10, "id": "364e76ca-52f8-4fa5-af37-c4069f9694bc", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Total number of parameters: 1,498,482,688\n" ] } ], "source": [ "total_params = sum(p.numel() for p in model.parameters())\n", "print(f\"Total number of parameters: {total_params:,}\")" ] }, { "cell_type": "code", "execution_count": 11, "id": "fd5efb03-5a07-46e8-8607-93ed47549d2b", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "float32 (PyTorch default): 15.23 GB\n", "bfloat16: 7.61 GB\n" ] } ], "source": [ "def model_memory_size(model, input_dtype=torch.float32):\n", " total_params = 0\n", " total_grads = 0\n", " for param in model.parameters():\n", " # Calculate total number of elements per parameter\n", " param_size = param.numel()\n", " total_params += param_size\n", " # Check if gradients are stored for this parameter\n", " if param.requires_grad:\n", " total_grads += param_size\n", "\n", " # Calculate buffer size (non-parameters that require memory)\n", " total_buffers = sum(buf.numel() for buf in model.buffers())\n", "\n", " # Size in bytes = (Number of elements) * (Size of each element in bytes)\n", " # We assume parameters and gradients are stored in the same type as input dtype\n", " element_size = torch.tensor(0, dtype=input_dtype).element_size()\n", " total_memory_bytes = (total_params + total_grads + total_buffers) * element_size\n", "\n", " # Convert bytes to gigabytes\n", " total_memory_gb = total_memory_bytes / (1024**3)\n", "\n", " return total_memory_gb\n", "\n", "print(f\"float32 (PyTorch default): {model_memory_size(model, input_dtype=torch.float32):.2f} GB\")\n", "print(f\"bfloat16: {model_memory_size(model, input_dtype=torch.bfloat16):.2f} GB\")" ] }, { "cell_type": "code", "execution_count": 12, "id": "31f12baf-f79b-499f-85c0-51328a6a20f5", "metadata": {}, "outputs": [], "source": [ "if torch.cuda.is_available():\n", " device = torch.device(\"cuda\")\n", "elif torch.backends.mps.is_available():\n", " device = torch.device(\"mps\")\n", "else:\n", " device = torch.device(\"cpu\")\n", "\n", "model.to(device);" ] }, { "cell_type": "markdown", "id": "78e091e1-afa8-4d23-9aea-cced86181bfd", "metadata": {}, "source": [ " \n", "# 3. Load tokenizer" ] }, { "cell_type": "code", "execution_count": 13, "id": "9482b01c-49f9-48e4-ab2c-4a4c75240e77", "metadata": {}, "outputs": [], "source": [ "import os\n", "from pathlib import Path\n", "\n", "import tiktoken\n", "from tiktoken.load import load_tiktoken_bpe\n", "\n", "\n", "class Tokenizer:\n", " def __init__(self, model_path):\n", " assert os.path.isfile(model_path), f\"Model file {model_path} not found\"\n", " mergeable_ranks = load_tiktoken_bpe(model_path)\n", " num_base_tokens = len(mergeable_ranks)\n", "\n", " self.special_tokens = {\n", " \"<|begin_of_text|>\": 128000,\n", " \"<|end_of_text|>\": 128001,\n", " \"<|start_header_id|>\": 128006,\n", " \"<|end_header_id|>\": 128007,\n", " \"<|eot_id|>\": 128009,\n", " }\n", " self.special_tokens.update({\n", " f\"<|reserved_{i}|>\": 128002 + i for i in range(256) if (128002 + i) not in self.special_tokens.values()\n", " })\n", "\n", " self.model = tiktoken.Encoding(\n", " name=Path(model_path).name,\n", " pat_str=r\"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+\",\n", " mergeable_ranks=mergeable_ranks,\n", " special_tokens=self.special_tokens\n", " )\n", "\n", "\n", " def encode(self, text, bos=False, eos=False, allowed_special=set(), disallowed_special=()):\n", " if bos:\n", " tokens = [self.special_tokens[\"<|begin_of_text|>\"]]\n", " else:\n", " tokens = []\n", "\n", " tokens += self.model.encode(text, allowed_special=allowed_special, disallowed_special=disallowed_special)\n", "\n", " if eos:\n", " tokens.append(self.special_tokens[\"<|end_of_text|>\"])\n", " return tokens\n", "\n", " def decode(self, tokens):\n", " return self.model.decode(tokens)\n", " \n", "\n", "class ChatFormat:\n", " def __init__(self, tokenizer):\n", " self.tokenizer = tokenizer\n", "\n", " def encode_header(self, message):\n", " tokens = []\n", " tokens.append(self.tokenizer.special_tokens[\"<|start_header_id|>\"])\n", " tokens.extend(self.tokenizer.encode(message[\"role\"], bos=False, eos=False))\n", " tokens.append(self.tokenizer.special_tokens[\"<|end_header_id|>\"])\n", " tokens.extend(self.tokenizer.encode(\"\\n\\n\", bos=False, eos=False))\n", " return tokens\n", "\n", " def encode(self, text):\n", " message = {\n", " \"role\": \"user\",\n", " \"content\": text\n", " }\n", "\n", " tokens = self.encode_header(message)\n", " tokens.extend(\n", " self.tokenizer.encode(message[\"content\"].strip(), bos=False, eos=False)\n", " )\n", " tokens.append(self.tokenizer.special_tokens[\"<|eot_id|>\"])\n", " return tokens\n", "\n", " def decode(self, token_ids):\n", " return self.tokenizer.decode(token_ids)" ] }, { "cell_type": "markdown", "id": "b771b60c-c198-4b30-bf10-42031197ae86", "metadata": {}, "source": [ "- Please note that Meta AI requires that you accept the Llama 3,2 licensing terms before you can download the files; to do this, you have to create a Hugging Face Hub account and visit the [meta-llama/Llama-3.2-1B](https://huggingface.co/meta-llama/Llama-3.2-1B) repository to accept the terms\n", "- Next, you will need to create an access token; to generate an access token with READ permissions, click on the profile picture in the upper right and click on \"Settings\"\n", "\n", "\n", "\n", "\n", "- Then, create and copy the access token so you can copy & paste it into the next code cell\n", "\n", "" ] }, { "cell_type": "code", "execution_count": null, "id": "edcc384a-adb7-43f6-acc3-ebe4b182ec91", "metadata": {}, "outputs": [], "source": [ "from huggingface_hub import login\n", "\n", "login()" ] }, { "cell_type": "code", "execution_count": 15, "id": "986bc1a0-804f-4154-80f8-44cefbee1368", "metadata": {}, "outputs": [], "source": [ "from huggingface_hub import hf_hub_download\n", "\n", "tokenizer_file_path = hf_hub_download(\n", " repo_id=f\"meta-llama/Llama-3.2-{LLAMA_SIZE_STR}-Instruct\",\n", " filename=\"original/tokenizer.model\",\n", " local_dir=\"llama32-files\"\n", ")" ] }, { "cell_type": "code", "execution_count": 16, "id": "f5a3014f-4c66-4fe2-874e-7b57562c49ad", "metadata": {}, "outputs": [], "source": [ "tokenizer = Tokenizer(tokenizer_file_path)\n", "chat_tokenizer = ChatFormat(tokenizer)" ] }, { "cell_type": "markdown", "id": "c172f89f-d301-439f-b809-46169e5f5945", "metadata": {}, "source": [ " \n", "# 4. Load pretrained weights" ] }, { "cell_type": "code", "execution_count": 17, "id": "75166128-5899-4995-9b88-9672e135650e", "metadata": {}, "outputs": [], "source": [ "def assign(left, right, tensor_name=\"unknown\"):\n", " if left.shape != right.shape:\n", " raise ValueError(f\"Shape mismatch in tensor '{tensor_name}'. Left: {left.shape}, Right: {right.shape}\")\n", "\n", " if isinstance(right, torch.Tensor):\n", " return torch.nn.Parameter(right.clone().detach())\n", " else:\n", " return torch.nn.Parameter(torch.tensor(right))\n", "\n", "\n", "def load_weights_into_llama(model, param_config, params):\n", " model.tok_emb.weight = assign(model.tok_emb.weight, params[\"model.embed_tokens.weight\"], \"model.embed_tokens.weight\")\n", "\n", " for l in range(param_config[\"n_layers\"]):\n", "\n", " # Load attention weights\n", " model.trf_blocks[l].att.W_query.weight = assign(\n", " model.trf_blocks[l].att.W_query.weight,\n", " params[f\"model.layers.{l}.self_attn.q_proj.weight\"],\n", " f\"model.layers.{l}.self_attn.q_proj.weight\"\n", " )\n", " model.trf_blocks[l].att.W_key.weight = assign(\n", " model.trf_blocks[l].att.W_key.weight,\n", " params[f\"model.layers.{l}.self_attn.k_proj.weight\"],\n", " f\"model.layers.{l}.self_attn.k_proj.weight\"\n", " )\n", " model.trf_blocks[l].att.W_value.weight = assign(\n", " model.trf_blocks[l].att.W_value.weight,\n", " params[f\"model.layers.{l}.self_attn.v_proj.weight\"],\n", " f\"model.layers.{l}.self_attn.v_proj.weight\"\n", " )\n", " model.trf_blocks[l].att.out_proj.weight = assign(\n", " model.trf_blocks[l].att.out_proj.weight,\n", " params[f\"model.layers.{l}.self_attn.o_proj.weight\"],\n", " f\"model.layers.{l}.self_attn.o_proj.weight\"\n", " )\n", " model.trf_blocks[l].norm1.weight = assign(\n", " model.trf_blocks[l].norm1.weight,\n", " params[f\"model.layers.{l}.input_layernorm.weight\"],\n", " f\"model.layers.{l}.input_layernorm.weight\"\n", " )\n", "\n", " # Load FeedForward weights\n", " model.trf_blocks[l].ff.fc1.weight = assign(\n", " model.trf_blocks[l].ff.fc1.weight,\n", " params[f\"model.layers.{l}.mlp.gate_proj.weight\"],\n", " f\"model.layers.{l}.mlp.gate_proj.weight\"\n", " )\n", " model.trf_blocks[l].ff.fc2.weight = assign(\n", " model.trf_blocks[l].ff.fc2.weight,\n", " params[f\"model.layers.{l}.mlp.up_proj.weight\"],\n", " f\"model.layers.{l}.mlp.up_proj.weight\"\n", " )\n", " model.trf_blocks[l].ff.fc3.weight = assign(\n", " model.trf_blocks[l].ff.fc3.weight,\n", " params[f\"model.layers.{l}.mlp.down_proj.weight\"],\n", " f\"model.layers.{l}.mlp.down_proj.weight\"\n", " )\n", " model.trf_blocks[l].norm2.weight = assign(\n", " model.trf_blocks[l].norm2.weight,\n", " params[f\"model.layers.{l}.post_attention_layernorm.weight\"],\n", " f\"model.layers.{l}.post_attention_layernorm.weight\"\n", " )\n", "\n", " # Load output layer weights\n", " model.final_norm.weight = assign(model.final_norm.weight, params[\"model.norm.weight\"], \"model.norm.weight\")\n", "\n", " if \"lm_head.weight\" in params.keys():\n", " model.out_head.weight = assign(model.out_head.weight, params[\"lm_head.weight\"], \"lm_head.weight\")\n", " else:\n", " model.out_head.weight = assign(model.out_head.weight, params[\"model.embed_tokens.weight\"], \"model.embed_tokens.weight\")" ] }, { "cell_type": "code", "execution_count": 18, "id": "699cb1b8-a67d-49fb-80a6-0dad9d81f392", "metadata": {}, "outputs": [], "source": [ "from safetensors.torch import load_file\n", "\n", "\n", "if LLAMA_SIZE_STR == \"1B\":\n", " weights_file = hf_hub_download(\n", " repo_id=f\"meta-llama/Llama-3.2-{LLAMA_SIZE_STR}-Instruct\",\n", " filename=f\"model.safetensors\",\n", " local_dir=\"llama32-files\"\n", " )\n", " combined_weights = load_file(weights_file)\n", "\n", "\n", "else:\n", " combined_weights = {}\n", " for i in range(1, 5):\n", " weights_file = hf_hub_download(\n", " repo_id=f\"meta-llama/Llama-3.2-{LLAMA_SIZE_STR}-Instruct\",\n", " filename=f\"model-0000{i}-of-00002.safetensors\",\n", " local_dir=\"llama3-files\"\n", " )\n", " current_weights = load_file(weights_file)\n", " combined_weights.update(current_weights)\n", " \n", "\n", "load_weights_into_llama(model, LLAMA32_CONFIG, combined_weights)\n", "model.to(device);" ] }, { "cell_type": "markdown", "id": "57d07df1-4401-4792-b549-7c4cc5632323", "metadata": {}, "source": [ " \n", "# 5. Generate text" ] }, { "cell_type": "code", "execution_count": 19, "id": "7b8401c6-e244-4cb7-9849-2ba71ce758d5", "metadata": {}, "outputs": [], "source": [ "def text_to_token_ids(text, tokenizer):\n", " encoded = tokenizer.encode(text)\n", " encoded_tensor = torch.tensor(encoded).unsqueeze(0) # add batch dimension\n", " return encoded_tensor\n", "\n", "\n", "def token_ids_to_text(token_ids, tokenizer):\n", " flat = token_ids.squeeze(0) # remove batch dimension\n", " return tokenizer.decode(flat.tolist())\n", "\n", "\n", "def generate(model, idx, max_new_tokens, context_size, temperature=0.0, top_k=None, eos_id=None):\n", "\n", " # For-loop is the same as before: Get logits, and only focus on last time step\n", " for _ in range(max_new_tokens):\n", " idx_cond = idx[:, -context_size:]\n", " with torch.no_grad():\n", " logits = model(idx_cond)\n", " logits = logits[:, -1, :]\n", "\n", " # New: Filter logits with top_k sampling\n", " if top_k is not None:\n", " # Keep only top_k values\n", " top_logits, _ = torch.topk(logits, top_k)\n", " min_val = top_logits[:, -1]\n", " logits = torch.where(logits < min_val, torch.tensor(float('-inf')).to(logits.device), logits)\n", "\n", " # New: Apply temperature scaling\n", " if temperature > 0.0:\n", " logits = logits / temperature\n", "\n", " # Apply softmax to get probabilities\n", " probs = torch.softmax(logits, dim=-1) # (batch_size, context_len)\n", "\n", " # Sample from the distribution\n", " idx_next = torch.multinomial(probs, num_samples=1) # (batch_size, 1)\n", "\n", " # Otherwise same as before: get idx of the vocab entry with the highest logits value\n", " else:\n", " idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch_size, 1)\n", "\n", " if idx_next == eos_id: # Stop generating early if end-of-sequence token is encountered and eos_id is specified\n", " break\n", "\n", " # Same as before: append sampled index to the running sequence\n", " idx = torch.cat((idx, idx_next), dim=1) # (batch_size, num_tokens+1)\n", "\n", " return idx" ] }, { "cell_type": "code", "execution_count": 20, "id": "1c7a04fa-6aac-416b-8f63-f1e19227633d", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Output text:\n", " Llamas are herbivores, which means they primarily eat plants and plant-based foods. Their diet typically consists of:\n", "\n", "1. Grasses: Llamas love to graze on various types of grasses, including tall grasses and short grasses.\n", "2. Hay: Llamas also eat hay, which is a dry, compressed form of grass or other plants.\n", "3. Alfalfa: Alfalfa is a legume that is commonly fed to llamas, as it is high in protein and fiber.\n", "4. Other plants: Llamas will also eat other plants, such as clover, wild grasses, and shrubs.\n", "\n", "It's worth noting that llamas are adapted to high altitudes and\n" ] } ], "source": [ "import re\n", "\n", "\n", "PROMPT = \"What do llamas eat?\"\n", "\n", "torch.manual_seed(123)\n", "\n", "token_ids = generate(\n", " model=model,\n", " idx=text_to_token_ids(PROMPT, chat_tokenizer).to(device),\n", " max_new_tokens=150,\n", " context_size=LLAMA32_CONFIG[\"context_length\"],\n", " top_k=1,\n", " temperature=0.\n", ")\n", "\n", "output_text = token_ids_to_text(token_ids, tokenizer)\n", "\n", "\n", "def clean_text(text, header_end=\"assistant<|end_header_id|>\\n\\n\"):\n", " # Find the index of the first occurrence of \"<|end_header_id|>\"\n", " index = text.find(header_end)\n", "\n", " if index != -1:\n", " # Return the substring starting after \"<|end_header_id|>\"\n", " return text[index + len(header_end):].strip() # Strip removes leading/trailing whitespace\n", " else:\n", " # If the token is not found, return the original text\n", " return text\n", "\n", "print(\"Output text:\\n\", clean_text(output_text))" ] }, { "cell_type": "markdown", "id": "549324d6-5c71-4147-ae21-2e67675faa3d", "metadata": {}, "source": [ " \n", "# What's next?" ] }, { "cell_type": "markdown", "id": "e6edaaae-2de1-406c-8ffa-897cdfa3808c", "metadata": {}, "source": [ "- The notebook was kept purposefully minimal; if you are interested in additional explanation about the individual components, check out the following two companion notebooks:\n", "\n", "\n", "\n", " 1. [Converting a From-Scratch GPT Architecture to Llama 2](converting-gpt-to-llama2.ipynb)\n", " 2. [Converting Llama 2 to Llama 3.2 From Scratch](converting-llama2-to-llama3.ipynb)\n", " \n", "- For those interested in a comprehensive guide on building a large language model from scratch and gaining a deeper understanding of its mechanics, you might like my [Build a Large Language Model (From Scratch)](http://mng.bz/orYv)\n", "\n", "" ] }, { "cell_type": "code", "execution_count": null, "id": "bf864c28-2ce1-44bf-84e4-c0671f494d62", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.4" } }, "nbformat": 4, "nbformat_minor": 5 }