Introduce buffers to improve Llama 3.2 efficiency (#389)

* Introduce buffers to improve Llama 3.2 efficiency

* update

* update
This commit is contained in:
Sebastian Raschka 2024-10-06 12:49:04 -05:00 committed by GitHub
parent a0c0c765a8
commit 1eb0b3810a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 248 additions and 89 deletions

View File

@ -430,6 +430,14 @@
"- In short, the main change in GQA is that each query group needs to be repeated to match the number of heads it is associated with, as implemented below"
]
},
{
"cell_type": "markdown",
"id": "842aa71a-4659-424e-8830-392bd6ae86af",
"metadata": {},
"source": [
"- In addition, we also introduce a `SharedBuffers` class that will allow us to reuse the `mask`, `cos`, and `sin` tensors in the transformer blocks to improve efficiency (this will be crucial when working with models such as Llama 3.1 and 3.2 later, which support up to 131k input tokens)"
]
},
{
"cell_type": "code",
"execution_count": 8,
@ -441,6 +449,28 @@
"source": [
"import torch.nn as nn\n",
"\n",
"\n",
"############################# NEW #############################\n",
"class SharedBuffers:\n",
" _buffers = {}\n",
"\n",
" @staticmethod\n",
" def get_buffers(context_length, head_dim, rope_base, freq_config, dtype=torch.float32):\n",
" key = (context_length, head_dim, rope_base, tuple(freq_config.values()) if freq_config else freq_config, dtype)\n",
"\n",
" if key not in SharedBuffers._buffers:\n",
" # Create or fetch the buffers\n",
" mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)\n",
" cos, sin = precompute_rope_params(head_dim, rope_base, context_length, freq_config)\n",
" if dtype is not None:\n",
" cos = cos.to(dtype)\n",
" sin = sin.to(dtype)\n",
" SharedBuffers._buffers[key] = (mask, cos, sin)\n",
"\n",
" return SharedBuffers._buffers[key]\n",
"############################# NEW #############################\n",
"\n",
"\n",
"class GroupedQueryAttention(nn.Module):\n",
" def __init__(\n",
" self, d_in, d_out, context_length, num_heads,\n",
@ -469,13 +499,12 @@
" self.W_query = nn.Linear(d_in, d_out, bias=False, dtype=dtype)\n",
" self.out_proj = nn.Linear(d_out, d_out, bias=False, dtype=dtype)\n",
"\n",
" self.register_buffer(\"mask\", torch.triu(torch.ones(context_length, context_length), diagonal=1))\n",
" cos, sin = precompute_rope_params(\n",
" head_dim=self.head_dim,\n",
" theta_base=rope_base, # NEW\n",
" freq_config=rope_config, # NEW\n",
" context_length=8192\n",
" )\n",
" ############################# NEW #############################\n",
" # Fetch buffers using SharedBuffers\n",
" mask, cos, sin = SharedBuffers.get_buffers(context_length, self.head_dim, rope_base, rope_config, dtype)\n",
" ############################# NEW #############################\n",
" \n",
" self.register_buffer(\"mask\", mask)\n",
" self.register_buffer(\"cos\", cos)\n",
" self.register_buffer(\"sin\", sin)\n",
"\n",
@ -907,6 +936,35 @@
"model = Llama3Model(LLAMA3_CONFIG_8B)"
]
},
{
"cell_type": "markdown",
"id": "edea6334-d1fc-427d-9cf2-4af963ff4bfc",
"metadata": {},
"source": [
"- The following is expected to print True to confirm buffers are reused instead of being (wastefully) recreated:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ee9625cc-9afa-4b11-8aab-d536fd170761",
"metadata": {},
"outputs": [],
"source": [
"# Check buffers\n",
"print(model.trf_blocks[0].att.mask is model.trf_blocks[-1].att.mask)\n",
"print(model.trf_blocks[0].att.cos is model.trf_blocks[-1].att.cos)\n",
"print(model.trf_blocks[0].att.sin is model.trf_blocks[-1].att.sin) "
]
},
{
"cell_type": "markdown",
"id": "8056a521-91a6-440f-8473-591409c3177b",
"metadata": {},
"source": [
"- Let's now also compute the number of trainable parameters:"
]
},
{
"cell_type": "code",
"execution_count": 18,
@ -2008,16 +2066,16 @@
"}\n",
"\n",
"LLAMA31_CONFIG_8B = {\n",
" \"vocab_size\": 128_256, # Vocabulary size\n",
" \"context_length\": 8192, # Context length\n",
" \"emb_dim\": 4096, # Embedding dimension\n",
" \"n_heads\": 32, # Number of attention heads\n",
" \"n_layers\": 32, # Number of layers\n",
" \"hidden_dim\": 14_336, # Size of the intermediate dimension in FeedForward\n",
" \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n",
" \"rope_base\": 50_000, # The base in RoPE's \"theta\"\n",
" \"dtype\": torch.bfloat16, # Lower-precision dtype to save memory\n",
" \"rope_freq\": { # NEW: RoPE frequency scaling\n",
" \"vocab_size\": 128_256, # Vocabulary size\n",
" \"context_length\": 131_072, # NEW: Larger supported context length\n",
" \"emb_dim\": 4096, # Embedding dimension\n",
" \"n_heads\": 32, # Number of attention heads\n",
" \"n_layers\": 32, # Number of layers\n",
" \"hidden_dim\": 14_336, # Size of the intermediate dimension in FeedForward\n",
" \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n",
" \"rope_base\": 50_000, # The base in RoPE's \"theta\"\n",
" \"dtype\": torch.bfloat16, # Lower-precision dtype to save memory\n",
" \"rope_freq\": { # NEW: RoPE frequency scaling\n",
" \"factor\": 8.0,\n",
" \"low_freq_factor\": 1.0,\n",
" \"high_freq_factor\": 4.0,\n",
@ -2026,6 +2084,24 @@
"}"
]
},
{
"cell_type": "markdown",
"id": "d81ee464-c112-43b0-9ee8-70df6ac942d0",
"metadata": {},
"source": [
"- Reduce the context length so the model would work fine on a MacBook Air (if you have more RAM, feel free to comment out the lines below):"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9bdbe32f-4c96-4e60-8bf4-52b5217df1e6",
"metadata": {},
"outputs": [],
"source": [
"LLAMA32_CONFIG[\"context_length\"] = 8192"
]
},
{
"cell_type": "markdown",
"id": "xa3bpMDtTdBs",
@ -2338,16 +2414,16 @@
"outputs": [],
"source": [
"LLAMA31_CONFIG_8B = {\n",
" \"vocab_size\": 128_256, # Vocabulary size\n",
" \"context_length\": 8192, # Context length\n",
" \"emb_dim\": 4096, # Embedding dimension\n",
" \"n_heads\": 32, # Number of attention heads\n",
" \"n_layers\": 32, # Number of layers\n",
" \"hidden_dim\": 14_336, # Size of the intermediate dimension in FeedForward\n",
" \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n",
" \"rope_base\": 50_000, # The base in RoPE's \"theta\"\n",
" \"dtype\": torch.bfloat16, # Lower-precision dtype to save memory\n",
" \"rope_freq\": { # RoPE frequency scaling\n",
" \"vocab_size\": 128_256, # Vocabulary size\n",
" \"context_length\": 131_072, # NEW: Larger supported context length\n",
" \"emb_dim\": 4096, # Embedding dimension\n",
" \"n_heads\": 32, # Number of attention heads\n",
" \"n_layers\": 32, # Number of layers\n",
" \"hidden_dim\": 14_336, # Size of the intermediate dimension in FeedForward\n",
" \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n",
" \"rope_base\": 50_000, # The base in RoPE's \"theta\"\n",
" \"dtype\": torch.bfloat16, # Lower-precision dtype to save memory\n",
" \"rope_freq\": { # NEW: RoPE frequency scaling\n",
" \"factor\": 8.0,\n",
" \"low_freq_factor\": 1.0,\n",
" \"high_freq_factor\": 4.0,\n",
@ -2357,17 +2433,17 @@
"\n",
"\n",
"LLAMA32_CONFIG_1B = {\n",
" \"vocab_size\": 128_256, # Vocabulary size\n",
" \"context_length\": 8192, # Context length\n",
" \"emb_dim\": 2048, # NEW: Half the embedding dimension\n",
" \"n_heads\": 32, # Number of attention heads\n",
" \"n_layers\": 16, # NEW: Half the number of layers\n",
" \"hidden_dim\": 8192, # NEW: Almost half the size of the intermediate dimension in FeedForward\n",
" \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n",
" \"rope_base\": 50_000, # The base in RoPE's \"theta\"\n",
" \"dtype\": torch.bfloat16, # Lower-precision dtype to save memory\n",
" \"rope_freq\": { # RoPE frequency scaling\n",
" \"factor\": 32.0, # NEW: Adjustment of the rescaling factor\n",
" \"vocab_size\": 128_256, # Vocabulary size\n",
" \"context_length\": 131_072, # Context length\n",
" \"emb_dim\": 2048, # NEW: Half the embedding dimension\n",
" \"n_heads\": 32, # Number of attention heads\n",
" \"n_layers\": 16, # NEW: Half the number of layers\n",
" \"hidden_dim\": 8192, # NEW: Almost half the size of the intermediate dimension in FeedForward\n",
" \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n",
" \"rope_base\": 50_000, # The base in RoPE's \"theta\"\n",
" \"dtype\": torch.bfloat16, # Lower-precision dtype to save memory\n",
" \"rope_freq\": { # RoPE frequency scaling\n",
" \"factor\": 32.0, # NEW: Adjustment of the rescaling factor\n",
" \"low_freq_factor\": 1.0,\n",
" \"high_freq_factor\": 4.0,\n",
" \"original_context_length\": 8192,\n",
@ -2375,6 +2451,24 @@
"}"
]
},
{
"cell_type": "markdown",
"id": "b5cd351b-d883-460d-9cdc-47e15ddb884a",
"metadata": {},
"source": [
"- Reduce the context length so the model would work fine on a MacBook Air (if you have more RAM, feel free to comment out the lines below):"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "387456c3-c6a1-46fe-8830-6e00eb46ac13",
"metadata": {},
"outputs": [],
"source": [
"LLAMA32_CONFIG[\"context_length\"] = 8192"
]
},
{
"cell_type": "markdown",
"id": "Dl4_0EoJKKYv",
@ -2593,7 +2687,7 @@
"provenance": []
},
"kernelspec": {
"display_name": "base",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
@ -2607,7 +2701,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.11"
"version": "3.11.4"
},
"widgets": {
"application/vnd.jupyter.widget-state+json": {

View File

@ -69,9 +69,9 @@
"output_type": "stream",
"text": [
"blobfile version: 3.0.0\n",
"huggingface_hub version: 0.25.0\n",
"huggingface_hub version: 0.25.1\n",
"tiktoken version: 0.7.0\n",
"torch version: 2.5.0.dev20240812+cu121\n"
"torch version: 2.4.0\n"
]
}
],
@ -201,6 +201,25 @@
"metadata": {},
"outputs": [],
"source": [
"class SharedBuffers:\n",
" _buffers = {}\n",
"\n",
" @staticmethod\n",
" def get_buffers(context_length, head_dim, rope_base, freq_config, dtype=torch.float32):\n",
" key = (context_length, head_dim, rope_base, tuple(freq_config.values()) if freq_config else freq_config, dtype)\n",
"\n",
" if key not in SharedBuffers._buffers:\n",
" # Create or fetch the buffers\n",
" mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)\n",
" cos, sin = precompute_rope_params(head_dim, rope_base, context_length, freq_config)\n",
" if dtype is not None:\n",
" cos = cos.to(dtype)\n",
" sin = sin.to(dtype)\n",
" SharedBuffers._buffers[key] = (mask, cos, sin)\n",
"\n",
" return SharedBuffers._buffers[key]\n",
"\n",
"\n",
"class GroupedQueryAttention(nn.Module):\n",
" def __init__(\n",
" self, d_in, d_out, context_length, num_heads,\n",
@ -225,13 +244,10 @@
" self.W_query = nn.Linear(d_in, d_out, bias=False, dtype=dtype)\n",
" self.out_proj = nn.Linear(d_out, d_out, bias=False, dtype=dtype)\n",
"\n",
" self.register_buffer(\"mask\", torch.triu(torch.ones(context_length, context_length), diagonal=1))\n",
" cos, sin = precompute_rope_params(\n",
" head_dim=self.head_dim,\n",
" theta_base=rope_base,\n",
" freq_config=rope_config,\n",
" context_length=8192\n",
" )\n",
" # Fetch buffers using SharedBuffers\n",
" mask, cos, sin = SharedBuffers.get_buffers(context_length, self.head_dim, rope_base, rope_config, dtype)\n",
" self.register_buffer(\"mask\", mask)\n",
"\n",
" self.register_buffer(\"cos\", cos)\n",
" self.register_buffer(\"sin\", sin)\n",
"\n",
@ -384,16 +400,16 @@
"# Llama 3.2 1B\n",
"\n",
"LLAMA32_CONFIG = {\n",
" \"vocab_size\": 128_256, # Vocabulary size\n",
" \"context_length\": 8192, # Context length\n",
" \"emb_dim\": 2048, # Embedding dimension\n",
" \"n_heads\": 32, # Number of attention heads\n",
" \"n_layers\": 16, # Number of layers\n",
" \"hidden_dim\": 8192, # Size of the intermediate dimension in FeedForward\n",
" \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n",
" \"rope_base\": 50_000, # The base in RoPE's \"theta\"\n",
" \"dtype\": torch.bfloat16, # Lower-precision dtype to save memory\n",
" \"rope_freq\": { # RoPE frequency scaling\n",
" \"vocab_size\": 128_256, # Vocabulary size\n",
" \"context_length\": 131_072, # Context length\n",
" \"emb_dim\": 2048, # Embedding dimension\n",
" \"n_heads\": 32, # Number of attention heads\n",
" \"n_layers\": 16, # Number of layers\n",
" \"hidden_dim\": 8192, # Size of the intermediate dimension in FeedForward\n",
" \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n",
" \"rope_base\": 50_000, # The base in RoPE's \"theta\"\n",
" \"dtype\": torch.bfloat16, # Lower-precision dtype to save memory\n",
" \"rope_freq\": { # RoPE frequency scaling\n",
" \"factor\": 32.0,\n",
" \"low_freq_factor\": 1.0,\n",
" \"high_freq_factor\": 4.0,\n",
@ -404,16 +420,16 @@
"# Llama 3.2 3B\n",
"\n",
"# LLAMA32_CONFIG = {\n",
"# \"vocab_size\": 128_256, # Vocabulary size\n",
"# \"context_length\": 8192, # Context length\n",
"# \"emb_dim\": 3072, # Embedding dimension\n",
"# \"n_heads\": 24, # Number of attention heads\n",
"# \"n_layers\": 28, # Number of layers\n",
"# \"hidden_dim\": 8192, # Size of the intermediate dimension in FeedForward\n",
"# \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n",
"# \"rope_base\": 50_000, # The base in RoPE's \"theta\"\n",
"# \"dtype\": torch.bfloat16, # Lower-precision dtype to save memory\n",
"# \"rope_freq\": { # RoPE frequency scaling\n",
"# \"vocab_size\": 128_256, # Vocabulary size\n",
"# \"context_length\": 131_000, # Context length\n",
"# \"emb_dim\": 3072, # Embedding dimension\n",
"# \"n_heads\": 24, # Number of attention heads\n",
"# \"n_layers\": 28, # Number of layers\n",
"# \"hidden_dim\": 8192, # Size of the intermediate dimension in FeedForward\n",
"# \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n",
"# \"rope_base\": 50_000, # The base in RoPE's \"theta\"\n",
"# \"dtype\": torch.bfloat16, # Lower-precision dtype to save memory\n",
"# \"rope_freq\": { # RoPE frequency scaling\n",
"# \"factor\": 32.0,\n",
"# \"low_freq_factor\": 1.0,\n",
"# \"high_freq_factor\": 4.0,\n",
@ -424,9 +440,27 @@
"LLAMA_SIZE_STR = \"1B\" if LLAMA32_CONFIG[\"emb_dim\"] == 2048 else \"3B\""
]
},
{
"cell_type": "markdown",
"id": "34535172-797e-4dd0-84fb-65bc75ad5b06",
"metadata": {},
"source": [
"- Reduce the context length so the model would work fine on a MacBook Air (if you have more RAM, feel free to comment out the lines below):"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "a8bc2370-39d2-4bfe-b4c1-6bdd75fe101c",
"metadata": {},
"outputs": [],
"source": [
"LLAMA32_CONFIG[\"context_length\"] = 8192"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "156253fe-aacd-4da2-8f13-705f05c4b11e",
"metadata": {},
"outputs": [],
@ -434,9 +468,40 @@
"model = Llama3Model(LLAMA32_CONFIG)"
]
},
{
"cell_type": "markdown",
"id": "19de6c2c-83ce-456d-8be9-6ec415fe9eb1",
"metadata": {},
"source": [
"- The following is expected to print True to confirm buffers are reused instead of being (wastefully) recreated:"
]
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 11,
"id": "0e95db6d-2712-41a5-a5e0-86c49897f4cf",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"True\n",
"True\n",
"True\n"
]
}
],
"source": [
"# Check buffers\n",
"print(model.trf_blocks[0].att.mask is model.trf_blocks[-1].att.mask)\n",
"print(model.trf_blocks[0].att.cos is model.trf_blocks[-1].att.cos)\n",
"print(model.trf_blocks[0].att.sin is model.trf_blocks[-1].att.sin) "
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "364e76ca-52f8-4fa5-af37-c4069f9694bc",
"metadata": {},
"outputs": [
@ -461,7 +526,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 13,
"id": "fd5efb03-5a07-46e8-8607-93ed47549d2b",
"metadata": {},
"outputs": [
@ -469,8 +534,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
"float32 (PyTorch default): 15.23 GB\n",
"bfloat16: 7.61 GB\n"
"float32 (PyTorch default): 11.42 GB\n",
"bfloat16: 5.71 GB\n"
]
}
],
@ -505,7 +570,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 14,
"id": "31f12baf-f79b-499f-85c0-51328a6a20f5",
"metadata": {},
"outputs": [],
@ -531,7 +596,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 15,
"id": "9482b01c-49f9-48e4-ab2c-4a4c75240e77",
"metadata": {},
"outputs": [],
@ -631,14 +696,14 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 16,
"id": "e9d96dc8-603a-4cb5-8c3e-4d2ca56862ed",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "fede18d637d24f79a27220fb83bc6d2b",
"model_id": "8cdf801700d64fe9b2b827172a8eebcf",
"version_major": 2,
"version_minor": 0
},
@ -658,7 +723,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 17,
"id": "986bc1a0-804f-4154-80f8-44cefbee1368",
"metadata": {},
"outputs": [],
@ -674,7 +739,7 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 18,
"id": "f5a3014f-4c66-4fe2-874e-7b57562c49ad",
"metadata": {},
"outputs": [],
@ -694,7 +759,7 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 19,
"id": "75166128-5899-4995-9b88-9672e135650e",
"metadata": {},
"outputs": [],
@ -775,7 +840,7 @@
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": 20,
"id": "699cb1b8-a67d-49fb-80a6-0dad9d81f392",
"metadata": {},
"outputs": [
@ -818,7 +883,7 @@
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": 21,
"id": "7f9f7ccc-70cb-41ff-9c25-44336042fc37",
"metadata": {},
"outputs": [
@ -845,7 +910,7 @@
},
{
"cell_type": "code",
"execution_count": 20,
"execution_count": 22,
"id": "7b8401c6-e244-4cb7-9849-2ba71ce758d5",
"metadata": {},
"outputs": [],
@ -902,7 +967,7 @@
},
{
"cell_type": "code",
"execution_count": 21,
"execution_count": 23,
"id": "1c7a04fa-6aac-416b-8f63-f1e19227633d",
"metadata": {},
"outputs": [
@ -915,10 +980,10 @@
"\n",
"1. Grasses: Llamas love to graze on various types of grasses, including tall grasses and short grasses.\n",
"2. Hay: Llamas also eat hay, which is a dry, compressed form of grass or other plants.\n",
"3. Alfalfa: Alfalfa is a legume that is commonly used as a hay substitute in llama feed.\n",
"4. Other plants: Llamas will also eat other plants, such as clover, oats, and barley.\n",
"3. Alfalfa: Alfalfa is a legume that is commonly fed to llamas, as it is high in protein and fiber.\n",
"4. Other plants: Llamas will also eat other plants, such as clover, wild grasses, and shrubs.\n",
"\n",
"It's worth noting that llamas are adapted to high-altitude environments and can survive on low-quality hay and\n"
"It's worth noting that llamas are adapted to high altitudes and\n"
]
}
],
@ -982,7 +1047,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "base",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
@ -996,7 +1061,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.11"
"version": "3.11.4"
}
},
"nbformat": 4,