Alt weight loading code via PyTorch (#585)

* Alt weight loading code via PyTorch

* commit additional files
This commit is contained in:
Sebastian Raschka 2025-03-27 20:10:23 -05:00 committed by GitHub
parent e07a7abdd5
commit e55e3e88e1
7 changed files with 535 additions and 18 deletions

View File

@ -113,7 +113,7 @@ Several folders contain optional materials as a bonus for interested readers:
- **Chapter 4: Implementing a GPT model from scratch** - **Chapter 4: Implementing a GPT model from scratch**
- [FLOPS Analysis](ch04/02_performance-analysis/flops-analysis.ipynb) - [FLOPS Analysis](ch04/02_performance-analysis/flops-analysis.ipynb)
- **Chapter 5: Pretraining on unlabeled data:** - **Chapter 5: Pretraining on unlabeled data:**
- [Alternative Weight Loading from Hugging Face Model Hub using Transformers](ch05/02_alternative_weight_loading/weight-loading-hf-transformers.ipynb) - [Alternative Weight Loading Methods](ch05/02_alternative_weight_loading/)
- [Pretraining GPT on the Project Gutenberg Dataset](ch05/03_bonus_pretraining_on_gutenberg) - [Pretraining GPT on the Project Gutenberg Dataset](ch05/03_bonus_pretraining_on_gutenberg)
- [Adding Bells and Whistles to the Training Loop](ch05/04_learning_rate_schedulers) - [Adding Bells and Whistles to the Training Loop](ch05/04_learning_rate_schedulers)
- [Optimizing Hyperparameters for Pretraining](ch05/05_bonus_hparam_tuning) - [Optimizing Hyperparameters for Pretraining](ch05/05_bonus_hparam_tuning)

View File

@ -2133,20 +2133,53 @@
"id": "127ddbdb-3878-4669-9a39-d231fbdfb834", "id": "127ddbdb-3878-4669-9a39-d231fbdfb834",
"metadata": {}, "metadata": {},
"source": [ "source": [
"<span style=\"color:darkred\">\n", "---\n",
" <ul>\n", "\n",
" <li>For an alternative way to load the weights from the Hugging Face Hub, see <a href=\"../02_alternative_weight_loading\">../02_alternative_weight_loading</a></li>\n", "---\n",
" <ul>\n", "\n",
" <li>This is useful if:</li>\n", "\n",
" <ul>\n", "⚠️ **Note: Some users may encounter issues in this section due to TensorFlow compatibility problems, particularly on certain Windows systems. TensorFlow is required here only to load the original OpenAI GPT-2 weight files, which we then convert to PyTorch.\n",
" <li>the weights are temporarily unavailable</li>\n", "If you're running into TensorFlow-related issues, you can use the alternative code below instead of the remaining code in this section.\n",
" <li>a company VPN only permits downloads from the Hugging Face Hub but not from the OpenAI CDN, for example</li>\n", "This alternative is based on pre-converted PyTorch weights, created using the same conversion process described in the previous section. For details, refer to the notebook:\n",
" <li>you are having trouble with the TensorFlow installation (the original weights are stored in TensorFlow files)</li>\n", "[../02_alternative_weight_loading/weight-loading-pytorch.ipynb](../02_alternative_weight_loading/weight-loading-pytorch.ipynb) notebook.**\n",
" </ul>\n", "\n",
" </ul>\n", "```python\n",
" <li>The <a href=\"../02_alternative_weight_loading\">../02_alternative_weight_loading</a> code notebooks are replacements for the remainder of this section 5.5</li>\n", "file_name = \"gpt2-small-124M.pth\"\n",
" </ul>\n", "# file_name = \"gpt2-medium-355M.pth\"\n",
"</span>\n" "# file_name = \"gpt2-large-774M.pth\"\n",
"# file_name = \"gpt2-xl-1558M.pth\"\n",
"\n",
"url = f\"https://huggingface.co/rasbt/gpt2-from-scratch-pytorch/resolve/main/{file_name}\"\n",
"\n",
"if not os.path.exists(file_name):\n",
" urllib.request.urlretrieve(url, file_name)\n",
" print(f\"Downloaded to {file_name}\")\n",
"\n",
"gpt = GPTModel(BASE_CONFIG)\n",
"gpt.load_state_dict(torch.load(file_name, weights_only=True))\n",
"gpt.eval()\n",
"\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"gpt.to(device);\n",
"\n",
"\n",
"torch.manual_seed(123)\n",
"\n",
"token_ids = generate(\n",
" model=gpt,\n",
" idx=text_to_token_ids(\"Every effort moves you\", tokenizer).to(device),\n",
" max_new_tokens=25,\n",
" context_size=NEW_CONFIG[\"context_length\"],\n",
" top_k=50,\n",
" temperature=1.5\n",
")\n",
"\n",
"print(\"Output text:\\n\", token_ids_to_text(token_ids, tokenizer))\n",
"```\n",
"\n",
"---\n",
"\n",
"---"
] ]
}, },
{ {
@ -2197,7 +2230,10 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"# Relative import from the gpt_download.py contained in this folder\n", "# Relative import from the gpt_download.py contained in this folder\n",
"from gpt_download import download_and_load_gpt2" "\n",
"from gpt_download import download_and_load_gpt2\n",
"# Alternatively:\n",
"# from llms_from_scratch.ch05 import download_and_load_gpt2"
] ]
}, },
{ {

View File

@ -2,6 +2,8 @@
This folder contains alternative weight loading strategies in case the weights become unavailable from OpenAI. This folder contains alternative weight loading strategies in case the weights become unavailable from OpenAI.
- [weight-loading-pytorch.ipynb](weight-loading-pytorch.ipynb): (Recommended) contains code to load the weights from PyTorch state dicts that I created by converting the original TensorFlow weights
- [weight-loading-hf-transformers.ipynb](weight-loading-hf-transformers.ipynb): contains code to load the weights from the Hugging Face Model Hub via the `transformers` library - [weight-loading-hf-transformers.ipynb](weight-loading-hf-transformers.ipynb): contains code to load the weights from the Hugging Face Model Hub via the `transformers` library
- [weight-loading-hf-safetensors.ipynb](weight-loading-hf-safetensors.ipynb): contains code to load the weights from the Hugging Face Model Hub via the `safetensors` library directly (skipping the instantiation of a Hugging Face transformer model) - [weight-loading-hf-safetensors.ipynb](weight-loading-hf-safetensors.ipynb): contains code to load the weights from the Hugging Face Model Hub via the `safetensors` library directly (skipping the instantiation of a Hugging Face transformer model)

View File

@ -0,0 +1,356 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "6d6bc54f-2b16-4b0f-be69-957eed5d112f",
"metadata": {},
"source": [
"<table style=\"width:100%\">\n",
"<tr>\n",
"<td style=\"vertical-align:middle; text-align:left;\">\n",
"<font size=\"2\">\n",
"Supplementary code for the <a href=\"http://mng.bz/orYv\">Build a Large Language Model From Scratch</a> book by <a href=\"https://sebastianraschka.com\">Sebastian Raschka</a><br>\n",
"<br>Code repository: <a href=\"https://github.com/rasbt/LLMs-from-scratch\">https://github.com/rasbt/LLMs-from-scratch</a>\n",
"</font>\n",
"</td>\n",
"<td style=\"vertical-align:middle; text-align:left;\">\n",
"<a href=\"http://mng.bz/orYv\"><img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/cover-small.webp\" width=\"100px\"></a>\n",
"</td>\n",
"</tr>\n",
"</table>"
]
},
{
"cell_type": "markdown",
"id": "72953590-5363-4398-85ce-54bde07f3d8a",
"metadata": {},
"source": [
"# Bonus Code for Chapter 5"
]
},
{
"cell_type": "markdown",
"id": "1a4ab5ee-e7b9-45d3-a82b-a12bcfc0945a",
"metadata": {},
"source": [
"## Alternative Weight Loading from PyTorch state dicts"
]
},
{
"cell_type": "markdown",
"id": "b2feea87-49f0-48b9-b925-b8f0dda4096f",
"metadata": {},
"source": [
"- In the main chapter, we loaded the GPT model weights directly from OpenAI\n",
"- This notebook provides alternative weight loading code to load the model weights from PyTorch state dict files that I created from the original TensorFlow files and uploaded to the [Hugging Face Model Hub](https://huggingface.co/docs/hub/en/models-the-hub) at [https://huggingface.co/rasbt/gpt2-from-scratch-pytorch](https://huggingface.co/rasbt/gpt2-from-scratch-pytorch)\n",
"- This is conceptually the same as loading weights of a PyTorch model from via the state-dict method described in chapter 5:\n",
"\n",
"```python\n",
"state_dict = torch.load(\"model_state_dict.pth\")\n",
"model.load_state_dict(state_dict) \n",
"```"
]
},
{
"cell_type": "markdown",
"id": "e3f9fbb2-3e39-41ee-8a08-58ba0434a8f3",
"metadata": {},
"source": [
"### Choose model"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "b0467eff-b43c-4a38-93e8-5ed87a5fc2b1",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch version: 2.6.0\n"
]
}
],
"source": [
"from importlib.metadata import version\n",
"\n",
"pkgs = [\"torch\"]\n",
"for p in pkgs:\n",
" print(f\"{p} version: {version(p)}\")"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "9ea9b1bc-7881-46ad-9555-27a9cf23faa7",
"metadata": {},
"outputs": [],
"source": [
"BASE_CONFIG = {\n",
" \"vocab_size\": 50257, # Vocabulary size\n",
" \"context_length\": 1024, # Context length\n",
" \"drop_rate\": 0.0, # Dropout rate\n",
" \"qkv_bias\": True # Query-key-value bias\n",
"}\n",
"\n",
"model_configs = {\n",
" \"gpt2-small (124M)\": {\"emb_dim\": 768, \"n_layers\": 12, \"n_heads\": 12},\n",
" \"gpt2-medium (355M)\": {\"emb_dim\": 1024, \"n_layers\": 24, \"n_heads\": 16},\n",
" \"gpt2-large (774M)\": {\"emb_dim\": 1280, \"n_layers\": 36, \"n_heads\": 20},\n",
" \"gpt2-xl (1558M)\": {\"emb_dim\": 1600, \"n_layers\": 48, \"n_heads\": 25},\n",
"}\n",
"\n",
"\n",
"CHOOSE_MODEL = \"gpt2-small (124M)\"\n",
"BASE_CONFIG.update(model_configs[CHOOSE_MODEL])"
]
},
{
"cell_type": "markdown",
"id": "d78fc2b0-ba27-4aff-8aa3-bc6e04fca69d",
"metadata": {},
"source": [
"### Download file"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "ca224672-a0f7-4b39-9bc9-19ddde69487b",
"metadata": {},
"outputs": [],
"source": [
"file_name = \"gpt2-small-124M.pth\"\n",
"# file_name = \"gpt2-medium-355M.pth\"\n",
"# file_name = \"gpt2-large-774M.pth\"\n",
"# file_name = \"gpt2-xl-1558M.pth\""
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "e7b22375-6fac-4e90-9063-daa4de86c778",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Downloaded to gpt2-small-124M.pth\n"
]
}
],
"source": [
"import os\n",
"import urllib.request\n",
"\n",
"url = f\"https://huggingface.co/rasbt/gpt2-from-scratch-pytorch/resolve/main/{file_name}\"\n",
"\n",
"if not os.path.exists(file_name):\n",
" urllib.request.urlretrieve(url, file_name)\n",
" print(f\"Downloaded to {file_name}\")"
]
},
{
"cell_type": "markdown",
"id": "e61f0990-74cf-4b6d-85e5-4c7d0554db32",
"metadata": {},
"source": [
"### Load weights"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "cda44d37-92c0-4c19-a70a-15711513afce",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from llms_from_scratch.ch04 import GPTModel\n",
"# For llms_from_scratch installation instructions, see:\n",
"# https://github.com/rasbt/LLMs-from-scratch/tree/main/pkg\n",
"\n",
"\n",
"gpt = GPTModel(BASE_CONFIG)\n",
"gpt.load_state_dict(torch.load(file_name, weights_only=True))\n",
"gpt.eval()\n",
"\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"gpt.to(device);"
]
},
{
"cell_type": "markdown",
"id": "e0297fc4-11dc-4093-922f-dcaf85a75344",
"metadata": {},
"source": [
"### Generate text"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "4ddd0d51-3ade-4890-9bab-d63f141d095f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Output text:\n",
" Every effort moves forward, but it's not enough.\n",
"\n",
"\"I'm not going to sit here and say, 'I'm not going to do this,'\n"
]
}
],
"source": [
"import tiktoken\n",
"from llms_from_scratch.ch05 import generate, text_to_token_ids, token_ids_to_text\n",
"\n",
"\n",
"torch.manual_seed(123)\n",
"\n",
"tokenizer = tiktoken.get_encoding(\"gpt2\")\n",
"\n",
"token_ids = generate(\n",
" model=gpt.to(device),\n",
" idx=text_to_token_ids(\"Every effort moves\", tokenizer).to(device),\n",
" max_new_tokens=30,\n",
" context_size=BASE_CONFIG[\"context_length\"],\n",
" top_k=1,\n",
" temperature=1.0\n",
")\n",
"\n",
"print(\"Output text:\\n\", token_ids_to_text(token_ids, tokenizer))"
]
},
{
"cell_type": "markdown",
"id": "aa4a7912-ae51-4786-8ef4-42bd53682932",
"metadata": {},
"source": [
"## Alternative safetensors file"
]
},
{
"cell_type": "markdown",
"id": "2f774001-9cda-4b1f-88c5-ef99786a612b",
"metadata": {},
"source": [
"- In addition, the [https://huggingface.co/rasbt/gpt2-from-scratch-pytorch](https://huggingface.co/rasbt/gpt2-from-scratch-pytorch) repository contains so-called `.safetensors` versions of the state dicts\n",
"- The appeal of `.safetensors` files lies in their secure design, as they only store tensor data and avoid the execution of potentially malicious code during loading\n",
"- In newer versions of PyTorch (e.g., 2.0 and newer), a `weights_only=True` argument can be used with `torch.load` (e.g., `torch.load(\"model_state_dict.pth\", weights_only=True)`) to improve safety by skipping the execution of code and loading only the weights (this is now enabled by default in PyTorch 2.6 and newer); so in that case loading the weights from the state dict files should not be a concern (anymore)\n",
"- However, the code block below briefly shows how to load the model from these `.safetensor` files"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "c0a4fd86-4119-4a94-ae5e-13fb60d198bc",
"metadata": {},
"outputs": [],
"source": [
"file_name = \"gpt2-small-124M.safetensors\"\n",
"# file_name = \"gpt2-medium-355M.safetensors\"\n",
"# file_name = \"gpt2-large-774M.safetensors\"\n",
"# file_name = \"gpt2-xl-1558M.safetensors\""
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "20f96c2e-3469-47fb-bad3-e9173a1f1ba3",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Downloaded to gpt2-small-124M.safetensors\n"
]
}
],
"source": [
"import os\n",
"import urllib.request\n",
"\n",
"url = f\"https://huggingface.co/rasbt/gpt2-from-scratch-pytorch/resolve/main/{file_name}\"\n",
"\n",
"if not os.path.exists(file_name):\n",
" urllib.request.urlretrieve(url, file_name)\n",
" print(f\"Downloaded to {file_name}\")"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "d16a69b3-9bb4-42f8-8e4f-cc62a1a1a083",
"metadata": {},
"outputs": [],
"source": [
"# Load file\n",
"\n",
"from safetensors.torch import load_file\n",
"\n",
"gpt = GPTModel(BASE_CONFIG)\n",
"gpt.load_state_dict(load_file(file_name))\n",
"gpt.eval();"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "352e57f7-8d82-4c12-900c-03e41bc9de58",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Output text:\n",
" Every effort moves forward, but it's not enough.\n",
"\n",
"\"I'm not going to sit here and say, 'I'm not going to do this,'\n"
]
}
],
"source": [
"token_ids = generate(\n",
" model=gpt.to(device),\n",
" idx=text_to_token_ids(\"Every effort moves\", tokenizer).to(device),\n",
" max_new_tokens=30,\n",
" context_size=BASE_CONFIG[\"context_length\"],\n",
" top_k=1,\n",
" temperature=1.0\n",
")\n",
"\n",
"print(\"Output text:\\n\", token_ids_to_text(token_ids, tokenizer))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.16"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@ -79,7 +79,8 @@ from llms_from_scratch.ch05 import (
token_ids_to_text, token_ids_to_text,
calc_loss_batch, calc_loss_batch,
calc_loss_loader, calc_loss_loader,
plot_losses plot_losses,
download_and_load_gpt2
) )
from llms_from_scratch.ch06 import ( from llms_from_scratch.ch06 import (

View File

@ -4,10 +4,16 @@
# Code: https://github.com/rasbt/LLMs-from-scratch # Code: https://github.com/rasbt/LLMs-from-scratch
from .ch04 import generate_text_simple from .ch04 import generate_text_simple
import json
import os
import urllib.request
import numpy as np import numpy as np
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator from matplotlib.ticker import MaxNLocator
import torch import torch
from tqdm import tqdm
def generate(model, idx, max_new_tokens, context_size, temperature=0.0, top_k=None, eos_id=None): def generate(model, idx, max_new_tokens, context_size, temperature=0.0, top_k=None, eos_id=None):
@ -231,3 +237,119 @@ def plot_losses(epochs_seen, tokens_seen, train_losses, val_losses):
fig.tight_layout() # Adjust layout to make room fig.tight_layout() # Adjust layout to make room
plt.savefig("loss-plot.pdf") plt.savefig("loss-plot.pdf")
plt.show() plt.show()
def download_and_load_gpt2(model_size, models_dir):
import tensorflow as tf
# Validate model size
allowed_sizes = ("124M", "355M", "774M", "1558M")
if model_size not in allowed_sizes:
raise ValueError(f"Model size not in {allowed_sizes}")
# Define paths
model_dir = os.path.join(models_dir, model_size)
base_url = "https://openaipublic.blob.core.windows.net/gpt-2/models"
backup_base_url = "https://f001.backblazeb2.com/file/LLMs-from-scratch/gpt2"
filenames = [
"checkpoint", "encoder.json", "hparams.json",
"model.ckpt.data-00000-of-00001", "model.ckpt.index",
"model.ckpt.meta", "vocab.bpe"
]
# Download files
os.makedirs(model_dir, exist_ok=True)
for filename in filenames:
file_url = os.path.join(base_url, model_size, filename)
backup_url = os.path.join(backup_base_url, model_size, filename)
file_path = os.path.join(model_dir, filename)
download_file(file_url, file_path, backup_url)
# Load settings and params
tf_ckpt_path = tf.train.latest_checkpoint(model_dir)
settings = json.load(open(os.path.join(model_dir, "hparams.json"), "r", encoding="utf-8"))
params = load_gpt2_params_from_tf_ckpt(tf_ckpt_path, settings)
return settings, params
def download_file(url, destination, backup_url=None):
def _attempt_download(download_url):
with urllib.request.urlopen(download_url) as response:
# Get the total file size from headers, defaulting to 0 if not present
file_size = int(response.headers.get("Content-Length", 0))
# Check if file exists and has the same size
if os.path.exists(destination):
file_size_local = os.path.getsize(destination)
if file_size == file_size_local:
print(f"File already exists and is up-to-date: {destination}")
return True # Indicate success without re-downloading
block_size = 1024 # 1 Kilobyte
# Initialize the progress bar with total file size
progress_bar_description = os.path.basename(download_url)
with tqdm(total=file_size, unit="iB", unit_scale=True, desc=progress_bar_description) as progress_bar:
with open(destination, "wb") as file:
while True:
chunk = response.read(block_size)
if not chunk:
break
file.write(chunk)
progress_bar.update(len(chunk))
return True
try:
if _attempt_download(url):
return
except (urllib.error.HTTPError, urllib.error.URLError):
if backup_url is not None:
print(f"Primary URL ({url}) failed. Attempting backup URL: {backup_url}")
try:
if _attempt_download(backup_url):
return
except urllib.error.HTTPError:
pass
# If we reach here, both attempts have failed
error_message = (
f"Failed to download from both primary URL ({url})"
f"{' and backup URL (' + backup_url + ')' if backup_url else ''}."
"\nCheck your internet connection or the file availability.\n"
"For help, visit: https://github.com/rasbt/LLMs-from-scratch/discussions/273"
)
print(error_message)
except Exception as e:
print(f"An unexpected error occurred: {e}")
def load_gpt2_params_from_tf_ckpt(ckpt_path, settings):
import tensorflow as tf
# Initialize parameters dictionary with empty blocks for each layer
params = {"blocks": [{} for _ in range(settings["n_layer"])]}
# Iterate over each variable in the checkpoint
for name, _ in tf.train.list_variables(ckpt_path):
# Load the variable and remove singleton dimensions
variable_array = np.squeeze(tf.train.load_variable(ckpt_path, name))
# Process the variable name to extract relevant parts
variable_name_parts = name.split("/")[1:] # Skip the 'model/' prefix
# Identify the target dictionary for the variable
target_dict = params
if variable_name_parts[0].startswith("h"):
layer_number = int(variable_name_parts[0][1:])
target_dict = params["blocks"][layer_number]
# Recursively access or create nested dictionaries
for key in variable_name_parts[1:-1]:
target_dict = target_dict.setdefault(key, {})
# Assign the variable array to the last key
last_key = variable_name_parts[-1]
target_dict[last_key] = variable_array
return params

View File

@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project] [project]
name = "llms-from-scratch" name = "llms-from-scratch"
version = "1.0.1" version = "1.0.2"
description = "Implement a ChatGPT-like LLM in PyTorch from scratch, step by step" description = "Implement a ChatGPT-like LLM in PyTorch from scratch, step by step"
readme = "README.md" readme = "README.md"
requires-python = ">=3.10" requires-python = ">=3.10"