Bonus material: extending tokenizers (#496)

* Bonus material: extending tokenizers

* small wording update
This commit is contained in:
Sebastian Raschka 2025-01-22 09:26:54 -06:00 committed by GitHub
parent 9175590ea4
commit dcaac28b92
7 changed files with 1224 additions and 2 deletions

View File

@ -120,6 +120,7 @@ Several folders contain optional materials as a bonus for interested readers:
- [Converting GPT to Llama](ch05/07_gpt_to_llama)
- [Llama 3.2 From Scratch](ch05/07_gpt_to_llama/standalone-llama32.ipynb)
- [Memory-efficient Model Weight Loading](ch05/08_memory_efficient_weight_loading/memory-efficient-state-dict.ipynb)
- [Extending the Tiktoken BPE Tokenizer with New Tokens](ch05/09_extending-tokenizers/extend-tiktoken.ipynb)
- **Chapter 6: Finetuning for classification**
- [Additional experiments finetuning different layers and using larger models](ch06/02_bonus_additional-experiments)
- [Finetuning different models on 50k IMDB movie review dataset](ch06/03_bonus_imdb-classification)

View File

@ -0,0 +1,3 @@
# Byte Pair Encoding (BPE) Tokenizer From Scratch
- [bpe-from-scratch.ipynb](bpe-from-scratch.ipynb) contains optional (bonus) code that explains and shows how the BPE tokenizer works under the hood.

View File

@ -0,0 +1,3 @@
# Extending the Tiktoken BPE Tokenizer with New Tokens
- [extend-tiktoken.ipynb](extend-tiktoken.ipynb) contains optional (bonus) code to explain how we can add special tokens to a tokenizer implemented via `tiktoken` and how to update the LLM accordingly

View File

@ -0,0 +1,771 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "cbbc1fe3-bff1-4631-bf35-342e19c54cc0",
"metadata": {},
"source": [
"<table style=\"width:100%\">\n",
"<tr>\n",
"<td style=\"vertical-align:middle; text-align:left;\">\n",
"<font size=\"2\">\n",
"Supplementary code for the <a href=\"http://mng.bz/orYv\">Build a Large Language Model From Scratch</a> book by <a href=\"https://sebastianraschka.com\">Sebastian Raschka</a><br>\n",
"<br>Code repository: <a href=\"https://github.com/rasbt/LLMs-from-scratch\">https://github.com/rasbt/LLMs-from-scratch</a>\n",
"</font>\n",
"</td>\n",
"<td style=\"vertical-align:middle; text-align:left;\">\n",
"<a href=\"http://mng.bz/orYv\"><img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/cover-small.webp\" width=\"100px\"></a>\n",
"</td>\n",
"</tr>\n",
"</table>"
]
},
{
"cell_type": "markdown",
"id": "2b022374-e3f6-4437-b86f-e6f8f94cbebc",
"metadata": {},
"source": [
"# Extending the Tiktoken BPE Tokenizer with New Tokens"
]
},
{
"cell_type": "markdown",
"id": "bcd624b1-2060-49af-bbf6-40517a58c128",
"metadata": {},
"source": [
"- This notebook explains how we can extend an existing BPE tokenizer; specifically, we will focus on how to do it for the popular [tiktoken](https://github.com/openai/tiktoken) implementation\n",
"- For a general introduction to tokenization, please refer to [Chapter 2](https://github.com/rasbt/LLMs-from-scratch/blob/main/ch02/01_main-chapter-code/ch02.ipynb) and the BPE from Scratch [link] tutorial\n",
"- For example, suppose we have a GPT-2 tokenizer and want to encode the following text:"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "798d4355-a146-48a8-a1a5-c5cec91edf2c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[15496, 11, 2011, 3791, 30642, 62, 16, 318, 257, 649, 11241, 13, 220, 50256]\n"
]
}
],
"source": [
"import tiktoken\n",
"\n",
"base_tokenizer = tiktoken.get_encoding(\"gpt2\")\n",
"sample_text = \"Hello, MyNewToken_1 is a new token. <|endoftext|>\"\n",
"\n",
"token_ids = base_tokenizer.encode(sample_text, allowed_special={\"<|endoftext|>\"})\n",
"print(token_ids)"
]
},
{
"cell_type": "markdown",
"id": "5b09b19b-772d-4449-971b-8ab052ee726d",
"metadata": {},
"source": [
"- Iterating over each token ID can give us a better understanding of how the token IDs are decoded via the vocabulary:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "21fd634b-bb4c-4ba3-8b69-9322b727bf58",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"15496 -> Hello\n",
"11 -> ,\n",
"2011 -> My\n",
"3791 -> New\n",
"30642 -> Token\n",
"62 -> _\n",
"16 -> 1\n",
"318 -> is\n",
"257 -> a\n",
"649 -> new\n",
"11241 -> token\n",
"13 -> .\n",
"220 -> \n",
"50256 -> <|endoftext|>\n"
]
}
],
"source": [
"for token_id in token_ids:\n",
" print(f\"{token_id} -> {base_tokenizer.decode([token_id])}\")"
]
},
{
"cell_type": "markdown",
"id": "fd5b1b9b-b1a9-489e-9711-c15a8e081813",
"metadata": {},
"source": [
"- As we can see above, the `\"MyNewToken_1\"` is broken down into 5 individual subword tokens -- this is normal behavior for BPE when handling unknown words\n",
"- However, suppose that it's a special token that we want to encode as a single token, similar to some of the other words or `\"<|endoftext|>\"`; this notebook explains how"
]
},
{
"cell_type": "markdown",
"id": "65f62ab6-df96-4f88-ab9a-37702cd30f5f",
"metadata": {},
"source": [
"&nbsp;\n",
"## 1. Adding special tokens"
]
},
{
"cell_type": "markdown",
"id": "c4379fdb-57ba-4a75-9183-0aee0836c391",
"metadata": {},
"source": [
"- Note that we have to add new tokens as special tokens; the reason is that we don't have the \"merges\" for the new tokens that are created during the tokenizer training process -- even if we had them, it would be very challenging to incorporate them without breaking the existing tokenization scheme (see the BPE from scratch notebook [link] to understand the \"merges\")\n",
"- Suppose we want to add 2 new tokens:"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "265f1bba-c478-497d-b7fc-f4bd191b7d55",
"metadata": {},
"outputs": [],
"source": [
"# Define custom tokens and their token IDs\n",
"custom_tokens = [\"MyNewToken_1\", \"MyNewToken_2\"]\n",
"custom_token_ids = {\n",
" token: base_tokenizer.n_vocab + i for i, token in enumerate(custom_tokens)\n",
"}"
]
},
{
"cell_type": "markdown",
"id": "1c6f3d98-1ab6-43cf-9ae2-2bf53860f99e",
"metadata": {},
"source": [
"- Next, we create a custom `Encoding` object that holds our special tokens as follows:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "1f519852-59ea-4069-a8c7-0f647bfaea09",
"metadata": {},
"outputs": [],
"source": [
"# Create a new Encoding object with extended tokens\n",
"extended_tokenizer = tiktoken.Encoding(\n",
" name=\"gpt2_custom\",\n",
" pat_str=base_tokenizer._pat_str,\n",
" mergeable_ranks=base_tokenizer._mergeable_ranks,\n",
" special_tokens={**base_tokenizer._special_tokens, **custom_token_ids},\n",
")"
]
},
{
"cell_type": "markdown",
"id": "90af6cfa-e0cc-4c80-89dc-3a824e7bdeb2",
"metadata": {},
"source": [
"- That's it, we can now check that it can encode the sample text:"
]
},
{
"cell_type": "markdown",
"id": "153e8e1d-c4cb-41ff-9c55-1701e9bcae1c",
"metadata": {},
"source": [
"- As we can see, the new tokens `50257` and `50258` are now encoded in the output:"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "eccc78a4-1fd4-47ba-a114-83ee0a3aec31",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[36674, 2420, 351, 220, 50257, 290, 220, 50258, 13, 220, 50256]\n"
]
}
],
"source": [
"special_tokens_set = set(custom_tokens) | {\"<|endoftext|>\"}\n",
"\n",
"token_ids = extended_tokenizer.encode(\n",
" \"Sample text with MyNewToken_1 and MyNewToken_2. <|endoftext|>\",\n",
" allowed_special=special_tokens_set\n",
")\n",
"print(token_ids)"
]
},
{
"cell_type": "markdown",
"id": "dc0547c1-bbb5-4915-8cf4-caaebcf922eb",
"metadata": {},
"source": [
"- Again, we can also look at it on a per-token level:"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "7583eff9-b10d-4e3d-802c-f0464e1ef030",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"36674 -> Sample\n",
"2420 -> text\n",
"351 -> with\n",
"220 -> \n",
"50257 -> MyNewToken_1\n",
"290 -> and\n",
"220 -> \n",
"50258 -> MyNewToken_2\n",
"13 -> .\n",
"220 -> \n",
"50256 -> <|endoftext|>\n"
]
}
],
"source": [
"for token_id in token_ids:\n",
" print(f\"{token_id} -> {extended_tokenizer.decode([token_id])}\")"
]
},
{
"cell_type": "markdown",
"id": "17f0764e-e5a9-4226-a384-18c11bd5fec3",
"metadata": {},
"source": [
"- As we can see above, we have successfully updated the tokenizer\n",
"- However, to use it with a pretrained LLM, we also have to update the embedding and output layers of the LLM, which is discussed in the next section"
]
},
{
"cell_type": "markdown",
"id": "8ec7f98d-8f09-4386-83f0-9bec68ef7f66",
"metadata": {},
"source": [
"&nbsp;\n",
"## 2. Updating a pretrained LLM"
]
},
{
"cell_type": "markdown",
"id": "b8a4f68b-04e9-4524-8df4-8718c7b566f2",
"metadata": {},
"source": [
"- In this section, we will take a look at how we have to update an existing pretrained LLM after updating the tokenizer\n",
"- For this, we are using the original pretrained GPT-2 model that is used in the main book"
]
},
{
"cell_type": "markdown",
"id": "1a9b252e-1d1d-4ddf-b9f3-95bd6ba505a9",
"metadata": {},
"source": [
"&nbsp;\n",
"### 2.1 Loading a pretrained GPT model"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "ded29b4e-9b39-4191-b61c-29d6b2360bae",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"checkpoint: 100%|███████████████████████████| 77.0/77.0 [00:00<00:00, 34.4kiB/s]\n",
"encoder.json: 100%|███████████████████████| 1.04M/1.04M [00:00<00:00, 4.78MiB/s]\n",
"hparams.json: 100%|█████████████████████████| 90.0/90.0 [00:00<00:00, 24.7kiB/s]\n",
"model.ckpt.data-00000-of-00001: 100%|███████| 498M/498M [00:33<00:00, 14.7MiB/s]\n",
"model.ckpt.index: 100%|███████████████████| 5.21k/5.21k [00:00<00:00, 1.05MiB/s]\n",
"model.ckpt.meta: 100%|██████████████████████| 471k/471k [00:00<00:00, 2.33MiB/s]\n",
"vocab.bpe: 100%|████████████████████████████| 456k/456k [00:00<00:00, 2.45MiB/s]\n"
]
}
],
"source": [
"# Relative import from the gpt_download.py contained in this folder\n",
"from gpt_download import download_and_load_gpt2\n",
"\n",
"settings, params = download_and_load_gpt2(model_size=\"124M\", models_dir=\"gpt2\")"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "93dc0d8e-b549-415b-840e-a00023bddcf9",
"metadata": {},
"outputs": [],
"source": [
"# Relative import from the gpt_download.py contained in this folder\n",
"from previous_chapters import GPTModel\n",
"\n",
"GPT_CONFIG_124M = {\n",
" \"vocab_size\": 50257, # Vocabulary size\n",
" \"context_length\": 256, # Shortened context length (orig: 1024)\n",
" \"emb_dim\": 768, # Embedding dimension\n",
" \"n_heads\": 12, # Number of attention heads\n",
" \"n_layers\": 12, # Number of layers\n",
" \"drop_rate\": 0.1, # Dropout rate\n",
" \"qkv_bias\": False # Query-key-value bias\n",
"}\n",
"\n",
"# Define model configurations in a dictionary for compactness\n",
"model_configs = {\n",
" \"gpt2-small (124M)\": {\"emb_dim\": 768, \"n_layers\": 12, \"n_heads\": 12},\n",
" \"gpt2-medium (355M)\": {\"emb_dim\": 1024, \"n_layers\": 24, \"n_heads\": 16},\n",
" \"gpt2-large (774M)\": {\"emb_dim\": 1280, \"n_layers\": 36, \"n_heads\": 20},\n",
" \"gpt2-xl (1558M)\": {\"emb_dim\": 1600, \"n_layers\": 48, \"n_heads\": 25},\n",
"}\n",
"\n",
"# Copy the base configuration and update with specific model settings\n",
"model_name = \"gpt2-small (124M)\" # Example model name\n",
"NEW_CONFIG = GPT_CONFIG_124M.copy()\n",
"NEW_CONFIG.update(model_configs[model_name])\n",
"NEW_CONFIG.update({\"context_length\": 1024, \"qkv_bias\": True})\n",
"\n",
"gpt = GPTModel(NEW_CONFIG)\n",
"gpt.eval();"
]
},
{
"cell_type": "markdown",
"id": "83f898c0-18f4-49ce-9b1f-3203a277b29e",
"metadata": {},
"source": [
"### 2.2 Using the pretrained GPT model"
]
},
{
"cell_type": "markdown",
"id": "a5a1f5e1-e806-4c60-abaa-42ae8564908c",
"metadata": {},
"source": [
"- Next, consider our sample text below, which we tokenize using the original and the new tokenizer:"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "9a88017d-cc8f-4ba1-bba9-38161a30f673",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"sample_text = \"Sample text with MyNewToken_1 and MyNewToken_2. <|endoftext|>\"\n",
"\n",
"original_token_ids = base_tokenizer.encode(\n",
" sample_text, allowed_special={\"<|endoftext|>\"}\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "1ee01bc3-ca24-497b-b540-3d13c52c29ed",
"metadata": {},
"outputs": [],
"source": [
"new_token_ids = extended_tokenizer.encode(\n",
" \"Sample text with MyNewToken_1 and MyNewToken_2. <|endoftext|>\",\n",
" allowed_special=special_tokens_set\n",
")"
]
},
{
"cell_type": "markdown",
"id": "1143106b-68fe-4234-98ad-eaff420a4d08",
"metadata": {},
"source": [
"- Now, let's feed the original token IDs to the GPT model:"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "6b06827f-b411-42cc-b978-5c1d568a3200",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([[[ 0.2204, 0.8901, 1.0138, ..., 0.2585, -0.9192, -0.2298],\n",
" [ 0.6745, -0.0726, 0.8218, ..., -0.1768, -0.4217, 0.0703],\n",
" [-0.2009, 0.0814, 0.2417, ..., 0.3166, 0.3629, 1.3400],\n",
" ...,\n",
" [ 0.1137, -0.1258, 2.0193, ..., -0.0314, -0.4288, -0.1487],\n",
" [-1.1983, -0.2050, -0.1337, ..., -0.0849, -0.4863, -0.1076],\n",
" [-1.0675, -0.5905, 0.2873, ..., -0.0979, -0.8713, 0.8415]]])\n"
]
}
],
"source": [
"import torch\n",
"\n",
"with torch.no_grad():\n",
" out = gpt(torch.tensor([original_token_ids]))\n",
"\n",
"print(out)"
]
},
{
"cell_type": "markdown",
"id": "082c7a78-35a8-473e-a08d-b099a6348a74",
"metadata": {},
"source": [
"- As we can see above, this works without problems (note that the code shows the raw output without converting the outputs back into text for simplicity; for more details on that, please check out the `generate` function in Chapter 5 [link] section 5.3.3"
]
},
{
"cell_type": "markdown",
"id": "628265b5-3dde-44e7-bde2-8fc594a2547d",
"metadata": {},
"source": [
"- What happens if we try the same on the token IDs generated by the updated tokenizer now?"
]
},
{
"cell_type": "markdown",
"id": "9796ad09-787c-4c25-a7f5-6d1dfe048ac3",
"metadata": {},
"source": [
"```python\n",
"with torch.no_grad():\n",
" gpt(torch.tensor([new_token_ids]))\n",
"\n",
"print(out)\n",
"\n",
"...\n",
"# IndexError: index out of range in self\n",
"```"
]
},
{
"cell_type": "markdown",
"id": "77d00244-7e40-4de0-942e-e15cdd8e3b18",
"metadata": {},
"source": [
"- As we can see, this results in an index error\n",
"- The reason is that the GPT model expects a fixed vocabulary size via its input embedding layer and its output layer:\n",
"\n",
"<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/extend-tiktoken/gpt-updates.webp\" width=\"400px\">"
]
},
{
"cell_type": "markdown",
"id": "dec38b24-c845-4090-96a4-0d3c4ec241d6",
"metadata": {},
"source": [
"&nbsp;\n",
"### 2.3 Updating the embedding layer"
]
},
{
"cell_type": "markdown",
"id": "b1328726-8297-4162-878b-a5daff7de742",
"metadata": {},
"source": [
"- Let's start with updating the embedding layer\n",
"- First, notice that the embedding layer has 50,257 entries, which corresponds to the vocabulary size:"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "23ecab6e-1232-47c7-a318-042f90e1dff3",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Embedding(50257, 768)"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"gpt.tok_emb"
]
},
{
"cell_type": "markdown",
"id": "d760c683-d082-470a-bff8-5a08b30d3b61",
"metadata": {},
"source": [
"- We want to extend this embedding layer by adding 2 more entries\n",
"- In short, we create a new embedding layer with a bigger size, and then we copy over the old embedding layer values"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "4ec5c48e-c6fe-4e84-b290-04bd4da9483f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Embedding(50259, 768)\n"
]
}
],
"source": [
"num_tokens, emb_size = gpt.tok_emb.weight.shape\n",
"new_num_tokens = num_tokens + 2\n",
"\n",
"# Create a new embedding layer\n",
"new_embedding = torch.nn.Embedding(new_num_tokens, emb_size)\n",
"\n",
"# Copy weights from the old embedding layer\n",
"new_embedding.weight.data[:num_tokens] = gpt.tok_emb.weight.data\n",
"\n",
"# Replace the old embedding layer with the new one in the model\n",
"gpt.tok_emb = new_embedding\n",
"\n",
"print(gpt.tok_emb)"
]
},
{
"cell_type": "markdown",
"id": "63954928-31a5-4e7e-9688-2e0c156b7302",
"metadata": {},
"source": [
"- As we can see above, we now have an increased embedding layer"
]
},
{
"cell_type": "markdown",
"id": "6e68bea5-255b-47bb-b352-09ea9539bc25",
"metadata": {},
"source": [
"&nbsp;\n",
"### 2.4 Updating the output layer"
]
},
{
"cell_type": "markdown",
"id": "90a4a519-bf0f-4502-912d-ef0ac7a9deab",
"metadata": {},
"source": [
"- Next, we have to extend the output layer, which has 50,257 output features corresponding to the vocabulary size similar to the embedding layer (by the way, you may find the bonus material, which discusses the similarity between Linear and Embedding layers in PyTorch, useful)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "6105922f-d889-423e-bbcc-bc49156d78df",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Linear(in_features=768, out_features=50257, bias=False)"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"gpt.out_head"
]
},
{
"cell_type": "markdown",
"id": "29f1ff24-9c00-40f6-a94f-82d03aaf0890",
"metadata": {},
"source": [
"- The procedure for extending the output layer is similar to extending the embedding layer:"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "354589db-b148-4dae-8068-62132e3fb38e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Linear(in_features=768, out_features=50259, bias=True)\n"
]
}
],
"source": [
"original_out_features, original_in_features = gpt.out_head.weight.shape\n",
"\n",
"# Define the new number of output features (e.g., adding 2 new tokens)\n",
"new_out_features = original_out_features + 2\n",
"\n",
"# Create a new linear layer with the extended output size\n",
"new_linear = torch.nn.Linear(original_in_features, new_out_features)\n",
"\n",
"# Copy the weights and biases from the original linear layer\n",
"with torch.no_grad():\n",
" new_linear.weight[:original_out_features] = gpt.out_head.weight\n",
" if gpt.out_head.bias is not None:\n",
" new_linear.bias[:original_out_features] = gpt.out_head.bias\n",
"\n",
"# Replace the original linear layer with the new one\n",
"gpt.out_head = new_linear\n",
"\n",
"print(gpt.out_head)"
]
},
{
"cell_type": "markdown",
"id": "df5d2205-1fae-4a4f-a7bd-fa8fc37eeec2",
"metadata": {},
"source": [
"- Let's try this updated model on the original token IDs first:"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "df604bbc-6c13-4792-8ba8-ecb692117c25",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([[[ 0.2267, 0.9132, 1.0494, ..., -0.2330, -0.3008, -1.1458],\n",
" [ 0.6808, -0.0495, 0.8574, ..., 0.0671, 0.5572, -0.7873],\n",
" [-0.1947, 0.1045, 0.2773, ..., 1.3368, 0.8479, -0.9660],\n",
" ...,\n",
" [ 0.1200, -0.1027, 2.0549, ..., -0.1519, -0.2096, 0.5651],\n",
" [-1.1920, -0.1819, -0.0981, ..., -0.1108, 0.8435, -0.3771],\n",
" [-1.0612, -0.5674, 0.3229, ..., 0.8383, -0.7121, -0.4850]]])\n"
]
}
],
"source": [
"with torch.no_grad():\n",
" output = gpt(torch.tensor([original_token_ids]))\n",
"print(output)"
]
},
{
"cell_type": "markdown",
"id": "3d80717e-50e6-4927-8129-0aadfa2628f5",
"metadata": {},
"source": [
"- Next, let's try it on the updated tokens:"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "75f11ec9-bdd2-440f-b8c8-6646b75891c6",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([[[ 0.2267, 0.9132, 1.0494, ..., -0.2330, -0.3008, -1.1458],\n",
" [ 0.6808, -0.0495, 0.8574, ..., 0.0671, 0.5572, -0.7873],\n",
" [-0.1947, 0.1045, 0.2773, ..., 1.3368, 0.8479, -0.9660],\n",
" ...,\n",
" [-0.0656, -1.2451, 0.7957, ..., -1.2124, 0.1044, 0.5088],\n",
" [-1.1561, -0.7380, -0.0645, ..., -0.4373, 1.1401, -0.3903],\n",
" [-0.8961, -0.6437, -0.1667, ..., 0.5663, -0.5862, -0.4020]]])\n"
]
}
],
"source": [
"with torch.no_grad():\n",
" output = gpt(torch.tensor([new_token_ids]))\n",
"print(output)"
]
},
{
"cell_type": "markdown",
"id": "d88a1bba-db01-4090-97e4-25dfc23ed54c",
"metadata": {},
"source": [
"- As we can see, the model works on the extended token set\n",
"- In practice, we want to now finetune (or continually pretrain) the model (specifically the new embedding and output layers) on data containing the new tokens"
]
},
{
"cell_type": "markdown",
"id": "6de573ad-0338-40d9-9dad-de60ae349c4f",
"metadata": {},
"source": [
"**A note about weight tying**\n",
"\n",
"- If the model uses weight tying, which means that the embedding layer and output layer share the same weights, similar to Llama 3 [link], updating the output layer is much simpler\n",
"- In this case, we can simply copy over the weights from the embedding layer:"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "4cbc5f51-c7a8-49d0-b87f-d3d87510953b",
"metadata": {},
"outputs": [],
"source": [
"gpt.out_head.weight = gpt.tok_emb.weight"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "d0d553a8-edff-40f0-bdc4-dff900e16caf",
"metadata": {},
"outputs": [],
"source": [
"with torch.no_grad():\n",
" output = gpt(torch.tensor([new_token_ids]))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@ -0,0 +1,142 @@
# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
# Source for "Build a Large Language Model From Scratch"
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
# Code: https://github.com/rasbt/LLMs-from-scratch
import os
import urllib.request
# import requests
import json
import numpy as np
import tensorflow as tf
from tqdm import tqdm
def download_and_load_gpt2(model_size, models_dir):
# Validate model size
allowed_sizes = ("124M", "355M", "774M", "1558M")
if model_size not in allowed_sizes:
raise ValueError(f"Model size not in {allowed_sizes}")
# Define paths
model_dir = os.path.join(models_dir, model_size)
base_url = "https://openaipublic.blob.core.windows.net/gpt-2/models"
filenames = [
"checkpoint", "encoder.json", "hparams.json",
"model.ckpt.data-00000-of-00001", "model.ckpt.index",
"model.ckpt.meta", "vocab.bpe"
]
# Download files
os.makedirs(model_dir, exist_ok=True)
for filename in filenames:
file_url = os.path.join(base_url, model_size, filename)
file_path = os.path.join(model_dir, filename)
download_file(file_url, file_path)
# Load settings and params
tf_ckpt_path = tf.train.latest_checkpoint(model_dir)
settings = json.load(open(os.path.join(model_dir, "hparams.json")))
params = load_gpt2_params_from_tf_ckpt(tf_ckpt_path, settings)
return settings, params
def download_file(url, destination):
# Send a GET request to download the file
try:
with urllib.request.urlopen(url) as response:
# Get the total file size from headers, defaulting to 0 if not present
file_size = int(response.headers.get("Content-Length", 0))
# Check if file exists and has the same size
if os.path.exists(destination):
file_size_local = os.path.getsize(destination)
if file_size == file_size_local:
print(f"File already exists and is up-to-date: {destination}")
return
# Define the block size for reading the file
block_size = 1024 # 1 Kilobyte
# Initialize the progress bar with total file size
progress_bar_description = os.path.basename(url) # Extract filename from URL
with tqdm(total=file_size, unit="iB", unit_scale=True, desc=progress_bar_description) as progress_bar:
# Open the destination file in binary write mode
with open(destination, "wb") as file:
# Read the file in chunks and write to destination
while True:
chunk = response.read(block_size)
if not chunk:
break
file.write(chunk)
progress_bar.update(len(chunk)) # Update progress bar
except urllib.error.HTTPError:
s = (
f"The specified URL ({url}) is incorrect, the internet connection cannot be established,"
"\nor the requested file is temporarily unavailable.\nPlease visit the following website"
" for help: https://github.com/rasbt/LLMs-from-scratch/discussions/273")
print(s)
# Alternative way using `requests`
"""
def download_file(url, destination):
# Send a GET request to download the file in streaming mode
response = requests.get(url, stream=True)
# Get the total file size from headers, defaulting to 0 if not present
file_size = int(response.headers.get("content-length", 0))
# Check if file exists and has the same size
if os.path.exists(destination):
file_size_local = os.path.getsize(destination)
if file_size == file_size_local:
print(f"File already exists and is up-to-date: {destination}")
return
# Define the block size for reading the file
block_size = 1024 # 1 Kilobyte
# Initialize the progress bar with total file size
progress_bar_description = url.split("/")[-1] # Extract filename from URL
with tqdm(total=file_size, unit="iB", unit_scale=True, desc=progress_bar_description) as progress_bar:
# Open the destination file in binary write mode
with open(destination, "wb") as file:
# Iterate over the file data in chunks
for chunk in response.iter_content(block_size):
progress_bar.update(len(chunk)) # Update progress bar
file.write(chunk) # Write the chunk to the file
"""
def load_gpt2_params_from_tf_ckpt(ckpt_path, settings):
# Initialize parameters dictionary with empty blocks for each layer
params = {"blocks": [{} for _ in range(settings["n_layer"])]}
# Iterate over each variable in the checkpoint
for name, _ in tf.train.list_variables(ckpt_path):
# Load the variable and remove singleton dimensions
variable_array = np.squeeze(tf.train.load_variable(ckpt_path, name))
# Process the variable name to extract relevant parts
variable_name_parts = name.split("/")[1:] # Skip the 'model/' prefix
# Identify the target dictionary for the variable
target_dict = params
if variable_name_parts[0].startswith("h"):
layer_number = int(variable_name_parts[0][1:])
target_dict = params["blocks"][layer_number]
# Recursively access or create nested dictionaries
for key in variable_name_parts[1:-1]:
target_dict = target_dict.setdefault(key, {})
# Assign the variable array to the last key
last_key = variable_name_parts[-1]
target_dict[last_key] = variable_array
return params

View File

@ -0,0 +1,279 @@
# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
# Source for "Build a Large Language Model From Scratch"
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
# Code: https://github.com/rasbt/LLMs-from-scratch
#
# This file collects all the relevant code that we covered thus far
# throughout Chapters 2-4.
# This file can be run as a standalone script.
import tiktoken
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
#####################################
# Chapter 2
#####################################
class GPTDatasetV1(Dataset):
def __init__(self, txt, tokenizer, max_length, stride):
self.input_ids = []
self.target_ids = []
# Tokenize the entire text
token_ids = tokenizer.encode(txt, allowed_special={"<|endoftext|>"})
# Use a sliding window to chunk the book into overlapping sequences of max_length
for i in range(0, len(token_ids) - max_length, stride):
input_chunk = token_ids[i:i + max_length]
target_chunk = token_ids[i + 1: i + max_length + 1]
self.input_ids.append(torch.tensor(input_chunk))
self.target_ids.append(torch.tensor(target_chunk))
def __len__(self):
return len(self.input_ids)
def __getitem__(self, idx):
return self.input_ids[idx], self.target_ids[idx]
def create_dataloader_v1(txt, batch_size=4, max_length=256,
stride=128, shuffle=True, drop_last=True, num_workers=0):
# Initialize the tokenizer
tokenizer = tiktoken.get_encoding("gpt2")
# Create dataset
dataset = GPTDatasetV1(txt, tokenizer, max_length, stride)
# Create dataloader
dataloader = DataLoader(
dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, num_workers=num_workers)
return dataloader
#####################################
# Chapter 3
#####################################
class MultiHeadAttention(nn.Module):
def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
super().__init__()
assert d_out % num_heads == 0, "d_out must be divisible by n_heads"
self.d_out = d_out
self.num_heads = num_heads
self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim
self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs
self.dropout = nn.Dropout(dropout)
self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))
def forward(self, x):
b, num_tokens, d_in = x.shape
keys = self.W_key(x) # Shape: (b, num_tokens, d_out)
queries = self.W_query(x)
values = self.W_value(x)
# We implicitly split the matrix by adding a `num_heads` dimension
# Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
values = values.view(b, num_tokens, self.num_heads, self.head_dim)
queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
# Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
keys = keys.transpose(1, 2)
queries = queries.transpose(1, 2)
values = values.transpose(1, 2)
# Compute scaled dot-product attention (aka self-attention) with a causal mask
attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head
# Original mask truncated to the number of tokens and converted to boolean
mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
# Use the mask to fill attention scores
attn_scores.masked_fill_(mask_bool, -torch.inf)
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
attn_weights = self.dropout(attn_weights)
# Shape: (b, num_tokens, num_heads, head_dim)
context_vec = (attn_weights @ values).transpose(1, 2)
# Combine heads, where self.d_out = self.num_heads * self.head_dim
context_vec = context_vec.reshape(b, num_tokens, self.d_out)
context_vec = self.out_proj(context_vec) # optional projection
return context_vec
#####################################
# Chapter 4
#####################################
class LayerNorm(nn.Module):
def __init__(self, emb_dim):
super().__init__()
self.eps = 1e-5
self.scale = nn.Parameter(torch.ones(emb_dim))
self.shift = nn.Parameter(torch.zeros(emb_dim))
def forward(self, x):
mean = x.mean(dim=-1, keepdim=True)
var = x.var(dim=-1, keepdim=True, unbiased=False)
norm_x = (x - mean) / torch.sqrt(var + self.eps)
return self.scale * norm_x + self.shift
class GELU(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return 0.5 * x * (1 + torch.tanh(
torch.sqrt(torch.tensor(2.0 / torch.pi)) *
(x + 0.044715 * torch.pow(x, 3))
))
class FeedForward(nn.Module):
def __init__(self, cfg):
super().__init__()
self.layers = nn.Sequential(
nn.Linear(cfg["emb_dim"], 4 * cfg["emb_dim"]),
GELU(),
nn.Linear(4 * cfg["emb_dim"], cfg["emb_dim"]),
)
def forward(self, x):
return self.layers(x)
class TransformerBlock(nn.Module):
def __init__(self, cfg):
super().__init__()
self.att = MultiHeadAttention(
d_in=cfg["emb_dim"],
d_out=cfg["emb_dim"],
context_length=cfg["context_length"],
num_heads=cfg["n_heads"],
dropout=cfg["drop_rate"],
qkv_bias=cfg["qkv_bias"])
self.ff = FeedForward(cfg)
self.norm1 = LayerNorm(cfg["emb_dim"])
self.norm2 = LayerNorm(cfg["emb_dim"])
self.drop_shortcut = nn.Dropout(cfg["drop_rate"])
def forward(self, x):
# Shortcut connection for attention block
shortcut = x
x = self.norm1(x)
x = self.att(x) # Shape [batch_size, num_tokens, emb_size]
x = self.drop_shortcut(x)
x = x + shortcut # Add the original input back
# Shortcut connection for feed-forward block
shortcut = x
x = self.norm2(x)
x = self.ff(x)
x = self.drop_shortcut(x)
x = x + shortcut # Add the original input back
return x
class GPTModel(nn.Module):
def __init__(self, cfg):
super().__init__()
self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"])
self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"])
self.drop_emb = nn.Dropout(cfg["drop_rate"])
self.trf_blocks = nn.Sequential(
*[TransformerBlock(cfg) for _ in range(cfg["n_layers"])])
self.final_norm = LayerNorm(cfg["emb_dim"])
self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False)
def forward(self, in_idx):
batch_size, seq_len = in_idx.shape
tok_embeds = self.tok_emb(in_idx)
pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))
x = tok_embeds + pos_embeds # Shape [batch_size, num_tokens, emb_size]
x = self.drop_emb(x)
x = self.trf_blocks(x)
x = self.final_norm(x)
logits = self.out_head(x)
return logits
def generate_text_simple(model, idx, max_new_tokens, context_size):
# idx is (B, T) array of indices in the current context
for _ in range(max_new_tokens):
# Crop current context if it exceeds the supported context size
# E.g., if LLM supports only 5 tokens, and the context size is 10
# then only the last 5 tokens are used as context
idx_cond = idx[:, -context_size:]
# Get the predictions
with torch.no_grad():
logits = model(idx_cond)
# Focus only on the last time step
# (batch, n_token, vocab_size) becomes (batch, vocab_size)
logits = logits[:, -1, :]
# Get the idx of the vocab entry with the highest logits value
idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch, 1)
# Append sampled index to the running sequence
idx = torch.cat((idx, idx_next), dim=1) # (batch, n_tokens+1)
return idx
if __name__ == "__main__":
GPT_CONFIG_124M = {
"vocab_size": 50257, # Vocabulary size
"context_length": 1024, # Context length
"emb_dim": 768, # Embedding dimension
"n_heads": 12, # Number of attention heads
"n_layers": 12, # Number of layers
"drop_rate": 0.1, # Dropout rate
"qkv_bias": False # Query-Key-Value bias
}
torch.manual_seed(123)
model = GPTModel(GPT_CONFIG_124M)
model.eval() # disable dropout
start_context = "Hello, I am"
tokenizer = tiktoken.get_encoding("gpt2")
encoded = tokenizer.encode(start_context)
encoded_tensor = torch.tensor(encoded).unsqueeze(0)
print(f"\n{50*'='}\n{22*' '}IN\n{50*'='}")
print("\nInput text:", start_context)
print("Encoded input text:", encoded)
print("encoded_tensor.shape:", encoded_tensor.shape)
out = generate_text_simple(
model=model,
idx=encoded_tensor,
max_new_tokens=10,
context_size=GPT_CONFIG_124M["context_length"]
)
decoded_text = tokenizer.decode(out.squeeze(0).tolist())
print(f"\n\n{50*'='}\n{22*' '}OUT\n{50*'='}")
print("\nOutput:", out)
print("Output length:", len(out[0]))
print("Output text:", decoded_text)

View File

@ -309,7 +309,30 @@
"Average score: 48.87\n",
"```\n",
"\n",
"The score is close to 50, which is in the same ballpark as the score we previously achieved with the Alpaca-style prompts."
"The score is close to 50, which is in the same ballpark as the score we previously achieved with the Alpaca-style prompts.\n",
"\n",
"There is no inherent advantage or rationale why the Phi prompt-style should be better, but it can be more concise and efficient, except for the caveat mentioned in the *Tip* section below."
]
},
{
"cell_type": "markdown",
"id": "156bc574-3f3e-4479-8f58-c8c8c472416e",
"metadata": {},
"source": [
"#### Tip: Considering special tokens"
]
},
{
"cell_type": "markdown",
"id": "65cacf90-21c2-48f2-8f21-5c0c86749ff2",
"metadata": {},
"source": [
"- Note that the Phi-3 prompt template contains special tokens such as `<|user|>` and `<|assistant|>`, which can be suboptimal for the GPT-2 tokenizer\n",
"- While the GPT-2 tokenizer recognizes `<|endoftext|>` as a special token (encoded into token ID 50256), it is inefficient at handling other special tokens, such as the aforementioned ones\n",
"- For instance, `<|user|>` is encoded into 5 individual token IDs (27, 91, 7220, 91, 29), which is very inefficient\n",
"- We could add `<|user|>` as a new special token in `tiktoken` via the `allowed_special` argument, but please keep in mind that the GPT-2 vocabulary would not be able to handle it without additional modification\n",
"- If you are curious about how a tokenizer and LLM can be extended to handle special tokens, please see the [extend-tiktoken.ipynb](../../ch05/09_extending-tokenizers/extend-tiktoken.ipynb) bonus materials (note that this is not required here but is just an interesting/bonus consideration for curious readers)\n",
"- Furthermore, we can hypothesize that models that support these special tokens of a prompt template via their vocabulary may perform more efficiently and better overall"
]
},
{
@ -994,7 +1017,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.11"
"version": "3.11.4"
}
},
"nbformat": 4,