diff --git a/.gitignore b/.gitignore index 02399f3..038cfcf 100644 --- a/.gitignore +++ b/.gitignore @@ -35,8 +35,9 @@ ch05/01_main-chapter-code/model.pth ch05/01_main-chapter-code/model_and_optimizer.pth ch05/03_bonus_pretraining_on_gutenberg/model_checkpoints ch05/06_user_interface/gpt2 -ch05/07_gpt_to_llama/models--meta-llama--Llama-2-7b -ch05/07_gpt_to_llama/models--meta-llama--Llama-2-7b-chat +ch05/07_gpt_to_llama/Llama-2-7b +ch05/07_gpt_to_llama/Llama-2-7b-chat +ch05/07_gpt_to_llama/.cache ch06/01_main-chapter-code/gpt2 ch06/02_bonus_additional-experiments/gpt2 diff --git a/ch05/07_gpt_to_llama/converting-gpt-to-llama2.ipynb b/ch05/07_gpt_to_llama/converting-gpt-to-llama2.ipynb index 01fb09e..4454af0 100644 --- a/ch05/07_gpt_to_llama/converting-gpt-to-llama2.ipynb +++ b/ch05/07_gpt_to_llama/converting-gpt-to-llama2.ipynb @@ -180,7 +180,7 @@ "\n", "\n", "class RMSNorm(nn.Module):\n", - " def __init__(self, emb_dim, eps=1e-6):\n", + " def __init__(self, emb_dim, eps=1e-5):\n", " super().__init__()\n", " self.eps = eps\n", " self.emb_dim = emb_dim\n", @@ -216,7 +216,7 @@ "example_batch = torch.randn(2, 3, 4)\n", "\n", "rms_norm = RMSNorm(emb_dim=example_batch.shape[-1])\n", - "rmsnorm_pytorch = torch.nn.RMSNorm(example_batch.shape[-1], eps=1e-6)\n", + "rmsnorm_pytorch = torch.nn.RMSNorm(example_batch.shape[-1], eps=1e-5)\n", "\n", "assert torch.allclose(rms_norm(example_batch), rmsnorm_pytorch(example_batch))" ] @@ -417,11 +417,11 @@ }, "outputs": [], "source": [ - "def precompute_rope_params(head_dim, context_length=4096):\n", + "def precompute_rope_params(head_dim, theta_base=10_000, context_length=4096):\n", " assert head_dim % 2 == 0, \"Embedding dimension must be even\"\n", "\n", " # Compute the inverse frequencies\n", - " inv_freq = 1.0 / (10000 ** (torch.arange(0, head_dim // 2) / (head_dim // 2)))\n", + " inv_freq = 1.0 / (theta_base ** (torch.arange(0, head_dim // 2) / (head_dim // 2)))\n", "\n", " # Generate position indices\n", " positions = torch.arange(context_length)\n", @@ -1151,7 +1151,7 @@ "tokenizer_file = hf_hub_download(\n", " repo_id=\"meta-llama/Llama-2-7b\",\n", " filename=\"tokenizer.model\",\n", - " cache_dir=\".\")" + " local_dir=\"Llama-2-7B\")" ] }, { @@ -1285,7 +1285,7 @@ "weights_file = hf_hub_download(\n", " repo_id=\"meta-llama/Llama-2-7b\",\n", " filename=\"consolidated.00.pth\",\n", - " cache_dir=\".\"\n", + " local_dir=\"Llama-2-7b\"\n", ")" ] }, @@ -1520,7 +1520,7 @@ "weights_file = hf_hub_download(\n", " repo_id=\"meta-llama/Llama-2-7b-chat\",\n", " filename=\"consolidated.00.pth\",\n", - " cache_dir=\".\"\n", + " lcoal_dir=\"Llama-2-7b-chat\n", ")\n", "\n", "model = Llama2Model(LLAMA2_CONFIG_7B)\n", diff --git a/ch05/07_gpt_to_llama/tests/tests.py b/ch05/07_gpt_to_llama/tests/tests.py index 99d7b3f..eae1fc7 100644 --- a/ch05/07_gpt_to_llama/tests/tests.py +++ b/ch05/07_gpt_to_llama/tests/tests.py @@ -58,10 +58,10 @@ def set_seed(): torch.manual_seed(123) -def test_rope(notebook): +def test_rope_llama2(notebook): # Settings batch_size = 1 - context_len = 5 + context_len = 4096 num_heads = 4 head_dim = 16 @@ -76,19 +76,51 @@ def test_rope(notebook): queries_rot = notebook.compute_rope(queries, cos, sin) keys_rot = notebook.compute_rope(keys, cos, sin) - class RoPEConfig: - rope_type = "default" - rope_scaling = None - factor = 1.0 - dim: int = head_dim - rope_theta = 10000 - max_position_embeddings: int = 4096 - hidden_size = head_dim * num_heads - num_attention_heads = num_heads + rot_emb = LlamaRotaryEmbedding( + dim=head_dim, + max_position_embeddings=context_len, + base=10_000 + ) - config = RoPEConfig() + position_ids = torch.arange(context_len, dtype=torch.long).unsqueeze(0) + ref_cos, ref_sin = rot_emb(queries, position_ids) + ref_queries_rot, ref_keys_rot = apply_rotary_pos_emb(queries, keys, ref_cos, ref_sin) + + torch.testing.assert_close(sin, ref_sin.squeeze(0)) + torch.testing.assert_close(cos, ref_cos.squeeze(0)) + torch.testing.assert_close(keys_rot, ref_keys_rot) + torch.testing.assert_close(queries_rot, ref_queries_rot) + + +def test_rope_llama3(notebook): + # Settings + batch_size = 1 + context_len = 8192 + num_heads = 4 + head_dim = 16 + theta_base = 50_000 + + # Instantiate RoPE parameters + cos, sin = notebook.precompute_rope_params( + head_dim=head_dim, + context_length=context_len, + theta_base=theta_base + ) + + # Dummy query and key tensors + queries = torch.randn(batch_size, num_heads, context_len, head_dim) + keys = torch.randn(batch_size, num_heads, context_len, head_dim) + + # Apply rotary position embeddings + queries_rot = notebook.compute_rope(queries, cos, sin) + keys_rot = notebook.compute_rope(keys, cos, sin) + + rot_emb = LlamaRotaryEmbedding( + dim=head_dim, + max_position_embeddings=context_len, + base=theta_base + ) - rot_emb = LlamaRotaryEmbedding(config=config) position_ids = torch.arange(context_len, dtype=torch.long).unsqueeze(0) ref_cos, ref_sin = rot_emb(queries, position_ids) ref_queries_rot, ref_keys_rot = apply_rotary_pos_emb(queries, keys, ref_cos, ref_sin) @@ -108,7 +140,7 @@ def test_silu(notebook): @pytest.mark.skipif(torch.__version__ < "2.4", reason="Requires PyTorch 2.4 or newer") def test_rmsnorm(notebook): example_batch = torch.randn(2, 3, 4) - rms_norm = notebook.RMSNorm(emb_dim=example_batch.shape[-1]) - rmsnorm_pytorch = torch.nn.RMSNorm(example_batch.shape[-1], eps=1e-6) + rms_norm = notebook.RMSNorm(emb_dim=example_batch.shape[-1], eps=1e-5) + rmsnorm_pytorch = torch.nn.RMSNorm(example_batch.shape[-1], eps=1e-5) assert torch.allclose(rms_norm(example_batch), rmsnorm_pytorch(example_batch))