Sebastian Raschka d6c3990c57
Training on MPS in PyTorch 2.9 (#900)
* Training on MPS in PyTorch 2.9

* update
2025-11-01 16:55:09 -05:00

1568 lines
79 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"id": "c024bfa4-1a7a-4751-b5a1-827225a3478b",
"metadata": {
"id": "c024bfa4-1a7a-4751-b5a1-827225a3478b"
},
"source": [
"<table style=\"width:100%\">\n",
"<tr>\n",
"<td style=\"vertical-align:middle; text-align:left;\">\n",
"<font size=\"2\">\n",
"Supplementary code for the <a href=\"http://mng.bz/orYv\">Build a Large Language Model From Scratch</a> book by <a href=\"https://sebastianraschka.com\">Sebastian Raschka</a><br>\n",
"<br>Code repository: <a href=\"https://github.com/rasbt/LLMs-from-scratch\">https://github.com/rasbt/LLMs-from-scratch</a>\n",
"</font>\n",
"</td>\n",
"<td style=\"vertical-align:middle; text-align:left;\">\n",
"<a href=\"http://mng.bz/orYv\"><img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/cover-small.webp\" width=\"100px\"></a>\n",
"</td>\n",
"</tr>\n",
"</table>\n"
]
},
{
"cell_type": "markdown",
"id": "58b8c870-fb72-490e-8916-d8129bd5d1ff",
"metadata": {
"id": "58b8c870-fb72-490e-8916-d8129bd5d1ff"
},
"source": [
"# Appendix E: Parameter-efficient Finetuning with LoRA"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "5b7e01c2-1c84-4f2a-bb51-2e0b74abda90",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "5b7e01c2-1c84-4f2a-bb51-2e0b74abda90",
"outputId": "316166b4-027a-4756-e9b4-fe88ae75dd4f"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"matplotlib version: 3.10.7\n",
"numpy version: 2.3.4\n",
"tiktoken version: 0.12.0\n",
"torch version: 2.9.0\n",
"tensorflow version: 2.20.0\n",
"pandas version: 2.3.3\n"
]
}
],
"source": [
"from importlib.metadata import version\n",
"\n",
"pkgs = [\"matplotlib\",\n",
" \"numpy\",\n",
" \"tiktoken\",\n",
" \"torch\",\n",
" \"tensorflow\", # For OpenAI's pretrained weights\n",
" \"pandas\" # Dataset loading\n",
" ]\n",
"for p in pkgs:\n",
" print(f\"{p} version: {version(p)}\")"
]
},
{
"cell_type": "markdown",
"id": "21532056-0ef4-4c98-82c7-e91f61c6485e",
"metadata": {
"id": "21532056-0ef4-4c98-82c7-e91f61c6485e"
},
"source": [
"## E.1 Introduction to LoRA"
]
},
{
"cell_type": "markdown",
"id": "66edc999-3d91-4a1c-a157-9d056392e8d8",
"metadata": {
"id": "66edc999-3d91-4a1c-a157-9d056392e8d8"
},
"source": [
"- No code in this section\n",
"- Low-rank adaptation (LoRA) is a machine learning technique that modifies a pretrained model to better suit a specific, often smaller, dataset by adjusting only a small, low-rank subset of the model's parameters\n",
"- This approach is important because it allows for efficient finetuning of large models on task-specific data, significantly reducing the computational cost and time required for finetuning"
]
},
{
"cell_type": "markdown",
"id": "5bb75b5d-d59c-4948-821a-1594a5883dc1",
"metadata": {
"id": "5bb75b5d-d59c-4948-821a-1594a5883dc1"
},
"source": [
"- Suppose we have a large weight matrix $W$ for a given layer\n",
"- During backpropagation, we learn a $\\Delta W$ matrix, which contains information on how much we want to update the original weights to minimize the loss function during training\n",
"- In regular training and finetuning, the weight update is defined as follows:\n",
"\n",
"$$W_{\\text{updated}} = W + \\Delta W$$\n",
"\n",
"- The LoRA method proposed by [Hu et al.](https://arxiv.org/abs/2106.09685) offers a more efficient alternative to computing the weight updates $\\Delta W$ by learning an approximation of it, $\\Delta W \\approx AB$.\n",
"- In other words, in LoRA, we have the following, where $A$ and $B$ are two small weight matrices:\n",
"\n",
"$$W_{\\text{updated}} = W + AB$$\n",
"\n",
"- The figure below illustrates these formulas for full finetuning and LoRA side by side"
]
},
{
"cell_type": "markdown",
"id": "a8a7419d-cae9-4525-bb44-1641f6ef4f3b",
"metadata": {
"id": "a8a7419d-cae9-4525-bb44-1641f6ef4f3b"
},
"source": [
"<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/appendix-e_compressed/lora-1.webp\" width=\"500px\">"
]
},
{
"cell_type": "markdown",
"id": "4edd43c9-8ec5-48e6-b3fc-5fb3c16037cc",
"metadata": {
"id": "4edd43c9-8ec5-48e6-b3fc-5fb3c16037cc"
},
"source": [
"- If you paid close attention, the full finetuning and LoRA depictions in the figure above look slightly different from the formulas I have shown earlier\n",
"- That's due to the distributive law of matrix multiplication: we don't have to add the weights with the updated weights but can keep them separate\n",
"- For instance, if $x$ is the input data, then we can write the following for regular finetuning:\n",
"\n",
"$$x (W+\\Delta W) = x W + x \\Delta W$$\n",
"\n",
"- Similarly, we can write the following for LoRA:\n",
"\n",
"$$x (W+A B) = x W + x A B$$\n",
"\n",
"- The fact that we can keep the LoRA weight matrices separate makes LoRA especially attractive\n",
"- In practice, this means that we don't have to modify the weights of the pretrained model at all, as we can apply the LoRA matrices on the fly\n",
"- After setting up the dataset and loading the model, we will implement LoRA in the code to make these concepts less abstract"
]
},
{
"cell_type": "markdown",
"id": "8c7017a2-32aa-4002-a2f3-12aac293ccdf",
"metadata": {
"id": "8c7017a2-32aa-4002-a2f3-12aac293ccdf"
},
"source": [
"## E.2 Preparing the dataset"
]
},
{
"cell_type": "markdown",
"id": "669c64df-4431-4d27-834d-2bb38a01fc02",
"metadata": {
"id": "669c64df-4431-4d27-834d-2bb38a01fc02"
},
"source": [
"- This section repeats the code from chapter 6 to load and prepare the dataset\n",
"- Instead of repeating this code, one could open and run the chapter 6 notebook and then insert the LoRA code from section E.4 there\n",
"- (The LoRA code was originally the last section of chapter 6 but was moved to the appendix due to the length of chapter 6)\n",
"- In a similar fashion, we could also apply LoRA to the models in chapter 7 for instruction finetuning"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "def7c09b-af9c-4216-90ce-5e67aed1065c",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "def7c09b-af9c-4216-90ce-5e67aed1065c",
"outputId": "a67a7afe-b401-4463-c731-87025d20f72d"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"sms_spam_collection/SMSSpamCollection.tsv already exists. Skipping download and extraction.\n"
]
}
],
"source": [
"# import urllib\n",
"import requests\n",
"from pathlib import Path\n",
"import pandas as pd\n",
"from previous_chapters import (\n",
" download_and_unzip_spam_data,\n",
" create_balanced_dataset,\n",
" random_split\n",
")\n",
"# If the `previous_chapters.py` file is not available locally,\n",
"# you can import it from the `llms-from-scratch` PyPI package.\n",
"# For details, see: https://github.com/rasbt/LLMs-from-scratch/tree/main/pkg\n",
"# E.g.,\n",
"# from llms_from_scratch.ch06 import (\n",
"# download_and_unzip_spam_data,\n",
"# create_balanced_dataset,\n",
"# random_split\n",
"# )\n",
"\n",
"\n",
"\n",
"url = \"https://archive.ics.uci.edu/static/public/228/sms+spam+collection.zip\"\n",
"zip_path = \"sms_spam_collection.zip\"\n",
"extracted_path = \"sms_spam_collection\"\n",
"data_file_path = Path(extracted_path) / \"SMSSpamCollection.tsv\"\n",
"\n",
"\n",
"try:\n",
" download_and_unzip_spam_data(url, zip_path, extracted_path, data_file_path)\n",
"except (requests.exceptions.RequestException, TimeoutError) as e:\n",
" print(f\"Primary URL failed: {e}. Trying backup URL...\")\n",
" url = \"https://f001.backblazeb2.com/file/LLMs-from-scratch/sms%2Bspam%2Bcollection.zip\"\n",
" download_and_unzip_spam_data(url, zip_path, extracted_path, data_file_path)\n",
"\n",
"# The book originally used\n",
"# except (urllib.error.HTTPError, urllib.error.URLError, TimeoutError) as e:\n",
"# in the code above.\n",
"# However, some VPN users reported issues with `urllib`, so the code was updated\n",
"# to use `requests` instead\n",
"\n",
"df = pd.read_csv(data_file_path, sep=\"\\t\", header=None, names=[\"Label\", \"Text\"])\n",
"balanced_df = create_balanced_dataset(df)\n",
"balanced_df[\"Label\"] = balanced_df[\"Label\"].map({\"ham\": 0, \"spam\": 1})\n",
"\n",
"train_df, validation_df, test_df = random_split(balanced_df, 0.7, 0.1)\n",
"train_df.to_csv(\"train.csv\", index=None)\n",
"validation_df.to_csv(\"validation.csv\", index=None)\n",
"test_df.to_csv(\"test.csv\", index=None)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "74c3c463-8763-4cc0-9320-41c7eaad8ab7",
"metadata": {
"id": "74c3c463-8763-4cc0-9320-41c7eaad8ab7"
},
"outputs": [],
"source": [
"import torch\n",
"import tiktoken\n",
"from previous_chapters import SpamDataset\n",
"\n",
"\n",
"tokenizer = tiktoken.get_encoding(\"gpt2\")\n",
"train_dataset = SpamDataset(\"train.csv\", max_length=None, tokenizer=tokenizer)\n",
"val_dataset = SpamDataset(\"validation.csv\", max_length=train_dataset.max_length, tokenizer=tokenizer)\n",
"test_dataset = SpamDataset(\"test.csv\", max_length=train_dataset.max_length, tokenizer=tokenizer)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "8681adc0-6f02-4e75-b01a-a6ab75d05542",
"metadata": {
"id": "8681adc0-6f02-4e75-b01a-a6ab75d05542"
},
"outputs": [],
"source": [
"from torch.utils.data import DataLoader\n",
"\n",
"num_workers = 0\n",
"batch_size = 8\n",
"\n",
"torch.manual_seed(123)\n",
"\n",
"train_loader = DataLoader(\n",
" dataset=train_dataset,\n",
" batch_size=batch_size,\n",
" shuffle=True,\n",
" num_workers=num_workers,\n",
" drop_last=True,\n",
")\n",
"\n",
"val_loader = DataLoader(\n",
" dataset=val_dataset,\n",
" batch_size=batch_size,\n",
" num_workers=num_workers,\n",
" drop_last=False,\n",
")\n",
"\n",
"test_loader = DataLoader(\n",
" dataset=test_dataset,\n",
" batch_size=batch_size,\n",
" num_workers=num_workers,\n",
" drop_last=False,\n",
")"
]
},
{
"cell_type": "markdown",
"id": "ab7335db-e0bb-4e27-80c5-eea11e593a57",
"metadata": {
"id": "ab7335db-e0bb-4e27-80c5-eea11e593a57"
},
"source": [
"- As a verification step, we iterate through the data loaders and check that the batches contain 8 training examples each, where each training example consists of 120 tokens"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "4dee6882-4c3a-4964-af15-fa31f86ad047",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "4dee6882-4c3a-4964-af15-fa31f86ad047",
"outputId": "2ae34de1-dd01-4f99-d2c8-ba4dca400754"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train loader:\n",
"Input batch dimensions: torch.Size([8, 120])\n",
"Label batch dimensions torch.Size([8])\n"
]
}
],
"source": [
"print(\"Train loader:\")\n",
"for input_batch, target_batch in train_loader:\n",
" pass\n",
"\n",
"print(\"Input batch dimensions:\", input_batch.shape)\n",
"print(\"Label batch dimensions\", target_batch.shape)"
]
},
{
"cell_type": "markdown",
"id": "5cdd7947-7039-49bf-8a5e-c0a2f4281ca1",
"metadata": {
"id": "5cdd7947-7039-49bf-8a5e-c0a2f4281ca1"
},
"source": [
"- Lastly, let's print the total number of batches in each dataset"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "IZfw-TYD2zTj",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "IZfw-TYD2zTj",
"outputId": "4d19ed61-cf7a-4ec4-b822-c847dd1c5d77"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"130 training batches\n",
"19 validation batches\n",
"38 test batches\n"
]
}
],
"source": [
"print(f\"{len(train_loader)} training batches\")\n",
"print(f\"{len(val_loader)} validation batches\")\n",
"print(f\"{len(test_loader)} test batches\")"
]
},
{
"cell_type": "markdown",
"id": "dec9aa4a-ffd2-4d9f-a835-cce1059fe604",
"metadata": {
"id": "dec9aa4a-ffd2-4d9f-a835-cce1059fe604"
},
"source": [
"## E.3 Initializing the model"
]
},
{
"cell_type": "markdown",
"id": "f36ebdaf-810e-46a2-9ad9-e017a04051b1",
"metadata": {
"id": "f36ebdaf-810e-46a2-9ad9-e017a04051b1"
},
"source": [
"- This section repeats the code from chapter 6 to load and prepare the model"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "02b3a506-3879-4258-82b5-93a5b6bafa74",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "02b3a506-3879-4258-82b5-93a5b6bafa74",
"outputId": "b8c9b125-bb52-45d3-8071-fa5054dbf5a9"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"File already exists and is up-to-date: gpt2/124M/checkpoint\n",
"File already exists and is up-to-date: gpt2/124M/encoder.json\n",
"File already exists and is up-to-date: gpt2/124M/hparams.json\n",
"File already exists and is up-to-date: gpt2/124M/model.ckpt.data-00000-of-00001\n",
"File already exists and is up-to-date: gpt2/124M/model.ckpt.index\n",
"File already exists and is up-to-date: gpt2/124M/model.ckpt.meta\n",
"File already exists and is up-to-date: gpt2/124M/vocab.bpe\n"
]
}
],
"source": [
"from gpt_download import download_and_load_gpt2\n",
"from previous_chapters import GPTModel, load_weights_into_gpt\n",
"# Alternatively:\n",
"# from llms_from_scratch.ch04 import GPTModel\n",
"# from llms_from_scratch.ch05 import load_weights_into_gpt\n",
"\n",
"\n",
"\n",
"CHOOSE_MODEL = \"gpt2-small (124M)\"\n",
"INPUT_PROMPT = \"Every effort moves\"\n",
"\n",
"BASE_CONFIG = {\n",
" \"vocab_size\": 50257, # Vocabulary size\n",
" \"context_length\": 1024, # Context length\n",
" \"drop_rate\": 0.0, # Dropout rate\n",
" \"qkv_bias\": True # Query-key-value bias\n",
"}\n",
"\n",
"model_configs = {\n",
" \"gpt2-small (124M)\": {\"emb_dim\": 768, \"n_layers\": 12, \"n_heads\": 12},\n",
" \"gpt2-medium (355M)\": {\"emb_dim\": 1024, \"n_layers\": 24, \"n_heads\": 16},\n",
" \"gpt2-large (774M)\": {\"emb_dim\": 1280, \"n_layers\": 36, \"n_heads\": 20},\n",
" \"gpt2-xl (1558M)\": {\"emb_dim\": 1600, \"n_layers\": 48, \"n_heads\": 25},\n",
"}\n",
"\n",
"BASE_CONFIG.update(model_configs[CHOOSE_MODEL])\n",
"\n",
"model_size = CHOOSE_MODEL.split(\" \")[-1].lstrip(\"(\").rstrip(\")\")\n",
"settings, params = download_and_load_gpt2(model_size=model_size, models_dir=\"gpt2\")\n",
"\n",
"model = GPTModel(BASE_CONFIG)\n",
"load_weights_into_gpt(model, params)\n",
"model.eval();"
]
},
{
"cell_type": "markdown",
"id": "252614cd-7ce6-4908-83e6-3761f519904e",
"metadata": {
"id": "252614cd-7ce6-4908-83e6-3761f519904e"
},
"source": [
"- To ensure that the model was loaded corrected, let's double-check that it generates coherent text"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "8b6ce20c-0700-4783-8be0-4cf17c200a7f",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "8b6ce20c-0700-4783-8be0-4cf17c200a7f",
"outputId": "28ccbca5-8de9-41a0-c093-da00fcbaa91c"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Every effort moves you forward.\n",
"\n",
"The first step is to understand the importance of your work\n"
]
}
],
"source": [
"from previous_chapters import (\n",
" generate_text_simple,\n",
" text_to_token_ids,\n",
" token_ids_to_text\n",
")\n",
"\n",
"\n",
"text_1 = \"Every effort moves you\"\n",
"\n",
"token_ids = generate_text_simple(\n",
" model=model,\n",
" idx=text_to_token_ids(text_1, tokenizer),\n",
" max_new_tokens=15,\n",
" context_size=BASE_CONFIG[\"context_length\"]\n",
")\n",
"\n",
"print(token_ids_to_text(token_ids, tokenizer))"
]
},
{
"cell_type": "markdown",
"id": "8174b31b-1ab5-4115-b01c-245369da5af3",
"metadata": {
"id": "8174b31b-1ab5-4115-b01c-245369da5af3"
},
"source": [
"- Then, we prepare the model for classification finetuning similar to chapter 6, where we replace the output layer"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "e255ce91-d73a-4854-90a4-95804928eb16",
"metadata": {
"id": "e255ce91-d73a-4854-90a4-95804928eb16"
},
"outputs": [],
"source": [
"torch.manual_seed(123)\n",
"\n",
"num_classes = 2\n",
"model.out_head = torch.nn.Linear(in_features=768, out_features=num_classes)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "02e6f057-1383-4ece-8444-0a88e71ac75d",
"metadata": {
"id": "02e6f057-1383-4ece-8444-0a88e71ac75d"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Device: mps\n"
]
}
],
"source": [
"if torch.cuda.is_available():\n",
" device = torch.device(\"cuda\")\n",
"elif torch.backends.mps.is_available():\n",
" # Use PyTorch 2.9 or newer for stable mps results\n",
" major, minor = map(int, torch.__version__.split(\".\")[:2])\n",
" if (major, minor) >= (2, 9):\n",
" device = torch.device(\"mps\")\n",
"else:\n",
" device = torch.device(\"cpu\")\n",
"\n",
"print(\"Device:\", device)\n",
"\n",
"model.to(device); # no assignment model = model.to(device) necessary for nn.Module classes"
]
},
{
"cell_type": "markdown",
"id": "8e951cd6-5e42-44d2-b21f-895cb61004fe",
"metadata": {
"id": "8e951cd6-5e42-44d2-b21f-895cb61004fe"
},
"source": [
"- Lastly, let's calculate the initial classification accuracy of the non-finetuned model (we expect this to be around 50%, which means that the model is not able to distinguish between spam and non-spam messages yet reliably)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "fc7dd72c-73a2-4881-ade0-0a9605f1ab8c",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "fc7dd72c-73a2-4881-ade0-0a9605f1ab8c",
"outputId": "74848515-5a49-4125-fecb-9f4bac23f812"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Training accuracy: 46.25%\n",
"Validation accuracy: 45.00%\n",
"Test accuracy: 48.75%\n"
]
}
],
"source": [
"from previous_chapters import calc_accuracy_loader\n",
"# Alternatively:\n",
"# from llms_from_scratch.ch06 import calc_accuracy_loader\n",
"\n",
"\n",
"\n",
"torch.manual_seed(123)\n",
"train_accuracy = calc_accuracy_loader(train_loader, model, device, num_batches=10)\n",
"val_accuracy = calc_accuracy_loader(val_loader, model, device, num_batches=10)\n",
"test_accuracy = calc_accuracy_loader(test_loader, model, device, num_batches=10)\n",
"\n",
"print(f\"Training accuracy: {train_accuracy*100:.2f}%\")\n",
"print(f\"Validation accuracy: {val_accuracy*100:.2f}%\")\n",
"print(f\"Test accuracy: {test_accuracy*100:.2f}%\")"
]
},
{
"cell_type": "markdown",
"id": "398a1ec9-e2a1-43d6-bf9f-12ee54b46a7b",
"metadata": {
"id": "398a1ec9-e2a1-43d6-bf9f-12ee54b46a7b"
},
"source": [
"## E.4 Parameter-efficient finetuning with LoRA"
]
},
{
"cell_type": "markdown",
"id": "652a4a82-61ef-4d0a-9858-8988e844f12c",
"metadata": {
"id": "652a4a82-61ef-4d0a-9858-8988e844f12c"
},
"source": [
"- We begin by initializing a LoRALayer that creates the matrices $A$ and $B$, along with the `alpha` scaling hyperparameter and the `rank` ($r$) hyperparameters\n",
"- This layer can accept an input and compute the corresponding output, as illustrated in the figure below\n",
"\n",
"<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/appendix-e_compressed/lora-2.webp\" width=\"200px\">\n",
"\n",
"In code, this LoRA layer depicted in the figure above looks like as follows"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "2ds9ywjMwvIW",
"metadata": {
"id": "2ds9ywjMwvIW"
},
"outputs": [],
"source": [
"import math\n",
"\n",
"class LoRALayer(torch.nn.Module):\n",
" def __init__(self, in_dim, out_dim, rank, alpha):\n",
" super().__init__()\n",
" self.A = torch.nn.Parameter(torch.empty(in_dim, rank))\n",
" torch.nn.init.kaiming_uniform_(self.A, a=math.sqrt(5)) # similar to standard weight initialization\n",
" self.B = torch.nn.Parameter(torch.zeros(rank, out_dim))\n",
" self.alpha = alpha\n",
" self.rank = rank\n",
"\n",
" def forward(self, x):\n",
" # Note: The original chapter didn't include the scaling by self.rank\n",
" # This scaling is not necessary, but it's more canonical and convenient\n",
" # as this lets us compare runs across different ranks without retuning learning rates\n",
" x = (self.alpha / self.rank) * (x @ self.A @ self.B)\n",
" return x"
]
},
{
"cell_type": "markdown",
"id": "ad21faa8-0614-4257-93cd-68952193e14a",
"metadata": {
"id": "ad21faa8-0614-4257-93cd-68952193e14a"
},
"source": [
"- In the code above, `rank` is a hyperparameter that controls the inner dimension of the matrices $A$ and $B$\n",
"- In other words, this parameter controls the number of additional parameters introduced by LoRA and is a key factor in determining the balance between model adaptability and parameter efficiency\n",
"- The second hyperparameter, `alpha`, is a scaling hyperparameter applied to the output of the low-rank adaptation\n",
"- It essentially controls the extent to which the adapted layer's output is allowed to influence the original output of the layer being adapted\n",
"- This can be seen as a way to regulate the impact of the low-rank adaptation on the layer's output\n",
"- So far, the `LoRALayer` class we implemented above allows us to transform the layer inputs $x$\n",
"- However, in LoRA, we are usually interested in replacing existing `Linear` layers so that the weight update is applied to the existing pretrained weights, as shown in the figure below\n",
"\n",
"<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/appendix-e_compressed/lora-3.webp\" width=\"200px\">"
]
},
{
"cell_type": "markdown",
"id": "3e6d5da0-dfce-4808-b89b-29ff333f563f",
"metadata": {
"id": "3e6d5da0-dfce-4808-b89b-29ff333f563f"
},
"source": [
"- To incorporate the original `Linear` layer weights as shown in the figure above, we implement a `LinearWithLoRA` layer below that uses the previously implemented LoRALayer and can be used to replace existing `Linear` layers in a neural network, for example, the self-attention module or feed forward modules in an LLM"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "127d3a64-8359-4b21-b056-78d58cc75fe8",
"metadata": {
"id": "127d3a64-8359-4b21-b056-78d58cc75fe8"
},
"outputs": [],
"source": [
"class LinearWithLoRA(torch.nn.Module):\n",
" def __init__(self, linear, rank, alpha):\n",
" super().__init__()\n",
" self.linear = linear\n",
" self.lora = LoRALayer(\n",
" linear.in_features, linear.out_features, rank, alpha\n",
" )\n",
"\n",
" def forward(self, x):\n",
" return self.linear(x) + self.lora(x)"
]
},
{
"cell_type": "markdown",
"id": "e1145a90-35ff-462c-820b-15483fa5b051",
"metadata": {
"id": "e1145a90-35ff-462c-820b-15483fa5b051"
},
"source": [
"- Note that since we initialize the weight matrix $B$ (`self.B` in `LoRALayer`) with zero values in the LoRA layer, the matrix multiplication between $A$ and $B$ results in a matrix consisting of 0's and doesn't affect the original weights (since adding 0 to the original weights does not modify them)"
]
},
{
"cell_type": "markdown",
"id": "e98a6d36-7bc9-434c-a7f1-533f26aff06d",
"metadata": {
"id": "e98a6d36-7bc9-434c-a7f1-533f26aff06d"
},
"source": [
"- To try LoRA on the GPT model we defined earlier, we define a `replace_linear_with_lora` function to replace all `Linear` layers in the model with the new `LinearWithLoRA` layers\n",
"\n",
"<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/appendix-e_compressed/lora-4.webp\" width=\"400px\">"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "WlQZ8ygqzN_g",
"metadata": {
"id": "WlQZ8ygqzN_g"
},
"outputs": [],
"source": [
"def replace_linear_with_lora(model, rank, alpha):\n",
" for name, module in model.named_children():\n",
" if isinstance(module, torch.nn.Linear):\n",
" # Replace the Linear layer with LinearWithLoRA\n",
" setattr(model, name, LinearWithLoRA(module, rank, alpha))\n",
" else:\n",
" # Recursively apply the same function to child modules\n",
" replace_linear_with_lora(module, rank, alpha)"
]
},
{
"cell_type": "markdown",
"id": "8c172164-cdde-4489-b7d7-aaed9cc2f5f2",
"metadata": {
"id": "8c172164-cdde-4489-b7d7-aaed9cc2f5f2"
},
"source": [
"- We then freeze the original model parameter and use the `replace_linear_with_lora` to replace the said `Linear` layers using the code below\n",
"- This will replace the `Linear` layers in the LLM with `LinearWithLoRA` layers"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "dbe15350-4da9-4829-9d23-98bbd3d0b1a1",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "dbe15350-4da9-4829-9d23-98bbd3d0b1a1",
"outputId": "fd4c208f-854a-4701-d9d3-9d73af733364"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Total trainable parameters before: 124,441,346\n",
"Total trainable parameters after: 0\n"
]
}
],
"source": [
"total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
"print(f\"Total trainable parameters before: {total_params:,}\")\n",
"\n",
"for param in model.parameters():\n",
" param.requires_grad = False\n",
"\n",
"total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
"print(f\"Total trainable parameters after: {total_params:,}\")"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "mLk_fPq0yz_u",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "mLk_fPq0yz_u",
"outputId": "0a93b8fc-05d7-4ace-ee47-e2fc6bdd7d75"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Total trainable LoRA parameters: 2,666,528\n"
]
}
],
"source": [
"replace_linear_with_lora(model, rank=16, alpha=16)\n",
"\n",
"total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
"print(f\"Total trainable LoRA parameters: {total_params:,}\")"
]
},
{
"cell_type": "markdown",
"id": "b8b6819e-ef7a-4f0d-841a-1b467496bef9",
"metadata": {
"id": "b8b6819e-ef7a-4f0d-841a-1b467496bef9"
},
"source": [
"- As we can see, we reduced the number of trainable parameters by almost 50x when using LoRA\n",
"- Let's now double-check whether the layers have been modified as intended by printing the model architecture"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "1711be61-bb2c-466f-9b5b-24f4aa5ccd9c",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "1711be61-bb2c-466f-9b5b-24f4aa5ccd9c",
"outputId": "acff8eca-3775-45a2-b62d-032a986ef037"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"GPTModel(\n",
" (tok_emb): Embedding(50257, 768)\n",
" (pos_emb): Embedding(1024, 768)\n",
" (drop_emb): Dropout(p=0.0, inplace=False)\n",
" (trf_blocks): Sequential(\n",
" (0): TransformerBlock(\n",
" (att): MultiHeadAttention(\n",
" (W_query): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (W_key): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (W_value): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (out_proj): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (dropout): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (ff): FeedForward(\n",
" (layers): Sequential(\n",
" (0): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=3072, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (1): GELU()\n",
" (2): LinearWithLoRA(\n",
" (linear): Linear(in_features=3072, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" )\n",
" )\n",
" (norm1): LayerNorm()\n",
" (norm2): LayerNorm()\n",
" (drop_resid): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (1): TransformerBlock(\n",
" (att): MultiHeadAttention(\n",
" (W_query): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (W_key): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (W_value): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (out_proj): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (dropout): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (ff): FeedForward(\n",
" (layers): Sequential(\n",
" (0): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=3072, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (1): GELU()\n",
" (2): LinearWithLoRA(\n",
" (linear): Linear(in_features=3072, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" )\n",
" )\n",
" (norm1): LayerNorm()\n",
" (norm2): LayerNorm()\n",
" (drop_resid): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (2): TransformerBlock(\n",
" (att): MultiHeadAttention(\n",
" (W_query): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (W_key): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (W_value): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (out_proj): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (dropout): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (ff): FeedForward(\n",
" (layers): Sequential(\n",
" (0): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=3072, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (1): GELU()\n",
" (2): LinearWithLoRA(\n",
" (linear): Linear(in_features=3072, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" )\n",
" )\n",
" (norm1): LayerNorm()\n",
" (norm2): LayerNorm()\n",
" (drop_resid): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (3): TransformerBlock(\n",
" (att): MultiHeadAttention(\n",
" (W_query): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (W_key): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (W_value): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (out_proj): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (dropout): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (ff): FeedForward(\n",
" (layers): Sequential(\n",
" (0): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=3072, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (1): GELU()\n",
" (2): LinearWithLoRA(\n",
" (linear): Linear(in_features=3072, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" )\n",
" )\n",
" (norm1): LayerNorm()\n",
" (norm2): LayerNorm()\n",
" (drop_resid): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (4): TransformerBlock(\n",
" (att): MultiHeadAttention(\n",
" (W_query): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (W_key): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (W_value): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (out_proj): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (dropout): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (ff): FeedForward(\n",
" (layers): Sequential(\n",
" (0): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=3072, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (1): GELU()\n",
" (2): LinearWithLoRA(\n",
" (linear): Linear(in_features=3072, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" )\n",
" )\n",
" (norm1): LayerNorm()\n",
" (norm2): LayerNorm()\n",
" (drop_resid): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (5): TransformerBlock(\n",
" (att): MultiHeadAttention(\n",
" (W_query): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (W_key): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (W_value): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (out_proj): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (dropout): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (ff): FeedForward(\n",
" (layers): Sequential(\n",
" (0): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=3072, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (1): GELU()\n",
" (2): LinearWithLoRA(\n",
" (linear): Linear(in_features=3072, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" )\n",
" )\n",
" (norm1): LayerNorm()\n",
" (norm2): LayerNorm()\n",
" (drop_resid): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (6): TransformerBlock(\n",
" (att): MultiHeadAttention(\n",
" (W_query): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (W_key): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (W_value): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (out_proj): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (dropout): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (ff): FeedForward(\n",
" (layers): Sequential(\n",
" (0): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=3072, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (1): GELU()\n",
" (2): LinearWithLoRA(\n",
" (linear): Linear(in_features=3072, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" )\n",
" )\n",
" (norm1): LayerNorm()\n",
" (norm2): LayerNorm()\n",
" (drop_resid): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (7): TransformerBlock(\n",
" (att): MultiHeadAttention(\n",
" (W_query): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (W_key): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (W_value): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (out_proj): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (dropout): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (ff): FeedForward(\n",
" (layers): Sequential(\n",
" (0): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=3072, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (1): GELU()\n",
" (2): LinearWithLoRA(\n",
" (linear): Linear(in_features=3072, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" )\n",
" )\n",
" (norm1): LayerNorm()\n",
" (norm2): LayerNorm()\n",
" (drop_resid): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (8): TransformerBlock(\n",
" (att): MultiHeadAttention(\n",
" (W_query): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (W_key): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (W_value): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (out_proj): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (dropout): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (ff): FeedForward(\n",
" (layers): Sequential(\n",
" (0): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=3072, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (1): GELU()\n",
" (2): LinearWithLoRA(\n",
" (linear): Linear(in_features=3072, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" )\n",
" )\n",
" (norm1): LayerNorm()\n",
" (norm2): LayerNorm()\n",
" (drop_resid): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (9): TransformerBlock(\n",
" (att): MultiHeadAttention(\n",
" (W_query): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (W_key): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (W_value): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (out_proj): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (dropout): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (ff): FeedForward(\n",
" (layers): Sequential(\n",
" (0): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=3072, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (1): GELU()\n",
" (2): LinearWithLoRA(\n",
" (linear): Linear(in_features=3072, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" )\n",
" )\n",
" (norm1): LayerNorm()\n",
" (norm2): LayerNorm()\n",
" (drop_resid): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (10): TransformerBlock(\n",
" (att): MultiHeadAttention(\n",
" (W_query): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (W_key): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (W_value): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (out_proj): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (dropout): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (ff): FeedForward(\n",
" (layers): Sequential(\n",
" (0): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=3072, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (1): GELU()\n",
" (2): LinearWithLoRA(\n",
" (linear): Linear(in_features=3072, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" )\n",
" )\n",
" (norm1): LayerNorm()\n",
" (norm2): LayerNorm()\n",
" (drop_resid): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (11): TransformerBlock(\n",
" (att): MultiHeadAttention(\n",
" (W_query): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (W_key): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (W_value): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (out_proj): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (dropout): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (ff): FeedForward(\n",
" (layers): Sequential(\n",
" (0): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=3072, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (1): GELU()\n",
" (2): LinearWithLoRA(\n",
" (linear): Linear(in_features=3072, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" )\n",
" )\n",
" (norm1): LayerNorm()\n",
" (norm2): LayerNorm()\n",
" (drop_resid): Dropout(p=0.0, inplace=False)\n",
" )\n",
" )\n",
" (final_norm): LayerNorm()\n",
" (out_head): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=2, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
")\n"
]
}
],
"source": [
"model.to(device)\n",
"\n",
"print(model)"
]
},
{
"cell_type": "markdown",
"id": "c4bbc9d7-65ec-4675-bab8-2e56eb0cfb55",
"metadata": {
"id": "c4bbc9d7-65ec-4675-bab8-2e56eb0cfb55"
},
"source": [
"- Based on the model architecture above, we can see that the model now contains our new `LinearWithLoRA` layers\n",
"- Also, since we initialized matrix $B$ with 0's, we expect the initial model performance to be unchanged compared to before"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "DAlrb_I00VEU",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "DAlrb_I00VEU",
"outputId": "3da44ac4-230b-4358-d996-30b63f0d962a"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Training accuracy: 46.25%\n",
"Validation accuracy: 45.00%\n",
"Test accuracy: 48.75%\n"
]
}
],
"source": [
"torch.manual_seed(123)\n",
"train_accuracy = calc_accuracy_loader(train_loader, model, device, num_batches=10)\n",
"val_accuracy = calc_accuracy_loader(val_loader, model, device, num_batches=10)\n",
"test_accuracy = calc_accuracy_loader(test_loader, model, device, num_batches=10)\n",
"\n",
"print(f\"Training accuracy: {train_accuracy*100:.2f}%\")\n",
"print(f\"Validation accuracy: {val_accuracy*100:.2f}%\")\n",
"print(f\"Test accuracy: {test_accuracy*100:.2f}%\")"
]
},
{
"cell_type": "markdown",
"id": "13735b3e-f0c3-4dba-ae3d-4141b2878101",
"metadata": {
"id": "13735b3e-f0c3-4dba-ae3d-4141b2878101"
},
"source": [
"- Let's now get to the interesting part and finetune the model by reusing the training function from chapter 6\n",
"- The training takes about 15 minutes on a M3 MacBook Air laptop computer and less than half a minute on a V100 or A100 GPU"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "wCParRvr0eff",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "wCParRvr0eff",
"outputId": "ce910a9c-ee89-48bb-bfa6-49c6aee1e450"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Ep 1 (Step 000000): Train loss 3.820, Val loss 3.462\n",
"Ep 1 (Step 000050): Train loss 0.346, Val loss 0.325\n",
"Ep 1 (Step 000100): Train loss 0.063, Val loss 0.144\n",
"Training accuracy: 100.00% | Validation accuracy: 92.50%\n",
"Ep 2 (Step 000150): Train loss 0.054, Val loss 0.045\n",
"Ep 2 (Step 000200): Train loss 0.058, Val loss 0.122\n",
"Ep 2 (Step 000250): Train loss 0.041, Val loss 0.199\n",
"Training accuracy: 100.00% | Validation accuracy: 95.00%\n",
"Ep 3 (Step 000300): Train loss 0.020, Val loss 0.153\n",
"Ep 3 (Step 000350): Train loss 0.017, Val loss 0.186\n",
"Training accuracy: 100.00% | Validation accuracy: 95.00%\n",
"Ep 4 (Step 000400): Train loss 0.017, Val loss 0.099\n",
"Ep 4 (Step 000450): Train loss 0.001, Val loss 0.170\n",
"Ep 4 (Step 000500): Train loss 0.117, Val loss 0.222\n",
"Training accuracy: 97.50% | Validation accuracy: 92.50%\n",
"Ep 5 (Step 000550): Train loss 0.038, Val loss 0.235\n",
"Ep 5 (Step 000600): Train loss 0.019, Val loss 0.252\n",
"Training accuracy: 100.00% | Validation accuracy: 100.00%\n",
"Training completed in 2.16 minutes.\n"
]
}
],
"source": [
"import time\n",
"from previous_chapters import train_classifier_simple\n",
"# Alternatively:\n",
"# from llms_from_scratch.ch06 import train_classifier_simple\n",
"\n",
"\n",
"start_time = time.time()\n",
"\n",
"torch.manual_seed(123)\n",
"\n",
"optimizer = torch.optim.AdamW(model.parameters(), lr=8e-4, weight_decay=0.1)\n",
"\n",
"num_epochs = 5\n",
"train_losses, val_losses, train_accs, val_accs, examples_seen = train_classifier_simple(\n",
" model, train_loader, val_loader, optimizer, device,\n",
" num_epochs=num_epochs, eval_freq=50, eval_iter=5,\n",
")\n",
"\n",
"end_time = time.time()\n",
"execution_time_minutes = (end_time - start_time) / 60\n",
"print(f\"Training completed in {execution_time_minutes:.2f} minutes.\")"
]
},
{
"cell_type": "markdown",
"id": "d0c89e82-3aa8-44c6-b046-0b16200b8e6c",
"metadata": {
"id": "d0c89e82-3aa8-44c6-b046-0b16200b8e6c"
},
"source": [
"- Finally, let's evaluate the model"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "bawWGijA0iF3",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 308
},
"id": "bawWGijA0iF3",
"outputId": "af70782a-d605-4376-fa6c-d33b38979cfa"
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAeoAAAEiCAYAAAA21pHjAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjcsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvTLEjVAAAAAlwSFlzAAAPYQAAD2EBqD+naQAARZ5JREFUeJzt3Qd4FOXaBuBnN5UEEloSCITeW+hIE6QjFjgqHuUgciy/CogiFo5K0aNgx4JYUNEjCoKCiBQRBKRJDZ0gLSRASCCk92T+6/12Z7MbQkggyc4mz31lrik7uzs7O9l3vm7SNE0DERERGZLZ2QdAREREV8dATUREZGAM1ERERAbGQE1ERGRgDNREREQGxkBNRERkYAzUREREBsZATUREZGAM1ERERAbGQE1EDvr27YunnnqKZ4XIIBioiUrYgw8+CJPJdMU0ZMgQnmsiKjb34j+FiK5FgvJXX33lsM3Ly4snjoiKjSlqolIgQblWrVoOU7Vq1dRjGzZsgKenJ/7880/b/m+++SYCAwNx4cIFtb569Wr06tULVatWRY0aNXDbbbfhxIkTtv1Pnz6tUuk//PADevfujUqVKqFLly44duwYdu7cic6dO6Ny5coYOnQoYmNjHVL7w4cPx4wZMxAQEAA/Pz889thjyMzMvOpnycjIwOTJk1GnTh34+vqiW7du6jPoIiIicPvtt6vPJ4+3bt0aK1euvOrrffzxx2jatCm8vb0RFBSEu+++2/ZYbm4uZs6ciYYNG6rPFBoaiiVLljg8/+DBg+pzyeeT548ePRoXL150yLp/8skn8dxzz6F69erq3E+fPr1I3xuRETFQEzmpDFgCTEJCAvbu3YuXX34Z8+bNU4FHpKSkYNKkSdi1axfWrVsHs9mMESNGqEBmb9q0aXjppZewZ88euLu74/7771cB6v3331c3AsePH8fUqVMdniOvd+TIERVsv//+e/z0008qcF/N+PHjsW3bNixcuBD79+/HPffco3IM/v77b/X4uHHjVDDftGkTDhw4gDfeeEMF0YLI55Eg+sorryA8PFzdkNx88822xyVIf/PNN/jkk09w6NAhPP300/jXv/6FjRs3qsfj4+PRr18/dOjQQb2WPF9ubkaOHOnwPl9//bW6afjrr7/UTZC839q1a4v9XREZggxzSUQlZ8yYMZqbm5vm6+vrML322mu2fTIyMrT27dtrI0eO1Fq1aqU98sgjhb5mbGysDEerHThwQK2fOnVKrc+bN8+2z/fff6+2rVu3zrZt5syZWvPmzR2OrXr16lpKSopt29y5c7XKlStrOTk5ar1Pnz7axIkT1XJERIT6LGfPnnU4nv79+2tTpkxRy23bttWmT59epHPz448/an5+flpiYuIVj6Wnp2s+Pj7a1q1bHbY/9NBD2n333aeWX331VW3QoEEOj0dGRqrPHR4ebjv+Xr16OezTpUsX7fnnny/SMRIZDcuoiUrBLbfcgrlz5zpsk2xYnWR9L1iwAO3atUP9+vXx3nvvOewrqVVJCUuKULJ19ZT0mTNn0KZNG9t+8nydnhpv27atw7aYmBiH15bsZB8fH9t69+7dkZycjMjISHUs9iSFnJOTg2bNmjlslxS0ZMkLSSE//vjj+O233zBgwADcddddDsdlb+DAgeo9GjVqpFLlMklOgRyPpP5TU1PVPvYkW15S0GLfvn34448/CkyxS9GAfpz537927dpXnAciV8FATVQKJNu1SZMmhe6zdetWNY+Li1OTPEcnZb4S0D7//HMEBwerQC0BOn9ZsoeHh21ZyqwL2pY/u7w4JIC7ublh9+7dam5PD5YPP/wwBg8ejF9//VUFa8m+fueddzBhwoQrXq9KlSoqm16y3WVfuRmR8mMpV5f3EvI6Uh5eUEU82UfOjWSv5yfBuKDzUhLngciZGKiJnEBSf1L+KoF40aJFGDNmDH7//XdVFn3p0iVVfiuPSUUxsXnz5hJ7b0mVpqWlqcpaYvv27SrohoSEXLGvpGQlRS2pUf1YCiLPlUppMk2ZMkUde0GBWkhZuqS8ZZIydqkwt379epWSloAsuQZ9+vQp8LkdO3bEjz/+iAYNGqjXIaoIeKUTlQLJGo6Ojnb8Z3N3R82aNVXgkwpSkgodO3asyv6V7GpJhT777LOq9rRkK3/22WcqlSiB64UXXiixY5NU+UMPPaQqoUntcQmWUmFMbhLyk6zkUaNG4YEHHlDHJ4FbapFLhTTJXh42bJiqGCe1sGXfy5cvq6zpli1bFvjeK1aswMmTJ1UFMvmcUjtcUrrNmzdXqW2pXS43MLJNar1LZbstW7ao2ulyMyMV1+Qm4L777rPV6pYsc6noJpXx8qf6icoDBmqiUiC1ke2zYoUEo6NHj+K1115TTZokaAnZT4KyBJ9BgwapMmQJPFL2K9nd8rwPPvhA1RYvCf3791fNoyRYyg2FvG9hzZekPfh///tfPPPMMzh79qy62bjppptUkzEhNx4SQKOiolRAlRuP/GXuOkk9Sy1zeb/09HR1HFLzXJp0iVdffVU1G5Pscwnosr+kov/zn/+ox6UYQAL3888/r86VHL8UEch7FnSjQVQemKRGmbMPgojKhrSjliZOy5Yt4yknchG8BSUiIjIwBmoiIiIDY9Y3ERGRgTFFTUREZGAM1ERERAbGQE1ERGRgDNRWc+bMUb0dydB7Mozfjh07UN7JaEfSHaO0TZUuFvM32ZGWe9LFo7TzlV6spCcpfcQknXR9KR1iSPtZafMqHWnoXUHqZMQl6dVKzq30YCWjGbkaadcrw0hKpxwyHKUMFSm9h9mTdsHSnlg6K5GevqTPa33YSp10XiKdhEjf1vI60sFJdna2wz7Svaa0HZZeuqQb0vnz58PVSD/n0iGKXBcySX/iq1atsj3Oc3V1s2bNUv+P0pEMz9eVpA2+nB/7qUWLFuX7XDl7VBAjWLhwoebp6al9+eWX2qFDh9RIRlWrVtUuXLiglWcrV67UXnzxRe2nn35Sow8tXbrU4fFZs2Zp/v7+2rJly7R9+/Zpd9xxh9awYUMtLS3Nts+QIUO00NBQbfv27dqff/6pNWnSxDbSkUhISNCCgoK0UaNGaQcPHlQjPFWqVEn79NNPNVcyePBg7auvvlKfISwsTLv11lu1evXqacnJybZ9HnvsMS0kJESNXrVr1y7tpptu0nr06GF7PDs7W2vTpo02YMAAbe/ever816xZ0zYKlTh58qQaQWrSpEna4cOHtQ8//FCNXrV69WrNlSxfvlz79ddftWPHjqlRrf7zn/9oHh4e6vwJnquC7dixQ2vQoIHWrl072whmPF+Opk2bprVu3Vo7f/68bZLR5XTl8dpioNY0rWvXrtq4ceNsJ0WG+wsODlZDBFYU+QN1bm6uVqtWLe2tt96ybYuPj9e8vLxUsBVyAcvzdu7cadtn1apVmslksg2L+PHHH2vVqlVTwzrqZLhB+6EXXVFMTIz67Bs3brSdGwlEixcvtu1z5MgRtc+2bdvUuvwgmM1mLTo62mGISRn2UT8/zz33nPoRsnfvvfeqGwVXJ9eBDMvJc1WwpKQkrWnTptratWsdhhrl+boyUEvioCDl9VxV+Kxv6fdYRgaSbF2ddEUo69u2bUNFderUKdVXtf158ff3V8UC+nmRuWR3d+7c2baP7C/nT4Zn1PeRriplWEed9HEt2cbSL7Srkj6o7YeulGsoKyvL4XxJdly9evUczpf06a0PR6mfi8TERBw6dMi2j/1r6Pu48rUoXYxKl6gpKSkqC5znqmCSXSvZsfm/f56vK0kRnBTZyXCpUvQmWdnl+VxV+EAtY/3KD4n9lyZkPf+gChWJ/tkLOy8yl/Kd/ANPSPCy36eg17B/D1cjA0ZI+WHPnj1tY0PLZ5GbEblxKex8XetcXG0f+RGREa9ciYxlLWWEUsYno2otXboUrVq14rkqgNzIyPCfUhciP15bjiSxIOXF0p++1IWQRIXUgUlKSiq354qDchBdR8rn4MGDJTr0ZHkkg4mEhYWp3IclS5ao0a82btzo7MMynMjISEycOBFr165VFS6pcEOHDrUtS4VFCdwyMMsPP/xgG7q1vKnwKWoZCUiGxstfK1DWa9WqhYpK/+yFnReZyzjF9qTmpNQEt9+noNewfw9XIsNByqhXMpRj3bp1bdvls0gxigx4Udj5uta5uNo+UnPa1X6EJGUjtWU7deqkUooyKtj777/Pc5WPZNfK/5HUMJYcKZnkhkZGTJNlScnx2ro6ST3LEKsy3Gl5/T+s8IFafkzkh0TG17XP2pR1KU+rqBo2bKguVvvzItk+UvasnxeZyz+E/NDo1q9fr86f3OXq+0gzMCk30knKQVJbMh6xq5D6dhKkJftWPqOcH3tyDXl4eDicLymHl7Iz+/Ml2cH2NzdyLuSfX7KE9X3sX0Pfpzxci3JdyLCUPFdXDjsq14XkPuiT1PuQsld9mdfW1Ulz0BMnTqhmpOX22nJKFTYDNs+S2szz589XNZkfffRR1TzLvlZgeSS1TKV5gkxyKbz77rtqOSIiwtY8S87Dzz//rO3fv1+78847C2ye1aFDB+2vv/7SNm/erGqt2jfPklqY0jxr9OjRqmmOnGtp9uBqzbMef/xx1VRtw4YNDs1CUlNTHZqFSJOt9evXq2Yh3bt3V1P+ZiGDBg1STbykqUdAQECBzUKeffZZVVt1zpw5Ltk864UXXlA14k+dOqWuHVmX1gC//fabepznqnD2tb55vhw988wz6v9Qrq0tW7aoZlbSvEpaYpTXa4uB2kraycmXK+2ppbmWtAsu7/744w8VoPNPY8aMsTXRevnll1WglRuZ/v37qzax9i5duqQCc+XKlVXzhrFjx6obAHvSBrtXr17qNerUqaNuAFxNQedJJmlbrZMbmCeeeEI1Q5J/8hEjRqhgbu/06dPa0KFDVVty+XGRH52srKwrvpf27dura7FRo0YO7+Eq/v3vf2v169dXn0F+BOXa0YO04LkqXqDm+XJsJlW7dm11bcnviawfP368XJ8rjp5FRERkYBW+jJqIiMjIGKiJiIgMjIGaiIjIwBioiYiIDIyBmoiIyMAYqImIiAyMgdqO9Jokg5LLnArHc1U8PF88V6WF11b5P1eGaUc9a9YsTJkyRXVOP3v2bKccg3SRKUM5yiAC0p0c8Vzx2uL/odHxd6v8nytDpKh37tyJTz/9VI2EQkRERAYK1NKhunQ+//nnn7vUIA1EREQVYjxqGdt32LBhGDBgAP773/8W67kypOLevXvVMHBm843fc8jA4+Ls2bMqi4R4rkoKry2eq9LCa8s1z5WMJidDZ3bo0EENZ1oYpwbqhQsXYs+ePSrruyikAoB9JQAZXrFfv34lflz6UGfEc8Vry3n4f8jzVRGurR07dqBLly7GDNSRkZGq4piM8ent7V2k58jg8zNmzCjwg8pYpERERK7g/Pnz6Nq1q8oRNmyt72XLlmHEiBFwc3OzbcvJyYHJZFLZ2JJytn+soBS1ZF/InZEE/bp165bp8RMREV2vqKgohISEFCl+OS1F3b9/fxw4cMBh29ixY9GiRQs8//zzVwRp4eXlpSads8sYiIiISpvTAnWVKlXQpk0bh22+vr6oUaPGFduJiIgqKqc3zyIiIiIDN8+yt2HDBmcfAhFVcFJXJisry9mHQS7Ow8OjwCJclw/UzpSSkY19kfHIztVwc7MAZx8OEZUxqVcbHR2N+Ph4nnsqEVWrVkWtWrVUJekbwUBttf5oDCZ8vxft6vozUBNVQHqQDgwMhI+Pzw3/uFLFvulLTU1FTEyMWr/R5sMM1FbtQ6qq+ZHziUjPyoG3R8lkWRCRa2R360FaKrQS3ahKlSqpuQRrua5uJBuclcms6larhBq+nsjK0XD4PJt9EVUkepm0pKSJSop+Pd1onQcGaivJ5gq1pqqlrJqIKh5md5MRrycG6gKyvxmoiYjIKBio7egp6jCmqImoAmvQoAFmz55drKa1knos7Rrz8+fPVzWpKxoGajuhdf3V/PSlVMSnZjrrOyEiKhIJjoVN06dPv64zKSMaPvroo0Xev0ePHmqQCX9/y28olSzW+rZT1ccTDWv64tTFFOyLSkAftqcmIgOT4KhbtGgRpk6divDwcNu2ypUrOzQZktrt1xr7WAQEFK8vCU9PT9VemEoHU9RXSVWHnWGFMiIyNgmO+iSpWUlF6+tHjx5VYyqsWrUKnTp1UgMabd68GSdOnMCdd96phleUQC5jIf/++++FZn3L686bN0+NeCg1mZs2bYrly5dfNetbz6Jes2YNWrZsqd5nyJAhDjcW2dnZePLJJ9V+0iROBmMaM2YMhg8fXqxzMHfuXDRu3FjdLDRv3hz/+9//HG5OJFehXr166vMHBwer99R9/PHH6rPIUMtyPu6++24YEQN1Praa31EM1ESo6J1WZGY7ZSrJ0YdfeOEFzJo1C0eOHEG7du2QnJyMW2+9FevWrcPevXtVAL399ttx5syZQl9nxowZGDlyJPbv36+eP2rUKMTFxV11f+nw4+2331aBc9OmTer1J0+ebHv8jTfewIIFC/DVV19hy5YtajREGf64OJYuXYqJEyfimWeewcGDB/F///d/ahTGP/74Qz3+448/4r333sOnn36Kv//+W71+27Zt1WO7du1SQfuVV15RuRCrV6/GzTffDCNi1nchNb/ln4XNNYgqprSsHLSausYp7334lcHw8SyZn2cJRAMHDrStV69eHaGhobb1V199VQU8SSGPHz/+qq/z4IMP4r777lPLr7/+Oj744APs2LFDBfqCSNvhTz75RKV2hby2HIvuww8/xJQpU1QqXXz00UdYuXJlsT7b22+/rY7riSeeUOuTJk3C9u3b1fZbbrlF3RxI7sKAAQNU39uSsu7atavaVx6TERtvu+02lfNQv359dOjQAUbEFHU+LWv7wcPNhEspmYi6nOacb4WIqIR07tzZYV1S1JKylSxpyXaWbGlJbV8rRS2pcZ0EOD8/P1sXmQWRLHI9SOvdaOr7JyQk4MKFC7agKaTnLsmiL44jR46gZ8+eDttkXbaLe+65B2lpaWjUqBEeeeQRdUMiWe5Cbl4kOMtjo0ePVql7yQUwIqao85GuQyVY749KUM20QqqzpyKiiqiSh5tK2TrrvUuKBFV7EqTXrl2rUp1NmjRRXV1K2WxmZuEtXSRFak9yG3Nzc4u1f0lm6RdFSEiIytaWMnj5zJLyfuutt7Bx40aVit6zZ48qX//tt99URTwpz5Ya70ZrAsYUdQHY8QkRSWCR7GdnTKVZ5CblwZJdLFnOUl4rWcOnT58u0y9cKr5J5S0JijqpkS6BszhatmypPo89WW/VqpVtXW5EpAxesuolKG/btg0HDhxQj0kNeMkWf/PNN1XZu5yH9evXw2iYoi5AaF25m4pgxydEVO5ILeeffvpJBS+5IXj55ZcLTRmXlgkTJmDmzJkqVd+iRQtVZn358uVi3aQ8++yzqoKblC1LwP3ll1/UZ9NrsUvtc7kB6Natm8qK//bbb1XglizvFStW4OTJk6oCWbVq1VT5uJwHqTluNAzUhdT8PnguAVk5ufBwY8YDEZUP7777Lv7973+rTkpq1qypmkVJjeuyJu8rQ4s+8MADqnxaOlgZPHhwsUaZGj58ON5//32VjS+1vxs2bKhqkfft21c9LlnYUuNdKplJwJYcBAnm0hxMHpOgLtnd6enp6gbm+++/R+vWrWE0Jq2sCw1KUFRUlCqDiIyMRN26dW/sxbIzgIitwKXjyO38MEJf+Q1J6dn49cleaB3M3naIyjP5oT516pT6oZc2tVT2JDUrWdmSQpaa6OX9uooqRvxiilqXdhn4nzS0N8HcbqTK/t58/KLK/magJiIqWREREaoSV58+fZCRkaGaZ0lQu//++3mq82Gerq5KLaBaQ+nmAIjcidAQSyqaI2kREZU8s9msypClZzRpUiUVvKRsWVLV5Igpanv1ugOXTwFntqF9iKVD+n2RCflOGRER3SjJ9s1fY5sKxhS1vXo3WeZnttv6/D4Wk4TkDEsDeSIiorLGQG2vfg/L/OwuBPqYEOzvDalqdyCKqWoiInIOBmp7NZoAPjWA7HTg/D60r8cBOoiIyLkYqO1JQ3sppxZntlk7PuGQl0RE5DwM1IWVU3PISyIicjIG6vxsKertaBtcBWYTcD4hHRcS08v+2yEiogqPgTq/Wu0A90pAWhx8k06hWVAVtVk6PiEiKo+ky82nnnrKtt6gQQPMnj270OdIn9zLli274fcuqdcpjHQT2r59e7gqBur83D2ButbxWyO22sqp2fEJERmNDKwxZMiQAh/7888/VRCUUaGKS0a1kr63yyJYnj9/HkOHDi3R9ypvGKivkf3Nmt9EZFQPPfSQGmdZ+o3OTwan6Ny5M9q1a1fs1w0ICFCjTZUFGWbTy8urTN7LVTFQF6Rhb6BRX5Wy1lPU+yMTkJvrsuOXEFE5dNttt6mgKl1x2ktOTsbixYtVIL906RLuu+8+1KlTRwVfGUFKRokqTP6s77///lsNBykDS8hYz3JzUNBoWM2aNVPv0ahRIzV8ZlZWlnpMjm/GjBnYt2+fSuXLpB9z/qxv6Uq0X79+ajhKGeXq0UcfVZ9HJ2Npy6hZMmJW7dq11T7jxo2zvVdRBwB55ZVX1GAYcpMgKf3Vq1fbHs/MzMT48ePV68tnlmExZUhOIeNYSe5AvXr11HODg4Px5JNPojSxC9GCNLzZMgFolpOLSh5uSMrIxsmLyWgSaCmzJqIKIjOl+M9x8wLcrD+vOdlATgZgMgMela79up6+RX4bd3d3NUykBL0XX3zRNpazBGkZ1lECtAS5Tp06qUDq5+eHX3/9FaNHj0bjxo3RtWvXIgW1f/zjHwgKCsJff/2FhIQEh/JsXZUqVdRxSOCSYPvII4+obc899xzuvfdeHDx4UAVDfaxof/8rRyVMSUlRQ112795dZb/HxMTg4YcfVkHT/mbkjz/+UEFU5sePH1evL8FW3rMoZGjMd955B59++qkay/rLL7/EHXfcgUOHDqnhLj/44AMsX74cP/zwgwrIMsKVTOLHH3/Ee++9h4ULF6ohMWWoTrkBKU0M1Nc6QW5mtK3jjx2n4xAWmcBATVTRvB5c/OfcMx9oPcKyfPQXYPGDQP1ewNhf8/aZ3RZIvXTlc6cXrydEGVv6rbfewsaNG23jMEu291133aWCoUyTJ0+27T9hwgSsWbNGBaGiBGoJrEePHlXPkSAsXn/99SvKlV966SWHFLm8pwQzCdSSOq5cubK6sZCs7qv57rvv1NCQ33zzDXx9LTcsH330kSqLf+ONN9TNgqhWrZraLmNXt2jRAsOGDcO6deuKHKglNS43Lv/85z/Vury2BH3JRZgzZw7OnDmjAnavXr3UzY+kqHXymHyGAQMGwMPDQwXyopzHG8Gs78IkxwDRB2wjaYVFXi7VL4OIqLgkUPXo0UOlCoWkMKUimWR7C0lZy/jOkuVdvXp1FTAl6ErAKYojR46oATT0IC0kxZvfokWL1ChYEsTkPSRwF/U97N8rNDTUFqRFz549Vao+PDzctk1SshKkdZK6ltR3USQmJuLcuXPqde3Jury/nr0eFhaG5s2bq2xtGY5Td8899yAtLU1l78uNwdKlS5GdnV1+U9Rz585V0+nTp20nf+rUqcaoARi+Gvj+XtVcK7THIrWJI2kRVUD/OXd9Wd+6FrdbXkOyvu09dQAlRYKypJQlNSipacnWlnGehaS2JatXUosSrCUISta1lMOWlG3btmHUqFGqHFqyriUVL6lpyV4uDR4eHg7rkuqVYF5SOnbsqMbGXrVqlcpRGDlypEpBL1myRN20yE2DbJey+ieeeMKWo5H/uMpFiloK8mfNmoXdu3dj165dqgLBnXfeqcoJnK52qHz9anzq9sGWu7sj5xORnpXj7CMjorIkZcbFnfTyaSHLss2+fLqw170OEkhkfGfJOpZsY8kO18urZShJ+V3917/+pVKrkhI8duxYkV9bxoeW8llpRqXbvn27wz5bt25V2cNSTi41zSXbOCIiwvHjenqq1P213kvKe6WsWrdlyxb12SR1WxKknF5yB/IPsSnrUlHOfj8p+/78889VboGUTcfFxanHJCtfsuOlLHvDhg3qRkXK5ctlilo+qL3XXntNpbDlIpDUtVP51QZeiAC8/VFH01CzsicuJmfi0LlEdKpfzbnHRkRkR7KaJahMmTJFZe1K1q1OgqakBCWYStnuu+++iwsXLjgEpcJISlJqc48ZM0alHOX1JSDbk/eQbG5JRXfp0kVVWJMsYXtSbi2pVMlSlkSaVDTL3yxLUuXTpk1T7yU1q2NjY1VOgVR+08unS8Kzzz6r3kdyHqQSmuRCyHEtWLBAPS7nSLLTpaKZ3CRI5TzJ0q9ataqq1CY3HN26dVM13L/99lsVuO3LscttGbV8cPmS5U6qoPIPkZGRoS4SfUpKSirdg/K2lE3LnSk7PiEiI5Ps78uXL6usZ/vyZCkrlqxc2S6VzSTgSPOmopJAJUFXymWl0pTUwpZElT2pMf3000+r2tkS+OSmQJpn2ZPKbdI5yy233KKalBXUREwCn5SfS8pVAv7dd9+N/v37q4pjJUnKnSdNmoRnnnlGFQdIbXSp5S03HEJuIt58802VOyDHIcWzK1euVOdCgrWksqVMW9qoSxb4L7/8opqJlRaTJo3CnEiyCyQwS00/uSuUrJtbb721wH3lDkvKQPKTbBm5Qys1uTn48I+TeGftMdzZPhjv/7ND6b0XEZU5+f2R1F7Dhg1Vu1mi0r6upJMaKe8uSvxyeopayh0ky0Ha5z3++OMqy+Pw4cMF7ivZOtKGT5+utl+JSY0DvroVeLMR2gdbeulhV6JERFSWnN6OWioYNGnSRC1Lo3xp5C41FKUhen5SnmFfpiHZ36WqUjUg5giQHo8OHpZmBqcvpeJySiaq+XqW7nsTEREZIUWdn1Sxl7JoQ5Bak9bxqStf2IlGNS01MvdFcSQtIiKqAIFasrI3bdqkCuqlrFrWpaq71Pwz4gAdoSH6SFrF6zmIiIjIJbO+pScZ6adW2udJA3mpQSc1/gYOHOjMw7pKoN6G0J5VsHQveygjIqIKEqi/+OILGJ50fOLuDaTFoZu/pbH7vqgENYKK3qEAEZUPJdm7FVFuCV1PTq9MZnjunkCdzkDEZjRJOwAPt1qIS8lE1OU0hFQvm/Faiaj0K7VKG1npA1ra+Mo6b8TpeklCTrpolQ5b5LqS6+lGMFAXhVQoi9gMj7M70Kr2v1SKem9kPAM1UTkhP6bS1lWK4SRYE5UE6cBFRteS6+tGMFAXs5y6fYPxKlBLe+o7Qq9j+DsiMiRJ9ciPqoyEdK0+qYmuRUb3kmE9SyJnhoG6KEK6Wka+uXwa3bpk4mt2fEJULsmPqoyAVFqjIBGVi3bUhuTtBwRZBgnpiKNqfuBsArJyWPGEiIhKFwN1MbO/g+L3ws/bHRnZuQiPLuVBQYiIqMJjoC4qaw9lpqhdeR2fsIcyIiIqZQzURdW4HzB2lZr0IS/DzrArUSIiKl2sTFacATrq91CL7ZmiJiKiMsIU9XVoF+Kv5n/HJCM5I7ukvxMiIiIbBuriiDsJrHwWgRueR52qlaBpwH6WUxMRUSlioC6OnCxgx2fAvkXoWLey2sSRtIiIqDQxUBdHzWZAjwnAiLkIrVtFbZIeyoiIiEoLK5MVh3QFN+i/arHtyUsATiKMgZqIiEoRU9TXqW1df5hNQHRiOqIT0kv2WyEiIrJioC6unGzg9Gb47PgIzQKt5dSsUEZERKWEgbq4tFzg27uA36dhQJClC1FmfxMRUWlhoC4ud0+gTme1eLPXcTVnhTIiIiotDNQ30O93s4wDar4/KgG5uVqJfjFERESCgfoGRtLyj92NSh5uqneyE7HJvKKIiKjEMVBfj5Au0lYLpsun0Lu2pQtRllMTEVFpYKC+Ht7+QFAbtTjEL0LNWfObiIhKAwP1DZZTd9COqjlT1EREVBoYqK9XfUs5dZ2kMDU/ej4J6Vk5JfbFEBERCQbq6xViSVF7xB5EPd9cZOdqOHQukVcVERGVKAbq6+VfB6haDyYtF8MDzqpNzP4mIqKSxkBdAs20enmy4xMiIiodDNQl0vHJQTVnzW8iIippDNQ3on5PoEFveDfrp1YjLqUiLiWzhL4aIiIiBuobE9AceHAFvPs/h0Y1fdUmpqqJiMjpKerIyEhERUXZ1nfs2IGnnnoKn332GSqq9iFV1ZwDdBARkdMD9f33348//vhDLUdHR2PgwIEqWL/44ot45ZVXUOGkxqGfH2t+ExGRQQL1wYMH0bVrV7X8ww8/oE2bNti6dSsWLFiA+fPno0I5txd4syEG75sgg1WrFLWmcSQtIiJyYqDOysqCl5eXWv79999xxx13qOUWLVrg/PnzRX6dmTNnokuXLqhSpQoCAwMxfPhwhIeHw6UEtgLcveHmWx213JJxOTULkXFpzj4qIiKqyIG6devW+OSTT/Dnn39i7dq1GDJkiNp+7tw51KhRo8ivs3HjRowbNw7bt29XryM3AIMGDUJKSgpchrsXMPkYzBN2ISg4RG0Ki4p39lEREVE54X49T3rjjTcwYsQIvPXWWxgzZgxCQ0PV9uXLl9uyxIti9erVDuuSbS4p6927d+Pmm2+GS42mJRXK6vqrrO+wM/G4IzTY2UdFREQVNVD37dsXFy9eRGJiIqpVq2bb/uijj8LHx+e6DyYhIUHNq1evDlfUoW5lfM0mWkRE5Oys77S0NGRkZNiCdEREBGbPnq3KlyVFfD1yc3NVE6+ePXuqymkFkfeUmwN9SkpKgiHk5gDfDMcdq7ojAPE4eDYBWTm5zj4qIiKqqIH6zjvvxDfffKOW4+Pj0a1bN7zzzjuqMtjcuXOv60CkrFpqky9cuLDQymf+/v62qVWrVjAEsxuQEgtzdip6ex9HRnYuwqMNchNBREQVL1Dv2bMHvXv3VstLlixBUFCQSlVL8P7ggw+K/Xrjx4/HihUrVNvsunXrXnW/KVOmqOxxfTp8+DCMNkDHoMqn1JwjaRERkdMCdWpqqmpSJX777Tf84x//gNlsxk033aQCdlFJe2MJ0kuXLsX69evRsGHDQveXJmF+fn62ST8GIw3Q0V47oubsoYyIiJwWqJs0aYJly5aprkTXrFmjmlSJmJgYFUCLk9397bff4rvvvlNBV3o5k0nKwF2ONUUdlHoMvkhjipqIiJwXqKdOnYrJkyejQYMGqjlW9+7dbanrDh06FPl1pDxbsrClFnnt2rVt06JFi+By/OsA/vVg0nLR3nwcx2OTkZSe5eyjIiKiitg86+6770avXr1UL2R6G2rRv39/1b66qMpdV5uS/X3gDPpVOoktKW1x4GwCejSu6eyjIiKiijgeda1atVTqWXoj00fSktS1dCNaYdW35Cz09PxbzVmhjIiInBKopc2zjJIlTaTq16+vpqpVq+LVV19Vj1VY1nLqxhmH4Y5sVigjIiLnZH3LcJZffPEFZs2apTooEZs3b8b06dORnp6O1157DRVSzeaAd1V4pMejlSkC+yIrO/uIiIioIgbqr7/+GvPmzbONmiXatWuHOnXq4Iknnqi4gdpstpRTH1uNrm7HMC+xMaIT0lHL39vZR0ZERBUp6zsuLq7AsmjZJo9VaNb21H0rnVBzllMTEVGZB2qp6f3RRx9dsV22Scq6QqvXQ83aQiqUadjHIS+JiKiss77ffPNNDBs2DL///rutDfW2bdtUBygrV65EhRbcHhjzC347HwgsP6GGvCQiIirTFHWfPn1w7Ngx1WZaBuWQSboRPXToEP73v/+hQnP3AhrejLaNLONRS1vqnNxy1l6ciIiMnaIWwcHBV1Qa27dvn6oN/tlnn6GiaxpYBT6ebkjOyMbJ2GQ0DTJQv+RERFT+OzyhQiTHwO23F/Glz4dqdW8ks7+JiOj6MFCXBjcPYPvHuCl9M2oigR2fEBFR2Wd9UyEqVQNu+Q/2plRD6iYv1vwmIqKyCdRSYawwUqmMrPo8h8D4NKRuWo+j55OQnpUDbw83nh4iIiq9QC19e1/r8QceeKB4R1COBft7I6CKF2KTMnDoXAI61a/u7EMiIqLyHKi/+uqr0juS8kbTYIrcgecrr8bUpO4Ii2SgJiKi4mMZdWkxmYAfH8LdCZH4yRyAsMhGpfZWRERUfrHWdxn0+93VfJQ1v4mI6LowUJdBoO5sCseZuFTEpWSW6tsREVH5w0BdmupZ+kHv5HYC7shmqpqIiIqNgbo0BbQEvP1RCeloaTrDIS+JiKjYGKhLk9kMhFiyv7uYw9nxCRERFT+U8JyVUTm1BOrIeGgaR9IiIqKiY6Auo3LqruZwXE7NVJXKiIiIioqBurTV6Qi4eaGmKQENTNEspyYiomJhoC5t7l6WYG0tpw7jkJdERFQMDNRl2p76GJtoERFRsTBQl2E5tVQoO3guEVk5uWXytkRE5PoYqMtCSDdoITdho7krsrOz1bCXRERERcFAXRYqVYXpoTXYUG88cmFGWBTH7SYioqJhoC5D7etaxvOW9tRERERFwWEuy1DHIDd0MknHJ5XL8m2JiMiFMVCXlaRo9FnaGb08TWgXOw9J6Vmo4u1RZm9PRESuiVnfZaVKLZj86uCCOQC1cQkHohLK7K2JiMh1OTVQb9q0CbfffjuCg4NhMpmwbNkylGtPbMPrTRfihFaHFcqIiMj4gTolJQWhoaGYM2cOKgRvP3QIqaoWw86wQhkRERm8jHro0KFqqkhCQ6rCjFzsj4xz9qEQEZELYBl1Gevw1yTs83oENZPDEZ2QXtZvT0RELsalAnVGRgYSExNtU1KS6/Xw5Z6dgiqmNOsAHZedfThERGRwLhWoZ86cCX9/f9vUqlUruOwAHSpQs+Y3ERGVo0A9ZcoUJCQk2KbDhw/DVQfo6GI+hn1nmKImIqJy1OGJl5eXmnSS/e1ygjsi1+yJwNx4XD4bjpzcm+BmNjn7qIiIyKCcmqJOTk5GWFiYmsSpU6fU8pkzZ1BueXjDVKejWmydfQQnYpOdfURERGRgTg3Uu3btQocOHdQkJk2apJanTp2K8sxkLafuYj6KMA7QQURERg3Uffv2haZpV0zz589HuWYrp5YKZez4hIiIykllsnIjpKuaNTafx6mICGcfDRERGRgDtTP4VEdWjeZq0T92N9KzcpxyGEREZHwM1E7i3qCHmnc0hePgWbanJiKigjFQO4mpviVQs5yaiIgKw0DtLNaa361METh0JsZph0FERMbmUh2elCv+ITjY/xv889dMVDub6uyjISIig2KK2llMJtTrciuS4YPIuDRcSs5w2qEQEZFxMVA7kZ+3BxoH+Krl/VGsUEZERFdioHamjGS86LkQCzxeQ9iZi049FCIiMiYGamfyqIReCb+gp9shXD6516mHQkRExsTKZM5kdkNs52fw7qYL2HXBW3WfajJxJC0iIsrDFLWTBQx4Cr+Y+yIizRsRl1j7m4iIHDFQO5mnuxmtg/3U8r4oDtBBRESOGKgNYEj1C/i32yqcOHHM2YdCREQGwzJqA7g75gPU8NiLD04HALjF2YdDREQGwhS1Abg3sIxPXSshDJnZuc4+HCIiMhAGagPwa9ZbzTviKMKjk5x9OEREZCAM1AZgsg7Q0cR8DkdOnHD24RARkYEwUBuBT3Vc9GmkFlP+3ursoyEiIgNhoDaIjNpd1dw3ZqezD4WIiAyEgdog/Jpbyqmbph9EYnqWsw+HiIgMgoHaIKo0tQTqNqZTOHQ62tmHQ0REBsFAbRRV6+Gye014mHKweNlSfLH5FOJSMp19VERE5GQM1EZhMiEzuJtafD3tFbRccz+envkenliwGxvCY5CTqzn7CImIyAkYqA0k6JbHkesbCG9TFnq4HYYpNwsrD0Tjwa924tGZn2DbvKcRfXCjsw+TiIjKELsQNZKGvWGefAyIDQcituCFgCFosD8ey8LOolPaVnSPWo7FC4/jp3peuLdLCIa0CoB35J9ASDfAq7Kzj56IiEoBA7XRyHjUgS3U1ALA9AZ18MLQFti/Phab96VgXUJbbDt5SU1dvc/gB7wAzeQGBLeHqX5PoEEvQDpQ8fZ39ichIiq6zBQgJwvIzbZMalmmnLztMs/JtE5ZQEAzVb9HSTwHHFsDeFUB2t6d97rb5gBJ0dbn2z33imW7bZ3GAp3GWJ4fcwT4+nagUjVgvHOazzJQuwBvDzd0HTwKGDwKDePT0HJXFBbvjoR3wmWccQ9APXMscHa3Zdr6AWAyA0FtLEFbgnf9HqpTFSKiAkmQSosHvP0Ady/LtguHgFObLIGwxTDLNk0Dfnz4ygAq6znWAKsey3ZcH/oW0GyQ5TWO/AL89H9AvW7A6KV5x/BeGyAtrnhf0K1vA10fsSxfOg6seAoIaOEYqHd/DVwML97rNrUeq2ICUmItn9dJGKhdTJ2qlTBxQFNM6NcE2062w9s7hyLs0EF0zD2MbuYjuMntCBoiGojeb5m2f2x5YmBroEFPoOHNQMvbnf0xyJ78+KXGAUnnLHf+kjJIOg+kJwCelS25I63uBKqGWPaXfdPjAZ8a5T/nRH7szW6WnCYqxnnLAiL/AtIuFzLFW6fLQKZ1jIEHf7Xc4Isz24DVLwAtbssL1PI9HPoJ0Io5eFBGouP1npUCZKY67mO2C0eSSyjrbh6W79/sYV33BNw9LXN5TFK5Ot9AoPkwwL+u4+u2vw9IuWh9jvV5V122rtdokvf8ag2Ax7cC7t5wFgZqF2U2m9CzSU01xae2xvJ95/DtzkhMOZeIQFxWQfsW77/Rx+sYaqSdAmIOWSb557UP1CfWW4J4lSBnfpzyS36McjLyflDkh3HjG9ZgHJ0XnCW7rTB1OuYF6oM/AisnW4L3yG8s23JzgU97W7L9JHjnn7z8Cthe1TJ3cy+ZwCBZl/pUpZYldSbiTgIRWwHfAKDZ4LznrHjacjNi/zx9Uj/kkhWaCXj4Wn585fN3ewxoOtB6blOA1EtAldqWH9jyRIKZnJvkGCA52pKisw+wch7062HXV5ZrSoLpsHcs27LTgfnW4Foc6XYBtWZzoPU/gLpdHPcZPNOSayfXjQRPCaL2AfWKZXegeuO85zfuBzwZBnj6Or7uU/st+6ogfR31nANbAPd9d+X2Xk/jhnh4A0Gt4UwM1OVAVR9PPNC9gZoOnk3A4l2RWBYWiF9SewCpQA0kYHTwWdzhfxIhjVvCwz6ILBhpyZp66mDeP778EMgPvvyz0dVTeikxQOJ5S+pXJgm+kg3nF2zZ5893gHWvAB0fAO740LJNzqmey5GfT03Ar7Yl8MgkwT0z2fKDrb+meu8swMPHMTUt+104WPxv676FQPOhlmUp3/vzXVWpEf1eynuvXydZA2iq5X1swdRuPf+Nxj+/y0uFRe0Cfh4HNOrrGKjlhkM+27VI0JasS5na3Zu3/fQW4Lt7gFrtgMf+zNu+9SPAoxLgH5IX4OV6NgK5oZIg6uljWc9IAv76xBKQ5YYt+YJ1HgNkp139dZoNyft/1XKs16BdR0mSExPQ0vK55TpymKpeua2gmza5DmTK76bHbuwcSMXXgiq/yndGBWKgLmfa1PFX05RbW+K3wxdU0N58HJh9zh+zz7WC3yl33HnxoKo13sb7IhDY0hKY7bOLFj8InN5sKZuq3tCS9VPNOpf1qvXLfy3z7Ezg9CZrILZL+eopYQnSBWX/NeqTF1Ql8OpZ1Tr54ez9jCV1qQdkCc6Va1my9Iqi+xOWSX70dZItN3qZJfAVNkkWpL4sQdY+2MefASK3O+auSApnz/8kGhTt2GR/SSnZl+dJwJQyPwmo9vr+x/K6sr9MknLWl23bfCypSDm2hChLCwedZP9LNqX9tSsp0T9etwR3e/I5VeC2Bm89gOvrcv6vJxWny86wBFlbwJVr5gLQeWze9SA3EGunAh1G5d24Sfnn+v9e/XUlJ6RyEFA50DGwSo6FruUdllSv7KeTLOpx26//85ChmDRNrmzXFBUVhZCQEERGRqJu3XzlEmQTGZeKJbuj1HQ2Pu8uvVVtP4zsXBfD29ZEVT+7FMdHXYCLxwo/g1IepAdumTcZCITkyyIzYvDVU79yg6IHKancsn0uUK870P9lyzZJJb5ul4otiGTRyQ+mHmxlLrVFg1rlvYYELD0L2GgktSxZmHrOyeXTwPl9lpsIqYCo2zzbEhDzB1E1VbYEU325qDcbJUVuViQo6ylm+Y5/ewlIiLROUZZAfy0jPgNCran16APA4eVA7VCg5W2WbVnpwNEVdineC9ZlmUdf/T0e+NmSkyB2zwd+mQg0HQyM+iFvH9km9Q1UQA6yXFP6sp7yLiMSDqIT03EiJgXHY5JwIjYFJ2KT1SSRopa/N2r5eVvmdsu1/Sup5UqezIUrjfjFQF2B5OZq2HLiIhbtjMRvhy4gM8eSIvN0M2NQ6yCVyu7ZuCbMyLWkHOWH+/IpyzzuVN56QT9KA18Bek7Ma86weKzlh+4fn+btk3DW8oMkZT4lSX5B5JhU9rME4nMFz1Mv5j1HaptKWZmQFOPy8UCTAcC/fszbZ95ASwBQQTjY8gMqqSM9Jexbk8UDrkCyl+Xa04N3vDWAqynScq0/uCLv5mTnF5bs/mZDgfsXFv3GTcpkVaCVICs3cEFAl4fzyjf18ni5EXJymXpmdi4iLlmC8PGY5LyAHJOMlMzrr93sX8kjL5DbB3Trem1/b7WPiZUDUZxAbYis7zlz5uCtt95CdHQ0QkND8eGHH6JrV8uwj1SyFdB6Nw1QU3xqJpbtPYtFu6Jw5HwiVuw/ryapVT6wVRB8PN3gZg6E2RQEN3N3uNU0wRxggpsZ8M5ORtWMs/BLPwu/tCg1nUmuj7jdUerx4PM70C32COKz3fHXoWi4meR5JnT7dRgqJZ1Gpk8QMqrUR2aVesj0q48sv3rI9q+PbP8GQKXqcHMz2eKv9JyqQVMJJ8/L4fCN3IRMn1qIbzhMZcZqWWkIXRAKs1TYKoJcswcyKgXh6OlYXMy4gFxNg3dWC/h3fRvJlRsgfv859b7y2h7dvoGHmxke7mZ4uJnUDY1ahxmeaWZ4ZmbAw91k2eZmtj5ugrucBDIOudmy9k1w1foG9oGjZjOg878tTRx1kmMgN3KSFa2neNU80BqUa1mypAsLQHolvjKUkJZlC8ASjCUon4xNRkRc6lW7JZb/1fo1fNAkoDIaB1ZGY5kH+Kpr/HxCOqIT0lSqW5YvWOfRCelIzcxR7ydT+AVrLfICeHuY8wVzSY17WeYqde6NmpW91HGQQVLUixYtwgMPPIBPPvkE3bp1w+zZs7F48WKEh4cjMDCw0Ocy67tkSAW0H6QC2t6zSEzPvuHX80cy2plPQoMJm3PbWrdq2OP1f6huSi70uYlaJURqlu89yHQZT2WNs73GPW4b8JbHZ9iY0w5jsl6wPWef18PwN6UiTquMC1p1XNCqIVqrhguorubRtm3VEQfJIi3dHwD5fbEFbmuQ19c91brdNrt1T3c3xxsC6+Ne7gXPPd3cHNety15XeczdbGJKphySn3AJlpaUcbJDKjk26eo3sJW93FUAbmwXkJsE+qJedV91vRT3GJIysnEhwRq4Ey3B2za3Lhd1oCEJ0oFVvBxS5kF+3vD1tFzX9v87lv8Fu3W1LP9HbupGWv8/VNvdzCrBYgQulfUtwblLly746KOP1Hpubq46+AkTJuCFF/J+jAvCQF2y0rNysOZQNA5EJSA7V1OpzRy7ueSU25Y1SeXqj6GAfe2WJWUs85xcVMmNR1D2eQTlnketnGjUzr2A4NxoBGvRCMSVnR1MNz2BFW79VUKljfY3RuUuxxFTEyz0GKG2yRSkXUKiyQ/ZZi8Vg80mkwrFaq72sa6b5WGTCqTygMwL3M+aKpJjl+KBLOsk2YVZOXbbsnOty65RzUM+b96PmlteUL/iBkCW3RyCuzo/1jpKcpb0c6/WCnpMrevL+faxnmcU9Jjd96NvMwIJAHIu7G+y3O2WLesmeJjzclU89W0Oz7EEDttzzUUPHBnZOYi4lGoJwtagfDxWUsgpKjV7NUF+XmhiSxlLMLbMZXtZZ0HLb0xMYgbOW1PlDsHcOo9JyijVQYjcrd+jHtjlurfcKJsL2J53w63fANwRWgfdG9eoOIE6MzMTPj4+WLJkCYYPH27bPmbMGMTHx+Pnn3922D8jI0NNurNnz6JVq1asTFZeZKVZavdKWbiQ8mCpqGaUpjVXIf9CEqxtAd0uqOcFeGuQty7nvwHIlH1t+1nW1fbsXPUDbdlHX891eCyjkMc46JrxSepRgodjYM8L+LKclpmNM3GpV/0+3fXs6nwBuVGAL6p4u1YbcwnSF5MzVNC2z16PSUxHWlaO+v+Q69zh5tm6btteijfSrw5vg9E31a84ZdQXL15ETk4OgoIcO9uQ9aNHj16x/8yZMzFjxowyPEIqU9KOMqC5ZXIhkipRWW3FzC4sC9k5VwbxjGLcAMiPpvzMSe6IfksvNyZ6Ob5lnrcuCwVt19etf1d/jQKe7+xUtRyDnruSbbshs8yzcyUgaMiSufXxAvfTb9SsFTjt6TlQcr6vRWVXq2DsawvKMq9X3UcF9PJy4yLZ3DKFWpuK3wi5pjLz3yjbXfdXrhe+X/u6VVHWDFGZrKimTJmCSZMmXZGiJqKCSYpMJp8ybjVFVw8aEpSlaMkxoFsChH5jlT/QS0pbgrKU27LGdPHI+ZKiHC+JdtZuzF2NUwN1zZo14ebmhgsXLjhsl/Vatewa9Ft5eXmpSZeYaNfdHRGRCwQNCbrubpbBdoiKwql5JZ6enujUqRPWrVtn2yaVyWS9e/fuzjw0IiIiQ3B61rdkZUvlsc6dO6u209I8KyUlBWPHjnX2oRERETmd0wP1vffei9jYWEydOlV1eNK+fXusXr36igpmREREFZHTA7UYP368moiIiMhR+ajPT0REVE4ZIkV9vaTimTh//ryzD4WIiKjI9Lilx7FyG6j1Zl0cwIOIiFw1jtWrV8/YfX3fiOzsbOzdu1dVPDPfyKDvVklJSaoDlcOHD6NKFWN3W2kkPG88d7zmXAf/X41x3iQlLUG6Q4cOcHd3L7+BuqRJByr+/v5ISEiAn5+fsw/HZfC88dzxmnMd/H91vfPGymREREQGxkBNRERkYAzUdqQf8WnTpjn0J07XxvN2/XjueN7KGq851ztvLKMmIiIyMKaoiYiIDIyBmoiIyMAYqImIiAyMgdpqzpw5aNCgAby9vdGtWzfs2LHDud+MC9i0aRNuv/12BAcHw2QyYdmyZc4+JJcwc+ZMdOnSRXWaEBgYiOHDhyM8PNzZh+US5s6di3bt2ql2rDLJuPWrVq1y9mG5nFmzZqn/2aeeesrZh2J406dPV+fKfmrRokWZHgMDNYBFixapcbGlRt+ePXsQGhqKwYMHIyYmpky/DFcj44bLuZKbHCq6jRs3Yty4cdi+fTvWrl2LrKwsDBo0SJ1PKlzdunVVkNm9ezd27dqFfv364c4778ShQ4d46opo586d+PTTT9UNDxVN69atVd/c+rR582aUKemZrKLr2rWrNm7cONt6Tk6OFhwcrM2cOdOpx+VK5FJaunSpsw/DJcXExKjzt3HjRmcfikuqVq2aNm/ePGcfhktISkrSmjZtqq1du1br06ePNnHiRGcfkuFNmzZNCw0NdeoxVPgUdWZmpro7HzBggO3mRfoNl/Vt27aV7V0TVUjSJaGoXr26sw/FpeTk5GDhwoUqJ0KywOnaJCdn2LBhDr93dG1///23KuJr1KgRRo0ahTNnzqAsufToWSXh4sWL6h9eBvawJ+tHjx512nFRxSAd80s5Yc+ePdGmTRtnH45LOHDggArM6enpqFy5MpYuXaoGS6DCyU2NFO1J1jcVndRZmj9/Ppo3b66yvWfMmIHevXvj4MGDZTZ4U4UP1ETOTuHIP3yZl3m5MPnBDAsLUzkRS5YswZgxY1S5P4P11UVGRmLixImqToRUmKWiGzp0qG1ZyvUlcNevXx8//PADHnroIZSFCh+oa9asCTc3N9vY1jpZr1WrVpl8CVQxjR8/HitWrFC156WSFBWNp6cnmjRpopY7deqkUojvv/++qiBFBZPiPakc27FjR9s2yUmUa++jjz5CRkaG+h2ka6tatSqaNWuG48ePo6xU+DJq+aeXf/Z169Y5ZEfKOsu9qDRI3TsJ0pJlu379ejRs2JAn+gbI/6sEGrq6/v37qyIDyYnQp86dO6vyVllmkC665ORknDhxArVr10ZZqfApaiFNsyT7TC7crl27Yvbs2aqCytixY8vsi3DVC9b+rvLUqVPqn14qRdWrV8+px2b07O7vvvsOP//8syrjio6OVttlrNtKlSo5+/AMbcqUKSorUq6vpKQkdR43bNiANWvWOPvQDE2us/x1IHx9fVGjRg3WjbiGyZMnq/4iJLv73Llzqhmv3Njcd999KCsM1ADuvfdexMbGYurUqepHs3379li9evUVFczIkbRjveWWWxxueITc9EjlC7p6px2ib9++Dtu/+uorPPjggzxthZDs2wceeEBV6pEbGykzlCA9cOBAnjcqFVFRUSooX7p0CQEBAejVq5fqA0GWywpHzyIiIjKwCl9GTUREZGQM1ERERAbGQE1ERGRgDNREREQGxkBNRERkYAzUREREBsZATUREZGAM1ERERAbGQE1EN8xkMmHZsmU8k0SlgIGayMVJt6MSKPNPQ4YMcfahEVEJYF/fROWABGXpK9yel5eX046HiEoOU9RE5YAEZRk/3X6qVq2aekxS1zIQiIw6JaNzNWrUCEuWLHF4vgyB2K9fP/W4jKj06KOPqtHR7H355Zdo3bq1ei8Z4k+G6rR38eJFjBgxAj4+PmjatCmWL19ue+zy5ctqSEUZyEDeQx7Pf2NBRAVjoCaqAF5++WXcdddd2LdvnwqY//znP3HkyBH1mAzpOnjwYBXYd+7cicWLF+P33393CMQS6GV4TgngEtQlCDdp0sThPWbMmIGRI0di//79uPXWW9X7xMXF2d7/8OHDWLVqlXpfeb2aNWuW8VkgclEaEbm0MWPGaG5ubpqvr6/D9Nprr6nH5d/8sccec3hOt27dtMcff1wtf/bZZ1q1atW05ORk2+O//vqrZjabtejoaLUeHBysvfjii1c9BnmPl156ybYuryXbVq1apdZvv/12bezYsSX8yYkqBpZRE5UDMi64Ps61rnr16rbl7t27Ozwm62FhYWpZUrihoaHw9fW1Pd6zZ0/k5uYiPDxcZZ2fO3cO/fv3L/QYZGxonbyWn5+fGj9aPP744ypFv2fPHgwaNAjDhw9Hjx49bvBTE1UMDNRE5YAExvxZ0SVFypSLwsPDw2FdArwEeyHl4xEREVi5ciXWrl2rgr5kpb/99tulcsxE5QnLqIkqgO3bt1+x3rJlS7Uscym7lrJq3ZYtW2A2m9G8eXNUqVIFDRo0wLp1627oGKQi2ZgxY/Dtt99i9uzZ+Oyzz27o9YgqCqaoicqBjIwMREdHO2xzd3e3VdiSCmKdO3dGr169sGDBAuzYsQNffPGFekwqfU2bNk0F0enTpyM2NhYTJkzA6NGjERQUpPaR7Y899hgCAwNV6jgpKUkFc9mvKKZOnYpOnTqpWuNyrCtWrLDdKBBR4RioicqB1atXqyZT9iQ1fPToUVuN7IULF+KJJ55Q+33//fdo1aqVekyaU61ZswYTJ05Ely5d1LqUJ7/77ru215Ignp6ejvfeew+TJ09WNwB33313kY/P09MTU6ZMwenTp1VWeu/evdXxENG1maRGWRH2IyIXJWXFS5cuVRW4iMj1sIyaiIjIwBioiYiIDIxl1ETlHEu3iFwbU9REREQGxkBNRERkYAzUREREBsZATUREZGAM1ERERAbGQE1ERGRgDNREREQGxkBNRERkYAzUREREMK7/BxDny/Ual2+MAAAAAElFTkSuQmCC",
"text/plain": [
"<Figure size 500x300 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from previous_chapters import plot_values\n",
"# Alternatively:\n",
"# from llms_from_scratch.ch06 import plot_values\n",
"\n",
"epochs_tensor = torch.linspace(0, num_epochs, len(train_losses))\n",
"examples_seen_tensor = torch.linspace(0, examples_seen, len(train_losses))\n",
"\n",
"plot_values(epochs_tensor, examples_seen_tensor, train_losses, val_losses, label=\"loss\")"
]
},
{
"cell_type": "markdown",
"id": "aa074723-e3f7-4f7e-a267-855531a037dc",
"metadata": {
"id": "aa074723-e3f7-4f7e-a267-855531a037dc"
},
"source": [
"- Note that we previously calculated the accuracy values on 5 batches only via the `eval_iter=5` setting; below, we calculate the accuracies on the full dataset"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "1D2awlEq0gZi",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "1D2awlEq0gZi",
"outputId": "d603eda1-d912-43eb-ec9c-af6a622510a0"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Training accuracy: 99.81%\n",
"Validation accuracy: 97.99%\n",
"Test accuracy: 96.67%\n"
]
}
],
"source": [
"train_accuracy = calc_accuracy_loader(train_loader, model, device)\n",
"val_accuracy = calc_accuracy_loader(val_loader, model, device)\n",
"test_accuracy = calc_accuracy_loader(test_loader, model, device)\n",
"\n",
"print(f\"Training accuracy: {train_accuracy*100:.2f}%\")\n",
"print(f\"Validation accuracy: {val_accuracy*100:.2f}%\")\n",
"print(f\"Test accuracy: {test_accuracy*100:.2f}%\")"
]
},
{
"cell_type": "markdown",
"id": "1f87f5e6-339e-4fcf-900b-6d845d3c713d",
"metadata": {
"id": "1f87f5e6-339e-4fcf-900b-6d845d3c713d"
},
"source": [
"- As we can see based on the relatively high accuracy values above, the LoRA finetuning was successful"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"gpuType": "V100",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.13.5"
}
},
"nbformat": 4,
"nbformat_minor": 5
}