mirror of
				https://github.com/rasbt/LLMs-from-scratch.git
				synced 2025-11-04 03:40:21 +00:00 
			
		
		
		
	[minor] typo & comments (#441)
* typo & comment - safe -> save - commenting code: batch_size, seq_len = in_idx.shape * comment - adding # NEW for assert num_heads % num_kv_groups == 0 * update memory wording --------- Co-authored-by: rasbt <mail@sebastianraschka.com>
This commit is contained in:
		
							parent
							
								
									dcef9b7d6f
								
							
						
					
					
						commit
						57fdd94358
					
				@ -381,7 +381,7 @@
 | 
				
			|||||||
    "id": "qcD8LSHNhBRW"
 | 
					    "id": "qcD8LSHNhBRW"
 | 
				
			||||||
   },
 | 
					   },
 | 
				
			||||||
   "source": [
 | 
					   "source": [
 | 
				
			||||||
    "- Note that we also added a `dtype=cfg[\"dtype\"]` setting above, which will allow us to load the model directly in lower precision formats later to save memory (versus instantiating it in the original 32-bit precision format and then converting it)\n",
 | 
					    "- Note that we also added a `dtype=cfg[\"dtype\"]` setting above, which will allow us to load the model directly in lower precision formats later to reduce memory usage (versus instantiating it in the original 32-bit precision format and then converting it)\n",
 | 
				
			||||||
    "- We also set `bias=False` since Llama doesn't use any bias units"
 | 
					    "- We also set `bias=False` since Llama doesn't use any bias units"
 | 
				
			||||||
   ]
 | 
					   ]
 | 
				
			||||||
  },
 | 
					  },
 | 
				
			||||||
@ -648,7 +648,7 @@
 | 
				
			|||||||
    "\n",
 | 
					    "\n",
 | 
				
			||||||
    "mha(example_batch)\n",
 | 
					    "mha(example_batch)\n",
 | 
				
			||||||
    "\n",
 | 
					    "\n",
 | 
				
			||||||
    "del mha  # delete to safe memory"
 | 
					    "del mha  # delete to free up memory"
 | 
				
			||||||
   ]
 | 
					   ]
 | 
				
			||||||
  },
 | 
					  },
 | 
				
			||||||
  {
 | 
					  {
 | 
				
			||||||
@ -781,7 +781,7 @@
 | 
				
			|||||||
    "        self.out_head = nn.Linear(cfg[\"emb_dim\"], cfg[\"vocab_size\"], bias=False, dtype=cfg[\"dtype\"])\n",
 | 
					    "        self.out_head = nn.Linear(cfg[\"emb_dim\"], cfg[\"vocab_size\"], bias=False, dtype=cfg[\"dtype\"])\n",
 | 
				
			||||||
    "\n",
 | 
					    "\n",
 | 
				
			||||||
    "    def forward(self, in_idx):\n",
 | 
					    "    def forward(self, in_idx):\n",
 | 
				
			||||||
    "        batch_size, seq_len = in_idx.shape\n",
 | 
					    "        # batch_size, seq_len = in_idx.shape\n",
 | 
				
			||||||
    "        tok_embeds = self.tok_emb(in_idx)\n",
 | 
					    "        tok_embeds = self.tok_emb(in_idx)\n",
 | 
				
			||||||
    "        # pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))\n",
 | 
					    "        # pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))\n",
 | 
				
			||||||
    "        x = tok_embeds  # + pos_embeds  # Shape [batch_size, num_tokens, emb_size]\n",
 | 
					    "        x = tok_embeds  # + pos_embeds  # Shape [batch_size, num_tokens, emb_size]\n",
 | 
				
			||||||
@ -890,7 +890,7 @@
 | 
				
			|||||||
    "    \"n_heads\": 32,           # Number of attention heads\n",
 | 
					    "    \"n_heads\": 32,           # Number of attention heads\n",
 | 
				
			||||||
    "    \"n_layers\": 32,          # Number of layers\n",
 | 
					    "    \"n_layers\": 32,          # Number of layers\n",
 | 
				
			||||||
    "    \"hidden_dim\": 11008,     # NEW: Size of the intermediate dimension in FeedForward\n",
 | 
					    "    \"hidden_dim\": 11008,     # NEW: Size of the intermediate dimension in FeedForward\n",
 | 
				
			||||||
    "    \"dtype\": torch.bfloat16  # NEW: Lower-precision dtype to save memory\n",
 | 
					    "    \"dtype\": torch.bfloat16  # NEW: Lower-precision dtype to reduce memory usage\n",
 | 
				
			||||||
    "}"
 | 
					    "}"
 | 
				
			||||||
   ]
 | 
					   ]
 | 
				
			||||||
  },
 | 
					  },
 | 
				
			||||||
