2024-10-14 10:30:25 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								{
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 "cells": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "1E_HhLEeYqFG"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "<table style=\"width:100%\">\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "<tr>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "<td style=\"vertical-align:middle; text-align:left;\">\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "<font size=\"2\">\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "Supplementary code for the <a href=\"http://mng.bz/orYv\">Build a Large Language Model From Scratch</a> book by <a href=\"https://sebastianraschka.com\">Sebastian Raschka</a><br>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "<br>Code repository: <a href=\"https://github.com/rasbt/LLMs-from-scratch\">https://github.com/rasbt/LLMs-from-scratch</a>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "</font>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "</td>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "<td style=\"vertical-align:middle; text-align:left;\">\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "<a href=\"http://mng.bz/orYv\"><img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/cover-small.webp\" width=\"100px\"></a>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "</td>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "</tr>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "</table>"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "ZuWudYFWYiH7"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "# Memory-efficient Model Weight Loading"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "qt0Qyg6ewUt6"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- This notebook provides tips for loading larger pretrained or finetuned models when GPU (or CPU) memory is limited\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Specifically, it focuses on cases where you saved the model using `torch.save(model.state_dict(), \"model.pth\")` (for example, in chapters 5-7) and want to load it in a new session later for continued pretraining or additional finetuning\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- While the example uses an LLM, the methods explained in this notebook are general and apply to loading any PyTorch model, not just LLMs"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/memory-efficient-loading/memory-efficient-loading.webp\" width=\"800px\">"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "execution_count": 1,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "colab": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "base_uri": "https://localhost:8080/"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "SxQzFoS-IXdY",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "outputId": "b28ebfbd-9036-4696-d95a-7f96fdf29919"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "memory_profiler version: 0.61.0\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "torch version: 2.4.1+cu121\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "from importlib.metadata import version\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "pkgs = [\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"torch\",\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "]\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "for p in pkgs:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    print(f\"{p} version: {version(p)}\")"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "y47iQaQKyHap"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    " \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "## 1. Benchmark utilities"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "nQeOEoo6yT0X"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- First, let's define some utility code to track VRAM (GPU memory)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Later, we will also introduce a tool to track the main system RAM (CPU memory)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- The purpose of these functions will become clear when we apply them later"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "execution_count": 2,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "pEiqjYrVivgt"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "import gc\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "import time\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "import torch\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "def start_memory_tracking():\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"\"\"Initialize GPU memory tracking.\"\"\"\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    if torch.cuda.is_available():\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        torch.cuda.reset_peak_memory_stats()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    else:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        print(\"This notebook is intended for CUDA GPUs but CUDA is not available.\")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "def print_memory_usage():\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    max_gpu_memory = torch.cuda.max_memory_allocated() / (1024 ** 3)  # Convert bytes to GB\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    print(f\"Maximum GPU memory allocated: {max_gpu_memory:.1f} GB\")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "def cleanup():\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    gc.collect()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    torch.cuda.empty_cache()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    time.sleep(3)  # some buffer time to allow memory to clear\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    torch.cuda.reset_peak_memory_stats()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    max_memory_allocated = torch.cuda.max_memory_allocated(device) / (1024 ** 3)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    print(f\"Maximum GPU memory allocated: {max_memory_allocated:.1f} GB\")"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "z5oJwoc-kkXs"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    " \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "## 2. Model setup"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "YfJE0vnMyr88"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- This code section sets up the model itself\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Here, we use the \"large\" GPT-2 model to make things more interesting (you may use the \"gpt2-small (124M)\" to lower the memory requirements and execution time of this notebook)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "execution_count": 3,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "tMuhCYaVI0w7"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "from previous_chapters import GPTModel\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "BASE_CONFIG = {\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"vocab_size\": 50257,     # Vocabulary size\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"context_length\": 1024,  # Context length\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"drop_rate\": 0.0,        # Dropout rate\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"qkv_bias\": True         # Query-key-value bias\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "}\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "model_configs = {\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"gpt2-small (124M)\": {\"emb_dim\": 768, \"n_layers\": 12, \"n_heads\": 12},\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"gpt2-medium (355M)\": {\"emb_dim\": 1024, \"n_layers\": 24, \"n_heads\": 16},\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"gpt2-large (774M)\": {\"emb_dim\": 1280, \"n_layers\": 36, \"n_heads\": 20},\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"gpt2-xl (1558M)\": {\"emb_dim\": 1600, \"n_layers\": 48, \"n_heads\": 25},\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "}\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "CHOOSE_MODEL = \"gpt2-xl (1558M)\"\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "BASE_CONFIG.update(model_configs[CHOOSE_MODEL])"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "KWYoo1z5y8aX"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Now, let's see the GPU memory functions in action:"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "execution_count": 4,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "colab": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "base_uri": "https://localhost:8080/"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "GK3NEA3eJv3f",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "outputId": "60573d6e-c603-45e7-8283-b1e92e2a0013"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Maximum GPU memory allocated: 6.4 GB\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "start_memory_tracking()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "model = GPTModel(BASE_CONFIG)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "device = torch.device(\"cuda\")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "model.to(device)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print_memory_usage()"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "GIhwBEBxzBsF"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Additionally, let's make sure that the model runs okay by passing in some example tensor"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "execution_count": 5,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "i_j6nZruUd7g"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "# Test if the model works (no need to track memory here)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "test_input = torch.tensor([[1, 2, 3]]).to(device)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "model.eval()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "with torch.no_grad():\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    model(test_input)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "UgNb8c32zh4g"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Next, imagine we were pretraining the model and saving it for later use\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- We skip the actual pretraining here for simplicity and just save the initialized model (but the same concept applies)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "execution_count": 6,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "wUIXjcsimXU7"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "# Training code would go here...\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "model.train()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "torch.save(model.state_dict(), \"model.pth\")"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "s9tBS4HUzz1g"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Lastly, we delete the model and example tensor in the Python session to reset the GPU memory"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "execution_count": 7,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "colab": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "base_uri": "https://localhost:8080/"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "SqmTzztqKnTs",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "outputId": "1198afb9-2d97-4b6a-9bdb-41551f25749d"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Maximum GPU memory allocated: 0.0 GB\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "del model, test_input\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "cleanup()"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "7EnO8beUJ6Sb"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    " \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "## 3. Weight loading"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "JtAXKjsG0AVL"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Now begins the interesting part where we load the pretrained model weights\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Let's see how much GPU memory is required to load the previously saved model"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "execution_count": 8,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "colab": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "base_uri": "https://localhost:8080/"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "wCrQNbSJJO9w",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "outputId": "9b203868-a8ef-4011-fc2b-611cc0d10994"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Maximum GPU memory allocated: 12.8 GB\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "# Then load pretrained weights\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "start_memory_tracking()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "model = GPTModel(BASE_CONFIG)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "model.to(device)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "model.load_state_dict(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    torch.load(\"model.pth\", map_location=device, weights_only=True)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    ")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "model.to(device)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "model.eval();\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print_memory_usage()"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "4AGvOrcN0KdJ"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Notice that the memory is 2x as large as in the previous session\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- This is because we have the same model in memory twice, for a short period of time:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "  - The first time via `model.to(device)`\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "  - The second time via the code line `model.load_state_dict(torch.load(\"model.pth\", map_location=device, weights_only=True))`; eventually, the loaded model weights will be copied into the model, and the `state_dict` will be discarded, but for a brief amount of time, we have both the main model and the loaded `state_dict` in memory\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- The remaining sections focus on addressing this\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- But first, let's test the model and reset the GPU memory\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "execution_count": 9,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "colab": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "base_uri": "https://localhost:8080/"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "DvlUn-nmmbuj",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "outputId": "11d3ab68-f570-4c1e-c631-fe5547026799"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Maximum GPU memory allocated: 0.0 GB\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "# Test if the model works (no need to track memory here)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "test_input = torch.tensor([[1, 2, 3]]).to(device)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "model.eval()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "with torch.no_grad():\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    model(test_input)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "del model, test_input\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "cleanup()"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "RdPnW3iLLrjX"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    " \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "## 4. Loading weights sequentially"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "FYqtUON602TD"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- One workaround for the problem of having the model weights in GPU memory twice, as highlighted in the previous section, is to load the model sequentially\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Below, we:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "  - first load the model into GPU memory\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "  - then load the model weights into CPU memory\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "  - and finally copy each parameter one by one into GPU memory\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "execution_count": 10,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "colab": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "base_uri": "https://localhost:8080/"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "DOIGTNWTmx9G",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "outputId": "145162e6-aaa6-4c2a-ed8f-f1cf068adb80"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Maximum GPU memory allocated: 6.4 GB\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Maximum GPU memory allocated: 6.7 GB\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "start_memory_tracking()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "model = GPTModel(BASE_CONFIG).to(device)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "state_dict = torch.load(\"model.pth\", map_location=\"cpu\", weights_only=True)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print_memory_usage()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "# Sequentially copy weights to the model's parameters\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "with torch.no_grad():\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    for name, param in model.named_parameters():\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        if name in state_dict:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            param.copy_(state_dict[name].to(device))\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        else:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            print(f\"Warning: {name} not found in state_dict.\")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print_memory_usage()"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "Pn9xD_xL1ZzM"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- As we can see above, the memory usage is much lower than before\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Notice that the memory increases from 6.4 to 6.7 GB because initially, we only have the model in memory, and then we have the model plus 1 parameter tensor in memory (we temporarily move the parameter tensor to the GPU so we can assign it using `\".to\"` the model)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Overall, this is a significant improvement\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Again, let's briefly test the model and then reset the GPU memory for the next section"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "execution_count": 11,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "colab": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "base_uri": "https://localhost:8080/"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "PRHnjA48nJgw",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "outputId": "dcd6b1b2-538f-4862-96a6-a5fcbf3326a4"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Maximum GPU memory allocated: 0.0 GB\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "# Test if the model works (no need to track memory here)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "test_input = torch.tensor([[1, 2, 3]]).to(device)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "model.eval()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "with torch.no_grad():\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    model(test_input)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "del model, test_input, state_dict, param\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "cleanup()"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "5M92LK7usb-Z"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    " \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "## 5. Loading the model with low CPU memory"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "R45qgeB613e2"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- In the previous session, we reduced GPU memory use by loading the weights (`state_dict`) into CPU memory first before copying them one-by-one into the model\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- However, what do we do if we have limited CPU memory?\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- This section uses PyTorch's so-called `\"meta\"` device approach to load a model on machines with large GPU memory but small CPU memory\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- But first, let's define a convenience function to monitor CPU memory"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "execution_count": 12,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "BrcWy0q-3Bbe"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "import os\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "import psutil\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "from threading import Thread\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "def memory_usage_in_gb(func, *args, **kwargs):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    process = psutil.Process(os.getpid())\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    # Measure the baseline memory usage before running the function\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    baseline_mem = process.memory_info().rss / 1024 ** 3  # in GB\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    # Start monitoring memory in a separate thread\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    mem_usage = []\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    done = False\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    def monitor_memory():\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        while not done:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            mem_usage.append(process.memory_info().rss / 1024 ** 3)  # Convert to GB\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            time.sleep(0.1)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    t = Thread(target=monitor_memory)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    t.start()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    # Run the function\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    func(*args, **kwargs)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    # Stop monitoring\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    done = True\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    t.join()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    peak_mem_usage_gb = max(mem_usage) - baseline_mem\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    return peak_mem_usage_gb\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "Ayy30Ytd5hjF"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- To start with, let's track the CPU memory of the sequential weight loading approach from the previous section"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "execution_count": 13,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "colab": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "base_uri": "https://localhost:8080/"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "rCkV6IbQtpVn",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "outputId": "26c0435a-1e3d-4e8f-fbe2-f9655bad61b4"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Maximum GPU memory allocated: 6.4 GB\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Maximum GPU memory allocated: 6.7 GB\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "-> Maximum CPU memory allocated: 6.3 GB\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "def load_sequentially():\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    start_memory_tracking()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    model = GPTModel(BASE_CONFIG).to(device)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    state_dict = torch.load(\"model.pth\", map_location=\"cpu\", weights_only=True)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    print_memory_usage()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    # Sequentially copy weights to the model's parameters\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    with torch.no_grad():\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        for name, param in model.named_parameters():\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            if name in state_dict:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                param.copy_(state_dict[name].to(device))\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            else:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                print(f\"Warning: {name} not found in state_dict.\")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    print_memory_usage()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "peak_memory_used = memory_usage_in_gb(load_sequentially)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(f\"-> Maximum CPU memory allocated: {peak_memory_used:.1f} GB\")"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "UWrmnCML5oKy"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Now, suppose we have a machine with low CPU memory but large GPU memory\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- We can trade off CPU memory and GPU memory usage by introducing PyTorch's so-called \"meta\" device\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- PyTorch's meta device is a special device type that allows you to create tensors without allocating actual memory for their data, effectively creating \"meta\" tensors\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- This is useful for tasks like model analysis or architecture definition, where you need tensor shapes and types without the overhead of memory allocation"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "execution_count": 14,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "colab": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "base_uri": "https://localhost:8080/"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "PBErC_5Yt8ly",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "outputId": "8799db06-191c-47c4-92fa-fbb95d685aa9"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Maximum GPU memory allocated: 12.8 GB\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Maximum GPU memory allocated: 12.8 GB\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "-> Maximum CPU memory allocated: 1.3 GB\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "def load_sequentially_with_meta():\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    start_memory_tracking()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    with torch.device(\"meta\"):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        model = GPTModel(BASE_CONFIG)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    model = model.to_empty(device=device)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    state_dict = torch.load(\"model.pth\", map_location=device, weights_only=True)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    print_memory_usage()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    # Sequentially copy weights to the model's parameters\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    with torch.no_grad():\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        for name, param in model.named_parameters():\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            if name in state_dict:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                param.copy_(state_dict[name])\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            else:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                print(f\"Warning: {name} not found in state_dict.\")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    print_memory_usage()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "peak_memory_used = memory_usage_in_gb(load_sequentially_with_meta)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(f\"-> Maximum CPU memory allocated: {peak_memory_used:.1f} GB\")"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "VpnCABp75-VQ"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- As we can see above, by creating the model on the meta-device and loading the weights directly into GPU memory, we effectively reduced the CPU memory requirements\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- One might ask: \"Is the sequential weight loading still necessary then, and how does that compare to the original approach?\"\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Let's check the simple PyTorch weight loading approach for comparison (from the first weight loading section in this notebook):"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "execution_count": 15,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "colab": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "base_uri": "https://localhost:8080/"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "4f-bqBNRuR39",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "outputId": "f7c0a901-b404-433a-9b93-2bbfa8183c56"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Maximum GPU memory allocated: 12.8 GB\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "-> Maximum CPU memory allocated: 4.4 GB\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "def baseline():\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    start_memory_tracking()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    model = GPTModel(BASE_CONFIG)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    model.to(device)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    model.load_state_dict(torch.load(\"model.pth\", map_location=device, weights_only=True))\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    model.to(device)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    model.eval();\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    print_memory_usage()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "peak_memory_used = memory_usage_in_gb(baseline)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(f\"-> Maximum CPU memory allocated: {peak_memory_used:.1f} GB\")"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "NKAjxbX86xnb"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- As we can see above, the \"simple\" weight loading without the meta device uses more memory\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- In other words, if you have a machine with limited CPU memory, you can use the meta device approach to directly load the model weights into GPU memory to reduce peak CPU memory usage"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    " \n",
							 
						 
					
						
							
								
									
										
										
										
											2024-10-14 14:27:19 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "## 6. Using `mmap=True` (recommmended)"
							 
						 
					
						
							
								
									
										
										
										
											2024-10-14 11:09:55 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- As an intermediate or advanced `torch.load` user, you may wonder how these approaches compare to the `mmap=True` setting in PyTorch\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-10-14 14:27:19 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "- The `mmap=True` setting in PyTorch enables memory-mapped file I/O, which allows the tensor to access data directly from disk storage, thus reducing memory usage by not loading the entire file into RAM if RAM is limited\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Also, see the helpful comment by [mikaylagawarecki](https://github.com/rasbt/LLMs-from-scratch/issues/402)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- At first glance, it may look less efficient than the sequential approaches above:"
							 
						 
					
						
							
								
									
										
										
										
											2024-10-14 11:09:55 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
									
										
										
										
											2024-10-14 14:27:19 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "execution_count": 37,
							 
						 
					
						
							
								
									
										
										
										
											2024-10-14 11:09:55 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "colab": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "base_uri": "https://localhost:8080/"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
									
										
										
										
											2024-10-14 14:27:19 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "id": "GKwV0AMNemuR",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "outputId": "e207f2bf-5c87-498e-80fe-e8c4016ac711"
							 
						 
					
						
							
								
									
										
										
										
											2024-10-14 11:09:55 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Maximum GPU memory allocated: 6.4 GB\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-10-14 14:27:19 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								      "-> Maximum CPU memory allocated: 5.9 GB\n"
							 
						 
					
						
							
								
									
										
										
										
											2024-10-14 11:09:55 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-10-14 14:27:19 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "def best_practices():\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "  with torch.device(\"meta\"):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "      model = GPTModel(BASE_CONFIG)\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-10-14 11:09:55 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-10-14 14:27:19 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "  model.load_state_dict(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "      torch.load(\"model.pth\", map_location=device, weights_only=True, mmap=True),\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "      assign=True\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "  )\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-10-14 11:09:55 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-10-14 14:27:19 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "  print_memory_usage()\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-10-14 11:09:55 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-10-14 14:27:19 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "peak_memory_used = memory_usage_in_gb(best_practices)\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-10-14 11:09:55 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "print(f\"-> Maximum CPU memory allocated: {peak_memory_used:.1f} GB\")"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
									
										
										
										
											2024-10-14 14:27:19 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
									
										
										
										
											2024-10-14 11:09:55 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-10-14 14:27:19 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "- The reason why the CPU RAM usage is so high is that there's enough CPU RAM available on this machine\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- However, if you were to run this on a machine with limited CPU RAM, the `mmap` approach would use less memory"
							 
						 
					
						
							
								
									
										
										
										
											2024-10-14 11:09:55 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    " \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "## 7. Other methods"
							 
						 
					
						
							
								
									
										
										
										
											2024-10-14 10:30:25 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {},
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
									
										
										
										
											2024-10-14 14:27:19 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "- This notebook is focused on simple, built-in methods for loading weights in PyTorch\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- The recommended approach for limited CPU memory cases is the `mmap=True` approach explained enough\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Alternatively, one other option is a brute-force approach that saves and loads each weight tensor separately:"
							 
						 
					
						
							
								
									
										
										
										
											2024-10-14 10:30:25 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "execution_count": 13,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "2CgPEZUIb00w"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "model = GPTModel(BASE_CONFIG)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "# Assume `model` is your trained model\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "state_dict = model.state_dict()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "# Create a directory to store individual parameter files\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "os.makedirs(\"model_parameters\", exist_ok=True)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "# Save each parameter tensor separately\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "for name, param in state_dict.items():\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    torch.save(param.cpu(), f\"model_parameters/{name}.pt\")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "del model"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "execution_count": 16,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "colab": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "base_uri": "https://localhost:8080/"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "gTsmtJK-b4yy",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "outputId": "d361e2d3-e34c-48d7-9047-846c9bfd291e"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Maximum GPU memory allocated: 6.4 GB\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Maximum GPU memory allocated: 6.4 GB\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "-> Maximum CPU memory allocated: 0.3 GB\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "def load_individual_weights():\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    start_memory_tracking()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    with torch.device(\"meta\"):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        model = GPTModel(BASE_CONFIG)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    model = model.to_empty(device=device)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    print_memory_usage()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    param_dir = \"model_parameters\"\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    with torch.no_grad():\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        for name, param in model.named_parameters():\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            weight_path = os.path.join(param_dir, f\"{name}.pt\")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            if os.path.exists(weight_path):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                param_data = torch.load(weight_path, map_location=\"cpu\", weights_only=True)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                param.copy_(param_data)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                del param_data  # Free memory\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            else:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                print(f\"Warning: {name} not found in {param_dir}.\")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    print_memory_usage()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "peak_memory_used = memory_usage_in_gb(load_individual_weights)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(f\"-> Maximum CPU memory allocated: {peak_memory_used:.1f} GB\")"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  "accelerator": "GPU",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  "colab": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "gpuType": "L4",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "provenance": []
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  "kernelspec": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "display_name": "Python 3 (ipykernel)",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "language": "python",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "name": "python3"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  "language_info": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "codemirror_mode": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "name": "ipython",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "version": 3
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "file_extension": ".py",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "mimetype": "text/x-python",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "name": "python",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "nbconvert_exporter": "python",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "pygments_lexer": "ipython3",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "version": "3.11.4"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 "nbformat": 4,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 "nbformat_minor": 4
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								}