diff --git a/README.md b/README.md index bd48b95..df82efc 100644 --- a/README.md +++ b/README.md @@ -36,7 +36,7 @@ Alternatively, you can view this and other files on GitHub at [https://github.co | Ch 1: Understanding Large Language Models | No code | No code | | Ch 2: Working with Text Data | - [ch02.ipynb](ch02/01_main-chapter-code/ch02.ipynb)
- [dataloader.ipynb](ch02/01_main-chapter-code/dataloader.ipynb) (summary)
- [exercise-solutions.ipynb](ch02/01_main-chapter-code/exercise-solutions.ipynb) | [./ch02](./ch02) | | Ch 3: Coding Attention Mechanisms | - [ch03.ipynb](ch03/01_main-chapter-code/ch03.ipynb)
- [multihead-attention.ipynb](ch03/01_main-chapter-code/multihead-attention.ipynb) (summary)
- [exercise-solutions.ipynb](ch03/01_main-chapter-code/exercise-solutions.ipynb)| [./ch03](./ch03) | -| Ch 4: Implementing a GPT Model from Scratch | - [ch04.ipynb](ch04/01_main-chapter-code/ch04.ipynb)
- [gpt.py](ch04/01_main-chapter-code/gpt.py) (summary) | [./ch04](./ch04) | +| Ch 4: Implementing a GPT Model from Scratch | - [ch04.ipynb](ch04/01_main-chapter-code/ch04.ipynb)
- [gpt.py](ch04/01_main-chapter-code/gpt.py) (summary)
- [exercise-solutions.ipynb](ch04/01_main-chapter-code/exercise-solutions.ipynb) | [./ch04](./ch04) | | Ch 5: Pretraining on Unlabeled Data | Q1 2024 | ... | | Ch 6: Finetuning for Text Classification | Q2 2024 | ... | | Ch 7: Finetuning with Human Feedback | Q2 2024 | ... | diff --git a/ch04/01_main-chapter-code/ch04.ipynb b/ch04/01_main-chapter-code/ch04.ipynb index 5001926..f63676e 100644 --- a/ch04/01_main-chapter-code/ch04.ipynb +++ b/ch04/01_main-chapter-code/ch04.ipynb @@ -942,12 +942,11 @@ " super().__init__()\n", " self.tok_emb = nn.Embedding(cfg[\"vocab_size\"], cfg[\"emb_dim\"])\n", " self.pos_emb = nn.Embedding(cfg[\"ctx_len\"], cfg[\"emb_dim\"])\n", + " self.drop_emb = nn.Dropout(cfg[\"drop_rate\"])\n", " \n", - " # Use a placeholder for TransformerBlock\n", " self.trf_blocks = nn.Sequential(\n", " *[TransformerBlock(cfg) for _ in range(cfg[\"n_layers\"])])\n", " \n", - " # Use a placeholder for LayerNorm\n", " self.final_norm = LayerNorm(cfg[\"emb_dim\"])\n", " self.out_head = nn.Linear(\n", " cfg[\"emb_dim\"], cfg[\"vocab_size\"], bias=False\n", @@ -1210,7 +1209,7 @@ }, { "cell_type": "code", - "execution_count": 41, + "execution_count": 26, "id": "c9b428a9-8764-4b36-80cd-7d4e00595ba6", "metadata": {}, "outputs": [], @@ -1264,7 +1263,7 @@ }, { "cell_type": "code", - "execution_count": 54, + "execution_count": 27, "id": "bb3ffc8e-f95f-4a24-a978-939b8953ea3e", "metadata": {}, "outputs": [ @@ -1282,7 +1281,7 @@ " 0.0000], grad_fn=)" ] }, - "execution_count": 54, + "execution_count": 27, "metadata": {}, "output_type": "execute_result" } @@ -1299,7 +1298,7 @@ }, { "cell_type": "code", - "execution_count": 53, + "execution_count": 28, "id": "3d7e3e94-df0f-4c0f-a6a1-423f500ac1d3", "metadata": {}, "outputs": [ @@ -1324,7 +1323,7 @@ }, { "cell_type": "code", - "execution_count": 43, + "execution_count": 29, "id": "a72a9b60-de66-44cf-b2f9-1e638934ada4", "metadata": {}, "outputs": [ @@ -1332,9 +1331,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "Output: tensor([[15496, 11, 314, 716, 27018, 24086, 47843, 30961, 42348, 7267,\n", - " 49706, 43231, 47062, 34657]])\n", - "Output length: 14\n" + "Output: tensor([[15496, 11, 314, 716, 27018, 24086, 47843, 30961, 42348, 7267]])\n", + "Output length: 10\n" ] } ], @@ -1344,7 +1342,7 @@ "out = generate_text_simple(\n", " model=model,\n", " idx=encoded_tensor, \n", - " max_new_tokens=10, \n", + " max_new_tokens=6, \n", " context_size=GPT_CONFIG_124M[\"ctx_len\"]\n", ")\n", "\n", @@ -1362,7 +1360,7 @@ }, { "cell_type": "code", - "execution_count": 29, + "execution_count": 30, "id": "053d99f6-5710-4446-8d52-117fb34ea9f6", "metadata": {}, "outputs": [ @@ -1370,7 +1368,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Hello, I am Featureiman Byeswickattribute argue logger Normandy Compton analogous\n" + "Hello, I am Featureiman Byeswickattribute argue\n" ] } ], diff --git a/ch04/01_main-chapter-code/exercise-solutions.ipynb b/ch04/01_main-chapter-code/exercise-solutions.ipynb new file mode 100644 index 0000000..5291396 --- /dev/null +++ b/ch04/01_main-chapter-code/exercise-solutions.ipynb @@ -0,0 +1,381 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "51c9672d-8d0c-470d-ac2d-1271f8ec3f14", + "metadata": {}, + "source": [ + "# Chapter 4 Exercise solutions" + ] + }, + { + "cell_type": "markdown", + "id": "33dfa199-9aee-41d4-a64b-7e3811b9a616", + "metadata": {}, + "source": [ + "# Exercise 4.1: Using separate dropout parameters" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "5fee2cf5-61c3-4167-81b5-44ea155bbaf2", + "metadata": {}, + "outputs": [], + "source": [ + "GPT_CONFIG_124M = {\n", + " \"vocab_size\": 50257,\n", + " \"ctx_len\": 1024,\n", + " \"emb_dim\": 768,\n", + " \"n_heads\": 12,\n", + " \"n_layers\": 12,\n", + " \"drop_rate_emb\": 0.1, # NEW: dropout for embedding layers\n", + " \"drop_rate_ffn\": 0.1, # NEW: dropout for feed forward module\n", + " \"drop_rate_attn\": 0.1, # NEW: dropout for multi-head attention \n", + " \"drop_rate_resid\": 0.1, # NEW: dropout for residual connections \n", + " \"qkv_bias\": False\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "5aa1b0c1-d78a-48fc-ad08-4802458b43f7", + "metadata": {}, + "outputs": [], + "source": [ + "import torch.nn as nn\n", + "from gpt import MultiHeadAttention, LayerNorm, GELU\n", + "\n", + "class FeedForward(nn.Module):\n", + " def __init__(self, cfg):\n", + " super().__init__()\n", + " self.layers = nn.Sequential(\n", + " nn.Linear(cfg[\"emb_dim\"], 4 * cfg[\"emb_dim\"]),\n", + " GELU(),\n", + " nn.Linear(4 * cfg[\"emb_dim\"], cfg[\"emb_dim\"]),\n", + " nn.Dropout(cfg[\"drop_rate_ffn\"]) # NEW: dropout for feed forward module\n", + " )\n", + "\n", + " def forward(self, x):\n", + " return self.layers(x)\n", + "\n", + "\n", + "class TransformerBlock(nn.Module):\n", + " def __init__(self, cfg):\n", + " super().__init__()\n", + " self.att = MultiHeadAttention(\n", + " d_in=cfg[\"emb_dim\"],\n", + " d_out=cfg[\"emb_dim\"],\n", + " block_size=cfg[\"ctx_len\"],\n", + " num_heads=cfg[\"n_heads\"], \n", + " dropout=cfg[\"drop_rate_attn\"], # NEW: dropout for multi-head attention\n", + " qkv_bias=cfg[\"qkv_bias\"])\n", + " self.ff = FeedForward(cfg)\n", + " self.norm1 = LayerNorm(cfg[\"emb_dim\"])\n", + " self.norm2 = LayerNorm(cfg[\"emb_dim\"])\n", + " self.drop_resid = nn.Dropout(cfg[\"drop_rate_resid\"])\n", + "\n", + " def forward(self, x):\n", + " # Shortcut connection for attention block\n", + " shortcut = x\n", + " x = self.norm1(x)\n", + " x = self.att(x) # Shape [batch_size, num_tokens, emb_size]\n", + " x = self.drop_resid(x)\n", + " x = x + shortcut # Add the original input back\n", + "\n", + " # Shortcut connection for feed-forward block\n", + " shortcut = x\n", + " x = self.norm2(x)\n", + " x = self.ff(x)\n", + " x = self.drop_resid(x)\n", + " x = x + shortcut # Add the original input back\n", + "\n", + " return x\n", + "\n", + "\n", + "class GPTModel(nn.Module):\n", + " def __init__(self, cfg):\n", + " super().__init__()\n", + " self.tok_emb = nn.Embedding(cfg[\"vocab_size\"], cfg[\"emb_dim\"])\n", + " self.pos_emb = nn.Embedding(cfg[\"ctx_len\"], cfg[\"emb_dim\"])\n", + " self.drop_emb = nn.Dropout(cfg[\"drop_rate_emb\"]) # NEW: dropout for embedding layers\n", + "\n", + " self.trf_blocks = nn.Sequential(\n", + " *[TransformerBlock(cfg) for _ in range(cfg[\"n_layers\"])])\n", + "\n", + " self.final_norm = LayerNorm(cfg[\"emb_dim\"])\n", + " self.out_head = nn.Linear(cfg[\"emb_dim\"], cfg[\"vocab_size\"], bias=False)\n", + "\n", + " def forward(self, in_idx):\n", + " batch_size, seq_len = in_idx.shape\n", + " tok_embeds = self.tok_emb(in_idx)\n", + " pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))\n", + " x = tok_embeds + pos_embeds # Shape [batch_size, num_tokens, emb_size]\n", + " x = self.trf_blocks(x)\n", + " x = self.final_norm(x)\n", + " logits = self.out_head(x)\n", + " return logits" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "1d013d32-c275-4f42-be21-9010f1537227", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import tiktoken\n", + "\n", + "torch.manual_seed(123)\n", + "model = GPTModel(GPT_CONFIG_124M)" + ] + }, + { + "cell_type": "markdown", + "id": "5fea8be3-30a1-4623-a6d7-b095c6c1092e", + "metadata": {}, + "source": [ + "# Exercise 4.2: Parameters in the feed forward versus attention module" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "2751b0e5-ffd3-4be2-8db3-e20dd4d61d69", + "metadata": {}, + "outputs": [], + "source": [ + "from gpt import TransformerBlock\n", + "\n", + "GPT_CONFIG_124M = {\n", + " \"vocab_size\": 50257,\n", + " \"ctx_len\": 1024,\n", + " \"emb_dim\": 768,\n", + " \"n_heads\": 12,\n", + " \"n_layers\": 12,\n", + " \"drop_rate\": 0.1,\n", + " \"qkv_bias\": False\n", + "}\n", + "\n", + "model = TransformerBlock(GPT_CONFIG_124M)" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "1bcaffd1-0cf6-4f8f-bd53-ab88a37f443e", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Total number of parameters in feed forward module: 4,722,432\n" + ] + } + ], + "source": [ + "total_params = sum(p.numel() for p in block.ff.parameters())\n", + "print(f\"Total number of parameters in feed forward module: {total_params:,}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "c1dd06c1-ab6c-4df7-ba73-f9cd54b31138", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Total number of parameters in feed forward module: 2,360,064\n" + ] + } + ], + "source": [ + "total_params = sum(p.numel() for p in block.att.parameters())\n", + "print(f\"Total number of parameters in attention module: {total_params:,}\")" + ] + }, + { + "cell_type": "markdown", + "id": "15463dec-520a-47b4-b3ad-e180394fd076", + "metadata": {}, + "source": [ + "- The results above are for a single transformer block\n", + "- Optionally multiply by 12 to capture all transformer blocks in the 124M GPT model" + ] + }, + { + "cell_type": "markdown", + "id": "0f7b7c7f-0fa1-4d30-ab44-e499edd55b6d", + "metadata": {}, + "source": [ + "# Exercise 4.3: Initialize larger GPT models" + ] + }, + { + "cell_type": "markdown", + "id": "310b2e05-3ec8-47fc-afd9-83bf03d4aad8", + "metadata": {}, + "source": [ + "- **GPT2-small** (the 124M configuration we already implemented):\n", + " - \"emb_dim\" = 768\n", + " - \"n_layers\" = 12\n", + " - \"n_heads\" = 12\n", + "\n", + "- **GPT2-medium:**\n", + " - \"emb_dim\" = 1024\n", + " - \"n_layers\" = 24\n", + " - \"n_heads\" = 16\n", + "\n", + "- **GPT2-large:**\n", + " - \"emb_dim\" = 1280\n", + " - \"n_layers\" = 36\n", + " - \"n_heads\" = 20\n", + "\n", + "- **GPT2-XL:**\n", + " - \"emb_dim\" = 1600\n", + " - \"n_layers\" = 48\n", + " - \"n_heads\" = 25" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "90185dea-81ca-4cdc-aef7-4aaf95cba946", + "metadata": {}, + "outputs": [], + "source": [ + "GPT_CONFIG_124M = {\n", + " \"vocab_size\": 50257,\n", + " \"ctx_len\": 1024,\n", + " \"emb_dim\": 768,\n", + " \"n_heads\": 12,\n", + " \"n_layers\": 12,\n", + " \"drop_rate\": 0.1,\n", + " \"qkv_bias\": False\n", + "}\n", + "\n", + "\n", + "def get_config(base_config, model_name=\"gpt2-small\"):\n", + " GPT_CONFIG = base_config.copy()\n", + "\n", + " if model_name == \"gpt2-small\":\n", + " GPT_CONFIG[\"emb_dim\"] = 768\n", + " GPT_CONFIG[\"n_layers\"] = 12\n", + " GPT_CONFIG[\"n_heads\"] = 12\n", + "\n", + " elif model_name == \"gpt2-medium\":\n", + " GPT_CONFIG[\"emb_dim\"] = 1024\n", + " GPT_CONFIG[\"n_layers\"] = 24\n", + " GPT_CONFIG[\"n_heads\"] = 16\n", + "\n", + " elif model_name == \"gpt2-large\":\n", + " GPT_CONFIG[\"emb_dim\"] = 1280\n", + " GPT_CONFIG[\"n_layers\"] = 36\n", + " GPT_CONFIG[\"n_heads\"] = 20\n", + "\n", + " elif model_name == \"gpt2-xl\":\n", + " GPT_CONFIG[\"emb_dim\"] = 1600\n", + " GPT_CONFIG[\"n_layers\"] = 48\n", + " GPT_CONFIG[\"n_heads\"] = 25\n", + "\n", + " else:\n", + " raise ValueError(f\"Incorrect model name {model_name}\")\n", + "\n", + " return GPT_CONFIG\n", + "\n", + "\n", + "def calculate_size(model): # based on chapter code\n", + " \n", + " total_params = sum(p.numel() for p in model.parameters())\n", + " print(f\"Total number of parameters: {total_params:,}\")\n", + "\n", + " total_params_gpt2 = total_params - sum(p.numel() for p in model.out_head.parameters())\n", + " print(f\"Number of trainable parameters considering weight tying: {total_params_gpt2:,}\")\n", + " \n", + " # Calculate the total size in bytes (assuming float32, 4 bytes per parameter)\n", + " total_size_bytes = total_params * 4\n", + " \n", + " # Convert to megabytes\n", + " total_size_mb = total_size_bytes / (1024 * 1024)\n", + " \n", + " print(f\"Total size of the model: {total_size_mb:.2f} MB\")" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "2587e011-78a4-479c-a8fd-961cc40a5fd4", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "gpt2-small:\n", + "Total number of parameters: 163,009,536\n", + "Number of trainable parameters considering weight tying: 124,412,160\n", + "Total size of the model: 621.83 MB\n", + "\n", + "\n", + "gpt2-medium:\n", + "Total number of parameters: 406,212,608\n", + "Number of trainable parameters considering weight tying: 354,749,440\n", + "Total size of the model: 1549.58 MB\n", + "\n", + "\n", + "gpt2-large:\n", + "Total number of parameters: 838,220,800\n", + "Number of trainable parameters considering weight tying: 773,891,840\n", + "Total size of the model: 3197.56 MB\n", + "\n", + "\n", + "gpt2-xl:\n", + "Total number of parameters: 1,637,792,000\n", + "Number of trainable parameters considering weight tying: 1,557,380,800\n", + "Total size of the model: 6247.68 MB\n" + ] + } + ], + "source": [ + "from gpt import GPTModel\n", + "\n", + "\n", + "for model_abbrev in (\"small\", \"medium\", \"large\", \"xl\"):\n", + " model_name = f\"gpt2-{model_abbrev}\"\n", + " CONFIG = get_config(GPT_CONFIG_124M, model_name=model_name)\n", + " model = GPTModel(CONFIG)\n", + " print(f\"\\n\\n{model_name}:\")\n", + " calculate_size(model)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/ch04/01_main-chapter-code/gpt.py b/ch04/01_main-chapter-code/gpt.py index a508786..1301d52 100644 --- a/ch04/01_main-chapter-code/gpt.py +++ b/ch04/01_main-chapter-code/gpt.py @@ -187,12 +187,11 @@ class GPTModel(nn.Module): super().__init__() self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"]) self.pos_emb = nn.Embedding(cfg["ctx_len"], cfg["emb_dim"]) + self.drop_emb = nn.Dropout(cfg["drop_rate"]) - # Use a placeholder for TransformerBlock self.trf_blocks = nn.Sequential( *[TransformerBlock(cfg) for _ in range(cfg["n_layers"])]) - # Use a placeholder for LayerNorm self.final_norm = LayerNorm(cfg["emb_dim"]) self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False)