Memory efficient weight loading (#401)

* memory efficient weight loading

* remove unused code
This commit is contained in:
Sebastian Raschka 2024-10-14 10:30:25 -05:00 committed by GitHub
parent 59a5c83726
commit 3d54af20f5
5 changed files with 1043 additions and 0 deletions

View File

@ -118,6 +118,7 @@ Several folders contain optional materials as a bonus for interested readers:
- [Building a User Interface to Interact With the Pretrained LLM](ch05/06_user_interface)
- [Converting GPT to Llama](ch05/07_gpt_to_llama)
- [Llama 3.2 From Scratch](ch05/07_gpt_to_llama/standalone-llama32.ipynb)
- [Memory-efficient Model Weight Loading](ch05/08_memory_efficient_weight_loading/memory-efficient-state-dict.ipynb)
- **Chapter 6:**
- [Additional experiments finetuning different layers and using larger models](ch06/02_bonus_additional-experiments)
- [Finetuning different models on 50k IMDB movie review dataset](ch06/03_bonus_imdb-classification)

View File

@ -0,0 +1,5 @@
# Memory-efficient Model Weight Loading
This folder contains code to illustrate how to load model weights more efficiently
- [memory-efficient-state-dict.ipynb](memory-efficient-state-dict.ipynb): contains code to load model weights via PyTorch's `load_state_dict` method more efficiently

View File

@ -0,0 +1,866 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "1E_HhLEeYqFG"
},
"source": [
"<table style=\"width:100%\">\n",
"<tr>\n",
"<td style=\"vertical-align:middle; text-align:left;\">\n",
"<font size=\"2\">\n",
"Supplementary code for the <a href=\"http://mng.bz/orYv\">Build a Large Language Model From Scratch</a> book by <a href=\"https://sebastianraschka.com\">Sebastian Raschka</a><br>\n",
"<br>Code repository: <a href=\"https://github.com/rasbt/LLMs-from-scratch\">https://github.com/rasbt/LLMs-from-scratch</a>\n",
"</font>\n",
"</td>\n",
"<td style=\"vertical-align:middle; text-align:left;\">\n",
"<a href=\"http://mng.bz/orYv\"><img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/cover-small.webp\" width=\"100px\"></a>\n",
"</td>\n",
"</tr>\n",
"</table>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZuWudYFWYiH7"
},
"source": [
"# Memory-efficient Model Weight Loading"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qt0Qyg6ewUt6"
},
"source": [
"- This notebook provides tips for loading larger pretrained or finetuned models when GPU (or CPU) memory is limited\n",
"- Specifically, it focuses on cases where you saved the model using `torch.save(model.state_dict(), \"model.pth\")` (for example, in chapters 5-7) and want to load it in a new session later for continued pretraining or additional finetuning\n",
"- While the example uses an LLM, the methods explained in this notebook are general and apply to loading any PyTorch model, not just LLMs"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/memory-efficient-loading/memory-efficient-loading.webp\" width=\"800px\">"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "SxQzFoS-IXdY",
"outputId": "b28ebfbd-9036-4696-d95a-7f96fdf29919"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"memory_profiler version: 0.61.0\n",
"torch version: 2.4.1+cu121\n"
]
}
],
"source": [
"from importlib.metadata import version\n",
"\n",
"pkgs = [\n",
" \"torch\",\n",
"]\n",
"for p in pkgs:\n",
" print(f\"{p} version: {version(p)}\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "y47iQaQKyHap"
},
"source": [
"&nbsp;\n",
"## 1. Benchmark utilities"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "nQeOEoo6yT0X"
},
"source": [
"- First, let's define some utility code to track VRAM (GPU memory)\n",
"- Later, we will also introduce a tool to track the main system RAM (CPU memory)\n",
"- The purpose of these functions will become clear when we apply them later"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"id": "pEiqjYrVivgt"
},
"outputs": [],
"source": [
"import gc\n",
"import time\n",
"import torch\n",
"\n",
"\n",
"def start_memory_tracking():\n",
" \"\"\"Initialize GPU memory tracking.\"\"\"\n",
" if torch.cuda.is_available():\n",
" torch.cuda.reset_peak_memory_stats()\n",
" else:\n",
" print(\"This notebook is intended for CUDA GPUs but CUDA is not available.\")\n",
"\n",
"def print_memory_usage():\n",
" max_gpu_memory = torch.cuda.max_memory_allocated() / (1024 ** 3) # Convert bytes to GB\n",
" print(f\"Maximum GPU memory allocated: {max_gpu_memory:.1f} GB\")\n",
"\n",
"def cleanup():\n",
" gc.collect()\n",
" torch.cuda.empty_cache()\n",
" time.sleep(3) # some buffer time to allow memory to clear\n",
" torch.cuda.reset_peak_memory_stats()\n",
" max_memory_allocated = torch.cuda.max_memory_allocated(device) / (1024 ** 3)\n",
" print(f\"Maximum GPU memory allocated: {max_memory_allocated:.1f} GB\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "z5oJwoc-kkXs"
},
"source": [
"&nbsp;\n",
"## 2. Model setup"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "YfJE0vnMyr88"
},
"source": [
"- This code section sets up the model itself\n",
"- Here, we use the \"large\" GPT-2 model to make things more interesting (you may use the \"gpt2-small (124M)\" to lower the memory requirements and execution time of this notebook)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"id": "tMuhCYaVI0w7"
},
"outputs": [],
"source": [
"from previous_chapters import GPTModel\n",
"\n",
"\n",
"BASE_CONFIG = {\n",
" \"vocab_size\": 50257, # Vocabulary size\n",
" \"context_length\": 1024, # Context length\n",
" \"drop_rate\": 0.0, # Dropout rate\n",
" \"qkv_bias\": True # Query-key-value bias\n",
"}\n",
"\n",
"model_configs = {\n",
" \"gpt2-small (124M)\": {\"emb_dim\": 768, \"n_layers\": 12, \"n_heads\": 12},\n",
" \"gpt2-medium (355M)\": {\"emb_dim\": 1024, \"n_layers\": 24, \"n_heads\": 16},\n",
" \"gpt2-large (774M)\": {\"emb_dim\": 1280, \"n_layers\": 36, \"n_heads\": 20},\n",
" \"gpt2-xl (1558M)\": {\"emb_dim\": 1600, \"n_layers\": 48, \"n_heads\": 25},\n",
"}\n",
"\n",
"CHOOSE_MODEL = \"gpt2-xl (1558M)\"\n",
"\n",
"BASE_CONFIG.update(model_configs[CHOOSE_MODEL])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "KWYoo1z5y8aX"
},
"source": [
"- Now, let's see the GPU memory functions in action:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "GK3NEA3eJv3f",
"outputId": "60573d6e-c603-45e7-8283-b1e92e2a0013"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Maximum GPU memory allocated: 6.4 GB\n"
]
}
],
"source": [
"start_memory_tracking()\n",
"\n",
"\n",
"model = GPTModel(BASE_CONFIG)\n",
"device = torch.device(\"cuda\")\n",
"model.to(device)\n",
"\n",
"print_memory_usage()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "GIhwBEBxzBsF"
},
"source": [
"- Additionally, let's make sure that the model runs okay by passing in some example tensor"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"id": "i_j6nZruUd7g"
},
"outputs": [],
"source": [
"# Test if the model works (no need to track memory here)\n",
"test_input = torch.tensor([[1, 2, 3]]).to(device)\n",
"model.eval()\n",
"\n",
"with torch.no_grad():\n",
" model(test_input)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "UgNb8c32zh4g"
},
"source": [
"- Next, imagine we were pretraining the model and saving it for later use\n",
"- We skip the actual pretraining here for simplicity and just save the initialized model (but the same concept applies)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"id": "wUIXjcsimXU7"
},
"outputs": [],
"source": [
"# Training code would go here...\n",
"\n",
"model.train()\n",
"torch.save(model.state_dict(), \"model.pth\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "s9tBS4HUzz1g"
},
"source": [
"- Lastly, we delete the model and example tensor in the Python session to reset the GPU memory"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "SqmTzztqKnTs",
"outputId": "1198afb9-2d97-4b6a-9bdb-41551f25749d"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Maximum GPU memory allocated: 0.0 GB\n"
]
}
],
"source": [
"del model, test_input\n",
"cleanup()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7EnO8beUJ6Sb"
},
"source": [
"&nbsp;\n",
"## 3. Weight loading"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "JtAXKjsG0AVL"
},
"source": [
"- Now begins the interesting part where we load the pretrained model weights\n",
"- Let's see how much GPU memory is required to load the previously saved model"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "wCrQNbSJJO9w",
"outputId": "9b203868-a8ef-4011-fc2b-611cc0d10994"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Maximum GPU memory allocated: 12.8 GB\n"
]
}
],
"source": [
"# Then load pretrained weights\n",
"\n",
"start_memory_tracking()\n",
"\n",
"model = GPTModel(BASE_CONFIG)\n",
"model.to(device)\n",
"\n",
"model.load_state_dict(\n",
" torch.load(\"model.pth\", map_location=device, weights_only=True)\n",
")\n",
"model.to(device)\n",
"model.eval();\n",
"\n",
"print_memory_usage()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4AGvOrcN0KdJ"
},
"source": [
"- Notice that the memory is 2x as large as in the previous session\n",
"- This is because we have the same model in memory twice, for a short period of time:\n",
" - The first time via `model.to(device)`\n",
" - The second time via the code line `model.load_state_dict(torch.load(\"model.pth\", map_location=device, weights_only=True))`; eventually, the loaded model weights will be copied into the model, and the `state_dict` will be discarded, but for a brief amount of time, we have both the main model and the loaded `state_dict` in memory\n",
"- The remaining sections focus on addressing this\n",
"- But first, let's test the model and reset the GPU memory\n"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "DvlUn-nmmbuj",
"outputId": "11d3ab68-f570-4c1e-c631-fe5547026799"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Maximum GPU memory allocated: 0.0 GB\n"
]
}
],
"source": [
"# Test if the model works (no need to track memory here)\n",
"test_input = torch.tensor([[1, 2, 3]]).to(device)\n",
"model.eval()\n",
"\n",
"with torch.no_grad():\n",
" model(test_input)\n",
"\n",
"del model, test_input\n",
"cleanup()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RdPnW3iLLrjX"
},
"source": [
"&nbsp;\n",
"## 4. Loading weights sequentially"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "FYqtUON602TD"
},
"source": [
"- One workaround for the problem of having the model weights in GPU memory twice, as highlighted in the previous section, is to load the model sequentially\n",
"- Below, we:\n",
" - first load the model into GPU memory\n",
" - then load the model weights into CPU memory\n",
" - and finally copy each parameter one by one into GPU memory\n"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "DOIGTNWTmx9G",
"outputId": "145162e6-aaa6-4c2a-ed8f-f1cf068adb80"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Maximum GPU memory allocated: 6.4 GB\n",
"Maximum GPU memory allocated: 6.7 GB\n"
]
}
],
"source": [
"start_memory_tracking()\n",
"\n",
"model = GPTModel(BASE_CONFIG).to(device)\n",
"\n",
"state_dict = torch.load(\"model.pth\", map_location=\"cpu\", weights_only=True)\n",
"\n",
"print_memory_usage()\n",
"\n",
"# Sequentially copy weights to the model's parameters\n",
"with torch.no_grad():\n",
" for name, param in model.named_parameters():\n",
" if name in state_dict:\n",
" param.copy_(state_dict[name].to(device))\n",
" else:\n",
" print(f\"Warning: {name} not found in state_dict.\")\n",
"\n",
"print_memory_usage()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Pn9xD_xL1ZzM"
},
"source": [
"- As we can see above, the memory usage is much lower than before\n",
"- Notice that the memory increases from 6.4 to 6.7 GB because initially, we only have the model in memory, and then we have the model plus 1 parameter tensor in memory (we temporarily move the parameter tensor to the GPU so we can assign it using `\".to\"` the model)\n",
"- Overall, this is a significant improvement\n",
"- Again, let's briefly test the model and then reset the GPU memory for the next section"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "PRHnjA48nJgw",
"outputId": "dcd6b1b2-538f-4862-96a6-a5fcbf3326a4"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Maximum GPU memory allocated: 0.0 GB\n"
]
}
],
"source": [
"# Test if the model works (no need to track memory here)\n",
"test_input = torch.tensor([[1, 2, 3]]).to(device)\n",
"model.eval()\n",
"\n",
"with torch.no_grad():\n",
" model(test_input)\n",
"\n",
"del model, test_input, state_dict, param\n",
"cleanup()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5M92LK7usb-Z"
},
"source": [
"&nbsp;\n",
"## 5. Loading the model with low CPU memory"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "R45qgeB613e2"
},
"source": [
"- In the previous session, we reduced GPU memory use by loading the weights (`state_dict`) into CPU memory first before copying them one-by-one into the model\n",
"- However, what do we do if we have limited CPU memory?\n",
"- This section uses PyTorch's so-called `\"meta\"` device approach to load a model on machines with large GPU memory but small CPU memory\n",
"- But first, let's define a convenience function to monitor CPU memory"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"id": "BrcWy0q-3Bbe"
},
"outputs": [],
"source": [
"import os\n",
"import psutil\n",
"from threading import Thread\n",
"\n",
"\n",
"def memory_usage_in_gb(func, *args, **kwargs):\n",
" process = psutil.Process(os.getpid())\n",
"\n",
" # Measure the baseline memory usage before running the function\n",
" baseline_mem = process.memory_info().rss / 1024 ** 3 # in GB\n",
"\n",
" # Start monitoring memory in a separate thread\n",
" mem_usage = []\n",
" done = False\n",
"\n",
" def monitor_memory():\n",
" while not done:\n",
" mem_usage.append(process.memory_info().rss / 1024 ** 3) # Convert to GB\n",
" time.sleep(0.1)\n",
"\n",
" t = Thread(target=monitor_memory)\n",
" t.start()\n",
"\n",
" # Run the function\n",
" func(*args, **kwargs)\n",
"\n",
" # Stop monitoring\n",
" done = True\n",
" t.join()\n",
"\n",
" peak_mem_usage_gb = max(mem_usage) - baseline_mem\n",
" return peak_mem_usage_gb\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Ayy30Ytd5hjF"
},
"source": [
"- To start with, let's track the CPU memory of the sequential weight loading approach from the previous section"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "rCkV6IbQtpVn",
"outputId": "26c0435a-1e3d-4e8f-fbe2-f9655bad61b4"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Maximum GPU memory allocated: 6.4 GB\n",
"Maximum GPU memory allocated: 6.7 GB\n",
"-> Maximum CPU memory allocated: 6.3 GB\n"
]
}
],
"source": [
"def load_sequentially():\n",
" start_memory_tracking()\n",
"\n",
" model = GPTModel(BASE_CONFIG).to(device)\n",
"\n",
" state_dict = torch.load(\"model.pth\", map_location=\"cpu\", weights_only=True)\n",
"\n",
" print_memory_usage()\n",
"\n",
" # Sequentially copy weights to the model's parameters\n",
" with torch.no_grad():\n",
" for name, param in model.named_parameters():\n",
" if name in state_dict:\n",
" param.copy_(state_dict[name].to(device))\n",
" else:\n",
" print(f\"Warning: {name} not found in state_dict.\")\n",
"\n",
" print_memory_usage()\n",
"\n",
"\n",
"peak_memory_used = memory_usage_in_gb(load_sequentially)\n",
"print(f\"-> Maximum CPU memory allocated: {peak_memory_used:.1f} GB\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "UWrmnCML5oKy"
},
"source": [
"- Now, suppose we have a machine with low CPU memory but large GPU memory\n",
"- We can trade off CPU memory and GPU memory usage by introducing PyTorch's so-called \"meta\" device\n",
"- PyTorch's meta device is a special device type that allows you to create tensors without allocating actual memory for their data, effectively creating \"meta\" tensors\n",
"- This is useful for tasks like model analysis or architecture definition, where you need tensor shapes and types without the overhead of memory allocation"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "PBErC_5Yt8ly",
"outputId": "8799db06-191c-47c4-92fa-fbb95d685aa9"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Maximum GPU memory allocated: 12.8 GB\n",
"Maximum GPU memory allocated: 12.8 GB\n",
"-> Maximum CPU memory allocated: 1.3 GB\n"
]
}
],
"source": [
"def load_sequentially_with_meta():\n",
" start_memory_tracking()\n",
"\n",
" with torch.device(\"meta\"):\n",
" model = GPTModel(BASE_CONFIG)\n",
"\n",
" model = model.to_empty(device=device)\n",
"\n",
" state_dict = torch.load(\"model.pth\", map_location=device, weights_only=True)\n",
"\n",
" print_memory_usage()\n",
"\n",
" # Sequentially copy weights to the model's parameters\n",
" with torch.no_grad():\n",
" for name, param in model.named_parameters():\n",
" if name in state_dict:\n",
" param.copy_(state_dict[name])\n",
" else:\n",
" print(f\"Warning: {name} not found in state_dict.\")\n",
"\n",
" print_memory_usage()\n",
"\n",
"peak_memory_used = memory_usage_in_gb(load_sequentially_with_meta)\n",
"print(f\"-> Maximum CPU memory allocated: {peak_memory_used:.1f} GB\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "VpnCABp75-VQ"
},
"source": [
"- As we can see above, by creating the model on the meta-device and loading the weights directly into GPU memory, we effectively reduced the CPU memory requirements\n",
"- One might ask: \"Is the sequential weight loading still necessary then, and how does that compare to the original approach?\"\n",
"- Let's check the simple PyTorch weight loading approach for comparison (from the first weight loading section in this notebook):"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "4f-bqBNRuR39",
"outputId": "f7c0a901-b404-433a-9b93-2bbfa8183c56"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Maximum GPU memory allocated: 12.8 GB\n",
"-> Maximum CPU memory allocated: 4.4 GB\n"
]
}
],
"source": [
"def baseline():\n",
" start_memory_tracking()\n",
"\n",
" model = GPTModel(BASE_CONFIG)\n",
" model.to(device)\n",
"\n",
" model.load_state_dict(torch.load(\"model.pth\", map_location=device, weights_only=True))\n",
" model.to(device)\n",
" model.eval();\n",
"\n",
" print_memory_usage()\n",
"\n",
"peak_memory_used = memory_usage_in_gb(baseline)\n",
"print(f\"-> Maximum CPU memory allocated: {peak_memory_used:.1f} GB\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "NKAjxbX86xnb"
},
"source": [
"- As we can see above, the \"simple\" weight loading without the meta device uses more memory\n",
"- In other words, if you have a machine with limited CPU memory, you can use the meta device approach to directly load the model weights into GPU memory to reduce peak CPU memory usage"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"&nbsp;\n",
"## 6. Other methods"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"- This notebook is focused on simple, built-in methods for loading weights in PyTorch.\n",
"- In case none of these methods work because you (1) don't have enough CPU memory for the `load_sequentially` approach and don't have enough GPU VRAM to have 2 copies of the weights in memory (the `load_sequentially_with_meta` approach), one option is to save and load each weight tensor separately:"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"id": "2CgPEZUIb00w"
},
"outputs": [],
"source": [
"model = GPTModel(BASE_CONFIG)\n",
"# Assume `model` is your trained model\n",
"state_dict = model.state_dict()\n",
"\n",
"# Create a directory to store individual parameter files\n",
"os.makedirs(\"model_parameters\", exist_ok=True)\n",
"\n",
"# Save each parameter tensor separately\n",
"for name, param in state_dict.items():\n",
" torch.save(param.cpu(), f\"model_parameters/{name}.pt\")\n",
"\n",
"del model"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "gTsmtJK-b4yy",
"outputId": "d361e2d3-e34c-48d7-9047-846c9bfd291e"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Maximum GPU memory allocated: 6.4 GB\n",
"Maximum GPU memory allocated: 6.4 GB\n",
"-> Maximum CPU memory allocated: 0.3 GB\n"
]
}
],
"source": [
"def load_individual_weights():\n",
"\n",
" start_memory_tracking()\n",
"\n",
" with torch.device(\"meta\"):\n",
" model = GPTModel(BASE_CONFIG)\n",
"\n",
" model = model.to_empty(device=device)\n",
"\n",
" print_memory_usage()\n",
" param_dir = \"model_parameters\"\n",
"\n",
" with torch.no_grad():\n",
" for name, param in model.named_parameters():\n",
" weight_path = os.path.join(param_dir, f\"{name}.pt\")\n",
" if os.path.exists(weight_path):\n",
" param_data = torch.load(weight_path, map_location=\"cpu\", weights_only=True)\n",
" param.copy_(param_data)\n",
" del param_data # Free memory\n",
" else:\n",
" print(f\"Warning: {name} not found in {param_dir}.\")\n",
"\n",
" print_memory_usage()\n",
"\n",
"\n",
"peak_memory_used = memory_usage_in_gb(load_individual_weights)\n",
"print(f\"-> Maximum CPU memory allocated: {peak_memory_used:.1f} GB\")"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"gpuType": "L4",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
}
},
"nbformat": 4,
"nbformat_minor": 4
}

