mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-12-03 10:30:50 +00:00
Reduce Llama 3 RoPE memory requirements (#658)
* Llama3 from scratch improvements * Fix Llama 3 expensive RoPE memory issue * updates * update package * benchmark * remove unused rescale_theta
This commit is contained in:
parent
c278745aff
commit
c4cde1c21b
3
.gitignore
vendored
3
.gitignore
vendored
@ -51,6 +51,9 @@ ch05/07_gpt_to_llama/Llama-3.2-3B-Instruct
|
||||
ch05/10_llm-training-speed/middlemarch.txt
|
||||
ch05/10_llm-training-speed/loss.pdf
|
||||
ch05/10_llm-training-speed/model.pth
|
||||
ch05/07_gpt_to_llama/Untitled.ipynb
|
||||
ch05/07_gpt_to_llama/llama3.2-1B-instruct.pth
|
||||
ch05/07_gpt_to_llama/tokenizer.model
|
||||
|
||||
ch06/01_main-chapter-code/gpt2
|
||||
ch06/02_bonus_additional-experiments/gpt2
|
||||
|
||||
@ -40,8 +40,6 @@ MODEL_FILE = "llama3.2-1B-instruct.pth"
|
||||
Basic text generation settings that can be defined by the user. Note that the recommended 8192-token context size requires approximately 3 GB of VRAM for the text generation example.
|
||||
|
||||
```python
|
||||
MODEL_CONTEXT_LENGTH = 8192 # Supports up to 131_072
|
||||
|
||||
# Text generation settings
|
||||
if "instruct" in MODEL_FILE:
|
||||
PROMPT = "What do llamas eat?"
|
||||
@ -82,8 +80,6 @@ elif "3B" in MODEL_FILE:
|
||||
else:
|
||||
raise ValueError("Incorrect model file name")
|
||||
|
||||
LLAMA32_CONFIG["context_length"] = MODEL_CONTEXT_LENGTH
|
||||
|
||||
model = Llama3Model(LLAMA32_CONFIG)
|
||||
model.load_state_dict(torch.load(MODEL_FILE, weights_only=True, map_location="cpu"))
|
||||
|
||||
@ -125,7 +121,7 @@ Lastly, we can generate text via the following code:
|
||||
```python
|
||||
import time
|
||||
|
||||
from ch05 import (
|
||||
from llms_from_scratch.ch05 import (
|
||||
generate,
|
||||
text_to_token_ids,
|
||||
token_ids_to_text
|
||||
@ -192,8 +188,8 @@ The following table shows a performance comparison on an A100:
|
||||
|
||||
| | Tokens/sec | Memory |
|
||||
| --------------- | ---------- | ------- |
|
||||
| Llama3Model | 50 | 2.91 GB |
|
||||
| Llama3ModelFast | 58 | 2.85 GB |
|
||||
| Llama3Model | 42 | 2.91 GB |
|
||||
| Llama3ModelFast | 54 | 2.91 GB |
|
||||
|
||||
|
||||
#### Pro tip 2: speed up inference with compilation
|
||||
@ -218,5 +214,5 @@ The following table shows a performance comparison on an A100 for consequent `ge
|
||||
|
||||
| | Tokens/sec | Memory |
|
||||
| --------------- | ---------- | ------- |
|
||||
| Llama3Model | 156 | 3.12 GB |
|
||||
| Llama3ModelFast | 159 | 2.84 GB |
|
||||
| Llama3Model | 170 | 3.12 GB |
|
||||
| Llama3ModelFast | 177 | 3.61 GB |
|
||||
|
||||
@ -95,9 +95,9 @@
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"blobfile version: 3.0.0\n",
|
||||
"huggingface_hub version: 0.24.7\n",
|
||||
"tiktoken version: 0.8.0\n",
|
||||
"torch version: 2.4.1+cu121\n"
|
||||
"huggingface_hub version: 0.30.1\n",
|
||||
"tiktoken version: 0.9.0\n",
|
||||
"torch version: 2.6.0\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -435,7 +435,7 @@
|
||||
"id": "842aa71a-4659-424e-8830-392bd6ae86af",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"- In addition, we also introduce a `SharedBuffers` class that will allow us to reuse the `mask`, `cos`, and `sin` tensors in the transformer blocks to improve efficiency (this will be crucial when working with models such as Llama 3.1 and 3.2 later, which support up to 131k input tokens)"
|
||||
"- **We also redesign the attention class a bit so it receives the mask through its forward method instead of storing and accessing it as `self.mask`. This lets us build the mask on the fly to reduce memory usage. To foreshadow why: Llama 3.1 can handle sequences of up to 128 k tokens, and precomputing a 128 k × 128 k causal mask would be extremely memory‑intensive, so we avoid it unless absolutely necessary.**"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -450,27 +450,6 @@
|
||||
"import torch.nn as nn\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"############################# NEW #############################\n",
|
||||
"class SharedBuffers:\n",
|
||||
" _buffers = {}\n",
|
||||
"\n",
|
||||
" @staticmethod\n",
|
||||
" def get_buffers(context_length, head_dim, rope_base, freq_config, dtype=torch.float32):\n",
|
||||
" key = (context_length, head_dim, rope_base, tuple(freq_config.values()) if freq_config else freq_config, dtype)\n",
|
||||
"\n",
|
||||
" if key not in SharedBuffers._buffers:\n",
|
||||
" # Create or fetch the buffers\n",
|
||||
" mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)\n",
|
||||
" cos, sin = precompute_rope_params(head_dim, rope_base, context_length, freq_config)\n",
|
||||
" if dtype is not None:\n",
|
||||
" cos = cos.to(dtype)\n",
|
||||
" sin = sin.to(dtype)\n",
|
||||
" SharedBuffers._buffers[key] = (mask, cos, sin)\n",
|
||||
"\n",
|
||||
" return SharedBuffers._buffers[key]\n",
|
||||
"############################# NEW #############################\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class GroupedQueryAttention(nn.Module):\n",
|
||||
" def __init__(\n",
|
||||
" self, d_in, d_out, context_length, num_heads,\n",
|
||||
@ -499,16 +478,12 @@
|
||||
" self.W_query = nn.Linear(d_in, d_out, bias=False, dtype=dtype)\n",
|
||||
" self.out_proj = nn.Linear(d_out, d_out, bias=False, dtype=dtype)\n",
|
||||
"\n",
|
||||
" ############################# NEW #############################\n",
|
||||
" # Fetch buffers using SharedBuffers\n",
|
||||
" mask, cos, sin = SharedBuffers.get_buffers(context_length, self.head_dim, rope_base, rope_config, dtype)\n",
|
||||
" ############################# NEW #############################\n",
|
||||
" \n",
|
||||
" self.register_buffer(\"mask\", mask)\n",
|
||||
" self.register_buffer(\"cos\", cos)\n",
|
||||
" self.register_buffer(\"sin\", sin)\n",
|
||||
"\n",
|
||||
" def forward(self, x):\n",
|
||||
" def forward(self, x, mask=None, cos=None, sin=None):\n",
|
||||
" ##################### NEW #####################\n",
|
||||
" # The forward method now accepts `mask` instead of accessing it via self.mask.\n",
|
||||
" # Also, we now have cos and sin as input for RoPE\n",
|
||||
" ################################################ \n",
|
||||
" b, num_tokens, d_in = x.shape\n",
|
||||
"\n",
|
||||
" queries = self.W_query(x) # Shape: (b, num_tokens, d_out)\n",
|
||||
@ -530,9 +505,12 @@
|
||||
" values = values.transpose(1, 2) # Shape: (b, num_heads, num_tokens, head_dim)\n",
|
||||
" queries = queries.transpose(1, 2) # Shape: (b, num_query_groups, num_tokens, head_dim)\n",
|
||||
"\n",
|
||||
" ##################### NEW #####################\n",
|
||||
" # Apply RoPE\n",
|
||||
" keys = compute_rope(keys, self.cos, self.sin)\n",
|
||||
" queries = compute_rope(queries, self.cos, self.sin)\n",
|
||||
" if cos is not None:\n",
|
||||
" keys = compute_rope(keys, cos, sin)\n",
|
||||
" queries = compute_rope(queries, cos, sin)\n",
|
||||
" ################################################\n",
|
||||
"\n",
|
||||
" ##################### NEW #####################\n",
|
||||
" # Expand keys and values to match the number of heads\n",
|
||||
@ -552,11 +530,14 @@
|
||||
" # Shape: (b, num_heads, num_tokens, num_tokens)\n",
|
||||
" attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head\n",
|
||||
"\n",
|
||||
" # Original mask truncated to the number of tokens and converted to boolean\n",
|
||||
" mask_bool = self.mask.bool()[:num_tokens, :num_tokens]\n",
|
||||
"\n",
|
||||
" ##################### NEW #####################\n",
|
||||
" # Create mask on the fly\n",
|
||||
" if mask is None:\n",
|
||||
" mask = torch.triu(torch.ones(num_tokens, num_tokens, device=x.device, dtype=torch.bool), diagonal=1)\n",
|
||||
" ################################################\n",
|
||||
" \n",
|
||||
" # Use the mask to fill attention scores\n",
|
||||
" attn_scores.masked_fill_(mask_bool, -torch.inf)\n",
|
||||
" attn_scores.masked_fill_(mask, -torch.inf)\n",
|
||||
"\n",
|
||||
" attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)\n",
|
||||
" assert keys.shape[-1] == self.head_dim\n",
|
||||
@ -578,7 +559,7 @@
|
||||
"id": "roAXSwJs9hR8"
|
||||
},
|
||||
"source": [
|
||||
"- To illustrate the parameter savings, consider the following multi-head attention example from the GPT and Llama 2 code:"
|
||||
"- To illustrate the parameter savings in GQA over MHA, consider the following multi-head attention example from the GPT and Llama 2 code:"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -753,7 +734,8 @@
|
||||
},
|
||||
"source": [
|
||||
"- Next, we update the `TransformerBlock`\n",
|
||||
"- Here, we simply swap `MultiHeadAttention` with `GroupedQueryAttention` and add the new RoPE settings"
|
||||
"- Here, we simply swap `MultiHeadAttention` with `GroupedQueryAttention` and add the new RoPE settings\n",
|
||||
"- In addition, we also modify the `forward` method so that it receives `mask`, `cos`, and `sin`; since the values for those are the same for each transformer block, we only have to compute them once and then can reuse them"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -782,11 +764,15 @@
|
||||
" self.norm1 = RMSNorm(cfg[\"emb_dim\"], eps=1e-5)\n",
|
||||
" self.norm2 = RMSNorm(cfg[\"emb_dim\"], eps=1e-5)\n",
|
||||
"\n",
|
||||
" def forward(self, x):\n",
|
||||
" def forward(self, x, mask=None, cos=None, sin=None):\n",
|
||||
" ##################### NEW #####################\n",
|
||||
" # The forward method now accepts `mask` instead of accessing it via self.mask.\n",
|
||||
" # Also, we now have cos and sin as input for RoPE\n",
|
||||
" ################################################\n",
|
||||
" # Shortcut connection for attention block\n",
|
||||
" shortcut = x\n",
|
||||
" x = self.norm1(x)\n",
|
||||
" x = self.att(x.to(torch.bfloat16)) # Shape [batch_size, num_tokens, emb_size]\n",
|
||||
" x = self.att(x.to(torch.bfloat16), mask, cos, sin) # Shape [batch_size, num_tokens, emb_size]\n",
|
||||
" x = x + shortcut # Add the original input back\n",
|
||||
"\n",
|
||||
" # Shortcut connection for feed-forward block\n",
|
||||
@ -816,7 +802,8 @@
|
||||
"id": "M_tLAq_r_llN"
|
||||
},
|
||||
"source": [
|
||||
"- When setting up the model class, we fortunately don't have to do much; we just update the name to `Llama3Model`"
|
||||
"- When setting up the model class, we technically don't have to do much; we just update the name to `Llama3Model`\n",
|
||||
"- However, since we now pass the `mask`, `cos`, and `sin` to the transformer blocks, we also have to add them here"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -840,12 +827,33 @@
|
||||
" self.final_norm = RMSNorm(cfg[\"emb_dim\"], eps=1e-5)\n",
|
||||
" self.out_head = nn.Linear(cfg[\"emb_dim\"], cfg[\"vocab_size\"], bias=False, dtype=cfg[\"dtype\"])\n",
|
||||
"\n",
|
||||
" #################### NEW #####################\n",
|
||||
" cos, sin = precompute_rope_params(\n",
|
||||
" head_dim=cfg[\"emb_dim\"] // cfg[\"n_heads\"],\n",
|
||||
" theta_base=cfg[\"rope_base\"],\n",
|
||||
" context_length=cfg[\"context_length\"],\n",
|
||||
" freq_config=cfg[\"rope_freq\"]\n",
|
||||
" )\n",
|
||||
" \n",
|
||||
" self.register_buffer(\"cos\", cos, persistent=False)\n",
|
||||
" self.register_buffer(\"sin\", sin, persistent=False)\n",
|
||||
" ##############################################\n",
|
||||
"\n",
|
||||
" self.cfg = cfg\n",
|
||||
"\n",
|
||||
" def forward(self, in_idx):\n",
|
||||
" tok_embeds = self.tok_emb(in_idx)\n",
|
||||
" x = tok_embeds\n",
|
||||
" x = self.trf_blocks(x)\n",
|
||||
"\n",
|
||||
" #################### NEW #####################\n",
|
||||
" num_tokens = x.shape[1]\n",
|
||||
" mask = torch.triu(torch.ones(num_tokens, num_tokens, device=x.device, dtype=torch.bool), diagonal=1)\n",
|
||||
" ##############################################\n",
|
||||
" \n",
|
||||
" for block in self.trf_blocks:\n",
|
||||
" x = block(x, mask, self.cos, self.sin)\n",
|
||||
" x = self.final_norm(x)\n",
|
||||
" logits = self.out_head(x.to(torch.bfloat16))\n",
|
||||
" logits = self.out_head(x.to(self.cfg[\"dtype\"]))\n",
|
||||
" return logits"
|
||||
]
|
||||
},
|
||||
@ -936,33 +944,12 @@
|
||||
"model = Llama3Model(LLAMA3_CONFIG_8B)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "edea6334-d1fc-427d-9cf2-4af963ff4bfc",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"- The following is expected to print True to confirm buffers are reused instead of being (wastefully) recreated:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "ee9625cc-9afa-4b11-8aab-d536fd170761",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Check buffers\n",
|
||||
"print(model.trf_blocks[0].att.mask is model.trf_blocks[-1].att.mask)\n",
|
||||
"print(model.trf_blocks[0].att.cos is model.trf_blocks[-1].att.cos)\n",
|
||||
"print(model.trf_blocks[0].att.sin is model.trf_blocks[-1].att.sin) "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "8056a521-91a6-440f-8473-591409c3177b",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"- Let's now also compute the number of trainable parameters:"
|
||||
"- Let's now compute the number of trainable parameters:"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -1017,8 +1004,8 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"float32 (PyTorch default): 68.08 GB\n",
|
||||
"bfloat16: 34.04 GB\n"
|
||||
"float32 (PyTorch default): 59.84 GB\n",
|
||||
"bfloat16: 29.92 GB\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -1121,43 +1108,47 @@
|
||||
"\n",
|
||||
"\n",
|
||||
"class Tokenizer:\n",
|
||||
" \"\"\"Thin wrapper around tiktoken that keeps track of Llama-3 special IDs.\"\"\"\n",
|
||||
" def __init__(self, model_path):\n",
|
||||
" assert os.path.isfile(model_path), f\"Model file {model_path} not found\"\n",
|
||||
" mergeable_ranks = load_tiktoken_bpe(model_path)\n",
|
||||
" if not os.path.isfile(model_path):\n",
|
||||
" raise FileNotFoundError(model_path)\n",
|
||||
"\n",
|
||||
" self.special_tokens = {\n",
|
||||
" mergeable = load_tiktoken_bpe(model_path)\n",
|
||||
"\n",
|
||||
" # hard-coded from Meta's tokenizer.json\n",
|
||||
" self.special = {\n",
|
||||
" \"<|begin_of_text|>\": 128000,\n",
|
||||
" \"<|end_of_text|>\": 128001,\n",
|
||||
" \"<|start_header_id|>\": 128006,\n",
|
||||
" \"<|end_header_id|>\": 128007,\n",
|
||||
" \"<|eot_id|>\": 128009,\n",
|
||||
" }\n",
|
||||
" self.special_tokens.update({\n",
|
||||
" f\"<|reserved_{i}|>\": 128002 + i for i in range(256) if (128002 + i) not in self.special_tokens.values()\n",
|
||||
" })\n",
|
||||
" self.special.update({f\"<|reserved_{i}|>\": 128002 + i\n",
|
||||
" for i in range(256)\n",
|
||||
" if 128002 + i not in self.special.values()})\n",
|
||||
"\n",
|
||||
" self.model = tiktoken.Encoding(\n",
|
||||
" name=Path(model_path).name,\n",
|
||||
" pat_str=r\"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+\",\n",
|
||||
" mergeable_ranks=mergeable_ranks,\n",
|
||||
" special_tokens=self.special_tokens\n",
|
||||
" pat_str=r\"(?i:'s|'t|'re|'ve|'m|'ll|'d)\"\n",
|
||||
" r\"|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+\"\n",
|
||||
" r\"|\\p{N}{1,3}\"\n",
|
||||
" r\"| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*\"\n",
|
||||
" r\"|\\s*[\\r\\n]+\"\n",
|
||||
" r\"|\\s+(?!\\S)\"\n",
|
||||
" r\"|\\s+\",\n",
|
||||
" mergeable_ranks=mergeable,\n",
|
||||
" special_tokens=self.special,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" def encode(self, text, bos=False, eos=False, allowed_special=set(), disallowed_special=()):\n",
|
||||
" if bos:\n",
|
||||
" tokens = [self.special_tokens[\"<|begin_of_text|>\"]]\n",
|
||||
" else:\n",
|
||||
" tokens = []\n",
|
||||
"\n",
|
||||
" tokens += self.model.encode(text, allowed_special=allowed_special, disallowed_special=disallowed_special)\n",
|
||||
"\n",
|
||||
" def encode(self, text, bos=False, eos=False):\n",
|
||||
" ids = ([self.special[\"<|begin_of_text|>\"]] if bos else []) \\\n",
|
||||
" + self.model.encode(text)\n",
|
||||
" if eos:\n",
|
||||
" tokens.append(self.special_tokens[\"<|end_of_text|>\"])\n",
|
||||
" return tokens\n",
|
||||
" ids.append(self.special[\"<|end_of_text|>\"])\n",
|
||||
" return ids\n",
|
||||
"\n",
|
||||
" def decode(self, tokens):\n",
|
||||
" return self.model.decode(tokens)"
|
||||
" def decode(self, ids):\n",
|
||||
" return self.model.decode(ids)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -1202,13 +1193,11 @@
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.\n",
|
||||
"Token is valid (permission: read).\n",
|
||||
"Your token has been saved to /root/.cache/huggingface/token\n",
|
||||
"Login successful\n"
|
||||
"/Users/sebastian/Developer/LLMs-from-scratch/.venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
||||
" from .autonotebook import tqdm as notebook_tqdm\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -1309,7 +1298,8 @@
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "e0a2b5cd-6cba-4d72-b8ff-04d8315d483e",
|
||||
"outputId": "990d7b74-cb35-476b-d8bd-d544006e00f4"
|
||||
"outputId": "990d7b74-cb35-476b-d8bd-d544006e00f4",
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
@ -1318,7 +1308,9 @@
|
||||
"text": [
|
||||
"Output text:\n",
|
||||
" Every effort_dead aeros Ingredients başında.extensionégor clangmissions güc như submodule.and report官方%,.Reader(\",\");\n",
|
||||
"ामल ندار Parliamentary !!! HigginsDynamicZhgmt writeln Globalsletion 사진------\n"
|
||||
"ामल ندار Parliamentary !!! HigginsDynamicZhamincus_beam cyc......\n",
|
||||
"\n",
|
||||
" haciendo\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -1437,22 +1429,7 @@
|
||||
"id": "5fa9c06c-7a53-4b4d-9ce4-acc027322ee4",
|
||||
"outputId": "c05118ce-9f81-41c8-a1f2-72caa932ae86"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "245443330e4d40c887a5649cc1663e98",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"model-00001-of-00004.safetensors: 0%| | 0.00/4.98G [00:00<?, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from safetensors.torch import load_file\n",
|
||||
"\n",
|
||||
@ -1763,64 +1740,7 @@
|
||||
"id": "nbvAV7vaz6yc",
|
||||
"outputId": "9e1badc9-a6c4-48b7-9125-e0810655528b"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "f7df6bbf8e63448c8a6cb5d2f6208403",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"model-00001-of-00004.safetensors: 36%|###6 | 1.81G/4.98G [00:00<?, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "4772f31a1c5b4c168c9aabe7a1d2bacc",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"model-00002-of-00004.safetensors: 0%| | 0.00/5.00G [00:00<?, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "ad49eeb9e1204ea2bd2e371df8ccdea2",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"model-00003-of-00004.safetensors: 0%| | 0.00/4.92G [00:00<?, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "951b9e81613a40a2a503f61e69677f0a",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"model-00004-of-00004.safetensors: 0%| | 0.00/1.17G [00:00<?, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"combined_weights = {}\n",
|
||||
"\n",
|
||||
@ -1861,35 +1781,40 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class ChatFormat:\n",
|
||||
" def __init__(self, tokenizer):\n",
|
||||
" self.tokenizer = tokenizer\n",
|
||||
"\n",
|
||||
" def encode_header(self, message):\n",
|
||||
" tokens = []\n",
|
||||
" tokens.append(self.tokenizer.special_tokens[\"<|start_header_id|>\"])\n",
|
||||
" tokens.extend(self.tokenizer.encode(message[\"role\"], bos=False, eos=False))\n",
|
||||
" tokens.append(self.tokenizer.special_tokens[\"<|end_header_id|>\"])\n",
|
||||
" tokens.extend(self.tokenizer.encode(\"\\n\\n\", bos=False, eos=False))\n",
|
||||
" return tokens\n",
|
||||
" def __init__(self, tokenizer: Tokenizer, *,\n",
|
||||
" default_system=\"You are a helpful assistant.\"):\n",
|
||||
" self.tok = tokenizer\n",
|
||||
" self.default_system = default_system\n",
|
||||
"\n",
|
||||
" def encode(self, text):\n",
|
||||
" message = {\n",
|
||||
" \"role\": \"user\",\n",
|
||||
" \"content\": text\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" tokens = self.encode_header(message)\n",
|
||||
" tokens.extend(\n",
|
||||
" self.tokenizer.encode(message[\"content\"].strip(), bos=False, eos=False)\n",
|
||||
" def _header(self, role):\n",
|
||||
" \"\"\"Encode <|start_header_id|>role<|end_header_id|>\\n\\n\"\"\"\n",
|
||||
" return (\n",
|
||||
" [self.tok.special[\"<|start_header_id|>\"]]\n",
|
||||
" + self.tok.encode(role)\n",
|
||||
" + [self.tok.special[\"<|end_header_id|>\"]]\n",
|
||||
" + self.tok.encode(\"\\n\\n\")\n",
|
||||
" )\n",
|
||||
" tokens.append(self.tokenizer.special_tokens[\"<|eot_id|>\"])\n",
|
||||
" return tokens\n",
|
||||
"\n",
|
||||
" def decode(self, token_ids):\n",
|
||||
" return self.tokenizer.decode(token_ids)\n",
|
||||
" def encode(self, user_message, system_message=None):\n",
|
||||
" sys_msg = system_message if system_message is not None else self.default_system\n",
|
||||
"\n",
|
||||
" ids = [self.tok.special[\"<|begin_of_text|>\"]]\n",
|
||||
"\n",
|
||||
"chat_tokenizer = ChatFormat(tokenizer)"
|
||||
" # system\n",
|
||||
" ids += self._header(\"system\")\n",
|
||||
" ids += self.tok.encode(sys_msg)\n",
|
||||
" ids += [self.tok.special[\"<|eot_id|>\"]]\n",
|
||||
"\n",
|
||||
" # user\n",
|
||||
" ids += self._header(\"user\")\n",
|
||||
" ids += self.tok.encode(user_message)\n",
|
||||
" ids += [self.tok.special[\"<|eot_id|>\"]]\n",
|
||||
"\n",
|
||||
" # assistant header (no content yet)\n",
|
||||
" ids += self._header(\"assistant\")\n",
|
||||
"\n",
|
||||
" return ids"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -1918,11 +1843,14 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"[128006, 882, 128007, 271, 9906, 4435, 0, 128009]\n"
|
||||
"[128000, 128006, 9125, 128007, 271, 2675, 527, 264, 11190, 18328, 13, 128009, 128006, 882, 128007, 271, 9906, 4435, 0, 128009, 128006, 78191, 128007, 271]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"tokenizer = Tokenizer(tokenizer_file_path)\n",
|
||||
"chat_tokenizer = ChatFormat(tokenizer)\n",
|
||||
"\n",
|
||||
"token_ids = chat_tokenizer.encode(\"Hello World!\")\n",
|
||||
"print(token_ids)"
|
||||
]
|
||||
@ -1943,7 +1871,7 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'<|start_header_id|>user<|end_header_id|>\\n\\nHello World!<|eot_id|>'"
|
||||
"'<|begin_of_text|><|start_header_id|>system<|end_header_id|>\\n\\nYou are a helpful assistant.<|eot_id|><|start_header_id|>user<|end_header_id|>\\n\\nHello World!<|eot_id|><|start_header_id|>assistant<|end_header_id|>\\n\\n'"
|
||||
]
|
||||
},
|
||||
"execution_count": 35,
|
||||
@ -1982,12 +1910,13 @@
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Output text:\n",
|
||||
" Llamas are herbivores, which means they primarily eat plants and plant-based foods. Here are some of the things llamas like to eat:\n",
|
||||
" Llamas are herbivores, which means they primarily eat plants and plant-based foods. Their diet typically consists of:\n",
|
||||
"\n",
|
||||
"1. Grass: Llamas love to graze on grass, especially in the spring and summer months.\n",
|
||||
"2. Hay: Hay is a staple in a llama's diet. They like to eat timothy hay, alfalfa hay, and other types of hay.\n",
|
||||
"3. Grains: Llamas may also be fed grains like oats, barley, and corn. However, grains should not make up more than 10-15% of a llama's diet.\n",
|
||||
"4. Fruits and vegetables: Llamas may enjoy fruits and vegetables as treats, such as\n"
|
||||
"1. Grasses: Llamas love to graze on grasses, including tall grasses, short grasses, and even weeds.\n",
|
||||
"2. Hay: Hay is a staple in a llama's diet. They enjoy a variety of hays, such as timothy hay, alfalfa hay, and oat hay.\n",
|
||||
"3. Grains: Llamas may be fed grains like oats, corn, and barley as a supplement to their diet.\n",
|
||||
"4. Fruits and vegetables: Llamas enjoy fruits and vegetables like apples, carrots, and sweet potatoes as treats or additions to their diet.\n",
|
||||
"5. Minerals:\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -2088,49 +2017,6 @@
|
||||
"}"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "d81ee464-c112-43b0-9ee8-70df6ac942d0",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"- Reduce the context length so the model would work fine on a MacBook Air (if you have more RAM, feel free to comment out the lines below):"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"id": "a55a8769-1a03-4265-8fd0-15f1c423da53",
|
||||
"metadata": {
|
||||
"id": "a8bc2370-39d2-4bfe-b4c1-6bdd75fe101c"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"New RoPE theta: 31250.0\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"old_context_length = LLAMA31_CONFIG_8B[\"context_length\"]\n",
|
||||
"LLAMA31_CONFIG_8B[\"context_length\"] = 8192\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def rescale_theta(theta_old, context_length_old, context_length_new):\n",
|
||||
" scaling_factor = context_length_new / context_length_old\n",
|
||||
" theta_new = theta_old * scaling_factor\n",
|
||||
" return theta_new\n",
|
||||
"\n",
|
||||
"LLAMA31_CONFIG_8B[\"rope_base\"] = rescale_theta(\n",
|
||||
" LLAMA31_CONFIG_8B[\"rope_base\"],\n",
|
||||
" old_context_length,\n",
|
||||
" LLAMA31_CONFIG_8B[\"context_length\"]\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"print(\"New RoPE theta:\", LLAMA31_CONFIG_8B[\"rope_base\"])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "xa3bpMDtTdBs",
|
||||
@ -2277,64 +2163,7 @@
|
||||
"id": "u4J7IxOvOyPM",
|
||||
"outputId": "925348d7-fc69-4d1b-90f1-7029426bcfcf"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "eabfde3ef38b436ea750e6fb50a02b5c",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"model-00001-of-00004.safetensors: 0%| | 0.00/4.98G [00:00<?, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "e117ad45771747ae95c16f9876e6dc19",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"model-00002-of-00004.safetensors: 0%| | 0.00/5.00G [00:00<?, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "170185f2f046437dab57c2ad23163c5c",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"model-00003-of-00004.safetensors: 0%| | 0.00/4.92G [00:00<?, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "6e65f5d6c5af4ab78bc7b3778b98ef86",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"model-00004-of-00004.safetensors: 0%| | 0.00/1.17G [00:00<?, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
}
|
||||
],
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"combined_weights = {}\n",
|
||||
"\n",
|
||||
@ -2481,43 +2310,6 @@
|
||||
"}"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "b5cd351b-d883-460d-9cdc-47e15ddb884a",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"- Reduce the context length so the model would work fine on a MacBook Air (if you have more RAM, feel free to comment out the lines below):"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"id": "73f001a6-7ae0-4204-aa83-a27a8878dfd2",
|
||||
"metadata": {
|
||||
"id": "a8bc2370-39d2-4bfe-b4c1-6bdd75fe101c"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"New RoPE theta: 31250.0\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"old_context_length = LLAMA32_CONFIG_1B[\"context_length\"]\n",
|
||||
"LLAMA32_CONFIG_1B[\"context_length\"] = 8192\n",
|
||||
"\n",
|
||||
"LLAMA32_CONFIG_1B[\"rope_base\"] = rescale_theta(\n",
|
||||
" LLAMA32_CONFIG_1B[\"rope_base\"],\n",
|
||||
" old_context_length,\n",
|
||||
" LLAMA32_CONFIG_1B[\"context_length\"]\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"print(\"New RoPE theta:\", LLAMA32_CONFIG_1B[\"rope_base\"])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "Dl4_0EoJKKYv",
|
||||
@ -2612,20 +2404,6 @@
|
||||
"outputId": "35588405-e2e1-4871-a1db-1d4bcb852e49"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "c309c56a6cdf426e8ba7967b6a21864e",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"model.safetensors: 0%| | 0.00/2.47G [00:00<?, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
@ -2688,7 +2466,7 @@
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Output text:\n",
|
||||
" Every effort is made to ensure that the information on this website is accurate. However, we cannot guarantee that the information is accurate, complete\n"
|
||||
" Every effort is made to ensure that the information on this website is accurate and up to date. However, the information is provided without any\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -56,7 +56,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 1,
|
||||
"id": "7c201adb-747e-437b-9a62-442802941e01",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -66,7 +66,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": 2,
|
||||
"id": "dd1b65a8-4301-444a-bd7c-a6f2bd1df9df",
|
||||
"metadata": {
|
||||
"colab": {
|
||||
@ -81,9 +81,9 @@
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"blobfile version: 3.0.0\n",
|
||||
"huggingface_hub version: 0.25.2\n",
|
||||
"tiktoken version: 0.8.0\n",
|
||||
"torch version: 2.5.0\n"
|
||||
"huggingface_hub version: 0.30.1\n",
|
||||
"tiktoken version: 0.9.0\n",
|
||||
"torch version: 2.6.0\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -113,7 +113,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"execution_count": 3,
|
||||
"id": "82076c21-9331-4dcd-b017-42b046cf1a60",
|
||||
"metadata": {
|
||||
"id": "82076c21-9331-4dcd-b017-42b046cf1a60"
|
||||
@ -140,18 +140,18 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"execution_count": 4,
|
||||
"id": "4b9a346f-5826-4083-9162-abd56afc03f0",
|
||||
"metadata": {
|
||||
"id": "4b9a346f-5826-4083-9162-abd56afc03f0"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def precompute_rope_params(head_dim, theta_base=10_000, context_length=4096, freq_config=None):\n",
|
||||
"def compute_rope_params(head_dim, theta_base=10_000, context_length=4096, freq_config=None, dtype=torch.float32):\n",
|
||||
" assert head_dim % 2 == 0, \"Embedding dimension must be even\"\n",
|
||||
"\n",
|
||||
" # Compute the inverse frequencies\n",
|
||||
" inv_freq = 1.0 / (theta_base ** (torch.arange(0, head_dim, 2)[: (head_dim // 2)].float() / head_dim))\n",
|
||||
" inv_freq = 1.0 / (theta_base ** (torch.arange(0, head_dim, 2, dtype=dtype)[: (head_dim // 2)].float() / head_dim))\n",
|
||||
"\n",
|
||||
" # Frequency adjustments\n",
|
||||
" if freq_config is not None:\n",
|
||||
@ -177,7 +177,7 @@
|
||||
" inv_freq = inv_freq_llama\n",
|
||||
"\n",
|
||||
" # Generate position indices\n",
|
||||
" positions = torch.arange(context_length)\n",
|
||||
" positions = torch.arange(context_length, dtype=dtype)\n",
|
||||
"\n",
|
||||
" # Compute the angles\n",
|
||||
" angles = positions[:, None] * inv_freq[None, :] # Shape: (context_length, head_dim // 2)\n",
|
||||
@ -192,7 +192,7 @@
|
||||
" return cos, sin\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def compute_rope(x, cos, sin):\n",
|
||||
"def apply_rope(x, cos, sin):\n",
|
||||
" # x: (batch_size, num_heads, seq_len, head_dim)\n",
|
||||
" batch_size, num_heads, seq_len, head_dim = x.shape\n",
|
||||
" assert head_dim % 2 == 0, \"Head dimension must be even\"\n",
|
||||
@ -209,43 +209,23 @@
|
||||
" rotated = torch.cat((-x2, x1), dim=-1)\n",
|
||||
" x_rotated = (x * cos) + (rotated * sin)\n",
|
||||
"\n",
|
||||
" # It's ok to use lower-precision after applying cos and sin rotation\n",
|
||||
" return x_rotated.to(dtype=x.dtype)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"execution_count": 5,
|
||||
"id": "e8169ab5-f976-4222-a2e1-eb1cabf267cb",
|
||||
"metadata": {
|
||||
"id": "e8169ab5-f976-4222-a2e1-eb1cabf267cb"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"class SharedBuffers:\n",
|
||||
" _buffers = {}\n",
|
||||
"\n",
|
||||
" @staticmethod\n",
|
||||
" def get_buffers(context_length, head_dim, rope_base, freq_config, dtype=torch.float32):\n",
|
||||
" key = (context_length, head_dim, rope_base, tuple(freq_config.values()) if freq_config else freq_config, dtype)\n",
|
||||
"\n",
|
||||
" if key not in SharedBuffers._buffers:\n",
|
||||
" # Create or fetch the buffers\n",
|
||||
" mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)\n",
|
||||
" cos, sin = precompute_rope_params(head_dim, rope_base, context_length, freq_config)\n",
|
||||
" if dtype is not None:\n",
|
||||
" cos = cos.to(dtype)\n",
|
||||
" sin = sin.to(dtype)\n",
|
||||
" SharedBuffers._buffers[key] = (mask, cos, sin)\n",
|
||||
"\n",
|
||||
" return SharedBuffers._buffers[key]\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class GroupedQueryAttention(nn.Module):\n",
|
||||
" def __init__(\n",
|
||||
" self, d_in, d_out, context_length, num_heads,\n",
|
||||
" self, d_in, d_out, num_heads,\n",
|
||||
" num_kv_groups,\n",
|
||||
" rope_base=10_000,\n",
|
||||
" rope_config=None,\n",
|
||||
" dtype=None\n",
|
||||
" ):\n",
|
||||
" super().__init__()\n",
|
||||
@ -264,14 +244,7 @@
|
||||
" self.W_query = nn.Linear(d_in, d_out, bias=False, dtype=dtype)\n",
|
||||
" self.out_proj = nn.Linear(d_out, d_out, bias=False, dtype=dtype)\n",
|
||||
"\n",
|
||||
" # Fetch buffers using SharedBuffers\n",
|
||||
" mask, cos, sin = SharedBuffers.get_buffers(context_length, self.head_dim, rope_base, rope_config, dtype)\n",
|
||||
" self.register_buffer(\"mask\", mask, persistent=False)\n",
|
||||
"\n",
|
||||
" self.register_buffer(\"cos\", cos, persistent=False)\n",
|
||||
" self.register_buffer(\"sin\", sin, persistent=False)\n",
|
||||
"\n",
|
||||
" def forward(self, x):\n",
|
||||
" def forward(self, x, mask, cos, sin):\n",
|
||||
" b, num_tokens, d_in = x.shape\n",
|
||||
"\n",
|
||||
" queries = self.W_query(x) # Shape: (b, num_tokens, d_out)\n",
|
||||
@ -289,8 +262,8 @@
|
||||
" queries = queries.transpose(1, 2) # Shape: (b, num_query_groups, num_tokens, head_dim)\n",
|
||||
"\n",
|
||||
" # Apply RoPE\n",
|
||||
" keys = compute_rope(keys, self.cos, self.sin)\n",
|
||||
" queries = compute_rope(queries, self.cos, self.sin)\n",
|
||||
" keys = apply_rope(keys, cos, sin)\n",
|
||||
" queries = apply_rope(queries, cos, sin)\n",
|
||||
"\n",
|
||||
" # Expand keys and values to match the number of heads\n",
|
||||
" # Shape: (b, num_heads, num_tokens, head_dim)\n",
|
||||
@ -307,11 +280,8 @@
|
||||
" # Shape: (b, num_heads, num_tokens, num_tokens)\n",
|
||||
" attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head\n",
|
||||
"\n",
|
||||
" # Original mask truncated to the number of tokens and converted to boolean\n",
|
||||
" mask_bool = self.mask.bool()[:num_tokens, :num_tokens]\n",
|
||||
"\n",
|
||||
" # Use the mask to fill attention scores\n",
|
||||
" attn_scores.masked_fill_(mask_bool, -torch.inf)\n",
|
||||
" # Compute attention scores\n",
|
||||
" attn_scores = attn_scores.masked_fill(mask, -torch.inf)\n",
|
||||
"\n",
|
||||
" attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)\n",
|
||||
" assert keys.shape[-1] == self.head_dim\n",
|
||||
@ -328,7 +298,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"execution_count": 6,
|
||||
"id": "457cb2f8-50c1-4045-8a74-f181bfb5fea9",
|
||||
"metadata": {
|
||||
"id": "457cb2f8-50c1-4045-8a74-f181bfb5fea9"
|
||||
@ -338,31 +308,28 @@
|
||||
"class TransformerBlock(nn.Module):\n",
|
||||
" def __init__(self, cfg):\n",
|
||||
" super().__init__()\n",
|
||||
" self.att = GroupedQueryAttention(\n",
|
||||
" self.att = GroupedQueryAttention(\n",
|
||||
" d_in=cfg[\"emb_dim\"],\n",
|
||||
" d_out=cfg[\"emb_dim\"],\n",
|
||||
" context_length=cfg[\"context_length\"],\n",
|
||||
" num_heads=cfg[\"n_heads\"],\n",
|
||||
" num_kv_groups=cfg[\"n_kv_groups\"],\n",
|
||||
" rope_base=cfg[\"rope_base\"],\n",
|
||||
" rope_config=cfg[\"rope_freq\"],\n",
|
||||
" dtype=cfg[\"dtype\"]\n",
|
||||
" )\n",
|
||||
" self.ff = FeedForward(cfg)\n",
|
||||
" self.norm1 = nn.RMSNorm(cfg[\"emb_dim\"], eps=1e-5)\n",
|
||||
" self.norm2 = nn.RMSNorm(cfg[\"emb_dim\"], eps=1e-5)\n",
|
||||
" self.norm1 = nn.RMSNorm(cfg[\"emb_dim\"], eps=1e-5, dtype=cfg[\"dtype\"])\n",
|
||||
" self.norm2 = nn.RMSNorm(cfg[\"emb_dim\"], eps=1e-5, dtype=cfg[\"dtype\"])\n",
|
||||
"\n",
|
||||
" def forward(self, x):\n",
|
||||
" def forward(self, x, mask, cos, sin):\n",
|
||||
" # Shortcut connection for attention block\n",
|
||||
" shortcut = x\n",
|
||||
" x = self.norm1(x)\n",
|
||||
" x = self.att(x.to(torch.bfloat16)) # Shape [batch_size, num_tokens, emb_size]\n",
|
||||
" x = self.att(x, mask, cos, sin) # Shape [batch_size, num_tokens, emb_size]\n",
|
||||
" x = x + shortcut # Add the original input back\n",
|
||||
"\n",
|
||||
" # Shortcut connection for feed-forward block\n",
|
||||
" shortcut = x\n",
|
||||
" x = self.norm2(x)\n",
|
||||
" x = self.ff(x.to(torch.bfloat16))\n",
|
||||
" x = self.ff(x)\n",
|
||||
" x = x + shortcut # Add the original input back\n",
|
||||
"\n",
|
||||
" return x"
|
||||
@ -370,7 +337,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"execution_count": 7,
|
||||
"id": "e88de3e3-9f07-42cc-816b-28dbd46e96c4",
|
||||
"metadata": {
|
||||
"id": "e88de3e3-9f07-42cc-816b-28dbd46e96c4"
|
||||
@ -380,20 +347,41 @@
|
||||
"class Llama3Model(nn.Module):\n",
|
||||
" def __init__(self, cfg):\n",
|
||||
" super().__init__()\n",
|
||||
"\n",
|
||||
" # Main model parameters\n",
|
||||
" self.tok_emb = nn.Embedding(cfg[\"vocab_size\"], cfg[\"emb_dim\"], dtype=cfg[\"dtype\"])\n",
|
||||
"\n",
|
||||
" self.trf_blocks = nn.Sequential(\n",
|
||||
" *[TransformerBlock(cfg) for _ in range(cfg[\"n_layers\"])])\n",
|
||||
" self.trf_blocks = nn.ModuleList( # ModuleList since Sequential can only accept one input, and we need `x, mask, cos, sin`\n",
|
||||
" [TransformerBlock(cfg) for _ in range(cfg[\"n_layers\"])]\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" self.final_norm = nn.RMSNorm(cfg[\"emb_dim\"], eps=1e-5)\n",
|
||||
" self.final_norm = nn.RMSNorm(cfg[\"emb_dim\"], eps=1e-5, dtype=cfg[\"dtype\"])\n",
|
||||
" self.out_head = nn.Linear(cfg[\"emb_dim\"], cfg[\"vocab_size\"], bias=False, dtype=cfg[\"dtype\"])\n",
|
||||
"\n",
|
||||
" # Reusuable utilities\n",
|
||||
" cos, sin = compute_rope_params(\n",
|
||||
" head_dim=cfg[\"emb_dim\"] // cfg[\"n_heads\"],\n",
|
||||
" theta_base=cfg[\"rope_base\"],\n",
|
||||
" context_length=cfg[\"context_length\"],\n",
|
||||
" freq_config=cfg[\"rope_freq\"]\n",
|
||||
" )\n",
|
||||
" self.register_buffer(\"cos\", cos, persistent=False)\n",
|
||||
" self.register_buffer(\"sin\", sin, persistent=False)\n",
|
||||
" self.cfg = cfg\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" def forward(self, in_idx):\n",
|
||||
" # Forward pass\n",
|
||||
" tok_embeds = self.tok_emb(in_idx)\n",
|
||||
" x = tok_embeds\n",
|
||||
" x = self.trf_blocks(x)\n",
|
||||
"\n",
|
||||
" num_tokens = x.shape[1]\n",
|
||||
" mask = torch.triu(torch.ones(num_tokens, num_tokens, device=x.device, dtype=torch.bool), diagonal=1)\n",
|
||||
" \n",
|
||||
" for block in self.trf_blocks:\n",
|
||||
" x = block(x, mask, self.cos, self.sin)\n",
|
||||
" x = self.final_norm(x)\n",
|
||||
" logits = self.out_head(x.to(torch.bfloat16))\n",
|
||||
" logits = self.out_head(x.to(self.cfg[\"dtype\"]))\n",
|
||||
" return logits"
|
||||
]
|
||||
},
|
||||
@ -420,7 +408,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"execution_count": 8,
|
||||
"id": "caa142fa-b375-4e78-b392-2072ced666f3",
|
||||
"metadata": {
|
||||
"id": "caa142fa-b375-4e78-b392-2072ced666f3"
|
||||
@ -430,16 +418,16 @@
|
||||
"# Llama 3.2 1B\n",
|
||||
"\n",
|
||||
"LLAMA32_CONFIG = {\n",
|
||||
" \"vocab_size\": 128_256, # Vocabulary size\n",
|
||||
" \"context_length\": 131_072, # Context length\n",
|
||||
" \"emb_dim\": 2048, # Embedding dimension\n",
|
||||
" \"n_heads\": 32, # Number of attention heads\n",
|
||||
" \"n_layers\": 16, # Number of layers\n",
|
||||
" \"hidden_dim\": 8192, # Size of the intermediate dimension in FeedForward\n",
|
||||
" \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n",
|
||||
" \"rope_base\": 500_000.0, # The base in RoPE's \"theta\"\n",
|
||||
" \"dtype\": torch.bfloat16, # Lower-precision dtype to reduce memory usage\n",
|
||||
" \"rope_freq\": { # RoPE frequency scaling\n",
|
||||
" \"vocab_size\": 128_256, # Vocabulary size\n",
|
||||
" \"context_length\": 131_072, # Context length that was used to train the model\n",
|
||||
" \"emb_dim\": 2048, # Embedding dimension\n",
|
||||
" \"n_heads\": 32, # Number of attention heads\n",
|
||||
" \"n_layers\": 16, # Number of layers\n",
|
||||
" \"hidden_dim\": 8192, # Size of the intermediate dimension in FeedForward\n",
|
||||
" \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n",
|
||||
" \"rope_base\": 500_000.0, # The base in RoPE's \"theta\"\n",
|
||||
" \"dtype\": torch.bfloat16, # Lower-precision dtype to reduce memory usage\n",
|
||||
" \"rope_freq\": { # RoPE frequency scaling\n",
|
||||
" \"factor\": 32.0,\n",
|
||||
" \"low_freq_factor\": 1.0,\n",
|
||||
" \"high_freq_factor\": 4.0,\n",
|
||||
@ -450,16 +438,16 @@
|
||||
"# Llama 3.2 3B\n",
|
||||
"\n",
|
||||
"# LLAMA32_CONFIG = {\n",
|
||||
"# \"vocab_size\": 128_256, # Vocabulary size\n",
|
||||
"# \"context_length\": 131_072, # Context length\n",
|
||||
"# \"emb_dim\": 3072, # Embedding dimension\n",
|
||||
"# \"n_heads\": 24, # Number of attention heads\n",
|
||||
"# \"n_layers\": 28, # Number of layers\n",
|
||||
"# \"hidden_dim\": 8192, # Size of the intermediate dimension in FeedForward\n",
|
||||
"# \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n",
|
||||
"# \"rope_base\": 500_000.0, # The base in RoPE's \"theta\"\n",
|
||||
"# \"dtype\": torch.bfloat16, # Lower-precision dtype to reduce memory usage\n",
|
||||
"# \"rope_freq\": { # RoPE frequency scaling\n",
|
||||
"# \"vocab_size\": 128_256, # Vocabulary size\n",
|
||||
"# \"context_length\": 131_072, # Context length that was used to train the model\n",
|
||||
"# \"emb_dim\": 3072, # Embedding dimension\n",
|
||||
"# \"n_heads\": 24, # Number of attention heads\n",
|
||||
"# \"n_layers\": 28, # Number of layers\n",
|
||||
"# \"hidden_dim\": 8192, # Size of the intermediate dimension in FeedForward\n",
|
||||
"# \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n",
|
||||
"# \"rope_base\": 500_000.0, # The base in RoPE's \"theta\"\n",
|
||||
"# \"dtype\": torch.bfloat16, # Lower-precision dtype to reduce memory usage\n",
|
||||
"# \"rope_freq\": { # RoPE frequency scaling\n",
|
||||
"# \"factor\": 32.0,\n",
|
||||
"# \"low_freq_factor\": 1.0,\n",
|
||||
"# \"high_freq_factor\": 4.0,\n",
|
||||
@ -470,54 +458,9 @@
|
||||
"LLAMA_SIZE_STR = \"1B\" if LLAMA32_CONFIG[\"emb_dim\"] == 2048 else \"3B\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "34535172-797e-4dd0-84fb-65bc75ad5b06",
|
||||
"metadata": {
|
||||
"id": "34535172-797e-4dd0-84fb-65bc75ad5b06"
|
||||
},
|
||||
"source": [
|
||||
"- Reduce the context length so the model would work fine on a MacBook Air (if you have more RAM, feel free to comment out the lines below):"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"id": "a8bc2370-39d2-4bfe-b4c1-6bdd75fe101c",
|
||||
"metadata": {
|
||||
"id": "a8bc2370-39d2-4bfe-b4c1-6bdd75fe101c"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"New RoPE theta: 31250.0\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"old_context_length = LLAMA32_CONFIG[\"context_length\"]\n",
|
||||
"LLAMA32_CONFIG[\"context_length\"] = 8192\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def rescale_theta(theta_old, context_length_old, context_length_new):\n",
|
||||
" scaling_factor = context_length_new / context_length_old\n",
|
||||
" theta_new = theta_old * scaling_factor\n",
|
||||
" return theta_new\n",
|
||||
"\n",
|
||||
"LLAMA32_CONFIG[\"rope_base\"] = rescale_theta(\n",
|
||||
" LLAMA32_CONFIG[\"rope_base\"],\n",
|
||||
" old_context_length,\n",
|
||||
" LLAMA32_CONFIG[\"context_length\"]\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"print(\"New RoPE theta:\", LLAMA32_CONFIG[\"rope_base\"])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"execution_count": 9,
|
||||
"id": "156253fe-aacd-4da2-8f13-705f05c4b11e",
|
||||
"metadata": {
|
||||
"id": "156253fe-aacd-4da2-8f13-705f05c4b11e"
|
||||
@ -539,36 +482,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"id": "0e95db6d-2712-41a5-a5e0-86c49897f4cf",
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "0e95db6d-2712-41a5-a5e0-86c49897f4cf",
|
||||
"outputId": "8efc4937-e616-40d0-cd59-670d7eb3e841"
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"True\n",
|
||||
"True\n",
|
||||
"True\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Check buffers\n",
|
||||
"print(model.trf_blocks[0].att.mask is model.trf_blocks[-1].att.mask)\n",
|
||||
"print(model.trf_blocks[0].att.cos is model.trf_blocks[-1].att.cos)\n",
|
||||
"print(model.trf_blocks[0].att.sin is model.trf_blocks[-1].att.sin)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"execution_count": 10,
|
||||
"id": "364e76ca-52f8-4fa5-af37-c4069f9694bc",
|
||||
"metadata": {
|
||||
"colab": {
|
||||
@ -599,7 +513,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"execution_count": 11,
|
||||
"id": "fd5efb03-5a07-46e8-8607-93ed47549d2b",
|
||||
"metadata": {
|
||||
"colab": {
|
||||
@ -613,8 +527,8 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"float32 (PyTorch default): 11.42 GB\n",
|
||||
"bfloat16: 5.71 GB\n"
|
||||
"float32 (PyTorch default): 11.23 GB\n",
|
||||
"bfloat16: 5.61 GB\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -649,7 +563,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"execution_count": 12,
|
||||
"id": "31f12baf-f79b-499f-85c0-51328a6a20f5",
|
||||
"metadata": {
|
||||
"id": "31f12baf-f79b-499f-85c0-51328a6a20f5"
|
||||
@ -679,7 +593,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"execution_count": 13,
|
||||
"id": "9482b01c-49f9-48e4-ab2c-4a4c75240e77",
|
||||
"metadata": {
|
||||
"id": "9482b01c-49f9-48e4-ab2c-4a4c75240e77"
|
||||
@ -693,73 +607,86 @@
|
||||
"from tiktoken.load import load_tiktoken_bpe\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class Tokenizer:\n",
|
||||
" def __init__(self, model_path):\n",
|
||||
" assert os.path.isfile(model_path), f\"Model file {model_path} not found\"\n",
|
||||
" mergeable_ranks = load_tiktoken_bpe(model_path)\n",
|
||||
"\n",
|
||||
" self.special_tokens = {\n",
|
||||
"class Tokenizer:\n",
|
||||
" \"\"\"Thin wrapper around tiktoken that keeps track of Llama-3 special IDs.\"\"\"\n",
|
||||
" def __init__(self, model_path):\n",
|
||||
" if not os.path.isfile(model_path):\n",
|
||||
" raise FileNotFoundError(model_path)\n",
|
||||
"\n",
|
||||
" mergeable = load_tiktoken_bpe(model_path)\n",
|
||||
"\n",
|
||||
" # hard-coded from Meta's tokenizer.json\n",
|
||||
" self.special = {\n",
|
||||
" \"<|begin_of_text|>\": 128000,\n",
|
||||
" \"<|end_of_text|>\": 128001,\n",
|
||||
" \"<|start_header_id|>\": 128006,\n",
|
||||
" \"<|end_header_id|>\": 128007,\n",
|
||||
" \"<|eot_id|>\": 128009,\n",
|
||||
" }\n",
|
||||
" self.special_tokens.update({\n",
|
||||
" f\"<|reserved_{i}|>\": 128002 + i for i in range(256) if (128002 + i) not in self.special_tokens.values()\n",
|
||||
" })\n",
|
||||
" self.special.update({f\"<|reserved_{i}|>\": 128002 + i\n",
|
||||
" for i in range(256)\n",
|
||||
" if 128002 + i not in self.special.values()})\n",
|
||||
"\n",
|
||||
" self.model = tiktoken.Encoding(\n",
|
||||
" name=Path(model_path).name,\n",
|
||||
" pat_str=r\"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+\",\n",
|
||||
" mergeable_ranks=mergeable_ranks,\n",
|
||||
" special_tokens=self.special_tokens\n",
|
||||
" pat_str=r\"(?i:'s|'t|'re|'ve|'m|'ll|'d)\"\n",
|
||||
" r\"|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+\"\n",
|
||||
" r\"|\\p{N}{1,3}\"\n",
|
||||
" r\"| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*\"\n",
|
||||
" r\"|\\s*[\\r\\n]+\"\n",
|
||||
" r\"|\\s+(?!\\S)\"\n",
|
||||
" r\"|\\s+\",\n",
|
||||
" mergeable_ranks=mergeable,\n",
|
||||
" special_tokens=self.special,\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" def encode(self, text, bos=False, eos=False, allowed_special=set(), disallowed_special=()):\n",
|
||||
" if bos:\n",
|
||||
" tokens = [self.special_tokens[\"<|begin_of_text|>\"]]\n",
|
||||
" else:\n",
|
||||
" tokens = []\n",
|
||||
"\n",
|
||||
" tokens += self.model.encode(text, allowed_special=allowed_special, disallowed_special=disallowed_special)\n",
|
||||
"\n",
|
||||
" def encode(self, text, bos=False, eos=False):\n",
|
||||
" ids = ([self.special[\"<|begin_of_text|>\"]] if bos else []) \\\n",
|
||||
" + self.model.encode(text)\n",
|
||||
" if eos:\n",
|
||||
" tokens.append(self.special_tokens[\"<|end_of_text|>\"])\n",
|
||||
" return tokens\n",
|
||||
" ids.append(self.special[\"<|end_of_text|>\"])\n",
|
||||
" return ids\n",
|
||||
"\n",
|
||||
" def decode(self, tokens):\n",
|
||||
" return self.model.decode(tokens)\n",
|
||||
" def decode(self, ids):\n",
|
||||
" return self.model.decode(ids)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class ChatFormat:\n",
|
||||
" def __init__(self, tokenizer):\n",
|
||||
" self.tokenizer = tokenizer\n",
|
||||
"\n",
|
||||
" def encode_header(self, message):\n",
|
||||
" tokens = []\n",
|
||||
" tokens.append(self.tokenizer.special_tokens[\"<|start_header_id|>\"])\n",
|
||||
" tokens.extend(self.tokenizer.encode(message[\"role\"], bos=False, eos=False))\n",
|
||||
" tokens.append(self.tokenizer.special_tokens[\"<|end_header_id|>\"])\n",
|
||||
" tokens.extend(self.tokenizer.encode(\"\\n\\n\", bos=False, eos=False))\n",
|
||||
" return tokens\n",
|
||||
" def __init__(self, tokenizer: Tokenizer, *,\n",
|
||||
" default_system=\"You are a helpful assistant.\"):\n",
|
||||
" self.tok = tokenizer\n",
|
||||
" self.default_system = default_system\n",
|
||||
"\n",
|
||||
" def encode(self, text):\n",
|
||||
" message = {\n",
|
||||
" \"role\": \"user\",\n",
|
||||
" \"content\": text\n",
|
||||
" }\n",
|
||||
"\n",
|
||||
" tokens = self.encode_header(message)\n",
|
||||
" tokens.extend(\n",
|
||||
" self.tokenizer.encode(message[\"content\"].strip(), bos=False, eos=False)\n",
|
||||
" def _header(self, role):\n",
|
||||
" \"\"\"Encode <|start_header_id|>role<|end_header_id|>\\n\\n\"\"\"\n",
|
||||
" return (\n",
|
||||
" [self.tok.special[\"<|start_header_id|>\"]]\n",
|
||||
" + self.tok.encode(role)\n",
|
||||
" + [self.tok.special[\"<|end_header_id|>\"]]\n",
|
||||
" + self.tok.encode(\"\\n\\n\")\n",
|
||||
" )\n",
|
||||
" tokens.append(self.tokenizer.special_tokens[\"<|eot_id|>\"])\n",
|
||||
" return tokens\n",
|
||||
"\n",
|
||||
" def decode(self, token_ids):\n",
|
||||
" return self.tokenizer.decode(token_ids)"
|
||||
" def encode(self, user_message, system_message=None):\n",
|
||||
" sys_msg = system_message if system_message is not None else self.default_system\n",
|
||||
"\n",
|
||||
" ids = [self.tok.special[\"<|begin_of_text|>\"]]\n",
|
||||
"\n",
|
||||
" # system\n",
|
||||
" ids += self._header(\"system\")\n",
|
||||
" ids += self.tok.encode(sys_msg)\n",
|
||||
" ids += [self.tok.special[\"<|eot_id|>\"]]\n",
|
||||
"\n",
|
||||
" # user\n",
|
||||
" ids += self._header(\"user\")\n",
|
||||
" ids += self.tok.encode(user_message)\n",
|
||||
" ids += [self.tok.special[\"<|eot_id|>\"]]\n",
|
||||
"\n",
|
||||
" # assistant header (no content yet)\n",
|
||||
" ids += self._header(\"assistant\")\n",
|
||||
"\n",
|
||||
" return ids"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -782,7 +709,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 17,
|
||||
"execution_count": 14,
|
||||
"id": "e9d96dc8-603a-4cb5-8c3e-4d2ca56862ed",
|
||||
"metadata": {
|
||||
"colab": {
|
||||
@ -793,25 +720,24 @@
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.\n",
|
||||
"Token is valid (permission: read).\n",
|
||||
"Your token has been saved to /teamspace/studios/this_studio/.cache/huggingface/token\n",
|
||||
"Login successful\n"
|
||||
"/Users/sebastian/Developer/LLMs-from-scratch/.venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
||||
" from .autonotebook import tqdm as notebook_tqdm\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from huggingface_hub import login\n",
|
||||
"# Uncomment and run the following code if you are executing the notebook for the first time\n",
|
||||
"\n",
|
||||
"login()"
|
||||
"# from huggingface_hub import login\n",
|
||||
"# login()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 18,
|
||||
"execution_count": 15,
|
||||
"id": "986bc1a0-804f-4154-80f8-44cefbee1368",
|
||||
"metadata": {
|
||||
"colab": {
|
||||
@ -847,7 +773,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 19,
|
||||
"execution_count": 16,
|
||||
"id": "_gBhxDtU_nxo",
|
||||
"metadata": {
|
||||
"id": "_gBhxDtU_nxo"
|
||||
@ -871,7 +797,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 20,
|
||||
"execution_count": 17,
|
||||
"id": "75166128-5899-4995-9b88-9672e135650e",
|
||||
"metadata": {
|
||||
"id": "75166128-5899-4995-9b88-9672e135650e"
|
||||
@ -954,7 +880,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 21,
|
||||
"execution_count": 18,
|
||||
"id": "699cb1b8-a67d-49fb-80a6-0dad9d81f392",
|
||||
"metadata": {
|
||||
"colab": {
|
||||
@ -1018,7 +944,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 22,
|
||||
"execution_count": 19,
|
||||
"id": "7f9f7ccc-70cb-41ff-9c25-44336042fc37",
|
||||
"metadata": {
|
||||
"id": "7f9f7ccc-70cb-41ff-9c25-44336042fc37"
|
||||
@ -1049,7 +975,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 23,
|
||||
"execution_count": 20,
|
||||
"id": "7b8401c6-e244-4cb7-9849-2ba71ce758d5",
|
||||
"metadata": {
|
||||
"id": "7b8401c6-e244-4cb7-9849-2ba71ce758d5"
|
||||
@ -1108,7 +1034,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 24,
|
||||
"execution_count": 21,
|
||||
"id": "1c7a04fa-6aac-416b-8f63-f1e19227633d",
|
||||
"metadata": {
|
||||
"id": "1c7a04fa-6aac-416b-8f63-f1e19227633d"
|
||||
@ -1118,23 +1044,31 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Time: 18.20 sec\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"Output text:\n",
|
||||
" Llamas are herbivores, which means they primarily eat plants. Their diet consists mainly of:\n",
|
||||
"\n",
|
||||
"1. Grasses: Llamas love to graze on various types of grasses, including tall grasses and grassy meadows.\n",
|
||||
"2. Hay: Llamas also eat hay, which is a dry, compressed form of grass or other plants.\n",
|
||||
"3. Alfalfa: Alfalfa is a legume that is commonly fed to llamas. It is high in protein and fiber.\n",
|
||||
"4. Other plants: Llamas will also eat other plants, such as wild grasses, shrubs, and trees.\n",
|
||||
" Llamas are herbivores, which means they primarily eat plants and plant-based foods. Their diet typically consists of:\n",
|
||||
"\n",
|
||||
"It's worth noting that the diet of llamas can vary depending on the region, climate,\n"
|
||||
"1. Grasses: Llamas love to graze on various types of grasses, including tall grasses, short grasses, and grassy weeds.\n",
|
||||
"2. Hay: They also enjoy munching on hay, which is a dry, compressed form of grass or other plant material.\n",
|
||||
"3. Leaves: Llamas will eat leaves from trees and shrubs, including leaves from plants like clover, alfalfa, and grasses.\n",
|
||||
"4. Fruits and vegetables: In the wild, llamas will eat fruits and vegetables like berries, apples, and carrots.\n",
|
||||
"5. Browse: Llamas will also\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import time\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"PROMPT = \"What do llamas eat?\"\n",
|
||||
"\n",
|
||||
"torch.manual_seed(123)\n",
|
||||
"\n",
|
||||
"start = time.time()\n",
|
||||
"\n",
|
||||
"token_ids = generate(\n",
|
||||
" model=model,\n",
|
||||
" idx=text_to_token_ids(PROMPT, chat_tokenizer).to(device),\n",
|
||||
@ -1144,6 +1078,13 @@
|
||||
" temperature=0.\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"print(f\"Time: {time.time() - start:.2f} sec\")\n",
|
||||
"\n",
|
||||
"if torch.cuda.is_available():\n",
|
||||
" max_mem_bytes = torch.cuda.max_memory_allocated()\n",
|
||||
" max_mem_gb = max_mem_bytes / (1024 ** 3)\n",
|
||||
" print(f\"Max memory allocated: {max_mem_gb:.2f} GB\")\n",
|
||||
"\n",
|
||||
"output_text = token_ids_to_text(token_ids, tokenizer)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
@ -1158,7 +1099,7 @@
|
||||
" # If the token is not found, return the original text\n",
|
||||
" return text\n",
|
||||
"\n",
|
||||
"print(\"Output text:\\n\", clean_text(output_text))"
|
||||
"print(\"\\n\\nOutput text:\\n\\n\", clean_text(output_text))"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@ -110,12 +110,21 @@ from llms_from_scratch.appendix_a import NeuralNetwork, ToyDataset
|
||||
|
||||
from llms_from_scratch.appendix_d import find_highest_gradient, train_model
|
||||
|
||||
```
|
||||
|
||||
|
||||
|
||||
### Llama 3 (Bonus material)
|
||||
|
||||
```python
|
||||
from llms_from_scratch.llama3 import (
|
||||
Llama3Model,
|
||||
Llama3ModelFast,
|
||||
Llama3Tokenizer,
|
||||
ChatFormat,
|
||||
clean_text
|
||||
)
|
||||
```
|
||||
|
||||
(For the `llms_from_scratch.llama3` usage information, please see [this bonus section](../../ch05/07_gpt_to_llama/README.md).
|
||||
|
||||
For the `llms_from_scratch.llama3` usage information, please see [this bonus section](../../ch05/07_gpt_to_llama/README.md).
|
||||
|
||||
@ -15,8 +15,7 @@ from tiktoken.load import load_tiktoken_bpe
|
||||
|
||||
LLAMA32_CONFIG_1B = {
|
||||
"vocab_size": 128_256, # Vocabulary size
|
||||
"context_length": 8192, # Maximum context length to use (reduced to save memory)
|
||||
"orig_context_length": 131_072, # Context length that was used to train the model
|
||||
"context_length": 131_072, # Context length that was used to train the model
|
||||
"emb_dim": 2048, # Embedding dimension
|
||||
"n_heads": 32, # Number of attention heads
|
||||
"n_layers": 16, # Number of layers
|
||||
@ -34,8 +33,7 @@ LLAMA32_CONFIG_1B = {
|
||||
|
||||
LLAMA32_CONFIG_3B = {
|
||||
"vocab_size": 128_256, # Vocabulary size
|
||||
"context_length": 8192, # Maximum context length to use (reduced to save memory)
|
||||
"orig_context_length": 131_072, # Context length that was used to train the model
|
||||
"context_length": 131_072, # Context length that was used to train the model
|
||||
"emb_dim": 3072, # Embedding dimension
|
||||
"n_heads": 24, # Number of attention heads
|
||||
"n_layers": 28, # Number of layers
|
||||
@ -67,17 +65,6 @@ class Llama3Model(nn.Module):
|
||||
self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False, dtype=cfg["dtype"])
|
||||
|
||||
# Reusuable utilities
|
||||
self.register_buffer(
|
||||
"mask", torch.triu(torch.ones(cfg["context_length"], cfg["context_length"]), diagonal=1).bool(),
|
||||
persistent=False
|
||||
)
|
||||
|
||||
if cfg["orig_context_length"] != cfg["context_length"]:
|
||||
cfg["rope_base"] = rescale_theta(
|
||||
cfg["rope_base"],
|
||||
cfg["orig_context_length"],
|
||||
cfg["context_length"]
|
||||
)
|
||||
cos, sin = compute_rope_params(
|
||||
head_dim=cfg["emb_dim"] // cfg["n_heads"],
|
||||
theta_base=cfg["rope_base"],
|
||||
@ -92,8 +79,11 @@ class Llama3Model(nn.Module):
|
||||
tok_embeds = self.tok_emb(in_idx)
|
||||
x = tok_embeds
|
||||
|
||||
num_tokens = x.shape[1]
|
||||
mask = torch.triu(torch.ones(num_tokens, num_tokens, device=x.device, dtype=torch.bool), diagonal=1)
|
||||
|
||||
for block in self.trf_blocks:
|
||||
x = block(x, self.mask, self.cos, self.sin)
|
||||
x = block(x, mask, self.cos, self.sin)
|
||||
x = self.final_norm(x)
|
||||
logits = self.out_head(x.to(self.cfg["dtype"]))
|
||||
return logits
|
||||
@ -281,88 +271,104 @@ def apply_rope(x, cos, sin):
|
||||
return x_rotated.to(dtype=x.dtype)
|
||||
|
||||
|
||||
def rescale_theta(theta_old, context_length_old, context_length_new):
|
||||
scaling_factor = context_length_new / context_length_old
|
||||
theta_new = theta_old * scaling_factor
|
||||
return theta_new
|
||||
|
||||
|
||||
##########################################
|
||||
# Tokenizer
|
||||
##########################################
|
||||
|
||||
|
||||
class Llama3Tokenizer:
|
||||
"""Thin wrapper around tiktoken that keeps track of Llama-3 special IDs."""
|
||||
def __init__(self, model_path):
|
||||
assert os.path.isfile(model_path), f"Model file {model_path} not found"
|
||||
mergeable_ranks = load_tiktoken_bpe(model_path)
|
||||
if not os.path.isfile(model_path):
|
||||
raise FileNotFoundError(model_path)
|
||||
|
||||
self.special_tokens = {
|
||||
mergeable = load_tiktoken_bpe(model_path)
|
||||
|
||||
# hard-coded from Meta's tokenizer.json
|
||||
self.special = {
|
||||
"<|begin_of_text|>": 128000,
|
||||
"<|end_of_text|>": 128001,
|
||||
"<|start_header_id|>": 128006,
|
||||
"<|end_header_id|>": 128007,
|
||||
"<|eot_id|>": 128009,
|
||||
}
|
||||
self.special_tokens.update({
|
||||
f"<|reserved_{i}|>": 128002 + i for i in range(256) if (128002 + i) not in self.special_tokens.values()
|
||||
})
|
||||
self.special.update({f"<|reserved_{i}|>": 128002 + i
|
||||
for i in range(256)
|
||||
if 128002 + i not in self.special.values()})
|
||||
|
||||
self.model = tiktoken.Encoding(
|
||||
name=Path(model_path).name,
|
||||
pat_str=r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+",
|
||||
mergeable_ranks=mergeable_ranks,
|
||||
special_tokens=self.special_tokens
|
||||
pat_str=r"(?i:'s|'t|'re|'ve|'m|'ll|'d)"
|
||||
r"|[^\r\n\p{L}\p{N}]?\p{L}+"
|
||||
r"|\p{N}{1,3}"
|
||||
r"| ?[^\s\p{L}\p{N}]+[\r\n]*"
|
||||
r"|\s*[\r\n]+"
|
||||
r"|\s+(?!\S)"
|
||||
r"|\s+",
|
||||
mergeable_ranks=mergeable,
|
||||
special_tokens=self.special,
|
||||
)
|
||||
|
||||
def encode(self, text, bos=False, eos=False, allowed_special=set(), disallowed_special=()):
|
||||
def encode(self, text, bos=False, eos=False, allowed_special=set()):
|
||||
ids: list[int] = []
|
||||
|
||||
if bos:
|
||||
tokens = [self.special_tokens["<|begin_of_text|>"]]
|
||||
else:
|
||||
tokens = []
|
||||
|
||||
tokens += self.model.encode(text, allowed_special=allowed_special, disallowed_special=disallowed_special)
|
||||
ids.append(self.special_tokens["<|begin_of_text|>"])
|
||||
|
||||
# delegate to underlying tiktoken.Encoding.encode
|
||||
ids.extend(
|
||||
self.model.encode(
|
||||
text,
|
||||
allowed_special=allowed_special,
|
||||
)
|
||||
)
|
||||
if eos:
|
||||
tokens.append(self.special_tokens["<|end_of_text|>"])
|
||||
return tokens
|
||||
ids.append(self.special_tokens["<|end_of_text|>"])
|
||||
|
||||
def decode(self, tokens):
|
||||
return self.model.decode(tokens)
|
||||
return ids
|
||||
|
||||
def decode(self, ids):
|
||||
return self.model.decode(ids)
|
||||
|
||||
|
||||
class ChatFormat:
|
||||
def __init__(self, tokenizer):
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
def encode_header(self, message):
|
||||
tokens = []
|
||||
tokens.append(self.tokenizer.special_tokens["<|start_header_id|>"])
|
||||
tokens.extend(self.tokenizer.encode(message["role"], bos=False, eos=False))
|
||||
tokens.append(self.tokenizer.special_tokens["<|end_header_id|>"])
|
||||
tokens.extend(self.tokenizer.encode("\n\n", bos=False, eos=False))
|
||||
return tokens
|
||||
def __init__(self, tokenizer: Llama3Tokenizer, *,
|
||||
default_system="You are a helpful assistant."):
|
||||
self.tok = tokenizer
|
||||
self.default_system = default_system
|
||||
|
||||
def encode(self, text, allowed_special=None):
|
||||
message = {
|
||||
"role": "user",
|
||||
"content": text
|
||||
}
|
||||
|
||||
tokens = self.encode_header(message)
|
||||
tokens.extend(
|
||||
self.tokenizer.encode(
|
||||
message["content"].strip(),
|
||||
bos=False,
|
||||
eos=False,
|
||||
allowed_special=allowed_special
|
||||
)
|
||||
def _header(self, role):
|
||||
"""Encode <|start_header_id|>role<|end_header_id|>\n\n"""
|
||||
return (
|
||||
[self.tok.special["<|start_header_id|>"]]
|
||||
+ self.tok.encode(role)
|
||||
+ [self.tok.special["<|end_header_id|>"]]
|
||||
+ self.tok.encode("\n\n")
|
||||
)
|
||||
tokens.append(self.tokenizer.special_tokens["<|eot_id|>"])
|
||||
return tokens
|
||||
|
||||
def decode(self, token_ids):
|
||||
return self.tokenizer.decode(token_ids)
|
||||
def encode(self, user_message, system_message=None, allowed_special=None):
|
||||
sys_msg = system_message if system_message is not None else self.default_system
|
||||
|
||||
ids = [self.tok.special["<|begin_of_text|>"]]
|
||||
|
||||
# system
|
||||
ids += self._header("system")
|
||||
ids += self.tok.encode(sys_msg, allowed_special=allowed_special)
|
||||
ids += [self.tok.special["<|eot_id|>"]]
|
||||
|
||||
# user
|
||||
ids += self._header("user")
|
||||
ids += self.tok.encode(user_message)
|
||||
ids += [self.tok.special["<|eot_id|>"]]
|
||||
|
||||
# assistant header (no content yet)
|
||||
ids += self._header("assistant")
|
||||
|
||||
return ids
|
||||
|
||||
def decode(self, ids):
|
||||
return self.tok.decode(ids)
|
||||
|
||||
|
||||
def clean_text(text, header_end="assistant<|end_header_id|>\n\n"):
|
||||
@ -483,12 +489,6 @@ class Llama3ModelFast(nn.Module):
|
||||
self.final_norm = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"])
|
||||
self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False, dtype=cfg["dtype"])
|
||||
|
||||
if cfg["orig_context_length"] != cfg["context_length"]:
|
||||
cfg["rope_base"] = rescale_theta(
|
||||
cfg["rope_base"],
|
||||
cfg["orig_context_length"],
|
||||
cfg["context_length"]
|
||||
)
|
||||
cos, sin = compute_rope_params(
|
||||
head_dim=cfg["emb_dim"] // cfg["n_heads"],
|
||||
theta_base=cfg["rope_base"],
|
||||
|
||||
@ -7,7 +7,6 @@ from llms_from_scratch.ch04 import generate_text_simple
|
||||
from llms_from_scratch.llama3 import (
|
||||
compute_rope_params,
|
||||
apply_rope,
|
||||
rescale_theta,
|
||||
LLAMA32_CONFIG_1B,
|
||||
GroupedQueryAttention,
|
||||
GroupedQueryAttentionFast,
|
||||
@ -102,23 +101,6 @@ GPT_CONFIG_124M = {
|
||||
}
|
||||
|
||||
|
||||
def test_rescale():
|
||||
|
||||
new_theta = rescale_theta(
|
||||
theta_old=500_000.,
|
||||
context_length_old=131_072,
|
||||
context_length_new=8192
|
||||
)
|
||||
assert new_theta == 31250.
|
||||
|
||||
old_theta = rescale_theta(
|
||||
theta_old=new_theta,
|
||||
context_length_old=8192,
|
||||
context_length_new=131_072
|
||||
)
|
||||
assert old_theta == 500_000.
|
||||
|
||||
|
||||
def test_grouped_query_attention_equivalence():
|
||||
torch.manual_seed(42)
|
||||
b, t, d_in, d_out, num_heads, num_kv_groups = 2, 8, 32, 64, 4, 2
|
||||
@ -194,6 +176,6 @@ def test_gpt_model_variants(ModelClass, llama3_weights_path):
|
||||
)
|
||||
print("Encoded output text:", out)
|
||||
expect = torch.tensor([
|
||||
[43, 2543, 292, 4483, 100383, 8113, 21197, 33804, 54419]
|
||||
[43, 2543, 292, 4483, 100383, 8113, 76873, 42175, 72641]
|
||||
])
|
||||
assert torch.equal(expect, out)
|
||||
|
||||
@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "llms-from-scratch"
|
||||
version = "1.0.6"
|
||||
version = "1.0.7"
|
||||
description = "Implement a ChatGPT-like LLM in PyTorch from scratch, step by step"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user