mirror of
				https://github.com/rasbt/LLMs-from-scratch.git
				synced 2025-10-26 07:20:09 +00:00 
			
		
		
		
	Improve rope settings for llama3 (#380)
This commit is contained in:
		
							parent
							
								
									2ae4ad15ba
								
							
						
					
					
						commit
						feb0647c79
					
				
							
								
								
									
										5
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										5
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @ -35,8 +35,9 @@ ch05/01_main-chapter-code/model.pth | ||||
| ch05/01_main-chapter-code/model_and_optimizer.pth | ||||
| ch05/03_bonus_pretraining_on_gutenberg/model_checkpoints | ||||
| ch05/06_user_interface/gpt2 | ||||
| ch05/07_gpt_to_llama/models--meta-llama--Llama-2-7b | ||||
| ch05/07_gpt_to_llama/models--meta-llama--Llama-2-7b-chat | ||||
| ch05/07_gpt_to_llama/Llama-2-7b | ||||
| ch05/07_gpt_to_llama/Llama-2-7b-chat | ||||
| ch05/07_gpt_to_llama/.cache | ||||
| 
 | ||||
| ch06/01_main-chapter-code/gpt2 | ||||
| ch06/02_bonus_additional-experiments/gpt2 | ||||
|  | ||||
| @ -180,7 +180,7 @@ | ||||
|     "\n", | ||||
|     "\n", | ||||
|     "class RMSNorm(nn.Module):\n", | ||||
|     "    def __init__(self, emb_dim, eps=1e-6):\n", | ||||
|     "    def __init__(self, emb_dim, eps=1e-5):\n", | ||||
|     "        super().__init__()\n", | ||||
|     "        self.eps = eps\n", | ||||
|     "        self.emb_dim = emb_dim\n", | ||||
| @ -216,7 +216,7 @@ | ||||
|     "example_batch = torch.randn(2, 3, 4)\n", | ||||
|     "\n", | ||||
|     "rms_norm = RMSNorm(emb_dim=example_batch.shape[-1])\n", | ||||
|     "rmsnorm_pytorch = torch.nn.RMSNorm(example_batch.shape[-1], eps=1e-6)\n", | ||||
|     "rmsnorm_pytorch = torch.nn.RMSNorm(example_batch.shape[-1], eps=1e-5)\n", | ||||
|     "\n", | ||||
|     "assert torch.allclose(rms_norm(example_batch), rmsnorm_pytorch(example_batch))" | ||||
|    ] | ||||
| @ -417,11 +417,11 @@ | ||||
|    }, | ||||
|    "outputs": [], | ||||
|    "source": [ | ||||
|     "def precompute_rope_params(head_dim, context_length=4096):\n", | ||||
|     "def precompute_rope_params(head_dim, theta_base=10_000, context_length=4096):\n", | ||||
|     "    assert head_dim % 2 == 0, \"Embedding dimension must be even\"\n", | ||||
|     "\n", | ||||
|     "    # Compute the inverse frequencies\n", | ||||
|     "    inv_freq = 1.0 / (10000 ** (torch.arange(0, head_dim // 2) / (head_dim // 2)))\n", | ||||
|     "    inv_freq = 1.0 / (theta_base ** (torch.arange(0, head_dim // 2) / (head_dim // 2)))\n", | ||||
|     "\n", | ||||
|     "    # Generate position indices\n", | ||||
|     "    positions = torch.arange(context_length)\n", | ||||
| @ -1151,7 +1151,7 @@ | ||||
|     "tokenizer_file = hf_hub_download(\n", | ||||
|     "    repo_id=\"meta-llama/Llama-2-7b\",\n", | ||||
|     "    filename=\"tokenizer.model\",\n", | ||||
|     "    cache_dir=\".\")" | ||||
|     "    local_dir=\"Llama-2-7B\")" | ||||
|    ] | ||||
|   }, | ||||
|   { | ||||
| @ -1285,7 +1285,7 @@ | ||||
|     "weights_file = hf_hub_download(\n", | ||||
|     "   repo_id=\"meta-llama/Llama-2-7b\",\n", | ||||
|     "   filename=\"consolidated.00.pth\",\n", | ||||
|     "   cache_dir=\".\"\n", | ||||
|     "   local_dir=\"Llama-2-7b\"\n", | ||||
|     ")" | ||||
|    ] | ||||
|   }, | ||||
| @ -1520,7 +1520,7 @@ | ||||
|     "weights_file = hf_hub_download(\n", | ||||
|     "   repo_id=\"meta-llama/Llama-2-7b-chat\",\n", | ||||
|     "   filename=\"consolidated.00.pth\",\n", | ||||
|     "   cache_dir=\".\"\n", | ||||
|     "   lcoal_dir=\"Llama-2-7b-chat\n", | ||||
|     ")\n", | ||||
|     "\n", | ||||
|     "model = Llama2Model(LLAMA2_CONFIG_7B)\n", | ||||
|  | ||||
| @ -58,10 +58,10 @@ def set_seed(): | ||||
|     torch.manual_seed(123) | ||||
| 
 | ||||