@ -1691,7 +1691,7 @@
 | 
				
			|||||||
   "name": "python",
 | 
					   "name": "python",
 | 
				
			||||||
   "nbconvert_exporter": "python",
 | 
					   "nbconvert_exporter": "python",
 | 
				
			||||||
   "pygments_lexer": "ipython3",
 | 
					   "pygments_lexer": "ipython3",
 | 
				
			||||||
   "version": "3.10.6"
 | 
					   "version": "3.11.4"
 | 
				
			||||||
  },
 | 
					  },
 | 
				
			||||||
  "widgets": {
 | 
					  "widgets": {
 | 
				
			||||||
   "application/vnd.jupyter.widget-state+json": {
 | 
					   "application/vnd.jupyter.widget-state+json": {
 | 
				
			||||||
 | 
				
			|||||||
@ -481,7 +481,7 @@
 | 
				
			|||||||
    "        ):\n",
 | 
					    "        ):\n",
 | 
				
			||||||
    "        super().__init__()\n",
 | 
					    "        super().__init__()\n",
 | 
				
			||||||
    "        assert d_out % num_heads == 0, \"d_out must be divisible by num_heads\"\n",
 | 
					    "        assert d_out % num_heads == 0, \"d_out must be divisible by num_heads\"\n",
 | 
				
			||||||
    "        assert num_heads % num_kv_groups == 0, \"num_heads must be divisible by num_kv_groups\"\n",
 | 
					    "        assert num_heads % num_kv_groups == 0, \"num_heads must be divisible by num_kv_groups\"  # NEW\n",
 | 
				
			||||||
    "\n",
 | 
					    "\n",
 | 
				
			||||||
    "        self.d_out = d_out\n",
 | 
					    "        self.d_out = d_out\n",
 | 
				
			||||||
    "        self.num_heads = num_heads\n",
 | 
					    "        self.num_heads = num_heads\n",
 | 
				
			||||||
@ -886,7 +886,7 @@
 | 
				
			|||||||
    "    \"n_heads\": 32,           # Number of attention heads\n",
 | 
					    "    \"n_heads\": 32,           # Number of attention heads\n",
 | 
				
			||||||
    "    \"n_layers\": 32,          # Number of layers\n",
 | 
					    "    \"n_layers\": 32,          # Number of layers\n",
 | 
				
			||||||
    "    \"hidden_dim\": 11_008,    # Size of the intermediate dimension in FeedForward\n",
 | 
					    "    \"hidden_dim\": 11_008,    # Size of the intermediate dimension in FeedForward\n",
 | 
				
			||||||
    "    \"dtype\": torch.bfloat16  # Lower-precision dtype to save memory\n",
 | 
					    "    \"dtype\": torch.bfloat16  # Lower-precision dtype to reduce memory usage\n",
 | 
				
			||||||
    "}"
 | 
					    "}"
 | 
				
			||||||
   ]
 | 
					   ]
 | 
				
			||||||
  },
 | 
					  },
 | 
				
			||||||
@ -909,7 +909,7 @@
 | 
				
			|||||||
    "    \"n_kv_groups\": 8,        # NEW: Key-Value groups for grouped-query attention\n",
 | 
					    "    \"n_kv_groups\": 8,        # NEW: Key-Value groups for grouped-query attention\n",
 | 
				
			||||||
    "    \"rope_base\": 500_000.0,  # NEW: The base in RoPE's \"theta\" was increased to 500_000\n",
 | 
					    "    \"rope_base\": 500_000.0,  # NEW: The base in RoPE's \"theta\" was increased to 500_000\n",
 | 
				
			||||||
    "    \"rope_freq\": None,       # NEW: Additional configuration for adjusting the RoPE frequencies\n",
 | 
					    "    \"rope_freq\": None,       # NEW: Additional configuration for adjusting the RoPE frequencies\n",
 | 
				
			||||||
    "    \"dtype\": torch.bfloat16  # Lower-precision dtype to save memory\n",
 | 
					    "    \"dtype\": torch.bfloat16  # Lower-precision dtype to reduce memory usage\n",
 | 
				
			||||||
    "}"
 | 
					    "}"
 | 
				
			||||||
   ]
 | 
					   ]
 | 
				
			||||||
  },
 | 
					  },
 | 
				
			||||||
