diff --git a/README.md b/README.md
index 4686540..5f2a37d 100644
--- a/README.md
+++ b/README.md
@@ -89,6 +89,9 @@ Several folders contain optional materials as a bonus for interested readers:
- [Pretraining GPT on the Project Gutenberg Dataset](ch05/03_bonus_pretraining_on_gutenberg)
- [Adding Bells and Whistles to the Training Loop](ch05/04_learning_rate_schedulers)
- [Optimizing Hyperparameters for Pretraining](ch05/05_bonus_hparam_tuning)
+- **Chapter 6:**
+ - [Additional experiments finetuning different layers and using larger models](ch06/02_bonus_additional-experiments)
+ - [Finetuning different models on 50k IMDB movie review dataset](ch06/03_bonus_imdb-classification)
diff --git a/ch06/03_bonus_imdb-classification/README.md b/ch06/03_bonus_imdb-classification/README.md
new file mode 100644
index 0000000..a5e96f6
--- /dev/null
+++ b/ch06/03_bonus_imdb-classification/README.md
@@ -0,0 +1,117 @@
+
+## Step 1: Install Dependencies
+
+Install the extra dependencies via
+
+```bash
+pip install -r requirements-extra.txt
+```
+
+
+## Step 2: Download Dataset
+
+The codes are using the 50k movie reviews from IMDb ([dataset source](https://ai.stanford.edu/~amaas/data/sentiment/)) to predict whether a movie review is positive or negative.
+
+Run the following code to create the `train.csv`, `val.csv`, and `test.csv` datasets:
+
+```bash
+download-prepare-dataset.py
+```
+
+
+
+## Step 3: Run Models
+
+The 124M GPT-2 model used in the main chapter, starting for the pretrained weights and only training the last transformer block plus output layers:
+
+```bash
+python train-gpt.py
+```
+
+```
+Ep 1 (Step 000000): Train loss 2.829, Val loss 3.433
+Ep 1 (Step 000050): Train loss 1.440, Val loss 1.669
+Ep 1 (Step 000100): Train loss 0.879, Val loss 1.037
+Ep 1 (Step 000150): Train loss 0.838, Val loss 0.866
+...
+Ep 1 (Step 004300): Train loss 0.174, Val loss 0.202
+Ep 1 (Step 004350): Train loss 0.309, Val loss 0.190
+Training accuracy: 88.75% | Validation accuracy: 91.25%
+Ep 2 (Step 004400): Train loss 0.263, Val loss 0.205
+Ep 2 (Step 004450): Train loss 0.226, Val loss 0.188
+...
+Ep 2 (Step 008650): Train loss 0.189, Val loss 0.171
+Ep 2 (Step 008700): Train loss 0.225, Val loss 0.179
+Training accuracy: 85.00% | Validation accuracy: 90.62%
+Ep 3 (Step 008750): Train loss 0.206, Val loss 0.187
+Ep 3 (Step 008800): Train loss 0.198, Val loss 0.172
+...
+Training accuracy: 96.88% | Validation accuracy: 90.62%
+Training completed in 18.62 minutes.
+
+Evaluating on the full datasets ...
+
+Training accuracy: 93.66%
+Validation accuracy: 90.02%
+Test accuracy: 89.96%
+
+```
+
+---
+
+A 66M parameter encoder-style [DistilBERT](https://medium.com/huggingface/distilbert-8cf3380435b5) model (distilled down from a 340M parameter BERT model), starting for the pretrained weights and only training the last transformer block plus output layers:
+
+
+```bash
+python train-bert-hf.py
+```
+
+```
+Ep 1 (Step 000000): Train loss 0.693, Val loss 0.697
+Ep 1 (Step 000050): Train loss 0.532, Val loss 0.596
+Ep 1 (Step 000100): Train loss 0.431, Val loss 0.446
+...
+Ep 1 (Step 004300): Train loss 0.234, Val loss 0.351
+Ep 1 (Step 004350): Train loss 0.190, Val loss 0.222
+Training accuracy: 88.75% | Validation accuracy: 88.12%
+Ep 2 (Step 004400): Train loss 0.258, Val loss 0.270
+Ep 2 (Step 004450): Train loss 0.204, Val loss 0.295
+...
+Ep 2 (Step 008650): Train loss 0.088, Val loss 0.246
+Ep 2 (Step 008700): Train loss 0.084, Val loss 0.247
+Training accuracy: 98.75% | Validation accuracy: 90.62%
+Ep 3 (Step 008750): Train loss 0.067, Val loss 0.209
+Ep 3 (Step 008800): Train loss 0.059, Val loss 0.256
+...
+Ep 3 (Step 013050): Train loss 0.068, Val loss 0.280
+Ep 3 (Step 013100): Train loss 0.064, Val loss 0.306
+Training accuracy: 99.38% | Validation accuracy: 87.50%
+Training completed in 16.70 minutes.
+
+Evaluating on the full datasets ...
+
+Training accuracy: 98.87%
+Validation accuracy: 90.98%
+Test accuracy: 90.81%
+```
+
+---
+
+A scikit-learn Logistic Regression model as a basline.
+
+```bash
+python train-sklearn-logreg.py
+```
+
+```
+Dummy classifier:
+Training Accuracy: 50.01%
+Validation Accuracy: 50.14%
+Test Accuracy: 49.91%
+
+
+Logistic regression classifier:
+Training Accuracy: 99.80%
+Validation Accuracy: 88.60%
+Test Accuracy: 88.84%
+```
\ No newline at end of file
diff --git a/ch06/03_bonus_imdb-classification/download-prepare-dataset.py b/ch06/03_bonus_imdb-classification/download-prepare-dataset.py
new file mode 100644
index 0000000..28197e6
--- /dev/null
+++ b/ch06/03_bonus_imdb-classification/download-prepare-dataset.py
@@ -0,0 +1,79 @@
+# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
+# Source for "Build a Large Language Model From Scratch"
+# - https://www.manning.com/books/build-a-large-language-model-from-scratch
+# Code: https://github.com/rasbt/LLMs-from-scratch
+
+import os
+import sys
+import tarfile
+import time
+import urllib.request
+import pandas as pd
+
+
+def reporthook(count, block_size, total_size):
+ global start_time
+ if count == 0:
+ start_time = time.time()
+ else:
+ duration = time.time() - start_time
+ progress_size = int(count * block_size)
+ percent = count * block_size * 100 / total_size
+ speed = progress_size / (1024**2 * duration)
+ sys.stdout.write(
+ f"\r{int(percent)}% | {progress_size / (1024**2):.2f} MB "
+ f"| {speed:.2f} MB/s | {duration:.2f} sec elapsed"
+ )
+ sys.stdout.flush()
+
+
+def download_and_extract_dataset(dataset_url, target_file, directory):
+ if not os.path.exists(directory):
+ if os.path.exists(target_file):
+ os.remove(target_file)
+ urllib.request.urlretrieve(dataset_url, target_file, reporthook)
+ with tarfile.open(target_file, "r:gz") as tar:
+ tar.extractall()
+ else:
+ print(f"Directory `{directory}` already exists. Skipping download.")
+
+
+def load_dataset_to_dataframe(basepath="aclImdb", labels={"pos": 1, "neg": 0}):
+ data_frames = [] # List to store each chunk of DataFrame
+ for subset in ("test", "train"):
+ for label in ("pos", "neg"):
+ path = os.path.join(basepath, subset, label)
+ for file in sorted(os.listdir(path)):
+ with open(os.path.join(path, file), "r", encoding="utf-8") as infile:
+ # Create a DataFrame for each file and add it to the list
+ data_frames.append(pd.DataFrame({"text": [infile.read()], "label": [labels[label]]}))
+ # Concatenate all DataFrame chunks together
+ df = pd.concat(data_frames, ignore_index=True)
+ df = df.sample(frac=1, random_state=123).reset_index(drop=True) # Shuffle the DataFrame
+ return df
+
+
+def partition_and_save(df, sizes=(35000, 5000, 10000)):
+ # Shuffle the DataFrame
+ df_shuffled = df.sample(frac=1, random_state=123).reset_index(drop=True)
+
+ # Get indices for where to split the data
+ train_end = sizes[0]
+ val_end = sizes[0] + sizes[1]
+
+ # Split the DataFrame
+ train = df_shuffled.iloc[:train_end]
+ val = df_shuffled.iloc[train_end:val_end]
+ test = df_shuffled.iloc[val_end:]
+
+ # Save to CSV files
+ train.to_csv("train.csv", index=False)
+ val.to_csv("val.csv", index=False)
+ test.to_csv("test.csv", index=False)
+
+
+if __name__ == "__main__":
+ dataset_url = "http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz"
+ download_and_extract_dataset(dataset_url, "aclImdb_v1.tar.gz", "aclImdb")
+ df = load_dataset_to_dataframe()
+ partition_and_save(df)
diff --git a/ch06/03_bonus_imdb-classification/requirements-extra.txt b/ch06/03_bonus_imdb-classification/requirements-extra.txt
new file mode 100644
index 0000000..7ab8694
--- /dev/null
+++ b/ch06/03_bonus_imdb-classification/requirements-extra.txt
@@ -0,0 +1,2 @@
+transformers>=4.33.2
+scikit-learn>=1.3.0
\ No newline at end of file
diff --git a/ch06/03_bonus_imdb-classification/sklearn-baseline.ipynb b/ch06/03_bonus_imdb-classification/sklearn-baseline.ipynb
new file mode 100644
index 0000000..c5d217d
--- /dev/null
+++ b/ch06/03_bonus_imdb-classification/sklearn-baseline.ipynb
@@ -0,0 +1,518 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 75,
+ "id": "b612c4c1-fa3c-47b9-a8ce-9e32f371e160",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "sms_spam_collection/SMSSpamCollection.tsv already exists. Skipping download and extraction.\n"
+ ]
+ }
+ ],
+ "source": [
+ "import urllib.request\n",
+ "import zipfile\n",
+ "import os\n",
+ "from pathlib import Path\n",
+ "\n",
+ "url = \"https://archive.ics.uci.edu/static/public/228/sms+spam+collection.zip\"\n",
+ "zip_path = \"sms_spam_collection.zip\"\n",
+ "extract_to = \"sms_spam_collection\"\n",
+ "new_file_path = Path(extract_to) / \"SMSSpamCollection.tsv\"\n",
+ "\n",
+ "def download_and_unzip(url, zip_path, extract_to, new_file_path):\n",
+ " # Check if the target file already exists\n",
+ " if new_file_path.exists():\n",
+ " print(f\"{new_file_path} already exists. Skipping download and extraction.\")\n",
+ " return\n",
+ "\n",
+ " # Downloading the file\n",
+ " with urllib.request.urlopen(url) as response:\n",
+ " with open(zip_path, \"wb\") as out_file:\n",
+ " out_file.write(response.read())\n",
+ "\n",
+ " # Unzipping the file\n",
+ " with zipfile.ZipFile(zip_path, 'r') as zip_ref:\n",
+ " zip_ref.extractall(extract_to)\n",
+ "\n",
+ " # Renaming the file to indicate its format\n",
+ " original_file = Path(extract_to) / \"SMSSpamCollection\"\n",
+ " os.rename(original_file, new_file_path)\n",
+ " print(f\"File download and saved as {new_file_path}\")\n",
+ "\n",
+ "# Execute the function\n",
+ "download_and_unzip(url, zip_path, extract_to, new_file_path)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 76,
+ "id": "69f32433-e19c-4066-b806-8f30b408107f",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "
\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " Label | \n",
+ " Text | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " | 0 | \n",
+ " ham | \n",
+ " Aight text me when you're back at mu and I'll ... | \n",
+ "
\n",
+ " \n",
+ " | 1 | \n",
+ " ham | \n",
+ " Our Prashanthettan's mother passed away last n... | \n",
+ "
\n",
+ " \n",
+ " | 2 | \n",
+ " ham | \n",
+ " No it will reach by 9 only. She telling she wi... | \n",
+ "
\n",
+ " \n",
+ " | 3 | \n",
+ " ham | \n",
+ " Do you know when the result. | \n",
+ "
\n",
+ " \n",
+ " | 4 | \n",
+ " spam | \n",
+ " Hi. Customer Loyalty Offer:The NEW Nokia6650 M... | \n",
+ "
\n",
+ " \n",
+ " | ... | \n",
+ " ... | \n",
+ " ... | \n",
+ "
\n",
+ " \n",
+ " | 5567 | \n",
+ " ham | \n",
+ " I accidentally brought em home in the box | \n",
+ "
\n",
+ " \n",
+ " | 5568 | \n",
+ " spam | \n",
+ " Moby Pub Quiz.Win a £100 High Street prize if ... | \n",
+ "
\n",
+ " \n",
+ " | 5569 | \n",
+ " ham | \n",
+ " Que pases un buen tiempo or something like that | \n",
+ "
\n",
+ " \n",
+ " | 5570 | \n",
+ " ham | \n",
+ " Nowadays people are notixiquating the laxinorf... | \n",
+ "
\n",
+ " \n",
+ " | 5571 | \n",
+ " ham | \n",
+ " Ard 4 lor... | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
5572 rows × 2 columns
\n",
+ "
"
+ ],
+ "text/plain": [
+ " Label Text\n",
+ "0 ham Aight text me when you're back at mu and I'll ...\n",
+ "1 ham Our Prashanthettan's mother passed away last n...\n",
+ "2 ham No it will reach by 9 only. She telling she wi...\n",
+ "3 ham Do you know when the result.\n",
+ "4 spam Hi. Customer Loyalty Offer:The NEW Nokia6650 M...\n",
+ "... ... ...\n",
+ "5567 ham I accidentally brought em home in the box\n",
+ "5568 spam Moby Pub Quiz.Win a £100 High Street prize if ...\n",
+ "5569 ham Que pases un buen tiempo or something like that\n",
+ "5570 ham Nowadays people are notixiquating the laxinorf...\n",
+ "5571 ham Ard 4 lor...\n",
+ "\n",
+ "[5572 rows x 2 columns]"
+ ]
+ },
+ "execution_count": 76,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "import pandas as pd\n",
+ "\n",
+ "df = pd.read_csv(new_file_path, sep=\"\\t\", header=None, names=[\"Label\", \"Text\"])\n",
+ "df = df.sample(frac=1, random_state=123).reset_index(drop=True) # Shuffle the DataFrame\n",
+ "df"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 77,
+ "id": "4b7beeba-9f3a-45f0-b2dc-76bb155a8f0e",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Label\n",
+ "ham 4825\n",
+ "spam 747\n",
+ "Name: count, dtype: int64\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Class distribution\n",
+ "print(df[\"Label\"].value_counts())"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 78,
+ "id": "b3db862a-9e03-4715-babb-9b699e4f4a36",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Label\n",
+ "spam 747\n",
+ "ham 747\n",
+ "Name: count, dtype: int64\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Count the instances of 'spam'\n",
+ "n_spam = df[df[\"Label\"] == \"spam\"].shape[0]\n",
+ "\n",
+ "# Randomly sample 'ham' instances to match the number of 'spam' instances\n",
+ "ham_sampled = df[df[\"Label\"] == \"ham\"].sample(n_spam)\n",
+ "\n",
+ "# Combine the sampled 'ham' with all 'spam'\n",
+ "balanced_df = pd.concat([ham_sampled, df[df[\"Label\"] == \"spam\"]])\n",
+ "\n",
+ "# Shuffle the DataFrame\n",
+ "balanced_df = balanced_df.sample(frac=1).reset_index(drop=True)\n",
+ "\n",
+ "# Now balanced_df is the balanced DataFrame\n",
+ "print(balanced_df[\"Label\"].value_counts())"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 79,
+ "id": "0af991e5-98ef-439a-a43d-63a581a2cc6d",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "df[\"Label\"] = df[\"Label\"].map({\"ham\": 0, \"spam\": 1})"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 80,
+ "id": "2f5b00ef-e3ed-4819-b271-5f355848feb5",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Training set:\n",
+ "Label\n",
+ "0 0.86612\n",
+ "1 0.13388\n",
+ "Name: proportion, dtype: float64\n",
+ "\n",
+ "Validation set:\n",
+ "Label\n",
+ "0 0.866906\n",
+ "1 0.133094\n",
+ "Name: proportion, dtype: float64\n",
+ "\n",
+ "Test set:\n",
+ "Label\n",
+ "0 0.864816\n",
+ "1 0.135184\n",
+ "Name: proportion, dtype: float64\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Define split ratios\n",
+ "train_size, validation_size = 0.7, 0.1\n",
+ "# Test size is implied to be 0.2 as the remainder\n",
+ "\n",
+ "# Split the data\n",
+ "def stratified_split(df, stratify_col, train_frac, validation_frac):\n",
+ " stratified_train = pd.DataFrame()\n",
+ " stratified_validation = pd.DataFrame()\n",
+ " stratified_test = pd.DataFrame()\n",
+ "\n",
+ " # Stratify split by the unique values in the column\n",
+ " for value in df[stratify_col].unique():\n",
+ " # Filter the DataFrame for the class\n",
+ " df_class = df[df[stratify_col] == value]\n",
+ " \n",
+ " # Calculate class split sizes\n",
+ " train_end = int(len(df_class) * train_frac)\n",
+ " validation_end = train_end + int(len(df_class) * validation_frac)\n",
+ " \n",
+ " # Slice the DataFrame to get the sets\n",
+ " stratified_train = pd.concat([stratified_train, df_class[:train_end]], axis=0)\n",
+ " stratified_validation = pd.concat([stratified_validation, df_class[train_end:validation_end]], axis=0)\n",
+ " stratified_test = pd.concat([stratified_test, df_class[validation_end:]], axis=0)\n",
+ "\n",
+ " # Shuffle the sets again\n",
+ " stratified_train = stratified_train.sample(frac=1, random_state=123).reset_index(drop=True)\n",
+ " stratified_validation = stratified_validation.sample(frac=1, random_state=123).reset_index(drop=True)\n",
+ " stratified_test = stratified_test.sample(frac=1, random_state=123).reset_index(drop=True)\n",
+ "\n",
+ " return stratified_train, stratified_validation, stratified_test\n",
+ "\n",
+ "# Apply the stratified split function\n",
+ "train_df, validation_df, test_df = stratified_split(df, \"Label\", train_size, validation_size)\n",
+ "\n",
+ "# Check the results\n",
+ "print(f\"Training set:\\n{train_df['Label'].value_counts(normalize=True)}\")\n",
+ "print(f\"\\nValidation set:\\n{validation_df['Label'].value_counts(normalize=True)}\")\n",
+ "print(f\"\\nTest set:\\n{test_df['Label'].value_counts(normalize=True)}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 81,
+ "id": "65808167-2b93-45b0-8506-ce722732ce77",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Training set:\n",
+ "Label\n",
+ "ham 0.5\n",
+ "spam 0.5\n",
+ "Name: proportion, dtype: float64\n",
+ "\n",
+ "Validation set:\n",
+ "Label\n",
+ "ham 0.5\n",
+ "spam 0.5\n",
+ "Name: proportion, dtype: float64\n",
+ "\n",
+ "Test set:\n",
+ "Label\n",
+ "spam 0.5\n",
+ "ham 0.5\n",
+ "Name: proportion, dtype: float64\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Define split ratios\n",
+ "train_size, validation_size = 0.7, 0.1\n",
+ "# Test size is implied to be 0.2 as the remainder\n",
+ "\n",
+ "# Apply the stratified split function\n",
+ "train_df, validation_df, test_df = stratified_split(balanced_df, \"Label\", train_size, validation_size)\n",
+ "\n",
+ "# Check the results\n",
+ "print(f\"Training set:\\n{train_df['Label'].value_counts(normalize=True)}\")\n",
+ "print(f\"\\nValidation set:\\n{validation_df['Label'].value_counts(normalize=True)}\")\n",
+ "print(f\"\\nTest set:\\n{test_df['Label'].value_counts(normalize=True)}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "fae87bc1-14ca-4f89-8e12-49f77b0ec00d",
+ "metadata": {},
+ "source": [
+ "## Scikit-learn baseline"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 82,
+ "id": "180318b7-de18-4b05-b84a-ba97c72b9d8e",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from sklearn.feature_extraction.text import CountVectorizer\n",
+ "from sklearn.linear_model import LogisticRegression\n",
+ "from sklearn.metrics import accuracy_score, balanced_accuracy_score"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 83,
+ "id": "25090b7c-f516-4be2-8083-3a7187fe4635",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "vectorizer = CountVectorizer()\n",
+ "\n",
+ "X_train = vectorizer.fit_transform(train_df[\"Text\"])\n",
+ "X_val = vectorizer.transform(validation_df[\"Text\"])\n",
+ "X_test = vectorizer.transform(test_df[\"Text\"])\n",
+ "\n",
+ "y_train, y_val, y_test = train_df[\"Label\"], validation_df[\"Label\"], test_df[\"Label\"]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 84,
+ "id": "0247de3a-88f0-4b9c-becd-157baf3acf49",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def eval(model, X_train, y_train, X_val, y_val, X_test, y_test):\n",
+ " # Making predictions\n",
+ " y_pred_train = model.predict(X_train)\n",
+ " y_pred_val = model.predict(X_val)\n",
+ " y_pred_test = model.predict(X_test)\n",
+ " \n",
+ " # Calculating accuracy and balanced accuracy\n",
+ " accuracy_train = accuracy_score(y_train, y_pred_train)\n",
+ " balanced_accuracy_train = balanced_accuracy_score(y_train, y_pred_train)\n",
+ " \n",
+ " accuracy_val = accuracy_score(y_val, y_pred_val)\n",
+ " balanced_accuracy_val = balanced_accuracy_score(y_val, y_pred_val)\n",
+ "\n",
+ " accuracy_test = accuracy_score(y_test, y_pred_test)\n",
+ " balanced_accuracy_test = balanced_accuracy_score(y_test, y_pred_test)\n",
+ " \n",
+ " # Printing the results\n",
+ " print(f\"Training Accuracy: {accuracy_train*100:.2f}%\")\n",
+ " print(f\"Validation Accuracy: {accuracy_val*100:.2f}%\")\n",
+ " print(f\"Test Accuracy: {accuracy_test*100:.2f}%\")\n",
+ " \n",
+ " print(f\"\\nTraining Balanced Accuracy: {balanced_accuracy_train*100:.2f}%\")\n",
+ " print(f\"Validation Balanced Accuracy: {balanced_accuracy_val*100:.2f}%\")\n",
+ " print(f\"Test Balanced Accuracy: {balanced_accuracy_test*100:.2f}%\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 85,
+ "id": "c29c6dfc-f72d-40ab-8cb5-783aad1a15ab",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Training Accuracy: 50.00%\n",
+ "Validation Accuracy: 50.00%\n",
+ "Test Accuracy: 50.00%\n",
+ "\n",
+ "Training Balanced Accuracy: 50.00%\n",
+ "Validation Balanced Accuracy: 50.00%\n",
+ "Test Balanced Accuracy: 50.00%\n"
+ ]
+ }
+ ],
+ "source": [
+ "from sklearn.dummy import DummyClassifier\n",
+ "\n",
+ "# Create a dummy classifier with the strategy to predict the most frequent class\n",
+ "dummy_clf = DummyClassifier(strategy=\"most_frequent\")\n",
+ "dummy_clf.fit(X_train, y_train)\n",
+ "\n",
+ "eval(dummy_clf, X_train, y_train, X_val, y_val, X_test, y_test)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 86,
+ "id": "088a8a3a-3b74-4d10-a51b-cb662569ae39",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Training Accuracy: 99.81%\n",
+ "Validation Accuracy: 95.27%\n",
+ "Test Accuracy: 96.03%\n",
+ "\n",
+ "Training Balanced Accuracy: 99.81%\n",
+ "Validation Balanced Accuracy: 95.27%\n",
+ "Test Balanced Accuracy: 96.03%\n"
+ ]
+ }
+ ],
+ "source": [
+ "model = LogisticRegression(max_iter=1000)\n",
+ "model.fit(X_train, y_train)\n",
+ "eval(model, X_train, y_train, X_val, y_val, X_test, y_test)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "34411348-45bc-4b01-bebf-b3602c002ef1",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "5a9bc6b1-c8b9-4d4f-bfe4-c5a4a8b0c756",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.12"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/ch06/03_bonus_imdb-classification/train-bert-hf.py b/ch06/03_bonus_imdb-classification/train-bert-hf.py
new file mode 100644
index 0000000..5337593
--- /dev/null
+++ b/ch06/03_bonus_imdb-classification/train-bert-hf.py
@@ -0,0 +1,270 @@
+# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
+# Source for "Build a Large Language Model From Scratch"
+# - https://www.manning.com/books/build-a-large-language-model-from-scratch
+# Code: https://github.com/rasbt/LLMs-from-scratch
+
+import argparse
+from pathlib import Path
+import time
+
+import pandas as pd
+import torch
+from torch.utils.data import DataLoader
+from torch.utils.data import Dataset
+
+from transformers import AutoTokenizer, AutoModelForSequenceClassification
+
+
+class IMDBDataset(Dataset):
+ def __init__(self, csv_file, tokenizer, max_length=None, pad_token_id=50256):
+ self.data = pd.read_csv(csv_file)
+ self.max_length = max_length if max_length is not None else self._longest_encoded_length(tokenizer)
+
+ # Pre-tokenize texts
+ self.encoded_texts = [
+ tokenizer.encode(text)[:self.max_length]
+ for text in self.data["text"]
+ ]
+ # Pad sequences to the longest sequence
+
+ # Debug
+ pad_token_id = 0
+
+ self.encoded_texts = [
+ et + [pad_token_id] * (self.max_length - len(et))
+ for et in self.encoded_texts
+ ]
+
+ def __getitem__(self, index):
+ encoded = self.encoded_texts[index]
+ label = self.data.iloc[index]["label"]
+ return torch.tensor(encoded, dtype=torch.long), torch.tensor(label, dtype=torch.long)
+
+ def __len__(self):
+ return len(self.data)
+
+ def _longest_encoded_length(self, tokenizer):
+ max_length = 0
+ for text in self.data["text"]:
+ encoded_length = len(tokenizer.encode(text))
+ if encoded_length > max_length:
+ max_length = encoded_length
+ return max_length
+
+
+def calc_loss_batch(input_batch, target_batch, model, device):
+ input_batch, target_batch = input_batch.to(device), target_batch.to(device)
+ # logits = model(input_batch)[:, -1, :] # Logits of last ouput token
+ logits = model(input_batch).logits
+ loss = torch.nn.functional.cross_entropy(logits, target_batch)
+ return loss
+
+
+# Same as in chapter 5
+def calc_loss_loader(data_loader, model, device, num_batches=None):
+ total_loss = 0.
+ if num_batches is None:
+ num_batches = len(data_loader)
+ else:
+ # Reduce the number of batches to match the total number of batches in the data loader
+ # if num_batches exceeds the number of batches in the data loader
+ num_batches = min(num_batches, len(data_loader))
+ for i, (input_batch, target_batch) in enumerate(data_loader):
+ if i < num_batches:
+ loss = calc_loss_batch(input_batch, target_batch, model, device)
+ total_loss += loss.item()
+ else:
+ break
+ return total_loss / num_batches
+
+
+@torch.no_grad() # Disable gradient tracking for efficiency
+def calc_accuracy_loader(data_loader, model, device, num_batches=None):
+ model.eval()
+ correct_predictions, num_examples = 0, 0
+
+ if num_batches is None:
+ num_batches = len(data_loader)
+ else:
+ num_batches = min(num_batches, len(data_loader))
+ for i, (input_batch, target_batch) in enumerate(data_loader):
+ if i < num_batches:
+ input_batch, target_batch = input_batch.to(device), target_batch.to(device)
+ # logits = model(input_batch)[:, -1, :] # Logits of last ouput token
+ logits = model(input_batch).logits
+ predicted_labels = torch.argmax(logits, dim=1)
+ num_examples += predicted_labels.shape[0]
+ correct_predictions += (predicted_labels == target_batch).sum().item()
+ else:
+ break
+ return correct_predictions / num_examples
+
+
+def evaluate_model(model, train_loader, val_loader, device, eval_iter):
+ model.eval()
+ with torch.no_grad():
+ train_loss = calc_loss_loader(train_loader, model, device, num_batches=eval_iter)
+ val_loss = calc_loss_loader(val_loader, model, device, num_batches=eval_iter)
+ model.train()
+ return train_loss, val_loss
+
+
+def train_classifier_simple(model, train_loader, val_loader, optimizer, device, num_epochs,
+ eval_freq, eval_iter, tokenizer, max_steps=None):
+ # Initialize lists to track losses and tokens seen
+ train_losses, val_losses, train_accs, val_accs = [], [], [], []
+ examples_seen, global_step = 0, -1
+
+ # Main training loop
+ for epoch in range(num_epochs):
+ model.train() # Set model to training mode
+
+ for input_batch, target_batch in train_loader:
+ optimizer.zero_grad() # Reset loss gradients from previous epoch
+ loss = calc_loss_batch(input_batch, target_batch, model, device)
+ loss.backward() # Calculate loss gradients
+ optimizer.step() # Update model weights using loss gradients
+ examples_seen += input_batch.shape[0] # New: track examples instead of tokens
+ global_step += 1
+
+ # Optional evaluation step
+ if global_step % eval_freq == 0:
+ train_loss, val_loss = evaluate_model(
+ model, train_loader, val_loader, device, eval_iter)
+ train_losses.append(train_loss)
+ val_losses.append(val_loss)
+ print(f"Ep {epoch+1} (Step {global_step:06d}): "
+ f"Train loss {train_loss:.3f}, Val loss {val_loss:.3f}")
+
+ if max_steps is not None and global_step > max_steps:
+ break
+
+ # New: Calculate accuracy after each epoch
+ train_accuracy = calc_accuracy_loader(train_loader, model, device, num_batches=eval_iter)
+ val_accuracy = calc_accuracy_loader(val_loader, model, device, num_batches=eval_iter)
+ print(f"Training accuracy: {train_accuracy*100:.2f}% | ", end="")
+ print(f"Validation accuracy: {val_accuracy*100:.2f}%")
+ train_accs.append(train_accuracy)
+ val_accs.append(val_accuracy)
+
+ if max_steps is not None and global_step > max_steps:
+ break
+
+ return train_losses, val_losses, train_accs, val_accs, examples_seen
+
+
+if __name__ == "__main__":
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--trainable_layers",
+ type=str,
+ default="last_block",
+ help=(
+ "Which layers to train. Options: 'all', 'last_block', 'last_layer'."
+ )
+ )
+ args = parser.parse_args()
+
+ ###############################
+ # Load model
+ ###############################
+ model = AutoModelForSequenceClassification.from_pretrained(
+ "distilbert-base-uncased", num_labels=2
+ )
+ torch.manual_seed(123)
+ model.out_head = torch.nn.Linear(in_features=768, out_features=2)
+
+ if args.trainable_layers == "last_layer":
+ pass
+ elif args.trainable_layers == "last_block":
+ for param in model.pre_classifier.parameters():
+ param.requires_grad = True
+ for param in model.distilbert.transformer.layer[-1].parameters():
+ param.requires_grad = True
+ elif args.trainable_layers == "all":
+ for param in model.parameters():
+ param.requires_grad = True
+ else:
+ raise ValueError("Invalid --trainable_layers argument.")
+
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ model.to(device)
+
+ ###############################
+ # Instantiate dataloaders
+ ###############################
+
+ url = "https://archive.ics.uci.edu/static/public/228/sms+spam+collection.zip"
+ zip_path = "sms_spam_collection.zip"
+ extract_to = "sms_spam_collection"
+ new_file_path = Path(extract_to) / "SMSSpamCollection.tsv"
+
+ base_path = Path(".")
+ file_names = ["train.csv", "val.csv", "test.csv"]
+ all_exist = all((base_path / file_name).exists() for file_name in file_names)
+
+ tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
+ pad_token_id = tokenizer.encode(tokenizer.pad_token)
+
+ train_dataset = IMDBDataset(base_path / "train.csv", max_length=256, tokenizer=tokenizer, pad_token_id=pad_token_id)
+ val_dataset = IMDBDataset(base_path / "val.csv", max_length=256, tokenizer=tokenizer, pad_token_id=pad_token_id)
+ test_dataset = IMDBDataset(base_path / "test.csv", max_length=256, tokenizer=tokenizer, pad_token_id=pad_token_id)
+
+ num_workers = 0
+ batch_size = 8
+
+ train_loader = DataLoader(
+ dataset=train_dataset,
+ batch_size=batch_size,
+ shuffle=True,
+ num_workers=num_workers,
+ drop_last=True,
+ )
+
+ val_loader = DataLoader(
+ dataset=val_dataset,
+ batch_size=batch_size,
+ num_workers=num_workers,
+ drop_last=False,
+ )
+
+ test_loader = DataLoader(
+ dataset=test_dataset,
+ batch_size=batch_size,
+ num_workers=num_workers,
+ drop_last=False,
+ )
+
+ ###############################
+ # Train model
+ ###############################
+
+ start_time = time.time()
+ torch.manual_seed(123)
+ optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5, weight_decay=0.1)
+
+ num_epochs = 3
+ train_losses, val_losses, train_accs, val_accs, examples_seen = train_classifier_simple(
+ model, train_loader, val_loader, optimizer, device,
+ num_epochs=num_epochs, eval_freq=50, eval_iter=20,
+ tokenizer=tokenizer, max_steps=None
+ )
+
+ end_time = time.time()
+ execution_time_minutes = (end_time - start_time) / 60
+ print(f"Training completed in {execution_time_minutes:.2f} minutes.")
+
+ ###############################
+ # Evaluate model
+ ###############################
+
+ print("\nEvaluating on the full datasets ...\n")
+
+ train_accuracy = calc_accuracy_loader(train_loader, model, device)
+ val_accuracy = calc_accuracy_loader(val_loader, model, device)
+ test_accuracy = calc_accuracy_loader(test_loader, model, device)
+
+ print(f"Training accuracy: {train_accuracy*100:.2f}%")
+ print(f"Validation accuracy: {val_accuracy*100:.2f}%")
+ print(f"Test accuracy: {test_accuracy*100:.2f}%")
diff --git a/ch06/03_bonus_imdb-classification/train-gpt.py b/ch06/03_bonus_imdb-classification/train-gpt.py
new file mode 100644
index 0000000..dda708b
--- /dev/null
+++ b/ch06/03_bonus_imdb-classification/train-gpt.py
@@ -0,0 +1,367 @@
+# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
+# Source for "Build a Large Language Model From Scratch"
+# - https://www.manning.com/books/build-a-large-language-model-from-scratch
+# Code: https://github.com/rasbt/LLMs-from-scratch
+
+import argparse
+from pathlib import Path
+import time
+
+import pandas as pd
+import tiktoken
+import torch
+from torch.utils.data import DataLoader
+from torch.utils.data import Dataset
+
+from gpt_download import download_and_load_gpt2
+from previous_chapters import GPTModel, load_weights_into_gpt
+
+
+class IMDBDataset(Dataset):
+ def __init__(self, csv_file, tokenizer, max_length=None, pad_token_id=50256):
+ self.data = pd.read_csv(csv_file)
+ self.max_length = max_length if max_length is not None else self._longest_encoded_length(tokenizer)
+
+ # Pre-tokenize texts
+ self.encoded_texts = [
+ tokenizer.encode(text)[:self.max_length]
+ for text in self.data["text"]
+ ]
+ # Pad sequences to the longest sequence
+ self.encoded_texts = [
+ et + [pad_token_id] * (self.max_length - len(et))
+ for et in self.encoded_texts
+ ]
+
+ def __getitem__(self, index):
+ encoded = self.encoded_texts[index]
+ label = self.data.iloc[index]["label"]
+ return torch.tensor(encoded, dtype=torch.long), torch.tensor(label, dtype=torch.long)
+
+ def __len__(self):
+ return len(self.data)
+
+ def _longest_encoded_length(self, tokenizer):
+ max_length = 0
+ for text in self.data["text"]:
+ encoded_length = len(tokenizer.encode(text))
+ if encoded_length > max_length:
+ max_length = encoded_length
+ return max_length
+
+
+def instantiate_model(choose_model, load_weights):
+
+ BASE_CONFIG = {
+ "vocab_size": 50257, # Vocabulary size
+ "context_length": 1024, # Context length
+ "drop_rate": 0.0, # Dropout rate
+ "qkv_bias": True # Query-key-value bias
+ }
+
+ model_configs = {
+ "gpt2-small (124M)": {"emb_dim": 768, "n_layers": 12, "n_heads": 12},
+ "gpt2-medium (355M)": {"emb_dim": 1024, "n_layers": 24, "n_heads": 16},
+ "gpt2-large (774M)": {"emb_dim": 1280, "n_layers": 36, "n_heads": 20},
+ "gpt2-xl (1558M)": {"emb_dim": 1600, "n_layers": 48, "n_heads": 25},
+ }
+
+ BASE_CONFIG.update(model_configs[choose_model])
+ model = GPTModel(BASE_CONFIG)
+
+ if load_weights:
+ model_size = choose_model.split(" ")[-1].lstrip("(").rstrip(")")
+ settings, params = download_and_load_gpt2(model_size=model_size, models_dir="gpt2")
+ load_weights_into_gpt(model, params)
+
+ model.eval()
+ return model
+
+
+def calc_loss_batch(input_batch, target_batch, model, device, trainable_token=-1):
+ input_batch, target_batch = input_batch.to(device), target_batch.to(device)
+ logits = model(input_batch)[:, trainable_token, :] # Logits of last ouput token
+ loss = torch.nn.functional.cross_entropy(logits, target_batch)
+ return loss
+
+
+def calc_loss_loader(data_loader, model, device, num_batches=None, trainable_token=-1):
+ total_loss = 0.
+ if len(data_loader) == 0:
+ return float("nan")
+ elif num_batches is None:
+ num_batches = len(data_loader)
+ else:
+ # Reduce the number of batches to match the total number of batches in the data loader
+ # if num_batches exceeds the number of batches in the data loader
+ num_batches = min(num_batches, len(data_loader))
+ for i, (input_batch, target_batch) in enumerate(data_loader):
+ if i < num_batches:
+ loss = calc_loss_batch(input_batch, target_batch, model, device, trainable_token=trainable_token)
+ total_loss += loss.item()
+ else:
+ break
+ return total_loss / num_batches
+
+
+@torch.no_grad() # Disable gradient tracking for efficiency
+def calc_accuracy_loader(data_loader, model, device, num_batches=None, trainable_token=-1):
+ model.eval()
+ correct_predictions, num_examples = 0, 0
+
+ if num_batches is None:
+ num_batches = len(data_loader)
+ else:
+ num_batches = min(num_batches, len(data_loader))
+ for i, (input_batch, target_batch) in enumerate(data_loader):
+ if i < num_batches:
+ input_batch, target_batch = input_batch.to(device), target_batch.to(device)
+ logits = model(input_batch)[:, trainable_token, :] # Logits of last ouput token
+ predicted_labels = torch.argmax(logits, dim=-1)
+
+ num_examples += predicted_labels.shape[0]
+ correct_predictions += (predicted_labels == target_batch).sum().item()
+ else:
+ break
+ return correct_predictions / num_examples
+
+
+def evaluate_model(model, train_loader, val_loader, device, eval_iter, trainable_token=-1):
+ model.eval()
+ with torch.no_grad():
+ train_loss = calc_loss_loader(train_loader, model, device, num_batches=eval_iter, trainable_token=trainable_token)
+ val_loss = calc_loss_loader(val_loader, model, device, num_batches=eval_iter, trainable_token=trainable_token)
+ model.train()
+ return train_loss, val_loss
+
+
+def train_classifier_simple(model, train_loader, val_loader, optimizer, device, num_epochs,
+ eval_freq, eval_iter, tokenizer, max_steps=None, trainable_token=-1):
+ # Initialize lists to track losses and tokens seen
+ train_losses, val_losses, train_accs, val_accs = [], [], [], []
+ examples_seen, global_step = 0, -1
+
+ # Main training loop
+ for epoch in range(num_epochs):
+ model.train() # Set model to training mode
+
+ for input_batch, target_batch in train_loader:
+ optimizer.zero_grad() # Reset loss gradients from previous epoch
+ loss = calc_loss_batch(input_batch, target_batch, model, device, trainable_token=trainable_token)
+ loss.backward() # Calculate loss gradients
+ optimizer.step() # Update model weights using loss gradients
+ examples_seen += input_batch.shape[0] # New: track examples instead of tokens
+ global_step += 1
+
+ # Optional evaluation step
+ if global_step % eval_freq == 0:
+ train_loss, val_loss = evaluate_model(
+ model, train_loader, val_loader, device, eval_iter, trainable_token=trainable_token)
+ train_losses.append(train_loss)
+ val_losses.append(val_loss)
+ print(f"Ep {epoch+1} (Step {global_step:06d}): "
+ f"Train loss {train_loss:.3f}, Val loss {val_loss:.3f}")
+
+ if max_steps is not None and global_step > max_steps:
+ break
+
+ # New: Calculate accuracy after each epoch
+ train_accuracy = calc_accuracy_loader(train_loader, model, device, num_batches=eval_iter, trainable_token=trainable_token)
+ val_accuracy = calc_accuracy_loader(val_loader, model, device, num_batches=eval_iter, trainable_token=trainable_token)
+ print(f"Training accuracy: {train_accuracy*100:.2f}% | ", end="")
+ print(f"Validation accuracy: {val_accuracy*100:.2f}%")
+ train_accs.append(train_accuracy)
+ val_accs.append(val_accuracy)
+
+ if max_steps is not None and global_step > max_steps:
+ break
+
+ return train_losses, val_losses, train_accs, val_accs, examples_seen
+
+
+if __name__ == "__main__":
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--model_size",
+ type=str,
+ default="gpt2-small (124M)",
+ help=(
+ "Which GPT model to use. Options: 'gpt2-small (124M)', 'gpt2-medium (355M)',"
+ " 'gpt2-large (774M)', 'gpt2-xl (1558M)'."
+ )
+ )
+ parser.add_argument(
+ "--weights",
+ type=str,
+ default="pretrained",
+ help=(
+ "Whether to use 'pretrained' or 'random' weights."
+ )
+ )
+ parser.add_argument(
+ "--trainable_layers",
+ type=str,
+ default="last_block",
+ help=(
+ "Which layers to train. Options: 'all', 'last_block', 'last_layer'."
+ )
+ )
+ parser.add_argument(
+ "--trainable_token",
+ type=str,
+ default="last",
+ help=(
+ "Which token to train. Options: 'first', 'last'."
+ )
+ )
+ parser.add_argument(
+ "--context_length",
+ type=str,
+ default="256",
+ help=(
+ "The context length of the data inputs."
+ "Options: 'longest_training_example', 'model_context_length' or integer value."
+ )
+ )
+
+ args = parser.parse_args()
+
+ if args.trainable_token == "first":
+ args.trainable_token = 0
+ elif args.trainable_token == "last":
+ args.trainable_token = -1
+ else:
+ raise ValueError("Invalid --trainable_token argument")
+
+ ###############################
+ # Load model
+ ###############################
+
+ if args.weights == "pretrained":
+ load_weights = True
+ elif args.weights == "random":
+ load_weights = False
+ else:
+ raise ValueError("Invalid --weights argument.")
+
+ model = instantiate_model(args.model_size, load_weights)
+ for param in model.parameters():
+ param.requires_grad = False
+
+ if args.model_size == "gpt2-small (124M)":
+ in_features = 768
+ elif args.model_size == "gpt2-medium (355M)":
+ in_features = 1024
+ elif args.model_size == "gpt2-large (774M)":
+ in_features = 1280
+ elif args.model_size == "gpt2-xl (1558M)":
+ in_features = 1280
+ else:
+ raise ValueError("Invalid --model_size argument")
+
+ torch.manual_seed(123)
+ model.out_head = torch.nn.Linear(in_features=in_features, out_features=2)
+
+ if args.trainable_layers == "last_layer":
+ pass
+ elif args.trainable_layers == "last_block":
+ for param in model.trf_blocks[-1].parameters():
+ param.requires_grad = True
+ for param in model.final_norm.parameters():
+ param.requires_grad = True
+ elif args.trainable_layers == "all":
+ for param in model.parameters():
+ param.requires_grad = True
+ else:
+ raise ValueError("Invalid --trainable_layers argument.")
+
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ model.to(device)
+
+ ###############################
+ # Instantiate dataloaders
+ ###############################
+
+ url = "https://archive.ics.uci.edu/static/public/228/sms+spam+collection.zip"
+ zip_path = "sms_spam_collection.zip"
+ extract_to = "sms_spam_collection"
+ new_file_path = Path(extract_to) / "SMSSpamCollection.tsv"
+
+ base_path = Path(".")
+ file_names = ["train.csv", "val.csv", "test.csv"]
+ all_exist = all((base_path / file_name).exists() for file_name in file_names)
+
+ tokenizer = tiktoken.get_encoding("gpt2")
+
+ if args.context_length == "model_context_length":
+ max_length = model.pos_emb.weight.shape[0]
+ elif args.context_length == "longest_training_example":
+ max_length = None
+ else:
+ try:
+ max_length = int(args.context_length)
+ except ValueError:
+ raise ValueError("Invalid --context_length argument")
+
+ train_dataset = IMDBDataset(base_path / "train.csv", max_length=max_length, tokenizer=tokenizer)
+ val_dataset = IMDBDataset(base_path / "val.csv", max_length=max_length, tokenizer=tokenizer)
+ test_dataset = IMDBDataset(base_path / "test.csv", max_length=max_length, tokenizer=tokenizer)
+
+ num_workers = 0
+ batch_size = 8
+
+ train_loader = DataLoader(
+ dataset=train_dataset,
+ batch_size=batch_size,
+ shuffle=True,
+ num_workers=num_workers,
+ drop_last=True,
+ )
+
+ val_loader = DataLoader(
+ dataset=val_dataset,
+ batch_size=batch_size,
+ num_workers=num_workers,
+ drop_last=False,
+ )
+
+ test_loader = DataLoader(
+ dataset=test_dataset,
+ batch_size=batch_size,
+ num_workers=num_workers,
+ drop_last=False,
+ )
+
+ ###############################
+ # Train model
+ ###############################
+
+ start_time = time.time()
+ torch.manual_seed(123)
+ optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5, weight_decay=0.1)
+
+ num_epochs = 3
+ train_losses, val_losses, train_accs, val_accs, examples_seen = train_classifier_simple(
+ model, train_loader, val_loader, optimizer, device,
+ num_epochs=num_epochs, eval_freq=50, eval_iter=20,
+ tokenizer=tokenizer, max_steps=None, trainable_token=args.trainable_token
+ )
+
+ end_time = time.time()
+ execution_time_minutes = (end_time - start_time) / 60
+ print(f"Training completed in {execution_time_minutes:.2f} minutes.")
+
+ ###############################
+ # Evaluate model
+ ###############################
+
+ print("\nEvaluating on the full datasets ...\n")
+
+ train_accuracy = calc_accuracy_loader(train_loader, model, device, trainable_token=args.trainable_token)
+ val_accuracy = calc_accuracy_loader(val_loader, model, device, trainable_token=args.trainable_token)
+ test_accuracy = calc_accuracy_loader(test_loader, model, device, trainable_token=args.trainable_token)
+
+ print(f"Training accuracy: {train_accuracy*100:.2f}%")
+ print(f"Validation accuracy: {val_accuracy*100:.2f}%")
+ print(f"Test accuracy: {test_accuracy*100:.2f}%")
diff --git a/ch06/03_bonus_imdb-classification/train-sklearn-logreg.py b/ch06/03_bonus_imdb-classification/train-sklearn-logreg.py
new file mode 100644
index 0000000..c7af1a2
--- /dev/null
+++ b/ch06/03_bonus_imdb-classification/train-sklearn-logreg.py
@@ -0,0 +1,75 @@
+# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
+# Source for "Build a Large Language Model From Scratch"
+# - https://www.manning.com/books/build-a-large-language-model-from-scratch
+# Code: https://github.com/rasbt/LLMs-from-scratch
+
+import pandas as pd
+from sklearn.feature_extraction.text import CountVectorizer
+from sklearn.linear_model import LogisticRegression
+from sklearn.metrics import accuracy_score
+# from sklearn.metrics import balanced_accuracy_score
+from sklearn.dummy import DummyClassifier
+
+
+def load_dataframes():
+ df_train = pd.read_csv("train.csv")
+ df_val = pd.read_csv("val.csv")
+ df_test = pd.read_csv("test.csv")
+
+ return df_train, df_val, df_test
+
+
+def eval(model, X_train, y_train, X_val, y_val, X_test, y_test):
+ # Making predictions
+ y_pred_train = model.predict(X_train)
+ y_pred_val = model.predict(X_val)
+ y_pred_test = model.predict(X_test)
+
+ # Calculating accuracy and balanced accuracy
+ accuracy_train = accuracy_score(y_train, y_pred_train)
+ # balanced_accuracy_train = balanced_accuracy_score(y_train, y_pred_train)
+
+ accuracy_val = accuracy_score(y_val, y_pred_val)
+ # balanced_accuracy_val = balanced_accuracy_score(y_val, y_pred_val)
+
+ accuracy_test = accuracy_score(y_test, y_pred_test)
+ # balanced_accuracy_test = balanced_accuracy_score(y_test, y_pred_test)
+
+ # Printing the results
+ print(f"Training Accuracy: {accuracy_train*100:.2f}%")
+ print(f"Validation Accuracy: {accuracy_val*100:.2f}%")
+ print(f"Test Accuracy: {accuracy_test*100:.2f}%")
+
+ # print(f"\nTraining Balanced Accuracy: {balanced_accuracy_train*100:.2f}%")
+ # print(f"Validation Balanced Accuracy: {balanced_accuracy_val*100:.2f}%")
+ # print(f"Test Balanced Accuracy: {balanced_accuracy_test*100:.2f}%")
+
+
+if __name__ == "__main__":
+ df_train, df_val, df_test = load_dataframes()
+
+ #########################################
+ # Convert text into bag-of-words model
+ vectorizer = CountVectorizer()
+ #########################################
+
+ X_train = vectorizer.fit_transform(df_train["text"])
+ X_val = vectorizer.transform(df_val["text"])
+ X_test = vectorizer.transform(df_test["text"])
+ y_train, y_val, y_test = df_train["label"], df_val["label"], df_test["label"]
+
+ #####################################
+ # Model training and evaluation
+ #####################################
+
+ # Create a dummy classifier with the strategy to predict the most frequent class
+ dummy_clf = DummyClassifier(strategy="most_frequent")
+ dummy_clf.fit(X_train, y_train)
+
+ print("Dummy classifier:")
+ eval(dummy_clf, X_train, y_train, X_val, y_val, X_test, y_test)
+
+ print("\n\nLogistic regression classifier:")
+ model = LogisticRegression(max_iter=1000)
+ model.fit(X_train, y_train)
+ eval(model, X_train, y_train, X_val, y_val, X_test, y_test)