From c4cde1c21bc38603cd11b06d55884a82555520b1 Mon Sep 17 00:00:00 2001 From: Sebastian Raschka Date: Thu, 12 Jun 2025 11:08:02 -0500 Subject: [PATCH] Reduce Llama 3 RoPE memory requirements (#658) * Llama3 from scratch improvements * Fix Llama 3 expensive RoPE memory issue * updates * update package * benchmark * remove unused rescale_theta --- .gitignore | 3 + ch05/07_gpt_to_llama/README.md | 14 +- .../converting-llama2-to-llama3.ipynb | 492 ++--- .../standalone-llama32-mem-opt.ipynb | 1881 ----------------- ch05/07_gpt_to_llama/standalone-llama32.ipynb | 411 ++-- pkg/llms_from_scratch/README.md | 11 +- pkg/llms_from_scratch/llama3.py | 148 +- pkg/llms_from_scratch/tests/test_llama3.py | 20 +- pyproject.toml | 2 +- 9 files changed, 405 insertions(+), 2577 deletions(-) delete mode 100644 ch05/07_gpt_to_llama/standalone-llama32-mem-opt.ipynb diff --git a/.gitignore b/.gitignore index 04a0886..f23dddb 100644 --- a/.gitignore +++ b/.gitignore @@ -51,6 +51,9 @@ ch05/07_gpt_to_llama/Llama-3.2-3B-Instruct ch05/10_llm-training-speed/middlemarch.txt ch05/10_llm-training-speed/loss.pdf ch05/10_llm-training-speed/model.pth +ch05/07_gpt_to_llama/Untitled.ipynb +ch05/07_gpt_to_llama/llama3.2-1B-instruct.pth +ch05/07_gpt_to_llama/tokenizer.model ch06/01_main-chapter-code/gpt2 ch06/02_bonus_additional-experiments/gpt2 diff --git a/ch05/07_gpt_to_llama/README.md b/ch05/07_gpt_to_llama/README.md index 8df9a45..233bfca 100644 --- a/ch05/07_gpt_to_llama/README.md +++ b/ch05/07_gpt_to_llama/README.md @@ -40,8 +40,6 @@ MODEL_FILE = "llama3.2-1B-instruct.pth" Basic text generation settings that can be defined by the user. Note that the recommended 8192-token context size requires approximately 3 GB of VRAM for the text generation example. ```python -MODEL_CONTEXT_LENGTH = 8192 # Supports up to 131_072 - # Text generation settings if "instruct" in MODEL_FILE: PROMPT = "What do llamas eat?" @@ -82,8 +80,6 @@ elif "3B" in MODEL_FILE: else: raise ValueError("Incorrect model file name") -LLAMA32_CONFIG["context_length"] = MODEL_CONTEXT_LENGTH - model = Llama3Model(LLAMA32_CONFIG) model.load_state_dict(torch.load(MODEL_FILE, weights_only=True, map_location="cpu")) @@ -125,7 +121,7 @@ Lastly, we can generate text via the following code: ```python import time -from ch05 import ( +from llms_from_scratch.ch05 import ( generate, text_to_token_ids, token_ids_to_text @@ -192,8 +188,8 @@ The following table shows a performance comparison on an A100: | | Tokens/sec | Memory | | --------------- | ---------- | ------- | -| Llama3Model | 50 | 2.91 GB | -| Llama3ModelFast | 58 | 2.85 GB | +| Llama3Model | 42 | 2.91 GB | +| Llama3ModelFast | 54 | 2.91 GB |   #### Pro tip 2: speed up inference with compilation @@ -218,5 +214,5 @@ The following table shows a performance comparison on an A100 for consequent `ge | | Tokens/sec | Memory | | --------------- | ---------- | ------- | -| Llama3Model | 156 | 3.12 GB | -| Llama3ModelFast | 159 | 2.84 GB | +| Llama3Model | 170 | 3.12 GB | +| Llama3ModelFast | 177 | 3.61 GB | diff --git a/ch05/07_gpt_to_llama/converting-llama2-to-llama3.ipynb b/ch05/07_gpt_to_llama/converting-llama2-to-llama3.ipynb index 78d0183..908a034 100644 --- a/ch05/07_gpt_to_llama/converting-llama2-to-llama3.ipynb +++ b/ch05/07_gpt_to_llama/converting-llama2-to-llama3.ipynb @@ -95,9 +95,9 @@ "output_type": "stream", "text": [ "blobfile version: 3.0.0\n", - "huggingface_hub version: 0.24.7\n", - "tiktoken version: 0.8.0\n", - "torch version: 2.4.1+cu121\n" + "huggingface_hub version: 0.30.1\n", + "tiktoken version: 0.9.0\n", + "torch version: 2.6.0\n" ] } ], @@ -435,7 +435,7 @@ "id": "842aa71a-4659-424e-8830-392bd6ae86af", "metadata": {}, "source": [ - "- In addition, we also introduce a `SharedBuffers` class that will allow us to reuse the `mask`, `cos`, and `sin` tensors in the transformer blocks to improve efficiency (this will be crucial when working with models such as Llama 3.1 and 3.2 later, which support up to 131k input tokens)" + "- **We also redesign the attention class a bit so it receives the mask through its forward method instead of storing and accessing it as `self.mask`. This lets us build the mask on the fly to reduce memory usage. To foreshadow why: Llama 3.1 can handle sequences of up to 128 k tokens, and precomputing a 128 k × 128 k causal mask would be extremely memory‑intensive, so we avoid it unless absolutely necessary.**" ] }, { @@ -450,27 +450,6 @@ "import torch.nn as nn\n", "\n", "\n", - "############################# NEW #############################\n", - "class SharedBuffers:\n", - " _buffers = {}\n", - "\n", - " @staticmethod\n", - " def get_buffers(context_length, head_dim, rope_base, freq_config, dtype=torch.float32):\n", - " key = (context_length, head_dim, rope_base, tuple(freq_config.values()) if freq_config else freq_config, dtype)\n", - "\n", - " if key not in SharedBuffers._buffers:\n", - " # Create or fetch the buffers\n", - " mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)\n", - " cos, sin = precompute_rope_params(head_dim, rope_base, context_length, freq_config)\n", - " if dtype is not None:\n", - " cos = cos.to(dtype)\n", - " sin = sin.to(dtype)\n", - " SharedBuffers._buffers[key] = (mask, cos, sin)\n", - "\n", - " return SharedBuffers._buffers[key]\n", - "############################# NEW #############################\n", - "\n", - "\n", "class GroupedQueryAttention(nn.Module):\n", " def __init__(\n", " self, d_in, d_out, context_length, num_heads,\n", @@ -499,16 +478,12 @@ " self.W_query = nn.Linear(d_in, d_out, bias=False, dtype=dtype)\n", " self.out_proj = nn.Linear(d_out, d_out, bias=False, dtype=dtype)\n", "\n", - " ############################# NEW #############################\n", - " # Fetch buffers using SharedBuffers\n", - " mask, cos, sin = SharedBuffers.get_buffers(context_length, self.head_dim, rope_base, rope_config, dtype)\n", - " ############################# NEW #############################\n", - " \n", - " self.register_buffer(\"mask\", mask)\n", - " self.register_buffer(\"cos\", cos)\n", - " self.register_buffer(\"sin\", sin)\n", "\n", - " def forward(self, x):\n", + " def forward(self, x, mask=None, cos=None, sin=None):\n", + " ##################### NEW #####################\n", + " # The forward method now accepts `mask` instead of accessing it via self.mask.\n", + " # Also, we now have cos and sin as input for RoPE\n", + " ################################################ \n", " b, num_tokens, d_in = x.shape\n", "\n", " queries = self.W_query(x) # Shape: (b, num_tokens, d_out)\n", @@ -530,9 +505,12 @@ " values = values.transpose(1, 2) # Shape: (b, num_heads, num_tokens, head_dim)\n", " queries = queries.transpose(1, 2) # Shape: (b, num_query_groups, num_tokens, head_dim)\n", "\n", + " ##################### NEW #####################\n", " # Apply RoPE\n", - " keys = compute_rope(keys, self.cos, self.sin)\n", - " queries = compute_rope(queries, self.cos, self.sin)\n", + " if cos is not None:\n", + " keys = compute_rope(keys, cos, sin)\n", + " queries = compute_rope(queries, cos, sin)\n", + " ################################################\n", "\n", " ##################### NEW #####################\n", " # Expand keys and values to match the number of heads\n", @@ -552,11 +530,14 @@ " # Shape: (b, num_heads, num_tokens, num_tokens)\n", " attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head\n", "\n", - " # Original mask truncated to the number of tokens and converted to boolean\n", - " mask_bool = self.mask.bool()[:num_tokens, :num_tokens]\n", - "\n", + " ##################### NEW #####################\n", + " # Create mask on the fly\n", + " if mask is None:\n", + " mask = torch.triu(torch.ones(num_tokens, num_tokens, device=x.device, dtype=torch.bool), diagonal=1)\n", + " ################################################\n", + " \n", " # Use the mask to fill attention scores\n", - " attn_scores.masked_fill_(mask_bool, -torch.inf)\n", + " attn_scores.masked_fill_(mask, -torch.inf)\n", "\n", " attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)\n", " assert keys.shape[-1] == self.head_dim\n", @@ -578,7 +559,7 @@ "id": "roAXSwJs9hR8" }, "source": [ - "- To illustrate the parameter savings, consider the following multi-head attention example from the GPT and Llama 2 code:" + "- To illustrate the parameter savings in GQA over MHA, consider the following multi-head attention example from the GPT and Llama 2 code:" ] }, { @@ -753,7 +734,8 @@ }, "source": [ "- Next, we update the `TransformerBlock`\n", - "- Here, we simply swap `MultiHeadAttention` with `GroupedQueryAttention` and add the new RoPE settings" + "- Here, we simply swap `MultiHeadAttention` with `GroupedQueryAttention` and add the new RoPE settings\n", + "- In addition, we also modify the `forward` method so that it receives `mask`, `cos`, and `sin`; since the values for those are the same for each transformer block, we only have to compute them once and then can reuse them" ] }, { @@ -782,11 +764,15 @@ " self.norm1 = RMSNorm(cfg[\"emb_dim\"], eps=1e-5)\n", " self.norm2 = RMSNorm(cfg[\"emb_dim\"], eps=1e-5)\n", "\n", - " def forward(self, x):\n", + " def forward(self, x, mask=None, cos=None, sin=None):\n", + " ##################### NEW #####################\n", + " # The forward method now accepts `mask` instead of accessing it via self.mask.\n", + " # Also, we now have cos and sin as input for RoPE\n", + " ################################################\n", " # Shortcut connection for attention block\n", " shortcut = x\n", " x = self.norm1(x)\n", - " x = self.att(x.to(torch.bfloat16)) # Shape [batch_size, num_tokens, emb_size]\n", + " x = self.att(x.to(torch.bfloat16), mask, cos, sin) # Shape [batch_size, num_tokens, emb_size]\n", " x = x + shortcut # Add the original input back\n", "\n", " # Shortcut connection for feed-forward block\n", @@ -816,7 +802,8 @@ "id": "M_tLAq_r_llN" }, "source": [ - "- When setting up the model class, we fortunately don't have to do much; we just update the name to `Llama3Model`" + "- When setting up the model class, we technically don't have to do much; we just update the name to `Llama3Model`\n", + "- However, since we now pass the `mask`, `cos`, and `sin` to the transformer blocks, we also have to add them here" ] }, { @@ -840,12 +827,33 @@ " self.final_norm = RMSNorm(cfg[\"emb_dim\"], eps=1e-5)\n", " self.out_head = nn.Linear(cfg[\"emb_dim\"], cfg[\"vocab_size\"], bias=False, dtype=cfg[\"dtype\"])\n", "\n", + " #################### NEW #####################\n", + " cos, sin = precompute_rope_params(\n", + " head_dim=cfg[\"emb_dim\"] // cfg[\"n_heads\"],\n", + " theta_base=cfg[\"rope_base\"],\n", + " context_length=cfg[\"context_length\"],\n", + " freq_config=cfg[\"rope_freq\"]\n", + " )\n", + " \n", + " self.register_buffer(\"cos\", cos, persistent=False)\n", + " self.register_buffer(\"sin\", sin, persistent=False)\n", + " ##############################################\n", + "\n", + " self.cfg = cfg\n", + "\n", " def forward(self, in_idx):\n", " tok_embeds = self.tok_emb(in_idx)\n", " x = tok_embeds\n", - " x = self.trf_blocks(x)\n", + "\n", + " #################### NEW #####################\n", + " num_tokens = x.shape[1]\n", + " mask = torch.triu(torch.ones(num_tokens, num_tokens, device=x.device, dtype=torch.bool), diagonal=1)\n", + " ##############################################\n", + " \n", + " for block in self.trf_blocks:\n", + " x = block(x, mask, self.cos, self.sin)\n", " x = self.final_norm(x)\n", - " logits = self.out_head(x.to(torch.bfloat16))\n", + " logits = self.out_head(x.to(self.cfg[\"dtype\"]))\n", " return logits" ] }, @@ -936,33 +944,12 @@ "model = Llama3Model(LLAMA3_CONFIG_8B)" ] }, - { - "cell_type": "markdown", - "id": "edea6334-d1fc-427d-9cf2-4af963ff4bfc", - "metadata": {}, - "source": [ - "- The following is expected to print True to confirm buffers are reused instead of being (wastefully) recreated:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ee9625cc-9afa-4b11-8aab-d536fd170761", - "metadata": {}, - "outputs": [], - "source": [ - "# Check buffers\n", - "print(model.trf_blocks[0].att.mask is model.trf_blocks[-1].att.mask)\n", - "print(model.trf_blocks[0].att.cos is model.trf_blocks[-1].att.cos)\n", - "print(model.trf_blocks[0].att.sin is model.trf_blocks[-1].att.sin) " - ] - }, { "cell_type": "markdown", "id": "8056a521-91a6-440f-8473-591409c3177b", "metadata": {}, "source": [ - "- Let's now also compute the number of trainable parameters:" + "- Let's now compute the number of trainable parameters:" ] }, { @@ -1017,8 +1004,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "float32 (PyTorch default): 68.08 GB\n", - "bfloat16: 34.04 GB\n" + "float32 (PyTorch default): 59.84 GB\n", + "bfloat16: 29.92 GB\n" ] } ], @@ -1121,43 +1108,47 @@ "\n", "\n", "class Tokenizer:\n", + " \"\"\"Thin wrapper around tiktoken that keeps track of Llama-3 special IDs.\"\"\"\n", " def __init__(self, model_path):\n", - " assert os.path.isfile(model_path), f\"Model file {model_path} not found\"\n", - " mergeable_ranks = load_tiktoken_bpe(model_path)\n", + " if not os.path.isfile(model_path):\n", + " raise FileNotFoundError(model_path)\n", "\n", - " self.special_tokens = {\n", + " mergeable = load_tiktoken_bpe(model_path)\n", + "\n", + " # hard-coded from Meta's tokenizer.json\n", + " self.special = {\n", " \"<|begin_of_text|>\": 128000,\n", " \"<|end_of_text|>\": 128001,\n", " \"<|start_header_id|>\": 128006,\n", " \"<|end_header_id|>\": 128007,\n", " \"<|eot_id|>\": 128009,\n", " }\n", - " self.special_tokens.update({\n", - " f\"<|reserved_{i}|>\": 128002 + i for i in range(256) if (128002 + i) not in self.special_tokens.values()\n", - " })\n", + " self.special.update({f\"<|reserved_{i}|>\": 128002 + i\n", + " for i in range(256)\n", + " if 128002 + i not in self.special.values()})\n", "\n", " self.model = tiktoken.Encoding(\n", " name=Path(model_path).name,\n", - " pat_str=r\"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+\",\n", - " mergeable_ranks=mergeable_ranks,\n", - " special_tokens=self.special_tokens\n", + " pat_str=r\"(?i:'s|'t|'re|'ve|'m|'ll|'d)\"\n", + " r\"|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+\"\n", + " r\"|\\p{N}{1,3}\"\n", + " r\"| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*\"\n", + " r\"|\\s*[\\r\\n]+\"\n", + " r\"|\\s+(?!\\S)\"\n", + " r\"|\\s+\",\n", + " mergeable_ranks=mergeable,\n", + " special_tokens=self.special,\n", " )\n", "\n", - "\n", - " def encode(self, text, bos=False, eos=False, allowed_special=set(), disallowed_special=()):\n", - " if bos:\n", - " tokens = [self.special_tokens[\"<|begin_of_text|>\"]]\n", - " else:\n", - " tokens = []\n", - "\n", - " tokens += self.model.encode(text, allowed_special=allowed_special, disallowed_special=disallowed_special)\n", - "\n", + " def encode(self, text, bos=False, eos=False):\n", + " ids = ([self.special[\"<|begin_of_text|>\"]] if bos else []) \\\n", + " + self.model.encode(text)\n", " if eos:\n", - " tokens.append(self.special_tokens[\"<|end_of_text|>\"])\n", - " return tokens\n", + " ids.append(self.special[\"<|end_of_text|>\"])\n", + " return ids\n", "\n", - " def decode(self, tokens):\n", - " return self.model.decode(tokens)" + " def decode(self, ids):\n", + " return self.model.decode(ids)" ] }, { @@ -1202,13 +1193,11 @@ }, "outputs": [ { - "name": "stdout", + "name": "stderr", "output_type": "stream", "text": [ - "The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.\n", - "Token is valid (permission: read).\n", - "Your token has been saved to /root/.cache/huggingface/token\n", - "Login successful\n" + "/Users/sebastian/Developer/LLMs-from-scratch/.venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], @@ -1309,7 +1298,8 @@ "base_uri": "https://localhost:8080/" }, "id": "e0a2b5cd-6cba-4d72-b8ff-04d8315d483e", - "outputId": "990d7b74-cb35-476b-d8bd-d544006e00f4" + "outputId": "990d7b74-cb35-476b-d8bd-d544006e00f4", + "scrolled": true }, "outputs": [ { @@ -1318,7 +1308,9 @@ "text": [ "Output text:\n", " Every effort_dead aeros Ingredients başında.extensionégor clangmissions güc như submodule.and report官方%,.Reader(\",\");\n", - "ामल ندار Parliamentary !!! HigginsDynamicZhgmt writeln Globalsletion 사진------\n" + "ामल ندار Parliamentary !!! HigginsDynamicZhamincus_beam cyc......\n", + "\n", + " haciendo\n" ] } ], @@ -1437,22 +1429,7 @@ "id": "5fa9c06c-7a53-4b4d-9ce4-acc027322ee4", "outputId": "c05118ce-9f81-41c8-a1f2-72caa932ae86" }, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "245443330e4d40c887a5649cc1663e98", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "model-00001-of-00004.safetensors: 0%| | 0.00/4.98G [00:00\"])\n", - " tokens.extend(self.tokenizer.encode(message[\"role\"], bos=False, eos=False))\n", - " tokens.append(self.tokenizer.special_tokens[\"<|end_header_id|>\"])\n", - " tokens.extend(self.tokenizer.encode(\"\\n\\n\", bos=False, eos=False))\n", - " return tokens\n", + " def __init__(self, tokenizer: Tokenizer, *,\n", + " default_system=\"You are a helpful assistant.\"):\n", + " self.tok = tokenizer\n", + " self.default_system = default_system\n", "\n", - " def encode(self, text):\n", - " message = {\n", - " \"role\": \"user\",\n", - " \"content\": text\n", - " }\n", - "\n", - " tokens = self.encode_header(message)\n", - " tokens.extend(\n", - " self.tokenizer.encode(message[\"content\"].strip(), bos=False, eos=False)\n", + " def _header(self, role):\n", + " \"\"\"Encode <|start_header_id|>role<|end_header_id|>\\n\\n\"\"\"\n", + " return (\n", + " [self.tok.special[\"<|start_header_id|>\"]]\n", + " + self.tok.encode(role)\n", + " + [self.tok.special[\"<|end_header_id|>\"]]\n", + " + self.tok.encode(\"\\n\\n\")\n", " )\n", - " tokens.append(self.tokenizer.special_tokens[\"<|eot_id|>\"])\n", - " return tokens\n", "\n", - " def decode(self, token_ids):\n", - " return self.tokenizer.decode(token_ids)\n", + " def encode(self, user_message, system_message=None):\n", + " sys_msg = system_message if system_message is not None else self.default_system\n", "\n", + " ids = [self.tok.special[\"<|begin_of_text|>\"]]\n", "\n", - "chat_tokenizer = ChatFormat(tokenizer)" + " # system\n", + " ids += self._header(\"system\")\n", + " ids += self.tok.encode(sys_msg)\n", + " ids += [self.tok.special[\"<|eot_id|>\"]]\n", + "\n", + " # user\n", + " ids += self._header(\"user\")\n", + " ids += self.tok.encode(user_message)\n", + " ids += [self.tok.special[\"<|eot_id|>\"]]\n", + "\n", + " # assistant header (no content yet)\n", + " ids += self._header(\"assistant\")\n", + "\n", + " return ids" ] }, { @@ -1918,11 +1843,14 @@ "name": "stdout", "output_type": "stream", "text": [ - "[128006, 882, 128007, 271, 9906, 4435, 0, 128009]\n" + "[128000, 128006, 9125, 128007, 271, 2675, 527, 264, 11190, 18328, 13, 128009, 128006, 882, 128007, 271, 9906, 4435, 0, 128009, 128006, 78191, 128007, 271]\n" ] } ], "source": [ + "tokenizer = Tokenizer(tokenizer_file_path)\n", + "chat_tokenizer = ChatFormat(tokenizer)\n", + "\n", "token_ids = chat_tokenizer.encode(\"Hello World!\")\n", "print(token_ids)" ] @@ -1943,7 +1871,7 @@ { "data": { "text/plain": [ - "'<|start_header_id|>user<|end_header_id|>\\n\\nHello World!<|eot_id|>'" + "'<|begin_of_text|><|start_header_id|>system<|end_header_id|>\\n\\nYou are a helpful assistant.<|eot_id|><|start_header_id|>user<|end_header_id|>\\n\\nHello World!<|eot_id|><|start_header_id|>assistant<|end_header_id|>\\n\\n'" ] }, "execution_count": 35, @@ -1982,12 +1910,13 @@ "output_type": "stream", "text": [ "Output text:\n", - " Llamas are herbivores, which means they primarily eat plants and plant-based foods. Here are some of the things llamas like to eat:\n", + " Llamas are herbivores, which means they primarily eat plants and plant-based foods. Their diet typically consists of:\n", "\n", - "1. Grass: Llamas love to graze on grass, especially in the spring and summer months.\n", - "2. Hay: Hay is a staple in a llama's diet. They like to eat timothy hay, alfalfa hay, and other types of hay.\n", - "3. Grains: Llamas may also be fed grains like oats, barley, and corn. However, grains should not make up more than 10-15% of a llama's diet.\n", - "4. Fruits and vegetables: Llamas may enjoy fruits and vegetables as treats, such as\n" + "1. Grasses: Llamas love to graze on grasses, including tall grasses, short grasses, and even weeds.\n", + "2. Hay: Hay is a staple in a llama's diet. They enjoy a variety of hays, such as timothy hay, alfalfa hay, and oat hay.\n", + "3. Grains: Llamas may be fed grains like oats, corn, and barley as a supplement to their diet.\n", + "4. Fruits and vegetables: Llamas enjoy fruits and vegetables like apples, carrots, and sweet potatoes as treats or additions to their diet.\n", + "5. Minerals:\n" ] } ], @@ -2088,49 +2017,6 @@ "}" ] }, - { - "cell_type": "markdown", - "id": "d81ee464-c112-43b0-9ee8-70df6ac942d0", - "metadata": {}, - "source": [ - "- Reduce the context length so the model would work fine on a MacBook Air (if you have more RAM, feel free to comment out the lines below):" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "a55a8769-1a03-4265-8fd0-15f1c423da53", - "metadata": { - "id": "a8bc2370-39d2-4bfe-b4c1-6bdd75fe101c" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "New RoPE theta: 31250.0\n" - ] - } - ], - "source": [ - "old_context_length = LLAMA31_CONFIG_8B[\"context_length\"]\n", - "LLAMA31_CONFIG_8B[\"context_length\"] = 8192\n", - "\n", - "\n", - "def rescale_theta(theta_old, context_length_old, context_length_new):\n", - " scaling_factor = context_length_new / context_length_old\n", - " theta_new = theta_old * scaling_factor\n", - " return theta_new\n", - "\n", - "LLAMA31_CONFIG_8B[\"rope_base\"] = rescale_theta(\n", - " LLAMA31_CONFIG_8B[\"rope_base\"],\n", - " old_context_length,\n", - " LLAMA31_CONFIG_8B[\"context_length\"]\n", - ")\n", - "\n", - "print(\"New RoPE theta:\", LLAMA31_CONFIG_8B[\"rope_base\"])" - ] - }, { "cell_type": "markdown", "id": "xa3bpMDtTdBs", @@ -2277,64 +2163,7 @@ "id": "u4J7IxOvOyPM", "outputId": "925348d7-fc69-4d1b-90f1-7029426bcfcf" }, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "eabfde3ef38b436ea750e6fb50a02b5c", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "model-00001-of-00004.safetensors: 0%| | 0.00/4.98G [00:00\n", - "\n", - "\n", - "\n", - "Supplementary code for the Build a Large Language Model From Scratch book by Sebastian Raschka
\n", - "
Code repository: https://github.com/rasbt/LLMs-from-scratch\n", - "
\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "" - ] - }, - { - "cell_type": "markdown", - "id": "efde77f2-6af3-4781-8597-89ecd3f41a52", - "metadata": { - "id": "efde77f2-6af3-4781-8597-89ecd3f41a52" - }, - "source": [ - "# Llama 3.2 From Scratch (A Standalone Notebook)" - ] - }, - { - "cell_type": "markdown", - "id": "55cdef4d-de59-4a65-89f9-fa2a8ef3471d", - "metadata": { - "id": "55cdef4d-de59-4a65-89f9-fa2a8ef3471d" - }, - "source": [ - "**Note: This notebook is an alternative to the [standalone-llama32.ipynb](standalone-llama32.ipynb) notebook but optimized for memory efficiency by using a global mask, cos, and sin. On an A100, based on a 8192 context length, this only uses 3.1 GB (vs 7.07 GB) VRAM.** \n", - "\n", - "\n", - "- This notebook is purposefully minimal and focuses on the code to implement the Llama 3.2 1B and 3B LLMs\n", - "- For a step-by-step guide that explains the individual components and the relationship between GPT, Llama 2, and Llama 3, please see the following companion notebooks:\n", - " - [Converting a From-Scratch GPT Architecture to Llama 2](converting-gpt-to-llama2.ipynb)\n", - " - [Converting Llama 2 to Llama 3.2 From Scratch](converting-llama2-to-llama3.ipynb)\n", - " \n", - "\n", - "\n", - " \n", - " \n", - "- About the code:\n", - " - all code is my own code, mapping the Llama 3 architecture onto the model code implemented in my [Build A Large Language Model (From Scratch)](http://mng.bz/orYv) book; the code is released under a permissive open-source Apache 2.0 license (see [LICENSE.txt](https://github.com/rasbt/LLMs-from-scratch/blob/main/LICENSE.txt))\n", - " - the tokenizer code is inspired by the original [Llama 3 tokenizer code](https://github.com/meta-llama/llama3/blob/main/llama/tokenizer.py), which Meta AI used to extend the Tiktoken GPT-4 tokenizer\n", - " - the RoPE rescaling section is inspired by the [_compute_llama3_parameters function](https://github.com/huggingface/transformers/blob/5c1027bf09717f664b579e01cbb8ec3ef5aeb140/src/transformers/modeling_rope_utils.py#L329-L347) in the `transformers` library" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "id": "7c201adb-747e-437b-9a62-442802941e01", - "metadata": {}, - "outputs": [], - "source": [ - "# pip install -r https://raw.githubusercontent.com/rasbt/LLMs-from-scratch/refs/heads/main/ch05/07_gpt_to_llama/requirements-extra.txt" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "dd1b65a8-4301-444a-bd7c-a6f2bd1df9df", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "dd1b65a8-4301-444a-bd7c-a6f2bd1df9df", - "outputId": "4f762354-e0a3-4cc2-e5d4-e61a227a202c" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "blobfile version: 3.0.0\n", - "huggingface_hub version: 0.29.3\n", - "tiktoken version: 0.9.0\n", - "torch version: 2.6.0\n" - ] - } - ], - "source": [ - "from importlib.metadata import version\n", - "\n", - "pkgs = [\n", - " \"blobfile\", # to download pretrained weights\n", - " \"huggingface_hub\", # to download pretrained weights\n", - " \"tiktoken\", # to implement the tokenizer\n", - " \"torch\", # to implement the model\n", - "]\n", - "for p in pkgs:\n", - " print(f\"{p} version: {version(p)}\")" - ] - }, - { - "cell_type": "markdown", - "id": "653410a6-dd2b-4eb2-a722-23d9782e726d", - "metadata": { - "id": "653410a6-dd2b-4eb2-a722-23d9782e726d" - }, - "source": [ - " \n", - "# 1. Architecture code" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "82076c21-9331-4dcd-b017-42b046cf1a60", - "metadata": { - "id": "82076c21-9331-4dcd-b017-42b046cf1a60" - }, - "outputs": [], - "source": [ - "import torch\n", - "import torch.nn as nn\n", - "\n", - "\n", - "class FeedForward(nn.Module):\n", - " def __init__(self, cfg):\n", - " super().__init__()\n", - " self.fc1 = nn.Linear(cfg[\"emb_dim\"], cfg[\"hidden_dim\"], dtype=cfg[\"dtype\"], bias=False)\n", - " self.fc2 = nn.Linear(cfg[\"emb_dim\"], cfg[\"hidden_dim\"], dtype=cfg[\"dtype\"], bias=False)\n", - " self.fc3 = nn.Linear(cfg[\"hidden_dim\"], cfg[\"emb_dim\"], dtype=cfg[\"dtype\"], bias=False)\n", - "\n", - " def forward(self, x):\n", - " x_fc1 = self.fc1(x)\n", - " x_fc2 = self.fc2(x)\n", - " x = nn.functional.silu(x_fc1) * x_fc2\n", - " return self.fc3(x)" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "4b9a346f-5826-4083-9162-abd56afc03f0", - "metadata": { - "id": "4b9a346f-5826-4083-9162-abd56afc03f0" - }, - "outputs": [], - "source": [ - "def compute_rope_params(head_dim, theta_base=10_000, context_length=4096, freq_config=None, dtype=torch.float32):\n", - " assert head_dim % 2 == 0, \"Embedding dimension must be even\"\n", - "\n", - " # Compute the inverse frequencies\n", - " inv_freq = 1.0 / (theta_base ** (torch.arange(0, head_dim, 2, dtype=dtype)[: (head_dim // 2)].float() / head_dim))\n", - "\n", - " # Frequency adjustments\n", - " if freq_config is not None:\n", - " low_freq_wavelen = freq_config[\"original_context_length\"] / freq_config[\"low_freq_factor\"]\n", - " high_freq_wavelen = freq_config[\"original_context_length\"] / freq_config[\"high_freq_factor\"]\n", - "\n", - " wavelen = 2 * torch.pi / inv_freq\n", - "\n", - " inv_freq_llama = torch.where(\n", - " wavelen > low_freq_wavelen, inv_freq / freq_config[\"factor\"], inv_freq\n", - " )\n", - "\n", - " smooth_factor = (freq_config[\"original_context_length\"] / wavelen - freq_config[\"low_freq_factor\"]) / (\n", - " freq_config[\"high_freq_factor\"] - freq_config[\"low_freq_factor\"]\n", - " )\n", - "\n", - " smoothed_inv_freq = (\n", - " (1 - smooth_factor) * (inv_freq / freq_config[\"factor\"]) + smooth_factor * inv_freq\n", - " )\n", - "\n", - " is_medium_freq = (wavelen <= low_freq_wavelen) & (wavelen >= high_freq_wavelen)\n", - " inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)\n", - " inv_freq = inv_freq_llama\n", - "\n", - " # Generate position indices\n", - " positions = torch.arange(context_length, dtype=dtype)\n", - "\n", - " # Compute the angles\n", - " angles = positions[:, None] * inv_freq[None, :] # Shape: (context_length, head_dim // 2)\n", - "\n", - " # Expand angles to match the head_dim\n", - " angles = torch.cat([angles, angles], dim=1) # Shape: (context_length, head_dim)\n", - "\n", - " # Precompute sine and cosine\n", - " cos = torch.cos(angles)\n", - " sin = torch.sin(angles)\n", - "\n", - " return cos, sin\n", - "\n", - "\n", - "def apply_rope(x, cos, sin):\n", - " # x: (batch_size, num_heads, seq_len, head_dim)\n", - " batch_size, num_heads, seq_len, head_dim = x.shape\n", - " assert head_dim % 2 == 0, \"Head dimension must be even\"\n", - "\n", - " # Split x into first half and second half\n", - " x1 = x[..., : head_dim // 2] # First half\n", - " x2 = x[..., head_dim // 2 :] # Second half\n", - "\n", - " # Adjust sin and cos shapes\n", - " cos = cos[:seq_len, :].unsqueeze(0).unsqueeze(0) # Shape: (1, 1, seq_len, head_dim)\n", - " sin = sin[:seq_len, :].unsqueeze(0).unsqueeze(0)\n", - "\n", - " # Apply the rotary transformation\n", - " rotated = torch.cat((-x2, x1), dim=-1)\n", - " x_rotated = (x * cos) + (rotated * sin)\n", - "\n", - " # It's ok to use lower-precision after applying cos and sin rotation\n", - " return x_rotated.to(dtype=x.dtype)\n", - "\n", - "\n", - "def rescale_theta(theta_old, context_length_old, context_length_new):\n", - " scaling_factor = context_length_new / context_length_old\n", - " theta_new = theta_old * scaling_factor\n", - " return theta_new" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "e8169ab5-f976-4222-a2e1-eb1cabf267cb", - "metadata": { - "id": "e8169ab5-f976-4222-a2e1-eb1cabf267cb" - }, - "outputs": [], - "source": [ - "class GroupedQueryAttention(nn.Module):\n", - " def __init__(\n", - " self, d_in, d_out, num_heads,\n", - " num_kv_groups,\n", - " dtype=None\n", - " ):\n", - " super().__init__()\n", - " assert d_out % num_heads == 0, \"d_out must be divisible by num_heads\"\n", - " assert num_heads % num_kv_groups == 0, \"num_heads must be divisible by num_kv_groups\"\n", - "\n", - " self.d_out = d_out\n", - " self.num_heads = num_heads\n", - " self.head_dim = d_out // num_heads\n", - "\n", - " self.W_key = nn.Linear(d_in, num_kv_groups * self.head_dim, bias=False, dtype=dtype)\n", - " self.W_value = nn.Linear(d_in, num_kv_groups * self.head_dim, bias=False, dtype=dtype)\n", - " self.num_kv_groups = num_kv_groups\n", - " self.group_size = num_heads // num_kv_groups\n", - "\n", - " self.W_query = nn.Linear(d_in, d_out, bias=False, dtype=dtype)\n", - " self.out_proj = nn.Linear(d_out, d_out, bias=False, dtype=dtype)\n", - "\n", - " def forward(self, x, mask, cos, sin):\n", - " b, num_tokens, d_in = x.shape\n", - "\n", - " queries = self.W_query(x) # Shape: (b, num_tokens, d_out)\n", - " keys = self.W_key(x) # Shape: (b, num_tokens, num_kv_groups * head_dim)\n", - " values = self.W_value(x) # Shape: (b, num_tokens, num_kv_groups * head_dim)\n", - "\n", - " # Reshape queries, keys, and values\n", - " queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)\n", - " keys = keys.view(b, num_tokens, self.num_kv_groups, self.head_dim)\n", - " values = values.view(b, num_tokens, self.num_kv_groups, self.head_dim)\n", - "\n", - " # Transpose keys, values, and queries\n", - " keys = keys.transpose(1, 2) # Shape: (b, num_heads, num_tokens, head_dim)\n", - " values = values.transpose(1, 2) # Shape: (b, num_heads, num_tokens, head_dim)\n", - " queries = queries.transpose(1, 2) # Shape: (b, num_query_groups, num_tokens, head_dim)\n", - "\n", - " # Apply RoPE\n", - " keys = apply_rope(keys, cos, sin)\n", - " queries = apply_rope(queries, cos, sin)\n", - "\n", - " # Expand keys and values to match the number of heads\n", - " # Shape: (b, num_heads, num_tokens, head_dim)\n", - " keys = keys.repeat_interleave(self.group_size, dim=1) # Shape: (b, num_heads, num_tokens, head_dim)\n", - " values = values.repeat_interleave(self.group_size, dim=1) # Shape: (b, num_heads, num_tokens, head_dim)\n", - " # For example, before repeat_interleave along dim=1 (query groups):\n", - " # [K1, K2]\n", - " # After repeat_interleave (each query group is repeated group_size times):\n", - " # [K1, K1, K2, K2]\n", - " # If we used regular repeat instead of repeat_interleave, we'd get:\n", - " # [K1, K2, K1, K2]\n", - "\n", - " # Compute scaled dot-product attention (aka self-attention) with a causal mask\n", - " # Shape: (b, num_heads, num_tokens, num_tokens)\n", - " attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head\n", - "\n", - " # Use the mask to fill attention scores\n", - " attn_scores = attn_scores.masked_fill(mask[:num_tokens, :num_tokens], -torch.inf)\n", - "\n", - " attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)\n", - " assert keys.shape[-1] == self.head_dim\n", - "\n", - " # Shape: (b, num_tokens, num_heads, head_dim)\n", - " context_vec = (attn_weights @ values).transpose(1, 2)\n", - "\n", - " # Combine heads, where self.d_out = self.num_heads * self.head_dim\n", - " context_vec = context_vec.reshape(b, num_tokens, self.d_out)\n", - " context_vec = self.out_proj(context_vec) # optional projection\n", - "\n", - " return context_vec" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "457cb2f8-50c1-4045-8a74-f181bfb5fea9", - "metadata": { - "id": "457cb2f8-50c1-4045-8a74-f181bfb5fea9" - }, - "outputs": [], - "source": [ - "class TransformerBlock(nn.Module):\n", - " def __init__(self, cfg):\n", - " super().__init__()\n", - " self.att = GroupedQueryAttention(\n", - " d_in=cfg[\"emb_dim\"],\n", - " d_out=cfg[\"emb_dim\"],\n", - " num_heads=cfg[\"n_heads\"],\n", - " num_kv_groups=cfg[\"n_kv_groups\"],\n", - " dtype=cfg[\"dtype\"]\n", - " )\n", - " self.ff = FeedForward(cfg)\n", - " self.norm1 = nn.RMSNorm(cfg[\"emb_dim\"], eps=1e-5, dtype=cfg[\"dtype\"])\n", - " self.norm2 = nn.RMSNorm(cfg[\"emb_dim\"], eps=1e-5, dtype=cfg[\"dtype\"])\n", - "\n", - " def forward(self, x, mask, cos, sin):\n", - " # Shortcut connection for attention block\n", - " shortcut = x\n", - " x = self.norm1(x)\n", - " x = self.att(x, mask, cos, sin) # Shape [batch_size, num_tokens, emb_size]\n", - " x = x + shortcut # Add the original input back\n", - "\n", - " # Shortcut connection for feed-forward block\n", - " shortcut = x\n", - " x = self.norm2(x)\n", - " x = self.ff(x)\n", - " x = x + shortcut # Add the original input back\n", - "\n", - " return x" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "e88de3e3-9f07-42cc-816b-28dbd46e96c4", - "metadata": { - "id": "e88de3e3-9f07-42cc-816b-28dbd46e96c4" - }, - "outputs": [], - "source": [ - "class Llama3Model(nn.Module):\n", - " def __init__(self, cfg):\n", - " super().__init__()\n", - "\n", - " # Main model parameters\n", - " self.tok_emb = nn.Embedding(cfg[\"vocab_size\"], cfg[\"emb_dim\"], dtype=cfg[\"dtype\"])\n", - "\n", - " self.trf_blocks = nn.ModuleList( # ModuleList since Sequential can only accept one input, and we need `x, mask, cos, sin`\n", - " [TransformerBlock(cfg) for _ in range(cfg[\"n_layers\"])]\n", - " )\n", - "\n", - " self.final_norm = nn.RMSNorm(cfg[\"emb_dim\"], eps=1e-5, dtype=cfg[\"dtype\"])\n", - " self.out_head = nn.Linear(cfg[\"emb_dim\"], cfg[\"vocab_size\"], bias=False, dtype=cfg[\"dtype\"])\n", - "\n", - " # Reusuable utilities\n", - " self.register_buffer(\n", - " \"mask\", torch.triu(torch.ones(cfg[\"context_length\"], cfg[\"context_length\"]), diagonal=1).bool(),\n", - " persistent=False\n", - " )\n", - " cfg[\"rope_base\"] = rescale_theta(\n", - " cfg[\"rope_base\"],\n", - " cfg[\"orig_context_length\"],\n", - " cfg[\"context_length\"]\n", - " )\n", - " cos, sin = compute_rope_params(\n", - " head_dim=cfg[\"emb_dim\"] // cfg[\"n_heads\"],\n", - " theta_base=cfg[\"rope_base\"],\n", - " context_length=cfg[\"context_length\"],\n", - " freq_config=cfg[\"rope_freq\"]\n", - " )\n", - " self.register_buffer(\"cos\", cos)\n", - " self.register_buffer(\"sin\", sin)\n", - " self.cfg = cfg\n", - "\n", - "\n", - " def forward(self, in_idx):\n", - " # Forward pass\n", - " tok_embeds = self.tok_emb(in_idx)\n", - " x = tok_embeds\n", - " \n", - " for block in self.trf_blocks:\n", - " x = block(x, self.mask, self.cos, self.sin)\n", - " x = self.final_norm(x)\n", - " logits = self.out_head(x.to(self.cfg[\"dtype\"]))\n", - " return logits" - ] - }, - { - "cell_type": "markdown", - "id": "be2d201f-74ad-4d63-ab9c-601b00674a48", - "metadata": { - "id": "be2d201f-74ad-4d63-ab9c-601b00674a48" - }, - "source": [ - " \n", - "# 2. Initialize model" - ] - }, - { - "cell_type": "markdown", - "id": "23dea40c-fe20-4a75-be25-d6fce5863c01", - "metadata": { - "id": "23dea40c-fe20-4a75-be25-d6fce5863c01" - }, - "source": [ - "- The remainder of this notebook uses the Llama 3.2 1B model; to use the 3B model variant, just uncomment the second configuration file in the following code cell" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "caa142fa-b375-4e78-b392-2072ced666f3", - "metadata": { - "id": "caa142fa-b375-4e78-b392-2072ced666f3" - }, - "outputs": [], - "source": [ - "# Llama 3.2 1B\n", - "\n", - "LLAMA32_CONFIG = {\n", - " \"vocab_size\": 128_256, # Vocabulary size\n", - " \"context_length\": 8192, # Maximum context length to use (reduced to save memory)\n", - " \"orig_context_length\": 131_072, # Context length that was used to train the model\n", - " \"emb_dim\": 2048, # Embedding dimension\n", - " \"n_heads\": 32, # Number of attention heads\n", - " \"n_layers\": 16, # Number of layers\n", - " \"hidden_dim\": 8192, # Size of the intermediate dimension in FeedForward\n", - " \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n", - " \"rope_base\": 500_000.0, # The base in RoPE's \"theta\"\n", - " \"dtype\": torch.bfloat16, # Lower-precision dtype to reduce memory usage\n", - " \"rope_freq\": { # RoPE frequency scaling\n", - " \"factor\": 32.0,\n", - " \"low_freq_factor\": 1.0,\n", - " \"high_freq_factor\": 4.0,\n", - " \"original_context_length\": 8192,\n", - " }\n", - "}\n", - "\n", - "# Llama 3.2 3B\n", - "\n", - "# LLAMA32_CONFIG = {\n", - "# \"vocab_size\": 128_256, # Vocabulary size\n", - "# \"context_length\": 8192, # Maximum context length to use (reduced to save memory)\n", - "# \"orig_context_length\": 131_072, # Context length that was used to train the model\n", - "# \"emb_dim\": 3072, # Embedding dimension\n", - "# \"n_heads\": 24, # Number of attention heads\n", - "# \"n_layers\": 28, # Number of layers\n", - "# \"hidden_dim\": 8192, # Size of the intermediate dimension in FeedForward\n", - "# \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n", - "# \"rope_base\": 500_000.0, # The base in RoPE's \"theta\"\n", - "# \"dtype\": torch.bfloat16, # Lower-precision dtype to reduce memory usage\n", - "# \"rope_freq\": { # RoPE frequency scaling\n", - "# \"factor\": 32.0,\n", - "# \"low_freq_factor\": 1.0,\n", - "# \"high_freq_factor\": 4.0,\n", - "# \"original_context_length\": 8192,\n", - "# }\n", - "# }\n", - "\n", - "LLAMA_SIZE_STR = \"1B\" if LLAMA32_CONFIG[\"emb_dim\"] == 2048 else \"3B\"" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "156253fe-aacd-4da2-8f13-705f05c4b11e", - "metadata": { - "id": "156253fe-aacd-4da2-8f13-705f05c4b11e" - }, - "outputs": [], - "source": [ - "model = Llama3Model(LLAMA32_CONFIG)" - ] - }, - { - "cell_type": "markdown", - "id": "19de6c2c-83ce-456d-8be9-6ec415fe9eb1", - "metadata": { - "id": "19de6c2c-83ce-456d-8be9-6ec415fe9eb1" - }, - "source": [ - "- The following is expected to print True to confirm buffers are reused instead of being (wastefully) recreated:" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "364e76ca-52f8-4fa5-af37-c4069f9694bc", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "364e76ca-52f8-4fa5-af37-c4069f9694bc", - "outputId": "00d7e983-262e-4c65-f322-f4d999311988" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Total number of parameters: 1,498,482,688\n", - "\n", - "Total number of unique parameters: 1,235,814,400\n" - ] - } - ], - "source": [ - "total_params = sum(p.numel() for p in model.parameters())\n", - "print(f\"Total number of parameters: {total_params:,}\")\n", - "\n", - "# Account for weight tying\n", - "total_params_normalized = total_params - model.tok_emb.weight.numel()\n", - "print(f\"\\nTotal number of unique parameters: {total_params_normalized:,}\")" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "fd5efb03-5a07-46e8-8607-93ed47549d2b", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "fd5efb03-5a07-46e8-8607-93ed47549d2b", - "outputId": "65c1a95e-b502-4150-9e2e-da619d9053d5" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "float32 (PyTorch default): 11.42 GB\n", - "bfloat16: 5.71 GB\n" - ] - } - ], - "source": [ - "def model_memory_size(model, input_dtype=torch.float32):\n", - " total_params = 0\n", - " total_grads = 0\n", - " for param in model.parameters():\n", - " # Calculate total number of elements per parameter\n", - " param_size = param.numel()\n", - " total_params += param_size\n", - " # Check if gradients are stored for this parameter\n", - " if param.requires_grad:\n", - " total_grads += param_size\n", - "\n", - " # Calculate buffer size (non-parameters that require memory)\n", - " total_buffers = sum(buf.numel() for buf in model.buffers())\n", - "\n", - " # Size in bytes = (Number of elements) * (Size of each element in bytes)\n", - " # We assume parameters and gradients are stored in the same type as input dtype\n", - " element_size = torch.tensor(0, dtype=input_dtype).element_size()\n", - " total_memory_bytes = (total_params + total_grads + total_buffers) * element_size\n", - "\n", - " # Convert bytes to gigabytes\n", - " total_memory_gb = total_memory_bytes / (1024**3)\n", - "\n", - " return total_memory_gb\n", - "\n", - "print(f\"float32 (PyTorch default): {model_memory_size(model, input_dtype=torch.float32):.2f} GB\")\n", - "print(f\"bfloat16: {model_memory_size(model, input_dtype=torch.bfloat16):.2f} GB\")" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "id": "31f12baf-f79b-499f-85c0-51328a6a20f5", - "metadata": { - "id": "31f12baf-f79b-499f-85c0-51328a6a20f5" - }, - "outputs": [], - "source": [ - "if torch.cuda.is_available():\n", - " device = torch.device(\"cuda\")\n", - "elif torch.backends.mps.is_available():\n", - " device = torch.device(\"mps\")\n", - "else:\n", - " device = torch.device(\"cpu\")\n", - "\n", - "model.to(device);" - ] - }, - { - "cell_type": "markdown", - "id": "78e091e1-afa8-4d23-9aea-cced86181bfd", - "metadata": { - "id": "78e091e1-afa8-4d23-9aea-cced86181bfd" - }, - "source": [ - " \n", - "# 3. Load tokenizer" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "id": "9482b01c-49f9-48e4-ab2c-4a4c75240e77", - "metadata": { - "id": "9482b01c-49f9-48e4-ab2c-4a4c75240e77" - }, - "outputs": [], - "source": [ - "import os\n", - "from pathlib import Path\n", - "\n", - "import tiktoken\n", - "from tiktoken.load import load_tiktoken_bpe\n", - "\n", - "\n", - "class Tokenizer:\n", - " def __init__(self, model_path):\n", - " assert os.path.isfile(model_path), f\"Model file {model_path} not found\"\n", - " mergeable_ranks = load_tiktoken_bpe(model_path)\n", - "\n", - " self.special_tokens = {\n", - " \"<|begin_of_text|>\": 128000,\n", - " \"<|end_of_text|>\": 128001,\n", - " \"<|start_header_id|>\": 128006,\n", - " \"<|end_header_id|>\": 128007,\n", - " \"<|eot_id|>\": 128009,\n", - " }\n", - " self.special_tokens.update({\n", - " f\"<|reserved_{i}|>\": 128002 + i for i in range(256) if (128002 + i) not in self.special_tokens.values()\n", - " })\n", - "\n", - " self.model = tiktoken.Encoding(\n", - " name=Path(model_path).name,\n", - " pat_str=r\"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+\",\n", - " mergeable_ranks=mergeable_ranks,\n", - " special_tokens=self.special_tokens\n", - " )\n", - "\n", - "\n", - " def encode(self, text, bos=False, eos=False, allowed_special=set(), disallowed_special=()):\n", - " if bos:\n", - " tokens = [self.special_tokens[\"<|begin_of_text|>\"]]\n", - " else:\n", - " tokens = []\n", - "\n", - " tokens += self.model.encode(text, allowed_special=allowed_special, disallowed_special=disallowed_special)\n", - "\n", - " if eos:\n", - " tokens.append(self.special_tokens[\"<|end_of_text|>\"])\n", - " return tokens\n", - "\n", - " def decode(self, tokens):\n", - " return self.model.decode(tokens)\n", - "\n", - "\n", - "class ChatFormat:\n", - " def __init__(self, tokenizer):\n", - " self.tokenizer = tokenizer\n", - "\n", - " def encode_header(self, message):\n", - " tokens = []\n", - " tokens.append(self.tokenizer.special_tokens[\"<|start_header_id|>\"])\n", - " tokens.extend(self.tokenizer.encode(message[\"role\"], bos=False, eos=False))\n", - " tokens.append(self.tokenizer.special_tokens[\"<|end_header_id|>\"])\n", - " tokens.extend(self.tokenizer.encode(\"\\n\\n\", bos=False, eos=False))\n", - " return tokens\n", - "\n", - " def encode(self, text):\n", - " message = {\n", - " \"role\": \"user\",\n", - " \"content\": text\n", - " }\n", - "\n", - " tokens = self.encode_header(message)\n", - " tokens.extend(\n", - " self.tokenizer.encode(message[\"content\"].strip(), bos=False, eos=False)\n", - " )\n", - " tokens.append(self.tokenizer.special_tokens[\"<|eot_id|>\"])\n", - " return tokens\n", - "\n", - " def decode(self, token_ids):\n", - " return self.tokenizer.decode(token_ids)" - ] - }, - { - "cell_type": "markdown", - "id": "b771b60c-c198-4b30-bf10-42031197ae86", - "metadata": { - "id": "b771b60c-c198-4b30-bf10-42031197ae86" - }, - "source": [ - "- Please note that Meta AI requires that you accept the Llama 3.2 licensing terms before you can download the files; to do this, you have to create a Hugging Face Hub account and visit the [meta-llama/Llama-3.2-1B](https://huggingface.co/meta-llama/Llama-3.2-1B) repository to accept the terms\n", - "- Next, you will need to create an access token; to generate an access token with READ permissions, click on the profile picture in the upper right and click on \"Settings\"\n", - "\n", - "\n", - "\n", - "\n", - "- Then, create and copy the access token so you can copy & paste it into the next code cell\n", - "\n", - "" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "id": "e9d96dc8-603a-4cb5-8c3e-4d2ca56862ed", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "e9d96dc8-603a-4cb5-8c3e-4d2ca56862ed", - "outputId": "e6e6dc05-7330-45bc-a9a7-331919155bdd" - }, - "outputs": [], - "source": [ - "from huggingface_hub import login\n", - "\n", - "login()" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "id": "986bc1a0-804f-4154-80f8-44cefbee1368", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 141, - "referenced_widgets": [ - "a1608feac06d4687967a3e398f01c489", - "518fb202e4b44aaba47f07d1a61b6762", - "672cdc5aea954de3af851c001a667ad3", - "eebf8874618746b39cf4a21a2728dc7f", - "5176834aa8784bba9ec21234b87a8948", - "e2dc407afcd945c798e30597fddfcb3c", - "0dccd57dcc5c43a588157cef957c07e8", - "33ca0cdf2c7f41598a381c4ebe6a4ee1", - "ee44487f58454dacb522b1e084ffb733", - "d2c41e71a3f441deaed091b620ac5603", - "3326b6141a1a4eba9f316df528a9b99a" - ] - }, - "id": "986bc1a0-804f-4154-80f8-44cefbee1368", - "outputId": "5dd7334b-4c71-465a-94d2-c3e95b9ddc58" - }, - "outputs": [], - "source": [ - "from huggingface_hub import hf_hub_download\n", - "\n", - "tokenizer_file_path = hf_hub_download(\n", - " repo_id=f\"meta-llama/Llama-3.2-{LLAMA_SIZE_STR}-Instruct\",\n", - " filename=\"original/tokenizer.model\",\n", - " local_dir=f\"Llama-3.2-{LLAMA_SIZE_STR}-Instruct\"\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "id": "_gBhxDtU_nxo", - "metadata": { - "id": "_gBhxDtU_nxo" - }, - "outputs": [], - "source": [ - "tokenizer = Tokenizer(tokenizer_file_path)\n", - "chat_tokenizer = ChatFormat(tokenizer)" - ] - }, - { - "cell_type": "markdown", - "id": "c172f89f-d301-439f-b809-46169e5f5945", - "metadata": { - "id": "c172f89f-d301-439f-b809-46169e5f5945" - }, - "source": [ - " \n", - "# 4. Load pretrained weights" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "id": "75166128-5899-4995-9b88-9672e135650e", - "metadata": { - "id": "75166128-5899-4995-9b88-9672e135650e" - }, - "outputs": [], - "source": [ - "def assign(left, right, tensor_name=\"unknown\"):\n", - " if left.shape != right.shape:\n", - " raise ValueError(f\"Shape mismatch in tensor '{tensor_name}'. Left: {left.shape}, Right: {right.shape}\")\n", - "\n", - " if isinstance(right, torch.Tensor):\n", - " return torch.nn.Parameter(right.clone().detach())\n", - " else:\n", - " return torch.nn.Parameter(torch.tensor(right))\n", - "\n", - "\n", - "def load_weights_into_llama(model, param_config, params):\n", - " model.tok_emb.weight = assign(model.tok_emb.weight, params[\"model.embed_tokens.weight\"], \"model.embed_tokens.weight\")\n", - "\n", - " for l in range(param_config[\"n_layers\"]):\n", - "\n", - " # Load attention weights\n", - " model.trf_blocks[l].att.W_query.weight = assign(\n", - " model.trf_blocks[l].att.W_query.weight,\n", - " params[f\"model.layers.{l}.self_attn.q_proj.weight\"],\n", - " f\"model.layers.{l}.self_attn.q_proj.weight\"\n", - " )\n", - " model.trf_blocks[l].att.W_key.weight = assign(\n", - " model.trf_blocks[l].att.W_key.weight,\n", - " params[f\"model.layers.{l}.self_attn.k_proj.weight\"],\n", - " f\"model.layers.{l}.self_attn.k_proj.weight\"\n", - " )\n", - " model.trf_blocks[l].att.W_value.weight = assign(\n", - " model.trf_blocks[l].att.W_value.weight,\n", - " params[f\"model.layers.{l}.self_attn.v_proj.weight\"],\n", - " f\"model.layers.{l}.self_attn.v_proj.weight\"\n", - " )\n", - " model.trf_blocks[l].att.out_proj.weight = assign(\n", - " model.trf_blocks[l].att.out_proj.weight,\n", - " params[f\"model.layers.{l}.self_attn.o_proj.weight\"],\n", - " f\"model.layers.{l}.self_attn.o_proj.weight\"\n", - " )\n", - " model.trf_blocks[l].norm1.weight = assign(\n", - " model.trf_blocks[l].norm1.weight,\n", - " params[f\"model.layers.{l}.input_layernorm.weight\"],\n", - " f\"model.layers.{l}.input_layernorm.weight\"\n", - " )\n", - "\n", - " # Load FeedForward weights\n", - " model.trf_blocks[l].ff.fc1.weight = assign(\n", - " model.trf_blocks[l].ff.fc1.weight,\n", - " params[f\"model.layers.{l}.mlp.gate_proj.weight\"],\n", - " f\"model.layers.{l}.mlp.gate_proj.weight\"\n", - " )\n", - " model.trf_blocks[l].ff.fc2.weight = assign(\n", - " model.trf_blocks[l].ff.fc2.weight,\n", - " params[f\"model.layers.{l}.mlp.up_proj.weight\"],\n", - " f\"model.layers.{l}.mlp.up_proj.weight\"\n", - " )\n", - " model.trf_blocks[l].ff.fc3.weight = assign(\n", - " model.trf_blocks[l].ff.fc3.weight,\n", - " params[f\"model.layers.{l}.mlp.down_proj.weight\"],\n", - " f\"model.layers.{l}.mlp.down_proj.weight\"\n", - " )\n", - " model.trf_blocks[l].norm2.weight = assign(\n", - " model.trf_blocks[l].norm2.weight,\n", - " params[f\"model.layers.{l}.post_attention_layernorm.weight\"],\n", - " f\"model.layers.{l}.post_attention_layernorm.weight\"\n", - " )\n", - "\n", - " # Load output layer weights\n", - " model.final_norm.weight = assign(model.final_norm.weight, params[\"model.norm.weight\"], \"model.norm.weight\")\n", - "\n", - " if \"lm_head.weight\" in params.keys():\n", - " model.out_head.weight = assign(model.out_head.weight, params[\"lm_head.weight\"], \"lm_head.weight\")\n", - " else:\n", - " model.out_head.weight = assign(model.out_head.weight, params[\"model.embed_tokens.weight\"], \"model.embed_tokens.weight\")\n", - " print(\"Model uses weight tying.\")" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "id": "699cb1b8-a67d-49fb-80a6-0dad9d81f392", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 17, - "referenced_widgets": [ - "9881b6995c3f49dc89e6992fd9ab660b", - "17a3174e65c54476b2e0d1faf8f011ca", - "1bbf2e62c0754d1593beb4105a7f1ac1", - "b82112e1dec645d98aa1c1ba64abcb61", - "271e2bd6a35e4a8b92de8697f7c0be5f", - "90a79523187446dfa692723b2e5833a7", - "431ffb83b8c14bf182f0430e07ea6154", - "a8f1b72a33dd4b548de23fbd95e0da18", - "25cc36132d384189acfbecc59483134b", - "bfd06423ad544218968648016e731a46", - "d029630b63ff44cf807ade428d2eb421" - ] - }, - "id": "699cb1b8-a67d-49fb-80a6-0dad9d81f392", - "outputId": "55b2f28c-142f-4698-9d23-d27456d3ed6d" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Model uses weight tying.\n" - ] - } - ], - "source": [ - "from safetensors.torch import load_file\n", - "\n", - "\n", - "if LLAMA_SIZE_STR == \"1B\":\n", - " weights_file = hf_hub_download(\n", - " repo_id=f\"meta-llama/Llama-3.2-{LLAMA_SIZE_STR}-Instruct\",\n", - " filename=\"model.safetensors\",\n", - " local_dir=f\"Llama-3.2-{LLAMA_SIZE_STR}-Instruct\"\n", - " )\n", - " combined_weights = load_file(weights_file)\n", - "\n", - "\n", - "else:\n", - " combined_weights = {}\n", - " for i in range(1, 3):\n", - " weights_file = hf_hub_download(\n", - " repo_id=f\"meta-llama/Llama-3.2-{LLAMA_SIZE_STR}-Instruct\",\n", - " filename=f\"model-0000{i}-of-00002.safetensors\",\n", - " local_dir=f\"Llama-3.2-{LLAMA_SIZE_STR}-Instruct\"\n", - " )\n", - " current_weights = load_file(weights_file)\n", - " combined_weights.update(current_weights)\n", - "\n", - "\n", - "load_weights_into_llama(model, LLAMA32_CONFIG, combined_weights)\n", - "model.to(device)\n", - "del combined_weights # free up memory" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "id": "7f9f7ccc-70cb-41ff-9c25-44336042fc37", - "metadata": { - "id": "7f9f7ccc-70cb-41ff-9c25-44336042fc37" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Weight tying: True\n" - ] - } - ], - "source": [ - "print(\"Weight tying:\", torch.equal(model.tok_emb.weight, model.out_head.weight))" - ] - }, - { - "cell_type": "markdown", - "id": "57d07df1-4401-4792-b549-7c4cc5632323", - "metadata": { - "id": "57d07df1-4401-4792-b549-7c4cc5632323" - }, - "source": [ - " \n", - "# 5. Generate text" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "id": "7b8401c6-e244-4cb7-9849-2ba71ce758d5", - "metadata": { - "id": "7b8401c6-e244-4cb7-9849-2ba71ce758d5" - }, - "outputs": [], - "source": [ - "def text_to_token_ids(text, tokenizer):\n", - " encoded = tokenizer.encode(text)\n", - " encoded_tensor = torch.tensor(encoded).unsqueeze(0) # add batch dimension\n", - " return encoded_tensor\n", - "\n", - "\n", - "def token_ids_to_text(token_ids, tokenizer):\n", - " flat = token_ids.squeeze(0) # remove batch dimension\n", - " return tokenizer.decode(flat.tolist())\n", - "\n", - "\n", - "def generate(model, idx, max_new_tokens, context_size, temperature=0.0, top_k=None, eos_id=None):\n", - "\n", - " # For-loop is the same as before: Get logits, and only focus on last time step\n", - " for _ in range(max_new_tokens):\n", - " idx_cond = idx[:, -context_size:]\n", - " with torch.no_grad():\n", - " logits = model(idx_cond)\n", - " logits = logits[:, -1, :]\n", - "\n", - " # New: Filter logits with top_k sampling\n", - " if top_k is not None:\n", - " # Keep only top_k values\n", - " top_logits, _ = torch.topk(logits, top_k)\n", - " min_val = top_logits[:, -1]\n", - " logits = torch.where(logits < min_val, torch.tensor(float('-inf')).to(logits.device), logits)\n", - "\n", - " # New: Apply temperature scaling\n", - " if temperature > 0.0:\n", - " logits = logits / temperature\n", - "\n", - " # Apply softmax to get probabilities\n", - " probs = torch.softmax(logits, dim=-1) # (batch_size, context_len)\n", - "\n", - " # Sample from the distribution\n", - " idx_next = torch.multinomial(probs, num_samples=1) # (batch_size, 1)\n", - "\n", - " # Otherwise same as before: get idx of the vocab entry with the highest logits value\n", - " else:\n", - " idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch_size, 1)\n", - "\n", - " if idx_next == eos_id: # Stop generating early if end-of-sequence token is encountered and eos_id is specified\n", - " break\n", - "\n", - " # Same as before: append sampled index to the running sequence\n", - " idx = torch.cat((idx, idx_next), dim=1) # (batch_size, num_tokens+1)\n", - "\n", - " return idx" - ] - }, - { - "cell_type": "code", - "execution_count": 24, - "id": "1c7a04fa-6aac-416b-8f63-f1e19227633d", - "metadata": { - "id": "1c7a04fa-6aac-416b-8f63-f1e19227633d" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Time: 19.49 sec\n", - "\n", - "\n", - "Output text:\n", - "\n", - " Llamas are herbivores, which means they primarily eat plants and plant-based foods. Their diet typically consists of:\n", - "\n", - "1. Grasses: Llamas love to graze on various types of grasses, including tall grasses and short grasses.\n", - "2. Hay: Llamas also eat hay, which is a dry, compressed form of grass or other plants.\n", - "3. Alfalfa: Alfalfa is a legume that is commonly used as a hay substitute in llama feed.\n", - "4. Other plants: Llamas will also eat other plants, such as clover, dandelions, and wild grasses.\n", - "\n", - "It's worth noting that the specific diet of llamas can vary depending on factors such as\n" - ] - } - ], - "source": [ - "import time\n", - "\n", - "\n", - "PROMPT = \"What do llamas eat?\"\n", - "\n", - "torch.manual_seed(123)\n", - "\n", - "start = time.time()\n", - "\n", - "token_ids = generate(\n", - " model=model,\n", - " idx=text_to_token_ids(PROMPT, chat_tokenizer).to(device),\n", - " max_new_tokens=150,\n", - " context_size=LLAMA32_CONFIG[\"context_length\"],\n", - " top_k=1,\n", - " temperature=0.\n", - ")\n", - "\n", - "print(f\"Time: {time.time() - start:.2f} sec\")\n", - "\n", - "if torch.cuda.is_available():\n", - " max_mem_bytes = torch.cuda.max_memory_allocated()\n", - " max_mem_gb = max_mem_bytes / (1024 ** 3)\n", - " print(f\"Max memory allocated: {max_mem_gb:.2f} GB\")\n", - "\n", - "output_text = token_ids_to_text(token_ids, tokenizer)\n", - "\n", - "\n", - "def clean_text(text, header_end=\"assistant<|end_header_id|>\\n\\n\"):\n", - " # Find the index of the first occurrence of \"<|end_header_id|>\"\n", - " index = text.find(header_end)\n", - "\n", - " if index != -1:\n", - " # Return the substring starting after \"<|end_header_id|>\"\n", - " return text[index + len(header_end):].strip() # Strip removes leading/trailing whitespace\n", - " else:\n", - " # If the token is not found, return the original text\n", - " return text\n", - "\n", - "print(\"\\n\\nOutput text:\\n\\n\", clean_text(output_text))" - ] - }, - { - "cell_type": "markdown", - "id": "549324d6-5c71-4147-ae21-2e67675faa3d", - "metadata": { - "id": "549324d6-5c71-4147-ae21-2e67675faa3d" - }, - "source": [ - " \n", - "# What's next?" - ] - }, - { - "cell_type": "markdown", - "id": "e6edaaae-2de1-406c-8ffa-897cdfa3808c", - "metadata": { - "id": "e6edaaae-2de1-406c-8ffa-897cdfa3808c" - }, - "source": [ - "- The notebook was kept purposefully minimal; if you are interested in additional explanation about the individual components, check out the following two companion notebooks:\n", - "\n", - "\n", - "\n", - " 1. [Converting a From-Scratch GPT Architecture to Llama 2](converting-gpt-to-llama2.ipynb)\n", - " 2. [Converting Llama 2 to Llama 3.2 From Scratch](converting-llama2-to-llama3.ipynb)\n", - " \n", - "- For those interested in a comprehensive guide on building a large language model from scratch and gaining a deeper understanding of its mechanics, you might like my [Build a Large Language Model (From Scratch)](http://mng.bz/orYv)\n", - "\n", - "" - ] - } - ], - "metadata": { - "accelerator": "GPU", - "colab": { - "gpuType": "A100", - "provenance": [] - }, - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.16" - }, - "widgets": { - "application/vnd.jupyter.widget-state+json": { - "0dccd57dcc5c43a588157cef957c07e8": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "2.0.0", - "model_name": "HTMLStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "2.0.0", - "_model_name": "HTMLStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "2.0.0", - "_view_name": "StyleView", - "background": null, - "description_width": "", - "font_size": null, - "text_color": null - } - }, - "17a3174e65c54476b2e0d1faf8f011ca": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "2.0.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "2.0.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "2.0.0", - "_view_name": "HTMLView", - "description": "", - "description_allow_html": false, - "layout": "IPY_MODEL_90a79523187446dfa692723b2e5833a7", - "placeholder": "​", - "style": "IPY_MODEL_431ffb83b8c14bf182f0430e07ea6154", - "tabbable": null, - "tooltip": null, - "value": "model.safetensors:  35%" - } - }, - "1bbf2e62c0754d1593beb4105a7f1ac1": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "2.0.0", - "model_name": "FloatProgressModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "2.0.0", - "_model_name": "FloatProgressModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "2.0.0", - "_view_name": "ProgressView", - "bar_style": "", - "description": "", - "description_allow_html": false, - "layout": "IPY_MODEL_a8f1b72a33dd4b548de23fbd95e0da18", - "max": 2471645608, - "min": 0, - "orientation": "horizontal", - "style": "IPY_MODEL_25cc36132d384189acfbecc59483134b", - "tabbable": null, - "tooltip": null, - "value": 880803840 - } - }, - "25cc36132d384189acfbecc59483134b": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "2.0.0", - "model_name": "ProgressStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "2.0.0", - "_model_name": "ProgressStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "2.0.0", - "_view_name": "StyleView", - "bar_color": null, - "description_width": "" - } - }, - "271e2bd6a35e4a8b92de8697f7c0be5f": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "2.0.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "2.0.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "2.0.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border_bottom": null, - "border_left": null, - "border_right": null, - "border_top": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "3326b6141a1a4eba9f316df528a9b99a": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "2.0.0", - "model_name": "HTMLStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "2.0.0", - "_model_name": "HTMLStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "2.0.0", - "_view_name": "StyleView", - "background": null, - "description_width": "", - "font_size": null, - "text_color": null - } - }, - "33ca0cdf2c7f41598a381c4ebe6a4ee1": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "2.0.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "2.0.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "2.0.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border_bottom": null, - "border_left": null, - "border_right": null, - "border_top": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "431ffb83b8c14bf182f0430e07ea6154": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "2.0.0", - "model_name": "HTMLStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "2.0.0", - "_model_name": "HTMLStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "2.0.0", - "_view_name": "StyleView", - "background": null, - "description_width": "", - "font_size": null, - "text_color": null - } - }, - "5176834aa8784bba9ec21234b87a8948": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "2.0.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "2.0.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "2.0.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border_bottom": null, - "border_left": null, - "border_right": null, - "border_top": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "518fb202e4b44aaba47f07d1a61b6762": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "2.0.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "2.0.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "2.0.0", - "_view_name": "HTMLView", - "description": "", - "description_allow_html": false, - "layout": "IPY_MODEL_e2dc407afcd945c798e30597fddfcb3c", - "placeholder": "​", - "style": "IPY_MODEL_0dccd57dcc5c43a588157cef957c07e8", - "tabbable": null, - "tooltip": null, - "value": "tokenizer.model: 100%" - } - }, - "672cdc5aea954de3af851c001a667ad3": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "2.0.0", - "model_name": "FloatProgressModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "2.0.0", - "_model_name": "FloatProgressModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "2.0.0", - "_view_name": "ProgressView", - "bar_style": "success", - "description": "", - "description_allow_html": false, - "layout": "IPY_MODEL_33ca0cdf2c7f41598a381c4ebe6a4ee1", - "max": 2183982, - "min": 0, - "orientation": "horizontal", - "style": "IPY_MODEL_ee44487f58454dacb522b1e084ffb733", - "tabbable": null, - "tooltip": null, - "value": 2183982 - } - }, - "90a79523187446dfa692723b2e5833a7": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "2.0.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "2.0.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "2.0.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border_bottom": null, - "border_left": null, - "border_right": null, - "border_top": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "9881b6995c3f49dc89e6992fd9ab660b": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "2.0.0", - "model_name": "HBoxModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "2.0.0", - "_model_name": "HBoxModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "2.0.0", - "_view_name": "HBoxView", - "box_style": "", - "children": [ - "IPY_MODEL_17a3174e65c54476b2e0d1faf8f011ca", - "IPY_MODEL_1bbf2e62c0754d1593beb4105a7f1ac1", - "IPY_MODEL_b82112e1dec645d98aa1c1ba64abcb61" - ], - "layout": "IPY_MODEL_271e2bd6a35e4a8b92de8697f7c0be5f", - "tabbable": null, - "tooltip": null - } - }, - "a1608feac06d4687967a3e398f01c489": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "2.0.0", - "model_name": "HBoxModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "2.0.0", - "_model_name": "HBoxModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "2.0.0", - "_view_name": "HBoxView", - "box_style": "", - "children": [ - "IPY_MODEL_518fb202e4b44aaba47f07d1a61b6762", - "IPY_MODEL_672cdc5aea954de3af851c001a667ad3", - "IPY_MODEL_eebf8874618746b39cf4a21a2728dc7f" - ], - "layout": "IPY_MODEL_5176834aa8784bba9ec21234b87a8948", - "tabbable": null, - "tooltip": null - } - }, - "a8f1b72a33dd4b548de23fbd95e0da18": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "2.0.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "2.0.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "2.0.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border_bottom": null, - "border_left": null, - "border_right": null, - "border_top": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "b82112e1dec645d98aa1c1ba64abcb61": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "2.0.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "2.0.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "2.0.0", - "_view_name": "HTMLView", - "description": "", - "description_allow_html": false, - "layout": "IPY_MODEL_bfd06423ad544218968648016e731a46", - "placeholder": "​", - "style": "IPY_MODEL_d029630b63ff44cf807ade428d2eb421", - "tabbable": null, - "tooltip": null, - "value": " 870M/2.47G [00:20<00:37, 42.8MB/s]" - } - }, - "bfd06423ad544218968648016e731a46": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "2.0.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "2.0.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "2.0.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border_bottom": null, - "border_left": null, - "border_right": null, - "border_top": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "d029630b63ff44cf807ade428d2eb421": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "2.0.0", - "model_name": "HTMLStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "2.0.0", - "_model_name": "HTMLStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "2.0.0", - "_view_name": "StyleView", - "background": null, - "description_width": "", - "font_size": null, - "text_color": null - } - }, - "d2c41e71a3f441deaed091b620ac5603": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "2.0.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "2.0.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "2.0.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border_bottom": null, - "border_left": null, - "border_right": null, - "border_top": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "e2dc407afcd945c798e30597fddfcb3c": { - "model_module": "@jupyter-widgets/base", - "model_module_version": "2.0.0", - "model_name": "LayoutModel", - "state": { - "_model_module": "@jupyter-widgets/base", - "_model_module_version": "2.0.0", - "_model_name": "LayoutModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "2.0.0", - "_view_name": "LayoutView", - "align_content": null, - "align_items": null, - "align_self": null, - "border_bottom": null, - "border_left": null, - "border_right": null, - "border_top": null, - "bottom": null, - "display": null, - "flex": null, - "flex_flow": null, - "grid_area": null, - "grid_auto_columns": null, - "grid_auto_flow": null, - "grid_auto_rows": null, - "grid_column": null, - "grid_gap": null, - "grid_row": null, - "grid_template_areas": null, - "grid_template_columns": null, - "grid_template_rows": null, - "height": null, - "justify_content": null, - "justify_items": null, - "left": null, - "margin": null, - "max_height": null, - "max_width": null, - "min_height": null, - "min_width": null, - "object_fit": null, - "object_position": null, - "order": null, - "overflow": null, - "padding": null, - "right": null, - "top": null, - "visibility": null, - "width": null - } - }, - "ee44487f58454dacb522b1e084ffb733": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "2.0.0", - "model_name": "ProgressStyleModel", - "state": { - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "2.0.0", - "_model_name": "ProgressStyleModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/base", - "_view_module_version": "2.0.0", - "_view_name": "StyleView", - "bar_color": null, - "description_width": "" - } - }, - "eebf8874618746b39cf4a21a2728dc7f": { - "model_module": "@jupyter-widgets/controls", - "model_module_version": "2.0.0", - "model_name": "HTMLModel", - "state": { - "_dom_classes": [], - "_model_module": "@jupyter-widgets/controls", - "_model_module_version": "2.0.0", - "_model_name": "HTMLModel", - "_view_count": null, - "_view_module": "@jupyter-widgets/controls", - "_view_module_version": "2.0.0", - "_view_name": "HTMLView", - "description": "", - "description_allow_html": false, - "layout": "IPY_MODEL_d2c41e71a3f441deaed091b620ac5603", - "placeholder": "​", - "style": "IPY_MODEL_3326b6141a1a4eba9f316df528a9b99a", - "tabbable": null, - "tooltip": null, - "value": " 2.18M/2.18M [00:00<00:00, 9.47MB/s]" - } - } - } - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/ch05/07_gpt_to_llama/standalone-llama32.ipynb b/ch05/07_gpt_to_llama/standalone-llama32.ipynb index 9275955..dbec8ad 100644 --- a/ch05/07_gpt_to_llama/standalone-llama32.ipynb +++ b/ch05/07_gpt_to_llama/standalone-llama32.ipynb @@ -56,7 +56,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "id": "7c201adb-747e-437b-9a62-442802941e01", "metadata": {}, "outputs": [], @@ -66,7 +66,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 2, "id": "dd1b65a8-4301-444a-bd7c-a6f2bd1df9df", "metadata": { "colab": { @@ -81,9 +81,9 @@ "output_type": "stream", "text": [ "blobfile version: 3.0.0\n", - "huggingface_hub version: 0.25.2\n", - "tiktoken version: 0.8.0\n", - "torch version: 2.5.0\n" + "huggingface_hub version: 0.30.1\n", + "tiktoken version: 0.9.0\n", + "torch version: 2.6.0\n" ] } ], @@ -113,7 +113,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "id": "82076c21-9331-4dcd-b017-42b046cf1a60", "metadata": { "id": "82076c21-9331-4dcd-b017-42b046cf1a60" @@ -140,18 +140,18 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "id": "4b9a346f-5826-4083-9162-abd56afc03f0", "metadata": { "id": "4b9a346f-5826-4083-9162-abd56afc03f0" }, "outputs": [], "source": [ - "def precompute_rope_params(head_dim, theta_base=10_000, context_length=4096, freq_config=None):\n", + "def compute_rope_params(head_dim, theta_base=10_000, context_length=4096, freq_config=None, dtype=torch.float32):\n", " assert head_dim % 2 == 0, \"Embedding dimension must be even\"\n", "\n", " # Compute the inverse frequencies\n", - " inv_freq = 1.0 / (theta_base ** (torch.arange(0, head_dim, 2)[: (head_dim // 2)].float() / head_dim))\n", + " inv_freq = 1.0 / (theta_base ** (torch.arange(0, head_dim, 2, dtype=dtype)[: (head_dim // 2)].float() / head_dim))\n", "\n", " # Frequency adjustments\n", " if freq_config is not None:\n", @@ -177,7 +177,7 @@ " inv_freq = inv_freq_llama\n", "\n", " # Generate position indices\n", - " positions = torch.arange(context_length)\n", + " positions = torch.arange(context_length, dtype=dtype)\n", "\n", " # Compute the angles\n", " angles = positions[:, None] * inv_freq[None, :] # Shape: (context_length, head_dim // 2)\n", @@ -192,7 +192,7 @@ " return cos, sin\n", "\n", "\n", - "def compute_rope(x, cos, sin):\n", + "def apply_rope(x, cos, sin):\n", " # x: (batch_size, num_heads, seq_len, head_dim)\n", " batch_size, num_heads, seq_len, head_dim = x.shape\n", " assert head_dim % 2 == 0, \"Head dimension must be even\"\n", @@ -209,43 +209,23 @@ " rotated = torch.cat((-x2, x1), dim=-1)\n", " x_rotated = (x * cos) + (rotated * sin)\n", "\n", + " # It's ok to use lower-precision after applying cos and sin rotation\n", " return x_rotated.to(dtype=x.dtype)" ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "id": "e8169ab5-f976-4222-a2e1-eb1cabf267cb", "metadata": { "id": "e8169ab5-f976-4222-a2e1-eb1cabf267cb" }, "outputs": [], "source": [ - "class SharedBuffers:\n", - " _buffers = {}\n", - "\n", - " @staticmethod\n", - " def get_buffers(context_length, head_dim, rope_base, freq_config, dtype=torch.float32):\n", - " key = (context_length, head_dim, rope_base, tuple(freq_config.values()) if freq_config else freq_config, dtype)\n", - "\n", - " if key not in SharedBuffers._buffers:\n", - " # Create or fetch the buffers\n", - " mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)\n", - " cos, sin = precompute_rope_params(head_dim, rope_base, context_length, freq_config)\n", - " if dtype is not None:\n", - " cos = cos.to(dtype)\n", - " sin = sin.to(dtype)\n", - " SharedBuffers._buffers[key] = (mask, cos, sin)\n", - "\n", - " return SharedBuffers._buffers[key]\n", - "\n", - "\n", "class GroupedQueryAttention(nn.Module):\n", " def __init__(\n", - " self, d_in, d_out, context_length, num_heads,\n", + " self, d_in, d_out, num_heads,\n", " num_kv_groups,\n", - " rope_base=10_000,\n", - " rope_config=None,\n", " dtype=None\n", " ):\n", " super().__init__()\n", @@ -264,14 +244,7 @@ " self.W_query = nn.Linear(d_in, d_out, bias=False, dtype=dtype)\n", " self.out_proj = nn.Linear(d_out, d_out, bias=False, dtype=dtype)\n", "\n", - " # Fetch buffers using SharedBuffers\n", - " mask, cos, sin = SharedBuffers.get_buffers(context_length, self.head_dim, rope_base, rope_config, dtype)\n", - " self.register_buffer(\"mask\", mask, persistent=False)\n", - "\n", - " self.register_buffer(\"cos\", cos, persistent=False)\n", - " self.register_buffer(\"sin\", sin, persistent=False)\n", - "\n", - " def forward(self, x):\n", + " def forward(self, x, mask, cos, sin):\n", " b, num_tokens, d_in = x.shape\n", "\n", " queries = self.W_query(x) # Shape: (b, num_tokens, d_out)\n", @@ -289,8 +262,8 @@ " queries = queries.transpose(1, 2) # Shape: (b, num_query_groups, num_tokens, head_dim)\n", "\n", " # Apply RoPE\n", - " keys = compute_rope(keys, self.cos, self.sin)\n", - " queries = compute_rope(queries, self.cos, self.sin)\n", + " keys = apply_rope(keys, cos, sin)\n", + " queries = apply_rope(queries, cos, sin)\n", "\n", " # Expand keys and values to match the number of heads\n", " # Shape: (b, num_heads, num_tokens, head_dim)\n", @@ -307,11 +280,8 @@ " # Shape: (b, num_heads, num_tokens, num_tokens)\n", " attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head\n", "\n", - " # Original mask truncated to the number of tokens and converted to boolean\n", - " mask_bool = self.mask.bool()[:num_tokens, :num_tokens]\n", - "\n", - " # Use the mask to fill attention scores\n", - " attn_scores.masked_fill_(mask_bool, -torch.inf)\n", + " # Compute attention scores\n", + " attn_scores = attn_scores.masked_fill(mask, -torch.inf)\n", "\n", " attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)\n", " assert keys.shape[-1] == self.head_dim\n", @@ -328,7 +298,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 6, "id": "457cb2f8-50c1-4045-8a74-f181bfb5fea9", "metadata": { "id": "457cb2f8-50c1-4045-8a74-f181bfb5fea9" @@ -338,31 +308,28 @@ "class TransformerBlock(nn.Module):\n", " def __init__(self, cfg):\n", " super().__init__()\n", - " self.att = GroupedQueryAttention(\n", + " self.att = GroupedQueryAttention(\n", " d_in=cfg[\"emb_dim\"],\n", " d_out=cfg[\"emb_dim\"],\n", - " context_length=cfg[\"context_length\"],\n", " num_heads=cfg[\"n_heads\"],\n", " num_kv_groups=cfg[\"n_kv_groups\"],\n", - " rope_base=cfg[\"rope_base\"],\n", - " rope_config=cfg[\"rope_freq\"],\n", " dtype=cfg[\"dtype\"]\n", " )\n", " self.ff = FeedForward(cfg)\n", - " self.norm1 = nn.RMSNorm(cfg[\"emb_dim\"], eps=1e-5)\n", - " self.norm2 = nn.RMSNorm(cfg[\"emb_dim\"], eps=1e-5)\n", + " self.norm1 = nn.RMSNorm(cfg[\"emb_dim\"], eps=1e-5, dtype=cfg[\"dtype\"])\n", + " self.norm2 = nn.RMSNorm(cfg[\"emb_dim\"], eps=1e-5, dtype=cfg[\"dtype\"])\n", "\n", - " def forward(self, x):\n", + " def forward(self, x, mask, cos, sin):\n", " # Shortcut connection for attention block\n", " shortcut = x\n", " x = self.norm1(x)\n", - " x = self.att(x.to(torch.bfloat16)) # Shape [batch_size, num_tokens, emb_size]\n", + " x = self.att(x, mask, cos, sin) # Shape [batch_size, num_tokens, emb_size]\n", " x = x + shortcut # Add the original input back\n", "\n", " # Shortcut connection for feed-forward block\n", " shortcut = x\n", " x = self.norm2(x)\n", - " x = self.ff(x.to(torch.bfloat16))\n", + " x = self.ff(x)\n", " x = x + shortcut # Add the original input back\n", "\n", " return x" @@ -370,7 +337,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 7, "id": "e88de3e3-9f07-42cc-816b-28dbd46e96c4", "metadata": { "id": "e88de3e3-9f07-42cc-816b-28dbd46e96c4" @@ -380,20 +347,41 @@ "class Llama3Model(nn.Module):\n", " def __init__(self, cfg):\n", " super().__init__()\n", + "\n", + " # Main model parameters\n", " self.tok_emb = nn.Embedding(cfg[\"vocab_size\"], cfg[\"emb_dim\"], dtype=cfg[\"dtype\"])\n", "\n", - " self.trf_blocks = nn.Sequential(\n", - " *[TransformerBlock(cfg) for _ in range(cfg[\"n_layers\"])])\n", + " self.trf_blocks = nn.ModuleList( # ModuleList since Sequential can only accept one input, and we need `x, mask, cos, sin`\n", + " [TransformerBlock(cfg) for _ in range(cfg[\"n_layers\"])]\n", + " )\n", "\n", - " self.final_norm = nn.RMSNorm(cfg[\"emb_dim\"], eps=1e-5)\n", + " self.final_norm = nn.RMSNorm(cfg[\"emb_dim\"], eps=1e-5, dtype=cfg[\"dtype\"])\n", " self.out_head = nn.Linear(cfg[\"emb_dim\"], cfg[\"vocab_size\"], bias=False, dtype=cfg[\"dtype\"])\n", "\n", + " # Reusuable utilities\n", + " cos, sin = compute_rope_params(\n", + " head_dim=cfg[\"emb_dim\"] // cfg[\"n_heads\"],\n", + " theta_base=cfg[\"rope_base\"],\n", + " context_length=cfg[\"context_length\"],\n", + " freq_config=cfg[\"rope_freq\"]\n", + " )\n", + " self.register_buffer(\"cos\", cos, persistent=False)\n", + " self.register_buffer(\"sin\", sin, persistent=False)\n", + " self.cfg = cfg\n", + "\n", + "\n", " def forward(self, in_idx):\n", + " # Forward pass\n", " tok_embeds = self.tok_emb(in_idx)\n", " x = tok_embeds\n", - " x = self.trf_blocks(x)\n", + "\n", + " num_tokens = x.shape[1]\n", + " mask = torch.triu(torch.ones(num_tokens, num_tokens, device=x.device, dtype=torch.bool), diagonal=1)\n", + " \n", + " for block in self.trf_blocks:\n", + " x = block(x, mask, self.cos, self.sin)\n", " x = self.final_norm(x)\n", - " logits = self.out_head(x.to(torch.bfloat16))\n", + " logits = self.out_head(x.to(self.cfg[\"dtype\"]))\n", " return logits" ] }, @@ -420,7 +408,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 8, "id": "caa142fa-b375-4e78-b392-2072ced666f3", "metadata": { "id": "caa142fa-b375-4e78-b392-2072ced666f3" @@ -430,16 +418,16 @@ "# Llama 3.2 1B\n", "\n", "LLAMA32_CONFIG = {\n", - " \"vocab_size\": 128_256, # Vocabulary size\n", - " \"context_length\": 131_072, # Context length\n", - " \"emb_dim\": 2048, # Embedding dimension\n", - " \"n_heads\": 32, # Number of attention heads\n", - " \"n_layers\": 16, # Number of layers\n", - " \"hidden_dim\": 8192, # Size of the intermediate dimension in FeedForward\n", - " \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n", - " \"rope_base\": 500_000.0, # The base in RoPE's \"theta\"\n", - " \"dtype\": torch.bfloat16, # Lower-precision dtype to reduce memory usage\n", - " \"rope_freq\": { # RoPE frequency scaling\n", + " \"vocab_size\": 128_256, # Vocabulary size\n", + " \"context_length\": 131_072, # Context length that was used to train the model\n", + " \"emb_dim\": 2048, # Embedding dimension\n", + " \"n_heads\": 32, # Number of attention heads\n", + " \"n_layers\": 16, # Number of layers\n", + " \"hidden_dim\": 8192, # Size of the intermediate dimension in FeedForward\n", + " \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n", + " \"rope_base\": 500_000.0, # The base in RoPE's \"theta\"\n", + " \"dtype\": torch.bfloat16, # Lower-precision dtype to reduce memory usage\n", + " \"rope_freq\": { # RoPE frequency scaling\n", " \"factor\": 32.0,\n", " \"low_freq_factor\": 1.0,\n", " \"high_freq_factor\": 4.0,\n", @@ -450,16 +438,16 @@ "# Llama 3.2 3B\n", "\n", "# LLAMA32_CONFIG = {\n", - "# \"vocab_size\": 128_256, # Vocabulary size\n", - "# \"context_length\": 131_072, # Context length\n", - "# \"emb_dim\": 3072, # Embedding dimension\n", - "# \"n_heads\": 24, # Number of attention heads\n", - "# \"n_layers\": 28, # Number of layers\n", - "# \"hidden_dim\": 8192, # Size of the intermediate dimension in FeedForward\n", - "# \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n", - "# \"rope_base\": 500_000.0, # The base in RoPE's \"theta\"\n", - "# \"dtype\": torch.bfloat16, # Lower-precision dtype to reduce memory usage\n", - "# \"rope_freq\": { # RoPE frequency scaling\n", + "# \"vocab_size\": 128_256, # Vocabulary size\n", + "# \"context_length\": 131_072, # Context length that was used to train the model\n", + "# \"emb_dim\": 3072, # Embedding dimension\n", + "# \"n_heads\": 24, # Number of attention heads\n", + "# \"n_layers\": 28, # Number of layers\n", + "# \"hidden_dim\": 8192, # Size of the intermediate dimension in FeedForward\n", + "# \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n", + "# \"rope_base\": 500_000.0, # The base in RoPE's \"theta\"\n", + "# \"dtype\": torch.bfloat16, # Lower-precision dtype to reduce memory usage\n", + "# \"rope_freq\": { # RoPE frequency scaling\n", "# \"factor\": 32.0,\n", "# \"low_freq_factor\": 1.0,\n", "# \"high_freq_factor\": 4.0,\n", @@ -470,54 +458,9 @@ "LLAMA_SIZE_STR = \"1B\" if LLAMA32_CONFIG[\"emb_dim\"] == 2048 else \"3B\"" ] }, - { - "cell_type": "markdown", - "id": "34535172-797e-4dd0-84fb-65bc75ad5b06", - "metadata": { - "id": "34535172-797e-4dd0-84fb-65bc75ad5b06" - }, - "source": [ - "- Reduce the context length so the model would work fine on a MacBook Air (if you have more RAM, feel free to comment out the lines below):" - ] - }, { "cell_type": "code", - "execution_count": 10, - "id": "a8bc2370-39d2-4bfe-b4c1-6bdd75fe101c", - "metadata": { - "id": "a8bc2370-39d2-4bfe-b4c1-6bdd75fe101c" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "New RoPE theta: 31250.0\n" - ] - } - ], - "source": [ - "old_context_length = LLAMA32_CONFIG[\"context_length\"]\n", - "LLAMA32_CONFIG[\"context_length\"] = 8192\n", - "\n", - "\n", - "def rescale_theta(theta_old, context_length_old, context_length_new):\n", - " scaling_factor = context_length_new / context_length_old\n", - " theta_new = theta_old * scaling_factor\n", - " return theta_new\n", - "\n", - "LLAMA32_CONFIG[\"rope_base\"] = rescale_theta(\n", - " LLAMA32_CONFIG[\"rope_base\"],\n", - " old_context_length,\n", - " LLAMA32_CONFIG[\"context_length\"]\n", - ")\n", - "\n", - "print(\"New RoPE theta:\", LLAMA32_CONFIG[\"rope_base\"])" - ] - }, - { - "cell_type": "code", - "execution_count": 11, + "execution_count": 9, "id": "156253fe-aacd-4da2-8f13-705f05c4b11e", "metadata": { "id": "156253fe-aacd-4da2-8f13-705f05c4b11e" @@ -539,36 +482,7 @@ }, { "cell_type": "code", - "execution_count": 12, - "id": "0e95db6d-2712-41a5-a5e0-86c49897f4cf", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "0e95db6d-2712-41a5-a5e0-86c49897f4cf", - "outputId": "8efc4937-e616-40d0-cd59-670d7eb3e841" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "True\n", - "True\n", - "True\n" - ] - } - ], - "source": [ - "# Check buffers\n", - "print(model.trf_blocks[0].att.mask is model.trf_blocks[-1].att.mask)\n", - "print(model.trf_blocks[0].att.cos is model.trf_blocks[-1].att.cos)\n", - "print(model.trf_blocks[0].att.sin is model.trf_blocks[-1].att.sin)" - ] - }, - { - "cell_type": "code", - "execution_count": 13, + "execution_count": 10, "id": "364e76ca-52f8-4fa5-af37-c4069f9694bc", "metadata": { "colab": { @@ -599,7 +513,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 11, "id": "fd5efb03-5a07-46e8-8607-93ed47549d2b", "metadata": { "colab": { @@ -613,8 +527,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "float32 (PyTorch default): 11.42 GB\n", - "bfloat16: 5.71 GB\n" + "float32 (PyTorch default): 11.23 GB\n", + "bfloat16: 5.61 GB\n" ] } ], @@ -649,7 +563,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 12, "id": "31f12baf-f79b-499f-85c0-51328a6a20f5", "metadata": { "id": "31f12baf-f79b-499f-85c0-51328a6a20f5" @@ -679,7 +593,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 13, "id": "9482b01c-49f9-48e4-ab2c-4a4c75240e77", "metadata": { "id": "9482b01c-49f9-48e4-ab2c-4a4c75240e77" @@ -693,73 +607,86 @@ "from tiktoken.load import load_tiktoken_bpe\n", "\n", "\n", - "class Tokenizer:\n", - " def __init__(self, model_path):\n", - " assert os.path.isfile(model_path), f\"Model file {model_path} not found\"\n", - " mergeable_ranks = load_tiktoken_bpe(model_path)\n", "\n", - " self.special_tokens = {\n", + "class Tokenizer:\n", + " \"\"\"Thin wrapper around tiktoken that keeps track of Llama-3 special IDs.\"\"\"\n", + " def __init__(self, model_path):\n", + " if not os.path.isfile(model_path):\n", + " raise FileNotFoundError(model_path)\n", + "\n", + " mergeable = load_tiktoken_bpe(model_path)\n", + "\n", + " # hard-coded from Meta's tokenizer.json\n", + " self.special = {\n", " \"<|begin_of_text|>\": 128000,\n", " \"<|end_of_text|>\": 128001,\n", " \"<|start_header_id|>\": 128006,\n", " \"<|end_header_id|>\": 128007,\n", " \"<|eot_id|>\": 128009,\n", " }\n", - " self.special_tokens.update({\n", - " f\"<|reserved_{i}|>\": 128002 + i for i in range(256) if (128002 + i) not in self.special_tokens.values()\n", - " })\n", + " self.special.update({f\"<|reserved_{i}|>\": 128002 + i\n", + " for i in range(256)\n", + " if 128002 + i not in self.special.values()})\n", "\n", " self.model = tiktoken.Encoding(\n", " name=Path(model_path).name,\n", - " pat_str=r\"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+\",\n", - " mergeable_ranks=mergeable_ranks,\n", - " special_tokens=self.special_tokens\n", + " pat_str=r\"(?i:'s|'t|'re|'ve|'m|'ll|'d)\"\n", + " r\"|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+\"\n", + " r\"|\\p{N}{1,3}\"\n", + " r\"| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*\"\n", + " r\"|\\s*[\\r\\n]+\"\n", + " r\"|\\s+(?!\\S)\"\n", + " r\"|\\s+\",\n", + " mergeable_ranks=mergeable,\n", + " special_tokens=self.special,\n", " )\n", "\n", - "\n", - " def encode(self, text, bos=False, eos=False, allowed_special=set(), disallowed_special=()):\n", - " if bos:\n", - " tokens = [self.special_tokens[\"<|begin_of_text|>\"]]\n", - " else:\n", - " tokens = []\n", - "\n", - " tokens += self.model.encode(text, allowed_special=allowed_special, disallowed_special=disallowed_special)\n", - "\n", + " def encode(self, text, bos=False, eos=False):\n", + " ids = ([self.special[\"<|begin_of_text|>\"]] if bos else []) \\\n", + " + self.model.encode(text)\n", " if eos:\n", - " tokens.append(self.special_tokens[\"<|end_of_text|>\"])\n", - " return tokens\n", + " ids.append(self.special[\"<|end_of_text|>\"])\n", + " return ids\n", "\n", - " def decode(self, tokens):\n", - " return self.model.decode(tokens)\n", + " def decode(self, ids):\n", + " return self.model.decode(ids)\n", "\n", "\n", "class ChatFormat:\n", - " def __init__(self, tokenizer):\n", - " self.tokenizer = tokenizer\n", "\n", - " def encode_header(self, message):\n", - " tokens = []\n", - " tokens.append(self.tokenizer.special_tokens[\"<|start_header_id|>\"])\n", - " tokens.extend(self.tokenizer.encode(message[\"role\"], bos=False, eos=False))\n", - " tokens.append(self.tokenizer.special_tokens[\"<|end_header_id|>\"])\n", - " tokens.extend(self.tokenizer.encode(\"\\n\\n\", bos=False, eos=False))\n", - " return tokens\n", + " def __init__(self, tokenizer: Tokenizer, *,\n", + " default_system=\"You are a helpful assistant.\"):\n", + " self.tok = tokenizer\n", + " self.default_system = default_system\n", "\n", - " def encode(self, text):\n", - " message = {\n", - " \"role\": \"user\",\n", - " \"content\": text\n", - " }\n", - "\n", - " tokens = self.encode_header(message)\n", - " tokens.extend(\n", - " self.tokenizer.encode(message[\"content\"].strip(), bos=False, eos=False)\n", + " def _header(self, role):\n", + " \"\"\"Encode <|start_header_id|>role<|end_header_id|>\\n\\n\"\"\"\n", + " return (\n", + " [self.tok.special[\"<|start_header_id|>\"]]\n", + " + self.tok.encode(role)\n", + " + [self.tok.special[\"<|end_header_id|>\"]]\n", + " + self.tok.encode(\"\\n\\n\")\n", " )\n", - " tokens.append(self.tokenizer.special_tokens[\"<|eot_id|>\"])\n", - " return tokens\n", "\n", - " def decode(self, token_ids):\n", - " return self.tokenizer.decode(token_ids)" + " def encode(self, user_message, system_message=None):\n", + " sys_msg = system_message if system_message is not None else self.default_system\n", + "\n", + " ids = [self.tok.special[\"<|begin_of_text|>\"]]\n", + "\n", + " # system\n", + " ids += self._header(\"system\")\n", + " ids += self.tok.encode(sys_msg)\n", + " ids += [self.tok.special[\"<|eot_id|>\"]]\n", + "\n", + " # user\n", + " ids += self._header(\"user\")\n", + " ids += self.tok.encode(user_message)\n", + " ids += [self.tok.special[\"<|eot_id|>\"]]\n", + "\n", + " # assistant header (no content yet)\n", + " ids += self._header(\"assistant\")\n", + "\n", + " return ids" ] }, { @@ -782,7 +709,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 14, "id": "e9d96dc8-603a-4cb5-8c3e-4d2ca56862ed", "metadata": { "colab": { @@ -793,25 +720,24 @@ }, "outputs": [ { - "name": "stdout", + "name": "stderr", "output_type": "stream", "text": [ - "The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.\n", - "Token is valid (permission: read).\n", - "Your token has been saved to /teamspace/studios/this_studio/.cache/huggingface/token\n", - "Login successful\n" + "/Users/sebastian/Developer/LLMs-from-scratch/.venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], "source": [ - "from huggingface_hub import login\n", + "# Uncomment and run the following code if you are executing the notebook for the first time\n", "\n", - "login()" + "# from huggingface_hub import login\n", + "# login()" ] }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 15, "id": "986bc1a0-804f-4154-80f8-44cefbee1368", "metadata": { "colab": { @@ -847,7 +773,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 16, "id": "_gBhxDtU_nxo", "metadata": { "id": "_gBhxDtU_nxo" @@ -871,7 +797,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 17, "id": "75166128-5899-4995-9b88-9672e135650e", "metadata": { "id": "75166128-5899-4995-9b88-9672e135650e" @@ -954,7 +880,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 18, "id": "699cb1b8-a67d-49fb-80a6-0dad9d81f392", "metadata": { "colab": { @@ -1018,7 +944,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 19, "id": "7f9f7ccc-70cb-41ff-9c25-44336042fc37", "metadata": { "id": "7f9f7ccc-70cb-41ff-9c25-44336042fc37" @@ -1049,7 +975,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 20, "id": "7b8401c6-e244-4cb7-9849-2ba71ce758d5", "metadata": { "id": "7b8401c6-e244-4cb7-9849-2ba71ce758d5" @@ -1108,7 +1034,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 21, "id": "1c7a04fa-6aac-416b-8f63-f1e19227633d", "metadata": { "id": "1c7a04fa-6aac-416b-8f63-f1e19227633d" @@ -1118,23 +1044,31 @@ "name": "stdout", "output_type": "stream", "text": [ + "Time: 18.20 sec\n", + "\n", + "\n", "Output text:\n", - " Llamas are herbivores, which means they primarily eat plants. Their diet consists mainly of:\n", "\n", - "1. Grasses: Llamas love to graze on various types of grasses, including tall grasses and grassy meadows.\n", - "2. Hay: Llamas also eat hay, which is a dry, compressed form of grass or other plants.\n", - "3. Alfalfa: Alfalfa is a legume that is commonly fed to llamas. It is high in protein and fiber.\n", - "4. Other plants: Llamas will also eat other plants, such as wild grasses, shrubs, and trees.\n", + " Llamas are herbivores, which means they primarily eat plants and plant-based foods. Their diet typically consists of:\n", "\n", - "It's worth noting that the diet of llamas can vary depending on the region, climate,\n" + "1. Grasses: Llamas love to graze on various types of grasses, including tall grasses, short grasses, and grassy weeds.\n", + "2. Hay: They also enjoy munching on hay, which is a dry, compressed form of grass or other plant material.\n", + "3. Leaves: Llamas will eat leaves from trees and shrubs, including leaves from plants like clover, alfalfa, and grasses.\n", + "4. Fruits and vegetables: In the wild, llamas will eat fruits and vegetables like berries, apples, and carrots.\n", + "5. Browse: Llamas will also\n" ] } ], "source": [ + "import time\n", + "\n", + "\n", "PROMPT = \"What do llamas eat?\"\n", "\n", "torch.manual_seed(123)\n", "\n", + "start = time.time()\n", + "\n", "token_ids = generate(\n", " model=model,\n", " idx=text_to_token_ids(PROMPT, chat_tokenizer).to(device),\n", @@ -1144,6 +1078,13 @@ " temperature=0.\n", ")\n", "\n", + "print(f\"Time: {time.time() - start:.2f} sec\")\n", + "\n", + "if torch.cuda.is_available():\n", + " max_mem_bytes = torch.cuda.max_memory_allocated()\n", + " max_mem_gb = max_mem_bytes / (1024 ** 3)\n", + " print(f\"Max memory allocated: {max_mem_gb:.2f} GB\")\n", + "\n", "output_text = token_ids_to_text(token_ids, tokenizer)\n", "\n", "\n", @@ -1158,7 +1099,7 @@ " # If the token is not found, return the original text\n", " return text\n", "\n", - "print(\"Output text:\\n\", clean_text(output_text))" + "print(\"\\n\\nOutput text:\\n\\n\", clean_text(output_text))" ] }, { diff --git a/pkg/llms_from_scratch/README.md b/pkg/llms_from_scratch/README.md index 7b2bddd..dc423b6 100644 --- a/pkg/llms_from_scratch/README.md +++ b/pkg/llms_from_scratch/README.md @@ -110,12 +110,21 @@ from llms_from_scratch.appendix_a import NeuralNetwork, ToyDataset from llms_from_scratch.appendix_d import find_highest_gradient, train_model +``` + + + +### Llama 3 (Bonus material) + +```python from llms_from_scratch.llama3 import ( Llama3Model, + Llama3ModelFast, Llama3Tokenizer, ChatFormat, clean_text ) ``` -(For the `llms_from_scratch.llama3` usage information, please see [this bonus section](../../ch05/07_gpt_to_llama/README.md). + +For the `llms_from_scratch.llama3` usage information, please see [this bonus section](../../ch05/07_gpt_to_llama/README.md). diff --git a/pkg/llms_from_scratch/llama3.py b/pkg/llms_from_scratch/llama3.py index 2776882..df7bc72 100644 --- a/pkg/llms_from_scratch/llama3.py +++ b/pkg/llms_from_scratch/llama3.py @@ -15,8 +15,7 @@ from tiktoken.load import load_tiktoken_bpe LLAMA32_CONFIG_1B = { "vocab_size": 128_256, # Vocabulary size - "context_length": 8192, # Maximum context length to use (reduced to save memory) - "orig_context_length": 131_072, # Context length that was used to train the model + "context_length": 131_072, # Context length that was used to train the model "emb_dim": 2048, # Embedding dimension "n_heads": 32, # Number of attention heads "n_layers": 16, # Number of layers @@ -34,8 +33,7 @@ LLAMA32_CONFIG_1B = { LLAMA32_CONFIG_3B = { "vocab_size": 128_256, # Vocabulary size - "context_length": 8192, # Maximum context length to use (reduced to save memory) - "orig_context_length": 131_072, # Context length that was used to train the model + "context_length": 131_072, # Context length that was used to train the model "emb_dim": 3072, # Embedding dimension "n_heads": 24, # Number of attention heads "n_layers": 28, # Number of layers @@ -67,17 +65,6 @@ class Llama3Model(nn.Module): self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False, dtype=cfg["dtype"]) # Reusuable utilities - self.register_buffer( - "mask", torch.triu(torch.ones(cfg["context_length"], cfg["context_length"]), diagonal=1).bool(), - persistent=False - ) - - if cfg["orig_context_length"] != cfg["context_length"]: - cfg["rope_base"] = rescale_theta( - cfg["rope_base"], - cfg["orig_context_length"], - cfg["context_length"] - ) cos, sin = compute_rope_params( head_dim=cfg["emb_dim"] // cfg["n_heads"], theta_base=cfg["rope_base"], @@ -92,8 +79,11 @@ class Llama3Model(nn.Module): tok_embeds = self.tok_emb(in_idx) x = tok_embeds + num_tokens = x.shape[1] + mask = torch.triu(torch.ones(num_tokens, num_tokens, device=x.device, dtype=torch.bool), diagonal=1) + for block in self.trf_blocks: - x = block(x, self.mask, self.cos, self.sin) + x = block(x, mask, self.cos, self.sin) x = self.final_norm(x) logits = self.out_head(x.to(self.cfg["dtype"])) return logits @@ -281,88 +271,104 @@ def apply_rope(x, cos, sin): return x_rotated.to(dtype=x.dtype) -def rescale_theta(theta_old, context_length_old, context_length_new): - scaling_factor = context_length_new / context_length_old - theta_new = theta_old * scaling_factor - return theta_new - - ########################################## # Tokenizer ########################################## class Llama3Tokenizer: + """Thin wrapper around tiktoken that keeps track of Llama-3 special IDs.""" def __init__(self, model_path): - assert os.path.isfile(model_path), f"Model file {model_path} not found" - mergeable_ranks = load_tiktoken_bpe(model_path) + if not os.path.isfile(model_path): + raise FileNotFoundError(model_path) - self.special_tokens = { + mergeable = load_tiktoken_bpe(model_path) + + # hard-coded from Meta's tokenizer.json + self.special = { "<|begin_of_text|>": 128000, "<|end_of_text|>": 128001, "<|start_header_id|>": 128006, "<|end_header_id|>": 128007, "<|eot_id|>": 128009, } - self.special_tokens.update({ - f"<|reserved_{i}|>": 128002 + i for i in range(256) if (128002 + i) not in self.special_tokens.values() - }) + self.special.update({f"<|reserved_{i}|>": 128002 + i + for i in range(256) + if 128002 + i not in self.special.values()}) self.model = tiktoken.Encoding( name=Path(model_path).name, - pat_str=r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+", - mergeable_ranks=mergeable_ranks, - special_tokens=self.special_tokens + pat_str=r"(?i:'s|'t|'re|'ve|'m|'ll|'d)" + r"|[^\r\n\p{L}\p{N}]?\p{L}+" + r"|\p{N}{1,3}" + r"| ?[^\s\p{L}\p{N}]+[\r\n]*" + r"|\s*[\r\n]+" + r"|\s+(?!\S)" + r"|\s+", + mergeable_ranks=mergeable, + special_tokens=self.special, ) - def encode(self, text, bos=False, eos=False, allowed_special=set(), disallowed_special=()): + def encode(self, text, bos=False, eos=False, allowed_special=set()): + ids: list[int] = [] + if bos: - tokens = [self.special_tokens["<|begin_of_text|>"]] - else: - tokens = [] - - tokens += self.model.encode(text, allowed_special=allowed_special, disallowed_special=disallowed_special) + ids.append(self.special_tokens["<|begin_of_text|>"]) + # delegate to underlying tiktoken.Encoding.encode + ids.extend( + self.model.encode( + text, + allowed_special=allowed_special, + ) + ) if eos: - tokens.append(self.special_tokens["<|end_of_text|>"]) - return tokens + ids.append(self.special_tokens["<|end_of_text|>"]) - def decode(self, tokens): - return self.model.decode(tokens) + return ids + + def decode(self, ids): + return self.model.decode(ids) class ChatFormat: - def __init__(self, tokenizer): - self.tokenizer = tokenizer - def encode_header(self, message): - tokens = [] - tokens.append(self.tokenizer.special_tokens["<|start_header_id|>"]) - tokens.extend(self.tokenizer.encode(message["role"], bos=False, eos=False)) - tokens.append(self.tokenizer.special_tokens["<|end_header_id|>"]) - tokens.extend(self.tokenizer.encode("\n\n", bos=False, eos=False)) - return tokens + def __init__(self, tokenizer: Llama3Tokenizer, *, + default_system="You are a helpful assistant."): + self.tok = tokenizer + self.default_system = default_system - def encode(self, text, allowed_special=None): - message = { - "role": "user", - "content": text - } - - tokens = self.encode_header(message) - tokens.extend( - self.tokenizer.encode( - message["content"].strip(), - bos=False, - eos=False, - allowed_special=allowed_special - ) + def _header(self, role): + """Encode <|start_header_id|>role<|end_header_id|>\n\n""" + return ( + [self.tok.special["<|start_header_id|>"]] + + self.tok.encode(role) + + [self.tok.special["<|end_header_id|>"]] + + self.tok.encode("\n\n") ) - tokens.append(self.tokenizer.special_tokens["<|eot_id|>"]) - return tokens - def decode(self, token_ids): - return self.tokenizer.decode(token_ids) + def encode(self, user_message, system_message=None, allowed_special=None): + sys_msg = system_message if system_message is not None else self.default_system + + ids = [self.tok.special["<|begin_of_text|>"]] + + # system + ids += self._header("system") + ids += self.tok.encode(sys_msg, allowed_special=allowed_special) + ids += [self.tok.special["<|eot_id|>"]] + + # user + ids += self._header("user") + ids += self.tok.encode(user_message) + ids += [self.tok.special["<|eot_id|>"]] + + # assistant header (no content yet) + ids += self._header("assistant") + + return ids + + def decode(self, ids): + return self.tok.decode(ids) def clean_text(text, header_end="assistant<|end_header_id|>\n\n"): @@ -483,12 +489,6 @@ class Llama3ModelFast(nn.Module): self.final_norm = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"]) self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False, dtype=cfg["dtype"]) - if cfg["orig_context_length"] != cfg["context_length"]: - cfg["rope_base"] = rescale_theta( - cfg["rope_base"], - cfg["orig_context_length"], - cfg["context_length"] - ) cos, sin = compute_rope_params( head_dim=cfg["emb_dim"] // cfg["n_heads"], theta_base=cfg["rope_base"], diff --git a/pkg/llms_from_scratch/tests/test_llama3.py b/pkg/llms_from_scratch/tests/test_llama3.py index 0ffdc09..1719976 100644 --- a/pkg/llms_from_scratch/tests/test_llama3.py +++ b/pkg/llms_from_scratch/tests/test_llama3.py @@ -7,7 +7,6 @@ from llms_from_scratch.ch04 import generate_text_simple from llms_from_scratch.llama3 import ( compute_rope_params, apply_rope, - rescale_theta, LLAMA32_CONFIG_1B, GroupedQueryAttention, GroupedQueryAttentionFast, @@ -102,23 +101,6 @@ GPT_CONFIG_124M = { } -def test_rescale(): - - new_theta = rescale_theta( - theta_old=500_000., - context_length_old=131_072, - context_length_new=8192 - ) - assert new_theta == 31250. - - old_theta = rescale_theta( - theta_old=new_theta, - context_length_old=8192, - context_length_new=131_072 - ) - assert old_theta == 500_000. - - def test_grouped_query_attention_equivalence(): torch.manual_seed(42) b, t, d_in, d_out, num_heads, num_kv_groups = 2, 8, 32, 64, 4, 2 @@ -194,6 +176,6 @@ def test_gpt_model_variants(ModelClass, llama3_weights_path): ) print("Encoded output text:", out) expect = torch.tensor([ - [43, 2543, 292, 4483, 100383, 8113, 21197, 33804, 54419] + [43, 2543, 292, 4483, 100383, 8113, 76873, 42175, 72641] ]) assert torch.equal(expect, out) diff --git a/pyproject.toml b/pyproject.toml index d9997a9..52b6e28 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "llms-from-scratch" -version = "1.0.6" +version = "1.0.7" description = "Implement a ChatGPT-like LLM in PyTorch from scratch, step by step" readme = "README.md" requires-python = ">=3.10"