| 
 | ||||
| def test_rope(notebook): | ||||
| def test_rope_llama2(notebook): | ||||
|     # Settings | ||||
|     batch_size = 1 | ||||
|     context_len = 5 | ||||
|     context_len = 4096 | ||||
|     num_heads = 4 | ||||
|     head_dim = 16 | ||||
| 
 | ||||
| @ -76,19 +76,51 @@ def test_rope(notebook): | ||||
|     queries_rot = notebook.compute_rope(queries, cos, sin) | ||||
|     keys_rot = notebook.compute_rope(keys, cos, sin) | ||||
| 
 | ||||
|     class RoPEConfig: | ||||
|         rope_type = "default" | ||||
|         rope_scaling = None | ||||
|         factor = 1.0 | ||||
|         dim: int = head_dim | ||||
|         rope_theta = 10000 | ||||
|         max_position_embeddings: int = 4096 | ||||
|         hidden_size = head_dim * num_heads | ||||
|         num_attention_heads = num_heads | ||||
|     rot_emb = LlamaRotaryEmbedding( | ||||
|         dim=head_dim, | ||||
|         max_position_embeddings=context_len, | ||||
|         base=10_000 | ||||
|     ) | ||||
| 
 | ||||
|     config = RoPEConfig() | ||||
|     position_ids = torch.arange(context_len, dtype=torch.long).unsqueeze(0) | ||||
|     ref_cos, ref_sin = rot_emb(queries, position_ids) | ||||
|     ref_queries_rot, ref_keys_rot = apply_rotary_pos_emb(queries, keys, ref_cos, ref_sin) | ||||
| 
 | ||||
|     torch.testing.assert_close(sin, ref_sin.squeeze(0)) | ||||
|     torch.testing.assert_close(cos, ref_cos.squeeze(0)) | ||||
|     torch.testing.assert_close(keys_rot, ref_keys_rot) | ||||
|     torch.testing.assert_close(queries_rot, ref_queries_rot) | ||||
| 
 | ||||
| 
 | ||||
| def test_rope_llama3(notebook): | ||||
|     # Settings | ||||
|     batch_size = 1 | ||||
|     context_len = 8192 | ||||
|     num_heads = 4 | ||||
|     head_dim = 16 | ||||
|     theta_base = 50_000 | ||||
| 
 | ||||
|     # Instantiate RoPE parameters | ||||
|     cos, sin = notebook.precompute_rope_params( | ||||
|         head_dim=head_dim, | ||||
|         context_length=context_len, | ||||
|         theta_base=theta_base | ||||
|     ) | ||||
| 
 | ||||
|     # Dummy query and key tensors | ||||
|     queries = torch.randn(batch_size, num_heads, context_len, head_dim) | ||||
|     keys = torch.randn(batch_size, num_heads, context_len, head_dim) | ||||
| 
 | ||||
|     # Apply rotary position embeddings | ||||
|     queries_rot = notebook.compute_rope(queries, cos, sin) | ||||
|     keys_rot = notebook.compute_rope(keys, cos, sin) | ||||
| 
 | ||||
|     rot_emb = LlamaRotaryEmbedding( | ||||
|         dim=head_dim, | ||||
|         max_position_embeddings=context_len, | ||||
|         base=theta_base | ||||
|     ) | ||||
| 
 | ||||
|     rot_emb = LlamaRotaryEmbedding(config=config) | ||||
|     position_ids = torch.arange(context_len, dtype=torch.long).unsqueeze(0) | ||||
|     ref_cos, ref_sin = rot_emb(queries, position_ids) | ||||
|     ref_queries_rot, ref_keys_rot = apply_rotary_pos_emb(queries, keys, ref_cos, ref_sin) | ||||
| @ -108,7 +140,7 @@ def test_silu(notebook): | ||||
| @pytest.mark.skipif(torch.__version__ < "2.4", reason="Requires PyTorch 2.4 or newer") | ||||
| def test_rmsnorm(notebook): | ||||
|     example_batch = torch.randn(2, 3, 4) | ||||
|     rms_norm = notebook.RMSNorm(emb_dim=example_batch.shape[-1]) | ||||
|     rmsnorm_pytorch = torch.nn.RMSNorm(example_batch.shape[-1], eps=1e-6) | ||||
|     rms_norm = notebook.RMSNorm(emb_dim=example_batch.shape[-1], eps=1e-5) | ||||
|     rmsnorm_pytorch = torch.nn.RMSNorm(example_batch.shape[-1], eps=1e-5) | ||||
| 
 | ||||
|     assert torch.allclose(rms_norm(example_batch), rmsnorm_pytorch(example_batch)) | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Sebastian Raschka
						Sebastian Raschka