@ -2062,7 +2062,7 @@
 | 
				
			|||||||
    "    \"n_kv_groups\": 8,        # Key-Value groups for grouped-query attention\n",
 | 
					    "    \"n_kv_groups\": 8,        # Key-Value groups for grouped-query attention\n",
 | 
				
			||||||
    "    \"rope_base\": 500_000.0,  # The base in RoPE's \"theta\"\n",
 | 
					    "    \"rope_base\": 500_000.0,  # The base in RoPE's \"theta\"\n",
 | 
				
			||||||
    "    \"rope_freq\": None,       # Additional configuration for adjusting the RoPE frequencies\n",
 | 
					    "    \"rope_freq\": None,       # Additional configuration for adjusting the RoPE frequencies\n",
 | 
				
			||||||
    "    \"dtype\": torch.bfloat16  # Lower-precision dtype to save memory\n",
 | 
					    "    \"dtype\": torch.bfloat16  # Lower-precision dtype to reduce memory usage\n",
 | 
				
			||||||
    "}\n",
 | 
					    "}\n",
 | 
				
			||||||
    "\n",
 | 
					    "\n",
 | 
				
			||||||
    "LLAMA31_CONFIG_8B = {\n",
 | 
					    "LLAMA31_CONFIG_8B = {\n",
 | 
				
			||||||
@ -2074,7 +2074,7 @@
 | 
				
			|||||||
    "    \"hidden_dim\": 14_336,       # Size of the intermediate dimension in FeedForward\n",
 | 
					    "    \"hidden_dim\": 14_336,       # Size of the intermediate dimension in FeedForward\n",
 | 
				
			||||||
    "    \"n_kv_groups\": 8,           # Key-Value groups for grouped-query attention\n",
 | 
					    "    \"n_kv_groups\": 8,           # Key-Value groups for grouped-query attention\n",
 | 
				
			||||||
    "    \"rope_base\": 500_000.0,     # The base in RoPE's \"theta\"\n",
 | 
					    "    \"rope_base\": 500_000.0,     # The base in RoPE's \"theta\"\n",
 | 
				
			||||||
    "    \"dtype\": torch.bfloat16,    # Lower-precision dtype to save memory\n",
 | 
					    "    \"dtype\": torch.bfloat16,    # Lower-precision dtype to reduce memory usage\n",
 | 
				
			||||||
    "    \"rope_freq\": {              # NEW: RoPE frequency scaling\n",
 | 
					    "    \"rope_freq\": {              # NEW: RoPE frequency scaling\n",
 | 
				
			||||||
    "        \"factor\": 8.0,\n",
 | 
					    "        \"factor\": 8.0,\n",
 | 
				
			||||||
    "        \"low_freq_factor\": 1.0,\n",
 | 
					    "        \"low_freq_factor\": 1.0,\n",
 | 
				
			||||||
@ -2448,7 +2448,7 @@
 | 
				
			|||||||
    "    \"hidden_dim\": 14_336,       # Size of the intermediate dimension in FeedForward\n",
 | 
					    "    \"hidden_dim\": 14_336,       # Size of the intermediate dimension in FeedForward\n",
 | 
				
			||||||
    "    \"n_kv_groups\": 8,           # Key-Value groups for grouped-query attention\n",
 | 
					    "    \"n_kv_groups\": 8,           # Key-Value groups for grouped-query attention\n",
 | 
				
			||||||
    "    \"rope_base\": 500_000.0,     # The base in RoPE's \"theta\"\n",
 | 
					    "    \"rope_base\": 500_000.0,     # The base in RoPE's \"theta\"\n",
 | 
				
			||||||
    "    \"dtype\": torch.bfloat16,    # Lower-precision dtype to save memory\n",
 | 
					    "    \"dtype\": torch.bfloat16,    # Lower-precision dtype to reduce memory usagey\n",
 | 
				
			||||||
    "    \"rope_freq\": {              # NEW: RoPE frequency scaling\n",
 | 
					    "    \"rope_freq\": {              # NEW: RoPE frequency scaling\n",
 | 
				
			||||||
    "        \"factor\": 8.0,\n",
 | 
					    "        \"factor\": 8.0,\n",
 | 
				
			||||||
    "        \"low_freq_factor\": 1.0,\n",
 | 
					    "        \"low_freq_factor\": 1.0,\n",
 | 
				
			||||||
