{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "\n", "\n", "\n", "\n", "
\n", "\n", "Supplementary code for the Build a Large Language Model From Scratch book by Sebastian Raschka
\n", "
Code repository: https://github.com/rasbt/LLMs-from-scratch\n", "
\n", "
\n", "\n", "
" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## FLOPS Analysis" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- FLOPs (Floating Point Operations Per Second) measure the computational complexity of neural network models by counting the number of floating-point operations executed\n", "- High FLOPs indicate more intensive computation and energy consumption" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# pip install -r requirements-extra.txt" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "thop version: 0.1.1-2209072238\n", "torch version: 2.2.1+cu121\n" ] } ], "source": [ "from importlib.metadata import version\n", "\n", "pkgs = [\n", " \"thop\",\n", " \"torch\",\n", "]\n", "for p in pkgs:\n", " print(f\"{p} version: {version(p)}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ " \n", "# Simple benchmark with fixed batch size" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "GerIdRMXd6g9", "outputId": "ccdd5c71-d221-4a84-f9bc-09557e77162d" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "gpt-small (124M) : 5.1e+11 FLOPS\n", "gpt-medium (355M) : 1.4e+12 FLOPS\n", "gpt-large (774M) : 3.2e+12 FLOPS\n", "gpt-xl (1558M) : 6.4e+12 FLOPS\n" ] } ], "source": [ "import torch\n", "from thop import profile\n", "\n", "from previous_chapters import GPTModel\n", "\n", "\n", "BASE_CONFIG = {\n", " \"vocab_size\": 50257, # Vocabulary size\n", " \"context_length\": 1024, # Context length\n", " \"drop_rate\": 0.0, # Dropout rate\n", " \"qkv_bias\": True # Query-key-value bias\n", "}\n", "\n", "model_configs = {\n", " \"gpt-small (124M)\": {\"emb_dim\": 768, \"n_layers\": 12, \"n_heads\": 12},\n", " \"gpt-medium (355M)\": {\"emb_dim\": 1024, \"n_layers\": 24, \"n_heads\": 16},\n", " \"gpt-large (774M)\": {\"emb_dim\": 1280, \"n_layers\": 36, \"n_heads\": 20},\n", " \"gpt-xl (1558M)\": {\"emb_dim\": 1600, \"n_layers\": 48, \"n_heads\": 25},\n", "}\n", "\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "batch_size = 2\n", "input_tensor = torch.randint(0, 50257, (batch_size, 1024)).to(device)\n", "\n", "for size in model_configs:\n", " BASE_CONFIG.update(model_configs[size])\n", " \n", " model = GPTModel(BASE_CONFIG).bfloat16()\n", " model.to(device)\n", "\n", " # MACS = multiply-accumulate operations\n", " # MACS are typically counted as two FLOPS (one multiply and one accumulate)\n", " macs, params = profile(model, inputs=(input_tensor,), verbose=False)\n", " flops = 2*macs\n", " print(f\"{size:18}: {flops:.1e} FLOPS\")\n", " \n", " del model\n", " torch.cuda.empty_cache()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ " \n", "# Simple benchmark with automatic batch size finding" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "Processing gpt-small (124M)\n", " Batch size 128: 3.2e+13 FLOPS\n", " Batch size 160: 4.0e+13 FLOPS\n", " Batch size 176: 4.5e+13 FLOPS\n", " Batch size 184: 4.7e+13 FLOPS\n", " Batch size 186: 4.7e+13 FLOPS\n", "\n", "Processing gpt-medium (355M)\n", " Batch size 128: 9.3e+13 FLOPS\n", " Batch size 136: 9.8e+13 FLOPS\n", " Batch size 140: 1.0e+14 FLOPS\n", " Batch size 142: 1.0e+14 FLOPS\n", " Batch size 143: 1.0e+14 FLOPS\n", "\n", "Processing gpt-large (774M)\n", " Batch size 128: 2.0e+14 FLOPS\n", "\n", "Processing gpt-xl (1558M)\n", " Batch size 64: 2.0e+14 FLOPS\n", " Batch size 96: 3.1e+14 FLOPS\n" ] } ], "source": [ "for size in model_configs:\n", " print(f\"\\nProcessing {size}\")\n", " config = BASE_CONFIG.copy()\n", " config.update(model_configs[size])\n", "\n", " min_batch_size = 1\n", " max_batch_size = None\n", " max_possible_batch_size = 4096\n", "\n", " while min_batch_size <= max_possible_batch_size:\n", " batch_size = (min_batch_size + max_possible_batch_size) // 2\n", " try:\n", " input_tensor = torch.randint(\n", " 0, config[\"vocab_size\"],\n", " (batch_size, config[\"context_length\"]),\n", " device=device\n", " )\n", "\n", " model = GPTModel(config).bfloat16().to(device)\n", "\n", " # MACS = multiply-accumulate operations\n", " # MACS are typically counted as two FLOPS (one multiply and one accumulate)\n", " macs, params = profile(model, inputs=(input_tensor,), verbose=False)\n", " flops = 2 * macs\n", " print(f\" Batch size {batch_size}: {flops:.1e} FLOPS\")\n", "\n", " # If successful, try a larger batch size\n", " min_batch_size = batch_size + 1\n", " max_batch_size = batch_size\n", "\n", " # Clean up\n", " del model, input_tensor\n", " torch.cuda.empty_cache()\n", "\n", " except RuntimeError as e:\n", " if \"out of memory\" in str(e):\n", " # Try smaller batch size\n", " max_possible_batch_size = batch_size - 1\n", "\n", " # Clean up\n", " try:\n", " del model, input_tensor\n", " torch.cuda.empty_cache()\n", " except NameError:\n", " pass\n", " else:\n", " raise e" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ " \n", "# Benchmark with automatic batch size finding and Model FLOP Utilization (MFU)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- Model FLOPs Utilization (MFU) explanation from the [PaLM paper](https://arxiv.org/abs/2204.02311)\n", "\n", "> We propose a new metric for efficiency that is implementation-independent and permits a cleaner comparison of system efficiency, called model FLOPs utilization (MFU). This is the ratio of the observed throughput (tokens-per-second) relative to the theoretical maximum throughput of a system operating at peak FLOPs. Crucially, the “theoretical maximum” throughput only accounts for the required operations to compute the forward+backward passes, and not rematerialization.\n", "\n", "\n", "$$\\text{MFU} = \\frac{\\text{Observed Tokens per Second}}{\\text{Theoretical Max Tokens per Second}}$$\n", "\n", "where \n", "\n", "$$\\text{Theoretical Max Tokens per Second} = \\frac{\\text{Max FLOPs per Second}}{\\text{Total FLOPs per Token}}$$\n", "\n", "and\n", "\n", "$$\\text{Tokens per Second} = \\frac{\\text{Batch Size} \\times \\text{Sequence Length}}{\\text{Total Time}}$$" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "# Max flops per second provided by the GPU manufacturer\n", "\n", "flops_per_second = {\n", " \"H100\": {\n", " torch.float32: 60e12, # 60 TFLOPs for FP32 on NVIDIA H100\n", " torch.float16: 1.979e15, # 1979 TFLOPs for FP16 on NVIDIA H100\n", " torch.bfloat16: 1.979e15\n", " },\n", " \"L4\": {\n", " torch.float32: 15e12, # 15 TFLOPs for FP32 on NVIDIA L4\n", " torch.float16: 30e12, # 30 TFLOPs for FP16 on NVIDIA L4\n", " torch.bfloat16: 30e12 \n", " },\n", " \"T4\": {\n", " torch.float32: 8.1e12, # 8.1 TFLOPs for FP32 on NVIDIA T4\n", " torch.float16: 130e12, # 130 TFLOPs for FP16 on NVIDIA T4\n", " torch.bfloat16: 130e12\n", " },\n", " \"A10G\": {\n", " torch.float32: 15.6e12, # 15.6 TFLOPs for FP32 on NVIDIA A10G\n", " torch.float16: 78e12, # 78 TFLOPs for FP16 on NVIDIA A10G\n", " torch.bfloat16: 78e12\n", " },\n", " \"A100\": {\n", " torch.float32: 19.5e12, # 19.5 TFLOPs for FP32 on NVIDIA A100\n", " torch.float16: 1.248e15, # 1248 TFLOPs for FP16 on NVIDIA A100\n", " torch.bfloat16: 1.248e15\n", " },\n", " \"H200\": {\n", " torch.float32: 70e12, # 70 TFLOPs for FP32 on NVIDIA H200\n", " torch.float16: 1.2e15, # Assuming 1200 TFLOPs for FP16 on NVIDIA H200\n", " torch.bfloat16: 1.2e15\n", " },\n", " \"RTX_3080\": {\n", " torch.float32: 29.8e12, # 29.8 TFLOPs for FP32 on NVIDIA RTX 3080\n", " torch.float16: 59.6e12, # 59.6 TFLOPs for FP16 on NVIDIA RTX 3080\n", " torch.bfloat16: 59.6e12\n", " },\n", " \"RTX_3090\": {\n", " torch.float32: 35.6e12, # 35.6 TFLOPs for FP32 on NVIDIA RTX 3090\n", " torch.float16: 71.2e12, # 71.2 TFLOPs for FP16 on NVIDIA RTX 3090\n", " torch.bfloat16: 71.2e12\n", " },\n", " \"GTX_1080\": {\n", " torch.float32: 8.9e12, # 8.9 TFLOPs for FP32 on NVIDIA GTX 1080\n", " torch.float16: 8.9e12, # No dedicated FP16 performance; using FP32 value\n", " torch.bfloat16: 8.9e12\n", " },\n", " \"GTX_1080Ti\": {\n", " torch.float32: 11.3e12, # 11.3 TFLOPs for FP32 on NVIDIA GTX 1080Ti\n", " torch.float16: 11.3e12, # No dedicated FP16 performance; using FP32 value\n", " torch.bfloat16: 11.3e12\n", " },\n", " \"GTX_1660\": {\n", " torch.float32: 5e12, # 5 TFLOPs for FP32 on NVIDIA GTX 1660\n", " torch.float16: 5e12, # No dedicated FP16 performance; using FP32 value\n", " torch.bfloat16: 5e12\n", " },\n", " \"GTX_1660Ti\": {\n", " torch.float32: 5.5e12, # 5.5 TFLOPs for FP32 on NVIDIA GTX 1660Ti\n", " torch.float16: 5.5e12, # No dedicated FP16 performance; using FP32 value\n", " torch.bfloat16: 5.5e12\n", " }\n", "}\n" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "GPU Model: L4\n", "\n", "Processing gpt-small (124M)\n", " Batch size 8: Tokens/sec: 14488.21, MFU: 0.3580\n", " Batch size 12: Tokens/sec: 15378.16, MFU: 0.3799\n", "\n", "Processing gpt-medium (355M)\n", " Batch size 2: Tokens/sec: 6493.81, MFU: 0.4591\n", " Batch size 3: Tokens/sec: 6328.82, MFU: 0.4474\n", "\n", "Processing gpt-large (774M)\n", " Batch size 4: Tokens/sec: 3130.38, MFU: 0.4834\n", "\n", "Processing gpt-xl (1558M)\n", " Batch size 2: Tokens/sec: 1896.17, MFU: 0.5897\n" ] } ], "source": [ "import time\n", "\n", "def get_gpu_model(flops_per_second_dict):\n", " device_name = torch.cuda.get_device_name(0)\n", " for model in flops_per_second_dict.keys():\n", " if model in device_name:\n", " return model\n", " return \"Unknown\" # Default if no matching model is found\n", "\n", "\n", "gpu_model = get_gpu_model(flops_per_second)\n", "print(\"GPU Model:\", gpu_model)\n", "\n", "if gpu_model != \"Unknown\":\n", "\n", " for size in model_configs:\n", " print(f\"\\nProcessing {size}\")\n", " config = BASE_CONFIG.copy()\n", " config.update(model_configs[size])\n", "\n", " min_batch_size = 1\n", " max_batch_size = None\n", " max_possible_batch_size = 4096\n", "\n", " while min_batch_size <= max_possible_batch_size:\n", " batch_size = (min_batch_size + max_possible_batch_size) // 2\n", " try:\n", " input_tensor = torch.randint(\n", " 0, config[\"vocab_size\"],\n", " (batch_size, config[\"context_length\"]),\n", " device=device\n", " )\n", "\n", " model = GPTModel(config).bfloat16().to(device)\n", " model.train()\n", "\n", " # Start timing\n", " torch.cuda.synchronize()\n", " start_time = time.time()\n", "\n", " # Forward & backward pass\n", " output = model(input_tensor)\n", " loss = output.sum() # Compute a dummy loss \n", " loss.backward()\n", "\n", " # End timing\n", " torch.cuda.synchronize()\n", " end_time = time.time()\n", "\n", " total_time_seconds = end_time - start_time\n", "\n", " # Calculate FLOPs for forward pass\n", " macs, params = profile(model, inputs=(input_tensor,), verbose=False)\n", " flops_forward = 2 * macs # Assuming one MAC equals two FLOPs\n", "\n", " # Estimate FLOPs for backward pass (typically 2x forward FLOPs)\n", " flops_backward = 2 * flops_forward\n", "\n", " # Total FLOPs for forward + backward passes\n", " total_flops = flops_forward + flops_backward # Or total_flops = flops_forward * 3\n", "\n", " data_type = next(model.parameters()).dtype\n", " max_flops_per_second = flops_per_second[gpu_model].get(data_type, 0)\n", "\n", " # Compute tokens per second\n", " tokens_processed = batch_size * config[\"context_length\"]\n", " tokens_per_second = tokens_processed / total_time_seconds\n", "\n", " # Compute FLOPs per token\n", " flops_per_token = total_flops / tokens_processed\n", "\n", " # Compute theoretical max tokens per second\n", " if flops_per_token > 0:\n", " theoretical_max_tokens_per_second = max_flops_per_second / flops_per_token\n", " else:\n", " theoretical_max_tokens_per_second = 0 # Avoid division by zero\n", "\n", " # Compute MFU\n", " if theoretical_max_tokens_per_second > 0:\n", " mfu = tokens_per_second / theoretical_max_tokens_per_second\n", " else:\n", " mfu = 0 # Avoid division by zero\n", "\n", " print(f\" Batch size {batch_size}: Tokens/sec: {tokens_per_second:.2f}, MFU: {mfu:.4f}\")\n", "\n", " # If successful, try a larger batch size\n", " min_batch_size = batch_size + 1\n", " max_batch_size = batch_size\n", "\n", " # Clean up\n", " del model, input_tensor, output, loss\n", " torch.cuda.empty_cache()\n", "\n", " except RuntimeError as e:\n", " if \"out of memory\" in str(e).lower():\n", " # Try smaller batch size\n", " max_possible_batch_size = batch_size - 1\n", "\n", " # Clean up\n", " try:\n", " del model, input_tensor\n", " torch.cuda.empty_cache()\n", " except NameError:\n", " pass\n", " else:\n", " raise e\n", "\n", "else:\n", " print(\"Unknown GPU model. Please update the flops_per_second dictionary with your GPU information.\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "- Note that the batch sizes are smaller than previously because we also carry out the backward pass here, which is more memory-intensive" ] } ], "metadata": { "accelerator": "GPU", "colab": { "gpuType": "A100", "machine_shape": "hm", "provenance": [] }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.4" } }, "nbformat": 4, "nbformat_minor": 4 }