2356 lines
139 KiB
Plaintext
Raw Normal View History

{
"cells": [
{
"cell_type": "markdown",
"id": "c024bfa4-1a7a-4751-b5a1-827225a3478b",
"metadata": {
"id": "c024bfa4-1a7a-4751-b5a1-827225a3478b"
},
"source": [
"<font size=\"1\">\n",
"Supplementary code for \"Build a Large Language Model From Scratch\": <a href=\"https://www.manning.com/books/build-a-large-language-model-from-scratch\">https://www.manning.com/books/build-a-large-language-model-from-scratch</a> by <a href=\"https://sebastianraschka.com\">Sebastian Raschka</a><br>\n",
"Code repository: <a href=\"https://github.com/rasbt/LLMs-from-scratch\">https://github.com/rasbt/LLMs-from-scratch</a>\n",
"</font>"
]
},
{
"cell_type": "markdown",
"id": "bfabadb8-5935-45ff-b39c-db7a29012129",
"metadata": {
"id": "bfabadb8-5935-45ff-b39c-db7a29012129"
},
"source": [
"# Chapter 6: Finetuning for Text Classification"
]
},
{
"cell_type": "code",
2024-05-05 07:10:04 -05:00
"execution_count": 1,
"id": "5b7e01c2-1c84-4f2a-bb51-2e0b74abda90",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "5b7e01c2-1c84-4f2a-bb51-2e0b74abda90",
"outputId": "9495f150-9d79-4910-d6e7-6c0d9aae4a41"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-05-06 06:55:56 -05:00
"matplotlib version: 3.8.2\n",
"numpy version: 1.26.0\n",
"tiktoken version: 0.5.1\n",
"torch version: 2.2.2\n",
"tensorflow version: 2.15.0\n",
2024-05-06 06:55:56 -05:00
"pandas version: 2.2.1\n"
]
}
],
"source": [
"from importlib.metadata import version\n",
"\n",
"pkgs = [\"matplotlib\",\n",
" \"numpy\",\n",
" \"tiktoken\",\n",
" \"torch\",\n",
" \"tensorflow\", # For OpenAI's pretrained weights\n",
" \"pandas\" # Dataset loading\n",
" ]\n",
"for p in pkgs:\n",
" print(f\"{p} version: {version(p)}\")"
]
},
2024-05-05 07:10:04 -05:00
{
"cell_type": "markdown",
"id": "a445828a-ff10-4efa-9f60-a2e2aed4c87d",
"metadata": {},
"source": [
"<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch06_compressed/chapter-overview.webp\" width=500px>"
]
},
{
"cell_type": "markdown",
"id": "3a84cf35-b37f-4c15-8972-dfafc9fadc1c",
"metadata": {
"id": "3a84cf35-b37f-4c15-8972-dfafc9fadc1c"
},
"source": [
"## 6.1 Different categories of finetuning"
]
},
{
"cell_type": "markdown",
"id": "ede3d731-5123-4f02-accd-c670ce50a5a3",
"metadata": {
"id": "ede3d731-5123-4f02-accd-c670ce50a5a3"
},
"source": [
"- No code in this section"
]
},
2024-05-05 07:10:04 -05:00
{
"cell_type": "markdown",
"id": "ac45579d-d485-47dc-829e-43be7f4db57b",
"metadata": {},
"source": [
2024-05-06 06:50:38 -05:00
"- The most common ways to finetune language models are instruction-finetuning and classification finetuning\n",
2024-05-05 07:10:04 -05:00
"- Instruction-finetuning, depicted below, is the topic of the next chapter"
]
},
{
"cell_type": "markdown",
"id": "6c29ef42-46d9-43d4-8bb4-94974e1665e4",
"metadata": {},
"source": [
"<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch06_compressed/instructions.webp\" width=500px>"
]
},
{
"cell_type": "markdown",
"id": "a7f60321-95b8-46a9-97bf-1d07fda2c3dd",
"metadata": {},
"source": [
"- Classification finetuning, the topic of this chapter, is a procedure you may already be familiar with if you have a background in machine learning -- it's similar to training a convolutional network to classify handwritten digits, for example\n",
"- In classification finetuning, we have a specific number of class labels (for example, \"spam\" and \"not spam\") that the model can output\n",
"- A classification finetuned model can only predict classes it has seen during training (for example, \"spam\" or \"not spam\", whereas an instruction-finetuned model can usually perform many tasks\n",
"- We can think of a classification-finetuned model as a very specialized model; in practice, it is much easier to create a specialized model than a generalist model that performs well on many different tasks"
]
},
{
"cell_type": "markdown",
"id": "0b37a0c4-0bb1-4061-b1fe-eaa4416d52c3",
"metadata": {},
"source": [
"<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch06_compressed/spam-non-spam.webp\" width=500px>"
]
},
{
"cell_type": "markdown",
"id": "8c7017a2-32aa-4002-a2f3-12aac293ccdf",
"metadata": {
"id": "8c7017a2-32aa-4002-a2f3-12aac293ccdf"
},
"source": [
"## 6.2 Preparing the dataset"
]
},
2024-05-05 07:10:04 -05:00
{
"cell_type": "markdown",
"id": "5f628975-d2e8-4f7f-ab38-92bb868b7067",
"metadata": {},
"source": [
"<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch06_compressed/overview-1.webp\" width=500px>"
]
},
{
"cell_type": "markdown",
"id": "9fbd459f-63fa-4d8c-8499-e23103156c7d",
"metadata": {
"id": "9fbd459f-63fa-4d8c-8499-e23103156c7d"
},
"source": [
"- This section prepares the dataset we use for classification finetuning\n",
2024-05-08 06:48:28 -05:00
"- We use a dataset consisting of spam and non-spam text messages to finetune the LLM to classify them\n",
"- First, we download and unzip the dataset"
]
},
{
"cell_type": "code",
2024-05-05 07:10:04 -05:00
"execution_count": 2,
"id": "def7c09b-af9c-4216-90ce-5e67aed1065c",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "def7c09b-af9c-4216-90ce-5e67aed1065c",
"outputId": "424e4423-f623-443c-ab9e-656f9e867559"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-05-06 07:16:40 -05:00
"sms_spam_collection/SMSSpamCollection.tsv already exists. Skipping download and extraction.\n"
]
}
],
"source": [
"import urllib.request\n",
"import zipfile\n",
"import os\n",
"from pathlib import Path\n",
"\n",
"url = \"https://archive.ics.uci.edu/static/public/228/sms+spam+collection.zip\"\n",
"zip_path = \"sms_spam_collection.zip\"\n",
"extracted_path = \"sms_spam_collection\"\n",
"data_file_path = Path(extracted_path) / \"SMSSpamCollection.tsv\"\n",
"\n",
"def download_and_unzip(url, zip_path, extracted_path, data_file_path):\n",
" if data_file_path.exists():\n",
" print(f\"{data_file_path} already exists. Skipping download and extraction.\")\n",
" return\n",
"\n",
" # Downloading the file\n",
" with urllib.request.urlopen(url) as response:\n",
" with open(zip_path, \"wb\") as out_file:\n",
" out_file.write(response.read())\n",
"\n",
" # Unzipping the file\n",
" with zipfile.ZipFile(zip_path, \"r\") as zip_ref:\n",
" zip_ref.extractall(extracted_path)\n",
"\n",
" # Add .tsv file extension\n",
" original_file_path = Path(extracted_path) / \"SMSSpamCollection\"\n",
" os.rename(original_file_path, data_file_path)\n",
" print(f\"File downloaded and saved as {data_file_path}\")\n",
"\n",
"download_and_unzip(url, zip_path, extracted_path, data_file_path)"
]
},
{
"cell_type": "markdown",
"id": "6aac2d19-06d0-4005-916b-0bd4b1ee50d1",
"metadata": {
"id": "6aac2d19-06d0-4005-916b-0bd4b1ee50d1"
},
"source": [
"- The dataset is saved as a tab-separated text file, which we can load into a pandas DataFrame"
]
},
{
"cell_type": "code",
2024-05-05 07:10:04 -05:00
"execution_count": 3,
"id": "da0ed4da-ac31-4e4d-8bdd-2153be4656a4",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 423
},
"id": "da0ed4da-ac31-4e4d-8bdd-2153be4656a4",
"outputId": "a16c5cde-d341-4887-a93f-baa9bec542ab"
},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Label</th>\n",
" <th>Text</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>ham</td>\n",
" <td>Go until jurong point, crazy.. Available only ...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>ham</td>\n",
" <td>Ok lar... Joking wif u oni...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>spam</td>\n",
" <td>Free entry in 2 a wkly comp to win FA Cup fina...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>ham</td>\n",
" <td>U dun say so early hor... U c already then say...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>ham</td>\n",
" <td>Nah I don't think he goes to usf, he lives aro...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5567</th>\n",
" <td>spam</td>\n",
" <td>This is the 2nd time we have tried 2 contact u...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5568</th>\n",
" <td>ham</td>\n",
" <td>Will ü b going to esplanade fr home?</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5569</th>\n",
" <td>ham</td>\n",
" <td>Pity, * was in mood for that. So...any other s...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5570</th>\n",
" <td>ham</td>\n",
" <td>The guy did some bitching but I acted like i'd...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5571</th>\n",
" <td>ham</td>\n",
" <td>Rofl. Its true to its name</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>5572 rows × 2 columns</p>\n",
"</div>"
],
"text/plain": [
" Label Text\n",
"0 ham Go until jurong point, crazy.. Available only ...\n",
"1 ham Ok lar... Joking wif u oni...\n",
"2 spam Free entry in 2 a wkly comp to win FA Cup fina...\n",
"3 ham U dun say so early hor... U c already then say...\n",
"4 ham Nah I don't think he goes to usf, he lives aro...\n",
"... ... ...\n",
"5567 spam This is the 2nd time we have tried 2 contact u...\n",
"5568 ham Will ü b going to esplanade fr home?\n",
"5569 ham Pity, * was in mood for that. So...any other s...\n",
"5570 ham The guy did some bitching but I acted like i'd...\n",
"5571 ham Rofl. Its true to its name\n",
"\n",
"[5572 rows x 2 columns]"
]
},
2024-05-05 07:10:04 -05:00
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import pandas as pd\n",
"\n",
"df = pd.read_csv(data_file_path, sep=\"\\t\", header=None, names=[\"Label\", \"Text\"])\n",
"df"
]
},
{
"cell_type": "markdown",
"id": "e7b6e631-4f0b-4aab-82b9-8898e6663109",
"metadata": {
"id": "e7b6e631-4f0b-4aab-82b9-8898e6663109"
},
"source": [
2024-05-08 06:48:28 -05:00
"- When we check the class distribution, we see that the data contains \"ham\" (i.e., \"not spam\") much more frequently than \"spam\""
]
},
{
"cell_type": "code",
2024-05-05 07:10:04 -05:00
"execution_count": 4,
"id": "495a5280-9d7c-41d4-9719-64ab99056d4c",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "495a5280-9d7c-41d4-9719-64ab99056d4c",
"outputId": "761e0482-43ba-4f46-f4b7-6774dae51b38"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Label\n",
"ham 4825\n",
"spam 747\n",
"Name: count, dtype: int64\n"
]
}
],
"source": [
"print(df[\"Label\"].value_counts())"
]
},
{
"cell_type": "markdown",
"id": "f773f054-0bdc-4aad-bbf6-397621bf63db",
"metadata": {
"id": "f773f054-0bdc-4aad-bbf6-397621bf63db"
},
"source": [
"- For simplicity, and because we prefer a small dataset for educational purposes anyway (it will make it possible to finetune the LLM faster), we subsample (undersample) the dataset so that it contains 747 instances from each class\n",
"- (Next to undersampling, there are several other ways to deal with class balances, but they are out of the scope of a book on LLMs; you can find examples and more information in the [`imbalanced-learn` user guide](https://imbalanced-learn.org/stable/user_guide.html))"
]
},
{
"cell_type": "code",
2024-05-05 07:10:04 -05:00
"execution_count": 5,
"id": "7be4a0a2-9704-4a96-b38f-240339818688",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "7be4a0a2-9704-4a96-b38f-240339818688",
"outputId": "396dc415-cb71-4a88-e85d-d88201c6d73f"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Label\n",
"ham 747\n",
"spam 747\n",
"Name: count, dtype: int64\n"
]
}
],
"source": [
"def create_balanced_dataset(df):\n",
" \n",
" # Count the instances of \"spam\"\n",
" num_spam = df[df[\"Label\"] == \"spam\"].shape[0]\n",
" \n",
2024-05-08 06:48:28 -05:00
" # Randomly sample \"ham\" instances to match the number of \"spam\" instances\n",
" ham_subset = df[df[\"Label\"] == \"ham\"].sample(num_spam, random_state=123)\n",
" \n",
" # Combine ham \"subset\" with \"spam\"\n",
" balanced_df = pd.concat([ham_subset, df[df[\"Label\"] == \"spam\"]])\n",
"\n",
" return balanced_df\n",
"\n",
"balanced_df = create_balanced_dataset(df)\n",
"print(balanced_df[\"Label\"].value_counts())"
]
},
{
"cell_type": "markdown",
"id": "d3fd2f5a-06d8-4d30-a2e3-230b86c559d6",
"metadata": {
"id": "d3fd2f5a-06d8-4d30-a2e3-230b86c559d6"
},
"source": [
2024-05-08 06:48:28 -05:00
"- Next, we change the string class labels \"ham\" and \"spam\" into integer class labels 0 and 1:"
]
},
{
"cell_type": "code",
2024-05-05 07:10:04 -05:00
"execution_count": 6,
"id": "c1b10c3d-5d57-42d0-8de8-cf80a06f5ffd",
"metadata": {
"id": "c1b10c3d-5d57-42d0-8de8-cf80a06f5ffd"
},
"outputs": [],
"source": [
"balanced_df[\"Label\"] = balanced_df[\"Label\"].map({\"ham\": 0, \"spam\": 1})"
]
},
{
"cell_type": "markdown",
"id": "5715e685-35b4-4b45-a86c-8a8694de9d6f",
"metadata": {
"id": "5715e685-35b4-4b45-a86c-8a8694de9d6f"
},
"source": [
"- Let's now define a function that randomly divides the dataset into a training, validation, and test subset"
]
},
{
"cell_type": "code",
2024-05-05 07:10:04 -05:00
"execution_count": 7,
"id": "uQl0Psdmx15D",
"metadata": {
"id": "uQl0Psdmx15D"
},
"outputs": [],
"source": [
"def random_split(df, train_frac, validation_frac):\n",
" # Shuffle the entire DataFrame\n",
" df = df.sample(frac=1, random_state=123).reset_index(drop=True)\n",
"\n",
" # Calculate split indices\n",
" train_end = int(len(df) * train_frac)\n",
" validation_end = train_end + int(len(df) * validation_frac)\n",
"\n",
" # Split the DataFrame\n",
" train_df = df[:train_end]\n",
" validation_df = df[train_end:validation_end]\n",
" test_df = df[validation_end:]\n",
"\n",
" return train_df, validation_df, test_df\n",
"\n",
"train_df, validation_df, test_df = random_split(balanced_df, 0.7, 0.1)\n",
"# Test size is implied to be 0.2 as the remainder\n",
"\n",
"train_df.to_csv(\"train.csv\", index=None)\n",
"validation_df.to_csv(\"validation.csv\", index=None)\n",
"test_df.to_csv(\"test.csv\", index=None)"
]
},
2024-05-05 07:10:04 -05:00
{
"cell_type": "markdown",
"id": "a8d7a0c5-1d5f-458a-b685-3f49520b0094",
"metadata": {},
"source": [
"## 6.3 Creating data loaders"
]
},
{
"cell_type": "markdown",
"id": "7126108a-75e7-4862-b0fb-cbf59a18bb6c",
"metadata": {
"id": "7126108a-75e7-4862-b0fb-cbf59a18bb6c"
},
"source": [
"- Note that the text messages have different lengths; if we want to combine multiple training examples in a batch, we have to either\n",
" - 1. truncate all messages to the length of the shortest message in the dataset or batch\n",
" - 2. pad all messages to the length of the longest message in the dataset or batch\n",
"\n",
2024-05-08 08:14:03 -05:00
"- We choose option 2 and pad all messages to the longest message in the dataset\n",
"- For that, we use `<|endoftext|>` as a padding token, as discussed in chapter 2"
]
},
2024-05-08 20:46:54 -05:00
{
"cell_type": "markdown",
"id": "0829f33f-1428-4f22-9886-7fee633b3666",
"metadata": {},
"source": [
"<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch06_compressed/pad-input-sequences.webp?123\" width=500px>"
]
},
{
"cell_type": "code",
2024-05-05 07:10:04 -05:00
"execution_count": 8,
"id": "74c3c463-8763-4cc0-9320-41c7eaad8ab7",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "74c3c463-8763-4cc0-9320-41c7eaad8ab7",
"outputId": "b5b48439-32c8-4b37-cca2-c9dc8fa86563"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[50256]\n"
]
}
],
"source": [
"import tiktoken\n",
"\n",
"tokenizer = tiktoken.get_encoding(\"gpt2\")\n",
"print(tokenizer.encode(\"<|endoftext|>\", allowed_special={\"<|endoftext|>\"}))"
]
},
{
"cell_type": "markdown",
"id": "04f582ff-68bf-450e-bd87-5fb61afe431c",
"metadata": {
"id": "04f582ff-68bf-450e-bd87-5fb61afe431c"
},
"source": [
"- The `SpamDataset` class below identifies the longest sequence in the training dataset and adds the padding token to the others to match that sequence length"
]
},
{
"cell_type": "code",
2024-05-11 07:42:13 -05:00
"execution_count": 9,
"id": "d7791b52-af18-4ac4-afa9-b921068e383e",
"metadata": {
"id": "d7791b52-af18-4ac4-afa9-b921068e383e"
},
"outputs": [],
"source": [
"import torch\n",
"from torch.utils.data import Dataset\n",
"\n",
"\n",
"class SpamDataset(Dataset):\n",
" def __init__(self, csv_file, tokenizer, max_length=None, pad_token_id=50256):\n",
" self.data = pd.read_csv(csv_file)\n",
"\n",
" # Pre-tokenize texts\n",
" self.encoded_texts = [\n",
" tokenizer.encode(text) for text in self.data[\"Text\"]\n",
" ]\n",
"\n",
" if max_length is None:\n",
" self.max_length = self._longest_encoded_length()\n",
" else:\n",
" self.max_length = max_length\n",
" # Truncate sequences if they are longer than max_length\n",
" self.encoded_texts = [\n",
" encoded_text[:self.max_length]\n",
" for encoded_text in self.encoded_texts\n",
" ]\n",
"\n",
" # Pad sequences to the longest sequence\n",
" self.encoded_texts = [\n",
" encoded_text + [pad_token_id] * (self.max_length - len(encoded_text))\n",
" for encoded_text in self.encoded_texts\n",
" ]\n",
"\n",
" def __getitem__(self, index):\n",
" encoded = self.encoded_texts[index]\n",
" label = self.data.iloc[index][\"Label\"]\n",
2024-05-06 07:40:09 -05:00
" return (\n",
" torch.tensor(encoded, dtype=torch.long),\n",
" torch.tensor(label, dtype=torch.long)\n",
" )\n",
"\n",
" def __len__(self):\n",
" return len(self.data)\n",
"\n",
" def _longest_encoded_length(self):\n",
" max_length = 0\n",
" for encoded_text in self.encoded_texts:\n",
" encoded_length = len(encoded_text)\n",
" if encoded_length > max_length:\n",
" max_length = encoded_length\n",
" return max_length"
]
},
{
"cell_type": "code",
2024-05-11 07:42:13 -05:00
"execution_count": 10,
"id": "uzj85f8ou82h",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "uzj85f8ou82h",
"outputId": "d08f1cf0-c24d-445f-a3f8-793532c3716f"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"120\n"
]
}
],
"source": [
2024-05-06 07:40:09 -05:00
"train_dataset = SpamDataset(\n",
" csv_file=\"train.csv\",\n",
" max_length=None,\n",
" tokenizer=tokenizer\n",
")\n",
"\n",
"print(train_dataset.max_length)"
]
},
{
"cell_type": "markdown",
"id": "15bdd932-97eb-4b88-9cf9-d766ea4c3a60",
"metadata": {},
"source": [
"- We also pad the validation and test set to the longest training sequence\n",
"- Note that validation and test set samples that are longer than the longest training example are being truncated via `encoded_text[:self.max_length]` in the `SpamDataset` code\n",
"- This behavior is entirely optional, and it would also work well if we set `max_length=None` in both the validation and test set cases"
]
},
{
"cell_type": "code",
2024-05-11 07:42:13 -05:00
"execution_count": 11,
"id": "bb0c502d-a75e-4248-8ea0-196e2b00c61e",
"metadata": {
"id": "bb0c502d-a75e-4248-8ea0-196e2b00c61e"
},
"outputs": [],
"source": [
2024-05-06 07:40:09 -05:00
"val_dataset = SpamDataset(\n",
" csv_file=\"validation.csv\",\n",
" max_length=train_dataset.max_length,\n",
" tokenizer=tokenizer\n",
")\n",
"test_dataset = SpamDataset(\n",
" csv_file=\"test.csv\",\n",
" max_length=train_dataset.max_length,\n",
" tokenizer=tokenizer\n",
")"
]
},
{
"cell_type": "markdown",
"id": "20170d89-85a0-4844-9887-832f5d23432a",
"metadata": {},
"source": [
2024-05-06 07:40:09 -05:00
"- Next, we use the dataset to instantiate the data loaders, which is similar to creating the data loaders in previous chapters"
]
},
2024-05-05 07:10:04 -05:00
{
"cell_type": "markdown",
"id": "64bcc349-205f-48f8-9655-95ff21f5e72f",
"metadata": {},
"source": [
"<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch06_compressed/batch.webp\" width=500px>"
]
},
{
"cell_type": "code",
2024-05-11 07:42:13 -05:00
"execution_count": 12,
"id": "8681adc0-6f02-4e75-b01a-a6ab75d05542",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "8681adc0-6f02-4e75-b01a-a6ab75d05542",
"outputId": "3266c410-4fdb-4a8c-a142-7f707e2525ab"
},
"outputs": [],
"source": [
"from torch.utils.data import DataLoader\n",
"\n",
2024-05-09 07:25:52 -05:00
"num_workers = 0\n",
"batch_size = 8\n",
"\n",
"torch.manual_seed(123)\n",
"\n",
"train_loader = DataLoader(\n",
" dataset=train_dataset,\n",
" batch_size=batch_size,\n",
" shuffle=True,\n",
" num_workers=num_workers,\n",
" drop_last=True,\n",
")\n",
"\n",
"val_loader = DataLoader(\n",
" dataset=val_dataset,\n",
" batch_size=batch_size,\n",
" num_workers=num_workers,\n",
" drop_last=False,\n",
")\n",
"\n",
"test_loader = DataLoader(\n",
" dataset=test_dataset,\n",
" batch_size=batch_size,\n",
" num_workers=num_workers,\n",
" drop_last=False,\n",
")"
]
},
{
"cell_type": "markdown",
"id": "ab7335db-e0bb-4e27-80c5-eea11e593a57",
"metadata": {},
"source": [
2024-05-04 07:34:29 -05:00
"- As a verification step, we iterate through the data loaders and ensure that the batches contain 8 training examples each, where each training example consists of 120 tokens"
]
},
{
"cell_type": "code",
2024-05-11 07:42:13 -05:00
"execution_count": 13,
"id": "4dee6882-4c3a-4964-af15-fa31f86ad047",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train loader:\n",
"Input batch dimensions: torch.Size([8, 120])\n",
"Label batch dimensions torch.Size([8])\n"
]
}
],
"source": [
"print(\"Train loader:\")\n",
"for input_batch, target_batch in train_loader:\n",
" pass\n",
"\n",
"print(\"Input batch dimensions:\", input_batch.shape)\n",
"print(\"Label batch dimensions\", target_batch.shape)"
]
},
{
"cell_type": "markdown",
"id": "5cdd7947-7039-49bf-8a5e-c0a2f4281ca1",
"metadata": {},
"source": [
2024-05-06 07:40:09 -05:00
"- Lastly, let's print the total number of batches in each dataset"
]
},
{
"cell_type": "code",
2024-05-11 07:42:13 -05:00
"execution_count": 14,
"id": "IZfw-TYD2zTj",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "IZfw-TYD2zTj",
"outputId": "6934bbf2-9797-4fbe-d26b-1a246e18c2fb"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"130 training batches\n",
"19 validation batches\n",
"38 test batches\n"
]
}
],
"source": [
"print(f\"{len(train_loader)} training batches\")\n",
"print(f\"{len(val_loader)} validation batches\")\n",
"print(f\"{len(test_loader)} test batches\")"
]
},
{
"cell_type": "markdown",
"id": "d1c4f61a-5f5d-4b3b-97cf-151b617d1d6c",
"metadata": {
"id": "d1c4f61a-5f5d-4b3b-97cf-151b617d1d6c"
},
"source": [
2024-05-05 07:10:04 -05:00
"## 6.4 Initializing a model with pretrained weights"
]
},
{
"cell_type": "markdown",
"id": "97e1af8b-8bd1-4b44-8b8b-dc031496e208",
"metadata": {},
"source": [
2024-05-05 07:10:04 -05:00
"- In this section, we initialize the pretrained model we worked with in the previous chapter\n",
"\n",
"<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch06_compressed/overview-2.webp\" width=500px>"
]
},
{
"cell_type": "code",
2024-05-11 07:42:13 -05:00
"execution_count": 15,
"id": "2992d779-f9fb-4812-a117-553eb790a5a9",
"metadata": {
"id": "2992d779-f9fb-4812-a117-553eb790a5a9"
},
"outputs": [],
"source": [
"CHOOSE_MODEL = \"gpt2-small (124M)\"\n",
"INPUT_PROMPT = \"Every effort moves\"\n",
"\n",
"BASE_CONFIG = {\n",
" \"vocab_size\": 50257, # Vocabulary size\n",
" \"context_length\": 1024, # Context length\n",
" \"drop_rate\": 0.0, # Dropout rate\n",
" \"qkv_bias\": True # Query-key-value bias\n",
"}\n",
"\n",
"model_configs = {\n",
" \"gpt2-small (124M)\": {\"emb_dim\": 768, \"n_layers\": 12, \"n_heads\": 12},\n",
" \"gpt2-medium (355M)\": {\"emb_dim\": 1024, \"n_layers\": 24, \"n_heads\": 16},\n",
" \"gpt2-large (774M)\": {\"emb_dim\": 1280, \"n_layers\": 36, \"n_heads\": 20},\n",
" \"gpt2-xl (1558M)\": {\"emb_dim\": 1600, \"n_layers\": 48, \"n_heads\": 25},\n",
"}\n",
"\n",
"BASE_CONFIG.update(model_configs[CHOOSE_MODEL])"
]
},
{
"cell_type": "code",
2024-05-11 07:42:13 -05:00
"execution_count": 16,
"id": "022a649a-44f5-466c-8a8e-326c063384f5",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "022a649a-44f5-466c-8a8e-326c063384f5",
"outputId": "7091e401-8442-4f47-a1d9-ecb42a1ef930"
},
"outputs": [
{
2024-05-11 07:42:13 -05:00
"name": "stdout",
"output_type": "stream",
"text": [
2024-05-11 07:42:13 -05:00
"File already exists and is up-to-date: gpt2/124M/checkpoint\n",
"File already exists and is up-to-date: gpt2/124M/encoder.json\n",
"File already exists and is up-to-date: gpt2/124M/hparams.json\n",
"File already exists and is up-to-date: gpt2/124M/model.ckpt.data-00000-of-00001\n",
"File already exists and is up-to-date: gpt2/124M/model.ckpt.index\n",
"File already exists and is up-to-date: gpt2/124M/model.ckpt.meta\n",
"File already exists and is up-to-date: gpt2/124M/vocab.bpe\n"
]
}
],
"source": [
"from gpt_download import download_and_load_gpt2\n",
"from previous_chapters import GPTModel, load_weights_into_gpt\n",
"\n",
"model_size = CHOOSE_MODEL.split(\" \")[-1].lstrip(\"(\").rstrip(\")\")\n",
"settings, params = download_and_load_gpt2(model_size=model_size, models_dir=\"gpt2\")\n",
"\n",
"model = GPTModel(BASE_CONFIG)\n",
"load_weights_into_gpt(model, params)\n",
"model.eval();"
]
},
{
"cell_type": "markdown",
"id": "ab8e056c-abe0-415f-b34d-df686204259e",
"metadata": {},
"source": [
"- To ensure that the model was loaded corrected, let's double-check that it generates coherent text"
]
},
{
"cell_type": "code",
2024-05-11 07:42:13 -05:00
"execution_count": 17,
2024-05-05 07:10:04 -05:00
"id": "d8ac25ff-74b1-4149-8dc5-4c429d464330",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-05-05 07:10:04 -05:00
"Every effort moves you forward.\n",
"\n",
"The first step is to understand the importance of your work\n"
]
}
],
"source": [
2024-05-05 07:10:04 -05:00
"from previous_chapters import (\n",
" generate_text_simple,\n",
" text_to_token_ids,\n",
" token_ids_to_text\n",
")\n",
"\n",
"\n",
2024-05-05 07:10:04 -05:00
"text_1 = \"Every effort moves you\"\n",
"\n",
2024-05-05 07:10:04 -05:00
"token_ids = generate_text_simple(\n",
" model=model,\n",
2024-05-05 07:10:04 -05:00
" idx=text_to_token_ids(text_1, tokenizer),\n",
" max_new_tokens=15,\n",
" context_size=BASE_CONFIG[\"context_length\"]\n",
")\n",
"\n",
2024-05-05 07:10:04 -05:00
"print(token_ids_to_text(token_ids, tokenizer))"
]
},
{
"cell_type": "markdown",
"id": "69162550-6a02-4ece-8db1-06c71d61946f",
"metadata": {},
"source": [
"- Before we finetune the model as a classifier, let's see if the model can perhaps already classify spam messages via prompting"
]
},
{
"cell_type": "code",
2024-05-11 07:42:13 -05:00
"execution_count": 18,
2024-05-05 07:10:04 -05:00
"id": "94224aa9-c95a-4f8a-a420-76d01e3a800c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Is the following text 'spam'? Answer with 'yes' or 'no': 'You are a winner you have been specially selected to receive $1000 cash or a $2000 award.' Answer with 'yes' or 'no'. Answer with 'yes' or 'no'. Answer with 'yes' or 'no'. Answer with 'yes'\n"
]
}
],
"source": [
"text_2 = (\n",
" \"Is the following text 'spam'? Answer with 'yes' or 'no':\"\n",
" \" 'You are a winner you have been specially\"\n",
" \" selected to receive $1000 cash or a $2000 award.'\"\n",
" \" Answer with 'yes' or 'no'.\"\n",
")\n",
"\n",
"token_ids = generate_text_simple(\n",
" model=model,\n",
" idx=text_to_token_ids(text_2, tokenizer),\n",
" max_new_tokens=23,\n",
" context_size=BASE_CONFIG[\"context_length\"]\n",
")\n",
"\n",
"print(token_ids_to_text(token_ids, tokenizer))"
]
},
{
"cell_type": "markdown",
"id": "1ce39ed0-2c77-410d-8392-dd15d4b22016",
"metadata": {},
"source": [
2024-05-09 07:25:52 -05:00
"- As we can see, the model is not very good at following instructions\n",
2024-05-05 07:10:04 -05:00
"- This is expected, since it has only been pretrained and not instruction-finetuned (instruction finetuning will be covered in the next chapter)"
]
},
{
"cell_type": "markdown",
"id": "4c9ae440-32f9-412f-96cf-fd52cc3e2522",
"metadata": {
"id": "4c9ae440-32f9-412f-96cf-fd52cc3e2522"
},
"source": [
2024-05-05 07:10:04 -05:00
"## 6.5 Adding a classification head"
]
},
{
"cell_type": "markdown",
"id": "d6e9d66f-76b2-40fc-9ec5-3f972a8db9c0",
"metadata": {},
"source": [
"<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch06_compressed/lm-head.webp\" width=500px>"
]
},
{
"cell_type": "markdown",
"id": "217bac05-78df-4412-bd80-612f8061c01d",
"metadata": {},
"source": [
"- In this section, we are modifying the pretrained LLM to make it ready for classification finetuning\n",
"- Let's take a look at the model architecture first"
]
},
{
"cell_type": "code",
2024-05-11 07:42:13 -05:00
"execution_count": 19,
"id": "b23aff91-6bd0-48da-88f6-353657e6c981",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "1d8f7a01-b7c0-48d4-b1e7-8c12cc7ad932",
"outputId": "b6a5b9b5-a92f-498f-d7cb-b58dd99e4497"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"GPTModel(\n",
" (tok_emb): Embedding(50257, 768)\n",
" (pos_emb): Embedding(1024, 768)\n",
" (drop_emb): Dropout(p=0.0, inplace=False)\n",
" (trf_blocks): Sequential(\n",
" (0): TransformerBlock(\n",
" (att): MultiHeadAttention(\n",
" (W_query): Linear(in_features=768, out_features=768, bias=True)\n",
" (W_key): Linear(in_features=768, out_features=768, bias=True)\n",
" (W_value): Linear(in_features=768, out_features=768, bias=True)\n",
" (out_proj): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (ff): FeedForward(\n",
" (layers): Sequential(\n",
" (0): Linear(in_features=768, out_features=3072, bias=True)\n",
" (1): GELU()\n",
" (2): Linear(in_features=3072, out_features=768, bias=True)\n",
" )\n",
" )\n",
" (norm1): LayerNorm()\n",
" (norm2): LayerNorm()\n",
" (drop_resid): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (1): TransformerBlock(\n",
" (att): MultiHeadAttention(\n",
" (W_query): Linear(in_features=768, out_features=768, bias=True)\n",
" (W_key): Linear(in_features=768, out_features=768, bias=True)\n",
" (W_value): Linear(in_features=768, out_features=768, bias=True)\n",
" (out_proj): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (ff): FeedForward(\n",
" (layers): Sequential(\n",
" (0): Linear(in_features=768, out_features=3072, bias=True)\n",
" (1): GELU()\n",
" (2): Linear(in_features=3072, out_features=768, bias=True)\n",
" )\n",
" )\n",
" (norm1): LayerNorm()\n",
" (norm2): LayerNorm()\n",
" (drop_resid): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (2): TransformerBlock(\n",
" (att): MultiHeadAttention(\n",
" (W_query): Linear(in_features=768, out_features=768, bias=True)\n",
" (W_key): Linear(in_features=768, out_features=768, bias=True)\n",
" (W_value): Linear(in_features=768, out_features=768, bias=True)\n",
" (out_proj): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (ff): FeedForward(\n",
" (layers): Sequential(\n",
" (0): Linear(in_features=768, out_features=3072, bias=True)\n",
" (1): GELU()\n",
" (2): Linear(in_features=3072, out_features=768, bias=True)\n",
" )\n",
" )\n",
" (norm1): LayerNorm()\n",
" (norm2): LayerNorm()\n",
" (drop_resid): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (3): TransformerBlock(\n",
" (att): MultiHeadAttention(\n",
" (W_query): Linear(in_features=768, out_features=768, bias=True)\n",
" (W_key): Linear(in_features=768, out_features=768, bias=True)\n",
" (W_value): Linear(in_features=768, out_features=768, bias=True)\n",
" (out_proj): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (ff): FeedForward(\n",
" (layers): Sequential(\n",
" (0): Linear(in_features=768, out_features=3072, bias=True)\n",
" (1): GELU()\n",
" (2): Linear(in_features=3072, out_features=768, bias=True)\n",
" )\n",
" )\n",
" (norm1): LayerNorm()\n",
" (norm2): LayerNorm()\n",
" (drop_resid): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (4): TransformerBlock(\n",
" (att): MultiHeadAttention(\n",
" (W_query): Linear(in_features=768, out_features=768, bias=True)\n",
" (W_key): Linear(in_features=768, out_features=768, bias=True)\n",
" (W_value): Linear(in_features=768, out_features=768, bias=True)\n",
" (out_proj): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (ff): FeedForward(\n",
" (layers): Sequential(\n",
" (0): Linear(in_features=768, out_features=3072, bias=True)\n",
" (1): GELU()\n",
" (2): Linear(in_features=3072, out_features=768, bias=True)\n",
" )\n",
" )\n",
" (norm1): LayerNorm()\n",
" (norm2): LayerNorm()\n",
" (drop_resid): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (5): TransformerBlock(\n",
" (att): MultiHeadAttention(\n",
" (W_query): Linear(in_features=768, out_features=768, bias=True)\n",
" (W_key): Linear(in_features=768, out_features=768, bias=True)\n",
" (W_value): Linear(in_features=768, out_features=768, bias=True)\n",
" (out_proj): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (ff): FeedForward(\n",
" (layers): Sequential(\n",
" (0): Linear(in_features=768, out_features=3072, bias=True)\n",
" (1): GELU()\n",
" (2): Linear(in_features=3072, out_features=768, bias=True)\n",
" )\n",
" )\n",
" (norm1): LayerNorm()\n",
" (norm2): LayerNorm()\n",
" (drop_resid): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (6): TransformerBlock(\n",
" (att): MultiHeadAttention(\n",
" (W_query): Linear(in_features=768, out_features=768, bias=True)\n",
" (W_key): Linear(in_features=768, out_features=768, bias=True)\n",
" (W_value): Linear(in_features=768, out_features=768, bias=True)\n",
" (out_proj): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (ff): FeedForward(\n",
" (layers): Sequential(\n",
" (0): Linear(in_features=768, out_features=3072, bias=True)\n",
" (1): GELU()\n",
" (2): Linear(in_features=3072, out_features=768, bias=True)\n",
" )\n",
" )\n",
" (norm1): LayerNorm()\n",
" (norm2): LayerNorm()\n",
" (drop_resid): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (7): TransformerBlock(\n",
" (att): MultiHeadAttention(\n",
" (W_query): Linear(in_features=768, out_features=768, bias=True)\n",
" (W_key): Linear(in_features=768, out_features=768, bias=True)\n",
" (W_value): Linear(in_features=768, out_features=768, bias=True)\n",
" (out_proj): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (ff): FeedForward(\n",
" (layers): Sequential(\n",
" (0): Linear(in_features=768, out_features=3072, bias=True)\n",
" (1): GELU()\n",
" (2): Linear(in_features=3072, out_features=768, bias=True)\n",
" )\n",
" )\n",
" (norm1): LayerNorm()\n",
" (norm2): LayerNorm()\n",
" (drop_resid): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (8): TransformerBlock(\n",
" (att): MultiHeadAttention(\n",
" (W_query): Linear(in_features=768, out_features=768, bias=True)\n",
" (W_key): Linear(in_features=768, out_features=768, bias=True)\n",
" (W_value): Linear(in_features=768, out_features=768, bias=True)\n",
" (out_proj): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (ff): FeedForward(\n",
" (layers): Sequential(\n",
" (0): Linear(in_features=768, out_features=3072, bias=True)\n",
" (1): GELU()\n",
" (2): Linear(in_features=3072, out_features=768, bias=True)\n",
" )\n",
" )\n",
" (norm1): LayerNorm()\n",
" (norm2): LayerNorm()\n",
" (drop_resid): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (9): TransformerBlock(\n",
" (att): MultiHeadAttention(\n",
" (W_query): Linear(in_features=768, out_features=768, bias=True)\n",
" (W_key): Linear(in_features=768, out_features=768, bias=True)\n",
" (W_value): Linear(in_features=768, out_features=768, bias=True)\n",
" (out_proj): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (ff): FeedForward(\n",
" (layers): Sequential(\n",
" (0): Linear(in_features=768, out_features=3072, bias=True)\n",
" (1): GELU()\n",
" (2): Linear(in_features=3072, out_features=768, bias=True)\n",
" )\n",
" )\n",
" (norm1): LayerNorm()\n",
" (norm2): LayerNorm()\n",
" (drop_resid): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (10): TransformerBlock(\n",
" (att): MultiHeadAttention(\n",
" (W_query): Linear(in_features=768, out_features=768, bias=True)\n",
" (W_key): Linear(in_features=768, out_features=768, bias=True)\n",
" (W_value): Linear(in_features=768, out_features=768, bias=True)\n",
" (out_proj): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (ff): FeedForward(\n",
" (layers): Sequential(\n",
" (0): Linear(in_features=768, out_features=3072, bias=True)\n",
" (1): GELU()\n",
" (2): Linear(in_features=3072, out_features=768, bias=True)\n",
" )\n",
" )\n",
" (norm1): LayerNorm()\n",
" (norm2): LayerNorm()\n",
" (drop_resid): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (11): TransformerBlock(\n",
" (att): MultiHeadAttention(\n",
" (W_query): Linear(in_features=768, out_features=768, bias=True)\n",
" (W_key): Linear(in_features=768, out_features=768, bias=True)\n",
" (W_value): Linear(in_features=768, out_features=768, bias=True)\n",
" (out_proj): Linear(in_features=768, out_features=768, bias=True)\n",
" (dropout): Dropout(p=0.0, inplace=False)\n",
" )\n",
" (ff): FeedForward(\n",
" (layers): Sequential(\n",
" (0): Linear(in_features=768, out_features=3072, bias=True)\n",
" (1): GELU()\n",
" (2): Linear(in_features=3072, out_features=768, bias=True)\n",
" )\n",
" )\n",
" (norm1): LayerNorm()\n",
" (norm2): LayerNorm()\n",
" (drop_resid): Dropout(p=0.0, inplace=False)\n",
" )\n",
" )\n",
" (final_norm): LayerNorm()\n",
" (out_head): Linear(in_features=768, out_features=50257, bias=False)\n",
")\n"
]
}
],
"source": [
"print(model)"
]
},
{
"cell_type": "markdown",
"id": "3f640a76-dd00-4769-9bc8-1aed0cec330d",
"metadata": {},
"source": [
"- Above, we can see the architecture we implemented in chapter 4 neatly laid out\n",
"- The goal is to replace and finetune the output layer\n",
"- To achieve this, we first freeze the model, meaning that we make all layers non-trainable"
]
},
{
"cell_type": "code",
2024-05-11 07:42:13 -05:00
"execution_count": 20,
"id": "fkMWFl-0etea",
"metadata": {
"id": "fkMWFl-0etea"
},
"outputs": [],
"source": [
"for param in model.parameters():\n",
" param.requires_grad = False"
]
},
{
"cell_type": "markdown",
"id": "72155f83-87d9-476a-a978-a15aa2d44147",
"metadata": {},
"source": [
"- Then, we replace the output layer (`model.out_head`), which originally maps the layer inputs to 50,257 dimensions (the size of the vocabulary)\n",
2024-05-08 06:48:28 -05:00
"- Since we finetune the model for binary classification (predicting 2 classes, \"spam\" and \"not spam\"), we can replace the output layer as shown below, which will be trainable by default\n",
"- Note that we use `BASE_CONFIG[\"emb_dim\"]` (which is equal to 768 in the `\"gpt2-small (124M)\"` model) to keep the code below more general"
]
},
{
"cell_type": "code",
2024-05-11 07:42:13 -05:00
"execution_count": 21,
"id": "7e759fa0-0f69-41be-b576-17e5f20e04cb",
"metadata": {},
"outputs": [],
"source": [
"torch.manual_seed(123)\n",
"\n",
"num_classes = 2\n",
"model.out_head = torch.nn.Linear(in_features=BASE_CONFIG[\"emb_dim\"], out_features=num_classes)"
]
},
{
"cell_type": "markdown",
"id": "30be5475-ae77-4f97-8f3e-dec462b1339f",
"metadata": {},
"source": [
"- Technically, it's sufficient to only train the output layer\n",
"- However, as I found in [experiments finetuning additional layers](https://magazine.sebastianraschka.com/p/finetuning-large-language-models) can noticeably improve the performance\n",
"- So, we are also making the last transformer block and the final `LayerNorm` module connecting the last transformer block to the output layer trainable"
]
},
2024-05-05 07:10:04 -05:00
{
"cell_type": "markdown",
"id": "0be7c1eb-c46c-4065-8525-eea1b8c66d10",
"metadata": {},
"source": [
"<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch06_compressed/trainable.webp\" width=500px>"
]
},
{
"cell_type": "code",
2024-05-11 08:27:07 -05:00
"execution_count": 22,
"id": "2aedc120-5ee3-48f6-92f2-ad9304ebcdc7",
"metadata": {
"id": "2aedc120-5ee3-48f6-92f2-ad9304ebcdc7"
},
"outputs": [],
"source": [
"for param in model.trf_blocks[-1].parameters():\n",
" param.requires_grad = True\n",
"\n",
"for param in model.final_norm.parameters():\n",
" param.requires_grad = True"
]
},
{
"cell_type": "markdown",
"id": "f012b899-8284-4d3a-97c0-8a48eb33ba2e",
"metadata": {},
"source": [
"- We can still use this model similar to before in previous chapters\n",
"- For example, let's feed it some text input"
]
},
{
"cell_type": "code",
2024-05-11 08:27:07 -05:00
"execution_count": 23,
"id": "f645c06a-7df6-451c-ad3f-eafb18224ebc",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "f645c06a-7df6-451c-ad3f-eafb18224ebc",
"outputId": "27e041b1-d731-48a1-cf60-f22d4565304e"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-05-05 07:10:04 -05:00
"Inputs: tensor([[5211, 345, 423, 640]])\n",
"Inputs dimensions: torch.Size([1, 4])\n"
]
}
],
"source": [
2024-05-05 07:10:04 -05:00
"inputs = tokenizer.encode(\"Do you have time\")\n",
"inputs = torch.tensor(inputs).unsqueeze(0)\n",
"print(\"Inputs:\", inputs)\n",
"print(\"Inputs dimensions:\", inputs.shape) # shape: (batch_size, num_tokens)"
]
},
{
"cell_type": "markdown",
"id": "fbbf8481-772d-467b-851c-a62b86d0cb1b",
"metadata": {},
"source": [
"- What's different compared to previous chapters is that it now has two output dimensions instead of 50,257"
]
},
{
"cell_type": "code",
2024-05-11 08:27:07 -05:00
"execution_count": 24,
"id": "48dc84f1-85cc-4609-9cee-94ff539f00f4",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "48dc84f1-85cc-4609-9cee-94ff539f00f4",
"outputId": "9cae7448-253d-4776-973e-0af190b06354"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Outputs:\n",
2024-05-05 07:10:04 -05:00
" tensor([[[-1.5854, 0.9904],\n",
" [-3.7235, 7.4548],\n",
" [-2.2661, 6.6049],\n",
" [-3.5983, 3.9902]]])\n",
"Outputs dimensions: torch.Size([1, 4, 2])\n"
]
}
],
"source": [
"with torch.no_grad():\n",
" outputs = model(inputs)\n",
"\n",
"print(\"Outputs:\\n\", outputs)\n",
"print(\"Outputs dimensions:\", outputs.shape) # shape: (batch_size, num_tokens, num_classes)"
]
},
2024-05-10 07:02:14 -05:00
{
"cell_type": "markdown",
"id": "75430a01-ef9c-426a-aca0-664689c4f461",
"metadata": {},
"source": [
"- As discussed in previous chapters, for each input token, there's one output vector\n",
"- Since we fed the model a text sample with 4 input tokens, the output consists of 4 2-dimensional output vectors above"
]
},
2024-05-05 07:10:04 -05:00
{
"cell_type": "markdown",
"id": "7df9144f-6817-4be4-8d4b-5d4dadfe4a9b",
"metadata": {},
"source": [
"<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch06_compressed/input-and-output.webp\" width=500px>"
]
},
{
"cell_type": "markdown",
"id": "e3bb8616-c791-4f5c-bac0-5302f663e46a",
"metadata": {},
"source": [
"- In chapter 3, we discussed the attention mechanism, which connects each input token to each other input token\n",
"- In chapter 3, we then also introduced the causal attention mask that is used in GPT-like models; this causal mask lets a current token only attend to the current and previous token positions\n",
2024-05-10 07:02:14 -05:00
"- Based on this causal attention mechanism, the 4th (last) token contains the most information among all tokens because it's the only token that includes information about all other tokens\n",
"- Hence, we are particularly interested in this last token, which we will finetune for the spam classification task"
]
},
{
"cell_type": "code",
2024-05-11 08:27:07 -05:00
"execution_count": 25,
"id": "49383a8c-41d5-4dab-98f1-238bca0c2ed7",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "49383a8c-41d5-4dab-98f1-238bca0c2ed7",
"outputId": "e79eb155-fa1f-46ed-ff8c-d828c3a3fabd"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-05-11 07:42:13 -05:00
"Last output token: tensor([[-3.5983, 3.9902]])\n"
]
}
],
"source": [
"print(\"Last output token:\", outputs[:, -1, :])"
]
},
2024-05-05 07:10:04 -05:00
{
"cell_type": "markdown",
"id": "8df08ae0-e664-4670-b7c5-8a2280d9b41b",
"metadata": {},
"source": [
"<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch06_compressed/attention-mask.webp\" width=200px>"
]
},
{
"cell_type": "markdown",
"id": "32aa4aef-e1e9-491b-9adf-5aa973e59b8c",
"metadata": {},
"source": [
2024-05-05 07:10:04 -05:00
"## 6.6 Calculating the classification loss and accuracy"
]
},
{
"cell_type": "markdown",
"id": "669e1fd1-ace8-44b4-b438-185ed0ba8b33",
"metadata": {},
"source": [
"<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch06_compressed/overview-3.webp\" width=500px>"
]
2024-05-11 07:42:13 -05:00
},
{
"cell_type": "markdown",
"id": "7a7df4ee-0a34-4a4d-896d-affbbf81e0b3",
"metadata": {},
"source": [
"- Before explaining the loss calculation, let's have a brief look at how the model outputs are turned into class labels"
]
},
{
"cell_type": "code",
2024-05-11 08:27:07 -05:00
"execution_count": 26,
2024-05-11 07:42:13 -05:00
"id": "c77faab1-3461-4118-866a-6171f2b89aa0",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Last output token: tensor([[-3.5983, 3.9902]])\n"
]
}
],
"source": [
"print(\"Last output token:\", outputs[:, -1, :])"
]
},
{
"cell_type": "markdown",
"id": "7edd71fa-628a-4d00-b81d-6d8bcb2c341d",
"metadata": {},
"source": [
2024-05-11 08:27:07 -05:00
"- Similar to chapter 5, we convert the outputs (logits) into probability scores via the `softmax` function and then obtain the index position of the largest probability value via the `argmax` function"
2024-05-11 07:42:13 -05:00
]
},
{
"cell_type": "code",
2024-05-11 08:27:07 -05:00
"execution_count": 27,
2024-05-11 07:42:13 -05:00
"id": "b81efa92-9be1-4b9e-8790-ce1fc7b17f01",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Class label: 1\n"
]
}
],
"source": [
"probas = torch.softmax(outputs[:, -1, :], dim=-1)\n",
"label = torch.argmax(probas)\n",
"print(\"Class label:\", label.item())"
]
},
{
"cell_type": "markdown",
2024-05-11 08:27:07 -05:00
"id": "414a6f02-307e-4147-a416-14d115bf8179",
2024-05-11 07:42:13 -05:00
"metadata": {},
"source": [
"- Note that the softmax function is optional here, as explained in chapter 5, because the largest outputs correspond to the largest probability scores"
]
},
2024-05-11 08:27:07 -05:00
{
"cell_type": "code",
"execution_count": 28,
"id": "f9f9ad66-4969-4501-8239-3ccdb37e71a2",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Class label: 1\n"
]
}
],
"source": [
"logits = outputs[:, -1, :]\n",
"label = torch.argmax(logits)\n",
"print(\"Class label:\", label.item())"
]
},
{
"cell_type": "markdown",
"id": "dcb20d3a-cbba-4ab1-8584-d94e16589505",
"metadata": {},
"source": [
"- We can apply this concept to calculate the so-called classification accuracy, which computes the percentage of correct predictions in a given dataset\n",
"- To calculate the classification accuracy, we can apply the preceding `argmax`-based prediction code to all examples in a dataset and calculate the fraction of correct predictions as follows:"
]
},
{
"cell_type": "code",
"execution_count": 29,
"id": "3ecf9572-aed0-4a21-9c3b-7f9f2aec5f23",
"metadata": {},
"outputs": [],
"source": [
"def calc_accuracy_loader(data_loader, model, device, num_batches=None):\n",
" model.eval()\n",
" correct_predictions, num_examples = 0, 0\n",
"\n",
" if num_batches is None:\n",
" num_batches = len(data_loader)\n",
" else:\n",
" num_batches = min(num_batches, len(data_loader))\n",
" for i, (input_batch, target_batch) in enumerate(data_loader):\n",
" if i < num_batches:\n",
" input_batch, target_batch = input_batch.to(device), target_batch.to(device)\n",
"\n",
" with torch.no_grad():\n",
" logits = model(input_batch)[:, -1, :] # Logits of last output token\n",
" predicted_labels = torch.argmax(logits, dim=-1)\n",
"\n",
" num_examples += predicted_labels.shape[0]\n",
" correct_predictions += (predicted_labels == target_batch).sum().item()\n",
" else:\n",
" break\n",
" return correct_predictions / num_examples"
]
},
{
"cell_type": "markdown",
"id": "7165fe46-a284-410b-957f-7524877d1a1a",
"metadata": {},
"source": [
"- Let's apply the function to calculate the classification accuracies for the different datasets:"
]
},
{
"cell_type": "code",
"execution_count": 30,
"id": "390e5255-8427-488c-adef-e1c10ab4fb26",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Training accuracy: 46.25%\n",
"Validation accuracy: 45.00%\n",
"Test accuracy: 48.75%\n"
]
}
],
"source": [
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"model.to(device) # no assignment model = model.to(device) necessary for nn.Module classes\n",
"\n",
"torch.manual_seed(123) # For reproducibility due to the shuffling in the training data loader\n",
"\n",
"train_accuracy = calc_accuracy_loader(train_loader, model, device, num_batches=10)\n",
"val_accuracy = calc_accuracy_loader(val_loader, model, device, num_batches=10)\n",
"test_accuracy = calc_accuracy_loader(test_loader, model, device, num_batches=10)\n",
"\n",
"print(f\"Training accuracy: {train_accuracy*100:.2f}%\")\n",
"print(f\"Validation accuracy: {val_accuracy*100:.2f}%\")\n",
"print(f\"Test accuracy: {test_accuracy*100:.2f}%\")"
]
},
{
"cell_type": "markdown",
"id": "30345e2a-afed-4d22-9486-f4010f90a871",
"metadata": {},
"source": [
"- As we can see, the prediction accuracies are not very good, since we haven't finetuned the model, yet"
]
},
{
"cell_type": "markdown",
"id": "4f4a9d15-8fc7-48a2-8734-d92a2f265328",
"metadata": {},
"source": [
"- Before we can start finetuning (/training), we first have to define the loss function we want to optimize during training\n",
"- The goal is to maximize the spam classification accuracy of the model; however, classification accuracy is not a differentiable function\n",
2024-05-11 08:27:07 -05:00
"- Hence, instead, we minimize the cross entropy loss as a proxy for maximizing the classification accuracy (you can learn more about this topic in lecture 8 of my freely available [Introduction to Deep Learning](https://sebastianraschka.com/blog/2021/dl-course.html#l08-multinomial-logistic-regression--softmax-regression) class)\n",
"\n",
2024-05-11 08:27:07 -05:00
"- The `calc_loss_batch` function is the same here as in chapter 5, except that we are only interested in optimizing the last token `model(input_batch)[:, -1, :]` instead of all tokens `model(input_batch)`"
]
},
{
"cell_type": "code",
2024-05-11 08:27:07 -05:00
"execution_count": 31,
"id": "2f1e9547-806c-41a9-8aba-3b2822baabe4",
"metadata": {
"id": "2f1e9547-806c-41a9-8aba-3b2822baabe4"
},
"outputs": [],
"source": [
"def calc_loss_batch(input_batch, target_batch, model, device):\n",
" input_batch, target_batch = input_batch.to(device), target_batch.to(device)\n",
" logits = model(input_batch)[:, -1, :] # Logits of last output token\n",
" loss = torch.nn.functional.cross_entropy(logits, target_batch)\n",
" return loss"
]
},
{
"cell_type": "markdown",
"id": "a013aab9-f854-4866-ad55-5b8350adb50a",
"metadata": {},
"source": [
2024-05-11 08:27:07 -05:00
"The `calc_loss_loader` is exactly the same as in chapter 5"
]
},
{
"cell_type": "code",
2024-05-11 08:27:07 -05:00
"execution_count": 32,
"id": "b7b83e10-5720-45e7-ac5e-369417ca846b",
"metadata": {},
"outputs": [],
"source": [
"# Same as in chapter 5\n",
"def calc_loss_loader(data_loader, model, device, num_batches=None):\n",
" total_loss = 0.\n",
" if len(data_loader) == 0:\n",
" return float(\"nan\")\n",
" elif num_batches is None:\n",
" num_batches = len(data_loader)\n",
" else:\n",
" # Reduce the number of batches to match the total number of batches in the data loader\n",
" # if num_batches exceeds the number of batches in the data loader\n",
" num_batches = min(num_batches, len(data_loader))\n",
" for i, (input_batch, target_batch) in enumerate(data_loader):\n",
" if i < num_batches:\n",
" loss = calc_loss_batch(input_batch, target_batch, model, device)\n",
" total_loss += loss.item()\n",
" else:\n",
" break\n",
" return total_loss / num_batches"
]
},
{
"cell_type": "markdown",
"id": "56826ecd-6e74-40e6-b772-d3541e585067",
"metadata": {},
"source": [
2024-05-11 08:27:07 -05:00
"- Using the `calc_closs_loader`, we compute the initial training, validation, and test set losses before we start training"
]
},
{
"cell_type": "code",
2024-05-11 08:27:07 -05:00
"execution_count": 33,
"id": "f6f00e53-5beb-4e64-b147-f26fd481c6ff",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "f6f00e53-5beb-4e64-b147-f26fd481c6ff",
"outputId": "49df8648-9e38-4314-854d-9faacd1b2e89"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2024-05-11 08:27:07 -05:00
"Training loss: 2.453\n",
"Validation loss: 2.583\n",
"Test loss: 2.322\n"
]
}
],
"source": [
"with torch.no_grad(): # Disable gradient tracking for efficiency because we are not training, yet\n",
" train_loss = calc_loss_loader(train_loader, model, device, num_batches=5)\n",
" val_loss = calc_loss_loader(val_loader, model, device, num_batches=5)\n",
" test_loss = calc_loss_loader(test_loader, model, device, num_batches=5)\n",
"\n",
"print(f\"Training loss: {train_loss:.3f}\")\n",
"print(f\"Validation loss: {val_loss:.3f}\")\n",
"print(f\"Test loss: {test_loss:.3f}\")"
]
},
{
"cell_type": "markdown",
"id": "e04b980b-e583-4f62-84a0-4edafaf99d5d",
"metadata": {},
"source": [
2024-05-11 08:27:07 -05:00
"- In the next section, we train the model to improve the loss values and consequently the classification accuracy"
]
},
{
"cell_type": "markdown",
"id": "456ae0fd-6261-42b4-ab6a-d24289953083",
"metadata": {
"id": "456ae0fd-6261-42b4-ab6a-d24289953083"
},
"source": [
2024-05-05 07:10:04 -05:00
"## 6.7 Finetuning the model on supervised data"
]
},
{
"cell_type": "markdown",
"id": "6a9b099b-0829-4f72-8a2b-4363e3497026",
"metadata": {},
"source": [
"- In this section, we define and use the training function to improve the classification accuracy of the model\n",
"- The `train_classifier_simple` function below is practically the same as the `train_model_simple` function we used for pretraining the model in chapter 5\n",
"- The only two differences are that we now \n",
" 1. track the number of training examples seen (`examples_seen`) instead of the number of tokens seen\n",
" 2. calculate the accuracy after each epoch instead of printing a sample text after each epoch"
]
},
2024-05-05 07:10:04 -05:00
{
"cell_type": "markdown",
"id": "979b6222-1dc2-4530-9d01-b6b04fe3de12",
"metadata": {},
"source": [
"<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch06_compressed/training-loop.webp\" width=500px>"
]
},
{
"cell_type": "code",
"execution_count": 31,
"id": "Csbr60to50FL",
"metadata": {
"id": "Csbr60to50FL"
},
"outputs": [],
"source": [
"# Overall the same as `train_model_simple` in chapter 5\n",
"def train_classifier_simple(model, train_loader, val_loader, optimizer, device, num_epochs,\n",
" eval_freq, eval_iter, tokenizer):\n",
" # Initialize lists to track losses and tokens seen\n",
" train_losses, val_losses, train_accs, val_accs = [], [], [], []\n",
" examples_seen, global_step = 0, -1\n",
"\n",
" # Main training loop\n",
" for epoch in range(num_epochs):\n",
" model.train() # Set model to training mode\n",
"\n",
" for input_batch, target_batch in train_loader:\n",
" optimizer.zero_grad() # Reset loss gradients from previous epoch\n",
" loss = calc_loss_batch(input_batch, target_batch, model, device)\n",
" loss.backward() # Calculate loss gradients\n",
" optimizer.step() # Update model weights using loss gradients\n",
" examples_seen += input_batch.shape[0] # New: track examples instead of tokens\n",
" global_step += 1\n",
"\n",
" # Optional evaluation step\n",
" if global_step % eval_freq == 0:\n",
" train_loss, val_loss = evaluate_model(\n",
" model, train_loader, val_loader, device, eval_iter)\n",
" train_losses.append(train_loss)\n",
" val_losses.append(val_loss)\n",
" print(f\"Ep {epoch+1} (Step {global_step:06d}): \"\n",
" f\"Train loss {train_loss:.3f}, Val loss {val_loss:.3f}\")\n",
"\n",
" # Calculate accuracy after each epoch\n",
" train_accuracy = calc_accuracy_loader(train_loader, model, device, num_batches=eval_iter)\n",
" val_accuracy = calc_accuracy_loader(val_loader, model, device, num_batches=eval_iter)\n",
" print(f\"Training accuracy: {train_accuracy*100:.2f}% | \", end=\"\")\n",
" print(f\"Validation accuracy: {val_accuracy*100:.2f}%\")\n",
" train_accs.append(train_accuracy)\n",
" val_accs.append(val_accuracy)\n",
"\n",
" return train_losses, val_losses, train_accs, val_accs, examples_seen"
]
},
{
"cell_type": "markdown",
"id": "9624cb30-3e3a-45be-b006-c00475b58ae8",
"metadata": {},
"source": [
"- The `evaluate_model` function used in the `train_classifier_simple` is the same as the one we used in chapter 5"
]
},
{
"cell_type": "code",
"execution_count": 32,
"id": "bcc7bc04-6aa6-4516-a147-460e2f466eab",
"metadata": {},
"outputs": [],
"source": [
"# Same as chapter 5\n",
"def evaluate_model(model, train_loader, val_loader, device, eval_iter):\n",
" model.eval()\n",
" with torch.no_grad():\n",
" train_loss = calc_loss_loader(train_loader, model, device, num_batches=eval_iter)\n",
" val_loss = calc_loss_loader(val_loader, model, device, num_batches=eval_iter)\n",
" model.train()\n",
" return train_loss, val_loss"
]
},
{
"cell_type": "markdown",
"id": "e807bfe9-364d-46b2-9e25-3b000c3ef6f9",
"metadata": {},
"source": [
"- The training takes about 5 minutes on a M3 MacBook Air laptop computer and less than half a minute on a V100 or A100 GPU"
]
},
{
"cell_type": "code",
"execution_count": 33,
"id": "X7kU3aAj7vTJ",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "X7kU3aAj7vTJ",
"outputId": "504a033e-2bf8-41b5-a037-468309845513"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Ep 1 (Step 000000): Train loss 2.153, Val loss 2.392\n",
"Ep 1 (Step 000050): Train loss 0.617, Val loss 0.637\n",
"Ep 1 (Step 000100): Train loss 0.523, Val loss 0.557\n",
"Training accuracy: 70.00% | Validation accuracy: 72.50%\n",
"Ep 2 (Step 000150): Train loss 0.561, Val loss 0.489\n",
"Ep 2 (Step 000200): Train loss 0.419, Val loss 0.397\n",
"Ep 2 (Step 000250): Train loss 0.409, Val loss 0.353\n",
"Training accuracy: 82.50% | Validation accuracy: 85.00%\n",
"Ep 3 (Step 000300): Train loss 0.333, Val loss 0.320\n",
"Ep 3 (Step 000350): Train loss 0.340, Val loss 0.306\n",
"Training accuracy: 90.00% | Validation accuracy: 90.00%\n",
"Ep 4 (Step 000400): Train loss 0.136, Val loss 0.200\n",
"Ep 4 (Step 000450): Train loss 0.153, Val loss 0.132\n",
"Ep 4 (Step 000500): Train loss 0.222, Val loss 0.137\n",
"Training accuracy: 100.00% | Validation accuracy: 97.50%\n",
"Ep 5 (Step 000550): Train loss 0.207, Val loss 0.143\n",
"Ep 5 (Step 000600): Train loss 0.083, Val loss 0.074\n",
"Training accuracy: 100.00% | Validation accuracy: 97.50%\n",
"Training completed in 5.65 minutes.\n"
]
}
],
"source": [
"import time\n",
"\n",
"start_time = time.time()\n",
"\n",
"torch.manual_seed(123)\n",
"\n",
"optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5, weight_decay=0.1)\n",
"\n",
"num_epochs = 5\n",
"train_losses, val_losses, train_accs, val_accs, examples_seen = train_classifier_simple(\n",
" model, train_loader, val_loader, optimizer, device,\n",
" num_epochs=num_epochs, eval_freq=50, eval_iter=5,\n",
" tokenizer=tokenizer\n",
")\n",
"\n",
"end_time = time.time()\n",
"execution_time_minutes = (end_time - start_time) / 60\n",
"print(f\"Training completed in {execution_time_minutes:.2f} minutes.\")"
]
},
{
"cell_type": "markdown",
"id": "1261bf90-3ce7-4591-895a-044a05538f30",
"metadata": {},
"source": [
"- Similar to chapter 5, we use matplotlib to plot the loss function for the training and validation set"
]
},
{
"cell_type": "code",
"execution_count": 34,
"id": "cURgnDqdCeka",
"metadata": {
"id": "cURgnDqdCeka"
},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"def plot_values(epochs_seen, examples_seen, train_values, val_values, label=\"loss\"):\n",
" fig, ax1 = plt.subplots(figsize=(5, 3))\n",
"\n",
" # Plot training and validation loss against epochs\n",
" ax1.plot(epochs_seen, train_values, label=f\"Training {label}\")\n",
" ax1.plot(epochs_seen, val_values, linestyle=\"-.\", label=f\"Validation {label}\")\n",
" ax1.set_xlabel(\"Epochs\")\n",
" ax1.set_ylabel(label.capitalize())\n",
" ax1.legend()\n",
"\n",
" # Create a second x-axis for tokens seen\n",
" ax2 = ax1.twiny() # Create a second x-axis that shares the same y-axis\n",
" ax2.plot(examples_seen, train_values, alpha=0) # Invisible plot for aligning ticks\n",
" ax2.set_xlabel(\"Examples seen\")\n",
"\n",
" fig.tight_layout() # Adjust layout to make room\n",
" plt.savefig(f\"{label}-plot.pdf\")\n",
" plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 35,
"id": "OIqRt466DiGk",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 307
},
"id": "OIqRt466DiGk",
"outputId": "b16987cf-0001-4652-ddaf-02f7cffc34db"
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAeoAAAEiCAYAAAA21pHjAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAAA9hAAAPYQGoP6dpAABXi0lEQVR4nO3deVxU9f748dfMwAz7viOCyuIK7uZOSamVZatfr7e0LG+FlZkt3krNfkWL3awsK7vJrVtZWVq3XELc9xUFF9wBlc2FVRhg5vz+GBidxAUEZsD38/E4D+Z8zuec855P5JvzOZ9zPipFURSEEEIIYZPU1g5ACCGEEJcniVoIIYSwYZKohRBCCBsmiVoIIYSwYZKohRBCCBsmiVoIIYSwYZKohRBCCBsmiVoIIYSwYZKohRBCCBsmiVoIcU1iY2OZNGmStcMQ4oYjiVqIJjJu3DhUKtUly7Bhw6wdmhDChtlZOwAhbiTDhg1j/vz5FmU6nc5K0QghmgO5ohaiCel0OgICAiwWT09PAFavXo1Wq2XdunXm+u+++y5+fn7k5uYCsGzZMgYMGICHhwfe3t7ceeedHDlyxFz/+PHjqFQqfvzxRwYOHIijoyO9evXi4MGDbNu2jZ49e+Li4sLw4cPJz8837zdu3DhGjhzJ66+/jq+vL25ubjzxxBNUVFRc9rvo9XqmTJlCcHAwzs7O9OnTh9WrV5u3Z2RkMGLECDw9PXF2dqZTp04sWbLkssf79NNPiYiIwMHBAX9/f+6//37zNqPRSEJCAm3atMHR0ZGYmBgWLlxosX9aWhrDhw/HxcUFf39/HnroIU6fPm3eHhsbyzPPPMOLL76Il5cXAQEBzJgx47LxCGErJFELYSNq7gE/9NBDFBYWsmvXLl577TW+/PJL/P39ASgtLWXy5Mls376d5ORk1Go199xzD0aj0eJY06dP59VXX2Xnzp3Y2dnxt7/9jRdffJEPP/yQdevWcfjwYaZNm2axT3JyMvv372f16tV8//33/PLLL7z++uuXjXfixIls2rSJBQsWsGfPHh544AGGDRvGoUOHAIiPj0ev17N27VpSU1N55513cHFxqfVY27dv55lnnmHmzJmkp6ezbNkyBg0aZN6ekJDA119/zWeffcbevXt57rnn+Pvf/86aNWsAKCgo4JZbbqFbt25s376dZcuWkZuby4MPPmhxnv/85z84OzuzZcsW3n33XWbOnElSUtI1/hcSwkoUIUSTGDt2rKLRaBRnZ2eL5c033zTX0ev1SteuXZUHH3xQ6dixo/L4449f8Zj5+fkKoKSmpiqKoijHjh1TAOXLL7801/n+++8VQElOTjaXJSQkKFFRURaxeXl5KaWlpeayuXPnKi4uLorBYFAURVEGDx6sPPvss4qiKEpGRoai0WiUkydPWsQzZMgQZerUqYqiKEqXLl2UGTNmXFPb/Pzzz4qbm5tSVFR0ybby8nLFyclJ2bhxo0X5+PHjldGjRyuKoihvvPGGctttt1lsz8rKUgAlPT3dHP+AAQMs6vTq1Ut56aWXrilGIaxF7lEL0YRuvvlm5s6da1Hm5eVl/qzVavn222+Jjo4mNDSUDz74wKLuoUOHmDZtGlu2bOH06dPmK+nMzEw6d+5srhcdHW3+XHM13qVLF4uyvLw8i2PHxMTg5ORkXu/bty8lJSVkZWURGhpqUTc1NRWDwUBkZKRFuV6vx9vbG4BnnnmGJ598kj///JO4uDjuu+8+i7guduuttxIaGkrbtm0ZNmwYw4YN45577sHJyYnDhw9z/vx5br31Vot9Kioq6NatGwC7d+9m1apVtV6xHzlyxBznX88fGBh4STsIYWskUQvRhJydnQkPD79inY0bNwJw9uxZzp49i7Ozs3nbiBEjCA0NZd68eQQFBWE0GuncufMl95Lt7e3Nn1UqVa1lf+0ur4uSkhI0Gg07duxAo9FYbKtJlo899hhDhw7ljz/+4M8//yQhIYH333+fp59++pLjubq6snPnTlavXs2ff/7JtGnTmDFjBtu2baOkpASAP/74g+DgYIv9agbilZSUMGLECN55551Ljh0YGGj+fHEbwPW3gxBNQRK1EDbkyJEjPPfcc8ybN48ffviBsWPHsmLFCtRqNWfOnCE9PZ158+YxcOBAANavX99g5969ezdlZWU4OjoCsHnzZlxcXAgJCbmkbrdu3TAYDOTl5ZljqU1ISAhPPPEETzzxBFOnTmXevHm1JmoAOzs74uLiiIuLY/r06Xh4eLBy5UpuvfVWdDodmZmZDB48uNZ9u3fvzs8//0xYWBh2dvLPmmhZ5DdaiCak1+vJycmxKLOzs8PHxweDwcDf//53hg4dyiOPPMKwYcPo0qUL77//Pi+88AKenp54e3vzxRdfEBgYSGZmJi+//HKDxVZRUcH48eN59dVXOX78ONOnT2fixImo1ZeOOY2MjGTMmDE8/PDDvP/++3Tr1o38/HySk5OJjo7mjjvuYNKkSQwfPpzIyEjOnTvHqlWr6NChQ63n/v333zl69CiDBg3C09OTJUuWYDQaiYqKwtXVlSlTpvDcc89hNBoZMGAAhYWFbNiwATc3N8aOHUt8fDzz5s1j9OjR5lHdhw8fZsGCBXz55ZeXXPUL0ZxIohaiCS1btsyiKxYgKiqKAwcO8Oabb5KRkcHvv/8OmLpsv/jiC0aPHs1tt91GTEwMCxYs4JlnnqFz585ERUXx0UcfERsb2yCxDRkyhIiICAYNGoRer2f06NFXfHxp/vz5/L//9/94/vnnOXnyJD4+Ptx0003ceeedABgMBuLj4zlx4gRubm4MGzbsknvuNTw8PPjll1+YMWMG5eXlRERE8P3339OpUycA3njjDXx9fUlISODo0aN4eHjQvXt3/vnPfwIQFBTEhg0beOmll7jtttvQ6/WEhoYybNiwWv/QEKI5USmKolg7CCGEdY0bN46CggIWL15s7VCEEH8hf2oKIYQQNkwStRBCCGHDpOtbCCGEsGFyRS2EEELYMEnUQgghhA2TRC2EEELYMEnU1+GTTz4hLCwMBwcH+vTpw9atW60dUqNZu3YtI0aMICgoCJVKdcljPIqiMG3aNAIDA3F0dCQuLs48i1KNs2fPMmbMGNzc3PDw8GD8+PHm10PW2LNnDwMHDsTBwYGQkBDefffdxv5qDSIhIYFevXrh6uqKn58fI0eOJD093aJOeXk58fHxeHt74+Liwn333WeevrJGZmYmd9xxB05OTvj5+fHCCy9QVVVlUWf16tV0794dnU5HeHg4iYmJjf31GsTcuXOJjo7Gzc0NNzc3+vbty9KlS83bb/T2qc3bb7+NSqVi0qRJ5jJpJ5gxYwYqlcpiad++vXl7i2sjq04J0owtWLBA0Wq1yldffaXs3btXefzxxxUPDw8lNzfX2qE1iiVLliivvPKK8ssvvyiAsmjRIovtb7/9tuLu7q4sXrxY2b17t3LXXXcpbdq0UcrKysx1hg0bpsTExCibN29W1q1bp4SHh5tnP1IURSksLFT8/f2VMWPGKGlpacr333+vODo6Kp9//nlTfc16Gzp0qDJ//nwlLS1NSUlJUW6//XaldevWSklJibnOE088oYSEhCjJycnK9u3blZtuuknp16+feXtVVZXSuXNnJS4uTtm1a5eyZMkSxcfHxzwblaIoytGjRxUnJydl8uTJyr59+5SPP/5Y0Wg0yrJly5r0+9bHb7/9pvzxxx/KwYMHlfT0dOWf//ynYm9vr6SlpSmKIu3zV1u3blXCwsKU6Oho86xliiLtpCiKMn36dKVTp05Kdna2ecnPzzdvb2ltJIm6nnr37q3Ex8eb1w0GgxIUFKQkJCRYMaqm8ddEbTQalYCAAOW9994zlxUUFCg6nU75/vvvFUVRlH379imAsm3bNnOdpUuXKiqVyjxV4qeffqp4enoqer3eXOell16ymI6xucjLy1MAZc2aNYqimNrD3t5e+emnn8x19u/frwDKpk2bFEUx/TGkVquVnJwcc525c+cqbm5u5jZ58cUXlU6dOlmca9SoUcrQoUMb+ys1Ck9PT+XLL7+U9vmL4uJiJSI
"text/plain": [
"<Figure size 500x300 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"epochs_tensor = torch.linspace(0, num_epochs, len(train_losses))\n",
"examples_seen_tensor = torch.linspace(0, examples_seen, len(train_losses))\n",
"\n",
"plot_values(epochs_tensor, examples_seen_tensor, train_losses, val_losses)"
]
},
{
"cell_type": "markdown",
"id": "dbd28174-1836-44ba-b6c0-7e0be774fadc",
"metadata": {},
"source": [
"- Above, based on the downward slope, we see that the model learns well\n",
"- Furthermore, the fact that the training and validation loss are very close indicates that the model does not tend to overfit the training data\n",
"- Similarly, we can plot the accuracy below"
]
},
{
"cell_type": "code",
"execution_count": 36,
"id": "yz8BIsaF0TUo",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 307
},
"id": "yz8BIsaF0TUo",
"outputId": "3a7ed967-1f2a-4c6d-f4a3-0cc8cc9d6c5f"
},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAeEAAAEiCAYAAADONmoUAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAAA9hAAAPYQGoP6dpAABdB0lEQVR4nO3deVhU1f/A8fcMOOyrIIIiouKuiBthbrmESyRmaWaJS/rTXDPTLPcWysosNU0tbXNPzW+4RLjvKyou5IKiCLjLomwz9/fH5OgIKoPoIHxezzPPM3Puued+5oh8uPeee45KURQFIYQQQjx1anMHIIQQQpRUkoSFEEIIM5EkLIQQQpiJJGEhhBDCTCQJCyGEEGYiSVgIIYQwE0nCQgghhJlIEhZCCCHMRJKwEEIIYSaShIUQeWrZsiXDhw83dxhCFGuShIV4Qnr16oVKpcr1ateunblDE0IUEZbmDkCI4qxdu3bMnz/fqMzKyspM0Qghiho5ExbiCbKysqJs2bJGLxcXFwA2bdqERqNh69athvpTpkyhTJkyJCcnA7Bu3TqaNm2Ks7MzpUuX5qWXXuL06dOG+mfPnkWlUrF06VKaNWuGjY0NjRo14t9//2Xv3r00bNgQe3t72rdvz+XLlw379erVi9DQUCZNmoS7uzuOjo4MGDCArKysB36XzMxMRo4cSbly5bCzsyMwMJBNmzYZtp87d46QkBBcXFyws7OjVq1arFmz5oHtff/99/j5+WFtbY2HhwevvvqqYZtOpyM8PBxfX19sbGzw9/dn+fLlRvvHxMTQvn177O3t8fDw4K233uLKlSuG7S1btmTo0KGMGjUKV1dXypYty8SJEx8YjxDmIElYCDO5c8/1rbfe4ubNmxw8eJBx48Yxb948PDw8AEhPT2fEiBHs27ePqKgo1Go1nTt3RqfTGbU1YcIExo4dy4EDB7C0tOSNN95g1KhRfPvtt2zdupVTp04xfvx4o32ioqI4fvw4mzZtYtGiRaxYsYJJkyY9MN7Bgwezc+dOFi9ezOHDh3nttddo164dJ0+eBGDQoEFkZmayZcsWjhw5whdffIG9vX2ebe3bt4+hQ4cyefJkYmNjWbduHc2bNzdsDw8P55dffmH27NkcPXqUd999lzfffJPNmzcDcOPGDVq1akVAQAD79u1j3bp1JCcn07VrV6Pj/Pzzz9jZ2bF7926mTJnC5MmTiYyMzOe/kBBPgSKEeCLCwsIUCwsLxc7Ozuj16aefGupkZmYq9erVU7p27arUrFlT6dev30PbvHz5sgIoR44cURRFUeLi4hRAmTdvnqHOokWLFECJiooylIWHhyvVqlUzis3V1VVJT083lM2aNUuxt7dXtFqtoiiK0qJFC2XYsGGKoijKuXPnFAsLCyUhIcEontatWytjxoxRFEVR6tSpo0ycODFfffPHH38ojo6OSkpKSq5tGRkZiq2trbJjxw6j8r59+yrdu3dXFEVRPv74Y+XFF1802n7+/HkFUGJjYw3xN23a1KhOo0aNlNGjR+crRiGeBrknLMQT9MILLzBr1iyjMldXV8N7jUbD77//Tt26dfHx8eGbb74xqnvy5EnGjx/P7t27uXLliuEMOD4+ntq1axvq1a1b1/D+zll0nTp1jMouXbpk1La/vz+2traGz0FBQaSlpXH+/Hl8fHyM6h45cgStVkvVqlWNyjMzMyldujQAQ4cOZeDAgfz999+0adOGLl26GMV1r7Zt2+Lj40OlSpVo164d7dq1o3Pnztja2nLq1Clu3bpF27ZtjfbJysoiICAAgEOHDrFx48Y8z7RPnz5tiPP+43t6eubqByHMSZKwEE+QnZ0dVapUeWidHTt2AHDt2jWuXbuGnZ2dYVtISAg+Pj7MnTsXLy8vdDodtWvXznXvtlSpUob3KpUqz7L7L2GbIi0tDQsLC/bv34+FhYXRtjuJ8O233yY4OJiIiAj+/vtvwsPD+frrrxkyZEiu9hwcHDhw4ACbNm3i77//Zvz48UycOJG9e/eSlpYGQEREBOXKlTPa786gtrS0NEJCQvjiiy9yte3p6Wl4f28fwOP3gxCFTZKwEGZ0+vRp3n33XebOncuSJUsICwvjn3/+Qa1Wc/XqVWJjY5k7dy7NmjUDYNu2bYV27EOHDnH79m1sbGwA2LVrF/b29nh7e+eqGxAQgFar5dKlS4ZY8uLt7c2AAQMYMGAAY8aMYe7cuXkmYQBLS0vatGlDmzZtmDBhAs7OzmzYsIG2bdtiZWVFfHw8LVq0yHPf+vXr88cff1CxYkUsLeXXmHh2yU+vEE9QZmYmSUlJRmWWlpa4ubmh1Wp58803CQ4Opnfv3rRr1446derw9ddf8/777+Pi4kLp0qWZM2cOnp6exMfH88EHHxRabFlZWfTt25exY8dy9uxZJkyYwODBg1Grc4/XrFq1Kj169KBnz558/fXXBAQEcPnyZaKioqhbty4dO3Zk+PDhtG/fnqpVq3L9+nU2btxIjRo18jz2X3/9xZkzZ2jevDkuLi6sWbMGnU5HtWrVcHBwYOTIkbz77rvodDqaNm3KzZs32b59O46OjoSFhTFo0CDmzp1L9+7dDaOfT506xeLFi5k3b16us3UhiipJwkI8QevWrTO6PApQrVo1Tpw4waeffsq5c+f466+/AP1l1Dlz5tC9e3defPFF/P39Wbx4MUOHDqV27dpUq1aN7777jpYtWxZKbK1bt8bPz4/mzZuTmZlJ9+7dH/oIz/z58/nkk0947733SEhIwM3Njeeee46XXnoJAK1Wy6BBg7hw4QKOjo60a9cu1z3uO5ydnVmxYgUTJ04kIyMDPz8/Fi1aRK1atQD4+OOPcXd3Jzw8nDNnzuDs7Ez9+vX58MMPAfDy8mL79u2MHj2aF198kczMTHx8fGjXrl2ef0QIUVSpFEVRzB2EEOLp6tWrFzdu3GDVqlXmDkWIEk3+ZBRCCCHMRJKwEEIIYSZyOVoIIYQwEzkTFkIIIcxEkrAQQghhJpKEhRBCCDORJFxAM2fOpGLFilhbWxMYGMiePXvMHdITsWXLFkJCQvDy8kKlUuV6pEVRFMaPH4+npyc2Nja0adPGsKrOHdeuXaNHjx44Ojri7OxM3759DVMT3nH48GGaNWuGtbU13t7eTJky5Ul/tccWHh5Oo0aNcHBwoEyZMoSGhhIbG2tUJyMjg0GDBlG6dGns7e3p0qWLYZnCO+Lj4+nYsSO2traUKVOG999/n5ycHKM6mzZton79+lhZWVGlShUWLFjwpL/eY5k1axZ169bF0dERR0dHgoKCWLt2rWF7Se2XB/n8889RqVQMHz7cUFaS+2jixImoVCqjV/Xq1Q3bi1XfmHX5iGfU4sWLFY1Go/z000/K0aNHlX79+inOzs5KcnKyuUMrdGvWrFE++ugjZcWKFQqgrFy50mj7559/rjg5OSmrVq1SDh06pLz88suKr6+vcvv2bUOddu3aKf7+/squXbuUrVu3KlWqVDGshqMoinLz5k3Fw8ND6dGjhxITE6MsWrRIsbGxUX744Yen9TULJDg4WJk/f74SExOjREdHKx06dFAqVKigpKWlGeoMGDBA8fb2VqKiopR9+/Ypzz33nNKkSRPD9pycHKV27dpKmzZtlIMHDypr1qxR3NzcDCsTKYqinDlzRrG1tVVGjBihHDt2TJk+fbpiYWGhrFu37ql+X1OsXr1aiYiIUP79918lNjZW+fDDD5VSpUopMTExiqKU3H7Jy549e5SKFSsqdevWNaxapSglu48mTJig1KpVS0lMTDS8Ll++bNhenPpGknABNG7cWBk0aJDhs1arVby8vJTw8HAzRvXk3Z+EdTqdUrZsWeXLL780lN24cUOxsrJSFi1apCiKohw7dkwBlL179xrqrF27VlGpVIZl8b7//nvFxcVFyczMNNQZPXq00dJ7z4JLly4pgLJ582ZFUfR9UapUKWXZsmWGOsePH1cAZefOnYqi6P/IUavVSlJSkqHOrFmzFEdHR0N/jBo1SqlVq5bRsbp166YEBwc/6a9UqFxcXJR58+ZJv9wjNTVV8fPzUyIjI42WjizpfTR
"text/plain": [
"<Figure size 500x300 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"epochs_tensor = torch.linspace(0, num_epochs, len(train_accs))\n",
"examples_seen_tensor = torch.linspace(0, examples_seen, len(train_accs))\n",
"\n",
"plot_values(epochs_tensor, examples_seen_tensor, train_accs, val_accs, label=\"accuracy\")"
]
},
{
"cell_type": "markdown",
"id": "90aba699-21bc-42de-a69c-99f370bb0363",
"metadata": {},
"source": [
"- Based on the accuracy plot above, we can see that the model achieves a relatively high training and validation accuracy after epochs 4 and 5\n",
"- However, we have to keep in mind that we specified `eval_iter=5` in the training function earlier, which means that we only estimated the training and validation set performances\n",
"- We can compute the training, validation, and test set performances over the complete dataset as follows below"
]
},
{
"cell_type": "code",
"execution_count": 37,
"id": "UHWaJFrjY0zW",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "UHWaJFrjY0zW",
"outputId": "e111e6e6-b147-4159-eb9d-19d4e809ed34"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Training accuracy: 97.21%\n",
"Validation accuracy: 97.32%\n",
"Test accuracy: 95.67%\n"
]
}
],
"source": [
"train_accuracy = calc_accuracy_loader(train_loader, model, device)\n",
"val_accuracy = calc_accuracy_loader(val_loader, model, device)\n",
"test_accuracy = calc_accuracy_loader(test_loader, model, device)\n",
"\n",
"print(f\"Training accuracy: {train_accuracy*100:.2f}%\")\n",
"print(f\"Validation accuracy: {val_accuracy*100:.2f}%\")\n",
"print(f\"Test accuracy: {test_accuracy*100:.2f}%\")"
]
},
{
"cell_type": "markdown",
"id": "6882649f-dc7b-401f-84d2-024ff79c74a1",
"metadata": {},
"source": [
"- We can see that the training and test set performances are practically identical\n",
2024-05-09 09:09:26 -05:00
"- However, based on the slightly lower test set performance, we can see that the model overfits the training data to a very small degree, as well as the validation data that has been used for tweaking some of the hyperparameters, such as the learning rate\n",
"- This is normal, however, and this gap could potentially be further reduced by increasing the model's dropout rate (`drop_rate`) or the `weight_decay` in the optimizer setting"
]
},
{
"cell_type": "markdown",
"id": "a74d9ad7-3ec1-450e-8c9f-4fc46d3d5bb0",
"metadata": {},
"source": [
2024-05-08 06:48:28 -05:00
"## 6.8 Using the LLM as a spam classifier"
2024-05-05 07:10:04 -05:00
]
},
{
"cell_type": "markdown",
"id": "72ebcfa2-479e-408b-9cf0-7421f6144855",
"metadata": {},
"source": [
"<img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/ch06_compressed/overview-4.webp\" width=500px>"
]
},
{
"cell_type": "markdown",
"id": "fd5408e6-83e4-4e5a-8503-c2fba6073f31",
"metadata": {},
"source": [
"- Finally, let's use the finetuned GPT model in action\n",
"- The `classify_review` function below implements the data preprocessing steps similar to the `SpamDataset` we implemented earlier\n",
"- Then, the function returns the predicted integer class label from the model and returns the corresponding class name"
]
},
{
"cell_type": "code",
"execution_count": 38,
"id": "aHdn6xvL-IW5",
"metadata": {
"id": "aHdn6xvL-IW5"
},
"outputs": [],
"source": [
"def classify_review(text, model, tokenizer, device, max_length=None, pad_token_id=50256):\n",
" model.eval()\n",
"\n",
" # Prepare inputs to the model\n",
" input_ids = tokenizer.encode(text)\n",
" supported_context_length = model.pos_emb.weight.shape[1]\n",
"\n",
" # Truncate sequences if they too long\n",
" input_ids = input_ids[:min(max_length, supported_context_length)]\n",
"\n",
" # Pad sequences to the longest sequence\n",
" input_ids += [pad_token_id] * (max_length - len(input_ids))\n",
" input_tensor = torch.tensor(input_ids, device=device).unsqueeze(0) # add batch dimension\n",
"\n",
" # Model inference\n",
" with torch.no_grad():\n",
" logits = model(input_tensor)[:, -1, :] # Logits of the last output token\n",
" predicted_label = torch.argmax(logits, dim=-1).item()\n",
"\n",
" # Return the classified result\n",
" return \"Positive\" if predicted_label == 1 else \"Negative\""
]
},
{
"cell_type": "markdown",
"id": "f29682d8-a899-4d9b-b973-f8d5ec68172c",
"metadata": {},
"source": [
"- Let's try it out on a few examples below"
]
},
{
"cell_type": "code",
"execution_count": 39,
"id": "apU_pf51AWSV",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "apU_pf51AWSV",
"outputId": "d0fde0a5-e7a3-4dbe-d9c5-0567dbab7e62"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Positive\n"
]
}
],
"source": [
"text_1 = (\n",
" \"You are a winner you have been specially\"\n",
" \" selected to receive $1000 cash or a $2000 award.\"\n",
")\n",
"\n",
2024-05-06 20:35:51 -05:00
"print(classify_review(\n",
" text_1, model, tokenizer, device, max_length=train_dataset.max_length\n",
"))"
]
},
{
"cell_type": "code",
"execution_count": 40,
"id": "1g5VTOo_Ajs5",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "1g5VTOo_Ajs5",
"outputId": "659b08eb-b6a9-4a8a-9af7-d94c757e93c2"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Negative\n"
]
}
],
"source": [
"text_2 = (\n",
" \"Hey, just wanted to check if we're still on\"\n",
" \" for dinner tonight? Let me know!\"\n",
")\n",
"\n",
2024-05-06 20:35:51 -05:00
"print(classify_review(\n",
" text_2, model, tokenizer, device, max_length=train_dataset.max_length\n",
"))"
]
},
{
"cell_type": "markdown",
"id": "bf736e39-0d47-40c1-8d18-1f716cf7a81e",
"metadata": {},
"source": [
"- Finally, let's save the model in case we want to reuse the model later without having to train it again"
]
},
{
"cell_type": "code",
"execution_count": 41,
"id": "mYnX-gI1CfQY",
"metadata": {
"id": "mYnX-gI1CfQY"
},
"outputs": [],
"source": [
"torch.save(model.state_dict(), \"review_classifier.pth\")"
]
},
{
"cell_type": "markdown",
"id": "ba78cf7c-6b80-4f71-a50e-3ccc73839af6",
"metadata": {},
"source": [
"- Then, in a new session, we could load the model as follows"
]
},
{
"cell_type": "code",
"execution_count": 42,
"id": "cc4e68a5-d492-493b-87ef-45c475f353f5",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<All keys matched successfully>"
]
},
"execution_count": 42,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model_state_dict = torch.load(\"review_classifier.pth\")\n",
"model.load_state_dict(model_state_dict)"
]
},
{
"cell_type": "markdown",
"id": "5b70ac71-234f-4eeb-b33d-c62726d50cd4",
"metadata": {
"id": "5b70ac71-234f-4eeb-b33d-c62726d50cd4"
},
"source": [
"## Summary and takeaways"
]
},
{
"cell_type": "markdown",
"id": "dafdc910-d616-47ab-aa85-f90c6e7ed80e",
"metadata": {},
"source": [
"- Interested readers can find an introduction to parameter-efficient training with low-rank adaptation (LoRA) in appendix E\n"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"gpuType": "V100",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
2024-05-10 07:02:14 -05:00
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}