@ -2467,7 +2467,7 @@
 | 
				
			|||||||
    "    \"hidden_dim\": 8192,         # NEW: Almost half the size of the intermediate dimension in FeedForward\n",
 | 
					    "    \"hidden_dim\": 8192,         # NEW: Almost half the size of the intermediate dimension in FeedForward\n",
 | 
				
			||||||
    "    \"n_kv_groups\": 8,           # Key-Value groups for grouped-query attention\n",
 | 
					    "    \"n_kv_groups\": 8,           # Key-Value groups for grouped-query attention\n",
 | 
				
			||||||
    "    \"rope_base\": 500_000.0,     # The base in RoPE's \"theta\"\n",
 | 
					    "    \"rope_base\": 500_000.0,     # The base in RoPE's \"theta\"\n",
 | 
				
			||||||
    "    \"dtype\": torch.bfloat16,    # Lower-precision dtype to save memory\n",
 | 
					    "    \"dtype\": torch.bfloat16,    # Lower-precision dtype to reduce memory usage\n",
 | 
				
			||||||
    "    \"rope_freq\": {              # RoPE frequency scaling\n",
 | 
					    "    \"rope_freq\": {              # RoPE frequency scaling\n",
 | 
				
			||||||
    "        \"factor\": 32.0,         # NEW: Adjustment of the rescaling factor\n",
 | 
					    "        \"factor\": 32.0,         # NEW: Adjustment of the rescaling factor\n",
 | 
				
			||||||
    "        \"low_freq_factor\": 1.0,\n",
 | 
					    "        \"low_freq_factor\": 1.0,\n",
 | 
				
			||||||
 | 
				
			|||||||
@ -438,7 +438,7 @@
 | 
				
			|||||||
    "    \"hidden_dim\": 8192,         # Size of the intermediate dimension in FeedForward\n",
 | 
					    "    \"hidden_dim\": 8192,         # Size of the intermediate dimension in FeedForward\n",
 | 
				
			||||||
    "    \"n_kv_groups\": 8,           # Key-Value groups for grouped-query attention\n",
 | 
					    "    \"n_kv_groups\": 8,           # Key-Value groups for grouped-query attention\n",
 | 
				
			||||||
    "    \"rope_base\": 500_000.0,     # The base in RoPE's \"theta\"\n",
 | 
					    "    \"rope_base\": 500_000.0,     # The base in RoPE's \"theta\"\n",
 | 
				
			||||||
    "    \"dtype\": torch.bfloat16,    # Lower-precision dtype to save memory\n",
 | 
					    "    \"dtype\": torch.bfloat16,    # Lower-precision dtype to reduce memory usage\n",
 | 
				
			||||||
    "    \"rope_freq\": {              # RoPE frequency scaling\n",
 | 
					    "    \"rope_freq\": {              # RoPE frequency scaling\n",
 | 
				
			||||||
    "        \"factor\": 32.0,\n",
 | 
					    "        \"factor\": 32.0,\n",
 | 
				
			||||||
    "        \"low_freq_factor\": 1.0,\n",
 | 
					    "        \"low_freq_factor\": 1.0,\n",
 | 
				
			||||||
@ -458,7 +458,7 @@
 | 
				
			|||||||
    "#     \"hidden_dim\": 8192,         # Size of the intermediate dimension in FeedForward\n",
 | 
					    "#     \"hidden_dim\": 8192,         # Size of the intermediate dimension in FeedForward\n",
 | 
				
			||||||
    "#     \"n_kv_groups\": 8,           # Key-Value groups for grouped-query attention\n",
 | 
					    "#     \"n_kv_groups\": 8,           # Key-Value groups for grouped-query attention\n",
 | 
				
			||||||
    "#     \"rope_base\": 500_000.0,     # The base in RoPE's \"theta\"\n",
 | 
					    "#     \"rope_base\": 500_000.0,     # The base in RoPE's \"theta\"\n",
 | 
				
			||||||
    "#     \"dtype\": torch.bfloat16,    # Lower-precision dtype to save memory\n",
 | 
					    "#     \"dtype\": torch.bfloat16,    # Lower-precision dtype to reduce memory usage\n",
 | 
				
			||||||
    "#     \"rope_freq\": {              # RoPE frequency scaling\n",
 | 
					    "#     \"rope_freq\": {              # RoPE frequency scaling\n",
 | 
				
			||||||
    "#         \"factor\": 32.0,\n",
 | 
					    "#         \"factor\": 32.0,\n",
 | 
				
			||||||
    "#         \"low_freq_factor\": 1.0,\n",
 | 
					    "#         \"low_freq_factor\": 1.0,\n",
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user