2024-05-23 20:35:41 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								{
							 
						 
					
						
							
								
									
										
										
										
											2024-10-11 12:15:01 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								 "cells": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "FtQYMbLvgzO-"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "<table style=\"width:100%\">\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "<tr>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "<td style=\"vertical-align:middle; text-align:left;\">\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "<font size=\"2\">\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "Supplementary code for the <a href=\"http://mng.bz/orYv\">Build a Large Language Model From Scratch</a> book by <a href=\"https://sebastianraschka.com\">Sebastian Raschka</a><br>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "<br>Code repository: <a href=\"https://github.com/rasbt/LLMs-from-scratch\">https://github.com/rasbt/LLMs-from-scratch</a>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "</font>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "</td>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "<td style=\"vertical-align:middle; text-align:left;\">\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "<a href=\"http://mng.bz/orYv\"><img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/cover-small.webp\" width=\"100px\"></a>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "</td>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "</tr>\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "</table>"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "EbrESHKtgzPA"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "# FLOPS Analysis"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "xS2WjniMgzPB"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- FLOPs (Floating Point Operations Per Second) measure the computational complexity of neural network models by counting the number of floating-point operations executed\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- High FLOPs indicate more intensive computation and energy consumption"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "execution_count": null,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "L01-NzkggzPB"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "# pip install -r requirements-extra.txt"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "execution_count": null,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "colab": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "base_uri": "https://localhost:8080/"
							 
						 
					
						
							
								
									
										
										
										
											2024-10-11 12:12:05 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
									
										
										
										
											2024-10-11 12:15:01 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "id": "ObzfVatqgzPC",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "outputId": "3ead6a41-ac38-4db1-9fc3-012fb3ad18cd"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
									
										
										
										
											2024-10-11 12:12:05 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
									
										
										
										
											2024-10-11 12:15:01 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "thop version: 0.1.1-2209072238\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "torch version: 2.4.1+cu121\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "from importlib.metadata import version\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "pkgs = [\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"thop\",\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"torch\",\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "]\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "for p in pkgs:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    print(f\"{p} version: {version(p)}\")"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "74UpjSLjgzPC"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    " \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "# Simple benchmark with fixed batch size"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "90pnCK39gzPD"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- forward pass only"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "execution_count": null,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "colab": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "base_uri": "https://localhost:8080/"
							 
						 
					
						
							
								
									
										
										
										
											2024-10-11 12:12:05 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
									
										
										
										
											2024-10-11 12:15:01 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "id": "GerIdRMXd6g9",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "outputId": "177c6d00-a817-40fe-badd-95cfa8ac9b51"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
									
										
										
										
											2024-10-11 12:12:05 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
									
										
										
										
											2024-10-11 12:15:01 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "gpt-small (124M)  : 5.1e+11 FLOPS\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "gpt-medium (355M) : 1.4e+12 FLOPS\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "gpt-large (774M)  : 3.2e+12 FLOPS\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "gpt-xl (1558M)    : 6.4e+12 FLOPS\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "import torch\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "from thop import profile\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
									
										
										
										
											2025-03-23 19:28:49 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "# For installation instructions, see:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "# https://github.com/rasbt/LLMs-from-scratch/tree/main/pkg\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "from llms_from_scratch.ch04 import GPTModel\n",
							 
						 
					
						
							
								
									
										
										
										
											2024-10-11 12:15:01 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "BASE_CONFIG = {\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"vocab_size\": 50257,     # Vocabulary size\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"context_length\": 1024,  # Context length\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"drop_rate\": 0.0,        # Dropout rate\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"qkv_bias\": True         # Query-key-value bias\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "}\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "model_configs = {\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"gpt-small (124M)\": {\"emb_dim\": 768, \"n_layers\": 12, \"n_heads\": 12},\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"gpt-medium (355M)\": {\"emb_dim\": 1024, \"n_layers\": 24, \"n_heads\": 16},\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"gpt-large (774M)\": {\"emb_dim\": 1280, \"n_layers\": 36, \"n_heads\": 20},\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"gpt-xl (1558M)\": {\"emb_dim\": 1600, \"n_layers\": 48, \"n_heads\": 25},\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "}\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "batch_size = 2\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "input_tensor = torch.randint(0, 50257, (batch_size, 1024)).to(device)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "for size in model_configs:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    BASE_CONFIG.update(model_configs[size])\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    model = GPTModel(BASE_CONFIG).bfloat16()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    model.to(device)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    # MACS = multiply-accumulate operations\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    # MACS are typically counted as two FLOPS (one multiply and one accumulate)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    macs, params = profile(model, inputs=(input_tensor,), verbose=False)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    flops = 2*macs\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    print(f\"{size:18}: {flops:.1e} FLOPS\")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    del model\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    torch.cuda.empty_cache()"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "_S6V05QmgzPD"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    " \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "# Simple benchmark with automatic batch size finding"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "amw4E983gzPD"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- forward pass only"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "execution_count": null,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "colab": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "base_uri": "https://localhost:8080/"
							 
						 
					
						
							
								
									
										
										
										
											2024-10-11 12:12:05 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
									
										
										
										
											2024-10-11 12:15:01 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "id": "h08VOiqpgzPE",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "outputId": "a6a90ef8-28fb-4b55-9268-6915b0c84c51"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
									
										
										
										
											2024-10-11 12:12:05 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
									
										
										
										
											2024-10-11 12:15:01 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Processing gpt-small (124M)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "  Batch size 256: 6.5e+13 FLOPS\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "  Batch size 384: 9.7e+13 FLOPS\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "  Batch size 388: 9.8e+13 FLOPS\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "  Batch size 389: 9.8e+13 FLOPS\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Processing gpt-medium (355M)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "  Batch size 256: 1.9e+14 FLOPS\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "  Batch size 260: 1.9e+14 FLOPS\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "  Batch size 262: 1.9e+14 FLOPS\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "  Batch size 263: 1.9e+14 FLOPS\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Processing gpt-large (774M)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "  Batch size 256: 4.0e+14 FLOPS\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Processing gpt-xl (1558M)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "  Batch size 128: 4.1e+14 FLOPS\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "  Batch size 136: 4.3e+14 FLOPS\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "  Batch size 140: 4.5e+14 FLOPS\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "  Batch size 142: 4.5e+14 FLOPS\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "  Batch size 143: 4.6e+14 FLOPS\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
									
										
										
										
											2024-10-11 12:12:05 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
									
										
										
										
											2024-10-11 12:15:01 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "for size in model_configs:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    print(f\"\\nProcessing {size}\")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    config = BASE_CONFIG.copy()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    config.update(model_configs[size])\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    min_batch_size = 1\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    max_batch_size = None\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    max_possible_batch_size = 4096\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    while min_batch_size <= max_possible_batch_size:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        batch_size = (min_batch_size + max_possible_batch_size) // 2\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        try:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            input_tensor = torch.randint(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                0, config[\"vocab_size\"],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                (batch_size, config[\"context_length\"]),\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                device=device\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            )\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            model = GPTModel(config).bfloat16().to(device)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            # MACS = multiply-accumulate operations\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            # MACS are typically counted as two FLOPS (one multiply and one accumulate)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            macs, params = profile(model, inputs=(input_tensor,), verbose=False)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            flops = 2 * macs\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            print(f\"  Batch size {batch_size}: {flops:.1e} FLOPS\")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            # If successful, try a larger batch size\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            min_batch_size = batch_size + 1\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            max_batch_size = batch_size\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            # Clean up\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            del model, input_tensor\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            torch.cuda.empty_cache()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        except RuntimeError as e:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            if \"out of memory\" in str(e):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                # Try smaller batch size\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                max_possible_batch_size = batch_size - 1\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                # Clean up\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                try:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                    del model, input_tensor\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                    torch.cuda.empty_cache()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                except NameError:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                    pass\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            else:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                raise e"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "V4lD7tfcgzPE"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    " \n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "# Benchmark with automatic batch size finding and Model FLOP Utilization (MFU)"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "70Y2mblVgzPE"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Model FLOPs Utilization (MFU) explanation from the [PaLM paper](https://arxiv.org/abs/2204.02311)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "> We propose a new metric for efficiency that is implementation-independent and permits a cleaner comparison of system efficiency, called model FLOPs utilization (MFU). This is the ratio of the observed throughput (tokens-per-second) relative to the theoretical maximum throughput of a system operating at peak FLOPs. Crucially, the “theoretical maximum” throughput only accounts for the required operations to compute the forward+backward passes, and not rematerialization.\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "$$\\text{MFU} = \\frac{\\text{Observed Tokens per Second}}{\\text{Theoretical Max Tokens per Second}}$$\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "where\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "$$\\text{Theoretical Max Tokens per Second} = \\frac{\\text{Max FLOPs per Second}}{\\text{Total FLOPs per Token}}$$\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "and\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "$$\\text{Tokens per Second} = \\frac{\\text{Batch Size} \\times \\text{Sequence Length}}{\\text{Total Time}}$$"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "TKttjC8xgzPF"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- forward and backward pass"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "execution_count": null,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "6aO4rjtNgzPF"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "# Theoretical max flops per second provided by the GPU manufacturer\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "flops_per_second = {\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    # https://www.techpowerup.com/gpu-specs/h100-pcie-80-gb.c3899\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"H100\": {\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        torch.float32: 51.22e12,  # 51.22 TFLOPs for FP32 on NVIDIA H100\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        torch.float16: 204.9e12,  # 204.9 TFLOPs for FP16 on NVIDIA H100\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        torch.bfloat16: 204.9e12\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    },\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    # https://www.techpowerup.com/gpu-specs/l4.c4091\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"L4\": {\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        torch.float32: 30.29e12,  # 30.29 TFLOPs for FP32 on NVIDIA L4\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        torch.float16: 30.29e12,  # 30.29 TFLOPs for FP16 on NVIDIA L4\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        torch.bfloat16: 30.29e12\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    },\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    # https://www.techpowerup.com/gpu-specs/tesla-t4.c3316\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"T4\": {\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        torch.float32: 8.1e12,  # 8.1 TFLOPs for FP32 on NVIDIA T4\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        torch.float16: 65.13e12,  # 65.13 TFLOPs for FP16 on NVIDIA T4\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        torch.bfloat16: 65.13e12\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    },\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    # https://www.techpowerup.com/gpu-specs/a10g.c3798\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"A10G\": {\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        torch.float32: 31.52e12,  # 31.52 TFLOPs for FP32 on NVIDIA A10G\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        torch.float16: 31.52e12,  # 31.52 TFLOPs for FP16 on NVIDIA A10G\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        torch.bfloat16: 31.52e12\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    },\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    # https://www.techpowerup.com/gpu-specs/a100-pcie-40-gb.c3623\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"A100\": {\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        torch.float32: 19.49e12,  # 19.49 TFLOPs for FP32 on NVIDIA A100\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        torch.float16: 77.97e12,  # 77.97 TFLOPs for FP16 on NVIDIA A100\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        torch.bfloat16: 77.97e12\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    },\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    # https://www.techpowerup.com/gpu-specs/geforce-rtx-3080.c3621\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"RTX_3080\": {\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        torch.float32: 29.77e12,  # 29.77 TFLOPs for FP32 on NVIDIA RTX 3080\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        torch.float16: 29.77e12,  # 29.77 TFLOPs for FP16 on NVIDIA RTX 3080\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        torch.bfloat16: 29.77e12\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    },\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    # https://www.techpowerup.com/gpu-specs/geforce-rtx-3090.c3622\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    \"RTX_3090\": {\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        torch.float32: 35.58e12,  # 35.58 TFLOPs for FP32 on NVIDIA RTX 3090\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        torch.float16: 35.58e12,  # 35.58 TFLOPs for FP16 on NVIDIA RTX 3090\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        torch.bfloat16: 35.58e12\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    }\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "}\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "code",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "execution_count": null,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
									
										
										
										
											2024-10-11 12:12:05 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "colab": {
							 
						 
					
						
							
								
									
										
										
										
											2024-10-11 12:15:01 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								     "background_save": true,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "base_uri": "https://localhost:8080/"
							 
						 
					
						
							
								
									
										
										
										
											2024-10-11 12:12:05 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    },
							 
						 
					
						
							
								
									
										
										
										
											2024-10-11 12:15:01 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    "id": "HW5qWfE7gzPF",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "outputId": "bb1663bc-ee66-44f1-f54d-0bb66ee0d0c2"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "outputs": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "name": "stdout",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "output_type": "stream",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     "text": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "GPU Model: A100\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Processing gpt-small (124M)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "  Batch size 16: Tokens/sec: 34248.82, MFU: 0.3256\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "  Batch size 24: Tokens/sec: 62568.34, MFU: 0.5948\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Processing gpt-medium (355M)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "  Batch size 4: Tokens/sec: 20159.93, MFU: 0.5483\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "  Batch size 6: Tokens/sec: 21717.66, MFU: 0.5907\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "  Batch size 7: Tokens/sec: 22536.25, MFU: 0.6130\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Processing gpt-large (774M)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "  Batch size 8: Tokens/sec: 12465.21, MFU: 0.7406\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "Processing gpt-xl (1558M)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								      "  Batch size 4: Tokens/sec: 6779.92, MFU: 0.8113\n"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								     ]
							 
						 
					
						
							
								
									
										
										
										
											2024-10-10 19:42:53 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								    }
							 
						 
					
						
							
								
									
										
										
										
											2024-10-11 12:15:01 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "import time\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "def get_gpu_model(flops_per_second_dict):\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    device_name = torch.cuda.get_device_name(0)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    for model in flops_per_second_dict.keys():\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        if model in device_name:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            return model\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    return \"Unknown\"  # Default if no matching model is found\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "gpu_model = get_gpu_model(flops_per_second)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "print(\"GPU Model:\", gpu_model)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "if gpu_model != \"Unknown\":\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    for size in model_configs:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        print(f\"\\nProcessing {size}\")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        config = BASE_CONFIG.copy()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        config.update(model_configs[size])\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        min_batch_size = 1\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        max_batch_size = None\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        max_possible_batch_size = 4096\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "        while min_batch_size <= max_possible_batch_size:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            batch_size = (min_batch_size + max_possible_batch_size) // 2\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            try:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                input_tensor = torch.randint(\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                    0, config[\"vocab_size\"],\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                    (batch_size, config[\"context_length\"]),\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                    device=device\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                )\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                model = GPTModel(config).bfloat16().to(device)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                model.train()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                # Start timing\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                torch.cuda.synchronize()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                start_time = time.time()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                # Forward & backward pass\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                output = model(input_tensor)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                loss = output.sum()  # Compute a dummy loss\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                loss.backward()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                # End timing\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                torch.cuda.synchronize()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                end_time = time.time()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                total_time_seconds = end_time - start_time\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                # Calculate FLOPs for forward pass\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                macs, params = profile(model, inputs=(input_tensor,), verbose=False)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                flops_forward = 2 * macs  # Assuming one MAC equals two FLOPs\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                # Estimate FLOPs for backward pass (typically 2x forward FLOPs)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                flops_backward = 2 * flops_forward\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                # Total FLOPs for forward + backward passes\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                total_flops = flops_forward + flops_backward  # Or total_flops = flops_forward * 3\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                data_type = next(model.parameters()).dtype\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                max_flops_per_second = flops_per_second[gpu_model].get(data_type, 0)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                # Compute tokens per second\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                tokens_processed = batch_size * config[\"context_length\"]\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                tokens_per_second = tokens_processed / total_time_seconds\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                # Compute FLOPs per token\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                flops_per_token = total_flops / tokens_processed\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                # Compute theoretical max tokens per second\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                if flops_per_token > 0:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                    theoretical_max_tokens_per_second = max_flops_per_second / flops_per_token\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                else:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                    theoretical_max_tokens_per_second = 0  # Avoid division by zero\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                # Compute MFU\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                if theoretical_max_tokens_per_second > 0:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                    mfu = tokens_per_second / theoretical_max_tokens_per_second\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                else:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                    mfu = 0  # Avoid division by zero\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                print(f\"  Batch size {batch_size}: Tokens/sec: {tokens_per_second:.2f}, MFU: {mfu:.4f}\")\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                # If successful, try a larger batch size\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                min_batch_size = batch_size + 1\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                max_batch_size = batch_size\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                # Clean up\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                del model, input_tensor, output, loss\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                torch.cuda.empty_cache()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "            except RuntimeError as e:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                if \"out of memory\" in str(e).lower():\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                    # Try smaller batch size\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                    max_possible_batch_size = batch_size - 1\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                    # Clean up\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                    try:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                        del model, input_tensor\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                        torch.cuda.empty_cache()\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                    except NameError:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                        pass\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                else:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "                    raise e\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "else:\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "    print(\"Unknown GPU model. Please update the flops_per_second dictionary with your GPU information.\")"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "cell_type": "markdown",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "id": "LovmswRigzPG"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "source": [
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- a value of 1.0 is best (equal to 100%)\n",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "- Note that the batch sizes are smaller than previously because we also carry out the backward pass here, which is more memory-intensive"
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   ]
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 ],
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 "metadata": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  "accelerator": "GPU",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  "colab": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "gpuType": "A100",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "machine_shape": "hm",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "provenance": []
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								  "kernelspec": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "display_name": "Python 3 (ipykernel)",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "language": "python",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "name": "python3"
							 
						 
					
						
							
								
									
										
										
										
											2024-05-23 20:35:41 -05:00 
										
									 
								 
							 
							
								
							 
							
								 
							
							
								  },
							 
						 
					
						
							
								
									
										
										
										
											2024-10-11 12:15:01 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  "language_info": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "codemirror_mode": {
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "name": "ipython",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								    "version": 3
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "file_extension": ".py",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "mimetype": "text/x-python",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "name": "python",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "nbconvert_exporter": "python",
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								   "pygments_lexer": "ipython3",
							 
						 
					
						
							
								
									
										
										
										
											2025-03-23 19:28:49 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								   "version": "3.10.16"
							 
						 
					
						
							
								
									
										
										
										
											2024-10-11 12:15:01 -05:00 
										
									 
								 
							 
							
								
									
										 
								
							 
							
								 
							
							
								  }
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 },
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 "nbformat": 4,
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								 "nbformat_minor": 4
							 
						 
					
						
							
								
							 
							
								
							 
							
								 
							
							
								}