View File

@ -0,0 +1,170 @@
# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
# Source for "Build a Large Language Model From Scratch"
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
# Code: https://github.com/rasbt/LLMs-from-scratch
#
# This file collects all the relevant code that we covered thus far
# throughout Chapters 2-5.
import torch
import torch.nn as nn
#####################################
# Chapter 3
#####################################
class MultiHeadAttention(nn.Module):
def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
super().__init__()
assert d_out % num_heads == 0, "d_out must be divisible by n_heads"
self.d_out = d_out
self.num_heads = num_heads
self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim
self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs
self.dropout = nn.Dropout(dropout)
self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))
def forward(self, x):
b, num_tokens, d_in = x.shape
keys = self.W_key(x) # Shape: (b, num_tokens, d_out)
queries = self.W_query(x)
values = self.W_value(x)
# We implicitly split the matrix by adding a `num_heads` dimension
# Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
values = values.view(b, num_tokens, self.num_heads, self.head_dim)
queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
# Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
keys = keys.transpose(1, 2)
queries = queries.transpose(1, 2)
values = values.transpose(1, 2)
# Compute scaled dot-product attention (aka self-attention) with a causal mask
attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head
# Original mask truncated to the number of tokens and converted to boolean
mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
# Use the mask to fill attention scores
attn_scores.masked_fill_(mask_bool, -torch.inf)
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
attn_weights = self.dropout(attn_weights)
# Shape: (b, num_tokens, num_heads, head_dim)
context_vec = (attn_weights @ values).transpose(1, 2)
# Combine heads, where self.d_out = self.num_heads * self.head_dim
context_vec = context_vec.reshape(b, num_tokens, self.d_out)
context_vec = self.out_proj(context_vec) # optional projection
return context_vec
#####################################
# Chapter 4
#####################################
class LayerNorm(nn.Module):
def __init__(self, emb_dim):
super().__init__()
self.eps = 1e-5
self.scale = nn.Parameter(torch.ones(emb_dim))
self.shift = nn.Parameter(torch.zeros(emb_dim))
def forward(self, x):
mean = x.mean(dim=-1, keepdim=True)
var = x.var(dim=-1, keepdim=True, unbiased=False)
norm_x = (x - mean) / torch.sqrt(var + self.eps)
return self.scale * norm_x + self.shift
class GELU(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return 0.5 * x * (1 + torch.tanh(
torch.sqrt(torch.tensor(2.0 / torch.pi)) *
(x + 0.044715 * torch.pow(x, 3))
))
class FeedForward(nn.Module):
def __init__(self, cfg):
super().__init__()
self.layers = nn.Sequential(
nn.Linear(cfg["emb_dim"], 4 * cfg["emb_dim"]),
GELU(),
nn.Linear(4 * cfg["emb_dim"], cfg["emb_dim"]),
)
def forward(self, x):
return self.layers(x)
class TransformerBlock(nn.Module):
def __init__(self, cfg):
super().__init__()
self.att = MultiHeadAttention(
d_in=cfg["emb_dim"],
d_out=cfg["emb_dim"],
context_length=cfg["context_length"],
num_heads=cfg["n_heads"],
dropout=cfg["drop_rate"],
qkv_bias=cfg["qkv_bias"])
self.ff = FeedForward(cfg)
self.norm1 = LayerNorm(cfg["emb_dim"])
self.norm2 = LayerNorm(cfg["emb_dim"])
self.drop_shortcut = nn.Dropout(cfg["drop_rate"])
def forward(self, x):
# Shortcut connection for attention block
shortcut = x
x = self.norm1(x)
x = self.att(x) # Shape [batch_size, num_tokens, emb_size]
x = self.drop_shortcut(x)
x = x + shortcut # Add the original input back
# Shortcut connection for feed-forward block
shortcut = x
x = self.norm2(x)
x = self.ff(x)
x = self.drop_shortcut(x)
x = x + shortcut # Add the original input back
return x
class GPTModel(nn.Module):
def __init__(self, cfg):
super().__init__()
self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"])
self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"])
self.drop_emb = nn.Dropout(cfg["drop_rate"])
self.trf_blocks = nn.Sequential(
*[TransformerBlock(cfg) for _ in range(cfg["n_layers"])])
self.final_norm = LayerNorm(cfg["emb_dim"])
self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False)
def forward(self, in_idx):
batch_size, seq_len = in_idx.shape
tok_embeds = self.tok_emb(in_idx)
pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))
x = tok_embeds + pos_embeds # Shape [batch_size, num_tokens, emb_size]
x = self.drop_emb(x)
x = self.trf_blocks(x)
x = self.final_norm(x)
logits = self.out_head(x)
return logits

View File

@ -14,3 +14,4 @@
- [05_bonus_hparam_tuning](05_bonus_hparam_tuning) contains an optional hyperparameter tuning script
- [06_user_interface](06_user_interface) implements an interactive user interface to interact with the pretrained LLM
- [07_gpt_to_llama](07_gpt_to_llama) contains a step-by-step guide for converting a GPT architecture implementation to Llama 3.2 and loads pretrained weights from Meta AI
- [08_memory_efficient_weight_loading](08_memory_efficient_weight_loading) contains a bonus notebook showing how to load model weights via PyTorch's `load_state_dict` method more efficiently