mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-08-31 20:08:08 +00:00
Improve rope settings for llama3 (#380)
This commit is contained in:
parent
2ae4ad15ba
commit
feb0647c79
5
.gitignore
vendored
5
.gitignore
vendored
@ -35,8 +35,9 @@ ch05/01_main-chapter-code/model.pth
|
|||||||
ch05/01_main-chapter-code/model_and_optimizer.pth
|
ch05/01_main-chapter-code/model_and_optimizer.pth
|
||||||
ch05/03_bonus_pretraining_on_gutenberg/model_checkpoints
|
ch05/03_bonus_pretraining_on_gutenberg/model_checkpoints
|
||||||
ch05/06_user_interface/gpt2
|
ch05/06_user_interface/gpt2
|
||||||
ch05/07_gpt_to_llama/models--meta-llama--Llama-2-7b
|
ch05/07_gpt_to_llama/Llama-2-7b
|
||||||
ch05/07_gpt_to_llama/models--meta-llama--Llama-2-7b-chat
|
ch05/07_gpt_to_llama/Llama-2-7b-chat
|
||||||
|
ch05/07_gpt_to_llama/.cache
|
||||||
|
|
||||||
ch06/01_main-chapter-code/gpt2
|
ch06/01_main-chapter-code/gpt2
|
||||||
ch06/02_bonus_additional-experiments/gpt2
|
ch06/02_bonus_additional-experiments/gpt2
|
||||||
|
@ -180,7 +180,7 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"class RMSNorm(nn.Module):\n",
|
"class RMSNorm(nn.Module):\n",
|
||||||
" def __init__(self, emb_dim, eps=1e-6):\n",
|
" def __init__(self, emb_dim, eps=1e-5):\n",
|
||||||
" super().__init__()\n",
|
" super().__init__()\n",
|
||||||
" self.eps = eps\n",
|
" self.eps = eps\n",
|
||||||
" self.emb_dim = emb_dim\n",
|
" self.emb_dim = emb_dim\n",
|
||||||
@ -216,7 +216,7 @@
|
|||||||
"example_batch = torch.randn(2, 3, 4)\n",
|
"example_batch = torch.randn(2, 3, 4)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"rms_norm = RMSNorm(emb_dim=example_batch.shape[-1])\n",
|
"rms_norm = RMSNorm(emb_dim=example_batch.shape[-1])\n",
|
||||||
"rmsnorm_pytorch = torch.nn.RMSNorm(example_batch.shape[-1], eps=1e-6)\n",
|
"rmsnorm_pytorch = torch.nn.RMSNorm(example_batch.shape[-1], eps=1e-5)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"assert torch.allclose(rms_norm(example_batch), rmsnorm_pytorch(example_batch))"
|
"assert torch.allclose(rms_norm(example_batch), rmsnorm_pytorch(example_batch))"
|
||||||
]
|
]
|
||||||
@ -417,11 +417,11 @@
|
|||||||
},
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"def precompute_rope_params(head_dim, context_length=4096):\n",
|
"def precompute_rope_params(head_dim, theta_base=10_000, context_length=4096):\n",
|
||||||
" assert head_dim % 2 == 0, \"Embedding dimension must be even\"\n",
|
" assert head_dim % 2 == 0, \"Embedding dimension must be even\"\n",
|
||||||
"\n",
|
"\n",
|
||||||
" # Compute the inverse frequencies\n",
|
" # Compute the inverse frequencies\n",
|
||||||
" inv_freq = 1.0 / (10000 ** (torch.arange(0, head_dim // 2) / (head_dim // 2)))\n",
|
" inv_freq = 1.0 / (theta_base ** (torch.arange(0, head_dim // 2) / (head_dim // 2)))\n",
|
||||||
"\n",
|
"\n",
|
||||||
" # Generate position indices\n",
|
" # Generate position indices\n",
|
||||||
" positions = torch.arange(context_length)\n",
|
" positions = torch.arange(context_length)\n",
|
||||||
@ -1151,7 +1151,7 @@
|
|||||||
"tokenizer_file = hf_hub_download(\n",
|
"tokenizer_file = hf_hub_download(\n",
|
||||||
" repo_id=\"meta-llama/Llama-2-7b\",\n",
|
" repo_id=\"meta-llama/Llama-2-7b\",\n",
|
||||||
" filename=\"tokenizer.model\",\n",
|
" filename=\"tokenizer.model\",\n",
|
||||||
" cache_dir=\".\")"
|
" local_dir=\"Llama-2-7B\")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -1285,7 +1285,7 @@
|
|||||||
"weights_file = hf_hub_download(\n",
|
"weights_file = hf_hub_download(\n",
|
||||||
" repo_id=\"meta-llama/Llama-2-7b\",\n",
|
" repo_id=\"meta-llama/Llama-2-7b\",\n",
|
||||||
" filename=\"consolidated.00.pth\",\n",
|
" filename=\"consolidated.00.pth\",\n",
|
||||||
" cache_dir=\".\"\n",
|
" local_dir=\"Llama-2-7b\"\n",
|
||||||
")"
|
")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
@ -1520,7 +1520,7 @@
|
|||||||
"weights_file = hf_hub_download(\n",
|
"weights_file = hf_hub_download(\n",
|
||||||
" repo_id=\"meta-llama/Llama-2-7b-chat\",\n",
|
" repo_id=\"meta-llama/Llama-2-7b-chat\",\n",
|
||||||
" filename=\"consolidated.00.pth\",\n",
|
" filename=\"consolidated.00.pth\",\n",
|
||||||
" cache_dir=\".\"\n",
|
" lcoal_dir=\"Llama-2-7b-chat\n",
|
||||||
")\n",
|
")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"model = Llama2Model(LLAMA2_CONFIG_7B)\n",
|
"model = Llama2Model(LLAMA2_CONFIG_7B)\n",
|
||||||
|
@ -58,10 +58,10 @@ def set_seed():
|
|||||||
torch.manual_seed(123)
|
torch.manual_seed(123)
|
||||||
|
|
||||||
|
|
||||||
def test_rope(notebook):
|
def test_rope_llama2(notebook):
|
||||||
# Settings
|
# Settings
|
||||||
batch_size = 1
|
batch_size = 1
|
||||||
context_len = 5
|
context_len = 4096
|
||||||
num_heads = 4
|
num_heads = 4
|
||||||
head_dim = 16
|
head_dim = 16
|
||||||
|
|
||||||
@ -76,19 +76,51 @@ def test_rope(notebook):
|
|||||||
queries_rot = notebook.compute_rope(queries, cos, sin)
|
queries_rot = notebook.compute_rope(queries, cos, sin)
|
||||||
keys_rot = notebook.compute_rope(keys, cos, sin)
|
keys_rot = notebook.compute_rope(keys, cos, sin)
|
||||||
|
|
||||||
class RoPEConfig:
|
rot_emb = LlamaRotaryEmbedding(
|
||||||
rope_type = "default"
|
dim=head_dim,
|
||||||
rope_scaling = None
|
max_position_embeddings=context_len,
|
||||||
factor = 1.0
|
base=10_000
|
||||||
dim: int = head_dim
|
)
|
||||||
rope_theta = 10000
|
|
||||||
max_position_embeddings: int = 4096
|
|
||||||
hidden_size = head_dim * num_heads
|
|
||||||
num_attention_heads = num_heads
|
|
||||||
|
|
||||||
config = RoPEConfig()
|
position_ids = torch.arange(context_len, dtype=torch.long).unsqueeze(0)
|
||||||
|
ref_cos, ref_sin = rot_emb(queries, position_ids)
|
||||||
|
ref_queries_rot, ref_keys_rot = apply_rotary_pos_emb(queries, keys, ref_cos, ref_sin)
|
||||||
|
|
||||||
|
torch.testing.assert_close(sin, ref_sin.squeeze(0))
|
||||||
|
torch.testing.assert_close(cos, ref_cos.squeeze(0))
|
||||||
|
torch.testing.assert_close(keys_rot, ref_keys_rot)
|
||||||
|
torch.testing.assert_close(queries_rot, ref_queries_rot)
|
||||||
|
|
||||||
|
|
||||||
|
def test_rope_llama3(notebook):
|
||||||
|
# Settings
|
||||||
|
batch_size = 1
|
||||||
|
context_len = 8192
|
||||||
|
num_heads = 4
|
||||||
|
head_dim = 16
|
||||||
|
theta_base = 50_000
|
||||||
|
|
||||||
|
# Instantiate RoPE parameters
|
||||||
|
cos, sin = notebook.precompute_rope_params(
|
||||||
|
head_dim=head_dim,
|
||||||
|
context_length=context_len,
|
||||||
|
theta_base=theta_base
|
||||||
|
)
|
||||||
|
|
||||||
|
# Dummy query and key tensors
|
||||||
|
queries = torch.randn(batch_size, num_heads, context_len, head_dim)
|
||||||
|
keys = torch.randn(batch_size, num_heads, context_len, head_dim)
|
||||||
|
|
||||||
|
# Apply rotary position embeddings
|
||||||
|
queries_rot = notebook.compute_rope(queries, cos, sin)
|
||||||
|
keys_rot = notebook.compute_rope(keys, cos, sin)
|
||||||
|
|
||||||
|
rot_emb = LlamaRotaryEmbedding(
|
||||||
|
dim=head_dim,
|
||||||
|
max_position_embeddings=context_len,
|
||||||
|
base=theta_base
|
||||||
|
)
|
||||||
|
|
||||||
rot_emb = LlamaRotaryEmbedding(config=config)
|
|
||||||
position_ids = torch.arange(context_len, dtype=torch.long).unsqueeze(0)
|
position_ids = torch.arange(context_len, dtype=torch.long).unsqueeze(0)
|
||||||
ref_cos, ref_sin = rot_emb(queries, position_ids)
|
ref_cos, ref_sin = rot_emb(queries, position_ids)
|
||||||
ref_queries_rot, ref_keys_rot = apply_rotary_pos_emb(queries, keys, ref_cos, ref_sin)
|
ref_queries_rot, ref_keys_rot = apply_rotary_pos_emb(queries, keys, ref_cos, ref_sin)
|
||||||
@ -108,7 +140,7 @@ def test_silu(notebook):
|
|||||||
@pytest.mark.skipif(torch.__version__ < "2.4", reason="Requires PyTorch 2.4 or newer")
|
@pytest.mark.skipif(torch.__version__ < "2.4", reason="Requires PyTorch 2.4 or newer")
|
||||||
def test_rmsnorm(notebook):
|
def test_rmsnorm(notebook):
|
||||||
example_batch = torch.randn(2, 3, 4)
|
example_batch = torch.randn(2, 3, 4)
|
||||||
rms_norm = notebook.RMSNorm(emb_dim=example_batch.shape[-1])
|
rms_norm = notebook.RMSNorm(emb_dim=example_batch.shape[-1], eps=1e-5)
|
||||||
rmsnorm_pytorch = torch.nn.RMSNorm(example_batch.shape[-1], eps=1e-6)
|
rmsnorm_pytorch = torch.nn.RMSNorm(example_batch.shape[-1], eps=1e-5)
|
||||||
|
|
||||||
assert torch.allclose(rms_norm(example_batch), rmsnorm_pytorch(example_batch))
|
assert torch.allclose(rms_norm(example_batch), rmsnorm_pytorch(example_batch))
|
||||||
|
Loading…
x
Reference in New Issue
Block a user