2024-05-19 20:11:56 -05:00

1517 lines
87 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"id": "c024bfa4-1a7a-4751-b5a1-827225a3478b",
"metadata": {
"id": "c024bfa4-1a7a-4751-b5a1-827225a3478b"
},
"source": [
"<font size=\"1\">\n",
"Supplementary code for \"Build a Large Language Model From Scratch\": <a href=\"https://www.manning.com/books/build-a-large-language-model-from-scratch\">https://www.manning.com/books/build-a-large-language-model-from-scratch</a> by <a href=\"https://sebastianraschka.com\">Sebastian Raschka</a><br>\n",
"Code repository: <a href=\"https://github.com/rasbt/LLMs-from-scratch\">https://github.com/rasbt/LLMs-from-scratch</a>\n",
"</font>"
]
},
{
"cell_type": "markdown",
"id": "58b8c870-fb72-490e-8916-d8129bd5d1ff",
"metadata": {
"id": "58b8c870-fb72-490e-8916-d8129bd5d1ff"
},
"source": [
"# Appendix E: Parameter-efficient Finetuning with LoRA"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "5b7e01c2-1c84-4f2a-bb51-2e0b74abda90",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "5b7e01c2-1c84-4f2a-bb51-2e0b74abda90",
"outputId": "316166b4-027a-4756-e9b4-fe88ae75dd4f"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"matplotlib version: 3.7.1\n",
"numpy version: 1.25.2\n",
"tiktoken version: 0.7.0\n",
"torch version: 2.2.1+cu121\n",
"tensorflow version: 2.15.0\n",
"pandas version: 2.2.2\n"
]
}
],
"source": [
"from importlib.metadata import version\n",
"\n",
"pkgs = [\"matplotlib\",\n",
" \"numpy\",\n",
" \"tiktoken\",\n",
" \"torch\",\n",
" \"tensorflow\", # For OpenAI's pretrained weights\n",
" \"pandas\" # Dataset loading\n",
" ]\n",
"for p in pkgs:\n",
" print(f\"{p} version: {version(p)}\")"
]
},
{
"cell_type": "markdown",
"id": "21532056-0ef4-4c98-82c7-e91f61c6485e",
"metadata": {
"id": "21532056-0ef4-4c98-82c7-e91f61c6485e"
},
"source": [
"## E.1 Introduction to LoRA"
]
},
{
"cell_type": "markdown",
"id": "66edc999-3d91-4a1c-a157-9d056392e8d8",
"metadata": {
"id": "66edc999-3d91-4a1c-a157-9d056392e8d8"
},
"source": [
"- No code in this section\n",
"- Low-rank adaptation (LoRA) is a machine learning technique that modifies a pretrained model to better suit a specific, often smaller, dataset by adjusting only a small, low-rank subset of the model's parameters\n",
"- This approach is important because it allows for efficient finetuning of large models on task-specific data, significantly reducing the computational cost and time required for finetuning"
]
},
{
"cell_type": "markdown",
"id": "5bb75b5d-d59c-4948-821a-1594a5883dc1",
"metadata": {
"id": "5bb75b5d-d59c-4948-821a-1594a5883dc1"
},
"source": [
"- Suppose we have a large weight matrix $W$ for a given layer\n",
"- During backpropagation, we learn a $\\Delta W$ matrix, which contains information on how much we want to update the original weights to minimize the loss function during training\n",
"- In regular training and finetuning, the weight update is defined as follows:\n",
"\n",
"$$W_{\\text{updated}} = W + \\Delta W$$\n",
"\n",
"- The LoRA method proposed by [Hu et al.](https://arxiv.org/abs/2106.09685) offers a more efficient alternative to computing the weight updates $\\Delta W$ by learning an approximation of it, $\\Delta W \\approx AB$.\n",
"- In other words, in LoRA, we have the following, where $A$ and $B$ are two small weight matrices:\n",
"\n",
"$$W_{\\text{updated}} = W + AB$$\n",
"\n",
"- The figure below illustrates these formulas for full finetuning and LoRA side by side"
]
},
{
"cell_type": "markdown",
"id": "a8a7419d-cae9-4525-bb44-1641f6ef4f3b",
"metadata": {
"id": "a8a7419d-cae9-4525-bb44-1641f6ef4f3b"
},
"source": [
"<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/appendix-e_compressed/lora-1.webp\" width=\"500px\">"
]
},
{
"cell_type": "markdown",
"id": "4edd43c9-8ec5-48e6-b3fc-5fb3c16037cc",
"metadata": {
"id": "4edd43c9-8ec5-48e6-b3fc-5fb3c16037cc"
},
"source": [
"- If you paid close attention, the full finetuning and LoRA depictions in the figure above look slightly different from the formulas I have shown earlier\n",
"- That's due to the distributive law of matrix multiplication: we don't have to add the weights with the updated weights but can keep them separate\n",
"- For instance, if $x$ is the input data, then we can write the following for regular finetuning:\n",
"\n",
"$$x (W+\\Delta W) = x W + x \\Delta W$$\n",
"\n",
"- Similarly, we can write the following for LoRA:\n",
"\n",
"$$x (W+A B) = x W + x A B$$\n",
"\n",
"- The fact that we can keep the LoRA weight matrices separate makes LoRA especially attractive\n",
"- In practice, this means that we don't have to modify the weights of the pretrained model at all, as we can apply the LoRA matrices on the fly\n",
"- After setting up the dataset and loading the model, we will implement LoRA in the code to make these concepts less abstract"
]
},
{
"cell_type": "markdown",
"id": "8c7017a2-32aa-4002-a2f3-12aac293ccdf",
"metadata": {
"id": "8c7017a2-32aa-4002-a2f3-12aac293ccdf"
},
"source": [
"## E.2 Preparing the dataset"
]
},
{
"cell_type": "markdown",
"id": "669c64df-4431-4d27-834d-2bb38a01fc02",
"metadata": {
"id": "669c64df-4431-4d27-834d-2bb38a01fc02"
},
"source": [
"- This section repeats the code from chapter 6 to load and prepare the dataset\n",
"- Instead of repeating this code, one could open and run the chapter 6 notebook and then insert the LoRA code from section E.4 there\n",
"- (The LoRA code was originally the last section of chapter 6 but was moved to the appendix due to the length of chapter 6)\n",
"- In a similar fashion, we could also apply LoRA to the models in chapter 7 for instruction finetuning"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "def7c09b-af9c-4216-90ce-5e67aed1065c",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "def7c09b-af9c-4216-90ce-5e67aed1065c",
"outputId": "a67a7afe-b401-4463-c731-87025d20f72d"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"sms_spam_collection/SMSSpamCollection.tsv already exists. Skipping download and extraction.\n"
]
}
],
"source": [
"from pathlib import Path\n",
"import pandas as pd\n",
"from previous_chapters import (\n",
" download_and_unzip_spam_data,\n",
" create_balanced_dataset,\n",
" random_split\n",
")\n",
"\n",
"\n",
"url = \"https://archive.ics.uci.edu/static/public/228/sms+spam+collection.zip\"\n",
"zip_path = \"sms_spam_collection.zip\"\n",
"extracted_path = \"sms_spam_collection\"\n",
"data_file_path = Path(extracted_path) / \"SMSSpamCollection.tsv\"\n",
"\n",
"download_and_unzip_spam_data(url, zip_path, extracted_path, data_file_path)\n",
"\n",
"df = pd.read_csv(data_file_path, sep=\"\\t\", header=None, names=[\"Label\", \"Text\"])\n",
"balanced_df = create_balanced_dataset(df)\n",
"balanced_df[\"Label\"] = balanced_df[\"Label\"].map({\"ham\": 0, \"spam\": 1})\n",
"\n",
"train_df, validation_df, test_df = random_split(balanced_df, 0.7, 0.1)\n",
"train_df.to_csv(\"train.csv\", index=None)\n",
"validation_df.to_csv(\"validation.csv\", index=None)\n",
"test_df.to_csv(\"test.csv\", index=None)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "74c3c463-8763-4cc0-9320-41c7eaad8ab7",
"metadata": {
"id": "74c3c463-8763-4cc0-9320-41c7eaad8ab7"
},
"outputs": [],
"source": [
"import torch\n",
"from torch.utils.data import Dataset\n",
"import tiktoken\n",
"from previous_chapters import SpamDataset\n",
"\n",
"\n",
"tokenizer = tiktoken.get_encoding(\"gpt2\")\n",
"train_dataset = SpamDataset(\"train.csv\", max_length=None, tokenizer=tokenizer)\n",
"val_dataset = SpamDataset(\"validation.csv\", max_length=train_dataset.max_length, tokenizer=tokenizer)\n",
"test_dataset = SpamDataset(\"test.csv\", max_length=train_dataset.max_length, tokenizer=tokenizer)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "8681adc0-6f02-4e75-b01a-a6ab75d05542",
"metadata": {
"id": "8681adc0-6f02-4e75-b01a-a6ab75d05542"
},
"outputs": [],
"source": [
"from torch.utils.data import DataLoader\n",
"\n",
"num_workers = 0\n",
"batch_size = 8\n",
"\n",
"torch.manual_seed(123)\n",
"\n",
"train_loader = DataLoader(\n",
" dataset=train_dataset,\n",
" batch_size=batch_size,\n",
" shuffle=True,\n",
" num_workers=num_workers,\n",
" drop_last=True,\n",
")\n",
"\n",
"val_loader = DataLoader(\n",
" dataset=val_dataset,\n",
" batch_size=batch_size,\n",
" num_workers=num_workers,\n",
" drop_last=False,\n",
")\n",
"\n",
"test_loader = DataLoader(\n",
" dataset=test_dataset,\n",
" batch_size=batch_size,\n",
" num_workers=num_workers,\n",
" drop_last=False,\n",
")"
]
},
{
"cell_type": "markdown",
"id": "ab7335db-e0bb-4e27-80c5-eea11e593a57",
"metadata": {
"id": "ab7335db-e0bb-4e27-80c5-eea11e593a57"
},
"source": [
"- As a verification step, we iterate through the data loaders and check that the batches contain 8 training examples each, where each training example consists of 120 tokens"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "4dee6882-4c3a-4964-af15-fa31f86ad047",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "4dee6882-4c3a-4964-af15-fa31f86ad047",
"outputId": "2ae34de1-dd01-4f99-d2c8-ba4dca400754"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Train loader:\n",
"Input batch dimensions: torch.Size([8, 120])\n",
"Label batch dimensions torch.Size([8])\n"
]
}
],
"source": [
"print(\"Train loader:\")\n",
"for input_batch, target_batch in train_loader:\n",
" pass\n",
"\n",
"print(\"Input batch dimensions:\", input_batch.shape)\n",
"print(\"Label batch dimensions\", target_batch.shape)"
]
},
{
"cell_type": "markdown",
"id": "5cdd7947-7039-49bf-8a5e-c0a2f4281ca1",
"metadata": {
"id": "5cdd7947-7039-49bf-8a5e-c0a2f4281ca1"
},
"source": [
"- Lastly, let's print the total number of batches in each dataset"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "IZfw-TYD2zTj",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "IZfw-TYD2zTj",
"outputId": "4d19ed61-cf7a-4ec4-b822-c847dd1c5d77"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"130 training batches\n",
"19 validation batches\n",
"38 test batches\n"
]
}
],
"source": [
"print(f\"{len(train_loader)} training batches\")\n",
"print(f\"{len(val_loader)} validation batches\")\n",
"print(f\"{len(test_loader)} test batches\")"
]
},
{
"cell_type": "markdown",
"id": "dec9aa4a-ffd2-4d9f-a835-cce1059fe604",
"metadata": {
"id": "dec9aa4a-ffd2-4d9f-a835-cce1059fe604"
},
"source": [
"## E.3 Initializing the model"
]
},
{
"cell_type": "markdown",
"id": "f36ebdaf-810e-46a2-9ad9-e017a04051b1",
"metadata": {
"id": "f36ebdaf-810e-46a2-9ad9-e017a04051b1"
},
"source": [
"- This section repeats the code from chapter 6 to load and prepare the model"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "02b3a506-3879-4258-82b5-93a5b6bafa74",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "02b3a506-3879-4258-82b5-93a5b6bafa74",
"outputId": "b8c9b125-bb52-45d3-8071-fa5054dbf5a9"
},
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"2024-05-20 00:06:21.369837: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
"2024-05-20 00:06:21.369891: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
"2024-05-20 00:06:21.371329: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
"2024-05-20 00:06:21.380176: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
"To enable the following instructions: AVX2 AVX512F FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
"2024-05-20 00:06:22.621156: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"File already exists and is up-to-date: gpt2/124M/checkpoint\n",
"File already exists and is up-to-date: gpt2/124M/encoder.json\n",
"File already exists and is up-to-date: gpt2/124M/hparams.json\n",
"File already exists and is up-to-date: gpt2/124M/model.ckpt.data-00000-of-00001\n",
"File already exists and is up-to-date: gpt2/124M/model.ckpt.index\n",
"File already exists and is up-to-date: gpt2/124M/model.ckpt.meta\n",
"File already exists and is up-to-date: gpt2/124M/vocab.bpe\n"
]
}
],
"source": [
"from gpt_download import download_and_load_gpt2\n",
"from previous_chapters import GPTModel, load_weights_into_gpt\n",
"\n",
"\n",
"CHOOSE_MODEL = \"gpt2-small (124M)\"\n",
"INPUT_PROMPT = \"Every effort moves\"\n",
"\n",
"BASE_CONFIG = {\n",
" \"vocab_size\": 50257, # Vocabulary size\n",
" \"context_length\": 1024, # Context length\n",
" \"drop_rate\": 0.0, # Dropout rate\n",
" \"qkv_bias\": True # Query-key-value bias\n",
"}\n",
"\n",
"model_configs = {\n",
" \"gpt2-small (124M)\": {\"emb_dim\": 768, \"n_layers\": 12, \"n_heads\": 12},\n",
" \"gpt2-medium (355M)\": {\"emb_dim\": 1024, \"n_layers\": 24, \"n_heads\": 16},\n",
" \"gpt2-large (774M)\": {\"emb_dim\": 1280, \"n_layers\": 36, \"n_heads\": 20},\n",
" \"gpt2-xl (1558M)\": {\"emb_dim\": 1600, \"n_layers\": 48, \"n_heads\": 25},\n",
"}\n",
"\n",
"BASE_CONFIG.update(model_configs[CHOOSE_MODEL])\n",
"\n",
"model_size = CHOOSE_MODEL.split(\" \")[-1].lstrip(\"(\").rstrip(\")\")\n",
"settings, params = download_and_load_gpt2(model_size=model_size, models_dir=\"gpt2\")\n",
"\n",
"model = GPTModel(BASE_CONFIG)\n",
"load_weights_into_gpt(model, params)\n",
"model.eval();"
]
},
{
"cell_type": "markdown",
"id": "252614cd-7ce6-4908-83e6-3761f519904e",
"metadata": {
"id": "252614cd-7ce6-4908-83e6-3761f519904e"
},
"source": [
"- To ensure that the model was loaded corrected, let's double-check that it generates coherent text"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "8b6ce20c-0700-4783-8be0-4cf17c200a7f",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "8b6ce20c-0700-4783-8be0-4cf17c200a7f",
"outputId": "28ccbca5-8de9-41a0-c093-da00fcbaa91c"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Every effort moves you forward.\n",
"\n",
"The first step is to understand the importance of your work\n"
]
}
],
"source": [
"from previous_chapters import (\n",
" generate_text_simple,\n",
" text_to_token_ids,\n",
" token_ids_to_text\n",
")\n",
"\n",
"\n",
"text_1 = \"Every effort moves you\"\n",
"\n",
"token_ids = generate_text_simple(\n",
" model=model,\n",
" idx=text_to_token_ids(text_1, tokenizer),\n",
" max_new_tokens=15,\n",
" context_size=BASE_CONFIG[\"context_length\"]\n",
")\n",
"\n",
"print(token_ids_to_text(token_ids, tokenizer))"
]
},
{
"cell_type": "markdown",
"id": "8174b31b-1ab5-4115-b01c-245369da5af3",
"metadata": {
"id": "8174b31b-1ab5-4115-b01c-245369da5af3"
},
"source": [
"- Then, we prepare the model for classification finetuning similar to chapter 6, where we replace the output layer"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "e255ce91-d73a-4854-90a4-95804928eb16",
"metadata": {
"id": "e255ce91-d73a-4854-90a4-95804928eb16"
},
"outputs": [],
"source": [
"torch.manual_seed(123)\n",
"\n",
"num_classes = 2\n",
"model.out_head = torch.nn.Linear(in_features=768, out_features=num_classes)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "02e6f057-1383-4ece-8444-0a88e71ac75d",
"metadata": {
"id": "02e6f057-1383-4ece-8444-0a88e71ac75d"
},
"outputs": [],
"source": [
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"model.to(device); # no assignment model = model.to(device) necessary for nn.Module classes"
]
},
{
"cell_type": "markdown",
"id": "8e951cd6-5e42-44d2-b21f-895cb61004fe",
"metadata": {
"id": "8e951cd6-5e42-44d2-b21f-895cb61004fe"
},
"source": [
"- Lastly, let's calculate the initial classification accuracy of the non-finetuned model (we expect this to be around 50%, which means that the model is not able to distinguish between spam and non-spam messages yet reliably)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "fc7dd72c-73a2-4881-ade0-0a9605f1ab8c",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "fc7dd72c-73a2-4881-ade0-0a9605f1ab8c",
"outputId": "74848515-5a49-4125-fecb-9f4bac23f812"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Training accuracy: 46.25%\n",
"Validation accuracy: 45.00%\n",
"Test accuracy: 48.75%\n"
]
}
],
"source": [
"from previous_chapters import calc_accuracy_loader\n",
"\n",
"\n",
"torch.manual_seed(123)\n",
"train_accuracy = calc_accuracy_loader(train_loader, model, device, num_batches=10)\n",
"val_accuracy = calc_accuracy_loader(val_loader, model, device, num_batches=10)\n",
"test_accuracy = calc_accuracy_loader(test_loader, model, device, num_batches=10)\n",
"\n",
"print(f\"Training accuracy: {train_accuracy*100:.2f}%\")\n",
"print(f\"Validation accuracy: {val_accuracy*100:.2f}%\")\n",
"print(f\"Test accuracy: {test_accuracy*100:.2f}%\")"
]
},
{
"cell_type": "markdown",
"id": "398a1ec9-e2a1-43d6-bf9f-12ee54b46a7b",
"metadata": {
"id": "398a1ec9-e2a1-43d6-bf9f-12ee54b46a7b"
},
"source": [
"## E.4 Parameter-efficient finetuning with LoRA"
]
},
{
"cell_type": "markdown",
"id": "652a4a82-61ef-4d0a-9858-8988e844f12c",
"metadata": {
"id": "652a4a82-61ef-4d0a-9858-8988e844f12c"
},
"source": [
"- We begin by initializing a LoRALayer that creates the matrices $A$ and $B$, along with the `alpha` scaling hyperparameter and the `rank` ($r$) hyperparameters\n",
"- This layer can accept an input and compute the corresponding output, as illustrated in the figure below\n",
"\n",
"<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/appendix-e_compressed/lora-2.webp\" width=\"200px\">\n",
"\n",
"In code, this LoRA layer depicted in the figure above looks like as follows"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "2ds9ywjMwvIW",
"metadata": {
"id": "2ds9ywjMwvIW"
},
"outputs": [],
"source": [
"import math\n",
"\n",
"class LoRALayer(torch.nn.Module):\n",
" def __init__(self, in_dim, out_dim, rank, alpha):\n",
" super().__init__()\n",
" self.A = torch.nn.Parameter(torch.empty(in_dim, rank))\n",
" torch.nn.init.kaiming_uniform_(self.A, a=math.sqrt(5))\n",
" self.B = torch.nn.Parameter(torch.zeros(rank, out_dim))\n",
" self.alpha = alpha\n",
"\n",
" def forward(self, x):\n",
" x = self.alpha * (x @ self.A @ self.B)\n",
" return x"
]
},
{
"cell_type": "markdown",
"id": "ad21faa8-0614-4257-93cd-68952193e14a",
"metadata": {
"id": "ad21faa8-0614-4257-93cd-68952193e14a"
},
"source": [
"- In the code above, `rank` is a hyperparameter that controls the inner dimension of the matrices $A$ and $B$\n",
"- In other words, this parameter controls the number of additional parameters introduced by LoRA and is a key factor in determining the balance between model adaptability and parameter efficiency\n",
"- The second hyperparameter, alpha, is a scaling hyperparameter applied to the output of the low-rank adaptation\n",
"- It essentially controls the extent to which the adapted layer's output is allowed to influence the original output of the layer being adapted\n",
"- This can be seen as a way to regulate the impact of the low-rank adaptation on the layer's output\n",
"- So far, the `LoRALayer` class we implemented above allows us to transform the layer inputs $x$\n",
"- However, in LoRA, we are usually interested in replacing existing `Linear` layers so that the weight update is applied to the existing pretrained weights, as shown in the figure below\n",
"\n",
"<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/appendix-e_compressed/lora-3.webp\" width=\"200px\">"
]
},
{
"cell_type": "markdown",
"id": "3e6d5da0-dfce-4808-b89b-29ff333f563f",
"metadata": {
"id": "3e6d5da0-dfce-4808-b89b-29ff333f563f"
},
"source": [
"- To incorporate the original `Linear` layer weights as shown in the figure above, we implement a `LinearWithLoRA` layer below that uses the previously implemented LoRALayer and can be used to replace existing `Linear` layers in a neural network, for example, the self-attention module or feed forward modules in an LLM"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "127d3a64-8359-4b21-b056-78d58cc75fe8",
"metadata": {
"id": "127d3a64-8359-4b21-b056-78d58cc75fe8"
},
"outputs": [],
"source": [
"class LinearWithLoRA(torch.nn.Module):\n",
" def __init__(self, linear, rank, alpha):\n",
" super().__init__()\n",
" self.linear = linear\n",
" self.lora = LoRALayer(\n",
" linear.in_features, linear.out_features, rank, alpha\n",
" )\n",
"\n",
" def forward(self, x):\n",
" return self.linear(x) + self.lora(x)"
]
},
{
"cell_type": "markdown",
"id": "e1145a90-35ff-462c-820b-15483fa5b051",
"metadata": {
"id": "e1145a90-35ff-462c-820b-15483fa5b051"
},
"source": [
"- Note that since we initialize the weight matrix $B$ (`self.B` in `LoRALayer`) with zero values in the LoRA layer, the matrix multiplication between $A$ and $B$ results in a matrix consisting of 0's and doesn't affect the original weights (since adding 0 to the original weights does not modify them)"
]
},
{
"cell_type": "markdown",
"id": "e98a6d36-7bc9-434c-a7f1-533f26aff06d",
"metadata": {
"id": "e98a6d36-7bc9-434c-a7f1-533f26aff06d"
},
"source": [
"- To try LoRA on the GPT model we defined earlier, we define a `replace_linear_with_lora` function to replace all `Linear` layers in the model with the new `LinearWithLoRA` layers\n",
"\n",
"<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/appendix-e_compressed/lora-4.webp\" width=\"400px\">"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "WlQZ8ygqzN_g",
"metadata": {
"id": "WlQZ8ygqzN_g"
},
"outputs": [],
"source": [
"def replace_linear_with_lora(model, rank, alpha):\n",
" for name, module in model.named_children():\n",
" if isinstance(module, torch.nn.Linear):\n",
" # Replace the Linear layer with LinearWithLoRA\n",
" setattr(model, name, LinearWithLoRA(module, rank, alpha))\n",
" else:\n",
" # Recursively apply the same function to child modules\n",
" replace_linear_with_lora(module, rank, alpha)"
]
},
{
"cell_type": "markdown",
"id": "8c172164-cdde-4489-b7d7-aaed9cc2f5f2",
"metadata": {
"id": "8c172164-cdde-4489-b7d7-aaed9cc2f5f2"
},
"source": [
"- We then freeze the original model parameter and use the `replace_linear_with_lora` to replace the said `Linear` layers using the code below\n",
"- This will replace the `Linear` layers in the LLM with `LinearWithLoRA` layers"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "dbe15350-4da9-4829-9d23-98bbd3d0b1a1",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "dbe15350-4da9-4829-9d23-98bbd3d0b1a1",
"outputId": "fd4c208f-854a-4701-d9d3-9d73af733364"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Total trainable parameters before: 124,441,346\n",
"Total trainable parameters after: 0\n"
]
}
],
"source": [
"total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
"print(f\"Total trainable parameters before: {total_params:,}\")\n",
"\n",
"for param in model.parameters():\n",
" param.requires_grad = False\n",
"\n",
"total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
"print(f\"Total trainable parameters after: {total_params:,}\")"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "mLk_fPq0yz_u",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "mLk_fPq0yz_u",
"outputId": "0a93b8fc-05d7-4ace-ee47-e2fc6bdd7d75"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Total trainable LoRA parameters: 2,666,528\n"
]
}
],
"source": [
"replace_linear_with_lora(model, rank=16, alpha=16)\n",
"\n",
"total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
"print(f\"Total trainable LoRA parameters: {total_params:,}\")"
]
},
{
"cell_type": "markdown",
"id": "b8b6819e-ef7a-4f0d-841a-1b467496bef9",
"metadata": {
"id": "b8b6819e-ef7a-4f0d-841a-1b467496bef9"
},
"source": [
"- As we can see, we reduced the number of trainable parameters by almost 100x when using LoRA\n",
"- Let's now double-check whether the layers have been modified as intended by printing the model architecture"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "1711be61-bb2c-466f-9b5b-24f4aa5ccd9c",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "1711be61-bb2c-466f-9b5b-24f4aa5ccd9c",
"outputId": "acff8eca-3775-45a2-b62d-032a986ef037"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"GPTModel(\n",
" (tok_emb): Embedding(50257, 768)\n",
" (pos_emb): Embedding(1024, 768)\n",
" (drop_emb): Dropout(p=0.0, inplace=False)\n",
" (trf_blocks): Sequential(\n",
" (0): TransformerBlock(\n",
" (att): MultiHeadAttention(\n",
" (W_query): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (W_key): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (W_value): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (out_proj): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (dropout): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (ff): FeedForward(\n",
" (layers): Sequential(\n",
" (0): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=3072, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (1): GELU()\n",
" (2): LinearWithLoRA(\n",
" (linear): Linear(in_features=3072, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" )\n",
" )\n",
" (norm1): LayerNorm()\n",
" (norm2): LayerNorm()\n",
" (drop_resid): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (1): TransformerBlock(\n",
" (att): MultiHeadAttention(\n",
" (W_query): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (W_key): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (W_value): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (out_proj): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (dropout): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (ff): FeedForward(\n",
" (layers): Sequential(\n",
" (0): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=3072, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (1): GELU()\n",
" (2): LinearWithLoRA(\n",
" (linear): Linear(in_features=3072, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" )\n",
" )\n",
" (norm1): LayerNorm()\n",
" (norm2): LayerNorm()\n",
" (drop_resid): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (2): TransformerBlock(\n",
" (att): MultiHeadAttention(\n",
" (W_query): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (W_key): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (W_value): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (out_proj): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (dropout): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (ff): FeedForward(\n",
" (layers): Sequential(\n",
" (0): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=3072, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (1): GELU()\n",
" (2): LinearWithLoRA(\n",
" (linear): Linear(in_features=3072, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" )\n",
" )\n",
" (norm1): LayerNorm()\n",
" (norm2): LayerNorm()\n",
" (drop_resid): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (3): TransformerBlock(\n",
" (att): MultiHeadAttention(\n",
" (W_query): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (W_key): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (W_value): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (out_proj): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (dropout): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (ff): FeedForward(\n",
" (layers): Sequential(\n",
" (0): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=3072, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (1): GELU()\n",
" (2): LinearWithLoRA(\n",
" (linear): Linear(in_features=3072, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" )\n",
" )\n",
" (norm1): LayerNorm()\n",
" (norm2): LayerNorm()\n",
" (drop_resid): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (4): TransformerBlock(\n",
" (att): MultiHeadAttention(\n",
" (W_query): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (W_key): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (W_value): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (out_proj): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (dropout): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (ff): FeedForward(\n",
" (layers): Sequential(\n",
" (0): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=3072, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (1): GELU()\n",
" (2): LinearWithLoRA(\n",
" (linear): Linear(in_features=3072, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" )\n",
" )\n",
" (norm1): LayerNorm()\n",
" (norm2): LayerNorm()\n",
" (drop_resid): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (5): TransformerBlock(\n",
" (att): MultiHeadAttention(\n",
" (W_query): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (W_key): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (W_value): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (out_proj): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (dropout): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (ff): FeedForward(\n",
" (layers): Sequential(\n",
" (0): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=3072, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (1): GELU()\n",
" (2): LinearWithLoRA(\n",
" (linear): Linear(in_features=3072, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" )\n",
" )\n",
" (norm1): LayerNorm()\n",
" (norm2): LayerNorm()\n",
" (drop_resid): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (6): TransformerBlock(\n",
" (att): MultiHeadAttention(\n",
" (W_query): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (W_key): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (W_value): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (out_proj): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (dropout): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (ff): FeedForward(\n",
" (layers): Sequential(\n",
" (0): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=3072, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (1): GELU()\n",
" (2): LinearWithLoRA(\n",
" (linear): Linear(in_features=3072, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" )\n",
" )\n",
" (norm1): LayerNorm()\n",
" (norm2): LayerNorm()\n",
" (drop_resid): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (7): TransformerBlock(\n",
" (att): MultiHeadAttention(\n",
" (W_query): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (W_key): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (W_value): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (out_proj): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (dropout): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (ff): FeedForward(\n",
" (layers): Sequential(\n",
" (0): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=3072, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (1): GELU()\n",
" (2): LinearWithLoRA(\n",
" (linear): Linear(in_features=3072, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" )\n",
" )\n",
" (norm1): LayerNorm()\n",
" (norm2): LayerNorm()\n",
" (drop_resid): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (8): TransformerBlock(\n",
" (att): MultiHeadAttention(\n",
" (W_query): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (W_key): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (W_value): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (out_proj): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (dropout): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (ff): FeedForward(\n",
" (layers): Sequential(\n",
" (0): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=3072, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (1): GELU()\n",
" (2): LinearWithLoRA(\n",
" (linear): Linear(in_features=3072, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" )\n",
" )\n",
" (norm1): LayerNorm()\n",
" (norm2): LayerNorm()\n",
" (drop_resid): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (9): TransformerBlock(\n",
" (att): MultiHeadAttention(\n",
" (W_query): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (W_key): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (W_value): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (out_proj): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (dropout): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (ff): FeedForward(\n",
" (layers): Sequential(\n",
" (0): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=3072, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (1): GELU()\n",
" (2): LinearWithLoRA(\n",
" (linear): Linear(in_features=3072, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" )\n",
" )\n",
" (norm1): LayerNorm()\n",
" (norm2): LayerNorm()\n",
" (drop_resid): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (10): TransformerBlock(\n",
" (att): MultiHeadAttention(\n",
" (W_query): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (W_key): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (W_value): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (out_proj): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (dropout): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (ff): FeedForward(\n",
" (layers): Sequential(\n",
" (0): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=3072, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (1): GELU()\n",
" (2): LinearWithLoRA(\n",
" (linear): Linear(in_features=3072, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" )\n",
" )\n",
" (norm1): LayerNorm()\n",
" (norm2): LayerNorm()\n",
" (drop_resid): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (11): TransformerBlock(\n",
" (att): MultiHeadAttention(\n",
" (W_query): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (W_key): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (W_value): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (out_proj): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (dropout): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (ff): FeedForward(\n",
" (layers): Sequential(\n",
" (0): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=3072, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" (1): GELU()\n",
" (2): LinearWithLoRA(\n",
" (linear): Linear(in_features=3072, out_features=768, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
" )\n",
" )\n",
" (norm1): LayerNorm()\n",
" (norm2): LayerNorm()\n",
" (drop_resid): Dropout(p=0.0, inplace=False)\n",
" )\n",
" )\n",
" (final_norm): LayerNorm()\n",
" (out_head): LinearWithLoRA(\n",
" (linear): Linear(in_features=768, out_features=2, bias=True)\n",
" (lora): LoRALayer()\n",
" )\n",
")\n"
]
}
],
"source": [
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"model.to(device)\n",
"\n",
"print(model)"
]
},
{
"cell_type": "markdown",
"id": "c4bbc9d7-65ec-4675-bab8-2e56eb0cfb55",
"metadata": {
"id": "c4bbc9d7-65ec-4675-bab8-2e56eb0cfb55"
},
"source": [
"- Based on the model architecture above, we can see that the model now contains our new `LinearWithLoRA` layers\n",
"- Also, since we initialized matrix $B$ with 0's, we expect the initial model performance to be unchanged compared to before"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "DAlrb_I00VEU",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "DAlrb_I00VEU",
"outputId": "3da44ac4-230b-4358-d996-30b63f0d962a"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Training accuracy: 46.25%\n",
"Validation accuracy: 45.00%\n",
"Test accuracy: 48.75%\n"
]
}
],
"source": [
"torch.manual_seed(123)\n",
"train_accuracy = calc_accuracy_loader(train_loader, model, device, num_batches=10)\n",
"val_accuracy = calc_accuracy_loader(val_loader, model, device, num_batches=10)\n",
"test_accuracy = calc_accuracy_loader(test_loader, model, device, num_batches=10)\n",
"\n",
"print(f\"Training accuracy: {train_accuracy*100:.2f}%\")\n",
"print(f\"Validation accuracy: {val_accuracy*100:.2f}%\")\n",
"print(f\"Test accuracy: {test_accuracy*100:.2f}%\")"
]
},
{
"cell_type": "markdown",
"id": "13735b3e-f0c3-4dba-ae3d-4141b2878101",
"metadata": {
"id": "13735b3e-f0c3-4dba-ae3d-4141b2878101"
},
"source": [
"- Let's now get to the interesting part and finetune the model by reusing the training function from chapter 6\n",
"- The training takes about 15 minutes on a M3 MacBook Air laptop computer and less than half a minute on a V100 or A100 GPU"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "wCParRvr0eff",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "wCParRvr0eff",
"outputId": "ce910a9c-ee89-48bb-bfa6-49c6aee1e450"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Ep 1 (Step 000000): Train loss 3.820, Val loss 3.462\n",
"Ep 1 (Step 000050): Train loss 0.396, Val loss 0.364\n",
"Ep 1 (Step 000100): Train loss 0.111, Val loss 0.229\n",
"Training accuracy: 97.50% | Validation accuracy: 95.00%\n",
"Ep 2 (Step 000150): Train loss 0.135, Val loss 0.073\n",
"Ep 2 (Step 000200): Train loss 0.007, Val loss 0.053\n",
"Ep 2 (Step 000250): Train loss 0.021, Val loss 0.180\n",
"Training accuracy: 97.50% | Validation accuracy: 97.50%\n",
"Ep 3 (Step 000300): Train loss 0.103, Val loss 0.065\n",
"Ep 3 (Step 000350): Train loss 0.059, Val loss 0.167\n",
"Training accuracy: 100.00% | Validation accuracy: 100.00%\n",
"Ep 4 (Step 000400): Train loss 0.006, Val loss 0.118\n",
"Ep 4 (Step 000450): Train loss 0.004, Val loss 0.179\n",
"Ep 4 (Step 000500): Train loss 0.001, Val loss 0.060\n",
"Training accuracy: 97.50% | Validation accuracy: 92.50%\n",
"Ep 5 (Step 000550): Train loss 0.021, Val loss 0.128\n",
"Ep 5 (Step 000600): Train loss 0.051, Val loss 0.051\n",
"Training accuracy: 100.00% | Validation accuracy: 97.50%\n",
"Training completed in 0.83 minutes.\n"
]
}
],
"source": [
"import time\n",
"from previous_chapters import train_classifier_simple\n",
"\n",
"\n",
"start_time = time.time()\n",
"\n",
"torch.manual_seed(123)\n",
"\n",
"optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5, weight_decay=0.1)\n",
"\n",
"num_epochs = 5\n",
"train_losses, val_losses, train_accs, val_accs, examples_seen = train_classifier_simple(\n",
" model, train_loader, val_loader, optimizer, device,\n",
" num_epochs=num_epochs, eval_freq=50, eval_iter=5,\n",
" tokenizer=tokenizer\n",
")\n",
"\n",
"end_time = time.time()\n",
"execution_time_minutes = (end_time - start_time) / 60\n",
"print(f\"Training completed in {execution_time_minutes:.2f} minutes.\")"
]
},
{
"cell_type": "markdown",
"id": "d0c89e82-3aa8-44c6-b046-0b16200b8e6c",
"metadata": {
"id": "d0c89e82-3aa8-44c6-b046-0b16200b8e6c"
},
"source": [
"- Finally, let's evaluate the model"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "bawWGijA0iF3",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 308
},
"id": "bawWGijA0iF3",
"outputId": "af70782a-d605-4376-fa6c-d33b38979cfa"
},
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 500x300 with 2 Axes>"
],
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAeoAAAEiCAYAAAA21pHjAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAABPtUlEQVR4nO3deVzUdf7A8dfMwAz3ISKHAl6IJ2KKhppZUmLlplub67qFrdWvQs3MMrdSrG21204r23TbLSkrzS3T1DzKNE8UL7wFlcOLGwaY+fz+GBgYT0BgBnw/H4/vY77H5/v9vufjyPv7+V4fjVJKIYQQQgiHpLV3AEIIIYS4PEnUQgghhAOTRC2EEEI4MEnUQgghhAOTRC2EEEI4MEnUQgghhAOTRC2EEEI4MEnUQgghhAOTRC2EEEI4MEnUQggbgwcPZtKkSfYOQwhRQRK1EPVs7NixaDSai4a4uDh7hyaEaIKc7B2AEM1RXFwc8+fPt5lnMBjsFI0QoimTFrUQDcBgMBAYGGgz+Pr6ArB27Vr0ej2//PKLtfyrr75Kq1atyMrKAmD58uUMHDgQHx8f/Pz8uOuuuzh8+LC1/LFjx9BoNHz11VfcdNNNuLq6Eh0dzYEDB9iyZQt9+vTBw8ODYcOGcfr0aet6Y8eOZcSIEcycORN/f3+8vLx49NFHKS0tvex3MRqNTJkyhdatW+Pu7k6/fv1Yu3atdfnx48cZPnw4vr6+uLu7061bN5YtW3bZ7X3wwQeEh4fj4uJCQEAA9957r3WZ2Wxm1qxZtGvXDldXV3r27MnXX39ts/7u3bsZNmwYHh4eBAQEcP/993PmzBnr8sGDBzNx4kSeeeYZWrRoQWBgIImJiZeNRwhHJ4laiEZWeQ34/vvvJzc3lx07dvDCCy/wySefEBAQAEBhYSGTJ09m69atrF69Gq1Wy8iRIzGbzTbbmjFjBs8//zzbt2/HycmJv/zlLzzzzDO8/fbb/PLLLxw6dIjp06fbrLN69Wr27dvH2rVrWbhwId9++y0zZ868bLzjx49n48aNJCUlsWvXLv70pz8RFxfHwYMHAUhISMBoNLJ+/XpSUlJ45ZVX8PDwuOS2tm7dysSJE3nxxRdJTU1l+fLlDBo0yLp81qxZfPbZZ3z44Yfs2bOHJ598kr/+9a+sW7cOgJycHG699VZ69erF1q1bWb58OVlZWdx33302+/n3v/+Nu7s7v//+O6+++iovvvgiK1eurOG/kBAORgkh6lV8fLzS6XTK3d3dZnj55ZetZYxGo4qKilL33Xef6tq1q3r44YevuM3Tp08rQKWkpCillDp69KgC1CeffGIts3DhQgWo1atXW+fNmjVLRURE2MTWokULVVhYaJ03d+5c5eHhoUwmk1JKqZtvvlk98cQTSimljh8/rnQ6nTp58qRNPEOGDFHTpk1TSinVo0cPlZiYWKO6+eabb5SXl5fKy8u7aFlJSYlyc3NTv/32m838cePGqdGjRyullHrppZfU7bffbrM8PT1dASo1NdUa/8CBA23KREdHq6lTp9YoRiEcjVyjFqIB3HLLLcydO9dmXosWLazjer2ezz//nMjISMLCwnjrrbdsyh48eJDp06fz+++/c+bMGWtLOi0tje7du1vLRUZGWscrW+M9evSwmZednW2z7Z49e+Lm5madjomJoaCggPT0dMLCwmzKpqSkYDKZ6NSpk818o9GIn58fABMnTuSxxx7jp59+IjY2lnvuuccmrupuu+02wsLCaN++PXFxccTFxTFy5Ejc3Nw4dOgQRUVF3HbbbTbrlJaW0qtXLwB27tzJmjVrLtliP3z4sDXOC/cfFBR0UT0I0VRIohaiAbi7u9OxY8crlvntt98AOHfuHOfOncPd3d26bPjw4YSFhTFv3jyCg4Mxm8107979omvJzs7O1nGNRnPJeReeLq+NgoICdDod27ZtQ6fT2SyrTJYPPfQQQ4cO5YcffuCnn35i1qxZvPHGG0yYMOGi7Xl6erJ9+3bWrl3LTz/9xPTp00lMTGTLli0UFBQA8MMPP9C6dWub9SpvxCsoKGD48OG88sorF207KCjIOl69DuDa60EIe5JELYQdHD58mCeffJJ58+bx5ZdfEh8fz6pVq9BqtZw9e5bU1FTmzZvHTTfdBMCvv/5ab/veuXMnxcXFuLq6ArBp0yY8PDwICQm5qGyvXr0wmUxkZ2dbY7mUkJAQHn30UR599FGmTZvGvHnzLpmoAZycnIiNjSU2NpYZM2bg4+PDzz//zG233YbBYCAtLY2bb775kuvecMMNfPPNN7Rt2xYnJ/nzJa4P8ksXogEYjUYyMzNt5jk5OdGyZUtMJhN//etfGTp0KA8++CBxcXH06NGDN954g6effhpfX1/8/Pz4+OOPCQoKIi0tjWeffbbeYistLWXcuHE8//zzHDt2jBkzZjB+/Hi02ovvLe3UqRNjxozhgQce4I033qBXr16cPn2a1atXExkZyZ133smkSZMYNmwYnTp14vz586xZs4YuXbpcct/ff/89R44cYdCgQfj6+rJs2TLMZjMRERF4enoyZcoUnnzyScxmMwMHDiQ3N5cNGzbg5eVFfHw8CQkJzJs3j9GjR1vv6j506BBJSUl88sknF7X6hWgOJFEL0QCWL19ucyoWICIigv379/Pyyy9z/Phxvv/+e8Byyvbjjz9m9OjR3H777fTs2ZOkpCQmTpxI9+7diYiI4J133mHw4MH1EtuQIUMIDw9n0KBBGI1GRo8efcXHl+bPn88//vEPnnrqKU6ePEnLli258cYbueuuuwAwmUwkJCRw4sQJvLy8iIuLu+iaeyUfHx++/fZbEhMTKSkpITw8nIULF9KtWzcAXnrpJfz9/Zk1axZHjhzBx8eHG264gb///e8ABAcHs2HDBqZOncrtt9+O0WgkLCyMuLi4Sx5oCNEcaJRSyt5BCCEax9ixY8nJyWHJkiX2DkUIUUNyCCqEEEI4MEnUQgghhAOTU99CCCGEA5MWtRBCCOHAJFELIYQQDkwStRBCCOHAJFFXeP/992nbti0uLi7069ePzZs32zukBrd+/XqGDx9OcHAwGo3mokd2lFJMnz6doKAgXF1diY2NtfaYVOncuXOMGTMGLy8vfHx8GDdunPVVkJV27drFTTfdhIuLCyEhIbz66qsN/dXq3axZs4iOjsbT05NWrVoxYsQIUlNTbcqUlJSQkJCAn58fHh4e3HPPPdZuKyulpaVx55134ubmRqtWrXj66acpLy+3KbN27VpuuOEGDAYDHTt2ZMGCBQ399erd3LlziYyMxMvLCy8vL2JiYvjxxx+ty6WuLm/27NloNBomTZpknSf1VSUxMRGNRmMzdO7c2bq8WdaVXbsEcRBJSUlKr9erTz/9VO3Zs0c9/PDDysfHR2VlZdk7tAa1bNky9dxzz6lvv/1WAWrx4sU2y2fPnq28vb3VkiVL1M6dO9Uf/vAH1a5dO1VcXGwtExcXp3r27Kk2bdqkfvnlF9WxY0drT0dKKZWbm6sCAgLUmDFj1O7du9XChQuVq6ur+uijjxrra9aLoUOHqvnz56vdu3er5ORkdccdd6jQ0FBVUFBgLfPoo4+qkJAQtXr1arV161Z14403qv79+1uXl5eXq+7du6vY2Fi1Y8cOtWzZMtWyZUtrL1RKKXXkyBHl5uamJk+erPbu3aveffddpdPp1PLlyxv1+16rpUuXqh9++EEdOHBApaamqr///e/K2dlZ7d69WykldXU5mzdvVm3btlWRkZHWHsyUkvqqbsaMGapbt24qIyPDOpw+fdq6vDnWlSRqpVTfvn1VQkKCddpkMqng4GA1a9YsO0bVuC5M1GazWQUGBqrXXnvNOi8nJ0cZDAa1cOFCpZRSe/fuVYDasmWLtcyPP/6oNBqNtVvEDz74QPn6+iqj0WgtM3XqVJuuF5ui7OxsBah169YppSx14+zsrBYtWmQts2/fPgWojRs3KqUsB0ZarVZlZmZay8ydO1d5eXlZ6+eZZ55R3bp1s9nXqFGj1NChQxv6KzU4X19f9cknn0hdXUZ+fr4KDw9XK1eutOlqVOrL1owZM1TPnj0vuay51tV1f+q7tLSUbdu2ERsba52n1WqJjY1l48aNdozMvo4ePUpmZqZNvXh7e9OvXz9rvWzcuBEfHx/69OljLRMbG4tWq+X333+3lhk0aBB6vd5aZujQoaSmpnL+/PlG+jb1Lzc3F6jqunLbtm2UlZXZ1Ffnzp0JDQ21qa8ePXpYu6MES13k5eWxZ88ea5nq26gs05R/iyaTiaSkJAoLC4mJiZG6uoyEhATuvPPOi76T1NfFDh48SHBwMO3bt2fMmDGkpaUBzbeurvtEfebMGUwmk80/Glj68b2wU4XrSeV3v1K9ZGZm0qpVK5vlTk5OtGjRwqbMpbZRfR9NjdlsZtKkSQwYMMDaN3RmZiZ6vR4fHx+bshfW19Xq4nJl8vLyKC4uboiv02BSUlLw8PDAYDDw6KOPsnjxYrp27Sp1dQlJSUls376dWbNmXbRM6stWv379WLBgAcuXL2fu3LkcPXqUm266ifz8/GZbV9IphxC1lJCQwO7du+u168nmKCIiguTkZHJzc/n666+Jj49n3bp19g7L4aSnp/PEE0+wcuVKXFxc7B2Owxs2bJh1PDIykn79+hEWFsZXX31l7bq1ubnuW9QtW7ZEp9NddFdgVlYWgYGBdorK/iq/+5XqJTAwkOzsbJvl5eXlnDt3zqbMpbZRfR9Nyfjx4/n+++9Zs2YNbdq0sc4PDAyktLSUnJwcm/IX1tfV6uJyZby8vJrcHyG9Xk/Hjh3p3bs3s2bNomfPnrz99ttSVxfYtm0b2dnZ3HDDDTg5OeHk5MS6det45513cHJyIiAgQOrrCnx8fOjUqROHDh1qtr+t6z5R6/V6evfuzerVq63zzGYzq1evJiYmxo6R2Ve7du0IDAy0qZe8vDx+//13a73ExMSQk5PDtm3brGV+/vlnzGYz/fr1s5ZZv349ZWVl1jIrV64kIiICX1/fRvo2104pxfjx41m8eDE///wz7dq1s1neu3dvnJ2dbeorNTWVtLQ0m/pKSUmxObhZuXIlXl5edO3a1Vqm+jYqyzSH36LZbMZoNEpdXWDIkCGkpKSQnJxsHfr06cOYMWOs41Jfl1dQUMDhw4cJCgpqvr8tu9zC5mCSkpKUwWBQCxYsUHv37lWPPPKI8vHxsbkrsDnKz89XO3bsUDt27FCAevPNN9WOHTvU8ePHlVKWx7N8fHzUd999p3bt2qXuvvvuSz6e1atXL/X777+rX3/9VYWHh9s8npWTk6MCAgLU/fffr3bv3q2SkpKUm5tbk3s867HHHlPe3t5q7dq1No+FFBUVWcs8+uijKjQ0VP38889q69atKiYmRsXExFiXVz4Wcvvtt6vk5GS1fPly5e/vf8nHQp5++mm1b98+9f777zfJR2ieffZZtW7dOnX06FG1a9cu9eyzzyqNRqN++uknpZTU1dVUv+tbKamv6p566im1du1adfToUbVhwwYVGxurWrZsqbKzs5VSzbOuJFFXePfdd1VoaKjS6/Wqb9++atOmTfYOqcGtWbNGARcN8fHxSinLI1ovvPCCCggIUAaDQQ0ZMkSlpqbabOPs2bNq9OjRysPDQ3l5eakHH3xQ5efn25TZuXOnGjhwoDIYDKp169Zq9uzZjfUV682l6glQ8+fPt5YpLi5Wjz/+uPL19VVubm5q5MiRKiMjw2Y7x44dU8OGDVOurq6qZcuW6qmnnlJlZWU2ZdasWaOioqKUXq9X7du3t9lHU/G3v/1NhYWFKb1er/z9/dWQIUOsSVopqauruTBRS31VGTVqlAoKClJ6vV61bt1ajRo1Sh06dMi6vDnWlfSeJYQQQjiw6/4atRBCCOHIJFELIYQQDkwStRBCCOHAJFELIYQQDkwStRBCCOHAJFELIYQQDkwSdTVGo5HExESMRqO9Q3F4Ule1I/VVc1JXtSP1VXNNta4c5jnq2bNnM23aNJ544gnmzJljlxjy8vLw9vYmNzcXLy8vu8TQVEhd1Y7UV81JXdWO1FfNNdW6cogW9ZYtW/joo4+IjIy0dyhCCCGEQ7F7oi4oKGDMmDHMmzevSXXSIIQQQjQGu/dHnZCQwJ133klsbCz/+Mc/arVueXk5O3bsICAgAK322o858vPzATh58iR5eXnXvL3mTOqqdqS+ak7qqnakvmrOkerKbDaTlZVFr169cHK6ciq2a6JOSkpi+/btbNmypUbljUajzU0A27Zt49Zbb633uCq7OhNXJ3VVO1JfNSd1VTtSXzXnSHW1efNmoqOjr1jGbok6PT2dJ554gpUrV+Li4lKjdWbNmsXMmTMvmr9582aCgoLqO0QhhBCiQWRkZNC3b18CAgKuWtZud30vWbKEkSNHotPprPNMJhMajQatVovRaLRZBhe3qE+ePEnXrl1JT0+nTZs2jRa7EEIIcS1OnDhBSEhIjfKX3VrUQ4YMISUlxWbegw8+SOfOnZk6depFSRrAYDBgMBis0/a+xiCEEEI0NLslak9PT7p3724zz93dHT8/v4vmCyGEENcruz+eJYQQQojLs/vjWdWtXbvW3iEIIa5zJpOJsrIye4chmjhnZ+dLXsKtC4dK1PZUaCxnZ3oO5WbFoE7+9g5HCNHIlFJkZmaSk5Nj71BEM+Hj40NgYCAajeaatiOJusLq/dlMXLiDyDbekqiFuA5VJulWrVrh5uZ2zX9cxfVLKUVRURHZ2dkA1/z4sCTqCr1CfADYl5FHSZkJF+f6OWUhhHB8JpPJmqT9/PzsHY5oBlxdXQHIzs6mVatW13QaXG4mq9DG1xU/dz1lJsWeU/LYlxDXk8pr0m5ubnaORDQnlb+na73nQRJ1BY1GQ69QHwB2pJ23bzBCCLuQ092iPtXX70kSdTVRFae/k9Nz7BqHEEIIUUkSdTVRIZZuNiVRCyGuZ23btmXOnDk1Lr927Vo0Gk2D3zG/YMECfHx8GnQfjkgSdTWRId5oNHDifDFnCoxXX0EIIexIo9FccUhMTKzTdrds2cIjjzxS4/L9+/cnIyMDb2/vOu1PXJnc9V2Nl4szHfw9OJRdQHJaDrFdr96riRBC2EtGRoZ1/Msvv2T69OmkpqZa53l4eFjHlVKYTKar9n0M4O9fu0dU9Xo9gYGBtVpH1Jy0qC8g16mFEE1FYGCgdfD29kaj0Vin9+/fj6enJz/++CO9e/fGYDDw66+/cvjwYe6++24CAgLw8PAgOjqaVatW2Wz3wlPfGo2GTz75hJEjR+Lm5kZ4eDhLly61Lr/w1HflKeoVK1bQpUsXPDw8iIuLszmwKC8vZ+LEifj4+ODn58fUqVOJj49nxIgRtaqDuXPn0qFDB/R6PREREfznP/+xLlNKkZiYSGhoKAaDgeDgYCZOnGhd/sEHHxAeHo6LiwsBAQHce++9tdp3Y5FEfQFJ1EIIqHhpRWm5XYb67H342WefZfbs2ezbt4/IyEgKCgq44447WL16NTt27CAuLo7hw4eTlpZ2xe3MnDmT++67j127dnHHHXcwZswYzp07d9nyRUVFvP766/znP/9h/fr1pKWlMWXKFOvyV155hc8//5z58+ezYcMG8vLyWLJkSa2+2+LFi3niiSd46qmn2L17N//3f//Hgw8+yJo1awD45ptveOutt/joo484ePAgS5YsoUePHgBs3bqViRMn8uKLL5Kamsry5csZNGhQrfbfWOTU9wUqE/XO9BzMZoVWK49rCHE9Ki4z0XX6Crvse++LQ3HT18+f5xdffJHbbrvNOt2iRQt69uxpnX7ppZdYvHgxS5cuZfz48ZfdztixYxk9ejQA//znP3nnnXfYvHkzcXFxlyxfVlbGhx9+SIcOHQAYP348L774onX5u+++y7Rp0xg5ciQA7733HsuWLavVd3v99dcZO3Ysjz/+OACTJ09m06ZNvP7669xyyy2kpaURGBhIbGwszs7OhIaG0rdvXwDS0tJwd3fnrrvuwtPTk7CwMHr16lWr/TcWaVFfoHOgJy7OWvKN5Rw+XWDvcIQQ4pr06dPHZrqgoIApU6bQpUsXfHx88PDwYN++fVdtUUdGRlrH3d3d8fLysr4i81Lc3NysSRosr9GsLJ+bm0tWVpY1aQLodDp69+5dq++2b98+BgwYYDNvwIAB7Nu3D4A//elPFBcX0759ex5++GEWL15MeXk5ALfddhthYWG0b9+e+++/n88//5yioqJa7b+xSIv6Ak46LZGtfdh87Bw70nMID/C0d0hCCDtwddax98Whdtt3fXF3d7eZnjJlCitXruT111+nY8eOuLq6cu+991JaWnrF7Tg7O9tMazQazGZzrcrX5yn9mggJCSE1NZVVq1axcuVKHn/8cV577TXWrVuHp6cn27dvZ+3atfz0009Mnz6dxMREtmzZ4nCPgEmL+hKiKt5QJtephbh+aTQa3PROdhka8g1pGzZsYOzYsYwcOZIePXoQGBjIsWPHGmx/l+Lt7U1AQABbtmyxzjOZTGzfvr1W2+nSpQsbNmywmbdhwwa6du1qnXZ1dWX48OG88847rF27lo0bN5KSkgKAk5MTsbGxvPrqq+zatYtjx47x888/X8M3axjSor4E6w1laTl2jUMIIepbeHg43377LcOHD0ej0fDCCy9csWXcUCZMmMCsWbPo2LEjnTt35t133+X8+fO1Okh5+umnue++++jVqxexsbH873//49tvv7Xexb5gwQJMJhP9+vXDzc2N//73v7i6uhIWFsb333/PkSNHGDRoEL6+vixbtgyz2UxERERDfeU6k0R9CZWJOjUrn+JSE6566UlLCNE8vPnmm/ztb3+jf//+tGzZkqlTp5KX1/gdEU2dOpXMzEweeOABdDodjzzyCEOHDq1VL1MjRozg7bff5vXXX+eJJ56gXbt2zJ8/n8GDBwOW/qBnz57N5MmTMZlM9OjRg//973/4+fnh4+PDt99+S2JiIiUlJYSHh7Nw4UK6devWQN+47jSqsS8a1KMTJ04QEhJCeno6bdq0ubaNlRvh+G9w9hAq+iH6/XM12flGvvq/GPq2a1E/AQshHFJJSQlHjx6lXbt2uLi42Duc65LZbKZLly7cd999vPTSS/YOp15c6XdVm/wl16grFefAf0bAsqfRlORWe55aetISQoj6dvz4cebNm8eBAwdISUnhscce4+jRo/zlL3+xd2gORxJ1Jc8AaNEeUJC+mV6h0kGHEEI0FK1Wy4IFC4iOjmbAgAGkpKSwatUqunTpYu/QHI5co64utD+cOwJpvxHVzvI83w65oUwIIepdSEjIRXdsi0uTFnV1YTGWz+MbiWzjjVYDGbklZOWV2DcuIYQQ1y1J1NWFViTqU9tx15bTqeJlJ9KqFkIIYS+SqKtr0R7cW4GpFE5ukw46hBBC2J0k6uo0mqrT32m/yZ3fQggh7E4S9YVC+1s+j2+0vko05UQuJnOTfdxcCCFEEyaJ+kKVLer0zYS3dMNdr6Ow1MTB7Hz7xiWEEOK6JIn6QgHdweAFpfnoTu8hso0PIO/9FkI0X4MHD2bSpEnW6bZt2zJnzpwrrqPRaFiyZMk177u+tnMliYmJREVFNeg+GpIk6gtpdRBS0UdqtdPfcue3EMLRDB8+nLi4uEsu++WXX9BoNOzatavW292yZQuPPPLItYZn43LJMiMjg2HDhtXrvpobSdSXEnqpG8py7BaOEEJcyrhx41i5ciUnTpy4aNn8+fPp06cPkZGRtd6uv78/bm5u9RHiVQUGBmIwGBplX02VJOpLaXcztL8FwgbQqyJRH8jOp8BYbt+4hBCimrvuugt/f38WLFhgM7+goIBFixYxbtw4zp49y+jRo2ndujVubm706NGDhQsXXnG7F576PnjwIIMGDcLFxYWuXbuycuXKi9aZOnUqnTp1ws3Njfbt2/PCCy9QVlYGWLqbnDlzJjt37kSj0aDRaKwxX3jqOyUlhVtvvRVXV1f8/Px45JFHKCgosC4fO3YsI0aM4PXXXycoKAg/Pz8SEhKs+6oJs9nMiy++SJs2bTAYDERFRbF8+XLr8tLSUsaPH09QUBAuLi6EhYUxa9YsAJRSJCYmEhoaisFgIDg4mIkTJ9Z433UhrxC9lJBoeGAJAK2AYG8XTuWWsOtEDv07tLRraEKIRlZaWPt1dAbQVfx5NZWDyQgaLTi7Xn27evca78bJyYkHHniABQsW8Nxzz1n7cl60aBEmk4nRo0dTUFBA7969mTp1Kl5eXvzwww/cf//9dOjQgb59+151H2azmT/+8Y8EBATw+++/k5uba3M9u5KnpycLFiwgODiYlJQUHn74YTw9PXnmmWcYNWoUu3fvZvny5da+or29vS/aRmFhIUOHDiUmJoYtW7aQnZ3NQw89xPjx420ORtasWUNQUBBr1qzh0KFDjBo1iqioKB5++OEa1dvbb7/NG2+8wUcffUSvXr349NNP+cMf/sCePXsIDw/nnXfeYenSpXz11VeEhoaSnp5Oeno6AN988w1vvfUWSUlJdOvWjczMTHbu3Fmj/daVJOoaiAr14VRKJsnpkqiFuO78M7j26/xpAXQbaRnf/z9YNBbCBsKDP1SVmdMDis5evG5ibq129be//Y3XXnuNdevWWfthnj9/Pvfccw/e3t54e3szZcoUa/kJEyawYsUKvvrqqxol6lWrVrF//35WrFhBcLClLv75z39edF35+eeft463bduWKVOmkJSUxDPPPIOrqyseHh44OTkRGBh42X198cUXlJSU8Nlnn+Hubjlgee+99xg+fDivvPIKAQEBAPj6+vLee++h0+no3Lkzd955J6tXr65xon799deZOnUqf/7znwF45ZVXWLNmDXPmzOH9998nLS2N8PBwBg4ciEajISwszLpuWloagYGBxMbG4uzsTGhoaI3q8VrIqe8rKciGk9urrlPLDWVCCAfTuXNn+vfvz6effgrAoUOH+OWXXxg3bhwAJpOJl156iR49etCiRQs8PDxYsWIFaWlpNdr+vn37CAkJsSZpgJiYmIvKffnllwwYMIDAwEA8PDx4/vnna7yP6vvq2bOnNUkDDBgwALPZTGpqqnVet27d0Ol01umgoCCys7NrtI+8vDxOnTrFgAEDbOYPGDCAffv2AZbT68nJyURERDBx4kR++ukna7k//elPFBcX0759ex5++GEWL15MeXnDXha1a4t67ty5zJ07l2PHjgGWyp8+fbpj3AF4dD38ezj4tiPqDz8DlhvKlFLW00tCiOvA30/Vfh1dtZujOg+3bENzQbtoUsq1xVXNuHHjmDBhAu+//z7z58+nQ4cO3HzzzQC89tprvP3228yZM4cePXrg7u7OpEmTKC0trbf9b9y4kTFjxjBz5kyGDh2Kt7c3SUlJvPHGG/W2j+qcnZ1tpjUaDWazud62f8MNN3D06FF+/PFHVq1axX333UdsbCxff/01ISEhpKamsmrVKlauXMnjjz9uPaNxYVz1xa4t6jZt2jB79my2bdvG1q1bufXWW7n77rvZs2ePPcOyCOoJGh04u9HD3wmdVkN2vpGMXOlJS4jrit699oOuWhtI52SZV/369JW2Wwf33XcfWq2WL774gs8++4y//e1v1gbFhg0buPvuu/nrX/9Kz549ad++PQcOHKjxtrt06UJ6ejoZGRnWeZs2bbIp89tvvxEWFsZzzz1Hnz59CA8P5/jx47ZfV6/HZDJddV87d+6ksLDq+v2GDRvQarVERETUOOYr8fLyIjg4+KIuNjds2EDXrl1tyo0aNYp58+bx5Zdf8s0333Du3DkAXF1dGT58OO+88w5r165l48aNpKTU34HXhezaoh4+fLjN9Msvv8zcuXPZtGkT3bp1s1NUFVy8YeoxcPHCFegc6MmeU3nsSMsh2Mf1amsLIUSj8fDwYNSoUUybNo28vDzGjh1rXRYeHs7XX3/Nb7/9hq+vL2+++SZZWVk2SelKYmNj6dSpE/Hx8bz22mvk5eXx3HPP2ZQJDw8nLS2NpKQkoqOj+eGHH1i8eLFNmbZt23L06FGSk5Np06YNnp6eFz2WNWbMGGbMmEF8fDyJiYmcPn2aCRMmcP/991uvT9eHp59+mhkzZtChQweioqKYP38+ycnJfP755wC8+eabBAUF0atXL7RaLYsWLSIwMBAfHx8WLFiAyWSiX79+uLm58d///hdXV1eb69j1zWGuUZtMJpKSkigsLLzk9Q8Ao9FIXl6edcjPb+DXerp4WUelgw4hhCMbN24c58+fZ+jQoTbXk59//nluuOEGhg4dyuDBgwkMDGTEiBE13q5Wq2Xx4sUUFxfTt29fHnroIV5++WWbMn/4wx948sknGT9+PFFRUfz222+88MILNmXuuece4uLiuOWWW/D397/kI2Jubm6sWLGCc+fOER0dzb333suQIUN47733alcZVzFx4kQmT57MU089RY8ePVi+fDlLly4lPDwcsNzB/uqrr9KnTx+io6M5duwYy5YtQ6vV4uPjw7x58xgwYACRkZGsWrWK//3vf/j5+dVrjNVplFJ27W0iJSWFmJgYSkpK8PDw4IsvvuCOO+64ZNnExERmzpx50fz09HTatGnTcEGayli0I5Onv95FdFtfFj3av+H2JYRodCUlJRw9epR27drh4uJi73BEM3Gl39WJEycICQmpUf6ye4s6IiKC5ORkfv/9dx577DHi4+PZu3fvJctOmzaN3Nxc63C5cvWmrBg+HQazQ+kdYLnek3IylzJT/d20IIQQQlyJ3Z+j1uv1dOzYEYDevXuzZcsW3n77bT766KOLyhoMBptrGnl5eQ0bnLMr5GdAWRFti/bg6eJEfkk5qZn5dG998cP6QgghRH2ze4v6QmazGaPRaO8wqoRZTnNr0zfSs7InLXnvtxBCiEZi10Q9bdo01q9fz7Fjx0hJSWHatGmsXbuWMWPG2DMsW5UddBzfKB10CCGEaHR2PfWdnZ3NAw88QEZGBt7e3kRGRrJixQpuu+02e4Zlq6JFzant9O5neSxLErUQQojGYtdE/a9//cueu6+ZFu3BvRUUZnOD01EADmUXkFtchrdrw7yFRghhH/X5dish6uv3ZPebyRyeRgNhMbD3O7yztxDS4gbSzxWz60QON4X72zs6IUQ90Ov1aLVaTp06hb+/P3q9Xl4VLOpMKUVpaSmnT59Gq9Wi1+uvaXuSqGsitD/s/Q7SNhIVMoT0c8Ukp0miFqK50Gq1tGvXjoyMDE6dqsO7vYW4BDc3N0JDQ9Fqr+12MEnUNRFWcUNZ+mZ6DfTkfzvlOrUQzY1eryc0NJTy8vKrvpNaiKvR6XQ4OTnVy5kZSdQ1EdAd9J5gzONGd8uL6aUnLSGaH41Gg7Ozc4P1giREXTjcc9QOSauDEEvH4OElKTjrNJwtLOXE+WI7ByaEEKK5k0RdUxWnv51PbKJLkKWzjh1y+lsIIUQDk0RdU6EVz1Of3EGvyhefpOXYLRwhhBDXB0nUNdW6Nzz4I4zfQlSoDwA7pMtLIYQQDUxuJqspZxfrW8qiQnwB2HMqj9JyM3onOd4RQgjRMCTD1EFbPzd83JwpLTezL6OBe/ASQghxXZNEXRv5WbDsaTQLR0tPWkIIIRqFJOracDLA5nlw4Ef6B5QDkqiFEEI0LLlGXRuuPjDkBWjRns4Ewy/nJVELIYRoUJKoa+umpwCILCwF9nD0TCE5RaX4uF3bS9eFEEKIS5FT33Xk666nXUt3QE5/CyGEaDiSqGtLKTj2K6x7jRuDLSckdsiLT4QQQjQQSdS1pdHAd+NhzT8Y4n4MkBa1EEKIhiOJui4qXnzSw7wXgJ0nLD1pCSGEEPVNEnVdhFo66PA/tw29k5acojKOnS2yc1BCCCGaI0nUdVHRotae2k5UkAsAyfLebyGEEA1AEnVdtGgP7q3AVMow31OA9KQlhBCiYUiirguNBkJvBOBGpwOA3FAmhBCiYUiirquK099tC3cCsDcjj5Iykz0jEkII0QxJoq6rihvKXDK34e+mo8yk2Cs9aQkhhKhnkqjrKrAH6D3RGPO4M8ByI5m8+EQIIUR9k0RdV1odhPQF4BbXQ4BcpxZCCFH/JFFfizDL6e+u5XsAeURLCCFE/ZNEfS1CLTeU+Z3dBijSzxVztsBo35iEEEI0K5Kor0Xr3tDuZrS94+nc0gDI6W8hhBD1SxL1tXB2gfilcOvzdA9rBUiiFkIIUb/qlKjT09M5ceKEdXrz5s1MmjSJjz/+uN4Ca2qiQnwASdRCCCHqV50S9V/+8hfWrFkDQGZmJrfddhubN2/mueee48UXX6zXAJuEonMM1OwCLInabJaetIQQQtSPOiXq3bt307ev5dGkr776iu7du/Pbb7/x+eefs2DBgvqMz/EZC+C1jrT98a+0cc4jv6ScI2cK7B2VEEKIZqJOibqsrAyDwXLz1KpVq/jDH/4AQOfOncnIyKjxdmbNmkV0dDSenp60atWKESNGkJqaWpeQ7MfgAa26QstODGhVCsiLT4QQQtSfOiXqbt268eGHH/LLL7+wcuVK4uLiADh16hR+fn413s66detISEhg06ZNrFy5krKyMm6//XYKCwvrEpb9PLQKxm/Bq300INephRBC1B+nuqz0yiuvMHLkSF577TXi4+Pp2bMnAEuXLrWeEq+J5cuX20wvWLCAVq1asW3bNgYNGlSX0OzD2dIndVSIL3BUErUQQoh6U6dEPXjwYM6cOUNeXh6+vr7W+Y888ghubm51DiY3NxeAFi1a1Hkb9hTV2h0nytmfmU9xqQlXvc7eIQkhhGji6nTqu7i4GKPRaE3Sx48fZ86cOaSmptKqVas6BWI2m5k0aRIDBgyge/fulyxjNBrJy8uzDvn5+XXaV4NY8jjBH0Zwp3sqJrNi96lce0ckhBCiGahTor777rv57LPPAMjJyaFfv3688cYbjBgxgrlz59YpkISEBHbv3k1SUtJly8yaNQtvb2/r0LVr1zrtq6FoyooY6nkEgGS5oUwIIUQ9qFOi3r59OzfddBMAX3/9NQEBARw/fpzPPvuMd955p9bbGz9+PN9//z1r1qyhTZs2ly03bdo0cnNzrcPevXvrEn7DCL0RgCi1D5AbyoQQQtSPOl2jLioqwtPTE4CffvqJP/7xj2i1Wm688UaOHz9e4+0opZgwYQKLFy9m7dq1tGvX7orlDQaD9bEwgLy8vLqE3zAqOugIzN+LgVJ2pElPWkIIIa5dnVrUHTt2ZMmSJaSnp7NixQpuv/12ALKzs/Hy8qrxdhISEvjvf//LF198gaenJ5mZmWRmZlJcXFyXsOzLrwO4+6M1lxKpPcKp3BKy80rsHZUQQogmrk6Jevr06UyZMoW2bdvSt29fYmIs/TL/9NNP9OrVq8bbmTt3Lrm5uQwePJigoCDr8OWXX9YlLPvSaCDUUg/DPI8BsENOfwshhLhGdTr1fe+99zJw4EAyMjKsz1ADDBkyhJEjR9Z4O0o1s3dih/WHfUvp73wAiCM5PYeh3QLtHZUQQogmrE6JGiAwMJDAwEBrL1pt2rSp1ctOmqWKFnX7kt1oMcud30IIIa5ZnU59m81mXnzxRby9vQkLCyMsLAwfHx9eeuklzGZzfcfYdAT2AL0n+vICOmvS2HUiB5P0pCWEEOIa1KlF/dxzz/Gvf/2L2bNnM2DAAAB+/fVXEhMTKSkp4eWXX67XIJsMrQ5C+sLh1QxwPsDe0rYcyi4gItDT3pEJIYRoouqUqP/973/zySefWHvNAoiMjKR169Y8/vjj12+iBgiLgcOrudXtMPNKITn9vCRqIYQQdVanU9/nzp2jc+fOF83v3Lkz586du+agmrSK69Q9THsBJV1eCiGEuCZ1StQ9e/bkvffeu2j+e++9R2Rk5DUH1aS17g1aZ9zKcwnknLyhTAghxDWp06nvV199lTvvvJNVq1ZZn6HeuHEj6enpLFu2rF4DbHKcXWHcT5x2CSPztU1kZ+VTaCzH3VDnG+yFEEJcx+rUor755ps5cOAAI0eOJCcnh5ycHP74xz+yZ88e/vOf/9R3jE1P6xsI8PMjyNsFs4JdJ6QnLSGEEHVT52ZecHDwRTeN7dy5k3/96198/PHH1xxYc9Ar1IeMlEyS03OI6eBn73CEEEI0QXVqUYurUApWPEdi5nj8ySE5XTroEEIIUTeSqBuCRgNH1tIqfy99tKlyQ5kQQog6kzucGspNkyktK2frIsXpPCMZucUEebvaOyohhBBNTK0S9R//+McrLs/JybmWWJqX7vegB/zX/8LpjDx2pOUQ1EMStRBCiNqpVaL29va+6vIHHnjgmgJqbqJCfdibkUdyeg539AiydzhCCCGamFol6vnz5zdUHM3TqWT+bFzCLo0fyWkt7B2NEEKIJkhuJmtIv39E5P63GKrbSsrJXMpN13HPYkIIIepEEnVDCrO8tS1Gl0pxmYnUrHw7BySEEKKpkUTdkEL7AxCpOYSeMnlMSwghRK1Jom5Ifh3A3R89ZfTQHCFZetISQghRS5KoG5JGY+32sq+8+EQIIUQdSKJuaGGW09/R2v0cOl1AXkmZnQMSQgjRlEiibmihNwIQrTuARpnZlS49aQkhhKg5SdQNLaAH6D3wpIgITbp00CGEEKJWJFE3NJ0ThPQFLKe/5Tq1EEKI2pBE3RgqHtOqvKFMKWXngIQQQjQVkqgbQ8WLT6K1+zlTYOTE+WI7BySEEKKpkETdGFr3Bq0zAZocQjXZcvpbCCFEjUmibgzOrnDD/fzS6q+UKSdJ1EIIIWqsVr1niWtw11uc3n6CjLSd7EiTO7+FEELUjLSoG1FUiA8Au0/lUVouPWkJIYS4OknUjaidRzl3uOzGpTyP/Zl59g5HCCFEEyCJuhFp/n0XH/BPBmj3yHVqIYQQNSKJujGF3EiOS2tLl5fSk5YQQogasGuiXr9+PcOHDyc4OBiNRsOSJUvsGU7Di5vFjpFr+c48UFrUQgghasSuibqwsJCePXvy/vvv2zOMxqNzJqqNDwBHzhSSWyQ9aQkhhLgyuz6eNWzYMIYNG2bPEBqdr7ue9i0MZJ7LJflEDjd38rd3SEIIIRyYXKNubL+9yw8lD/CY01J5nloIIcRVNakXnhiNRoxGo3U6Pz/fjtHUkcELV3MhfbX7mSvXqYUQQlxFk2pRz5o1C29vb+vQtWtXe4dUe2GWnrSiNIfZm3ZaetISQghxRU0qUU+bNo3c3FzrsHfvXnuHVHt+HVHu/hg0ZYSUpHL8bJG9IxJCCOHAmlSiNhgMeHl5WQdPT097h1R7Gg2a0BuBqv6phRBCiMuxa6IuKCggOTmZ5ORkAI4ePUpycjJpaWn2DKvhhVpOf0dr90uiFkIIcUV2TdRbt26lV69e9OrVC4DJkyfTq1cvpk+fbs+wGl5YDAB9tAfYmXbWzsEIIYRwZHa963vw4MHX581UAT0wO7vjVVaIKWMvxvIBGJx09o5KCCGEA2pS16ibDZ0TmtB+AESxlz2npCctIYQQlyaJ2k40Fdep+2pTpYMOIYQQlyWJ2l4qrlNHa/eTLG8oE0IIcRmSqO2ldW/MWmcCNDlkp+23dzRCCCEclCRqe3F2xRzUC5PS4JV3gLMFxquvI4QQ4rojidqOnP74ISM8v+AnczQ7T+TYOxwhhBAOSBK1Pfl1oFNoawC5oUwIIcQlSaK2s6hQHwB2yBvKhBBCXIIkajsbkruYb/Qz8E1fidl8Hb78RQghxBVJorazgLI0emsPElWewpEzhfYORwghhIOx6ytEBeiixvDuAW8WZrfFKz2Hjq087B2SEEIIByItantr05u8iD9xipYkp8uLT4QQQtiSRO0AokJ8AaTLSyGEEBeRU98OoLdXDuN0yzib5UNJWX9cnKUnLSGEEBbSonYAAWc384LzfxmtXcXuk7n2DkcIIYQDkUTtADRhlp60ojSH2XU8287RCCGEcCSSqB2BX0eKnFtg0JRx/uDv9o5GCCGEA5FE7Qg0GooDowFwy9xs52CEEEI4EknUDsI9fCAAEcbdZOeX2DkaIYQQjkIStYNw6WBJ1H20B9h+7KydoxFCCOEoJFE7isBIjFpXvDRFfJD0P8Z/sZ1fDp6W938LIcR1Tp6jdhQ6J8qDozGcWM9s3Qf8vvdnlu4JY757J3pGD+De6Ha09nG1d5RCCCEamSRqB+Le8244sZ6u2uN01R4HwGzU0H31v5jz81FuCvcnISSNG0I8cA7tC24t7ByxEEKIhiaJ2pFEPwSh/SEjGTJ3Y8rcTW5uDpGuwWw6co71B06TcOw1nLX7WdJuOl3jHqFTgCecPQxpmyCwO/h3BieDvb+JEEKIeiKJ2tEEdLUMgA5oASQBx84UsmhbOid/D+FgeT5z97uRum89USE+vNByHb33vWJZX+sELTtBQDcI6G5J3gE9wDPAXt9ICCHENZBE3US0benO00M7Ux77FesOnCZsSzqH92eTnJ7Dv08WUu7clR5O6biZ8iF7r2VIWVS1AXf/asm7BwRFQavOdvs+QgghakYSdRPjpNMypEsAQ7oEcDrfyLfbT/DlFndGnekPRkUQ57jVN4sRQeeIdDqB4ew+OHsICk/DkbWWASBsADy4rGrD2xaATxiExoCzix2+mRBCiEuRRN2E+Xsa+L+bO/DIoPZsPX6eL7ek88MuJz4/78fn58FJq2FIl1aMvqUlA71O43R6D2Tuhqw9ENKvakPGAvjfJEDB04erEvXx38BUCsG9wMXbHl9RCCGue5KomwGNRkN02xZEt23BjOFd+X5XBklb0tmZnsOKPVms2JNFgJeBP/WO4b4b7yPUz812A8Z86DIcCrLBvWXV/F/egEOrAI3lunfr3tD6BmjTB1p1Ayd9o35PIYS4HmmUUk32jRonTpwgJCSE9PR02rRpY+9wHE5qZj5fbkln8Y4TnC8qs86Pae/Hn/uGMLRb4JX7vv7+SUuizkm7eJnOAEGRFcm7jyWBt2gPGk0DfJPrgFKQewKy91XcY7AP8jPAt63lIKllJ/CPAN8we0faeMpLITcdNFpw9QGDN2ib4TualIKSXCg6C4VnoOiM5VJV4ZmqeeZy8Oto+Q34R1jGneW9Ck1ZbfKXJOrrgLHcxMq9WXy5JZ1fD52h8l/cy8WJEb1ac1+fELq3vsKp7YLTcGo7nNgKJ7dZhpKci8u5+MCwV6HnKMu0UpK4LydrLxz7pSopZ+8DY96V12ndGx7+uWp64weWMyCd4sDFq2HjbSjGAjh/FM4drfg8UjWeewKUuVphjeV7jvgQOt9hmXViG+z4zHKDZPRDVUWPb7QkMlcfy+/S4NV4SV4py/+PwrOWhNsyvOpMVdom2DzPkmhvmVZV/iV/MJdddpMX08CIDyDqL5bJgtNw/hj4d5LLVE1EbfKXnPq+DhicdNwVGcxdkcGcOF/E19tOsGjrCU7mFPPZxuN8tvE43YK9+HN0CH+Iao23q7PtBjz8odNQywCgFOrsYUzpWzGd2Ib21DacsnejKcnhWLELZ4+fw1hmxuPYcjpu/ycnguPY1mkSxjITJeVmSspMGC/8LDNjLDdRUvHp7aonItCDiEAvIgI8ae/vjrOuibamkhdCxk4YMBG8gi3z9n8Pa162LVf5aF2rLpbBM9jyx/fMAcsNgYE9qsqaymDlC5aW1pN7qxL1js8tz+G37GRJEH7hln3a+4Ap96Tlnge9e1WSLTfCrDbAFdoKTq6W2MuKLOVKcm0vuWTvtdwIGX67baL+7z1QVlg1rdFaEpiLT1Xyrv7p6gvhQ6uehDAWWFqzbi3A4GmZpxSkbazW6j1z6fGis5Z/l0p/WgDdRlrG8zNh99cQcmNVotZoLIncWADufuDW0jLt1tIy7e5vKXfmAJxOtQwlOeDVumofh1bCkseg7U0w9vuq+Ts+B59Qy/sV3Fva/3dQU0pB8XnL5biCLMtn0VnL79wzyPKb9gxqugeotSSJ+jrTxteNSbGdmHBrOBsOneHLrems3JPFnlN5vPDdHv7xwz76tfcDqEigFyfVymRqVp7AYGAwzpTTWZPG4SXlFLERgKedVhDpdJJt+48wbXcKAM6Us0ifyF5zW5JVB3aaO3BQtcF8idfOr9qXZR131mno4O9BpwBPIgI9iaj4bO3jilZr5z8+5UY4c7CiZbwHyoph2CtVyze+D1kp0P7mqkTdJhoi7qxKyq26WlpZNb3uX1oIvf4K549XbRPgwHLYt9S2rN7Dsu3KU+gtwy2fLdpfdIe/Uoq84nIy8orJyCnBWG7C08UZLxdnPF2cKgZn9E7V/r3MZsg/ZWkJnztS1UK+8TEIvdFS5vhv8O1DlqcNKhO1k8GSbMoKwbedJZ4WFZ++7SzjHgGW5FJuhOKcigRV7fsG9oDB0yyXCCqZyi3JqSTHsk55saVlXnzeMpy/TJ16ta5K1EfWwJd/hTZ94aGVVWU+GwEm45X+ZaoYvMDNz3KQUCmoJ9z+suU7VvfErpr/2ytlaakbqiUpU6klcflXe+TSmA/fPV417eoLLSMsrW7/zlXjXm0a/5JC5m44exACI8Gvg2Ve+mb48ZmK5JxdszMMeg+YmGxpTAAcWGE5uG070PI4ajMhp74F5wtLWbzjJF9uSSc1K79O2zA4aXFx1uHirMXgZPn00RrpxiFK9T5kuYVjcNbRsewgTx59xGbdMp0rZ726ct43kny/SApaRpFW5kNqdiGpmXkcyCqgwFh+yf2663V0qpa4Kz/9PBrg7Wxmk+WPgPV09V7LKeyzh0CZqsrpDPD3U6CrOA7+9S3Lqcmo0bat4oaw/wfLH7wzBy0tsHNHbGOr/nXQkusSzEavO/hcfw8ZuSVk5RahL83lPLYtFWfKaa05TVtNFqGaLDrosmmvyyZUk0WQOQs9F/9R3dx5KicjHsDLxZmAooO02/oS5uAbULe9iIfeyXKAVVpoaWU3pLKSqqR94Wfx+arxGx+D4CjLOilfw5LHLQdXY6q9j+DTYZbWsnv1Vm+1z+rj9nhDoNkE2or7TvJOWZ7mOJNqOaC73JkLZ7eqeyBadoJe99fuBUllxVXJtSALCrNtW8IF2ZYzDeO3VR0QfBUPe5dA3Ctw46OWeemb4V+32W7bxcdysObRynJ2oyTXclYiLwOMuZaDoBfOVH3nRQ/Cnm9h6D8hJsEy7+R2WDgavIIsZ6k8A6vGq38avBr1jEOTu0b9/vvv89prr5GZmUnPnj1599136du371XXk0Rdv5RS7DqRy96MPPQ6S+KtTMAGZy0uTraJ2FCx3OCkRVPTH3hJLhxZV3Wt+9QOKC24uJxGa0l4TnqUTs/JsVs4cLaU/Zn5tE95m9CczXxovJ2l5ZYWW1tNBhOdFlOqnCjFGZ2zAU8Pd7w9PPD18sDP25OWPp4YDK6WP6A6veWzw61VN+XknrT80fYIqLqmmLUXNr5neaTtdKqldXYpBm/LG+UqW8e97m/U59GVUuQWl5GRW0JGbrHlM6eE7Jx8zOeO4Jp7BJ+iY4Spk3TQnKKD5hRemiIA3ikfwZvl9wEQzBl+c5nIaXyI9/kMV4Mz+SVl/Cf3QQK4fPerZUrHCdWSNBXA8YrhV3N3UlXoJctrNOBhcLK21L1cnPFytbTWK6c9XZzwcnXG29UZH1dnvN2c8XXT4+PmjKuzrua/uWthKgOd89XLObqy4qqDt9OpcHp/xSWVwxe3XCdsr2rlblsAh3+GHn+yPBkClsT30/NVifhq91ZUevpw1f+rda/B4dXQ528QafntUZJrOfPi0ari/6D/lQ92SgstMVQ/O7HpQzi+AaLHQfvBlnn7/mc5O3I1zu4Vibvi1Ppdb1UdROakg0ZDqYs/GifnerkM16QS9ZdffskDDzzAhx9+SL9+/ZgzZw6LFi0iNTWVVq1aXXFdSdTNgNlk+QNystqNall7bK/xAUw/f9HRuCnuFY60G8P+zHyKDqxj1J5Ha737w/HbCQ1tZ/mPt+xp2PwxDHoabn3eUuDENvjk1qoVnFwsLY9W3aqScqsuDXodWClFTpElCWfmFXMqp4TM3BJO5RaTmWsZz8gtobjs0i3nC7Vw1xPoaaCzZzHd9JkYWrTBLSiCQG8X2uXvIGjJvZbTouM3V630aRxk7ET5tqXMuy0lHqHku4WQ69KGs/rWZGv9ySu1nDbPLykjv6ScvAs+80vKyCsup9RkvnxwNaTXafF2syRwXze9ddzHzRmfimTu42r59HZ1xtddj4+rM276RkrwTYWpzHKZ4kzFte+zh+Du96taqN88ZHnDYWwiDHzSMu/EVvhkiO12dIaqlq91qJh2rxgPjmrQswxlJjNFpSZKykwUl5ooLjNRVGqitCgP7dlDaAoy0RVk4FSUiaEoC5eSbNyN2XiUnsbVZHsm0YSW4d7fUlAGxWUmXip9nTjNRmaUxdPl7in8ue+lD0Bro0kl6n79+hEdHc17770HgNlsJiQkhAkTJvDss89ecV1J1M1U5anKcqPl2pup1PZ606kdljuCA7pVHU3npFtOpZUbKSst4VxeAbn5BeQWFFJYWEhRcTGqvAQ95RgoQ68pR08ZD5Q+i1HnTgd/D6Zov6B//gpOdn0Yw6BJGJy1lJcU4LZ1LiUtIij0iaDYLYQyNJSbFOUmM2Xmik+Totxstsw32y4rNynKKpddZh3bcTPlZkVRaTlZeUYycospKatZcvNz1xPo7UKQtwtB3q4240HeLgR6u1z5kTyA0iLLNdDqj4KV5FluqqqHJFdSZqqWvMvJK66WyCuTe3FVks8tLiOnqIyc4jJyikopM9X9T5azToN3RQL3dXO2jl8pyXu7OaPTaKxfXYPGpho0Gsu8qnGsBwOaynlN9eAg7Xc4sQXaDbI8jgmWlu+hVRWJ2JKMzc6elJoVxnIzpeVmSk1mjGUmSk0V0xWDsdq0tWy5bbnKMsZy23VLTWZrAi4pq0rEJRXzys11/124UkKA5jyBmvMEcA4vTRH/Md1uXf6B8xxu025jfNkEbrxzLA8OaHetNdt0EnVpaSlubm58/fXXjBgxwjo/Pj6enJwcvvvuO5vyRqMRo7HqRo6TJ0/StWtXSdSiRs4WGEnNyudAZj6pWfnsz7SMF5bWrCVqb37ueoJ8XAj0ciXYx+WiJBzgVYMk3MQppSgqNVmTdm5FAj9fVEpOUWVSt4xbknvVeH205OtDZTK3jGuqjVckfOsBge1BgFZTrbwGtBUHDxqqxqk4iNBqqg4oKg8cLAcMlvmX3Fa1/WmqLddULC83K5ukWj2RXkuSrG9aDbjpnXBx1uGm1+HqrMNFr8PVWYub3sky7azDVa+1lnN1rlruoq+c1uGqr/bppMHVWYObi6HRT33b9a7vM2fOYDKZCAiwvXEhICCA/fv3X1R+1qxZzJw5s7HCE82Mn4eB/h4G+neoevuaUooT54s5UJm4s/JJzczn8OkCzMryGlZnnRYnnQYnrRZnnQYnnQZn7YXztNayOq3GMq+ijHPFMied1jq/cjuXWr9y+wZnLQFeLgR7u9LKy9Dsk3BNaDQa3A1OuBucaO1T8xd+KKUoKTNbE3pOsW2Sz71EYs8pLuV8URml5fWb4JWqdlvXRe0kx0l410LvpMWg01o+nSyf1kFXOa5Dr7P8zg26Sy2vmjY46zDotBckUS2uzk5VybQioTrrNE33DMZlNKnHs6ZNm8bkyZOt05UtaiHqSqPRENLCjZAWbgzpIl2BNlcajcbyB13vSnAtEjxYXhhkrsjVCmWTaJVS1cYBZSlTOa0qyljWrZynKjdms27ltqvKKps8rhSYK/anqn+qqm2brdMVn9XGq9a17MFcubxivrliQ+pS+1Kg02psk6eT5cbSC5Nqc0yU9mbXRN2yZUt0Oh1ZWVk287OysggMDLyovMFgwGCouhkhL6+GdxsKIUQdGZzkTIawL7u+6kmv19O7d29Wr15tnWc2m1m9ejUxMTF2jEwIIYRwDHY/9T158mTi4+Pp06cPffv2Zc6cORQWFvLggw/aOzQhhBDC7uyeqEeNGsXp06eZPn06mZmZREVFsXz58otuMBNCCCGuR3ZP1ADjx49n/Pjx9g5DCCGEcDhNtDsiIYQQ4vrgEC3qujJXPDORkZFh50iEEEKImqvMW5V57EqadKKufKyrJh14CCGEEI4mKyuL0NArvzvc7u/6vhbl5eXs2LGDgIAAtPXQn2p+fj5du3Zl7969eHp61kOE1wept7qTuqsbqbe6k7qrm/quN7PZTFZWFr169cLJ6cpt5iadqOtbXl4e3t7e5Obm4uXldfUVBCD1di2k7upG6q3upO7qxp71JjeTCSGEEA5MErUQQgjhwCRRV2MwGJgxY4bN+8TF1Um91Z3UXd1IvdWd1F3d2LPe5Bq1EEII4cCkRS2EEEI4MEnUQgghhAOTRC2EEEI4MEnUFd5//33atm2Li4sL/fr1Y/PmzfYOyeGtX7+e4cOHExwcjEajYcmSJfYOqUmYNWsW0dHReHp60qpVK0aMGEFqaqq9w2oS5s6dS2RkJF5eXnh5eRETE8OPP/5o77CanNmzZ6PRaJg0aZK9Q3F4iYmJaDQam6Fz586NGoMkauDLL79k8uTJzJgxg+3bt9OzZ0+GDh1Kdna2vUNzaIWFhfTs2ZP333/f3qE0KevWrSMhIYFNmzaxcuVKysrKuP322yksLLR3aA6vTZs2zJ49m23btrF161ZuvfVW7r77bvbs2WPv0JqMLVu28NFHHxEZGWnvUJqMbt26kZGRYR1+/fXXxg1ACdW3b1+VkJBgnTaZTCo4OFjNmjXLjlE1LYBavHixvcNokrKzsxWg1q1bZ+9QmiRfX1/1ySef2DuMJiE/P1+Fh4erlStXqptvvlk98cQT9g7J4c2YMUP17NnTrjFc9y3q0tJStm3bRmxsrHWeVqslNjaWjRs32jEycb3Izc0FoEWLFnaOpGkxmUwkJSVRWFhITEyMvcNpEhISErjzzjtt/t6Jqzt48CDBwcG0b9+eMWPGkJaW1qj7b9K9Z9WHM2fOYDKZCAgIsJkfEBDA/v377RSVuF6YzWYmTZrEgAED6N69u73DaRJSUlKIiYmhpKQEDw8PFi9eTNeuXe0dlsNLSkpi+/btbNmyxd6hNCn9+vVjwYIFREREkJGRwcyZM7npppvYvXt3o3Vqct0naiHsKSEhgd27dzf+Na8mLCIiguTkZHJzc/n666+Jj49n3bp1kqyvID09nSeeeIKVK1fi4uJi73CalGHDhlnHIyMj6devH2FhYXz11VeMGzeuUWK47hN1y5Yt0el01r6tK2VlZREYGGinqMT1YPz48Xz//fesX7+eNm3a2DucJkOv19OxY0cAevfuzZYtW3j77bf56KOP7ByZ49q2bRvZ2dnccMMN1nkmk4n169fz3nvvYTQa0el0doyw6fDx8aFTp04cOnSo0fZ53V+j1uv19O7dm9WrV1vnmc1mVq9eLde9RINQSjF+/HgWL17Mzz//TLt27ewdUpNmNpsxGo32DsOhDRkyhJSUFJKTk61Dnz59GDNmDMnJyZKka6GgoIDDhw8TFBTUaPu87lvUAJMnTyY+Pp4+ffrQt29f5syZQ2FhIQ8++KC9Q3NoBQUFNkeVR48eJTk5mRYtWhAaGmrHyBxbQkICX3zxBd999x2enp5kZmYC4O3tjaurq52jc2zTpk1j2LBhhIaGkp+fzxdffMHatWtZsWKFvUNzaJ6enhfdA+Hu7o6fn5/cG3EVU6ZMYfjw4YSFhXHq1ClmzJiBTqdj9OjRjRaDJGpg1KhRnD59munTp5OZmUlUVBTLly+/6AYzYWvr1q3ccsst1unJkycDEB8fz4IFC+wUleObO3cuAIMHD7aZP3/+fMaOHdv4ATUh2dnZPPDAA2RkZODt7U1kZCQrVqzgtttus3doopk6ceIEo0eP5uzZs/j7+zNw4EA2bdqEv79/o8UgvWcJIYQQDuy6v0YthBBCODJJ1EIIIYQDk0QthBBCODBJ1EIIIYQDk0QthBBCODBJ1EIIIYQDk0QthBBCODBJ1EIIIYQDk0QthLhmGo2GJUuW2DsMIZolSdRCNHFjx45Fo9FcNMTFxdk7NCFEPZB3fQvRDMTFxTF//nybeQaDwU7RCCHqk7SohWgGDAYDgYGBNoOvry9gOS09d+5chg0bhqurK+3bt+frr7+2WT8lJYVbb70VV1dX/Pz8eOSRRygoKLAp8+mnn9KtWzcMBgNBQUGMHz/eZvmZM2cYOXIkbm5uhIeHs3TpUuuy8+fPM2bMGPz9/XF1dSU8PPyiAwshxKVJohbiOvDCCy9wzz33sHPnTsaMGcOf//xn9u3bB0BhYSFDhw7F19eXLVu2sGjRIlatWmWTiOfOnUtCQgKPPPIIKSkpLF26lI4dO9rsY+bMmdx3333s2rWLO+64gzFjxnDu3Dnr/vfu3cuPP/7Ivn37mDt3Li1btmy8ChCiKVNCiCYtPj5e6XQ65e7ubjO8/PLLSimlAPXoo4/arNOvXz/12GOPKaWU+vjjj5Wvr68qKCiwLv/hhx+UVqtVmZmZSimlgoOD1XPPPXfZGAD1/PPPW6cLCgoUoH788UellFLDhw9XDz74YP18YSGuM3KNWohm4JZbbrH2c12pRYsW1vGYmBibZTExMSQnJwOwb98+evbsibu7u3X5gAEDMJvNpKamotFoOHXqFEOGDLliDJGRkdZxd3d3vLy8yM7OBuCxxx7jnnvuYfv27dx+++2MGDGC/v371+m7CnG9kUQtRDPg7u5+0ano+uLq6lqjcs7OzjbTGo0Gs9kMwLBhwzh+/DjLli1j5cqVDBkyhISEBF5//fV6j1eI5kauUQtxHdi0adNF0126dAGgS5cu7Ny5k8LCQuvyDRs2oNVqiYiIwNPTk7Zt27J69eprisHf35/4+Hj++9//MmfOHD7++ONr2p4Q1wtpUQvRDBiNRjIzM23mOTk5WW/YWrRoEX369GHgwIF8/vnnbN68mX/9618AjBkzhhkzZhAfH09iYiKnT59mwoQJ3H///QQEBACQmJjIo48+SqtWrRg2bBj5+fls2LCBCRMm1Ci+6dOn07t3b7p164bRaOT777+3HigIIa5MErUQzcDy5csJCgqymRcREcH+/fsByx3ZSUlJPP744wQFBbFw4UK6du0KgJubGytWrOCJJ54gOjoaNzc37rnnHt58803rtuLj4ykpKeGtt95iypQptGzZknvvvbfG8en1eqZNm8axY8dwdXXlpptuIikpqR6+uRDNn0YppewdhBCi4Wg0GhYvXsyIESPsHYoQog7kGrUQQgjhwCRRCyGEEA5MrlEL0czJ1S0hmjZpUQshhBAOTBK1EEII4cAkUQshhBAOTBK1EEII4cAkUQshhBAOTBK1EEII4cAkUQshhBAOTBK1EEII4cAkUQshhBAO7P8BjwYPyAUrLWAAAAAASUVORK5CYII=\n"
},
"metadata": {}
}
],
"source": [
"from previous_chapters import plot_values\n",
"\n",
"epochs_tensor = torch.linspace(0, num_epochs, len(train_losses))\n",
"examples_seen_tensor = torch.linspace(0, examples_seen, len(train_losses))\n",
"\n",
"plot_values(epochs_tensor, examples_seen_tensor, train_losses, val_losses, label=\"loss\")"
]
},
{
"cell_type": "markdown",
"id": "aa074723-e3f7-4f7e-a267-855531a037dc",
"metadata": {
"id": "aa074723-e3f7-4f7e-a267-855531a037dc"
},
"source": [
"- Note that we previously calculated the accuracy values on 5 batches only via the `eval_iter=5` setting; below, we calculate the accuracies on the full dataset"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "1D2awlEq0gZi",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "1D2awlEq0gZi",
"outputId": "d603eda1-d912-43eb-ec9c-af6a622510a0"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Training accuracy: 100.00%\n",
"Validation accuracy: 97.32%\n",
"Test accuracy: 97.33%\n"
]
}
],
"source": [
"from previous_chapters import calc_accuracy_loader\n",
"\n",
"train_accuracy = calc_accuracy_loader(train_loader, model, device)\n",
"val_accuracy = calc_accuracy_loader(val_loader, model, device)\n",
"test_accuracy = calc_accuracy_loader(test_loader, model, device)\n",
"\n",
"print(f\"Training accuracy: {train_accuracy*100:.2f}%\")\n",
"print(f\"Validation accuracy: {val_accuracy*100:.2f}%\")\n",
"print(f\"Test accuracy: {test_accuracy*100:.2f}%\")"
]
},
{
"cell_type": "markdown",
"id": "1f87f5e6-339e-4fcf-900b-6d845d3c713d",
"metadata": {
"id": "1f87f5e6-339e-4fcf-900b-6d845d3c713d"
},
"source": [
"- As we can see based on the relatively high accuracy values above, the LoRA finetuning was successful"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"gpuType": "V100",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.11"
}
},
"nbformat": 4,
"nbformat_minor": 5
}