RoPE increase (#407)

This commit is contained in:
Sebastian Raschka 2024-10-21 19:58:38 -05:00 committed by GitHub
parent 75133605c5
commit 534a704364
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 15 additions and 15 deletions

View File

@ -254,12 +254,12 @@
"- Llama 3 uses rotary position embeddings (RoPE) similar to Llama 2 (for a detailed explanation, please see the [RoPE paper](https://arxiv.org/abs/2104.09864))\n", "- Llama 3 uses rotary position embeddings (RoPE) similar to Llama 2 (for a detailed explanation, please see the [RoPE paper](https://arxiv.org/abs/2104.09864))\n",
"- There are some subtle differences in the RoPE settings, though\n", "- There are some subtle differences in the RoPE settings, though\n",
" - Llama 3 now supports up to 8,192 tokens, twice as many as Llama 2 (4,096)\n", " - Llama 3 now supports up to 8,192 tokens, twice as many as Llama 2 (4,096)\n",
" - The base value for the so-called RoPE $\\theta$ (see equation below) was increased from 10,000 (Llama 2) to 50,000 (Llama 3) in the following equation (adapted from the [RoPE paper](https://arxiv.org/abs/2104.09864))\n", " - The base value for the so-called RoPE $\\theta$ (see equation below) was increased from 10,000 (Llama 2) to 500,000 (Llama 3) in the following equation (adapted from the [RoPE paper](https://arxiv.org/abs/2104.09864))\n",
"\n", "\n",
"$$\\Theta = \\left\\{\\theta_i = \\text{base}^{\\frac{-2(i-1)}{d}}, i \\in \\left[1, 2, ..., d/2\\right]\\right\\}$$\n", "$$\\Theta = \\left\\{\\theta_i = \\text{base}^{\\frac{-2(i-1)}{d}}, i \\in \\left[1, 2, ..., d/2\\right]\\right\\}$$\n",
"\n", "\n",
"- These $\\theta$ values are a set of predefined parameters that are used to determine the rotational angles in the rotary matrix, where $d$ is the dimensionality of the embedding space\n", "- These $\\theta$ values are a set of predefined parameters that are used to determine the rotational angles in the rotary matrix, where $d$ is the dimensionality of the embedding space\n",
"- Increasing the base from 10,000 to 50,000 makes the frequencies (or rotation angles) decay more slowly across the dimensions, which means that higher dimensions will be associated with larger angles than before (essentially, it's a decompression of the frequencies)\n", "- Increasing the base from 10,000 to 500,000 makes the frequencies (or rotation angles) decay more slowly across the dimensions, which means that higher dimensions will be associated with larger angles than before (essentially, it's a decompression of the frequencies)\n",
"- In addition, we introduce a `freq_config` section in the code below that adjusts the frequency; however, we won't be needing it in Llama 3 (only Llama 3.1 and Llama 3.2), so we will revisit this `freq_config` later (it's set to `None` and ignored by default)" "- In addition, we introduce a `freq_config` section in the code below that adjusts the frequency; however, we won't be needing it in Llama 3 (only Llama 3.1 and Llama 3.2), so we will revisit this `freq_config` later (it's set to `None` and ignored by default)"
] ]
}, },
@ -274,7 +274,7 @@
"source": [ "source": [
"import torch\n", "import torch\n",
"\n", "\n",
"def precompute_rope_params(head_dim, theta_base=10000, context_length=4096, freq_config=None):\n", "def precompute_rope_params(head_dim, theta_base=10_000, context_length=4096, freq_config=None):\n",
" assert head_dim % 2 == 0, \"Embedding dimension must be even\"\n", " assert head_dim % 2 == 0, \"Embedding dimension must be even\"\n",
"\n", "\n",
" # Compute the inverse frequencies\n", " # Compute the inverse frequencies\n",
@ -347,7 +347,7 @@
"llama_3_context_len = 8192\n", "llama_3_context_len = 8192\n",
"\n", "\n",
"llama_2_theta_base = 10_000\n", "llama_2_theta_base = 10_000\n",
"llama_3_theta_base = 50_000" "llama_3_theta_base = 500_000"
] ]
}, },
{ {
@ -907,7 +907,7 @@
" \"n_layers\": 32, # Number of layers\n", " \"n_layers\": 32, # Number of layers\n",
" \"hidden_dim\": 14_336, # NEW: Larger size of the intermediate dimension in FeedForward\n", " \"hidden_dim\": 14_336, # NEW: Larger size of the intermediate dimension in FeedForward\n",
" \"n_kv_groups\": 8, # NEW: Key-Value groups for grouped-query attention\n", " \"n_kv_groups\": 8, # NEW: Key-Value groups for grouped-query attention\n",
" \"rope_base\": 50_000, # NEW: The base in RoPE's \"theta\" was increased to 50_000\n", " \"rope_base\": 500_000, # NEW: The base in RoPE's \"theta\" was increased to 500_000\n",
" \"rope_freq\": None, # NEW: Additional configuration for adjusting the RoPE frequencies\n", " \"rope_freq\": None, # NEW: Additional configuration for adjusting the RoPE frequencies\n",
" \"dtype\": torch.bfloat16 # Lower-precision dtype to save memory\n", " \"dtype\": torch.bfloat16 # Lower-precision dtype to save memory\n",
"}" "}"
@ -2060,7 +2060,7 @@
" \"n_layers\": 32, # Number of layers\n", " \"n_layers\": 32, # Number of layers\n",
" \"hidden_dim\": 14_336, # Size of the intermediate dimension in FeedForward\n", " \"hidden_dim\": 14_336, # Size of the intermediate dimension in FeedForward\n",
" \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n", " \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n",
" \"rope_base\": 50_000, # The base in RoPE's \"theta\"\n", " \"rope_base\": 500_000, # The base in RoPE's \"theta\"\n",
" \"rope_freq\": None, # Additional configuration for adjusting the RoPE frequencies\n", " \"rope_freq\": None, # Additional configuration for adjusting the RoPE frequencies\n",
" \"dtype\": torch.bfloat16 # Lower-precision dtype to save memory\n", " \"dtype\": torch.bfloat16 # Lower-precision dtype to save memory\n",
"}\n", "}\n",
@ -2073,7 +2073,7 @@
" \"n_layers\": 32, # Number of layers\n", " \"n_layers\": 32, # Number of layers\n",
" \"hidden_dim\": 14_336, # Size of the intermediate dimension in FeedForward\n", " \"hidden_dim\": 14_336, # Size of the intermediate dimension in FeedForward\n",
" \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n", " \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n",
" \"rope_base\": 50_000, # The base in RoPE's \"theta\"\n", " \"rope_base\": 500_000, # The base in RoPE's \"theta\"\n",
" \"dtype\": torch.bfloat16, # Lower-precision dtype to save memory\n", " \"dtype\": torch.bfloat16, # Lower-precision dtype to save memory\n",
" \"rope_freq\": { # NEW: RoPE frequency scaling\n", " \"rope_freq\": { # NEW: RoPE frequency scaling\n",
" \"factor\": 8.0,\n", " \"factor\": 8.0,\n",
@ -2421,7 +2421,7 @@
" \"n_layers\": 32, # Number of layers\n", " \"n_layers\": 32, # Number of layers\n",
" \"hidden_dim\": 14_336, # Size of the intermediate dimension in FeedForward\n", " \"hidden_dim\": 14_336, # Size of the intermediate dimension in FeedForward\n",
" \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n", " \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n",
" \"rope_base\": 50_000, # The base in RoPE's \"theta\"\n", " \"rope_base\": 500_000, # The base in RoPE's \"theta\"\n",
" \"dtype\": torch.bfloat16, # Lower-precision dtype to save memory\n", " \"dtype\": torch.bfloat16, # Lower-precision dtype to save memory\n",
" \"rope_freq\": { # NEW: RoPE frequency scaling\n", " \"rope_freq\": { # NEW: RoPE frequency scaling\n",
" \"factor\": 8.0,\n", " \"factor\": 8.0,\n",
@ -2440,7 +2440,7 @@
" \"n_layers\": 16, # NEW: Half the number of layers\n", " \"n_layers\": 16, # NEW: Half the number of layers\n",
" \"hidden_dim\": 8192, # NEW: Almost half the size of the intermediate dimension in FeedForward\n", " \"hidden_dim\": 8192, # NEW: Almost half the size of the intermediate dimension in FeedForward\n",
" \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n", " \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n",
" \"rope_base\": 50_000, # The base in RoPE's \"theta\"\n", " \"rope_base\": 500_000, # The base in RoPE's \"theta\"\n",
" \"dtype\": torch.bfloat16, # Lower-precision dtype to save memory\n", " \"dtype\": torch.bfloat16, # Lower-precision dtype to save memory\n",
" \"rope_freq\": { # RoPE frequency scaling\n", " \"rope_freq\": { # RoPE frequency scaling\n",
" \"factor\": 32.0, # NEW: Adjustment of the rescaling factor\n", " \"factor\": 32.0, # NEW: Adjustment of the rescaling factor\n",

View File

@ -129,7 +129,7 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"def precompute_rope_params(head_dim, theta_base=10000, context_length=4096, freq_config=None):\n", "def precompute_rope_params(head_dim, theta_base=10_000, context_length=4096, freq_config=None):\n",
" assert head_dim % 2 == 0, \"Embedding dimension must be even\"\n", " assert head_dim % 2 == 0, \"Embedding dimension must be even\"\n",
"\n", "\n",
" # Compute the inverse frequencies\n", " # Compute the inverse frequencies\n",
@ -407,7 +407,7 @@
" \"n_layers\": 16, # Number of layers\n", " \"n_layers\": 16, # Number of layers\n",
" \"hidden_dim\": 8192, # Size of the intermediate dimension in FeedForward\n", " \"hidden_dim\": 8192, # Size of the intermediate dimension in FeedForward\n",
" \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n", " \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n",
" \"rope_base\": 50_000, # The base in RoPE's \"theta\"\n", " \"rope_base\": 500_000, # The base in RoPE's \"theta\"\n",
" \"dtype\": torch.bfloat16, # Lower-precision dtype to save memory\n", " \"dtype\": torch.bfloat16, # Lower-precision dtype to save memory\n",
" \"rope_freq\": { # RoPE frequency scaling\n", " \"rope_freq\": { # RoPE frequency scaling\n",
" \"factor\": 32.0,\n", " \"factor\": 32.0,\n",
@ -427,7 +427,7 @@
"# \"n_layers\": 28, # Number of layers\n", "# \"n_layers\": 28, # Number of layers\n",
"# \"hidden_dim\": 8192, # Size of the intermediate dimension in FeedForward\n", "# \"hidden_dim\": 8192, # Size of the intermediate dimension in FeedForward\n",
"# \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n", "# \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n",
"# \"rope_base\": 50_000, # The base in RoPE's \"theta\"\n", "# \"rope_base\": 500_000, # The base in RoPE's \"theta\"\n",
"# \"dtype\": torch.bfloat16, # Lower-precision dtype to save memory\n", "# \"dtype\": torch.bfloat16, # Lower-precision dtype to save memory\n",
"# \"rope_freq\": { # RoPE frequency scaling\n", "# \"rope_freq\": { # RoPE frequency scaling\n",
"# \"factor\": 32.0,\n", "# \"factor\": 32.0,\n",

View File

@ -111,7 +111,7 @@ def test_rope_llama3(notebook):
context_len = 8192 context_len = 8192
num_heads = 4 num_heads = 4
head_dim = 16 head_dim = 16
theta_base = 50_000 theta_base = 500_000
# Instantiate RoPE parameters # Instantiate RoPE parameters
cos, sin = nb2.precompute_rope_params( cos, sin = nb2.precompute_rope_params(
@ -155,7 +155,7 @@ def test_rope_llama3_12(notebook):
context_len = 8192 context_len = 8192
num_heads = 4 num_heads = 4
head_dim = 16 head_dim = 16
rope_theta = 50_000 rope_theta = 500_000
rope_config = { rope_config = {
"factor": 8.0, "factor": 8.0,
@ -194,7 +194,7 @@ def test_rope_llama3_12(notebook):
rope_scaling = hf_rope_params rope_scaling = hf_rope_params
factor = 1.0 factor = 1.0
dim: int = head_dim dim: int = head_dim
rope_theta = 50_000 rope_theta = 500_000
max_position_embeddings: int = 8192 max_position_embeddings: int = 8192
hidden_size = head_dim * num_heads hidden_size = head_dim * num_heads
num_attention_heads = num_heads num_attention_heads = num_heads