Reduce Llama 3 RoPE memory requirements (#658)

* Llama3 from scratch improvements

* Fix Llama 3 expensive RoPE memory issue

* updates

* update package

* benchmark

* remove unused rescale_theta
This commit is contained in:
Sebastian Raschka 2025-06-12 11:08:02 -05:00 committed by GitHub
parent c278745aff
commit c4cde1c21b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 405 additions and 2577 deletions

3
.gitignore vendored
View File

@ -51,6 +51,9 @@ ch05/07_gpt_to_llama/Llama-3.2-3B-Instruct
ch05/10_llm-training-speed/middlemarch.txt ch05/10_llm-training-speed/middlemarch.txt
ch05/10_llm-training-speed/loss.pdf ch05/10_llm-training-speed/loss.pdf
ch05/10_llm-training-speed/model.pth ch05/10_llm-training-speed/model.pth
ch05/07_gpt_to_llama/Untitled.ipynb
ch05/07_gpt_to_llama/llama3.2-1B-instruct.pth
ch05/07_gpt_to_llama/tokenizer.model
ch06/01_main-chapter-code/gpt2 ch06/01_main-chapter-code/gpt2
ch06/02_bonus_additional-experiments/gpt2 ch06/02_bonus_additional-experiments/gpt2

View File

@ -40,8 +40,6 @@ MODEL_FILE = "llama3.2-1B-instruct.pth"
Basic text generation settings that can be defined by the user. Note that the recommended 8192-token context size requires approximately 3 GB of VRAM for the text generation example. Basic text generation settings that can be defined by the user. Note that the recommended 8192-token context size requires approximately 3 GB of VRAM for the text generation example.
```python ```python
MODEL_CONTEXT_LENGTH = 8192 # Supports up to 131_072
# Text generation settings # Text generation settings
if "instruct" in MODEL_FILE: if "instruct" in MODEL_FILE:
PROMPT = "What do llamas eat?" PROMPT = "What do llamas eat?"
@ -82,8 +80,6 @@ elif "3B" in MODEL_FILE:
else: else:
raise ValueError("Incorrect model file name") raise ValueError("Incorrect model file name")
LLAMA32_CONFIG["context_length"] = MODEL_CONTEXT_LENGTH
model = Llama3Model(LLAMA32_CONFIG) model = Llama3Model(LLAMA32_CONFIG)
model.load_state_dict(torch.load(MODEL_FILE, weights_only=True, map_location="cpu")) model.load_state_dict(torch.load(MODEL_FILE, weights_only=True, map_location="cpu"))
@ -125,7 +121,7 @@ Lastly, we can generate text via the following code:
```python ```python
import time import time
from ch05 import ( from llms_from_scratch.ch05 import (
generate, generate,
text_to_token_ids, text_to_token_ids,
token_ids_to_text token_ids_to_text
@ -192,8 +188,8 @@ The following table shows a performance comparison on an A100:
| | Tokens/sec | Memory | | | Tokens/sec | Memory |
| --------------- | ---------- | ------- | | --------------- | ---------- | ------- |
| Llama3Model | 50 | 2.91 GB | | Llama3Model | 42 | 2.91 GB |
| Llama3ModelFast | 58 | 2.85 GB | | Llama3ModelFast | 54 | 2.91 GB |
   
#### Pro tip 2: speed up inference with compilation #### Pro tip 2: speed up inference with compilation
@ -218,5 +214,5 @@ The following table shows a performance comparison on an A100 for consequent `ge
| | Tokens/sec | Memory | | | Tokens/sec | Memory |
| --------------- | ---------- | ------- | | --------------- | ---------- | ------- |
| Llama3Model | 156 | 3.12 GB | | Llama3Model | 170 | 3.12 GB |
| Llama3ModelFast | 159 | 2.84 GB | | Llama3ModelFast | 177 | 3.61 GB |

View File

@ -95,9 +95,9 @@
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"blobfile version: 3.0.0\n", "blobfile version: 3.0.0\n",
"huggingface_hub version: 0.24.7\n", "huggingface_hub version: 0.30.1\n",
"tiktoken version: 0.8.0\n", "tiktoken version: 0.9.0\n",
"torch version: 2.4.1+cu121\n" "torch version: 2.6.0\n"
] ]
} }
], ],
@ -435,7 +435,7 @@
"id": "842aa71a-4659-424e-8830-392bd6ae86af", "id": "842aa71a-4659-424e-8830-392bd6ae86af",
"metadata": {}, "metadata": {},
"source": [ "source": [
"- In addition, we also introduce a `SharedBuffers` class that will allow us to reuse the `mask`, `cos`, and `sin` tensors in the transformer blocks to improve efficiency (this will be crucial when working with models such as Llama 3.1 and 3.2 later, which support up to 131k input tokens)" "- **We also redesign the attention class a bit so it receives the mask through its forward method instead of storing and accessing it as `self.mask`. This lets us build the mask on the fly to reduce memory usage. To foreshadow why: Llama3.1 can handle sequences of up to 128k tokens, and precomputing a 128k×128k causal mask would be extremely memoryintensive, so we avoid it unless absolutely necessary.**"
] ]
}, },
{ {
@ -450,27 +450,6 @@
"import torch.nn as nn\n", "import torch.nn as nn\n",
"\n", "\n",
"\n", "\n",
"############################# NEW #############################\n",
"class SharedBuffers:\n",
" _buffers = {}\n",
"\n",
" @staticmethod\n",
" def get_buffers(context_length, head_dim, rope_base, freq_config, dtype=torch.float32):\n",
" key = (context_length, head_dim, rope_base, tuple(freq_config.values()) if freq_config else freq_config, dtype)\n",
"\n",
" if key not in SharedBuffers._buffers:\n",
" # Create or fetch the buffers\n",
" mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)\n",
" cos, sin = precompute_rope_params(head_dim, rope_base, context_length, freq_config)\n",
" if dtype is not None:\n",
" cos = cos.to(dtype)\n",
" sin = sin.to(dtype)\n",
" SharedBuffers._buffers[key] = (mask, cos, sin)\n",
"\n",
" return SharedBuffers._buffers[key]\n",
"############################# NEW #############################\n",
"\n",
"\n",
"class GroupedQueryAttention(nn.Module):\n", "class GroupedQueryAttention(nn.Module):\n",
" def __init__(\n", " def __init__(\n",
" self, d_in, d_out, context_length, num_heads,\n", " self, d_in, d_out, context_length, num_heads,\n",
@ -499,16 +478,12 @@
" self.W_query = nn.Linear(d_in, d_out, bias=False, dtype=dtype)\n", " self.W_query = nn.Linear(d_in, d_out, bias=False, dtype=dtype)\n",
" self.out_proj = nn.Linear(d_out, d_out, bias=False, dtype=dtype)\n", " self.out_proj = nn.Linear(d_out, d_out, bias=False, dtype=dtype)\n",
"\n", "\n",
" ############################# NEW #############################\n",
" # Fetch buffers using SharedBuffers\n",
" mask, cos, sin = SharedBuffers.get_buffers(context_length, self.head_dim, rope_base, rope_config, dtype)\n",
" ############################# NEW #############################\n",
" \n",
" self.register_buffer(\"mask\", mask)\n",
" self.register_buffer(\"cos\", cos)\n",
" self.register_buffer(\"sin\", sin)\n",
"\n", "\n",
" def forward(self, x):\n", " def forward(self, x, mask=None, cos=None, sin=None):\n",
" ##################### NEW #####################\n",
" # The forward method now accepts `mask` instead of accessing it via self.mask.\n",
" # Also, we now have cos and sin as input for RoPE\n",
" ################################################ \n",
" b, num_tokens, d_in = x.shape\n", " b, num_tokens, d_in = x.shape\n",
"\n", "\n",
" queries = self.W_query(x) # Shape: (b, num_tokens, d_out)\n", " queries = self.W_query(x) # Shape: (b, num_tokens, d_out)\n",
@ -530,9 +505,12 @@
" values = values.transpose(1, 2) # Shape: (b, num_heads, num_tokens, head_dim)\n", " values = values.transpose(1, 2) # Shape: (b, num_heads, num_tokens, head_dim)\n",
" queries = queries.transpose(1, 2) # Shape: (b, num_query_groups, num_tokens, head_dim)\n", " queries = queries.transpose(1, 2) # Shape: (b, num_query_groups, num_tokens, head_dim)\n",
"\n", "\n",
" ##################### NEW #####################\n",
" # Apply RoPE\n", " # Apply RoPE\n",
" keys = compute_rope(keys, self.cos, self.sin)\n", " if cos is not None:\n",
" queries = compute_rope(queries, self.cos, self.sin)\n", " keys = compute_rope(keys, cos, sin)\n",
" queries = compute_rope(queries, cos, sin)\n",
" ################################################\n",
"\n", "\n",
" ##################### NEW #####################\n", " ##################### NEW #####################\n",
" # Expand keys and values to match the number of heads\n", " # Expand keys and values to match the number of heads\n",
@ -552,11 +530,14 @@
" # Shape: (b, num_heads, num_tokens, num_tokens)\n", " # Shape: (b, num_heads, num_tokens, num_tokens)\n",
" attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head\n", " attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head\n",
"\n", "\n",
" # Original mask truncated to the number of tokens and converted to boolean\n", " ##################### NEW #####################\n",
" mask_bool = self.mask.bool()[:num_tokens, :num_tokens]\n", " # Create mask on the fly\n",
"\n", " if mask is None:\n",
" mask = torch.triu(torch.ones(num_tokens, num_tokens, device=x.device, dtype=torch.bool), diagonal=1)\n",
" ################################################\n",
" \n",
" # Use the mask to fill attention scores\n", " # Use the mask to fill attention scores\n",
" attn_scores.masked_fill_(mask_bool, -torch.inf)\n", " attn_scores.masked_fill_(mask, -torch.inf)\n",
"\n", "\n",
" attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)\n", " attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)\n",
" assert keys.shape[-1] == self.head_dim\n", " assert keys.shape[-1] == self.head_dim\n",
@ -578,7 +559,7 @@
"id": "roAXSwJs9hR8" "id": "roAXSwJs9hR8"
}, },
"source": [ "source": [
"- To illustrate the parameter savings, consider the following multi-head attention example from the GPT and Llama 2 code:" "- To illustrate the parameter savings in GQA over MHA, consider the following multi-head attention example from the GPT and Llama 2 code:"
] ]
}, },
{ {
@ -753,7 +734,8 @@
}, },
"source": [ "source": [
"- Next, we update the `TransformerBlock`\n", "- Next, we update the `TransformerBlock`\n",
"- Here, we simply swap `MultiHeadAttention` with `GroupedQueryAttention` and add the new RoPE settings" "- Here, we simply swap `MultiHeadAttention` with `GroupedQueryAttention` and add the new RoPE settings\n",
"- In addition, we also modify the `forward` method so that it receives `mask`, `cos`, and `sin`; since the values for those are the same for each transformer block, we only have to compute them once and then can reuse them"
] ]
}, },
{ {
@ -782,11 +764,15 @@
" self.norm1 = RMSNorm(cfg[\"emb_dim\"], eps=1e-5)\n", " self.norm1 = RMSNorm(cfg[\"emb_dim\"], eps=1e-5)\n",
" self.norm2 = RMSNorm(cfg[\"emb_dim\"], eps=1e-5)\n", " self.norm2 = RMSNorm(cfg[\"emb_dim\"], eps=1e-5)\n",
"\n", "\n",
" def forward(self, x):\n", " def forward(self, x, mask=None, cos=None, sin=None):\n",
" ##################### NEW #####################\n",
" # The forward method now accepts `mask` instead of accessing it via self.mask.\n",
" # Also, we now have cos and sin as input for RoPE\n",
" ################################################\n",
" # Shortcut connection for attention block\n", " # Shortcut connection for attention block\n",
" shortcut = x\n", " shortcut = x\n",
" x = self.norm1(x)\n", " x = self.norm1(x)\n",
" x = self.att(x.to(torch.bfloat16)) # Shape [batch_size, num_tokens, emb_size]\n", " x = self.att(x.to(torch.bfloat16), mask, cos, sin) # Shape [batch_size, num_tokens, emb_size]\n",
" x = x + shortcut # Add the original input back\n", " x = x + shortcut # Add the original input back\n",
"\n", "\n",
" # Shortcut connection for feed-forward block\n", " # Shortcut connection for feed-forward block\n",
@ -816,7 +802,8 @@
"id": "M_tLAq_r_llN" "id": "M_tLAq_r_llN"
}, },
"source": [ "source": [
"- When setting up the model class, we fortunately don't have to do much; we just update the name to `Llama3Model`" "- When setting up the model class, we technically don't have to do much; we just update the name to `Llama3Model`\n",
"- However, since we now pass the `mask`, `cos`, and `sin` to the transformer blocks, we also have to add them here"
] ]
}, },
{ {
@ -840,12 +827,33 @@
" self.final_norm = RMSNorm(cfg[\"emb_dim\"], eps=1e-5)\n", " self.final_norm = RMSNorm(cfg[\"emb_dim\"], eps=1e-5)\n",
" self.out_head = nn.Linear(cfg[\"emb_dim\"], cfg[\"vocab_size\"], bias=False, dtype=cfg[\"dtype\"])\n", " self.out_head = nn.Linear(cfg[\"emb_dim\"], cfg[\"vocab_size\"], bias=False, dtype=cfg[\"dtype\"])\n",
"\n", "\n",
" #################### NEW #####################\n",
" cos, sin = precompute_rope_params(\n",
" head_dim=cfg[\"emb_dim\"] // cfg[\"n_heads\"],\n",
" theta_base=cfg[\"rope_base\"],\n",
" context_length=cfg[\"context_length\"],\n",
" freq_config=cfg[\"rope_freq\"]\n",
" )\n",
" \n",
" self.register_buffer(\"cos\", cos, persistent=False)\n",
" self.register_buffer(\"sin\", sin, persistent=False)\n",
" ##############################################\n",
"\n",
" self.cfg = cfg\n",
"\n",
" def forward(self, in_idx):\n", " def forward(self, in_idx):\n",
" tok_embeds = self.tok_emb(in_idx)\n", " tok_embeds = self.tok_emb(in_idx)\n",
" x = tok_embeds\n", " x = tok_embeds\n",
" x = self.trf_blocks(x)\n", "\n",
" #################### NEW #####################\n",
" num_tokens = x.shape[1]\n",
" mask = torch.triu(torch.ones(num_tokens, num_tokens, device=x.device, dtype=torch.bool), diagonal=1)\n",
" ##############################################\n",
" \n",
" for block in self.trf_blocks:\n",
" x = block(x, mask, self.cos, self.sin)\n",
" x = self.final_norm(x)\n", " x = self.final_norm(x)\n",
" logits = self.out_head(x.to(torch.bfloat16))\n", " logits = self.out_head(x.to(self.cfg[\"dtype\"]))\n",
" return logits" " return logits"
] ]
}, },
@ -936,33 +944,12 @@
"model = Llama3Model(LLAMA3_CONFIG_8B)" "model = Llama3Model(LLAMA3_CONFIG_8B)"
] ]
}, },
{
"cell_type": "markdown",
"id": "edea6334-d1fc-427d-9cf2-4af963ff4bfc",
"metadata": {},
"source": [
"- The following is expected to print True to confirm buffers are reused instead of being (wastefully) recreated:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ee9625cc-9afa-4b11-8aab-d536fd170761",
"metadata": {},
"outputs": [],
"source": [
"# Check buffers\n",
"print(model.trf_blocks[0].att.mask is model.trf_blocks[-1].att.mask)\n",
"print(model.trf_blocks[0].att.cos is model.trf_blocks[-1].att.cos)\n",
"print(model.trf_blocks[0].att.sin is model.trf_blocks[-1].att.sin) "
]
},
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "8056a521-91a6-440f-8473-591409c3177b", "id": "8056a521-91a6-440f-8473-591409c3177b",
"metadata": {}, "metadata": {},
"source": [ "source": [
"- Let's now also compute the number of trainable parameters:" "- Let's now compute the number of trainable parameters:"
] ]
}, },
{ {
@ -1017,8 +1004,8 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"float32 (PyTorch default): 68.08 GB\n", "float32 (PyTorch default): 59.84 GB\n",
"bfloat16: 34.04 GB\n" "bfloat16: 29.92 GB\n"
] ]
} }
], ],
@ -1121,43 +1108,47 @@
"\n", "\n",
"\n", "\n",
"class Tokenizer:\n", "class Tokenizer:\n",
" \"\"\"Thin wrapper around tiktoken that keeps track of Llama-3 special IDs.\"\"\"\n",
" def __init__(self, model_path):\n", " def __init__(self, model_path):\n",
" assert os.path.isfile(model_path), f\"Model file {model_path} not found\"\n", " if not os.path.isfile(model_path):\n",
" mergeable_ranks = load_tiktoken_bpe(model_path)\n", " raise FileNotFoundError(model_path)\n",
"\n", "\n",
" self.special_tokens = {\n", " mergeable = load_tiktoken_bpe(model_path)\n",
"\n",
" # hard-coded from Meta's tokenizer.json\n",
" self.special = {\n",
" \"<|begin_of_text|>\": 128000,\n", " \"<|begin_of_text|>\": 128000,\n",
" \"<|end_of_text|>\": 128001,\n", " \"<|end_of_text|>\": 128001,\n",
" \"<|start_header_id|>\": 128006,\n", " \"<|start_header_id|>\": 128006,\n",
" \"<|end_header_id|>\": 128007,\n", " \"<|end_header_id|>\": 128007,\n",
" \"<|eot_id|>\": 128009,\n", " \"<|eot_id|>\": 128009,\n",
" }\n", " }\n",
" self.special_tokens.update({\n", " self.special.update({f\"<|reserved_{i}|>\": 128002 + i\n",
" f\"<|reserved_{i}|>\": 128002 + i for i in range(256) if (128002 + i) not in self.special_tokens.values()\n", " for i in range(256)\n",
" })\n", " if 128002 + i not in self.special.values()})\n",
"\n", "\n",
" self.model = tiktoken.Encoding(\n", " self.model = tiktoken.Encoding(\n",
" name=Path(model_path).name,\n", " name=Path(model_path).name,\n",
" pat_str=r\"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+\",\n", " pat_str=r\"(?i:'s|'t|'re|'ve|'m|'ll|'d)\"\n",
" mergeable_ranks=mergeable_ranks,\n", " r\"|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+\"\n",
" special_tokens=self.special_tokens\n", " r\"|\\p{N}{1,3}\"\n",
" r\"| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*\"\n",
" r\"|\\s*[\\r\\n]+\"\n",
" r\"|\\s+(?!\\S)\"\n",
" r\"|\\s+\",\n",
" mergeable_ranks=mergeable,\n",
" special_tokens=self.special,\n",
" )\n", " )\n",
"\n", "\n",
"\n", " def encode(self, text, bos=False, eos=False):\n",
" def encode(self, text, bos=False, eos=False, allowed_special=set(), disallowed_special=()):\n", " ids = ([self.special[\"<|begin_of_text|>\"]] if bos else []) \\\n",
" if bos:\n", " + self.model.encode(text)\n",
" tokens = [self.special_tokens[\"<|begin_of_text|>\"]]\n",
" else:\n",
" tokens = []\n",
"\n",
" tokens += self.model.encode(text, allowed_special=allowed_special, disallowed_special=disallowed_special)\n",
"\n",
" if eos:\n", " if eos:\n",
" tokens.append(self.special_tokens[\"<|end_of_text|>\"])\n", " ids.append(self.special[\"<|end_of_text|>\"])\n",
" return tokens\n", " return ids\n",
"\n", "\n",
" def decode(self, tokens):\n", " def decode(self, ids):\n",
" return self.model.decode(tokens)" " return self.model.decode(ids)"
] ]
}, },
{ {
@ -1202,13 +1193,11 @@
}, },
"outputs": [ "outputs": [
{ {
"name": "stdout", "name": "stderr",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.\n", "/Users/sebastian/Developer/LLMs-from-scratch/.venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
"Token is valid (permission: read).\n", " from .autonotebook import tqdm as notebook_tqdm\n"
"Your token has been saved to /root/.cache/huggingface/token\n",
"Login successful\n"
] ]
} }
], ],
@ -1309,7 +1298,8 @@
"base_uri": "https://localhost:8080/" "base_uri": "https://localhost:8080/"
}, },
"id": "e0a2b5cd-6cba-4d72-b8ff-04d8315d483e", "id": "e0a2b5cd-6cba-4d72-b8ff-04d8315d483e",
"outputId": "990d7b74-cb35-476b-d8bd-d544006e00f4" "outputId": "990d7b74-cb35-476b-d8bd-d544006e00f4",
"scrolled": true
}, },
"outputs": [ "outputs": [
{ {
@ -1318,7 +1308,9 @@
"text": [ "text": [
"Output text:\n", "Output text:\n",
" Every effort_dead aeros Ingredients başında.extensionégor clangmissions güc như submodule.and report官方%.Reader(\",\");\n", " Every effort_dead aeros Ingredients başında.extensionégor clangmissions güc như submodule.and report官方%.Reader(\",\");\n",
"ामल ندار Parliamentary !!! HigginsDynamicZhgmt writeln Globalsletion 사진------\n" "ामल ندار Parliamentary !!! HigginsDynamicZhamincus_beam cyc......\n",
"\n",
" haciendo\n"
] ]
} }
], ],
@ -1437,22 +1429,7 @@
"id": "5fa9c06c-7a53-4b4d-9ce4-acc027322ee4", "id": "5fa9c06c-7a53-4b4d-9ce4-acc027322ee4",
"outputId": "c05118ce-9f81-41c8-a1f2-72caa932ae86" "outputId": "c05118ce-9f81-41c8-a1f2-72caa932ae86"
}, },
"outputs": [ "outputs": [],
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "245443330e4d40c887a5649cc1663e98",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"model-00001-of-00004.safetensors: 0%| | 0.00/4.98G [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [ "source": [
"from safetensors.torch import load_file\n", "from safetensors.torch import load_file\n",
"\n", "\n",
@ -1763,64 +1740,7 @@
"id": "nbvAV7vaz6yc", "id": "nbvAV7vaz6yc",
"outputId": "9e1badc9-a6c4-48b7-9125-e0810655528b" "outputId": "9e1badc9-a6c4-48b7-9125-e0810655528b"
}, },
"outputs": [ "outputs": [],
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "f7df6bbf8e63448c8a6cb5d2f6208403",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"model-00001-of-00004.safetensors: 36%|###6 | 1.81G/4.98G [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "4772f31a1c5b4c168c9aabe7a1d2bacc",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"model-00002-of-00004.safetensors: 0%| | 0.00/5.00G [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "ad49eeb9e1204ea2bd2e371df8ccdea2",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"model-00003-of-00004.safetensors: 0%| | 0.00/4.92G [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "951b9e81613a40a2a503f61e69677f0a",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"model-00004-of-00004.safetensors: 0%| | 0.00/1.17G [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [ "source": [
"combined_weights = {}\n", "combined_weights = {}\n",
"\n", "\n",
@ -1861,35 +1781,40 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"class ChatFormat:\n", "class ChatFormat:\n",
" def __init__(self, tokenizer):\n",
" self.tokenizer = tokenizer\n",
"\n", "\n",
" def encode_header(self, message):\n", " def __init__(self, tokenizer: Tokenizer, *,\n",
" tokens = []\n", " default_system=\"You are a helpful assistant.\"):\n",
" tokens.append(self.tokenizer.special_tokens[\"<|start_header_id|>\"])\n", " self.tok = tokenizer\n",
" tokens.extend(self.tokenizer.encode(message[\"role\"], bos=False, eos=False))\n", " self.default_system = default_system\n",
" tokens.append(self.tokenizer.special_tokens[\"<|end_header_id|>\"])\n",
" tokens.extend(self.tokenizer.encode(\"\\n\\n\", bos=False, eos=False))\n",
" return tokens\n",
"\n", "\n",
" def encode(self, text):\n", " def _header(self, role):\n",
" message = {\n", " \"\"\"Encode <|start_header_id|>role<|end_header_id|>\\n\\n\"\"\"\n",
" \"role\": \"user\",\n", " return (\n",
" \"content\": text\n", " [self.tok.special[\"<|start_header_id|>\"]]\n",
" }\n", " + self.tok.encode(role)\n",
"\n", " + [self.tok.special[\"<|end_header_id|>\"]]\n",
" tokens = self.encode_header(message)\n", " + self.tok.encode(\"\\n\\n\")\n",
" tokens.extend(\n",
" self.tokenizer.encode(message[\"content\"].strip(), bos=False, eos=False)\n",
" )\n", " )\n",
" tokens.append(self.tokenizer.special_tokens[\"<|eot_id|>\"])\n",
" return tokens\n",
"\n", "\n",
" def decode(self, token_ids):\n", " def encode(self, user_message, system_message=None):\n",
" return self.tokenizer.decode(token_ids)\n", " sys_msg = system_message if system_message is not None else self.default_system\n",
"\n", "\n",
" ids = [self.tok.special[\"<|begin_of_text|>\"]]\n",
"\n", "\n",
"chat_tokenizer = ChatFormat(tokenizer)" " # system\n",
" ids += self._header(\"system\")\n",
" ids += self.tok.encode(sys_msg)\n",
" ids += [self.tok.special[\"<|eot_id|>\"]]\n",
"\n",
" # user\n",
" ids += self._header(\"user\")\n",
" ids += self.tok.encode(user_message)\n",
" ids += [self.tok.special[\"<|eot_id|>\"]]\n",
"\n",
" # assistant header (no content yet)\n",
" ids += self._header(\"assistant\")\n",
"\n",
" return ids"
] ]
}, },
{ {
@ -1918,11 +1843,14 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"[128006, 882, 128007, 271, 9906, 4435, 0, 128009]\n" "[128000, 128006, 9125, 128007, 271, 2675, 527, 264, 11190, 18328, 13, 128009, 128006, 882, 128007, 271, 9906, 4435, 0, 128009, 128006, 78191, 128007, 271]\n"
] ]
} }
], ],
"source": [ "source": [
"tokenizer = Tokenizer(tokenizer_file_path)\n",
"chat_tokenizer = ChatFormat(tokenizer)\n",
"\n",
"token_ids = chat_tokenizer.encode(\"Hello World!\")\n", "token_ids = chat_tokenizer.encode(\"Hello World!\")\n",
"print(token_ids)" "print(token_ids)"
] ]
@ -1943,7 +1871,7 @@
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"'<|start_header_id|>user<|end_header_id|>\\n\\nHello World!<|eot_id|>'" "'<|begin_of_text|><|start_header_id|>system<|end_header_id|>\\n\\nYou are a helpful assistant.<|eot_id|><|start_header_id|>user<|end_header_id|>\\n\\nHello World!<|eot_id|><|start_header_id|>assistant<|end_header_id|>\\n\\n'"
] ]
}, },
"execution_count": 35, "execution_count": 35,
@ -1982,12 +1910,13 @@
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"Output text:\n", "Output text:\n",
" Llamas are herbivores, which means they primarily eat plants and plant-based foods. Here are some of the things llamas like to eat:\n", " Llamas are herbivores, which means they primarily eat plants and plant-based foods. Their diet typically consists of:\n",
"\n", "\n",
"1. Grass: Llamas love to graze on grass, especially in the spring and summer months.\n", "1. Grasses: Llamas love to graze on grasses, including tall grasses, short grasses, and even weeds.\n",
"2. Hay: Hay is a staple in a llama's diet. They like to eat timothy hay, alfalfa hay, and other types of hay.\n", "2. Hay: Hay is a staple in a llama's diet. They enjoy a variety of hays, such as timothy hay, alfalfa hay, and oat hay.\n",
"3. Grains: Llamas may also be fed grains like oats, barley, and corn. However, grains should not make up more than 10-15% of a llama's diet.\n", "3. Grains: Llamas may be fed grains like oats, corn, and barley as a supplement to their diet.\n",
"4. Fruits and vegetables: Llamas may enjoy fruits and vegetables as treats, such as\n" "4. Fruits and vegetables: Llamas enjoy fruits and vegetables like apples, carrots, and sweet potatoes as treats or additions to their diet.\n",
"5. Minerals:\n"
] ]
} }
], ],
@ -2088,49 +2017,6 @@
"}" "}"
] ]
}, },
{
"cell_type": "markdown",
"id": "d81ee464-c112-43b0-9ee8-70df6ac942d0",
"metadata": {},
"source": [
"- Reduce the context length so the model would work fine on a MacBook Air (if you have more RAM, feel free to comment out the lines below):"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "a55a8769-1a03-4265-8fd0-15f1c423da53",
"metadata": {
"id": "a8bc2370-39d2-4bfe-b4c1-6bdd75fe101c"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"New RoPE theta: 31250.0\n"
]
}
],
"source": [
"old_context_length = LLAMA31_CONFIG_8B[\"context_length\"]\n",
"LLAMA31_CONFIG_8B[\"context_length\"] = 8192\n",
"\n",
"\n",
"def rescale_theta(theta_old, context_length_old, context_length_new):\n",
" scaling_factor = context_length_new / context_length_old\n",
" theta_new = theta_old * scaling_factor\n",
" return theta_new\n",
"\n",
"LLAMA31_CONFIG_8B[\"rope_base\"] = rescale_theta(\n",
" LLAMA31_CONFIG_8B[\"rope_base\"],\n",
" old_context_length,\n",
" LLAMA31_CONFIG_8B[\"context_length\"]\n",
")\n",
"\n",
"print(\"New RoPE theta:\", LLAMA31_CONFIG_8B[\"rope_base\"])"
]
},
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "xa3bpMDtTdBs", "id": "xa3bpMDtTdBs",
@ -2277,64 +2163,7 @@
"id": "u4J7IxOvOyPM", "id": "u4J7IxOvOyPM",
"outputId": "925348d7-fc69-4d1b-90f1-7029426bcfcf" "outputId": "925348d7-fc69-4d1b-90f1-7029426bcfcf"
}, },
"outputs": [ "outputs": [],
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "eabfde3ef38b436ea750e6fb50a02b5c",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"model-00001-of-00004.safetensors: 0%| | 0.00/4.98G [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "e117ad45771747ae95c16f9876e6dc19",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"model-00002-of-00004.safetensors: 0%| | 0.00/5.00G [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "170185f2f046437dab57c2ad23163c5c",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"model-00003-of-00004.safetensors: 0%| | 0.00/4.92G [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "6e65f5d6c5af4ab78bc7b3778b98ef86",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"model-00004-of-00004.safetensors: 0%| | 0.00/1.17G [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [ "source": [
"combined_weights = {}\n", "combined_weights = {}\n",
"\n", "\n",
@ -2481,43 +2310,6 @@
"}" "}"
] ]
}, },
{
"cell_type": "markdown",
"id": "b5cd351b-d883-460d-9cdc-47e15ddb884a",
"metadata": {},
"source": [
"- Reduce the context length so the model would work fine on a MacBook Air (if you have more RAM, feel free to comment out the lines below):"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "73f001a6-7ae0-4204-aa83-a27a8878dfd2",
"metadata": {
"id": "a8bc2370-39d2-4bfe-b4c1-6bdd75fe101c"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"New RoPE theta: 31250.0\n"
]
}
],
"source": [
"old_context_length = LLAMA32_CONFIG_1B[\"context_length\"]\n",
"LLAMA32_CONFIG_1B[\"context_length\"] = 8192\n",
"\n",
"LLAMA32_CONFIG_1B[\"rope_base\"] = rescale_theta(\n",
" LLAMA32_CONFIG_1B[\"rope_base\"],\n",
" old_context_length,\n",
" LLAMA32_CONFIG_1B[\"context_length\"]\n",
")\n",
"\n",
"print(\"New RoPE theta:\", LLAMA32_CONFIG_1B[\"rope_base\"])"
]
},
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "Dl4_0EoJKKYv", "id": "Dl4_0EoJKKYv",
@ -2612,20 +2404,6 @@
"outputId": "35588405-e2e1-4871-a1db-1d4bcb852e49" "outputId": "35588405-e2e1-4871-a1db-1d4bcb852e49"
}, },
"outputs": [ "outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c309c56a6cdf426e8ba7967b6a21864e",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"model.safetensors: 0%| | 0.00/2.47G [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{ {
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
@ -2688,7 +2466,7 @@
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"Output text:\n", "Output text:\n",
" Every effort is made to ensure that the information on this website is accurate. However, we cannot guarantee that the information is accurate, complete\n" " Every effort is made to ensure that the information on this website is accurate and up to date. However, the information is provided without any\n"
] ]
} }
], ],

File diff suppressed because it is too large Load Diff

View File

@ -56,7 +56,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": 1,
"id": "7c201adb-747e-437b-9a62-442802941e01", "id": "7c201adb-747e-437b-9a62-442802941e01",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -66,7 +66,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 3, "execution_count": 2,
"id": "dd1b65a8-4301-444a-bd7c-a6f2bd1df9df", "id": "dd1b65a8-4301-444a-bd7c-a6f2bd1df9df",
"metadata": { "metadata": {
"colab": { "colab": {
@ -81,9 +81,9 @@
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"blobfile version: 3.0.0\n", "blobfile version: 3.0.0\n",
"huggingface_hub version: 0.25.2\n", "huggingface_hub version: 0.30.1\n",
"tiktoken version: 0.8.0\n", "tiktoken version: 0.9.0\n",
"torch version: 2.5.0\n" "torch version: 2.6.0\n"
] ]
} }
], ],
@ -113,7 +113,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 4, "execution_count": 3,
"id": "82076c21-9331-4dcd-b017-42b046cf1a60", "id": "82076c21-9331-4dcd-b017-42b046cf1a60",
"metadata": { "metadata": {
"id": "82076c21-9331-4dcd-b017-42b046cf1a60" "id": "82076c21-9331-4dcd-b017-42b046cf1a60"
@ -140,18 +140,18 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 5, "execution_count": 4,
"id": "4b9a346f-5826-4083-9162-abd56afc03f0", "id": "4b9a346f-5826-4083-9162-abd56afc03f0",
"metadata": { "metadata": {
"id": "4b9a346f-5826-4083-9162-abd56afc03f0" "id": "4b9a346f-5826-4083-9162-abd56afc03f0"
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"def precompute_rope_params(head_dim, theta_base=10_000, context_length=4096, freq_config=None):\n", "def compute_rope_params(head_dim, theta_base=10_000, context_length=4096, freq_config=None, dtype=torch.float32):\n",
" assert head_dim % 2 == 0, \"Embedding dimension must be even\"\n", " assert head_dim % 2 == 0, \"Embedding dimension must be even\"\n",
"\n", "\n",
" # Compute the inverse frequencies\n", " # Compute the inverse frequencies\n",
" inv_freq = 1.0 / (theta_base ** (torch.arange(0, head_dim, 2)[: (head_dim // 2)].float() / head_dim))\n", " inv_freq = 1.0 / (theta_base ** (torch.arange(0, head_dim, 2, dtype=dtype)[: (head_dim // 2)].float() / head_dim))\n",
"\n", "\n",
" # Frequency adjustments\n", " # Frequency adjustments\n",
" if freq_config is not None:\n", " if freq_config is not None:\n",
@ -177,7 +177,7 @@
" inv_freq = inv_freq_llama\n", " inv_freq = inv_freq_llama\n",
"\n", "\n",
" # Generate position indices\n", " # Generate position indices\n",
" positions = torch.arange(context_length)\n", " positions = torch.arange(context_length, dtype=dtype)\n",
"\n", "\n",
" # Compute the angles\n", " # Compute the angles\n",
" angles = positions[:, None] * inv_freq[None, :] # Shape: (context_length, head_dim // 2)\n", " angles = positions[:, None] * inv_freq[None, :] # Shape: (context_length, head_dim // 2)\n",
@ -192,7 +192,7 @@
" return cos, sin\n", " return cos, sin\n",
"\n", "\n",
"\n", "\n",
"def compute_rope(x, cos, sin):\n", "def apply_rope(x, cos, sin):\n",
" # x: (batch_size, num_heads, seq_len, head_dim)\n", " # x: (batch_size, num_heads, seq_len, head_dim)\n",
" batch_size, num_heads, seq_len, head_dim = x.shape\n", " batch_size, num_heads, seq_len, head_dim = x.shape\n",
" assert head_dim % 2 == 0, \"Head dimension must be even\"\n", " assert head_dim % 2 == 0, \"Head dimension must be even\"\n",
@ -209,43 +209,23 @@
" rotated = torch.cat((-x2, x1), dim=-1)\n", " rotated = torch.cat((-x2, x1), dim=-1)\n",
" x_rotated = (x * cos) + (rotated * sin)\n", " x_rotated = (x * cos) + (rotated * sin)\n",
"\n", "\n",
" # It's ok to use lower-precision after applying cos and sin rotation\n",
" return x_rotated.to(dtype=x.dtype)" " return x_rotated.to(dtype=x.dtype)"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 6, "execution_count": 5,
"id": "e8169ab5-f976-4222-a2e1-eb1cabf267cb", "id": "e8169ab5-f976-4222-a2e1-eb1cabf267cb",
"metadata": { "metadata": {
"id": "e8169ab5-f976-4222-a2e1-eb1cabf267cb" "id": "e8169ab5-f976-4222-a2e1-eb1cabf267cb"
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"class SharedBuffers:\n",
" _buffers = {}\n",
"\n",
" @staticmethod\n",
" def get_buffers(context_length, head_dim, rope_base, freq_config, dtype=torch.float32):\n",
" key = (context_length, head_dim, rope_base, tuple(freq_config.values()) if freq_config else freq_config, dtype)\n",
"\n",
" if key not in SharedBuffers._buffers:\n",
" # Create or fetch the buffers\n",
" mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)\n",
" cos, sin = precompute_rope_params(head_dim, rope_base, context_length, freq_config)\n",
" if dtype is not None:\n",
" cos = cos.to(dtype)\n",
" sin = sin.to(dtype)\n",
" SharedBuffers._buffers[key] = (mask, cos, sin)\n",
"\n",
" return SharedBuffers._buffers[key]\n",
"\n",
"\n",
"class GroupedQueryAttention(nn.Module):\n", "class GroupedQueryAttention(nn.Module):\n",
" def __init__(\n", " def __init__(\n",
" self, d_in, d_out, context_length, num_heads,\n", " self, d_in, d_out, num_heads,\n",
" num_kv_groups,\n", " num_kv_groups,\n",
" rope_base=10_000,\n",
" rope_config=None,\n",
" dtype=None\n", " dtype=None\n",
" ):\n", " ):\n",
" super().__init__()\n", " super().__init__()\n",
@ -264,14 +244,7 @@
" self.W_query = nn.Linear(d_in, d_out, bias=False, dtype=dtype)\n", " self.W_query = nn.Linear(d_in, d_out, bias=False, dtype=dtype)\n",
" self.out_proj = nn.Linear(d_out, d_out, bias=False, dtype=dtype)\n", " self.out_proj = nn.Linear(d_out, d_out, bias=False, dtype=dtype)\n",
"\n", "\n",
" # Fetch buffers using SharedBuffers\n", " def forward(self, x, mask, cos, sin):\n",
" mask, cos, sin = SharedBuffers.get_buffers(context_length, self.head_dim, rope_base, rope_config, dtype)\n",
" self.register_buffer(\"mask\", mask, persistent=False)\n",
"\n",
" self.register_buffer(\"cos\", cos, persistent=False)\n",
" self.register_buffer(\"sin\", sin, persistent=False)\n",
"\n",
" def forward(self, x):\n",
" b, num_tokens, d_in = x.shape\n", " b, num_tokens, d_in = x.shape\n",
"\n", "\n",
" queries = self.W_query(x) # Shape: (b, num_tokens, d_out)\n", " queries = self.W_query(x) # Shape: (b, num_tokens, d_out)\n",
@ -289,8 +262,8 @@
" queries = queries.transpose(1, 2) # Shape: (b, num_query_groups, num_tokens, head_dim)\n", " queries = queries.transpose(1, 2) # Shape: (b, num_query_groups, num_tokens, head_dim)\n",
"\n", "\n",
" # Apply RoPE\n", " # Apply RoPE\n",
" keys = compute_rope(keys, self.cos, self.sin)\n", " keys = apply_rope(keys, cos, sin)\n",
" queries = compute_rope(queries, self.cos, self.sin)\n", " queries = apply_rope(queries, cos, sin)\n",
"\n", "\n",
" # Expand keys and values to match the number of heads\n", " # Expand keys and values to match the number of heads\n",
" # Shape: (b, num_heads, num_tokens, head_dim)\n", " # Shape: (b, num_heads, num_tokens, head_dim)\n",
@ -307,11 +280,8 @@
" # Shape: (b, num_heads, num_tokens, num_tokens)\n", " # Shape: (b, num_heads, num_tokens, num_tokens)\n",
" attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head\n", " attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head\n",
"\n", "\n",
" # Original mask truncated to the number of tokens and converted to boolean\n", " # Compute attention scores\n",
" mask_bool = self.mask.bool()[:num_tokens, :num_tokens]\n", " attn_scores = attn_scores.masked_fill(mask, -torch.inf)\n",
"\n",
" # Use the mask to fill attention scores\n",
" attn_scores.masked_fill_(mask_bool, -torch.inf)\n",
"\n", "\n",
" attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)\n", " attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)\n",
" assert keys.shape[-1] == self.head_dim\n", " assert keys.shape[-1] == self.head_dim\n",
@ -328,7 +298,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 7, "execution_count": 6,
"id": "457cb2f8-50c1-4045-8a74-f181bfb5fea9", "id": "457cb2f8-50c1-4045-8a74-f181bfb5fea9",
"metadata": { "metadata": {
"id": "457cb2f8-50c1-4045-8a74-f181bfb5fea9" "id": "457cb2f8-50c1-4045-8a74-f181bfb5fea9"
@ -338,31 +308,28 @@
"class TransformerBlock(nn.Module):\n", "class TransformerBlock(nn.Module):\n",
" def __init__(self, cfg):\n", " def __init__(self, cfg):\n",
" super().__init__()\n", " super().__init__()\n",
" self.att = GroupedQueryAttention(\n", " self.att = GroupedQueryAttention(\n",
" d_in=cfg[\"emb_dim\"],\n", " d_in=cfg[\"emb_dim\"],\n",
" d_out=cfg[\"emb_dim\"],\n", " d_out=cfg[\"emb_dim\"],\n",
" context_length=cfg[\"context_length\"],\n",
" num_heads=cfg[\"n_heads\"],\n", " num_heads=cfg[\"n_heads\"],\n",
" num_kv_groups=cfg[\"n_kv_groups\"],\n", " num_kv_groups=cfg[\"n_kv_groups\"],\n",
" rope_base=cfg[\"rope_base\"],\n",
" rope_config=cfg[\"rope_freq\"],\n",
" dtype=cfg[\"dtype\"]\n", " dtype=cfg[\"dtype\"]\n",
" )\n", " )\n",
" self.ff = FeedForward(cfg)\n", " self.ff = FeedForward(cfg)\n",
" self.norm1 = nn.RMSNorm(cfg[\"emb_dim\"], eps=1e-5)\n", " self.norm1 = nn.RMSNorm(cfg[\"emb_dim\"], eps=1e-5, dtype=cfg[\"dtype\"])\n",
" self.norm2 = nn.RMSNorm(cfg[\"emb_dim\"], eps=1e-5)\n", " self.norm2 = nn.RMSNorm(cfg[\"emb_dim\"], eps=1e-5, dtype=cfg[\"dtype\"])\n",
"\n", "\n",
" def forward(self, x):\n", " def forward(self, x, mask, cos, sin):\n",
" # Shortcut connection for attention block\n", " # Shortcut connection for attention block\n",
" shortcut = x\n", " shortcut = x\n",
" x = self.norm1(x)\n", " x = self.norm1(x)\n",
" x = self.att(x.to(torch.bfloat16)) # Shape [batch_size, num_tokens, emb_size]\n", " x = self.att(x, mask, cos, sin) # Shape [batch_size, num_tokens, emb_size]\n",
" x = x + shortcut # Add the original input back\n", " x = x + shortcut # Add the original input back\n",
"\n", "\n",
" # Shortcut connection for feed-forward block\n", " # Shortcut connection for feed-forward block\n",
" shortcut = x\n", " shortcut = x\n",
" x = self.norm2(x)\n", " x = self.norm2(x)\n",
" x = self.ff(x.to(torch.bfloat16))\n", " x = self.ff(x)\n",
" x = x + shortcut # Add the original input back\n", " x = x + shortcut # Add the original input back\n",
"\n", "\n",
" return x" " return x"
@ -370,7 +337,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 8, "execution_count": 7,
"id": "e88de3e3-9f07-42cc-816b-28dbd46e96c4", "id": "e88de3e3-9f07-42cc-816b-28dbd46e96c4",
"metadata": { "metadata": {
"id": "e88de3e3-9f07-42cc-816b-28dbd46e96c4" "id": "e88de3e3-9f07-42cc-816b-28dbd46e96c4"
@ -380,20 +347,41 @@
"class Llama3Model(nn.Module):\n", "class Llama3Model(nn.Module):\n",
" def __init__(self, cfg):\n", " def __init__(self, cfg):\n",
" super().__init__()\n", " super().__init__()\n",
"\n",
" # Main model parameters\n",
" self.tok_emb = nn.Embedding(cfg[\"vocab_size\"], cfg[\"emb_dim\"], dtype=cfg[\"dtype\"])\n", " self.tok_emb = nn.Embedding(cfg[\"vocab_size\"], cfg[\"emb_dim\"], dtype=cfg[\"dtype\"])\n",
"\n", "\n",
" self.trf_blocks = nn.Sequential(\n", " self.trf_blocks = nn.ModuleList( # ModuleList since Sequential can only accept one input, and we need `x, mask, cos, sin`\n",
" *[TransformerBlock(cfg) for _ in range(cfg[\"n_layers\"])])\n", " [TransformerBlock(cfg) for _ in range(cfg[\"n_layers\"])]\n",
" )\n",
"\n", "\n",
" self.final_norm = nn.RMSNorm(cfg[\"emb_dim\"], eps=1e-5)\n", " self.final_norm = nn.RMSNorm(cfg[\"emb_dim\"], eps=1e-5, dtype=cfg[\"dtype\"])\n",
" self.out_head = nn.Linear(cfg[\"emb_dim\"], cfg[\"vocab_size\"], bias=False, dtype=cfg[\"dtype\"])\n", " self.out_head = nn.Linear(cfg[\"emb_dim\"], cfg[\"vocab_size\"], bias=False, dtype=cfg[\"dtype\"])\n",
"\n", "\n",
" # Reusuable utilities\n",
" cos, sin = compute_rope_params(\n",
" head_dim=cfg[\"emb_dim\"] // cfg[\"n_heads\"],\n",
" theta_base=cfg[\"rope_base\"],\n",
" context_length=cfg[\"context_length\"],\n",
" freq_config=cfg[\"rope_freq\"]\n",
" )\n",
" self.register_buffer(\"cos\", cos, persistent=False)\n",
" self.register_buffer(\"sin\", sin, persistent=False)\n",
" self.cfg = cfg\n",
"\n",
"\n",
" def forward(self, in_idx):\n", " def forward(self, in_idx):\n",
" # Forward pass\n",
" tok_embeds = self.tok_emb(in_idx)\n", " tok_embeds = self.tok_emb(in_idx)\n",
" x = tok_embeds\n", " x = tok_embeds\n",
" x = self.trf_blocks(x)\n", "\n",
" num_tokens = x.shape[1]\n",
" mask = torch.triu(torch.ones(num_tokens, num_tokens, device=x.device, dtype=torch.bool), diagonal=1)\n",
" \n",
" for block in self.trf_blocks:\n",
" x = block(x, mask, self.cos, self.sin)\n",
" x = self.final_norm(x)\n", " x = self.final_norm(x)\n",
" logits = self.out_head(x.to(torch.bfloat16))\n", " logits = self.out_head(x.to(self.cfg[\"dtype\"]))\n",
" return logits" " return logits"
] ]
}, },
@ -420,7 +408,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 9, "execution_count": 8,
"id": "caa142fa-b375-4e78-b392-2072ced666f3", "id": "caa142fa-b375-4e78-b392-2072ced666f3",
"metadata": { "metadata": {
"id": "caa142fa-b375-4e78-b392-2072ced666f3" "id": "caa142fa-b375-4e78-b392-2072ced666f3"
@ -430,16 +418,16 @@
"# Llama 3.2 1B\n", "# Llama 3.2 1B\n",
"\n", "\n",
"LLAMA32_CONFIG = {\n", "LLAMA32_CONFIG = {\n",
" \"vocab_size\": 128_256, # Vocabulary size\n", " \"vocab_size\": 128_256, # Vocabulary size\n",
" \"context_length\": 131_072, # Context length\n", " \"context_length\": 131_072, # Context length that was used to train the model\n",
" \"emb_dim\": 2048, # Embedding dimension\n", " \"emb_dim\": 2048, # Embedding dimension\n",
" \"n_heads\": 32, # Number of attention heads\n", " \"n_heads\": 32, # Number of attention heads\n",
" \"n_layers\": 16, # Number of layers\n", " \"n_layers\": 16, # Number of layers\n",
" \"hidden_dim\": 8192, # Size of the intermediate dimension in FeedForward\n", " \"hidden_dim\": 8192, # Size of the intermediate dimension in FeedForward\n",
" \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n", " \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n",
" \"rope_base\": 500_000.0, # The base in RoPE's \"theta\"\n", " \"rope_base\": 500_000.0, # The base in RoPE's \"theta\"\n",
" \"dtype\": torch.bfloat16, # Lower-precision dtype to reduce memory usage\n", " \"dtype\": torch.bfloat16, # Lower-precision dtype to reduce memory usage\n",
" \"rope_freq\": { # RoPE frequency scaling\n", " \"rope_freq\": { # RoPE frequency scaling\n",
" \"factor\": 32.0,\n", " \"factor\": 32.0,\n",
" \"low_freq_factor\": 1.0,\n", " \"low_freq_factor\": 1.0,\n",
" \"high_freq_factor\": 4.0,\n", " \"high_freq_factor\": 4.0,\n",
@ -450,16 +438,16 @@
"# Llama 3.2 3B\n", "# Llama 3.2 3B\n",
"\n", "\n",
"# LLAMA32_CONFIG = {\n", "# LLAMA32_CONFIG = {\n",
"# \"vocab_size\": 128_256, # Vocabulary size\n", "# \"vocab_size\": 128_256, # Vocabulary size\n",
"# \"context_length\": 131_072, # Context length\n", "# \"context_length\": 131_072, # Context length that was used to train the model\n",
"# \"emb_dim\": 3072, # Embedding dimension\n", "# \"emb_dim\": 3072, # Embedding dimension\n",
"# \"n_heads\": 24, # Number of attention heads\n", "# \"n_heads\": 24, # Number of attention heads\n",
"# \"n_layers\": 28, # Number of layers\n", "# \"n_layers\": 28, # Number of layers\n",
"# \"hidden_dim\": 8192, # Size of the intermediate dimension in FeedForward\n", "# \"hidden_dim\": 8192, # Size of the intermediate dimension in FeedForward\n",
"# \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n", "# \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n",
"# \"rope_base\": 500_000.0, # The base in RoPE's \"theta\"\n", "# \"rope_base\": 500_000.0, # The base in RoPE's \"theta\"\n",
"# \"dtype\": torch.bfloat16, # Lower-precision dtype to reduce memory usage\n", "# \"dtype\": torch.bfloat16, # Lower-precision dtype to reduce memory usage\n",
"# \"rope_freq\": { # RoPE frequency scaling\n", "# \"rope_freq\": { # RoPE frequency scaling\n",
"# \"factor\": 32.0,\n", "# \"factor\": 32.0,\n",
"# \"low_freq_factor\": 1.0,\n", "# \"low_freq_factor\": 1.0,\n",
"# \"high_freq_factor\": 4.0,\n", "# \"high_freq_factor\": 4.0,\n",
@ -470,54 +458,9 @@
"LLAMA_SIZE_STR = \"1B\" if LLAMA32_CONFIG[\"emb_dim\"] == 2048 else \"3B\"" "LLAMA_SIZE_STR = \"1B\" if LLAMA32_CONFIG[\"emb_dim\"] == 2048 else \"3B\""
] ]
}, },
{
"cell_type": "markdown",
"id": "34535172-797e-4dd0-84fb-65bc75ad5b06",
"metadata": {
"id": "34535172-797e-4dd0-84fb-65bc75ad5b06"
},
"source": [
"- Reduce the context length so the model would work fine on a MacBook Air (if you have more RAM, feel free to comment out the lines below):"
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 10, "execution_count": 9,
"id": "a8bc2370-39d2-4bfe-b4c1-6bdd75fe101c",
"metadata": {
"id": "a8bc2370-39d2-4bfe-b4c1-6bdd75fe101c"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"New RoPE theta: 31250.0\n"
]
}
],
"source": [
"old_context_length = LLAMA32_CONFIG[\"context_length\"]\n",
"LLAMA32_CONFIG[\"context_length\"] = 8192\n",
"\n",
"\n",
"def rescale_theta(theta_old, context_length_old, context_length_new):\n",
" scaling_factor = context_length_new / context_length_old\n",
" theta_new = theta_old * scaling_factor\n",
" return theta_new\n",
"\n",
"LLAMA32_CONFIG[\"rope_base\"] = rescale_theta(\n",
" LLAMA32_CONFIG[\"rope_base\"],\n",
" old_context_length,\n",
" LLAMA32_CONFIG[\"context_length\"]\n",
")\n",
"\n",
"print(\"New RoPE theta:\", LLAMA32_CONFIG[\"rope_base\"])"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "156253fe-aacd-4da2-8f13-705f05c4b11e", "id": "156253fe-aacd-4da2-8f13-705f05c4b11e",
"metadata": { "metadata": {
"id": "156253fe-aacd-4da2-8f13-705f05c4b11e" "id": "156253fe-aacd-4da2-8f13-705f05c4b11e"
@ -539,36 +482,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 12, "execution_count": 10,
"id": "0e95db6d-2712-41a5-a5e0-86c49897f4cf",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "0e95db6d-2712-41a5-a5e0-86c49897f4cf",
"outputId": "8efc4937-e616-40d0-cd59-670d7eb3e841"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"True\n",
"True\n",
"True\n"
]
}
],
"source": [
"# Check buffers\n",
"print(model.trf_blocks[0].att.mask is model.trf_blocks[-1].att.mask)\n",
"print(model.trf_blocks[0].att.cos is model.trf_blocks[-1].att.cos)\n",
"print(model.trf_blocks[0].att.sin is model.trf_blocks[-1].att.sin)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "364e76ca-52f8-4fa5-af37-c4069f9694bc", "id": "364e76ca-52f8-4fa5-af37-c4069f9694bc",
"metadata": { "metadata": {
"colab": { "colab": {
@ -599,7 +513,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 14, "execution_count": 11,
"id": "fd5efb03-5a07-46e8-8607-93ed47549d2b", "id": "fd5efb03-5a07-46e8-8607-93ed47549d2b",
"metadata": { "metadata": {
"colab": { "colab": {
@ -613,8 +527,8 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"float32 (PyTorch default): 11.42 GB\n", "float32 (PyTorch default): 11.23 GB\n",
"bfloat16: 5.71 GB\n" "bfloat16: 5.61 GB\n"
] ]
} }
], ],
@ -649,7 +563,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 15, "execution_count": 12,
"id": "31f12baf-f79b-499f-85c0-51328a6a20f5", "id": "31f12baf-f79b-499f-85c0-51328a6a20f5",
"metadata": { "metadata": {
"id": "31f12baf-f79b-499f-85c0-51328a6a20f5" "id": "31f12baf-f79b-499f-85c0-51328a6a20f5"
@ -679,7 +593,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 16, "execution_count": 13,
"id": "9482b01c-49f9-48e4-ab2c-4a4c75240e77", "id": "9482b01c-49f9-48e4-ab2c-4a4c75240e77",
"metadata": { "metadata": {
"id": "9482b01c-49f9-48e4-ab2c-4a4c75240e77" "id": "9482b01c-49f9-48e4-ab2c-4a4c75240e77"
@ -693,73 +607,86 @@
"from tiktoken.load import load_tiktoken_bpe\n", "from tiktoken.load import load_tiktoken_bpe\n",
"\n", "\n",
"\n", "\n",
"class Tokenizer:\n",
" def __init__(self, model_path):\n",
" assert os.path.isfile(model_path), f\"Model file {model_path} not found\"\n",
" mergeable_ranks = load_tiktoken_bpe(model_path)\n",
"\n", "\n",
" self.special_tokens = {\n", "class Tokenizer:\n",
" \"\"\"Thin wrapper around tiktoken that keeps track of Llama-3 special IDs.\"\"\"\n",
" def __init__(self, model_path):\n",
" if not os.path.isfile(model_path):\n",
" raise FileNotFoundError(model_path)\n",
"\n",
" mergeable = load_tiktoken_bpe(model_path)\n",
"\n",
" # hard-coded from Meta's tokenizer.json\n",
" self.special = {\n",
" \"<|begin_of_text|>\": 128000,\n", " \"<|begin_of_text|>\": 128000,\n",
" \"<|end_of_text|>\": 128001,\n", " \"<|end_of_text|>\": 128001,\n",
" \"<|start_header_id|>\": 128006,\n", " \"<|start_header_id|>\": 128006,\n",
" \"<|end_header_id|>\": 128007,\n", " \"<|end_header_id|>\": 128007,\n",
" \"<|eot_id|>\": 128009,\n", " \"<|eot_id|>\": 128009,\n",
" }\n", " }\n",
" self.special_tokens.update({\n", " self.special.update({f\"<|reserved_{i}|>\": 128002 + i\n",
" f\"<|reserved_{i}|>\": 128002 + i for i in range(256) if (128002 + i) not in self.special_tokens.values()\n", " for i in range(256)\n",
" })\n", " if 128002 + i not in self.special.values()})\n",
"\n", "\n",
" self.model = tiktoken.Encoding(\n", " self.model = tiktoken.Encoding(\n",
" name=Path(model_path).name,\n", " name=Path(model_path).name,\n",
" pat_str=r\"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+\",\n", " pat_str=r\"(?i:'s|'t|'re|'ve|'m|'ll|'d)\"\n",
" mergeable_ranks=mergeable_ranks,\n", " r\"|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+\"\n",
" special_tokens=self.special_tokens\n", " r\"|\\p{N}{1,3}\"\n",
" r\"| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*\"\n",
" r\"|\\s*[\\r\\n]+\"\n",
" r\"|\\s+(?!\\S)\"\n",
" r\"|\\s+\",\n",
" mergeable_ranks=mergeable,\n",
" special_tokens=self.special,\n",
" )\n", " )\n",
"\n", "\n",
"\n", " def encode(self, text, bos=False, eos=False):\n",
" def encode(self, text, bos=False, eos=False, allowed_special=set(), disallowed_special=()):\n", " ids = ([self.special[\"<|begin_of_text|>\"]] if bos else []) \\\n",
" if bos:\n", " + self.model.encode(text)\n",
" tokens = [self.special_tokens[\"<|begin_of_text|>\"]]\n",
" else:\n",
" tokens = []\n",
"\n",
" tokens += self.model.encode(text, allowed_special=allowed_special, disallowed_special=disallowed_special)\n",
"\n",
" if eos:\n", " if eos:\n",
" tokens.append(self.special_tokens[\"<|end_of_text|>\"])\n", " ids.append(self.special[\"<|end_of_text|>\"])\n",
" return tokens\n", " return ids\n",
"\n", "\n",
" def decode(self, tokens):\n", " def decode(self, ids):\n",
" return self.model.decode(tokens)\n", " return self.model.decode(ids)\n",
"\n", "\n",
"\n", "\n",
"class ChatFormat:\n", "class ChatFormat:\n",
" def __init__(self, tokenizer):\n",
" self.tokenizer = tokenizer\n",
"\n", "\n",
" def encode_header(self, message):\n", " def __init__(self, tokenizer: Tokenizer, *,\n",
" tokens = []\n", " default_system=\"You are a helpful assistant.\"):\n",
" tokens.append(self.tokenizer.special_tokens[\"<|start_header_id|>\"])\n", " self.tok = tokenizer\n",
" tokens.extend(self.tokenizer.encode(message[\"role\"], bos=False, eos=False))\n", " self.default_system = default_system\n",
" tokens.append(self.tokenizer.special_tokens[\"<|end_header_id|>\"])\n",
" tokens.extend(self.tokenizer.encode(\"\\n\\n\", bos=False, eos=False))\n",
" return tokens\n",
"\n", "\n",
" def encode(self, text):\n", " def _header(self, role):\n",
" message = {\n", " \"\"\"Encode <|start_header_id|>role<|end_header_id|>\\n\\n\"\"\"\n",
" \"role\": \"user\",\n", " return (\n",
" \"content\": text\n", " [self.tok.special[\"<|start_header_id|>\"]]\n",
" }\n", " + self.tok.encode(role)\n",
"\n", " + [self.tok.special[\"<|end_header_id|>\"]]\n",
" tokens = self.encode_header(message)\n", " + self.tok.encode(\"\\n\\n\")\n",
" tokens.extend(\n",
" self.tokenizer.encode(message[\"content\"].strip(), bos=False, eos=False)\n",
" )\n", " )\n",
" tokens.append(self.tokenizer.special_tokens[\"<|eot_id|>\"])\n",
" return tokens\n",
"\n", "\n",
" def decode(self, token_ids):\n", " def encode(self, user_message, system_message=None):\n",
" return self.tokenizer.decode(token_ids)" " sys_msg = system_message if system_message is not None else self.default_system\n",
"\n",
" ids = [self.tok.special[\"<|begin_of_text|>\"]]\n",
"\n",
" # system\n",
" ids += self._header(\"system\")\n",
" ids += self.tok.encode(sys_msg)\n",
" ids += [self.tok.special[\"<|eot_id|>\"]]\n",
"\n",
" # user\n",
" ids += self._header(\"user\")\n",
" ids += self.tok.encode(user_message)\n",
" ids += [self.tok.special[\"<|eot_id|>\"]]\n",
"\n",
" # assistant header (no content yet)\n",
" ids += self._header(\"assistant\")\n",
"\n",
" return ids"
] ]
}, },
{ {
@ -782,7 +709,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 17, "execution_count": 14,
"id": "e9d96dc8-603a-4cb5-8c3e-4d2ca56862ed", "id": "e9d96dc8-603a-4cb5-8c3e-4d2ca56862ed",
"metadata": { "metadata": {
"colab": { "colab": {
@ -793,25 +720,24 @@
}, },
"outputs": [ "outputs": [
{ {
"name": "stdout", "name": "stderr",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.\n", "/Users/sebastian/Developer/LLMs-from-scratch/.venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
"Token is valid (permission: read).\n", " from .autonotebook import tqdm as notebook_tqdm\n"
"Your token has been saved to /teamspace/studios/this_studio/.cache/huggingface/token\n",
"Login successful\n"
] ]
} }
], ],
"source": [ "source": [
"from huggingface_hub import login\n", "# Uncomment and run the following code if you are executing the notebook for the first time\n",
"\n", "\n",
"login()" "# from huggingface_hub import login\n",
"# login()"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 18, "execution_count": 15,
"id": "986bc1a0-804f-4154-80f8-44cefbee1368", "id": "986bc1a0-804f-4154-80f8-44cefbee1368",
"metadata": { "metadata": {
"colab": { "colab": {
@ -847,7 +773,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 19, "execution_count": 16,
"id": "_gBhxDtU_nxo", "id": "_gBhxDtU_nxo",
"metadata": { "metadata": {
"id": "_gBhxDtU_nxo" "id": "_gBhxDtU_nxo"
@ -871,7 +797,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 20, "execution_count": 17,
"id": "75166128-5899-4995-9b88-9672e135650e", "id": "75166128-5899-4995-9b88-9672e135650e",
"metadata": { "metadata": {
"id": "75166128-5899-4995-9b88-9672e135650e" "id": "75166128-5899-4995-9b88-9672e135650e"
@ -954,7 +880,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 21, "execution_count": 18,
"id": "699cb1b8-a67d-49fb-80a6-0dad9d81f392", "id": "699cb1b8-a67d-49fb-80a6-0dad9d81f392",
"metadata": { "metadata": {
"colab": { "colab": {
@ -1018,7 +944,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 22, "execution_count": 19,
"id": "7f9f7ccc-70cb-41ff-9c25-44336042fc37", "id": "7f9f7ccc-70cb-41ff-9c25-44336042fc37",
"metadata": { "metadata": {
"id": "7f9f7ccc-70cb-41ff-9c25-44336042fc37" "id": "7f9f7ccc-70cb-41ff-9c25-44336042fc37"
@ -1049,7 +975,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 23, "execution_count": 20,
"id": "7b8401c6-e244-4cb7-9849-2ba71ce758d5", "id": "7b8401c6-e244-4cb7-9849-2ba71ce758d5",
"metadata": { "metadata": {
"id": "7b8401c6-e244-4cb7-9849-2ba71ce758d5" "id": "7b8401c6-e244-4cb7-9849-2ba71ce758d5"
@ -1108,7 +1034,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 24, "execution_count": 21,
"id": "1c7a04fa-6aac-416b-8f63-f1e19227633d", "id": "1c7a04fa-6aac-416b-8f63-f1e19227633d",
"metadata": { "metadata": {
"id": "1c7a04fa-6aac-416b-8f63-f1e19227633d" "id": "1c7a04fa-6aac-416b-8f63-f1e19227633d"
@ -1118,23 +1044,31 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"Time: 18.20 sec\n",
"\n",
"\n",
"Output text:\n", "Output text:\n",
" Llamas are herbivores, which means they primarily eat plants. Their diet consists mainly of:\n",
"\n", "\n",
"1. Grasses: Llamas love to graze on various types of grasses, including tall grasses and grassy meadows.\n", " Llamas are herbivores, which means they primarily eat plants and plant-based foods. Their diet typically consists of:\n",
"2. Hay: Llamas also eat hay, which is a dry, compressed form of grass or other plants.\n",
"3. Alfalfa: Alfalfa is a legume that is commonly fed to llamas. It is high in protein and fiber.\n",
"4. Other plants: Llamas will also eat other plants, such as wild grasses, shrubs, and trees.\n",
"\n", "\n",
"It's worth noting that the diet of llamas can vary depending on the region, climate,\n" "1. Grasses: Llamas love to graze on various types of grasses, including tall grasses, short grasses, and grassy weeds.\n",
"2. Hay: They also enjoy munching on hay, which is a dry, compressed form of grass or other plant material.\n",
"3. Leaves: Llamas will eat leaves from trees and shrubs, including leaves from plants like clover, alfalfa, and grasses.\n",
"4. Fruits and vegetables: In the wild, llamas will eat fruits and vegetables like berries, apples, and carrots.\n",
"5. Browse: Llamas will also\n"
] ]
} }
], ],
"source": [ "source": [
"import time\n",
"\n",
"\n",
"PROMPT = \"What do llamas eat?\"\n", "PROMPT = \"What do llamas eat?\"\n",
"\n", "\n",
"torch.manual_seed(123)\n", "torch.manual_seed(123)\n",
"\n", "\n",
"start = time.time()\n",
"\n",
"token_ids = generate(\n", "token_ids = generate(\n",
" model=model,\n", " model=model,\n",
" idx=text_to_token_ids(PROMPT, chat_tokenizer).to(device),\n", " idx=text_to_token_ids(PROMPT, chat_tokenizer).to(device),\n",
@ -1144,6 +1078,13 @@
" temperature=0.\n", " temperature=0.\n",
")\n", ")\n",
"\n", "\n",
"print(f\"Time: {time.time() - start:.2f} sec\")\n",
"\n",
"if torch.cuda.is_available():\n",
" max_mem_bytes = torch.cuda.max_memory_allocated()\n",
" max_mem_gb = max_mem_bytes / (1024 ** 3)\n",
" print(f\"Max memory allocated: {max_mem_gb:.2f} GB\")\n",
"\n",
"output_text = token_ids_to_text(token_ids, tokenizer)\n", "output_text = token_ids_to_text(token_ids, tokenizer)\n",
"\n", "\n",
"\n", "\n",
@ -1158,7 +1099,7 @@
" # If the token is not found, return the original text\n", " # If the token is not found, return the original text\n",
" return text\n", " return text\n",
"\n", "\n",
"print(\"Output text:\\n\", clean_text(output_text))" "print(\"\\n\\nOutput text:\\n\\n\", clean_text(output_text))"
] ]
}, },
{ {

View File

@ -110,12 +110,21 @@ from llms_from_scratch.appendix_a import NeuralNetwork, ToyDataset
from llms_from_scratch.appendix_d import find_highest_gradient, train_model from llms_from_scratch.appendix_d import find_highest_gradient, train_model
```
### Llama 3 (Bonus material)
```python
from llms_from_scratch.llama3 import ( from llms_from_scratch.llama3 import (
Llama3Model, Llama3Model,
Llama3ModelFast,
Llama3Tokenizer, Llama3Tokenizer,
ChatFormat, ChatFormat,
clean_text clean_text
) )
``` ```
(For the `llms_from_scratch.llama3` usage information, please see [this bonus section](../../ch05/07_gpt_to_llama/README.md).
For the `llms_from_scratch.llama3` usage information, please see [this bonus section](../../ch05/07_gpt_to_llama/README.md).

View File

@ -15,8 +15,7 @@ from tiktoken.load import load_tiktoken_bpe
LLAMA32_CONFIG_1B = { LLAMA32_CONFIG_1B = {
"vocab_size": 128_256, # Vocabulary size "vocab_size": 128_256, # Vocabulary size
"context_length": 8192, # Maximum context length to use (reduced to save memory) "context_length": 131_072, # Context length that was used to train the model
"orig_context_length": 131_072, # Context length that was used to train the model
"emb_dim": 2048, # Embedding dimension "emb_dim": 2048, # Embedding dimension
"n_heads": 32, # Number of attention heads "n_heads": 32, # Number of attention heads
"n_layers": 16, # Number of layers "n_layers": 16, # Number of layers
@ -34,8 +33,7 @@ LLAMA32_CONFIG_1B = {
LLAMA32_CONFIG_3B = { LLAMA32_CONFIG_3B = {
"vocab_size": 128_256, # Vocabulary size "vocab_size": 128_256, # Vocabulary size
"context_length": 8192, # Maximum context length to use (reduced to save memory) "context_length": 131_072, # Context length that was used to train the model
"orig_context_length": 131_072, # Context length that was used to train the model
"emb_dim": 3072, # Embedding dimension "emb_dim": 3072, # Embedding dimension
"n_heads": 24, # Number of attention heads "n_heads": 24, # Number of attention heads
"n_layers": 28, # Number of layers "n_layers": 28, # Number of layers
@ -67,17 +65,6 @@ class Llama3Model(nn.Module):
self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False, dtype=cfg["dtype"]) self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False, dtype=cfg["dtype"])
# Reusuable utilities # Reusuable utilities
self.register_buffer(
"mask", torch.triu(torch.ones(cfg["context_length"], cfg["context_length"]), diagonal=1).bool(),
persistent=False
)
if cfg["orig_context_length"] != cfg["context_length"]:
cfg["rope_base"] = rescale_theta(
cfg["rope_base"],
cfg["orig_context_length"],
cfg["context_length"]
)
cos, sin = compute_rope_params( cos, sin = compute_rope_params(
head_dim=cfg["emb_dim"] // cfg["n_heads"], head_dim=cfg["emb_dim"] // cfg["n_heads"],
theta_base=cfg["rope_base"], theta_base=cfg["rope_base"],
@ -92,8 +79,11 @@ class Llama3Model(nn.Module):
tok_embeds = self.tok_emb(in_idx) tok_embeds = self.tok_emb(in_idx)
x = tok_embeds x = tok_embeds
num_tokens = x.shape[1]
mask = torch.triu(torch.ones(num_tokens, num_tokens, device=x.device, dtype=torch.bool), diagonal=1)
for block in self.trf_blocks: for block in self.trf_blocks:
x = block(x, self.mask, self.cos, self.sin) x = block(x, mask, self.cos, self.sin)
x = self.final_norm(x) x = self.final_norm(x)
logits = self.out_head(x.to(self.cfg["dtype"])) logits = self.out_head(x.to(self.cfg["dtype"]))
return logits return logits
@ -281,88 +271,104 @@ def apply_rope(x, cos, sin):
return x_rotated.to(dtype=x.dtype) return x_rotated.to(dtype=x.dtype)
def rescale_theta(theta_old, context_length_old, context_length_new):
scaling_factor = context_length_new / context_length_old
theta_new = theta_old * scaling_factor
return theta_new
########################################## ##########################################
# Tokenizer # Tokenizer
########################################## ##########################################
class Llama3Tokenizer: class Llama3Tokenizer:
"""Thin wrapper around tiktoken that keeps track of Llama-3 special IDs."""
def __init__(self, model_path): def __init__(self, model_path):
assert os.path.isfile(model_path), f"Model file {model_path} not found" if not os.path.isfile(model_path):
mergeable_ranks = load_tiktoken_bpe(model_path) raise FileNotFoundError(model_path)
self.special_tokens = { mergeable = load_tiktoken_bpe(model_path)
# hard-coded from Meta's tokenizer.json
self.special = {
"<|begin_of_text|>": 128000, "<|begin_of_text|>": 128000,
"<|end_of_text|>": 128001, "<|end_of_text|>": 128001,
"<|start_header_id|>": 128006, "<|start_header_id|>": 128006,
"<|end_header_id|>": 128007, "<|end_header_id|>": 128007,
"<|eot_id|>": 128009, "<|eot_id|>": 128009,
} }
self.special_tokens.update({ self.special.update({f"<|reserved_{i}|>": 128002 + i
f"<|reserved_{i}|>": 128002 + i for i in range(256) if (128002 + i) not in self.special_tokens.values() for i in range(256)
}) if 128002 + i not in self.special.values()})
self.model = tiktoken.Encoding( self.model = tiktoken.Encoding(
name=Path(model_path).name, name=Path(model_path).name,
pat_str=r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+", pat_str=r"(?i:'s|'t|'re|'ve|'m|'ll|'d)"
mergeable_ranks=mergeable_ranks, r"|[^\r\n\p{L}\p{N}]?\p{L}+"
special_tokens=self.special_tokens r"|\p{N}{1,3}"
r"| ?[^\s\p{L}\p{N}]+[\r\n]*"
r"|\s*[\r\n]+"
r"|\s+(?!\S)"
r"|\s+",
mergeable_ranks=mergeable,
special_tokens=self.special,
) )
def encode(self, text, bos=False, eos=False, allowed_special=set(), disallowed_special=()): def encode(self, text, bos=False, eos=False, allowed_special=set()):
ids: list[int] = []
if bos: if bos:
tokens = [self.special_tokens["<|begin_of_text|>"]] ids.append(self.special_tokens["<|begin_of_text|>"])
else:
tokens = []
tokens += self.model.encode(text, allowed_special=allowed_special, disallowed_special=disallowed_special)
# delegate to underlying tiktoken.Encoding.encode
ids.extend(
self.model.encode(
text,
allowed_special=allowed_special,
)
)
if eos: if eos:
tokens.append(self.special_tokens["<|end_of_text|>"]) ids.append(self.special_tokens["<|end_of_text|>"])
return tokens
def decode(self, tokens): return ids
return self.model.decode(tokens)
def decode(self, ids):
return self.model.decode(ids)
class ChatFormat: class ChatFormat:
def __init__(self, tokenizer):
self.tokenizer = tokenizer
def encode_header(self, message): def __init__(self, tokenizer: Llama3Tokenizer, *,
tokens = [] default_system="You are a helpful assistant."):
tokens.append(self.tokenizer.special_tokens["<|start_header_id|>"]) self.tok = tokenizer
tokens.extend(self.tokenizer.encode(message["role"], bos=False, eos=False)) self.default_system = default_system
tokens.append(self.tokenizer.special_tokens["<|end_header_id|>"])
tokens.extend(self.tokenizer.encode("\n\n", bos=False, eos=False))
return tokens
def encode(self, text, allowed_special=None): def _header(self, role):
message = { """Encode <|start_header_id|>role<|end_header_id|>\n\n"""
"role": "user", return (
"content": text [self.tok.special["<|start_header_id|>"]]
} + self.tok.encode(role)
+ [self.tok.special["<|end_header_id|>"]]
tokens = self.encode_header(message) + self.tok.encode("\n\n")
tokens.extend(
self.tokenizer.encode(
message["content"].strip(),
bos=False,
eos=False,
allowed_special=allowed_special
)
) )
tokens.append(self.tokenizer.special_tokens["<|eot_id|>"])
return tokens
def decode(self, token_ids): def encode(self, user_message, system_message=None, allowed_special=None):
return self.tokenizer.decode(token_ids) sys_msg = system_message if system_message is not None else self.default_system
ids = [self.tok.special["<|begin_of_text|>"]]
# system
ids += self._header("system")
ids += self.tok.encode(sys_msg, allowed_special=allowed_special)
ids += [self.tok.special["<|eot_id|>"]]
# user
ids += self._header("user")
ids += self.tok.encode(user_message)
ids += [self.tok.special["<|eot_id|>"]]
# assistant header (no content yet)
ids += self._header("assistant")
return ids
def decode(self, ids):
return self.tok.decode(ids)
def clean_text(text, header_end="assistant<|end_header_id|>\n\n"): def clean_text(text, header_end="assistant<|end_header_id|>\n\n"):
@ -483,12 +489,6 @@ class Llama3ModelFast(nn.Module):
self.final_norm = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"]) self.final_norm = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"])
self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False, dtype=cfg["dtype"]) self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False, dtype=cfg["dtype"])
if cfg["orig_context_length"] != cfg["context_length"]:
cfg["rope_base"] = rescale_theta(
cfg["rope_base"],
cfg["orig_context_length"],
cfg["context_length"]
)
cos, sin = compute_rope_params( cos, sin = compute_rope_params(
head_dim=cfg["emb_dim"] // cfg["n_heads"], head_dim=cfg["emb_dim"] // cfg["n_heads"],
theta_base=cfg["rope_base"], theta_base=cfg["rope_base"],

View File

@ -7,7 +7,6 @@ from llms_from_scratch.ch04 import generate_text_simple
from llms_from_scratch.llama3 import ( from llms_from_scratch.llama3 import (
compute_rope_params, compute_rope_params,
apply_rope, apply_rope,
rescale_theta,
LLAMA32_CONFIG_1B, LLAMA32_CONFIG_1B,
GroupedQueryAttention, GroupedQueryAttention,
GroupedQueryAttentionFast, GroupedQueryAttentionFast,
@ -102,23 +101,6 @@ GPT_CONFIG_124M = {
} }
def test_rescale():
new_theta = rescale_theta(
theta_old=500_000.,
context_length_old=131_072,
context_length_new=8192
)
assert new_theta == 31250.
old_theta = rescale_theta(
theta_old=new_theta,
context_length_old=8192,
context_length_new=131_072
)
assert old_theta == 500_000.
def test_grouped_query_attention_equivalence(): def test_grouped_query_attention_equivalence():
torch.manual_seed(42) torch.manual_seed(42)
b, t, d_in, d_out, num_heads, num_kv_groups = 2, 8, 32, 64, 4, 2 b, t, d_in, d_out, num_heads, num_kv_groups = 2, 8, 32, 64, 4, 2
@ -194,6 +176,6 @@ def test_gpt_model_variants(ModelClass, llama3_weights_path):
) )
print("Encoded output text:", out) print("Encoded output text:", out)
expect = torch.tensor([ expect = torch.tensor([
[43, 2543, 292, 4483, 100383, 8113, 21197, 33804, 54419] [43, 2543, 292, 4483, 100383, 8113, 76873, 42175, 72641]
]) ])
assert torch.equal(expect, out) assert torch.equal(expect, out)

View File

@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "llms-from-scratch" name = "llms-from-scratch"
version = "1.0.6" version = "1.0.7"
description = "Implement a ChatGPT-like LLM in PyTorch from scratch, step by step" description = "Implement a ChatGPT-like LLM in PyTorch from scratch, step by step"
readme = "README.md" readme = "README.md"
requires-python = ">=3.10" requires-python = ">=3.10"