diff --git a/ch05/07_gpt_to_llama/converting-llama2-to-llama3.ipynb b/ch05/07_gpt_to_llama/converting-llama2-to-llama3.ipynb index 79659c3..1c5fbf9 100644 --- a/ch05/07_gpt_to_llama/converting-llama2-to-llama3.ipynb +++ b/ch05/07_gpt_to_llama/converting-llama2-to-llama3.ipynb @@ -430,6 +430,14 @@ "- In short, the main change in GQA is that each query group needs to be repeated to match the number of heads it is associated with, as implemented below" ] }, + { + "cell_type": "markdown", + "id": "842aa71a-4659-424e-8830-392bd6ae86af", + "metadata": {}, + "source": [ + "- In addition, we also introduce a `SharedBuffers` class that will allow us to reuse the `mask`, `cos`, and `sin` tensors in the transformer blocks to improve efficiency (this will be crucial when working with models such as Llama 3.1 and 3.2 later, which support up to 131k input tokens)" + ] + }, { "cell_type": "code", "execution_count": 8, @@ -441,6 +449,28 @@ "source": [ "import torch.nn as nn\n", "\n", + "\n", + "############################# NEW #############################\n", + "class SharedBuffers:\n", + " _buffers = {}\n", + "\n", + " @staticmethod\n", + " def get_buffers(context_length, head_dim, rope_base, freq_config, dtype=torch.float32):\n", + " key = (context_length, head_dim, rope_base, tuple(freq_config.values()) if freq_config else freq_config, dtype)\n", + "\n", + " if key not in SharedBuffers._buffers:\n", + " # Create or fetch the buffers\n", + " mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)\n", + " cos, sin = precompute_rope_params(head_dim, rope_base, context_length, freq_config)\n", + " if dtype is not None:\n", + " cos = cos.to(dtype)\n", + " sin = sin.to(dtype)\n", + " SharedBuffers._buffers[key] = (mask, cos, sin)\n", + "\n", + " return SharedBuffers._buffers[key]\n", + "############################# NEW #############################\n", + "\n", + "\n", "class GroupedQueryAttention(nn.Module):\n", " def __init__(\n", " self, d_in, d_out, context_length, num_heads,\n", @@ -469,13 +499,12 @@ " self.W_query = nn.Linear(d_in, d_out, bias=False, dtype=dtype)\n", " self.out_proj = nn.Linear(d_out, d_out, bias=False, dtype=dtype)\n", "\n", - " self.register_buffer(\"mask\", torch.triu(torch.ones(context_length, context_length), diagonal=1))\n", - " cos, sin = precompute_rope_params(\n", - " head_dim=self.head_dim,\n", - " theta_base=rope_base, # NEW\n", - " freq_config=rope_config, # NEW\n", - " context_length=8192\n", - " )\n", + " ############################# NEW #############################\n", + " # Fetch buffers using SharedBuffers\n", + " mask, cos, sin = SharedBuffers.get_buffers(context_length, self.head_dim, rope_base, rope_config, dtype)\n", + " ############################# NEW #############################\n", + " \n", + " self.register_buffer(\"mask\", mask)\n", " self.register_buffer(\"cos\", cos)\n", " self.register_buffer(\"sin\", sin)\n", "\n", @@ -907,6 +936,35 @@ "model = Llama3Model(LLAMA3_CONFIG_8B)" ] }, + { + "cell_type": "markdown", + "id": "edea6334-d1fc-427d-9cf2-4af963ff4bfc", + "metadata": {}, + "source": [ + "- The following is expected to print True to confirm buffers are reused instead of being (wastefully) recreated:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ee9625cc-9afa-4b11-8aab-d536fd170761", + "metadata": {}, + "outputs": [], + "source": [ + "# Check buffers\n", + "print(model.trf_blocks[0].att.mask is model.trf_blocks[-1].att.mask)\n", + "print(model.trf_blocks[0].att.cos is model.trf_blocks[-1].att.cos)\n", + "print(model.trf_blocks[0].att.sin is model.trf_blocks[-1].att.sin) " + ] + }, + { + "cell_type": "markdown", + "id": "8056a521-91a6-440f-8473-591409c3177b", + "metadata": {}, + "source": [ + "- Let's now also compute the number of trainable parameters:" + ] + }, { "cell_type": "code", "execution_count": 18, @@ -2008,16 +2066,16 @@ "}\n", "\n", "LLAMA31_CONFIG_8B = {\n", - " \"vocab_size\": 128_256, # Vocabulary size\n", - " \"context_length\": 8192, # Context length\n", - " \"emb_dim\": 4096, # Embedding dimension\n", - " \"n_heads\": 32, # Number of attention heads\n", - " \"n_layers\": 32, # Number of layers\n", - " \"hidden_dim\": 14_336, # Size of the intermediate dimension in FeedForward\n", - " \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n", - " \"rope_base\": 50_000, # The base in RoPE's \"theta\"\n", - " \"dtype\": torch.bfloat16, # Lower-precision dtype to save memory\n", - " \"rope_freq\": { # NEW: RoPE frequency scaling\n", + " \"vocab_size\": 128_256, # Vocabulary size\n", + " \"context_length\": 131_072, # NEW: Larger supported context length\n", + " \"emb_dim\": 4096, # Embedding dimension\n", + " \"n_heads\": 32, # Number of attention heads\n", + " \"n_layers\": 32, # Number of layers\n", + " \"hidden_dim\": 14_336, # Size of the intermediate dimension in FeedForward\n", + " \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n", + " \"rope_base\": 50_000, # The base in RoPE's \"theta\"\n", + " \"dtype\": torch.bfloat16, # Lower-precision dtype to save memory\n", + " \"rope_freq\": { # NEW: RoPE frequency scaling\n", " \"factor\": 8.0,\n", " \"low_freq_factor\": 1.0,\n", " \"high_freq_factor\": 4.0,\n", @@ -2026,6 +2084,24 @@ "}" ] }, + { + "cell_type": "markdown", + "id": "d81ee464-c112-43b0-9ee8-70df6ac942d0", + "metadata": {}, + "source": [ + "- Reduce the context length so the model would work fine on a MacBook Air (if you have more RAM, feel free to comment out the lines below):" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9bdbe32f-4c96-4e60-8bf4-52b5217df1e6", + "metadata": {}, + "outputs": [], + "source": [ + "LLAMA32_CONFIG[\"context_length\"] = 8192" + ] + }, { "cell_type": "markdown", "id": "xa3bpMDtTdBs", @@ -2338,16 +2414,16 @@ "outputs": [], "source": [ "LLAMA31_CONFIG_8B = {\n", - " \"vocab_size\": 128_256, # Vocabulary size\n", - " \"context_length\": 8192, # Context length\n", - " \"emb_dim\": 4096, # Embedding dimension\n", - " \"n_heads\": 32, # Number of attention heads\n", - " \"n_layers\": 32, # Number of layers\n", - " \"hidden_dim\": 14_336, # Size of the intermediate dimension in FeedForward\n", - " \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n", - " \"rope_base\": 50_000, # The base in RoPE's \"theta\"\n", - " \"dtype\": torch.bfloat16, # Lower-precision dtype to save memory\n", - " \"rope_freq\": { # RoPE frequency scaling\n", + " \"vocab_size\": 128_256, # Vocabulary size\n", + " \"context_length\": 131_072, # NEW: Larger supported context length\n", + " \"emb_dim\": 4096, # Embedding dimension\n", + " \"n_heads\": 32, # Number of attention heads\n", + " \"n_layers\": 32, # Number of layers\n", + " \"hidden_dim\": 14_336, # Size of the intermediate dimension in FeedForward\n", + " \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n", + " \"rope_base\": 50_000, # The base in RoPE's \"theta\"\n", + " \"dtype\": torch.bfloat16, # Lower-precision dtype to save memory\n", + " \"rope_freq\": { # NEW: RoPE frequency scaling\n", " \"factor\": 8.0,\n", " \"low_freq_factor\": 1.0,\n", " \"high_freq_factor\": 4.0,\n", @@ -2357,17 +2433,17 @@ "\n", "\n", "LLAMA32_CONFIG_1B = {\n", - " \"vocab_size\": 128_256, # Vocabulary size\n", - " \"context_length\": 8192, # Context length\n", - " \"emb_dim\": 2048, # NEW: Half the embedding dimension\n", - " \"n_heads\": 32, # Number of attention heads\n", - " \"n_layers\": 16, # NEW: Half the number of layers\n", - " \"hidden_dim\": 8192, # NEW: Almost half the size of the intermediate dimension in FeedForward\n", - " \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n", - " \"rope_base\": 50_000, # The base in RoPE's \"theta\"\n", - " \"dtype\": torch.bfloat16, # Lower-precision dtype to save memory\n", - " \"rope_freq\": { # RoPE frequency scaling\n", - " \"factor\": 32.0, # NEW: Adjustment of the rescaling factor\n", + " \"vocab_size\": 128_256, # Vocabulary size\n", + " \"context_length\": 131_072, # Context length\n", + " \"emb_dim\": 2048, # NEW: Half the embedding dimension\n", + " \"n_heads\": 32, # Number of attention heads\n", + " \"n_layers\": 16, # NEW: Half the number of layers\n", + " \"hidden_dim\": 8192, # NEW: Almost half the size of the intermediate dimension in FeedForward\n", + " \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n", + " \"rope_base\": 50_000, # The base in RoPE's \"theta\"\n", + " \"dtype\": torch.bfloat16, # Lower-precision dtype to save memory\n", + " \"rope_freq\": { # RoPE frequency scaling\n", + " \"factor\": 32.0, # NEW: Adjustment of the rescaling factor\n", " \"low_freq_factor\": 1.0,\n", " \"high_freq_factor\": 4.0,\n", " \"original_context_length\": 8192,\n", @@ -2375,6 +2451,24 @@ "}" ] }, + { + "cell_type": "markdown", + "id": "b5cd351b-d883-460d-9cdc-47e15ddb884a", + "metadata": {}, + "source": [ + "- Reduce the context length so the model would work fine on a MacBook Air (if you have more RAM, feel free to comment out the lines below):" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "387456c3-c6a1-46fe-8830-6e00eb46ac13", + "metadata": {}, + "outputs": [], + "source": [ + "LLAMA32_CONFIG[\"context_length\"] = 8192" + ] + }, { "cell_type": "markdown", "id": "Dl4_0EoJKKYv", @@ -2593,7 +2687,7 @@ "provenance": [] }, "kernelspec": { - "display_name": "base", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -2607,7 +2701,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.11" + "version": "3.11.4" }, "widgets": { "application/vnd.jupyter.widget-state+json": { diff --git a/ch05/07_gpt_to_llama/standalone-llama32.ipynb b/ch05/07_gpt_to_llama/standalone-llama32.ipynb index e2ac747..a9398a2 100644 --- a/ch05/07_gpt_to_llama/standalone-llama32.ipynb +++ b/ch05/07_gpt_to_llama/standalone-llama32.ipynb @@ -69,9 +69,9 @@ "output_type": "stream", "text": [ "blobfile version: 3.0.0\n", - "huggingface_hub version: 0.25.0\n", + "huggingface_hub version: 0.25.1\n", "tiktoken version: 0.7.0\n", - "torch version: 2.5.0.dev20240812+cu121\n" + "torch version: 2.4.0\n" ] } ], @@ -201,6 +201,25 @@ "metadata": {}, "outputs": [], "source": [ + "class SharedBuffers:\n", + " _buffers = {}\n", + "\n", + " @staticmethod\n", + " def get_buffers(context_length, head_dim, rope_base, freq_config, dtype=torch.float32):\n", + " key = (context_length, head_dim, rope_base, tuple(freq_config.values()) if freq_config else freq_config, dtype)\n", + "\n", + " if key not in SharedBuffers._buffers:\n", + " # Create or fetch the buffers\n", + " mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)\n", + " cos, sin = precompute_rope_params(head_dim, rope_base, context_length, freq_config)\n", + " if dtype is not None:\n", + " cos = cos.to(dtype)\n", + " sin = sin.to(dtype)\n", + " SharedBuffers._buffers[key] = (mask, cos, sin)\n", + "\n", + " return SharedBuffers._buffers[key]\n", + "\n", + "\n", "class GroupedQueryAttention(nn.Module):\n", " def __init__(\n", " self, d_in, d_out, context_length, num_heads,\n", @@ -225,13 +244,10 @@ " self.W_query = nn.Linear(d_in, d_out, bias=False, dtype=dtype)\n", " self.out_proj = nn.Linear(d_out, d_out, bias=False, dtype=dtype)\n", "\n", - " self.register_buffer(\"mask\", torch.triu(torch.ones(context_length, context_length), diagonal=1))\n", - " cos, sin = precompute_rope_params(\n", - " head_dim=self.head_dim,\n", - " theta_base=rope_base,\n", - " freq_config=rope_config,\n", - " context_length=8192\n", - " )\n", + " # Fetch buffers using SharedBuffers\n", + " mask, cos, sin = SharedBuffers.get_buffers(context_length, self.head_dim, rope_base, rope_config, dtype)\n", + " self.register_buffer(\"mask\", mask)\n", + "\n", " self.register_buffer(\"cos\", cos)\n", " self.register_buffer(\"sin\", sin)\n", "\n", @@ -384,16 +400,16 @@ "# Llama 3.2 1B\n", "\n", "LLAMA32_CONFIG = {\n", - " \"vocab_size\": 128_256, # Vocabulary size\n", - " \"context_length\": 8192, # Context length\n", - " \"emb_dim\": 2048, # Embedding dimension\n", - " \"n_heads\": 32, # Number of attention heads\n", - " \"n_layers\": 16, # Number of layers\n", - " \"hidden_dim\": 8192, # Size of the intermediate dimension in FeedForward\n", - " \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n", - " \"rope_base\": 50_000, # The base in RoPE's \"theta\"\n", - " \"dtype\": torch.bfloat16, # Lower-precision dtype to save memory\n", - " \"rope_freq\": { # RoPE frequency scaling\n", + " \"vocab_size\": 128_256, # Vocabulary size\n", + " \"context_length\": 131_072, # Context length\n", + " \"emb_dim\": 2048, # Embedding dimension\n", + " \"n_heads\": 32, # Number of attention heads\n", + " \"n_layers\": 16, # Number of layers\n", + " \"hidden_dim\": 8192, # Size of the intermediate dimension in FeedForward\n", + " \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n", + " \"rope_base\": 50_000, # The base in RoPE's \"theta\"\n", + " \"dtype\": torch.bfloat16, # Lower-precision dtype to save memory\n", + " \"rope_freq\": { # RoPE frequency scaling\n", " \"factor\": 32.0,\n", " \"low_freq_factor\": 1.0,\n", " \"high_freq_factor\": 4.0,\n", @@ -404,16 +420,16 @@ "# Llama 3.2 3B\n", "\n", "# LLAMA32_CONFIG = {\n", - "# \"vocab_size\": 128_256, # Vocabulary size\n", - "# \"context_length\": 8192, # Context length\n", - "# \"emb_dim\": 3072, # Embedding dimension\n", - "# \"n_heads\": 24, # Number of attention heads\n", - "# \"n_layers\": 28, # Number of layers\n", - "# \"hidden_dim\": 8192, # Size of the intermediate dimension in FeedForward\n", - "# \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n", - "# \"rope_base\": 50_000, # The base in RoPE's \"theta\"\n", - "# \"dtype\": torch.bfloat16, # Lower-precision dtype to save memory\n", - "# \"rope_freq\": { # RoPE frequency scaling\n", + "# \"vocab_size\": 128_256, # Vocabulary size\n", + "# \"context_length\": 131_000, # Context length\n", + "# \"emb_dim\": 3072, # Embedding dimension\n", + "# \"n_heads\": 24, # Number of attention heads\n", + "# \"n_layers\": 28, # Number of layers\n", + "# \"hidden_dim\": 8192, # Size of the intermediate dimension in FeedForward\n", + "# \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n", + "# \"rope_base\": 50_000, # The base in RoPE's \"theta\"\n", + "# \"dtype\": torch.bfloat16, # Lower-precision dtype to save memory\n", + "# \"rope_freq\": { # RoPE frequency scaling\n", "# \"factor\": 32.0,\n", "# \"low_freq_factor\": 1.0,\n", "# \"high_freq_factor\": 4.0,\n", @@ -424,9 +440,27 @@ "LLAMA_SIZE_STR = \"1B\" if LLAMA32_CONFIG[\"emb_dim\"] == 2048 else \"3B\"" ] }, + { + "cell_type": "markdown", + "id": "34535172-797e-4dd0-84fb-65bc75ad5b06", + "metadata": {}, + "source": [ + "- Reduce the context length so the model would work fine on a MacBook Air (if you have more RAM, feel free to comment out the lines below):" + ] + }, { "cell_type": "code", "execution_count": 9, + "id": "a8bc2370-39d2-4bfe-b4c1-6bdd75fe101c", + "metadata": {}, + "outputs": [], + "source": [ + "LLAMA32_CONFIG[\"context_length\"] = 8192" + ] + }, + { + "cell_type": "code", + "execution_count": 10, "id": "156253fe-aacd-4da2-8f13-705f05c4b11e", "metadata": {}, "outputs": [], @@ -434,9 +468,40 @@ "model = Llama3Model(LLAMA32_CONFIG)" ] }, + { + "cell_type": "markdown", + "id": "19de6c2c-83ce-456d-8be9-6ec415fe9eb1", + "metadata": {}, + "source": [ + "- The following is expected to print True to confirm buffers are reused instead of being (wastefully) recreated:" + ] + }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 11, + "id": "0e95db6d-2712-41a5-a5e0-86c49897f4cf", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "True\n", + "True\n", + "True\n" + ] + } + ], + "source": [ + "# Check buffers\n", + "print(model.trf_blocks[0].att.mask is model.trf_blocks[-1].att.mask)\n", + "print(model.trf_blocks[0].att.cos is model.trf_blocks[-1].att.cos)\n", + "print(model.trf_blocks[0].att.sin is model.trf_blocks[-1].att.sin) " + ] + }, + { + "cell_type": "code", + "execution_count": 12, "id": "364e76ca-52f8-4fa5-af37-c4069f9694bc", "metadata": {}, "outputs": [ @@ -461,7 +526,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 13, "id": "fd5efb03-5a07-46e8-8607-93ed47549d2b", "metadata": {}, "outputs": [ @@ -469,8 +534,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "float32 (PyTorch default): 15.23 GB\n", - "bfloat16: 7.61 GB\n" + "float32 (PyTorch default): 11.42 GB\n", + "bfloat16: 5.71 GB\n" ] } ], @@ -505,7 +570,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 14, "id": "31f12baf-f79b-499f-85c0-51328a6a20f5", "metadata": {}, "outputs": [], @@ -531,7 +596,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 15, "id": "9482b01c-49f9-48e4-ab2c-4a4c75240e77", "metadata": {}, "outputs": [], @@ -631,14 +696,14 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 16, "id": "e9d96dc8-603a-4cb5-8c3e-4d2ca56862ed", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "fede18d637d24f79a27220fb83bc6d2b", + "model_id": "8cdf801700d64fe9b2b827172a8eebcf", "version_major": 2, "version_minor": 0 }, @@ -658,7 +723,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 17, "id": "986bc1a0-804f-4154-80f8-44cefbee1368", "metadata": {}, "outputs": [], @@ -674,7 +739,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 18, "id": "f5a3014f-4c66-4fe2-874e-7b57562c49ad", "metadata": {}, "outputs": [], @@ -694,7 +759,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 19, "id": "75166128-5899-4995-9b88-9672e135650e", "metadata": {}, "outputs": [], @@ -775,7 +840,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 20, "id": "699cb1b8-a67d-49fb-80a6-0dad9d81f392", "metadata": {}, "outputs": [ @@ -818,7 +883,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 21, "id": "7f9f7ccc-70cb-41ff-9c25-44336042fc37", "metadata": {}, "outputs": [ @@ -845,7 +910,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 22, "id": "7b8401c6-e244-4cb7-9849-2ba71ce758d5", "metadata": {}, "outputs": [], @@ -902,7 +967,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 23, "id": "1c7a04fa-6aac-416b-8f63-f1e19227633d", "metadata": {}, "outputs": [ @@ -915,10 +980,10 @@ "\n", "1. Grasses: Llamas love to graze on various types of grasses, including tall grasses and short grasses.\n", "2. Hay: Llamas also eat hay, which is a dry, compressed form of grass or other plants.\n", - "3. Alfalfa: Alfalfa is a legume that is commonly used as a hay substitute in llama feed.\n", - "4. Other plants: Llamas will also eat other plants, such as clover, oats, and barley.\n", + "3. Alfalfa: Alfalfa is a legume that is commonly fed to llamas, as it is high in protein and fiber.\n", + "4. Other plants: Llamas will also eat other plants, such as clover, wild grasses, and shrubs.\n", "\n", - "It's worth noting that llamas are adapted to high-altitude environments and can survive on low-quality hay and\n" + "It's worth noting that llamas are adapted to high altitudes and\n" ] } ], @@ -982,7 +1047,7 @@ ], "metadata": { "kernelspec": { - "display_name": "base", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -996,7 +1061,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.11" + "version": "3.11.4" } }, "nbformat": 4,