mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-09-21 14:14:19 +00:00
Introduce buffers to improve Llama 3.2 efficiency (#389)
* Introduce buffers to improve Llama 3.2 efficiency * update * update
This commit is contained in:
parent
a0c0c765a8
commit
1eb0b3810a
@ -430,6 +430,14 @@
|
|||||||
"- In short, the main change in GQA is that each query group needs to be repeated to match the number of heads it is associated with, as implemented below"
|
"- In short, the main change in GQA is that each query group needs to be repeated to match the number of heads it is associated with, as implemented below"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "842aa71a-4659-424e-8830-392bd6ae86af",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"- In addition, we also introduce a `SharedBuffers` class that will allow us to reuse the `mask`, `cos`, and `sin` tensors in the transformer blocks to improve efficiency (this will be crucial when working with models such as Llama 3.1 and 3.2 later, which support up to 131k input tokens)"
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 8,
|
"execution_count": 8,
|
||||||
@ -441,6 +449,28 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"import torch.nn as nn\n",
|
"import torch.nn as nn\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
"\n",
|
||||||
|
"############################# NEW #############################\n",
|
||||||
|
"class SharedBuffers:\n",
|
||||||
|
" _buffers = {}\n",
|
||||||
|
"\n",
|
||||||
|
" @staticmethod\n",
|
||||||
|
" def get_buffers(context_length, head_dim, rope_base, freq_config, dtype=torch.float32):\n",
|
||||||
|
" key = (context_length, head_dim, rope_base, tuple(freq_config.values()) if freq_config else freq_config, dtype)\n",
|
||||||
|
"\n",
|
||||||
|
" if key not in SharedBuffers._buffers:\n",
|
||||||
|
" # Create or fetch the buffers\n",
|
||||||
|
" mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)\n",
|
||||||
|
" cos, sin = precompute_rope_params(head_dim, rope_base, context_length, freq_config)\n",
|
||||||
|
" if dtype is not None:\n",
|
||||||
|
" cos = cos.to(dtype)\n",
|
||||||
|
" sin = sin.to(dtype)\n",
|
||||||
|
" SharedBuffers._buffers[key] = (mask, cos, sin)\n",
|
||||||
|
"\n",
|
||||||
|
" return SharedBuffers._buffers[key]\n",
|
||||||
|
"############################# NEW #############################\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
"class GroupedQueryAttention(nn.Module):\n",
|
"class GroupedQueryAttention(nn.Module):\n",
|
||||||
" def __init__(\n",
|
" def __init__(\n",
|
||||||
" self, d_in, d_out, context_length, num_heads,\n",
|
" self, d_in, d_out, context_length, num_heads,\n",
|
||||||
@ -469,13 +499,12 @@
|
|||||||
" self.W_query = nn.Linear(d_in, d_out, bias=False, dtype=dtype)\n",
|
" self.W_query = nn.Linear(d_in, d_out, bias=False, dtype=dtype)\n",
|
||||||
" self.out_proj = nn.Linear(d_out, d_out, bias=False, dtype=dtype)\n",
|
" self.out_proj = nn.Linear(d_out, d_out, bias=False, dtype=dtype)\n",
|
||||||
"\n",
|
"\n",
|
||||||
" self.register_buffer(\"mask\", torch.triu(torch.ones(context_length, context_length), diagonal=1))\n",
|
" ############################# NEW #############################\n",
|
||||||
" cos, sin = precompute_rope_params(\n",
|
" # Fetch buffers using SharedBuffers\n",
|
||||||
" head_dim=self.head_dim,\n",
|
" mask, cos, sin = SharedBuffers.get_buffers(context_length, self.head_dim, rope_base, rope_config, dtype)\n",
|
||||||
" theta_base=rope_base, # NEW\n",
|
" ############################# NEW #############################\n",
|
||||||
" freq_config=rope_config, # NEW\n",
|
" \n",
|
||||||
" context_length=8192\n",
|
" self.register_buffer(\"mask\", mask)\n",
|
||||||
" )\n",
|
|
||||||
" self.register_buffer(\"cos\", cos)\n",
|
" self.register_buffer(\"cos\", cos)\n",
|
||||||
" self.register_buffer(\"sin\", sin)\n",
|
" self.register_buffer(\"sin\", sin)\n",
|
||||||
"\n",
|
"\n",
|
||||||
@ -907,6 +936,35 @@
|
|||||||
"model = Llama3Model(LLAMA3_CONFIG_8B)"
|
"model = Llama3Model(LLAMA3_CONFIG_8B)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "edea6334-d1fc-427d-9cf2-4af963ff4bfc",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"- The following is expected to print True to confirm buffers are reused instead of being (wastefully) recreated:"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "ee9625cc-9afa-4b11-8aab-d536fd170761",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Check buffers\n",
|
||||||
|
"print(model.trf_blocks[0].att.mask is model.trf_blocks[-1].att.mask)\n",
|
||||||
|
"print(model.trf_blocks[0].att.cos is model.trf_blocks[-1].att.cos)\n",
|
||||||
|
"print(model.trf_blocks[0].att.sin is model.trf_blocks[-1].att.sin) "
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "8056a521-91a6-440f-8473-591409c3177b",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"- Let's now also compute the number of trainable parameters:"
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 18,
|
"execution_count": 18,
|
||||||
@ -2009,7 +2067,7 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"LLAMA31_CONFIG_8B = {\n",
|
"LLAMA31_CONFIG_8B = {\n",
|
||||||
" \"vocab_size\": 128_256, # Vocabulary size\n",
|
" \"vocab_size\": 128_256, # Vocabulary size\n",
|
||||||
" \"context_length\": 8192, # Context length\n",
|
" \"context_length\": 131_072, # NEW: Larger supported context length\n",
|
||||||
" \"emb_dim\": 4096, # Embedding dimension\n",
|
" \"emb_dim\": 4096, # Embedding dimension\n",
|
||||||
" \"n_heads\": 32, # Number of attention heads\n",
|
" \"n_heads\": 32, # Number of attention heads\n",
|
||||||
" \"n_layers\": 32, # Number of layers\n",
|
" \"n_layers\": 32, # Number of layers\n",
|
||||||
@ -2026,6 +2084,24 @@
|
|||||||
"}"
|
"}"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "d81ee464-c112-43b0-9ee8-70df6ac942d0",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"- Reduce the context length so the model would work fine on a MacBook Air (if you have more RAM, feel free to comment out the lines below):"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "9bdbe32f-4c96-4e60-8bf4-52b5217df1e6",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"LLAMA32_CONFIG[\"context_length\"] = 8192"
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"id": "xa3bpMDtTdBs",
|
"id": "xa3bpMDtTdBs",
|
||||||
@ -2339,7 +2415,7 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"LLAMA31_CONFIG_8B = {\n",
|
"LLAMA31_CONFIG_8B = {\n",
|
||||||
" \"vocab_size\": 128_256, # Vocabulary size\n",
|
" \"vocab_size\": 128_256, # Vocabulary size\n",
|
||||||
" \"context_length\": 8192, # Context length\n",
|
" \"context_length\": 131_072, # NEW: Larger supported context length\n",
|
||||||
" \"emb_dim\": 4096, # Embedding dimension\n",
|
" \"emb_dim\": 4096, # Embedding dimension\n",
|
||||||
" \"n_heads\": 32, # Number of attention heads\n",
|
" \"n_heads\": 32, # Number of attention heads\n",
|
||||||
" \"n_layers\": 32, # Number of layers\n",
|
" \"n_layers\": 32, # Number of layers\n",
|
||||||
@ -2347,7 +2423,7 @@
|
|||||||
" \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n",
|
" \"n_kv_groups\": 8, # Key-Value groups for grouped-query attention\n",
|
||||||
" \"rope_base\": 50_000, # The base in RoPE's \"theta\"\n",
|
" \"rope_base\": 50_000, # The base in RoPE's \"theta\"\n",
|
||||||
" \"dtype\": torch.bfloat16, # Lower-precision dtype to save memory\n",
|
" \"dtype\": torch.bfloat16, # Lower-precision dtype to save memory\n",
|
||||||
" \"rope_freq\": { # RoPE frequency scaling\n",
|
" \"rope_freq\": { # NEW: RoPE frequency scaling\n",
|
||||||
" \"factor\": 8.0,\n",
|
" \"factor\": 8.0,\n",
|
||||||
" \"low_freq_factor\": 1.0,\n",
|
" \"low_freq_factor\": 1.0,\n",
|
||||||
" \"high_freq_factor\": 4.0,\n",
|
" \"high_freq_factor\": 4.0,\n",
|
||||||
@ -2358,7 +2434,7 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"LLAMA32_CONFIG_1B = {\n",
|
"LLAMA32_CONFIG_1B = {\n",
|
||||||
" \"vocab_size\": 128_256, # Vocabulary size\n",
|
" \"vocab_size\": 128_256, # Vocabulary size\n",
|
||||||
" \"context_length\": 8192, # Context length\n",
|
" \"context_length\": 131_072, # Context length\n",
|
||||||
" \"emb_dim\": 2048, # NEW: Half the embedding dimension\n",
|
" \"emb_dim\": 2048, # NEW: Half the embedding dimension\n",
|
||||||
" \"n_heads\": 32, # Number of attention heads\n",
|
" \"n_heads\": 32, # Number of attention heads\n",
|
||||||
" \"n_layers\": 16, # NEW: Half the number of layers\n",
|
" \"n_layers\": 16, # NEW: Half the number of layers\n",
|
||||||
@ -2375,6 +2451,24 @@
|
|||||||
"}"
|
"}"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "b5cd351b-d883-460d-9cdc-47e15ddb884a",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"- Reduce the context length so the model would work fine on a MacBook Air (if you have more RAM, feel free to comment out the lines below):"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"id": "387456c3-c6a1-46fe-8830-6e00eb46ac13",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"LLAMA32_CONFIG[\"context_length\"] = 8192"
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"id": "Dl4_0EoJKKYv",
|
"id": "Dl4_0EoJKKYv",
|
||||||
@ -2593,7 +2687,7 @@
|
|||||||
"provenance": []
|
"provenance": []
|
||||||
},
|
},
|
||||||
"kernelspec": {
|
"kernelspec": {
|
||||||
"display_name": "base",
|
"display_name": "Python 3 (ipykernel)",
|
||||||
"language": "python",
|
"language": "python",
|
||||||
"name": "python3"
|
"name": "python3"
|
||||||
},
|
},
|
||||||
@ -2607,7 +2701,7 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.10.11"
|
"version": "3.11.4"
|
||||||
},
|
},
|
||||||
"widgets": {
|
"widgets": {
|
||||||
"application/vnd.jupyter.widget-state+json": {
|
"application/vnd.jupyter.widget-state+json": {
|
||||||
|
@ -69,9 +69,9 @@
|
|||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"blobfile version: 3.0.0\n",
|
"blobfile version: 3.0.0\n",
|
||||||
"huggingface_hub version: 0.25.0\n",
|
"huggingface_hub version: 0.25.1\n",
|
||||||
"tiktoken version: 0.7.0\n",
|
"tiktoken version: 0.7.0\n",
|
||||||
"torch version: 2.5.0.dev20240812+cu121\n"
|
"torch version: 2.4.0\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
@ -201,6 +201,25 @@
|
|||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
|
"class SharedBuffers:\n",
|
||||||
|
" _buffers = {}\n",
|
||||||
|
"\n",
|
||||||
|
" @staticmethod\n",
|
||||||
|
" def get_buffers(context_length, head_dim, rope_base, freq_config, dtype=torch.float32):\n",
|
||||||
|
" key = (context_length, head_dim, rope_base, tuple(freq_config.values()) if freq_config else freq_config, dtype)\n",
|
||||||
|
"\n",
|
||||||
|
" if key not in SharedBuffers._buffers:\n",
|
||||||
|
" # Create or fetch the buffers\n",
|
||||||
|
" mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)\n",
|
||||||
|
" cos, sin = precompute_rope_params(head_dim, rope_base, context_length, freq_config)\n",
|
||||||
|
" if dtype is not None:\n",
|
||||||
|
" cos = cos.to(dtype)\n",
|
||||||
|
" sin = sin.to(dtype)\n",
|
||||||
|
" SharedBuffers._buffers[key] = (mask, cos, sin)\n",
|
||||||
|
"\n",
|
||||||
|
" return SharedBuffers._buffers[key]\n",
|
||||||
|
"\n",
|
||||||
|
"\n",
|
||||||
"class GroupedQueryAttention(nn.Module):\n",
|
"class GroupedQueryAttention(nn.Module):\n",
|
||||||
" def __init__(\n",
|
" def __init__(\n",
|
||||||
" self, d_in, d_out, context_length, num_heads,\n",
|
" self, d_in, d_out, context_length, num_heads,\n",
|
||||||
@ -225,13 +244,10 @@
|
|||||||
" self.W_query = nn.Linear(d_in, d_out, bias=False, dtype=dtype)\n",
|
" self.W_query = nn.Linear(d_in, d_out, bias=False, dtype=dtype)\n",
|
||||||
" self.out_proj = nn.Linear(d_out, d_out, bias=False, dtype=dtype)\n",
|
" self.out_proj = nn.Linear(d_out, d_out, bias=False, dtype=dtype)\n",
|
||||||
"\n",
|
"\n",
|
||||||
" self.register_buffer(\"mask\", torch.triu(torch.ones(context_length, context_length), diagonal=1))\n",
|
" # Fetch buffers using SharedBuffers\n",
|
||||||
" cos, sin = precompute_rope_params(\n",
|
" mask, cos, sin = SharedBuffers.get_buffers(context_length, self.head_dim, rope_base, rope_config, dtype)\n",
|
||||||
" head_dim=self.head_dim,\n",
|
" self.register_buffer(\"mask\", mask)\n",
|
||||||
" theta_base=rope_base,\n",
|
"\n",
|
||||||
" freq_config=rope_config,\n",
|
|
||||||
" context_length=8192\n",
|
|
||||||
" )\n",
|
|
||||||
" self.register_buffer(\"cos\", cos)\n",
|
" self.register_buffer(\"cos\", cos)\n",
|
||||||
" self.register_buffer(\"sin\", sin)\n",
|
" self.register_buffer(\"sin\", sin)\n",
|
||||||
"\n",
|
"\n",
|
||||||
@ -385,7 +401,7 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"LLAMA32_CONFIG = {\n",
|
"LLAMA32_CONFIG = {\n",
|
||||||
" \"vocab_size\": 128_256, # Vocabulary size\n",
|
" \"vocab_size\": 128_256, # Vocabulary size\n",
|
||||||
" \"context_length\": 8192, # Context length\n",
|
" \"context_length\": 131_072, # Context length\n",
|
||||||
" \"emb_dim\": 2048, # Embedding dimension\n",
|
" \"emb_dim\": 2048, # Embedding dimension\n",
|
||||||
" \"n_heads\": 32, # Number of attention heads\n",
|
" \"n_heads\": 32, # Number of attention heads\n",
|
||||||
" \"n_layers\": 16, # Number of layers\n",
|
" \"n_layers\": 16, # Number of layers\n",
|
||||||
@ -405,7 +421,7 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"# LLAMA32_CONFIG = {\n",
|
"# LLAMA32_CONFIG = {\n",
|
||||||
"# \"vocab_size\": 128_256, # Vocabulary size\n",
|
"# \"vocab_size\": 128_256, # Vocabulary size\n",
|
||||||
"# \"context_length\": 8192, # Context length\n",
|
"# \"context_length\": 131_000, # Context length\n",
|
||||||
"# \"emb_dim\": 3072, # Embedding dimension\n",
|
"# \"emb_dim\": 3072, # Embedding dimension\n",
|
||||||
"# \"n_heads\": 24, # Number of attention heads\n",
|
"# \"n_heads\": 24, # Number of attention heads\n",
|
||||||
"# \"n_layers\": 28, # Number of layers\n",
|
"# \"n_layers\": 28, # Number of layers\n",
|
||||||
@ -424,9 +440,27 @@
|
|||||||
"LLAMA_SIZE_STR = \"1B\" if LLAMA32_CONFIG[\"emb_dim\"] == 2048 else \"3B\""
|
"LLAMA_SIZE_STR = \"1B\" if LLAMA32_CONFIG[\"emb_dim\"] == 2048 else \"3B\""
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "34535172-797e-4dd0-84fb-65bc75ad5b06",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"- Reduce the context length so the model would work fine on a MacBook Air (if you have more RAM, feel free to comment out the lines below):"
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 9,
|
"execution_count": 9,
|
||||||
|
"id": "a8bc2370-39d2-4bfe-b4c1-6bdd75fe101c",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"LLAMA32_CONFIG[\"context_length\"] = 8192"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 10,
|
||||||
"id": "156253fe-aacd-4da2-8f13-705f05c4b11e",
|
"id": "156253fe-aacd-4da2-8f13-705f05c4b11e",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@ -434,9 +468,40 @@
|
|||||||
"model = Llama3Model(LLAMA32_CONFIG)"
|
"model = Llama3Model(LLAMA32_CONFIG)"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "19de6c2c-83ce-456d-8be9-6ec415fe9eb1",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"- The following is expected to print True to confirm buffers are reused instead of being (wastefully) recreated:"
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 10,
|
"execution_count": 11,
|
||||||
|
"id": "0e95db6d-2712-41a5-a5e0-86c49897f4cf",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"True\n",
|
||||||
|
"True\n",
|
||||||
|
"True\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"# Check buffers\n",
|
||||||
|
"print(model.trf_blocks[0].att.mask is model.trf_blocks[-1].att.mask)\n",
|
||||||
|
"print(model.trf_blocks[0].att.cos is model.trf_blocks[-1].att.cos)\n",
|
||||||
|
"print(model.trf_blocks[0].att.sin is model.trf_blocks[-1].att.sin) "
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 12,
|
||||||
"id": "364e76ca-52f8-4fa5-af37-c4069f9694bc",
|
"id": "364e76ca-52f8-4fa5-af37-c4069f9694bc",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@ -461,7 +526,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 11,
|
"execution_count": 13,
|
||||||
"id": "fd5efb03-5a07-46e8-8607-93ed47549d2b",
|
"id": "fd5efb03-5a07-46e8-8607-93ed47549d2b",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@ -469,8 +534,8 @@
|
|||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"float32 (PyTorch default): 15.23 GB\n",
|
"float32 (PyTorch default): 11.42 GB\n",
|
||||||
"bfloat16: 7.61 GB\n"
|
"bfloat16: 5.71 GB\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
@ -505,7 +570,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 12,
|
"execution_count": 14,
|
||||||
"id": "31f12baf-f79b-499f-85c0-51328a6a20f5",
|
"id": "31f12baf-f79b-499f-85c0-51328a6a20f5",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@ -531,7 +596,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 13,
|
"execution_count": 15,
|
||||||
"id": "9482b01c-49f9-48e4-ab2c-4a4c75240e77",
|
"id": "9482b01c-49f9-48e4-ab2c-4a4c75240e77",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@ -631,14 +696,14 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 14,
|
"execution_count": 16,
|
||||||
"id": "e9d96dc8-603a-4cb5-8c3e-4d2ca56862ed",
|
"id": "e9d96dc8-603a-4cb5-8c3e-4d2ca56862ed",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"data": {
|
"data": {
|
||||||
"application/vnd.jupyter.widget-view+json": {
|
"application/vnd.jupyter.widget-view+json": {
|
||||||
"model_id": "fede18d637d24f79a27220fb83bc6d2b",
|
"model_id": "8cdf801700d64fe9b2b827172a8eebcf",
|
||||||
"version_major": 2,
|
"version_major": 2,
|
||||||
"version_minor": 0
|
"version_minor": 0
|
||||||
},
|
},
|
||||||
@ -658,7 +723,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 15,
|
"execution_count": 17,
|
||||||
"id": "986bc1a0-804f-4154-80f8-44cefbee1368",
|
"id": "986bc1a0-804f-4154-80f8-44cefbee1368",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@ -674,7 +739,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 16,
|
"execution_count": 18,
|
||||||
"id": "f5a3014f-4c66-4fe2-874e-7b57562c49ad",
|
"id": "f5a3014f-4c66-4fe2-874e-7b57562c49ad",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@ -694,7 +759,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 17,
|
"execution_count": 19,
|
||||||
"id": "75166128-5899-4995-9b88-9672e135650e",
|
"id": "75166128-5899-4995-9b88-9672e135650e",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@ -775,7 +840,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 18,
|
"execution_count": 20,
|
||||||
"id": "699cb1b8-a67d-49fb-80a6-0dad9d81f392",
|
"id": "699cb1b8-a67d-49fb-80a6-0dad9d81f392",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@ -818,7 +883,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 19,
|
"execution_count": 21,
|
||||||
"id": "7f9f7ccc-70cb-41ff-9c25-44336042fc37",
|
"id": "7f9f7ccc-70cb-41ff-9c25-44336042fc37",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@ -845,7 +910,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 20,
|
"execution_count": 22,
|
||||||
"id": "7b8401c6-e244-4cb7-9849-2ba71ce758d5",
|
"id": "7b8401c6-e244-4cb7-9849-2ba71ce758d5",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@ -902,7 +967,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 21,
|
"execution_count": 23,
|
||||||
"id": "1c7a04fa-6aac-416b-8f63-f1e19227633d",
|
"id": "1c7a04fa-6aac-416b-8f63-f1e19227633d",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@ -915,10 +980,10 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"1. Grasses: Llamas love to graze on various types of grasses, including tall grasses and short grasses.\n",
|
"1. Grasses: Llamas love to graze on various types of grasses, including tall grasses and short grasses.\n",
|
||||||
"2. Hay: Llamas also eat hay, which is a dry, compressed form of grass or other plants.\n",
|
"2. Hay: Llamas also eat hay, which is a dry, compressed form of grass or other plants.\n",
|
||||||
"3. Alfalfa: Alfalfa is a legume that is commonly used as a hay substitute in llama feed.\n",
|
"3. Alfalfa: Alfalfa is a legume that is commonly fed to llamas, as it is high in protein and fiber.\n",
|
||||||
"4. Other plants: Llamas will also eat other plants, such as clover, oats, and barley.\n",
|
"4. Other plants: Llamas will also eat other plants, such as clover, wild grasses, and shrubs.\n",
|
||||||
"\n",
|
"\n",
|
||||||
"It's worth noting that llamas are adapted to high-altitude environments and can survive on low-quality hay and\n"
|
"It's worth noting that llamas are adapted to high altitudes and\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
@ -982,7 +1047,7 @@
|
|||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"kernelspec": {
|
"kernelspec": {
|
||||||
"display_name": "base",
|
"display_name": "Python 3 (ipykernel)",
|
||||||
"language": "python",
|
"language": "python",
|
||||||
"name": "python3"
|
"name": "python3"
|
||||||
},
|
},
|
||||||
@ -996,7 +1061,7 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.10.11"
|
"version": "3.11.4"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user