2024-05-05 12:05:17 -05:00
{
2024-05-19 20:19:02 -05:00
"cells": [
{
"cell_type": "markdown",
"id": "c024bfa4-1a7a-4751-b5a1-827225a3478b",
"metadata": {
"id": "c024bfa4-1a7a-4751-b5a1-827225a3478b"
},
"source": [
2024-08-30 08:07:54 +02:00
"<table style=\"width:100%\">\n",
"<tr>\n",
"<td style=\"vertical-align:middle; text-align:left;\">\n",
"<font size=\"2\">\n",
"Supplementary code for the <a href=\"http://mng.bz/orYv\">Build a Large Language Model From Scratch</a> book by <a href=\"https://sebastianraschka.com\">Sebastian Raschka</a><br>\n",
"<br>Code repository: <a href=\"https://github.com/rasbt/LLMs-from-scratch\">https://github.com/rasbt/LLMs-from-scratch</a>\n",
"</font>\n",
"</td>\n",
"<td style=\"vertical-align:middle; text-align:left;\">\n",
"<a href=\"http://mng.bz/orYv\"><img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/cover-small.webp\" width=\"100px\"></a>\n",
"</td>\n",
"</tr>\n",
"</table>\n"
2024-05-19 20:19:02 -05:00
]
},
{
"cell_type": "markdown",
"id": "58b8c870-fb72-490e-8916-d8129bd5d1ff",
"metadata": {
"id": "58b8c870-fb72-490e-8916-d8129bd5d1ff"
},
"source": [
"# Appendix E: Parameter-efficient Finetuning with LoRA"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "5b7e01c2-1c84-4f2a-bb51-2e0b74abda90",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
2024-05-19 20:11:56 -05:00
},
2024-05-19 20:19:02 -05:00
"id": "5b7e01c2-1c84-4f2a-bb51-2e0b74abda90",
"outputId": "316166b4-027a-4756-e9b4-fe88ae75dd4f"
},
"outputs": [
2024-05-19 20:11:56 -05:00
{
2024-05-19 20:19:02 -05:00
"name": "stdout",
"output_type": "stream",
"text": [
2025-02-20 08:08:28 -06:00
"matplotlib version: 3.10.0\n",
"numpy version: 2.0.2\n",
"tiktoken version: 0.9.0\n",
"torch version: 2.6.0\n",
"tensorflow version: 2.18.0\n",
"pandas version: 2.2.3\n"
2024-05-19 20:19:02 -05:00
]
}
],
"source": [
"from importlib.metadata import version\n",
"\n",
"pkgs = [\"matplotlib\",\n",
" \"numpy\",\n",
" \"tiktoken\",\n",
" \"torch\",\n",
" \"tensorflow\", # For OpenAI's pretrained weights\n",
" \"pandas\" # Dataset loading\n",
" ]\n",
"for p in pkgs:\n",
" print(f\"{p} version: {version(p)}\")"
]
},
{
"cell_type": "markdown",
"id": "21532056-0ef4-4c98-82c7-e91f61c6485e",
"metadata": {
"id": "21532056-0ef4-4c98-82c7-e91f61c6485e"
},
"source": [
"## E.1 Introduction to LoRA"
]
},
{
"cell_type": "markdown",
"id": "66edc999-3d91-4a1c-a157-9d056392e8d8",
"metadata": {
"id": "66edc999-3d91-4a1c-a157-9d056392e8d8"
},
"source": [
"- No code in this section\n",
"- Low-rank adaptation (LoRA) is a machine learning technique that modifies a pretrained model to better suit a specific, often smaller, dataset by adjusting only a small, low-rank subset of the model's parameters\n",
"- This approach is important because it allows for efficient finetuning of large models on task-specific data, significantly reducing the computational cost and time required for finetuning"
]
},
{
"cell_type": "markdown",
"id": "5bb75b5d-d59c-4948-821a-1594a5883dc1",
"metadata": {
"id": "5bb75b5d-d59c-4948-821a-1594a5883dc1"
},
"source": [
"- Suppose we have a large weight matrix $W$ for a given layer\n",
"- During backpropagation, we learn a $\\Delta W$ matrix, which contains information on how much we want to update the original weights to minimize the loss function during training\n",
"- In regular training and finetuning, the weight update is defined as follows:\n",
"\n",
"$$W_{\\text{updated}} = W + \\Delta W$$\n",
"\n",
"- The LoRA method proposed by [Hu et al.](https://arxiv.org/abs/2106.09685) offers a more efficient alternative to computing the weight updates $\\Delta W$ by learning an approximation of it, $\\Delta W \\approx AB$.\n",
"- In other words, in LoRA, we have the following, where $A$ and $B$ are two small weight matrices:\n",
"\n",
"$$W_{\\text{updated}} = W + AB$$\n",
"\n",
"- The figure below illustrates these formulas for full finetuning and LoRA side by side"
]
},
{
"cell_type": "markdown",
"id": "a8a7419d-cae9-4525-bb44-1641f6ef4f3b",
"metadata": {
"id": "a8a7419d-cae9-4525-bb44-1641f6ef4f3b"
},
"source": [
"<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/appendix-e_compressed/lora-1.webp\" width=\"500px\">"
]
},
{
"cell_type": "markdown",
"id": "4edd43c9-8ec5-48e6-b3fc-5fb3c16037cc",
"metadata": {
"id": "4edd43c9-8ec5-48e6-b3fc-5fb3c16037cc"
},
"source": [
"- If you paid close attention, the full finetuning and LoRA depictions in the figure above look slightly different from the formulas I have shown earlier\n",
"- That's due to the distributive law of matrix multiplication: we don't have to add the weights with the updated weights but can keep them separate\n",
"- For instance, if $x$ is the input data, then we can write the following for regular finetuning:\n",
"\n",
"$$x (W+\\Delta W) = x W + x \\Delta W$$\n",
"\n",
"- Similarly, we can write the following for LoRA:\n",
"\n",
"$$x (W+A B) = x W + x A B$$\n",
"\n",
"- The fact that we can keep the LoRA weight matrices separate makes LoRA especially attractive\n",
"- In practice, this means that we don't have to modify the weights of the pretrained model at all, as we can apply the LoRA matrices on the fly\n",
"- After setting up the dataset and loading the model, we will implement LoRA in the code to make these concepts less abstract"
]
},
{
"cell_type": "markdown",
"id": "8c7017a2-32aa-4002-a2f3-12aac293ccdf",
"metadata": {
"id": "8c7017a2-32aa-4002-a2f3-12aac293ccdf"
},
"source": [
"## E.2 Preparing the dataset"
]
},
{
"cell_type": "markdown",
"id": "669c64df-4431-4d27-834d-2bb38a01fc02",
"metadata": {
"id": "669c64df-4431-4d27-834d-2bb38a01fc02"
},
"source": [
"- This section repeats the code from chapter 6 to load and prepare the dataset\n",
"- Instead of repeating this code, one could open and run the chapter 6 notebook and then insert the LoRA code from section E.4 there\n",
"- (The LoRA code was originally the last section of chapter 6 but was moved to the appendix due to the length of chapter 6)\n",
"- In a similar fashion, we could also apply LoRA to the models in chapter 7 for instruction finetuning"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "def7c09b-af9c-4216-90ce-5e67aed1065c",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
2024-05-19 20:11:56 -05:00
},
2024-05-19 20:19:02 -05:00
"id": "def7c09b-af9c-4216-90ce-5e67aed1065c",
"outputId": "a67a7afe-b401-4463-c731-87025d20f72d"
},
"outputs": [
2024-05-19 20:11:56 -05:00
{
2024-05-19 20:19:02 -05:00
"name": "stdout",
"output_type": "stream",
"text": [
"File downloaded and saved as sms_spam_collection/SMSSpamCollection.tsv\n"
]
}
],
"source": [
2025-02-20 08:08:28 -06:00
"import urllib\n",
2024-05-19 20:19:02 -05:00
"from pathlib import Path\n",
"import pandas as pd\n",
"from previous_chapters import (\n",
" download_and_unzip_spam_data,\n",
" create_balanced_dataset,\n",
" random_split\n",
")\n",
2025-03-23 19:28:49 -05:00
"# If the `previous_chapters.py` file is not available locally,\n",
"# you can import it from the `llms-from-scratch` PyPI package.\n",
"# For details, see: https://github.com/rasbt/LLMs-from-scratch/tree/main/pkg\n",
"# E.g.,\n",
"# from llms_from_scratch.ch06 import (\n",
"# download_and_unzip_spam_data,\n",
"# create_balanced_dataset,\n",
"# random_split\n",
"# )\n",
"\n",
2024-05-19 20:19:02 -05:00
"\n",
"\n",
"url = \"https://archive.ics.uci.edu/static/public/228/sms+spam+collection.zip\"\n",
"zip_path = \"sms_spam_collection.zip\"\n",
"extracted_path = \"sms_spam_collection\"\n",
"data_file_path = Path(extracted_path) / \"SMSSpamCollection.tsv\"\n",
"\n",
2025-02-20 08:08:28 -06:00
"try:\n",
" download_and_unzip_spam_data(url, zip_path, extracted_path, data_file_path)\n",
2025-02-20 09:26:23 -06:00
"except (urllib.error.HTTPError, urllib.error.URLError, TimeoutError) as e:\n",
" print(f\"Primary URL failed: {e}. Trying backup URL...\")\n",
2025-02-20 08:08:28 -06:00
" url = \"https://f001.backblazeb2.com/file/LLMs-from-scratch/sms%2Bspam%2Bcollection.zip\"\n",
" download_and_unzip_spam_data(url, zip_path, extracted_path, data_file_path)\n",
2024-05-19 20:19:02 -05:00
"\n",
"df = pd.read_csv(data_file_path, sep=\"\\t\", header=None, names=[\"Label\", \"Text\"])\n",
"balanced_df = create_balanced_dataset(df)\n",
"balanced_df[\"Label\"] = balanced_df[\"Label\"].map({\"ham\": 0, \"spam\": 1})\n",
"\n",
"train_df, validation_df, test_df = random_split(balanced_df, 0.7, 0.1)\n",
"train_df.to_csv(\"train.csv\", index=None)\n",
"validation_df.to_csv(\"validation.csv\", index=None)\n",
"test_df.to_csv(\"test.csv\", index=None)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "74c3c463-8763-4cc0-9320-41c7eaad8ab7",
"metadata": {
"id": "74c3c463-8763-4cc0-9320-41c7eaad8ab7"
},
"outputs": [],
"source": [
"import torch\n",
"import tiktoken\n",
"from previous_chapters import SpamDataset\n",
"\n",
"\n",
"tokenizer = tiktoken.get_encoding(\"gpt2\")\n",
"train_dataset = SpamDataset(\"train.csv\", max_length=None, tokenizer=tokenizer)\n",
"val_dataset = SpamDataset(\"validation.csv\", max_length=train_dataset.max_length, tokenizer=tokenizer)\n",
"test_dataset = SpamDataset(\"test.csv\", max_length=train_dataset.max_length, tokenizer=tokenizer)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "8681adc0-6f02-4e75-b01a-a6ab75d05542",
"metadata": {
"id": "8681adc0-6f02-4e75-b01a-a6ab75d05542"
},
"outputs": [],
"source": [
"from torch.utils.data import DataLoader\n",
"\n",
"num_workers = 0\n",
"batch_size = 8\n",
"\n",
"torch.manual_seed(123)\n",
"\n",
"train_loader = DataLoader(\n",
" dataset=train_dataset,\n",
" batch_size=batch_size,\n",
" shuffle=True,\n",
" num_workers=num_workers,\n",
" drop_last=True,\n",
")\n",
"\n",
"val_loader = DataLoader(\n",
" dataset=val_dataset,\n",
" batch_size=batch_size,\n",
" num_workers=num_workers,\n",
" drop_last=False,\n",
")\n",
"\n",
"test_loader = DataLoader(\n",
" dataset=test_dataset,\n",
" batch_size=batch_size,\n",
" num_workers=num_workers,\n",
" drop_last=False,\n",
")"
]
},
{
"cell_type": "markdown",
"id": "ab7335db-e0bb-4e27-80c5-eea11e593a57",
"metadata": {
"id": "ab7335db-e0bb-4e27-80c5-eea11e593a57"
},
"source": [
"- As a verification step, we iterate through the data loaders and check that the batches contain 8 training examples each, where each training example consists of 120 tokens"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "4dee6882-4c3a-4964-af15-fa31f86ad047",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
2024-05-19 20:11:56 -05:00
},
2024-05-19 20:19:02 -05:00
"id": "4dee6882-4c3a-4964-af15-fa31f86ad047",
"outputId": "2ae34de1-dd01-4f99-d2c8-ba4dca400754"
},
"outputs": [
2024-05-19 20:11:56 -05:00
{
2024-05-19 20:19:02 -05:00
"name": "stdout",
"output_type": "stream",
"text": [
"Train loader:\n",
"Input batch dimensions: torch.Size([8, 120])\n",
"Label batch dimensions torch.Size([8])\n"
]
}
],
"source": [
"print(\"Train loader:\")\n",
"for input_batch, target_batch in train_loader:\n",
" pass\n",
"\n",
"print(\"Input batch dimensions:\", input_batch.shape)\n",
"print(\"Label batch dimensions\", target_batch.shape)"
]
},
{
"cell_type": "markdown",
"id": "5cdd7947-7039-49bf-8a5e-c0a2f4281ca1",
"metadata": {
"id": "5cdd7947-7039-49bf-8a5e-c0a2f4281ca1"
},
"source": [
"- Lastly, let's print the total number of batches in each dataset"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "IZfw-TYD2zTj",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
2024-05-19 20:11:56 -05:00
},
2024-05-19 20:19:02 -05:00
"id": "IZfw-TYD2zTj",
"outputId": "4d19ed61-cf7a-4ec4-b822-c847dd1c5d77"
},
"outputs": [
2024-05-19 20:11:56 -05:00
{
2024-05-19 20:19:02 -05:00
"name": "stdout",
"output_type": "stream",
"text": [
"130 training batches\n",
"19 validation batches\n",
"38 test batches\n"
]
}
],
"source": [
"print(f\"{len(train_loader)} training batches\")\n",
"print(f\"{len(val_loader)} validation batches\")\n",
"print(f\"{len(test_loader)} test batches\")"
]
},
{
"cell_type": "markdown",
"id": "dec9aa4a-ffd2-4d9f-a835-cce1059fe604",
"metadata": {
"id": "dec9aa4a-ffd2-4d9f-a835-cce1059fe604"
},
"source": [
"## E.3 Initializing the model"
]
},
{
"cell_type": "markdown",
"id": "f36ebdaf-810e-46a2-9ad9-e017a04051b1",
"metadata": {
"id": "f36ebdaf-810e-46a2-9ad9-e017a04051b1"
},
"source": [
"- This section repeats the code from chapter 6 to load and prepare the model"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "02b3a506-3879-4258-82b5-93a5b6bafa74",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
2024-05-19 20:11:56 -05:00
},
2024-05-19 20:19:02 -05:00
"id": "02b3a506-3879-4258-82b5-93a5b6bafa74",
"outputId": "b8c9b125-bb52-45d3-8071-fa5054dbf5a9"
},
"outputs": [
2024-05-19 20:11:56 -05:00
{
2024-05-19 20:19:02 -05:00
"name": "stderr",
"output_type": "stream",
"text": [
"checkpoint: 100%|███████████████████████████| 77.0/77.0 [00:00<00:00, 45.0kiB/s]\n",
"encoder.json: 100%|███████████████████████| 1.04M/1.04M [00:00<00:00, 2.15MiB/s]\n",
"hparams.json: 100%|█████████████████████████| 90.0/90.0 [00:00<00:00, 54.5kiB/s]\n",
"model.ckpt.data-00000-of-00001: 100%|███████| 498M/498M [01:12<00:00, 6.86MiB/s]\n",
"model.ckpt.index: 100%|███████████████████| 5.21k/5.21k [00:00<00:00, 2.99MiB/s]\n",
"model.ckpt.meta: 100%|██████████████████████| 471k/471k [00:00<00:00, 1.32MiB/s]\n",
"vocab.bpe: 100%|████████████████████████████| 456k/456k [00:00<00:00, 1.48MiB/s]\n"
]
}
],
"source": [
"from gpt_download import download_and_load_gpt2\n",
"from previous_chapters import GPTModel, load_weights_into_gpt\n",
2025-03-23 19:28:49 -05:00
"# Alternatively:\n",
"# from llms_from_scratch.ch04 import GPTModel\n",
"# from llms_from_scratch.ch05 import load_weights_into_gpt\n",
"\n",
2024-05-19 20:19:02 -05:00
"\n",
"\n",
"CHOOSE_MODEL = \"gpt2-small (124M)\"\n",
"INPUT_PROMPT = \"Every effort moves\"\n",
"\n",
"BASE_CONFIG = {\n",
" \"vocab_size\": 50257, # Vocabulary size\n",
" \"context_length\": 1024, # Context length\n",
" \"drop_rate\": 0.0, # Dropout rate\n",
" \"qkv_bias\": True # Query-key-value bias\n",
"}\n",
"\n",
"model_configs = {\n",
" \"gpt2-small (124M)\": {\"emb_dim\": 768, \"n_layers\": 12, \"n_heads\": 12},\n",
" \"gpt2-medium (355M)\": {\"emb_dim\": 1024, \"n_layers\": 24, \"n_heads\": 16},\n",
" \"gpt2-large (774M)\": {\"emb_dim\": 1280, \"n_layers\": 36, \"n_heads\": 20},\n",
" \"gpt2-xl (1558M)\": {\"emb_dim\": 1600, \"n_layers\": 48, \"n_heads\": 25},\n",
"}\n",
"\n",
"BASE_CONFIG.update(model_configs[CHOOSE_MODEL])\n",
"\n",
"model_size = CHOOSE_MODEL.split(\" \")[-1].lstrip(\"(\").rstrip(\")\")\n",
"settings, params = download_and_load_gpt2(model_size=model_size, models_dir=\"gpt2\")\n",
"\n",
"model = GPTModel(BASE_CONFIG)\n",
"load_weights_into_gpt(model, params)\n",
"model.eval();"
]
},
{
"cell_type": "markdown",
"id": "252614cd-7ce6-4908-83e6-3761f519904e",
"metadata": {
"id": "252614cd-7ce6-4908-83e6-3761f519904e"
},
"source": [
"- To ensure that the model was loaded corrected, let's double-check that it generates coherent text"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "8b6ce20c-0700-4783-8be0-4cf17c200a7f",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
2024-05-19 20:11:56 -05:00
},
2024-05-19 20:19:02 -05:00
"id": "8b6ce20c-0700-4783-8be0-4cf17c200a7f",
"outputId": "28ccbca5-8de9-41a0-c093-da00fcbaa91c"
},
"outputs": [
2024-05-19 20:11:56 -05:00
{
2024-05-19 20:19:02 -05:00
"name": "stdout",
"output_type": "stream",
"text": [
"Every effort moves you forward.\n",
"\n",
"The first step is to understand the importance of your work\n"
]
}
],
"source": [
"from previous_chapters import (\n",
" generate_text_simple,\n",
" text_to_token_ids,\n",
" token_ids_to_text\n",
")\n",
"\n",
"\n",
"text_1 = \"Every effort moves you\"\n",
"\n",
"token_ids = generate_text_simple(\n",
" model=model,\n",
" idx=text_to_token_ids(text_1, tokenizer),\n",
" max_new_tokens=15,\n",
" context_size=BASE_CONFIG[\"context_length\"]\n",
")\n",
"\n",
"print(token_ids_to_text(token_ids, tokenizer))"
]
},
{
"cell_type": "markdown",
"id": "8174b31b-1ab5-4115-b01c-245369da5af3",
"metadata": {
"id": "8174b31b-1ab5-4115-b01c-245369da5af3"
},
"source": [
"- Then, we prepare the model for classification finetuning similar to chapter 6, where we replace the output layer"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "e255ce91-d73a-4854-90a4-95804928eb16",
"metadata": {
"id": "e255ce91-d73a-4854-90a4-95804928eb16"
},
"outputs": [],
"source": [
"torch.manual_seed(123)\n",
"\n",
"num_classes = 2\n",
"model.out_head = torch.nn.Linear(in_features=768, out_features=num_classes)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "02e6f057-1383-4ece-8444-0a88e71ac75d",
"metadata": {
"id": "02e6f057-1383-4ece-8444-0a88e71ac75d"
},
"outputs": [],
"source": [
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
2024-08-19 20:58:45 -05:00
"\n",
"# Note:\n",
"# Uncommenting the following lines will allow the code to run on Apple Silicon chips, if applicable,\n",
"# which is approximately 1.2x faster than on an Apple CPU (as measured on an M3 MacBook Air).\n",
"# However, the resulting loss values may be slightly different.\n",
"\n",
"#if torch.cuda.is_available():\n",
"# device = torch.device(\"cuda\")\n",
"#elif torch.backends.mps.is_available():\n",
"# device = torch.device(\"mps\")\n",
"#else:\n",
"# device = torch.device(\"cpu\")\n",
"#\n",
"# print(f\"Using {device} device.\")\n",
"\n",
2024-05-19 20:19:02 -05:00
"model.to(device); # no assignment model = model.to(device) necessary for nn.Module classes"
]
},
{
"cell_type": "markdown",
"id": "8e951cd6-5e42-44d2-b21f-895cb61004fe",
"metadata": {
"id": "8e951cd6-5e42-44d2-b21f-895cb61004fe"
},
"source": [
"- Lastly, let's calculate the initial classification accuracy of the non-finetuned model (we expect this to be around 50%, which means that the model is not able to distinguish between spam and non-spam messages yet reliably)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "fc7dd72c-73a2-4881-ade0-0a9605f1ab8c",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
2024-05-19 20:11:56 -05:00
},
2024-05-19 20:19:02 -05:00
"id": "fc7dd72c-73a2-4881-ade0-0a9605f1ab8c",
"outputId": "74848515-5a49-4125-fecb-9f4bac23f812"
},
"outputs": [
2024-05-19 20:11:56 -05:00
{
2024-05-19 20:19:02 -05:00
"name": "stdout",
"output_type": "stream",
"text": [
"Training accuracy: 46.25%\n",
"Validation accuracy: 45.00%\n",
"Test accuracy: 48.75%\n"
]
}
],
"source": [
"from previous_chapters import calc_accuracy_loader\n",
2025-03-23 19:28:49 -05:00
"# Alternatively:\n",
"# from llms_from_scratch.ch06 import calc_accuracy_loader\n",
"\n",
2024-05-19 20:19:02 -05:00
"\n",
"\n",
"torch.manual_seed(123)\n",
"train_accuracy = calc_accuracy_loader(train_loader, model, device, num_batches=10)\n",
"val_accuracy = calc_accuracy_loader(val_loader, model, device, num_batches=10)\n",
"test_accuracy = calc_accuracy_loader(test_loader, model, device, num_batches=10)\n",
"\n",
"print(f\"Training accuracy: {train_accuracy*100:.2f}%\")\n",
"print(f\"Validation accuracy: {val_accuracy*100:.2f}%\")\n",
"print(f\"Test accuracy: {test_accuracy*100:.2f}%\")"
]
},
{
"cell_type": "markdown",
"id": "398a1ec9-e2a1-43d6-bf9f-12ee54b46a7b",
"metadata": {
"id": "398a1ec9-e2a1-43d6-bf9f-12ee54b46a7b"
},
"source": [
"## E.4 Parameter-efficient finetuning with LoRA"
]
},
{
"cell_type": "markdown",
"id": "652a4a82-61ef-4d0a-9858-8988e844f12c",
"metadata": {
"id": "652a4a82-61ef-4d0a-9858-8988e844f12c"
},
"source": [
"- We begin by initializing a LoRALayer that creates the matrices $A$ and $B$, along with the `alpha` scaling hyperparameter and the `rank` ($r$) hyperparameters\n",
"- This layer can accept an input and compute the corresponding output, as illustrated in the figure below\n",
"\n",
"<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/appendix-e_compressed/lora-2.webp\" width=\"200px\">\n",
"\n",
"In code, this LoRA layer depicted in the figure above looks like as follows"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "2ds9ywjMwvIW",
"metadata": {
"id": "2ds9ywjMwvIW"
},
"outputs": [],
"source": [
"import math\n",
"\n",
"class LoRALayer(torch.nn.Module):\n",
" def __init__(self, in_dim, out_dim, rank, alpha):\n",
" super().__init__()\n",
" self.A = torch.nn.Parameter(torch.empty(in_dim, rank))\n",
" torch.nn.init.kaiming_uniform_(self.A, a=math.sqrt(5)) # similar to standard weight initialization\n",
" self.B = torch.nn.Parameter(torch.zeros(rank, out_dim))\n",
" self.alpha = alpha\n",
"\n",
" def forward(self, x):\n",
" x = self.alpha * (x @ self.A @ self.B)\n",
" return x"
]
},
{
"cell_type": "markdown",
"id": "ad21faa8-0614-4257-93cd-68952193e14a",
"metadata": {
"id": "ad21faa8-0614-4257-93cd-68952193e14a"
},
"source": [
"- In the code above, `rank` is a hyperparameter that controls the inner dimension of the matrices $A$ and $B$\n",
"- In other words, this parameter controls the number of additional parameters introduced by LoRA and is a key factor in determining the balance between model adaptability and parameter efficiency\n",
"- The second hyperparameter, `alpha`, is a scaling hyperparameter applied to the output of the low-rank adaptation\n",
"- It essentially controls the extent to which the adapted layer's output is allowed to influence the original output of the layer being adapted\n",
"- This can be seen as a way to regulate the impact of the low-rank adaptation on the layer's output\n",
"- So far, the `LoRALayer` class we implemented above allows us to transform the layer inputs $x$\n",
"- However, in LoRA, we are usually interested in replacing existing `Linear` layers so that the weight update is applied to the existing pretrained weights, as shown in the figure below\n",
"\n",
"<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/appendix-e_compressed/lora-3.webp\" width=\"200px\">"
]
},
{
"cell_type": "markdown",
"id": "3e6d5da0-dfce-4808-b89b-29ff333f563f",
"metadata": {
"id": "3e6d5da0-dfce-4808-b89b-29ff333f563f"
},
"source": [
"- To incorporate the original `Linear` layer weights as shown in the figure above, we implement a `LinearWithLoRA` layer below that uses the previously implemented LoRALayer and can be used to replace existing `Linear` layers in a neural network, for example, the self-attention module or feed forward modules in an LLM"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "127d3a64-8359-4b21-b056-78d58cc75fe8",
"metadata": {
"id": "127d3a64-8359-4b21-b056-78d58cc75fe8"
},
"outputs": [],
"source": [
"class LinearWithLoRA(torch.nn.Module):\n",
" def __init__(self, linear, rank, alpha):\n",
" super().__init__()\n",
" self.linear = linear\n",
" self.lora = LoRALayer(\n",
" linear.in_features, linear.out_features, rank, alpha\n",
" )\n",
"\n",
" def forward(self, x):\n",
" return self.linear(x) + self.lora(x)"
]
},
{
"cell_type": "markdown",
"id": "e1145a90-35ff-462c-820b-15483fa5b051",
"metadata": {
"id": "e1145a90-35ff-462c-820b-15483fa5b051"
},
"source": [
"- Note that since we initialize the weight matrix $B$ (`self.B` in `LoRALayer`) with zero values in the LoRA layer, the matrix multiplication between $A$ and $B$ results in a matrix consisting of 0's and doesn't affect the original weights (since adding 0 to the original weights does not modify them)"
]
},
{
"cell_type": "markdown",
"id": "e98a6d36-7bc9-434c-a7f1-533f26aff06d",
"metadata": {
"id": "e98a6d36-7bc9-434c-a7f1-533f26aff06d"
},
"source": [
"- To try LoRA on the GPT model we defined earlier, we define a `replace_linear_with_lora` function to replace all `Linear` layers in the model with the new `LinearWithLoRA` layers\n",
"\n",
"<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/appendix-e_compressed/lora-4.webp\" width=\"400px\">"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "WlQZ8ygqzN_g",
"metadata": {
"id": "WlQZ8ygqzN_g"
},
"outputs": [],
"source": [
"def replace_linear_with_lora(model, rank, alpha):\n",
" for name, module in model.named_children():\n",
" if isinstance(module, torch.nn.Linear):\n",
" # Replace the Linear layer with LinearWithLoRA\n",
" setattr(model, name, LinearWithLoRA(module, rank, alpha))\n",
" else:\n",
" # Recursively apply the same function to child modules\n",
" replace_linear_with_lora(module, rank, alpha)"
]
},
{
"cell_type": "markdown",
"id": "8c172164-cdde-4489-b7d7-aaed9cc2f5f2",
"metadata": {
"id": "8c172164-cdde-4489-b7d7-aaed9cc2f5f2"
},
"source": [
"- We then freeze the original model parameter and use the `replace_linear_with_lora` to replace the said `Linear` layers using the code below\n",
"- This will replace the `Linear` layers in the LLM with `LinearWithLoRA` layers"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "dbe15350-4da9-4829-9d23-98bbd3d0b1a1",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
2024-05-19 20:11:56 -05:00
},
2024-05-19 20:19:02 -05:00
"id": "dbe15350-4da9-4829-9d23-98bbd3d0b1a1",
"outputId": "fd4c208f-854a-4701-d9d3-9d73af733364"
},
"outputs": [
2024-05-19 20:11:56 -05:00
{
2024-05-19 20:19:02 -05:00
"name": "stdout",
"output_type": "stream",
"text": [
"Total trainable parameters before: 124,441,346\n",
"Total trainable parameters after: 0\n"
]
}
],
"source": [
"total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
"print(f\"Total trainable parameters before: {total_params:,}\")\n",
"\n",
"for param in model.parameters():\n",
" param.requires_grad = False\n",
"\n",
"total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
"print(f\"Total trainable parameters after: {total_params:,}\")"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "mLk_fPq0yz_u",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
2024-05-19 20:11:56 -05:00
},
2024-05-19 20:19:02 -05:00
"id": "mLk_fPq0yz_u",
"outputId": "0a93b8fc-05d7-4ace-ee47-e2fc6bdd7d75"
},
"outputs": [
2024-05-19 20:11:56 -05:00
{
2024-05-19 20:19:02 -05:00
"name": "stdout",
"output_type": "stream",
"text": [
"Total trainable LoRA parameters: 2,666,528\n"
]
}
],
"source": [
"replace_linear_with_lora(model, rank=16, alpha=16)\n",
"\n",
"total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
"print(f\"Total trainable LoRA parameters: {total_params:,}\")"
]
},
{
"cell_type": "markdown",
"id": "b8b6819e-ef7a-4f0d-841a-1b467496bef9",
"metadata": {
"id": "b8b6819e-ef7a-4f0d-841a-1b467496bef9"
},
"source": [
2024-05-19 20:26:53 -05:00
"- As we can see, we reduced the number of trainable parameters by almost 50x when using LoRA\n",
2024-05-19 20:19:02 -05:00
"- Let's now double-check whether the layers have been modified as intended by printing the model architecture"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "1711be61-bb2c-466f-9b5b-24f4aa5ccd9c",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
2024-05-19 20:11:56 -05:00
},
2024-05-19 20:19:02 -05:00
"id": "1711be61-bb2c-466f-9b5b-24f4aa5ccd9c",
"outputId": "acff8eca-3775-45a2-b62d-032a986ef037"
},
"outputs": [
2024-05-19 20:11:56 -05:00
{
2024-05-19 20:19:02 -05:00
"name": "stdout",
"output_type": "stream",
"text": [
"GPTModel(\n",
" (tok_emb): Embedding(50257, 768)\n",
" (pos_emb): Embedding(1024, 768)\n",
" (drop_emb): Dropout(p=0.0, inplace=False)\n",
" (trf_blocks): Sequential(\n",
" (0): TransformerBlock(\n",
" (att): MultiHeadAttention(\n",
" (W_query): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (W_key): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (W_value): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (out_proj): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (dropout): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (ff): FeedForward(\n",
" (layers): Sequential(\n",
" (0): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=3072, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (1): GELU()\n",
" (2): LinearWithLoRA(\n",
" (linear): Linear(in_features=3072, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" )\n",
" )\n",
" (norm1): LayerNorm()\n",
" (norm2): LayerNorm()\n",
" (drop_resid): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (1): TransformerBlock(\n",
" (att): MultiHeadAttention(\n",
" (W_query): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (W_key): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (W_value): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (out_proj): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (dropout): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (ff): FeedForward(\n",
" (layers): Sequential(\n",
" (0): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=3072, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (1): GELU()\n",
" (2): LinearWithLoRA(\n",
" (linear): Linear(in_features=3072, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" )\n",
" )\n",
" (norm1): LayerNorm()\n",
" (norm2): LayerNorm()\n",
" (drop_resid): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (2): TransformerBlock(\n",
" (att): MultiHeadAttention(\n",
" (W_query): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (W_key): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (W_value): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (out_proj): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (dropout): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (ff): FeedForward(\n",
" (layers): Sequential(\n",
" (0): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=3072, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (1): GELU()\n",
" (2): LinearWithLoRA(\n",
" (linear): Linear(in_features=3072, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" )\n",
" )\n",
" (norm1): LayerNorm()\n",
" (norm2): LayerNorm()\n",
" (drop_resid): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (3): TransformerBlock(\n",
" (att): MultiHeadAttention(\n",
" (W_query): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (W_key): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (W_value): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (out_proj): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (dropout): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (ff): FeedForward(\n",
" (layers): Sequential(\n",
" (0): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=3072, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (1): GELU()\n",
" (2): LinearWithLoRA(\n",
" (linear): Linear(in_features=3072, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" )\n",
" )\n",
" (norm1): LayerNorm()\n",
" (norm2): LayerNorm()\n",
" (drop_resid): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (4): TransformerBlock(\n",
" (att): MultiHeadAttention(\n",
" (W_query): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (W_key): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (W_value): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (out_proj): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (dropout): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (ff): FeedForward(\n",
" (layers): Sequential(\n",
" (0): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=3072, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (1): GELU()\n",
" (2): LinearWithLoRA(\n",
" (linear): Linear(in_features=3072, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" )\n",
" )\n",
" (norm1): LayerNorm()\n",
" (norm2): LayerNorm()\n",
" (drop_resid): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (5): TransformerBlock(\n",
" (att): MultiHeadAttention(\n",
" (W_query): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (W_key): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (W_value): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (out_proj): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (dropout): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (ff): FeedForward(\n",
" (layers): Sequential(\n",
" (0): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=3072, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (1): GELU()\n",
" (2): LinearWithLoRA(\n",
" (linear): Linear(in_features=3072, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" )\n",
" )\n",
" (norm1): LayerNorm()\n",
" (norm2): LayerNorm()\n",
" (drop_resid): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (6): TransformerBlock(\n",
" (att): MultiHeadAttention(\n",
" (W_query): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (W_key): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (W_value): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (out_proj): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (dropout): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (ff): FeedForward(\n",
" (layers): Sequential(\n",
" (0): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=3072, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (1): GELU()\n",
" (2): LinearWithLoRA(\n",
" (linear): Linear(in_features=3072, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" )\n",
" )\n",
" (norm1): LayerNorm()\n",
" (norm2): LayerNorm()\n",
" (drop_resid): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (7): TransformerBlock(\n",
" (att): MultiHeadAttention(\n",
" (W_query): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (W_key): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (W_value): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (out_proj): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (dropout): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (ff): FeedForward(\n",
" (layers): Sequential(\n",
" (0): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=3072, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (1): GELU()\n",
" (2): LinearWithLoRA(\n",
" (linear): Linear(in_features=3072, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" )\n",
" )\n",
" (norm1): LayerNorm()\n",
" (norm2): LayerNorm()\n",
" (drop_resid): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (8): TransformerBlock(\n",
" (att): MultiHeadAttention(\n",
" (W_query): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (W_key): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (W_value): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (out_proj): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (dropout): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (ff): FeedForward(\n",
" (layers): Sequential(\n",
" (0): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=3072, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (1): GELU()\n",
" (2): LinearWithLoRA(\n",
" (linear): Linear(in_features=3072, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" )\n",
" )\n",
" (norm1): LayerNorm()\n",
" (norm2): LayerNorm()\n",
" (drop_resid): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (9): TransformerBlock(\n",
" (att): MultiHeadAttention(\n",
" (W_query): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (W_key): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (W_value): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (out_proj): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (dropout): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (ff): FeedForward(\n",
" (layers): Sequential(\n",
" (0): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=3072, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (1): GELU()\n",
" (2): LinearWithLoRA(\n",
" (linear): Linear(in_features=3072, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" )\n",
" )\n",
" (norm1): LayerNorm()\n",
" (norm2): LayerNorm()\n",
" (drop_resid): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (10): TransformerBlock(\n",
" (att): MultiHeadAttention(\n",
" (W_query): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (W_key): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (W_value): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (out_proj): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (dropout): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (ff): FeedForward(\n",
" (layers): Sequential(\n",
" (0): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=3072, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (1): GELU()\n",
" (2): LinearWithLoRA(\n",
" (linear): Linear(in_features=3072, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" )\n",
" )\n",
" (norm1): LayerNorm()\n",
" (norm2): LayerNorm()\n",
" (drop_resid): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (11): TransformerBlock(\n",
" (att): MultiHeadAttention(\n",
" (W_query): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (W_key): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (W_value): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (out_proj): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (dropout): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (ff): FeedForward(\n",
" (layers): Sequential(\n",
" (0): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=3072, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (1): GELU()\n",
" (2): LinearWithLoRA(\n",
" (linear): Linear(in_features=3072, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" )\n",
" )\n",
" (norm1): LayerNorm()\n",
" (norm2): LayerNorm()\n",
" (drop_resid): Dropout(p=0.0, inplace=False)\n",
" )\n",
" )\n",
" (final_norm): LayerNorm()\n",
" (out_head): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=2, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
")\n"
]
}
],
"source": [
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"model.to(device)\n",
"\n",
"print(model)"
]
},
{
"cell_type": "markdown",
"id": "c4bbc9d7-65ec-4675-bab8-2e56eb0cfb55",
"metadata": {
"id": "c4bbc9d7-65ec-4675-bab8-2e56eb0cfb55"
},
"source": [
"- Based on the model architecture above, we can see that the model now contains our new `LinearWithLoRA` layers\n",
"- Also, since we initialized matrix $B$ with 0's, we expect the initial model performance to be unchanged compared to before"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "DAlrb_I00VEU",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
2024-05-19 20:11:56 -05:00
},
2024-05-19 20:19:02 -05:00
"id": "DAlrb_I00VEU",
"outputId": "3da44ac4-230b-4358-d996-30b63f0d962a"
},
"outputs": [
2024-05-19 20:11:56 -05:00
{
2024-05-19 20:19:02 -05:00
"name": "stdout",
"output_type": "stream",
"text": [
"Training accuracy: 46.25%\n",
"Validation accuracy: 45.00%\n",
"Test accuracy: 48.75%\n"
]
}
],
"source": [
"torch.manual_seed(123)\n",
"train_accuracy = calc_accuracy_loader(train_loader, model, device, num_batches=10)\n",
"val_accuracy = calc_accuracy_loader(val_loader, model, device, num_batches=10)\n",
"test_accuracy = calc_accuracy_loader(test_loader, model, device, num_batches=10)\n",
"\n",
"print(f\"Training accuracy: {train_accuracy*100:.2f}%\")\n",
"print(f\"Validation accuracy: {val_accuracy*100:.2f}%\")\n",
"print(f\"Test accuracy: {test_accuracy*100:.2f}%\")"
]
},
{
"cell_type": "markdown",
"id": "13735b3e-f0c3-4dba-ae3d-4141b2878101",
"metadata": {
"id": "13735b3e-f0c3-4dba-ae3d-4141b2878101"
},
"source": [
"- Let's now get to the interesting part and finetune the model by reusing the training function from chapter 6\n",
"- The training takes about 15 minutes on a M3 MacBook Air laptop computer and less than half a minute on a V100 or A100 GPU"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "wCParRvr0eff",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
2024-05-19 20:11:56 -05:00
},
2024-05-19 20:19:02 -05:00
"id": "wCParRvr0eff",
"outputId": "ce910a9c-ee89-48bb-bfa6-49c6aee1e450"
},
"outputs": [
2024-05-19 20:11:56 -05:00
{
2024-05-19 20:19:02 -05:00
"name": "stdout",
"output_type": "stream",
"text": [
"Ep 1 (Step 000000): Train loss 3.820, Val loss 3.462\n",
"Ep 1 (Step 000050): Train loss 0.396, Val loss 0.364\n",
"Ep 1 (Step 000100): Train loss 0.111, Val loss 0.229\n",
"Training accuracy: 97.50% | Validation accuracy: 95.00%\n",
"Ep 2 (Step 000150): Train loss 0.135, Val loss 0.073\n",
"Ep 2 (Step 000200): Train loss 0.008, Val loss 0.052\n",
"Ep 2 (Step 000250): Train loss 0.021, Val loss 0.179\n",
"Training accuracy: 97.50% | Validation accuracy: 97.50%\n",
"Ep 3 (Step 000300): Train loss 0.096, Val loss 0.080\n",
"Ep 3 (Step 000350): Train loss 0.010, Val loss 0.116\n",
"Training accuracy: 97.50% | Validation accuracy: 95.00%\n",
"Ep 4 (Step 000400): Train loss 0.003, Val loss 0.151\n",
"Ep 4 (Step 000450): Train loss 0.008, Val loss 0.077\n",
"Ep 4 (Step 000500): Train loss 0.001, Val loss 0.147\n",
"Training accuracy: 100.00% | Validation accuracy: 97.50%\n",
"Ep 5 (Step 000550): Train loss 0.007, Val loss 0.094\n",
"Ep 5 (Step 000600): Train loss 0.000, Val loss 0.056\n",
"Training accuracy: 100.00% | Validation accuracy: 97.50%\n",
"Training completed in 12.10 minutes.\n"
]
}
],
"source": [
"import time\n",
"from previous_chapters import train_classifier_simple\n",
2025-03-23 19:28:49 -05:00
"# Alternatively:\n",
"# from llms_from_scratch.ch06 import train_classifier_simple\n",
2024-05-19 20:19:02 -05:00
"\n",
"\n",
"start_time = time.time()\n",
"\n",
"torch.manual_seed(123)\n",
"\n",
"optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5, weight_decay=0.1)\n",
"\n",
"num_epochs = 5\n",
"train_losses, val_losses, train_accs, val_accs, examples_seen = train_classifier_simple(\n",
" model, train_loader, val_loader, optimizer, device,\n",
" num_epochs=num_epochs, eval_freq=50, eval_iter=5,\n",
")\n",
"\n",
"end_time = time.time()\n",
"execution_time_minutes = (end_time - start_time) / 60\n",
"print(f\"Training completed in {execution_time_minutes:.2f} minutes.\")"
]
},
{
"cell_type": "markdown",
"id": "d0c89e82-3aa8-44c6-b046-0b16200b8e6c",
"metadata": {
"id": "d0c89e82-3aa8-44c6-b046-0b16200b8e6c"
},
"source": [
"- Finally, let's evaluate the model"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "bawWGijA0iF3",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 308
2024-05-19 20:11:56 -05:00
},
2024-05-19 20:19:02 -05:00
"id": "bawWGijA0iF3",
"outputId": "af70782a-d605-4376-fa6c-d33b38979cfa"
},
"outputs": [
2024-05-19 20:11:56 -05:00
{
2024-05-19 20:19:02 -05:00
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAeoAAAEiCAYAAAA21pHjAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8fJSN1AAAACXBIWXMAAA9hAAAPYQGoP6dpAABO1UlEQVR4nO3deVxU5f7A8c/MwAz7IrIq4Ia4grhGLllaYmXprZvX6y0sb/0qzMws81aKdbvaem27VnbT223BNq1b5ppLmSYuKG64pILK4sYqDDDz/P4YGBhBBQRnwO/79TqvmTnnOed8z+PId57znHMejVJKIYQQQgiHpLV3AEIIIYS4OEnUQgghhAOTRC2EEEI4MEnUQgghhAOTRC2EEEI4MEnUQgghhAOTRC2EEEI4MEnUQgghhAOTRC2EEEI4MEnUQggbQ4cOZcqUKfYOQwhRQRK1EI1swoQJaDSaGlNcXJy9QxNCNENO9g5AiJYoLi6OhQsX2swzGAx2ikYI0ZxJi1qIJmAwGAgKCrKZfH19AVi3bh16vZ6ff/7ZWv6VV14hICCA7OxsAJYvX86gQYPw8fHBz8+P22+/ncOHD1vLHz16FI1GwxdffMHgwYNxdXWlX79+HDhwgOTkZPr27YuHhwcjR47k1KlT1vUmTJjA6NGjmT17Nv7+/nh5efHwww9TWlp60WMxGo1MmzaNNm3a4O7uzoABA1i3bp11+bFjxxg1ahS+vr64u7vTvXt3li1bdtHt/etf/yIiIgIXFxcCAwO5++67rcvMZjNz5syhffv2uLq6Eh0dzVdffWWz/u7duxk5ciQeHh4EBgZy7733cvr0aevyoUOHMnnyZJ5++mlatWpFUFAQiYmJF41HCEcniVqIq6yyD/jee+8lLy+PHTt28Pzzz/Phhx8SGBgIQFFREVOnTmXr1q2sWbMGrVbLmDFjMJvNNtuaNWsWzz33HNu3b8fJyYk///nPPP3007z55pv8/PPPHDp0iJkzZ9qss2bNGvbt28e6dev4/PPP+eabb5g9e/ZF4500aRKbNm0iKSmJXbt28cc//pG4uDgOHjwIQEJCAkajkQ0bNpCamsrLL7+Mh4dHrdvaunUrkydP5oUXXiAtLY3ly5czZMgQ6/I5c+bw8ccf895777Fnzx6eeOIJ/vKXv7B+/XoAcnNzuemmm4iJiWHr1q0sX76c7Oxs7rnnHpv9/Oc//8Hd3Z3ffvuNV155hRdeeIFVq1bV8V9ICAejhBCNKj4+Xul0OuXu7m4zvfTSS9YyRqNR9erVS91zzz2qW7du6sEHH7zkNk+dOqUAlZqaqpRS6siRIwpQH374obXM559/rgC1Zs0a67w5c+aoyMhIm9hatWqlioqKrPPmz5+vPDw8lMlkUkopdcMNN6jHH39cKaXUsWPHlE6nUydOnLCJZ9iwYWrGjBlKKaV69uypEhMT61Q3X3/9tfLy8lL5+fk1lpWUlCg3Nzf166+/2syfOHGiGjdunFJKqRdffFHdcsstNsszMjIUoNLS0qzxDxo0yKZMv3791PTp0+sUoxCORvqohWgCN954I/Pnz7eZ16pVK+t7vV7Pp59+SlRUFOHh4fzzn/+0KXvw4EFmzpzJb7/9xunTp60t6fT0dHr06GEtFxUVZX1f2Rrv2bOnzbycnBybbUdHR+Pm5mb9HBsbS2FhIRkZGYSHh9uUTU1NxWQy0blzZ5v5RqMRPz8/ACZPnswjjzzCypUrGT58OHfddZdNXNXdfPPNhIeH06FDB+Li4oiLi2PMmDG4ublx6NAhzp8/z80332yzTmlpKTExMQDs3LmTtWvX1tpiP3z4sDXOC/cfHBxcox6EaC4kUQvRBNzd3enUqdMly/z6668AnD17lrNnz+Lu7m5dNmrUKMLDw1mwYAEhISGYzWZ69OhRoy/Z2dnZ+l6j0dQ678LT5fVRWFiITqdj27Zt6HQ6m2WVyfKvf/0rI0aM4IcffmDlypXMmTOH119/nccee6zG9jw9Pdm+fTvr1q1j5cqVzJw5k8TERJKTkyksLATghx9+oE2bNjbrVV6IV1hYyKhRo3j55ZdrbDs4ONj6vnodwJXXgxD2JIlaCDs4fPgwTzzxBAsWLGDx4sXEx8ezevVqtFotZ86cIS0tjQULFjB48GAAfvnll0bb986dOykuLsbV1RWAzZs34+HhQWhoaI2yMTExmEwmcnJyrLHUJjQ0lIcffpiHH36YGTNmsGDBgloTNYCTkxPDhw9n+PDhzJo1Cx8fH3766SduvvlmDAYD6enp3HDDDbWu27t3b77++mvatWuHk5P8+RLXBvmmC9EEjEYjWVlZNvOcnJxo3bo1JpOJv/zlL4wYMYL777+fuLg4evbsyeuvv85TTz2Fr68vfn5+fPDBBwQHB5Oens4zzzzTaLGVlpYyceJEnnvuOY4ePcqsWbOYNGkSWm3Na0s7d+7M+PHjue+++3j99deJiYnh1KlTrFmzhqioKG677TamTJnCyJEj6dy5M+fOnWPt2rV07dq11n1///33/P777wwZMgRfX1+WLVuG2WwmMjIST09Ppk2bxhNPPIHZbGbQoEHk5eWxceNGvLy8iI+PJyEhgQULFjBu3DjrVd2HDh0iKSmJDz/8sEarX4iWQBK1EE1g+fLlNqdiASIjI9m/fz8vvfQSx44d4/vvvwcsp2w/+OADxo0bxy233EJ0dDRJSUlMnjyZHj16EBkZyVtvvcXQoUMbJbZhw4YRERHBkCFDMBqNjBs37pK3Ly1cuJC///3vPPnkk5w4cYLWrVtz3XXXcfvttwNgMplISEjg+PHjeHl5ERcXV6PPvZKPjw/ffPMNiYmJlJSUEBERweeff0737t0BePHFF/H392fOnDn8/vvv+Pj40Lt3b/72t78BEBISwsaNG5k+fTq33HILRqOR8PBw4uLiav2hIURLoFFKKXsHIYS4OiZMmEBubi5Lly61dyhCiDqSn6BCCCGEA5NELYQQQjgwOfUthBBCODBpUQshhBAOTBK1EEII4cAkUQshhBAOTBJ1hXfffZd27drh4uLCgAED2LJli71DanIbNmxg1KhRhISEoNFoatyyo5Ri5syZBAcH4+rqyvDhw60jJlU6e/Ys48ePx8vLCx8fHyZOnGh9FGSlXbt2MXjwYFxcXAgNDeWVV15p6kNrdHPmzKFfv354enoSEBDA6NGjSUtLsylTUlJCQkICfn5+eHh4cNddd1mHrayUnp7ObbfdhpubGwEBATz11FOUl5fblFm3bh29e/fGYDDQqVMnFi1a1NSH1+jmz59PVFQUXl5eeHl5ERsby48//mhdLnV1cXPnzkWj0TBlyhTrPKmvKomJiWg0GpupS5cu1uUtsq7sOiSIg0hKSlJ6vV599NFHas+ePerBBx9UPj4+Kjs7296hNally5apZ599Vn3zzTcKUEuWLLFZPnfuXOXt7a2WLl2qdu7cqe644w7Vvn17VVxcbC0TFxenoqOj1ebNm9XPP/+sOnXqZB3pSCml8vLyVGBgoBo/frzavXu3+vzzz5Wrq6t6//33r9ZhNooRI0aohQsXqt27d6uUlBR16623qrCwMFVYWGgt8/DDD6vQ0FC1Zs0atXXrVnXdddep66+/3rq8vLxc9ejRQw0fPlzt2LFDLVu2TLVu3do6CpVSSv3+++/Kzc1NTZ06Ve3du1e9/fbbSqfTqeXLl1/V471S3333nfrhhx/UgQMHVFpamvrb3/6mnJ2d1e7du5VSUlcXs2XLFtWuXTsVFRVlHcFMKamv6mbNmqW6d++uMjMzrdOpU6esy1tiXUmiVkr1799fJSQkWD+bTCYVEhKi5syZY8eorq4LE7XZbFZBQUHq1Vdftc7Lzc1VBoNBff7550oppfbu3asAlZycbC3z448/Ko1GYx0W8V//+pfy9fVVRqPRWmb69Ok2Qy82Rzk5OQpQ69evV0pZ6sbZ2Vl9+eWX1jL79u1TgNq0aZNSyvLDSKvVqqysLGuZ+fPnKy8vL2v9PP3006p79+42+xo7dqwaMWJEUx9Sk/P19VUffvih1NVFFBQUqIiICLVq1SqboUalvmz
"text/plain": [
"<Figure size 500x300 with 2 Axes>"
2024-05-05 12:05:17 -05:00
]
2024-05-19 20:19:02 -05:00
},
"metadata": {},
"output_type": "display_data"
2024-05-05 12:05:17 -05:00
}
2024-05-19 20:19:02 -05:00
],
"source": [
"from previous_chapters import plot_values\n",
2025-03-23 19:28:49 -05:00
"# Alternatively:\n",
"# from llms_from_scratch.ch06 import plot_values\n",
2024-05-19 20:19:02 -05:00
"\n",
"epochs_tensor = torch.linspace(0, num_epochs, len(train_losses))\n",
"examples_seen_tensor = torch.linspace(0, examples_seen, len(train_losses))\n",
"\n",
"plot_values(epochs_tensor, examples_seen_tensor, train_losses, val_losses, label=\"loss\")"
]
},
{
"cell_type": "markdown",
"id": "aa074723-e3f7-4f7e-a267-855531a037dc",
"metadata": {
"id": "aa074723-e3f7-4f7e-a267-855531a037dc"
},
"source": [
"- Note that we previously calculated the accuracy values on 5 batches only via the `eval_iter=5` setting; below, we calculate the accuracies on the full dataset"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "1D2awlEq0gZi",
"metadata": {
2024-05-05 12:05:17 -05:00
"colab": {
2024-05-19 20:19:02 -05:00
"base_uri": "https://localhost:8080/"
2024-05-19 20:11:56 -05:00
},
2024-05-19 20:19:02 -05:00
"id": "1D2awlEq0gZi",
"outputId": "d603eda1-d912-43eb-ec9c-af6a622510a0"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Training accuracy: 100.00%\n",
"Validation accuracy: 96.64%\n",
"Test accuracy: 97.33%\n"
]
2024-05-05 12:05:17 -05:00
}
2024-05-19 20:19:02 -05:00
],
"source": [
"train_accuracy = calc_accuracy_loader(train_loader, model, device)\n",
"val_accuracy = calc_accuracy_loader(val_loader, model, device)\n",
"test_accuracy = calc_accuracy_loader(test_loader, model, device)\n",
"\n",
"print(f\"Training accuracy: {train_accuracy*100:.2f}%\")\n",
"print(f\"Validation accuracy: {val_accuracy*100:.2f}%\")\n",
"print(f\"Test accuracy: {test_accuracy*100:.2f}%\")"
]
},
{
"cell_type": "markdown",
"id": "1f87f5e6-339e-4fcf-900b-6d845d3c713d",
"metadata": {
"id": "1f87f5e6-339e-4fcf-900b-6d845d3c713d"
},
"source": [
"- As we can see based on the relatively high accuracy values above, the LoRA finetuning was successful"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"gpuType": "V100",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
2024-05-05 12:05:17 -05:00
},
2024-05-19 20:19:02 -05:00
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
2025-02-16 13:16:51 -06:00
"version": "3.10.16"
2024-05-19 20:19:02 -05:00
}
},
"nbformat": 4,
"nbformat_minor": 5
}