602 lines
19 KiB
Plaintext
Raw Normal View History

2024-01-29 08:13:52 -06:00
{
"cells": [
{
"cell_type": "markdown",
"id": "ce9295b2-182b-490b-8325-83a67c4a001d",
"metadata": {},
"source": [
"# Chapter 4: Implementing a GPT model from Scratch To Generate Text \n",
"\n",
"## (Notes are in progress ...)"
]
},
{
"cell_type": "markdown",
"id": "e7da97ed-e02f-4d7f-b68e-a0eba3716e02",
"metadata": {},
"source": [
"- In this chapter, we implement the architecture of a GPT-like LLM; in the next chapter, we will train this LLM"
]
},
{
"cell_type": "markdown",
"id": "53fe99ab-0bcf-4778-a6b5-6db81fb826ef",
"metadata": {},
"source": [
"## 4.1 Coding the decoder"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "5ed66875-1f24-445d-add6-006aae3c5707",
"metadata": {},
"outputs": [],
"source": [
2024-01-31 08:00:19 -06:00
"GPT_CONFIG_124M = {\n",
2024-01-29 08:13:52 -06:00
" \"vocab_size\": 50257, # Vocabulary size\n",
" \"ctx_len\": 1024, # Context length\n",
" \"emb_dim\": 768, # Embedding dimension\n",
" \"n_heads\": 12, # Number of attention heads\n",
" \"n_layers\": 12, # Number of layers\n",
" \"drop_rate\": 0.1, # Dropout rate\n",
" \"qkv_bias\": True # Query-Key-Value bias\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "619c2eed-f8ea-4ff5-92c3-feda0f29b227",
"metadata": {},
"outputs": [],
"source": [
"import torch.nn as nn\n",
"\n",
"\n",
"class DummyGPTModel(nn.Module):\n",
" def __init__(self, cfg):\n",
" super().__init__()\n",
" self.tok_emb = nn.Embedding(cfg[\"vocab_size\"], cfg[\"emb_dim\"])\n",
" self.pos_emb = nn.Embedding(cfg[\"ctx_len\"], cfg[\"emb_dim\"])\n",
" self.drop_emb = nn.Dropout(cfg[\"drop_rate\"])\n",
" \n",
" # Use a placeholder for TransformerBlock\n",
" self.trf_blocks = nn.Sequential(\n",
" *[DummyTransformerBlock(cfg) for _ in range(cfg[\"n_layers\"])])\n",
" \n",
" # Use a placeholder for LayerNorm\n",
" self.final_norm = DummyLayerNorm(cfg[\"emb_dim\"])\n",
" self.out_head = nn.Linear(cfg[\"emb_dim\"], cfg[\"vocab_size\"], bias=False)\n",
"\n",
" def forward(self, in_idx):\n",
" batch_size, seq_len = in_idx.shape\n",
" tok_embeds = self.tok_emb(in_idx)\n",
" pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))\n",
" x = tok_embeds + pos_embeds\n",
" x = self.drop_emb(x)\n",
" x = self.trf_blocks(x)\n",
" x = self.final_norm(x)\n",
" logits = self.out_head(x)\n",
" return logits\n",
"\n",
"\n",
"class DummyTransformerBlock(nn.Module):\n",
" def __init__(self, cfg):\n",
" super().__init__()\n",
" # A simple placeholder\n",
"\n",
" def forward(self, x):\n",
" # This block does nothing and just returns its input.\n",
" return x\n",
"\n",
"\n",
"class DummyLayerNorm(nn.Module):\n",
" def __init__(self, normalized_shape, eps=1e-5):\n",
" super().__init__()\n",
" # The parameters here are just to mimic the LayerNorm interface.\n",
"\n",
" def forward(self, x):\n",
" # This layer does nothing and just returns its input.\n",
" return x"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "794b6b6c-d36f-411e-a7db-8ac566a87fee",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[ 6109, 3626, 6100, 345, 2651, 13],\n",
" [ 6109, 1110, 6622, 257, 11483, 13]])"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import tiktoken\n",
"import torch\n",
"\n",
"tokenizer = tiktoken.get_encoding(\"gpt2\")\n",
"\n",
"batch = []\n",
"\n",
"txt1 = \"Every effort moves you forward.\"\n",
"txt2 = \"Every day holds a lesson.\"\n",
"\n",
"batch.append(torch.tensor(tokenizer.encode(txt1)))\n",
"batch.append(torch.tensor(tokenizer.encode(txt2)))\n",
"batch = torch.stack(batch, dim=0)\n",
"batch"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "009238cd-0160-4834-979c-309710986bb0",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Output shape: torch.Size([2, 6, 50257])\n",
"tensor([[[-1.2034, 0.3201, -0.7130, ..., -1.5548, -0.2390, -0.4667],\n",
" [-0.1192, 0.4539, -0.4432, ..., 0.2392, 1.3469, 1.2430],\n",
" [ 0.5307, 1.6720, -0.4695, ..., 1.1966, 0.0111, 0.5835],\n",
" [ 0.0139, 1.6755, -0.3388, ..., 1.1586, -0.0435, -1.0400],\n",
" [ 0.0106, -1.6711, 0.7797, ..., 0.3561, -0.0867, -0.5452],\n",
" [ 0.1821, 1.1189, 0.1641, ..., 1.9012, 1.2240, 0.8853]],\n",
"\n",
" [[-1.0341, 0.2765, -1.1252, ..., -0.8381, 0.0773, 0.1147],\n",
" [-0.2632, 0.5427, -0.2828, ..., 0.1357, 0.3707, 1.3615],\n",
" [ 0.9695, 1.2466, -0.3515, ..., -0.0171, -0.3478, 0.2616],\n",
" [-0.0237, -0.7329, 0.3184, ..., 1.5946, -0.1334, -0.2981],\n",
" [-0.1876, -0.7909, 0.8811, ..., 1.1121, -0.3781, -1.4438],\n",
" [ 0.0405, 1.2000, 0.0702, ..., 1.4740, 1.1567, 1.2077]]],\n",
" grad_fn=<UnsafeViewBackward0>)\n"
]
}
],
"source": [
"torch.manual_seed(123)\n",
2024-01-31 08:00:19 -06:00
"model = DummyGPTModel(GPT_CONFIG_124M)\n",
2024-01-29 08:13:52 -06:00
"\n",
"out = model(batch)\n",
"print(\"Output shape:\", out.shape)\n",
"print(out)"
]
},
{
"cell_type": "markdown",
"id": "62598daa-f819-40da-95ca-899988b6f8da",
"metadata": {},
"source": [
"## 4.2 Normalizing activations with LayerNorm"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "3333a305-aa3d-460a-bcce-b80662d464d9",
"metadata": {},
"outputs": [],
"source": [
"class LayerNorm(nn.Module):\n",
" def __init__(self, emb_dim):\n",
" super().__init__()\n",
" self.eps = 1e-5\n",
" self.scale = nn.Parameter(torch.ones(emb_dim))\n",
" self.shift = nn.Parameter(torch.zeros(emb_dim))\n",
"\n",
" def forward(self, x):\n",
" mean = x.mean(-1, keepdim=True)\n",
" var = x.var(-1, keepdim=True, unbiased=False)\n",
" norm_x = (x - mean) / torch.sqrt(var + self.eps)\n",
" return self.scale * norm_x + self.shift"
]
},
{
"cell_type": "markdown",
"id": "fd9d772b-c833-4a5c-9d58-9b208d2a0b68",
"metadata": {},
"source": [
"## 4.3 Adding GeLU activation functions"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "9275c879-b148-4579-a107-86827ca14d4d",
"metadata": {},
"outputs": [],
"source": [
"class GELU(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
"\n",
" def forward(self, x):\n",
" return 0.5 * x * (1 + torch.tanh(torch.sqrt(torch.tensor(2 / torch.pi)) *\n",
" (x + 0.044715 * x ** 3)))\n",
"\n",
"\n",
"class FeedForward(nn.Module):\n",
" def __init__(self, cfg):\n",
" super().__init__()\n",
" self.net = nn.Sequential(\n",
" nn.Linear(cfg[\"emb_dim\"], 4 * cfg[\"emb_dim\"]),\n",
" GELU(),\n",
" nn.Linear(4 * cfg[\"emb_dim\"], cfg[\"emb_dim\"]),\n",
" nn.Dropout(cfg[\"drop_rate\"])\n",
" )\n",
"\n",
" def forward(self, x):\n",
" return self.net(x)"
]
},
{
"cell_type": "markdown",
"id": "4ffcb905-53c7-4886-87d2-4464c5fecf89",
"metadata": {},
"source": [
"## 4.4 Understanding shortcut connections"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "05473938-799c-49fd-86d4-8ed65f94fee6",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[-1.1785],\n",
" [-0.0278],\n",
" [-0.5737],\n",
" [-1.5400],\n",
" [ 0.1513]], grad_fn=<AddmmBackward0>)"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"class ExampleWithShortcut(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.fc1 = nn.Linear(10, 10)\n",
" self.fc2 = nn.Linear(10, 10)\n",
" self.fc3 = nn.Linear(10, 1)\n",
" self.relu = nn.ReLU()\n",
"\n",
" def forward(self, x):\n",
" identity = x\n",
" x = self.relu(self.fc1(x))\n",
" x = self.relu(self.fc2(x)) + identity # Shortcut connection\n",
" x = self.fc3(x)\n",
" return x\n",
"\n",
"torch.manual_seed(123)\n",
"ex_short = ExampleWithShortcut()\n",
"inputs = torch.randn(5, 10)\n",
"ex_short(inputs)"
]
},
{
"cell_type": "markdown",
"id": "cae578ca-e564-42cf-8635-a2267047cdff",
"metadata": {},
"source": [
"## 4.5 Connecting attention and linear layers"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "0e1e8176-e5e3-4152-b1aa-0bbd7891dfd9",
"metadata": {},
"outputs": [],
"source": [
"from previous_chapters import MultiHeadAttention\n",
"\n",
"\n",
"class TransformerBlock(nn.Module):\n",
" def __init__(self, cfg):\n",
" super().__init__()\n",
" self.att = MultiHeadAttention(\n",
" d_in=cfg[\"emb_dim\"],\n",
" d_out=cfg[\"emb_dim\"],\n",
" block_size=cfg[\"ctx_len\"],\n",
" num_heads=cfg[\"n_heads\"], \n",
" dropout=cfg[\"drop_rate\"],\n",
" qkv_bias=cfg[\"qkv_bias\"])\n",
" self.ff = FeedForward(cfg)\n",
" self.norm1 = LayerNorm(cfg[\"emb_dim\"])\n",
" self.norm2 = LayerNorm(cfg[\"emb_dim\"])\n",
" self.drop_resid = nn.Dropout(cfg[\"drop_rate\"])\n",
"\n",
" def forward(self, x):\n",
" x = x + self.drop_resid(self.att(self.norm1(x)))\n",
" x = x + self.drop_resid(self.ff(self.norm2(x)))\n",
" return x"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "c61de39c-d03c-4a32-8b57-f49ac3834857",
"metadata": {},
"outputs": [],
"source": [
"class GPTModel(nn.Module):\n",
" def __init__(self, cfg):\n",
" super().__init__()\n",
" self.tok_emb = nn.Embedding(cfg[\"vocab_size\"], cfg[\"emb_dim\"])\n",
" self.pos_emb = nn.Embedding(cfg[\"ctx_len\"], cfg[\"emb_dim\"])\n",
" \n",
" # Use a placeholder for TransformerBlock\n",
" self.trf_blocks = nn.Sequential(\n",
" *[TransformerBlock(cfg) for _ in range(cfg[\"n_layers\"])])\n",
" \n",
" # Use a placeholder for LayerNorm\n",
" self.final_norm = LayerNorm(cfg[\"emb_dim\"])\n",
" self.out_head = nn.Linear(cfg[\"emb_dim\"], cfg[\"vocab_size\"], bias=False)\n",
"\n",
" def forward(self, in_idx):\n",
" batch_size, seq_len = in_idx.shape\n",
" tok_embeds = self.tok_emb(in_idx)\n",
" pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))\n",
" x = tok_embeds + pos_embeds\n",
" x = self.trf_blocks(x)\n",
" x = self.final_norm(x)\n",
" logits = self.out_head(x)\n",
" return logits"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "252b78c2-4404-483b-84fe-a412e55c16fc",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Output shape: torch.Size([2, 6, 50257])\n",
"tensor([[[-0.7971, -0.6232, -0.1815, ..., 0.1020, -0.0916, 0.1885],\n",
" [ 0.5491, -0.5220, 0.7559, ..., -0.3137, -0.8780, 0.2182],\n",
" [ 0.3107, 0.0346, -0.4637, ..., -0.3700, -0.4346, -0.0747],\n",
" [ 0.5681, 0.3940, 0.5397, ..., -0.1027, 0.5461, 0.4834],\n",
" [-0.2948, -0.1605, -0.5878, ..., 0.0054, -0.0207, -0.1100],\n",
" [-0.3096, -0.7744, -0.0254, ..., 0.7480, 0.3515, 0.3208]],\n",
"\n",
" [[-0.6910, -0.3758, -0.1458, ..., -0.1824, -0.5231, 0.0873],\n",
" [-0.2562, -0.4204, 1.5507, ..., -0.7057, -0.3989, 0.0084],\n",
" [-0.4263, -0.2257, -0.2074, ..., -0.2160, -1.1648, 0.4744],\n",
" [-0.0245, 1.3792, 0.2234, ..., -0.7153, -0.7858, -0.3762],\n",
" [-0.4696, -0.4584, -0.4812, ..., 0.5044, -0.8911, 0.1549],\n",
" [-0.7727, -0.6125, -0.3203, ..., 1.0753, -0.0878, 0.2805]]],\n",
" grad_fn=<UnsafeViewBackward0>)\n"
]
}
],
"source": [
"torch.manual_seed(123)\n",
2024-01-31 08:00:19 -06:00
"model = GPTModel(GPT_CONFIG_124M)\n",
2024-01-29 08:13:52 -06:00
"\n",
"out = model(batch)\n",
"print(\"Output shape:\", out.shape)\n",
"print(out)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "84fb8be4-9d3b-402b-b3da-86b663aac33a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Total number of parameters: 163,037,184\n",
"Number of trainable parameters considering weight tying: 124,439,808\n"
]
}
],
"source": [
"total_params = sum(p.numel() for p in model.parameters())\n",
"print(f\"Total number of parameters: {total_params:,}\")\n",
"\n",
"total_params_gpt2 = total_params - sum(p.numel() for p in model.tok_emb.parameters())\n",
"print(f\"Number of trainable parameters considering weight tying: {total_params_gpt2:,}\")"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "5131a752-fab8-4d70-a600-e29870b33528",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Total size of the model: 621.94 MB\n"
]
}
],
"source": [
"# Calculate the total size in bytes (assuming float32, 4 bytes per parameter)\n",
"total_size_bytes = total_params * 4\n",
"\n",
"# Convert to megabytes\n",
"total_size_mb = total_size_bytes / (1024 * 1024)\n",
"\n",
"print(f\"Total size of the model: {total_size_mb:.2f} MB\")"
]
},
{
"cell_type": "markdown",
"id": "da5d9bc0-95ab-45d4-9378-417628d86e35",
"metadata": {},
"source": [
"## 4.6 Implementing the forward pass"
]
},
2024-01-31 08:00:19 -06:00
{
"cell_type": "markdown",
"id": "2cf2f7ac-d531-45c0-a556-41b7d13c992e",
"metadata": {},
"source": [
"- The following `generate_text_simple` function implements greedy decoding, which is a simple and fast method to generate text\n",
"- In greedy decoding, at each step, the model chooses the word (or token) with the highest probability as its next output (the highest logit corresponds to the highest probability, so we don't have to compute the softmax function explicitly)\n",
"- In the next chapter, we will implement a more advanced `generate_text` function"
]
},
2024-01-29 08:13:52 -06:00
{
"cell_type": "code",
2024-01-31 08:00:19 -06:00
"execution_count": 13,
2024-01-29 08:13:52 -06:00
"id": "07700ec8-32e8-4775-9c13-5c43671d6728",
"metadata": {},
"outputs": [],
2024-01-31 08:00:19 -06:00
"source": [
"def generate_text_simple(model, idx, max_new_tokens, context_size):\n",
" # idx is (B, T) array of indices in the current context\n",
" for _ in range(max_new_tokens):\n",
" # Crop index to the last block_size tokens\n",
" idx_cond = idx[:, -context_size:]\n",
"\n",
" # Get the predictions\n",
" with torch.no_grad():\n",
" logits = model(idx_cond)\n",
" \n",
" # Focus only on the last time step\n",
" logits = logits[:, -1, :] # becomes (B, C)\n",
"\n",
" idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (B, 1)\n",
"\n",
" # Append sampled index to the running sequence\n",
" idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)\n",
"\n",
" return idx"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "7b1444cd-f7b9-4348-9034-523f7bd20597",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"encoded: [17250, 11, 314, 1101]\n",
"encoded_tensor.shape: torch.Size([1, 4])\n"
]
}
],
"source": [
"start_context = \"Hi, I'm\"\n",
"\n",
"encoded = tokenizer.encode(start_context)\n",
"print(\"encoded:\", encoded)\n",
"\n",
"encoded_tensor = torch.tensor(encoded).unsqueeze(0)\n",
"print(\"encoded_tensor.shape:\", encoded_tensor.shape)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "1bfb9fc6-dd14-457b-b0ee-cf0e33e2bb33",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"tensor([[17250, 11, 314, 1101, 19106, 34827, 44404, 13427, 42336, 17502,\n",
" 48541, 27970, 23115, 37903, 42378, 47769, 5746, 4830, 22564, 32971,\n",
" 41836, 32124, 36101, 33912, 23027, 32252, 2951, 27247, 10075, 44911,\n",
" 6399, 42214, 3226, 10305, 35112, 9234, 9466, 27622, 44047, 32440,\n",
" 23270, 31510, 15309, 35186]])\n"
]
}
],
"source": [
"model.eval() # disable dropout\n",
"\n",
"out = generate_text_simple(\n",
" model=model,\n",
" idx=encoded_tensor, \n",
" max_new_tokens=40, \n",
" context_size=GPT_CONFIG_124M[\"ctx_len\"]\n",
")\n",
"\n",
"print(out)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "4e6d40d2-a542-44da-acdc-703a703d80a8",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Hi, I'mshadow maiden motel acceptanceRN lobbying↑312 Lac Sioux Styarrellà iTunesluxWHAT groomingridges Ved†illon Scy eyes Wong cyber Celebration subsequ dodging OfSD liberated monster */ Ducks contestant competed Partnership 226Alex residue\n"
]
}
],
"source": [
"decoded_text = tokenizer.decode(out.squeeze(0).tolist())\n",
"print(decoded_text)"
]
},
{
"cell_type": "markdown",
"id": "847888d2-0f92-4840-8a87-8836b529cee5",
"metadata": {},
"source": [
"- Note that the model is untrained hence the random output texts above\n",
"- We will train the model in the next chapter"
]
2024-01-29 08:13:52 -06:00
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}