IMDB experiments (#128)

* IMDB experiments

* style fixes

* Update README.md
This commit is contained in:
Sebastian Raschka 2024-04-25 07:20:53 -05:00 committed by GitHub
parent 51f4980a42
commit 4bbd476e7a
8 changed files with 1431 additions and 0 deletions

View File

@ -89,6 +89,9 @@ Several folders contain optional materials as a bonus for interested readers:
- [Pretraining GPT on the Project Gutenberg Dataset](ch05/03_bonus_pretraining_on_gutenberg)
- [Adding Bells and Whistles to the Training Loop](ch05/04_learning_rate_schedulers)
- [Optimizing Hyperparameters for Pretraining](ch05/05_bonus_hparam_tuning)
- **Chapter 6:**
- [Additional experiments finetuning different layers and using larger models](ch06/02_bonus_additional-experiments)
- [Finetuning different models on 50k IMDB movie review dataset](ch06/03_bonus_imdb-classification)
<br>
<br>

View File

@ -0,0 +1,117 @@
&nbsp;
## Step 1: Install Dependencies
Install the extra dependencies via
```bash
pip install -r requirements-extra.txt
```
&nbsp;
## Step 2: Download Dataset
The codes are using the 50k movie reviews from IMDb ([dataset source](https://ai.stanford.edu/~amaas/data/sentiment/)) to predict whether a movie review is positive or negative.
Run the following code to create the `train.csv`, `val.csv`, and `test.csv` datasets:
```bash
download-prepare-dataset.py
```
&nbsp;
## Step 3: Run Models
The 124M GPT-2 model used in the main chapter, starting for the pretrained weights and only training the last transformer block plus output layers:
```bash
python train-gpt.py
```
```
Ep 1 (Step 000000): Train loss 2.829, Val loss 3.433
Ep 1 (Step 000050): Train loss 1.440, Val loss 1.669
Ep 1 (Step 000100): Train loss 0.879, Val loss 1.037
Ep 1 (Step 000150): Train loss 0.838, Val loss 0.866
...
Ep 1 (Step 004300): Train loss 0.174, Val loss 0.202
Ep 1 (Step 004350): Train loss 0.309, Val loss 0.190
Training accuracy: 88.75% | Validation accuracy: 91.25%
Ep 2 (Step 004400): Train loss 0.263, Val loss 0.205
Ep 2 (Step 004450): Train loss 0.226, Val loss 0.188
...
Ep 2 (Step 008650): Train loss 0.189, Val loss 0.171
Ep 2 (Step 008700): Train loss 0.225, Val loss 0.179
Training accuracy: 85.00% | Validation accuracy: 90.62%
Ep 3 (Step 008750): Train loss 0.206, Val loss 0.187
Ep 3 (Step 008800): Train loss 0.198, Val loss 0.172
...
Training accuracy: 96.88% | Validation accuracy: 90.62%
Training completed in 18.62 minutes.
Evaluating on the full datasets ...
Training accuracy: 93.66%
Validation accuracy: 90.02%
Test accuracy: 89.96%
```
---
A 66M parameter encoder-style [DistilBERT](https://medium.com/huggingface/distilbert-8cf3380435b5) model (distilled down from a 340M parameter BERT model), starting for the pretrained weights and only training the last transformer block plus output layers:
```bash
python train-bert-hf.py
```
```
Ep 1 (Step 000000): Train loss 0.693, Val loss 0.697
Ep 1 (Step 000050): Train loss 0.532, Val loss 0.596
Ep 1 (Step 000100): Train loss 0.431, Val loss 0.446
...
Ep 1 (Step 004300): Train loss 0.234, Val loss 0.351
Ep 1 (Step 004350): Train loss 0.190, Val loss 0.222
Training accuracy: 88.75% | Validation accuracy: 88.12%
Ep 2 (Step 004400): Train loss 0.258, Val loss 0.270
Ep 2 (Step 004450): Train loss 0.204, Val loss 0.295
...
Ep 2 (Step 008650): Train loss 0.088, Val loss 0.246
Ep 2 (Step 008700): Train loss 0.084, Val loss 0.247
Training accuracy: 98.75% | Validation accuracy: 90.62%
Ep 3 (Step 008750): Train loss 0.067, Val loss 0.209
Ep 3 (Step 008800): Train loss 0.059, Val loss 0.256
...
Ep 3 (Step 013050): Train loss 0.068, Val loss 0.280
Ep 3 (Step 013100): Train loss 0.064, Val loss 0.306
Training accuracy: 99.38% | Validation accuracy: 87.50%
Training completed in 16.70 minutes.
Evaluating on the full datasets ...
Training accuracy: 98.87%
Validation accuracy: 90.98%
Test accuracy: 90.81%
```
---
A scikit-learn Logistic Regression model as a basline.
```bash
python train-sklearn-logreg.py
```
```
Dummy classifier:
Training Accuracy: 50.01%
Validation Accuracy: 50.14%
Test Accuracy: 49.91%
Logistic regression classifier:
Training Accuracy: 99.80%
Validation Accuracy: 88.60%
Test Accuracy: 88.84%
```

View File

@ -0,0 +1,79 @@
# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
# Source for "Build a Large Language Model From Scratch"
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
# Code: https://github.com/rasbt/LLMs-from-scratch
import os
import sys
import tarfile
import time
import urllib.request
import pandas as pd
def reporthook(count, block_size, total_size):
global start_time
if count == 0:
start_time = time.time()
else:
duration = time.time() - start_time
progress_size = int(count * block_size)
percent = count * block_size * 100 / total_size
speed = progress_size / (1024**2 * duration)
sys.stdout.write(
f"\r{int(percent)}% | {progress_size / (1024**2):.2f} MB "
f"| {speed:.2f} MB/s | {duration:.2f} sec elapsed"
)
sys.stdout.flush()
def download_and_extract_dataset(dataset_url, target_file, directory):
if not os.path.exists(directory):
if os.path.exists(target_file):
os.remove(target_file)
urllib.request.urlretrieve(dataset_url, target_file, reporthook)
with tarfile.open(target_file, "r:gz") as tar:
tar.extractall()
else:
print(f"Directory `{directory}` already exists. Skipping download.")
def load_dataset_to_dataframe(basepath="aclImdb", labels={"pos": 1, "neg": 0}):
data_frames = [] # List to store each chunk of DataFrame
for subset in ("test", "train"):
for label in ("pos", "neg"):
path = os.path.join(basepath, subset, label)
for file in sorted(os.listdir(path)):
with open(os.path.join(path, file), "r", encoding="utf-8") as infile:
# Create a DataFrame for each file and add it to the list
data_frames.append(pd.DataFrame({"text": [infile.read()], "label": [labels[label]]}))
# Concatenate all DataFrame chunks together
df = pd.concat(data_frames, ignore_index=True)
df = df.sample(frac=1, random_state=123).reset_index(drop=True) # Shuffle the DataFrame
return df
def partition_and_save(df, sizes=(35000, 5000, 10000)):
# Shuffle the DataFrame
df_shuffled = df.sample(frac=1, random_state=123).reset_index(drop=True)
# Get indices for where to split the data
train_end = sizes[0]
val_end = sizes[0] + sizes[1]
# Split the DataFrame
train = df_shuffled.iloc[:train_end]
val = df_shuffled.iloc[train_end:val_end]
test = df_shuffled.iloc[val_end:]
# Save to CSV files
train.to_csv("train.csv", index=False)
val.to_csv("val.csv", index=False)
test.to_csv("test.csv", index=False)
if __name__ == "__main__":
dataset_url = "http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz"
download_and_extract_dataset(dataset_url, "aclImdb_v1.tar.gz", "aclImdb")
df = load_dataset_to_dataframe()
partition_and_save(df)

View File

@ -0,0 +1,2 @@
transformers>=4.33.2
scikit-learn>=1.3.0

View File

@ -0,0 +1,518 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 75,
"id": "b612c4c1-fa3c-47b9-a8ce-9e32f371e160",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"sms_spam_collection/SMSSpamCollection.tsv already exists. Skipping download and extraction.\n"
]
}
],
"source": [
"import urllib.request\n",
"import zipfile\n",
"import os\n",
"from pathlib import Path\n",
"\n",
"url = \"https://archive.ics.uci.edu/static/public/228/sms+spam+collection.zip\"\n",
"zip_path = \"sms_spam_collection.zip\"\n",
"extract_to = \"sms_spam_collection\"\n",
"new_file_path = Path(extract_to) / \"SMSSpamCollection.tsv\"\n",
"\n",
"def download_and_unzip(url, zip_path, extract_to, new_file_path):\n",
" # Check if the target file already exists\n",
" if new_file_path.exists():\n",
" print(f\"{new_file_path} already exists. Skipping download and extraction.\")\n",
" return\n",
"\n",
" # Downloading the file\n",
" with urllib.request.urlopen(url) as response:\n",
" with open(zip_path, \"wb\") as out_file:\n",
" out_file.write(response.read())\n",
"\n",
" # Unzipping the file\n",
" with zipfile.ZipFile(zip_path, 'r') as zip_ref:\n",
" zip_ref.extractall(extract_to)\n",
"\n",
" # Renaming the file to indicate its format\n",
" original_file = Path(extract_to) / \"SMSSpamCollection\"\n",
" os.rename(original_file, new_file_path)\n",
" print(f\"File download and saved as {new_file_path}\")\n",
"\n",
"# Execute the function\n",
"download_and_unzip(url, zip_path, extract_to, new_file_path)"
]
},
{
"cell_type": "code",
"execution_count": 76,
"id": "69f32433-e19c-4066-b806-8f30b408107f",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Label</th>\n",
" <th>Text</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>ham</td>\n",
" <td>Aight text me when you're back at mu and I'll ...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>ham</td>\n",
" <td>Our Prashanthettan's mother passed away last n...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>ham</td>\n",
" <td>No it will reach by 9 only. She telling she wi...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>ham</td>\n",
" <td>Do you know when the result.</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>spam</td>\n",
" <td>Hi. Customer Loyalty Offer:The NEW Nokia6650 M...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5567</th>\n",
" <td>ham</td>\n",
" <td>I accidentally brought em home in the box</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5568</th>\n",
" <td>spam</td>\n",
" <td>Moby Pub Quiz.Win a £100 High Street prize if ...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5569</th>\n",
" <td>ham</td>\n",
" <td>Que pases un buen tiempo or something like that</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5570</th>\n",
" <td>ham</td>\n",
" <td>Nowadays people are notixiquating the laxinorf...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5571</th>\n",
" <td>ham</td>\n",
" <td>Ard 4 lor...</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>5572 rows × 2 columns</p>\n",
"</div>"
],
"text/plain": [
" Label Text\n",
"0 ham Aight text me when you're back at mu and I'll ...\n",
"1 ham Our Prashanthettan's mother passed away last n...\n",
"2 ham No it will reach by 9 only. She telling she wi...\n",
"3 ham Do you know when the result.\n",
"4 spam Hi. Customer Loyalty Offer:The NEW Nokia6650 M...\n",
"... ... ...\n",
"5567 ham I accidentally brought em home in the box\n",
"5568 spam Moby Pub Quiz.Win a £100 High Street prize if ...\n",
"5569 ham Que pases un buen tiempo or something like that\n",
"5570 ham Nowadays people are notixiquating the laxinorf...\n",
"5571 ham Ard 4 lor...\n",
"\n",
"[5572 rows x 2 columns]"
]
},
"execution_count": 76,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import pandas as pd\n",
"\n",
"df = pd.read_csv(new_file_path, sep=\"\\t\", header=None, names=[\"Label\", \"Text\"])\n",
"df = df.sample(frac=1, random_state=123).reset_index(drop=True) # Shuffle the DataFrame\n",
"df"
]
},
{
"cell_type": "code",
"execution_count": 77,
"id": "4b7beeba-9f3a-45f0-b2dc-76bb155a8f0e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Label\n",
"ham 4825\n",
"spam 747\n",
"Name: count, dtype: int64\n"
]
}
],
"source": [
"# Class distribution\n",
"print(df[\"Label\"].value_counts())"
]
},
{
"cell_type": "code",
"execution_count": 78,
"id": "b3db862a-9e03-4715-babb-9b699e4f4a36",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Label\n",
"spam 747\n",
"ham 747\n",
"Name: count, dtype: int64\n"
]
}
],
"source": [
"# Count the instances of 'spam'\n",
"n_spam = df[df[\"Label\"] == \"spam\"].shape[0]\n",
"\n",
"# Randomly sample 'ham' instances to match the number of 'spam' instances\n",
"ham_sampled = df[df[\"Label\"] == \"ham\"].sample(n_spam)\n",
"\n",
"# Combine the sampled 'ham' with all 'spam'\n",
"balanced_df = pd.concat([ham_sampled, df[df[\"Label\"] == \"spam\"]])\n",
"\n",
"# Shuffle the DataFrame\n",
"balanced_df = balanced_df.sample(frac=1).reset_index(drop=True)\n",
"\n",
"# Now balanced_df is the balanced DataFrame\n",
"print(balanced_df[\"Label\"].value_counts())"
]
},
{
"cell_type": "code",
"execution_count": 79,
"id": "0af991e5-98ef-439a-a43d-63a581a2cc6d",
"metadata": {},
"outputs": [],
"source": [
"df[\"Label\"] = df[\"Label\"].map({\"ham\": 0, \"spam\": 1})"
]
},
{
"cell_type": "code",
"execution_count": 80,
"id": "2f5b00ef-e3ed-4819-b271-5f355848feb5",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Training set:\n",
"Label\n",
"0 0.86612\n",
"1 0.13388\n",
"Name: proportion, dtype: float64\n",
"\n",
"Validation set:\n",
"Label\n",
"0 0.866906\n",
"1 0.133094\n",
"Name: proportion, dtype: float64\n",
"\n",
"Test set:\n",
"Label\n",
"0 0.864816\n",
"1 0.135184\n",
"Name: proportion, dtype: float64\n"
]
}
],
"source": [
"# Define split ratios\n",
"train_size, validation_size = 0.7, 0.1\n",
"# Test size is implied to be 0.2 as the remainder\n",
"\n",
"# Split the data\n",
"def stratified_split(df, stratify_col, train_frac, validation_frac):\n",
" stratified_train = pd.DataFrame()\n",
" stratified_validation = pd.DataFrame()\n",
" stratified_test = pd.DataFrame()\n",
"\n",
" # Stratify split by the unique values in the column\n",
" for value in df[stratify_col].unique():\n",
" # Filter the DataFrame for the class\n",
" df_class = df[df[stratify_col] == value]\n",
" \n",
" # Calculate class split sizes\n",
" train_end = int(len(df_class) * train_frac)\n",
" validation_end = train_end + int(len(df_class) * validation_frac)\n",
" \n",
" # Slice the DataFrame to get the sets\n",
" stratified_train = pd.concat([stratified_train, df_class[:train_end]], axis=0)\n",
" stratified_validation = pd.concat([stratified_validation, df_class[train_end:validation_end]], axis=0)\n",
" stratified_test = pd.concat([stratified_test, df_class[validation_end:]], axis=0)\n",
"\n",
" # Shuffle the sets again\n",
" stratified_train = stratified_train.sample(frac=1, random_state=123).reset_index(drop=True)\n",
" stratified_validation = stratified_validation.sample(frac=1, random_state=123).reset_index(drop=True)\n",
" stratified_test = stratified_test.sample(frac=1, random_state=123).reset_index(drop=True)\n",
"\n",
" return stratified_train, stratified_validation, stratified_test\n",
"\n",
"# Apply the stratified split function\n",
"train_df, validation_df, test_df = stratified_split(df, \"Label\", train_size, validation_size)\n",
"\n",
"# Check the results\n",
"print(f\"Training set:\\n{train_df['Label'].value_counts(normalize=True)}\")\n",
"print(f\"\\nValidation set:\\n{validation_df['Label'].value_counts(normalize=True)}\")\n",
"print(f\"\\nTest set:\\n{test_df['Label'].value_counts(normalize=True)}\")"
]
},
{
"cell_type": "code",
"execution_count": 81,
"id": "65808167-2b93-45b0-8506-ce722732ce77",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Training set:\n",
"Label\n",
"ham 0.5\n",
"spam 0.5\n",
"Name: proportion, dtype: float64\n",
"\n",
"Validation set:\n",
"Label\n",
"ham 0.5\n",
"spam 0.5\n",
"Name: proportion, dtype: float64\n",
"\n",
"Test set:\n",
"Label\n",
"spam 0.5\n",
"ham 0.5\n",
"Name: proportion, dtype: float64\n"
]
}
],
"source": [
"# Define split ratios\n",
"train_size, validation_size = 0.7, 0.1\n",
"# Test size is implied to be 0.2 as the remainder\n",
"\n",
"# Apply the stratified split function\n",
"train_df, validation_df, test_df = stratified_split(balanced_df, \"Label\", train_size, validation_size)\n",
"\n",
"# Check the results\n",
"print(f\"Training set:\\n{train_df['Label'].value_counts(normalize=True)}\")\n",
"print(f\"\\nValidation set:\\n{validation_df['Label'].value_counts(normalize=True)}\")\n",
"print(f\"\\nTest set:\\n{test_df['Label'].value_counts(normalize=True)}\")"
]
},
{
"cell_type": "markdown",
"id": "fae87bc1-14ca-4f89-8e12-49f77b0ec00d",
"metadata": {},
"source": [
"## Scikit-learn baseline"
]
},
{
"cell_type": "code",
"execution_count": 82,
"id": "180318b7-de18-4b05-b84a-ba97c72b9d8e",
"metadata": {},
"outputs": [],
"source": [
"from sklearn.feature_extraction.text import CountVectorizer\n",
"from sklearn.linear_model import LogisticRegression\n",
"from sklearn.metrics import accuracy_score, balanced_accuracy_score"
]
},
{
"cell_type": "code",
"execution_count": 83,
"id": "25090b7c-f516-4be2-8083-3a7187fe4635",
"metadata": {},
"outputs": [],
"source": [
"vectorizer = CountVectorizer()\n",
"\n",
"X_train = vectorizer.fit_transform(train_df[\"Text\"])\n",
"X_val = vectorizer.transform(validation_df[\"Text\"])\n",
"X_test = vectorizer.transform(test_df[\"Text\"])\n",
"\n",
"y_train, y_val, y_test = train_df[\"Label\"], validation_df[\"Label\"], test_df[\"Label\"]"
]
},
{
"cell_type": "code",
"execution_count": 84,
"id": "0247de3a-88f0-4b9c-becd-157baf3acf49",
"metadata": {},
"outputs": [],
"source": [
"def eval(model, X_train, y_train, X_val, y_val, X_test, y_test):\n",
" # Making predictions\n",
" y_pred_train = model.predict(X_train)\n",
" y_pred_val = model.predict(X_val)\n",
" y_pred_test = model.predict(X_test)\n",
" \n",
" # Calculating accuracy and balanced accuracy\n",
" accuracy_train = accuracy_score(y_train, y_pred_train)\n",
" balanced_accuracy_train = balanced_accuracy_score(y_train, y_pred_train)\n",
" \n",
" accuracy_val = accuracy_score(y_val, y_pred_val)\n",
" balanced_accuracy_val = balanced_accuracy_score(y_val, y_pred_val)\n",
"\n",
" accuracy_test = accuracy_score(y_test, y_pred_test)\n",
" balanced_accuracy_test = balanced_accuracy_score(y_test, y_pred_test)\n",
" \n",
" # Printing the results\n",
" print(f\"Training Accuracy: {accuracy_train*100:.2f}%\")\n",
" print(f\"Validation Accuracy: {accuracy_val*100:.2f}%\")\n",
" print(f\"Test Accuracy: {accuracy_test*100:.2f}%\")\n",
" \n",
" print(f\"\\nTraining Balanced Accuracy: {balanced_accuracy_train*100:.2f}%\")\n",
" print(f\"Validation Balanced Accuracy: {balanced_accuracy_val*100:.2f}%\")\n",
" print(f\"Test Balanced Accuracy: {balanced_accuracy_test*100:.2f}%\")"
]
},
{
"cell_type": "code",
"execution_count": 85,
"id": "c29c6dfc-f72d-40ab-8cb5-783aad1a15ab",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Training Accuracy: 50.00%\n",
"Validation Accuracy: 50.00%\n",
"Test Accuracy: 50.00%\n",
"\n",
"Training Balanced Accuracy: 50.00%\n",
"Validation Balanced Accuracy: 50.00%\n",
"Test Balanced Accuracy: 50.00%\n"
]
}
],
"source": [
"from sklearn.dummy import DummyClassifier\n",
"\n",
"# Create a dummy classifier with the strategy to predict the most frequent class\n",
"dummy_clf = DummyClassifier(strategy=\"most_frequent\")\n",
"dummy_clf.fit(X_train, y_train)\n",
"\n",
"eval(dummy_clf, X_train, y_train, X_val, y_val, X_test, y_test)"
]
},
{
"cell_type": "code",
"execution_count": 86,
"id": "088a8a3a-3b74-4d10-a51b-cb662569ae39",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Training Accuracy: 99.81%\n",
"Validation Accuracy: 95.27%\n",
"Test Accuracy: 96.03%\n",
"\n",
"Training Balanced Accuracy: 99.81%\n",
"Validation Balanced Accuracy: 95.27%\n",
"Test Balanced Accuracy: 96.03%\n"
]
}
],
"source": [
"model = LogisticRegression(max_iter=1000)\n",
"model.fit(X_train, y_train)\n",
"eval(model, X_train, y_train, X_val, y_val, X_test, y_test)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "34411348-45bc-4b01-bebf-b3602c002ef1",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "5a9bc6b1-c8b9-4d4f-bfe4-c5a4a8b0c756",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@ -0,0 +1,270 @@
# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
# Source for "Build a Large Language Model From Scratch"
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
# Code: https://github.com/rasbt/LLMs-from-scratch
import argparse
from pathlib import Path
import time
import pandas as pd
import torch
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification
class IMDBDataset(Dataset):
def __init__(self, csv_file, tokenizer, max_length=None, pad_token_id=50256):
self.data = pd.read_csv(csv_file)
self.max_length = max_length if max_length is not None else self._longest_encoded_length(tokenizer)
# Pre-tokenize texts
self.encoded_texts = [
tokenizer.encode(text)[:self.max_length]
for text in self.data["text"]
]
# Pad sequences to the longest sequence
# Debug
pad_token_id = 0
self.encoded_texts = [
et + [pad_token_id] * (self.max_length - len(et))
for et in self.encoded_texts
]
def __getitem__(self, index):
encoded = self.encoded_texts[index]
label = self.data.iloc[index]["label"]
return torch.tensor(encoded, dtype=torch.long), torch.tensor(label, dtype=torch.long)
def __len__(self):
return len(self.data)
def _longest_encoded_length(self, tokenizer):
max_length = 0
for text in self.data["text"]:
encoded_length = len(tokenizer.encode(text))
if encoded_length > max_length:
max_length = encoded_length
return max_length
def calc_loss_batch(input_batch, target_batch, model, device):
input_batch, target_batch = input_batch.to(device), target_batch.to(device)
# logits = model(input_batch)[:, -1, :] # Logits of last ouput token
logits = model(input_batch).logits
loss = torch.nn.functional.cross_entropy(logits, target_batch)
return loss
# Same as in chapter 5
def calc_loss_loader(data_loader, model, device, num_batches=None):
total_loss = 0.
if num_batches is None:
num_batches = len(data_loader)
else:
# Reduce the number of batches to match the total number of batches in the data loader
# if num_batches exceeds the number of batches in the data loader
num_batches = min(num_batches, len(data_loader))
for i, (input_batch, target_batch) in enumerate(data_loader):
if i < num_batches:
loss = calc_loss_batch(input_batch, target_batch, model, device)
total_loss += loss.item()
else:
break
return total_loss / num_batches
@torch.no_grad() # Disable gradient tracking for efficiency
def calc_accuracy_loader(data_loader, model, device, num_batches=None):
model.eval()
correct_predictions, num_examples = 0, 0
if num_batches is None:
num_batches = len(data_loader)
else:
num_batches = min(num_batches, len(data_loader))
for i, (input_batch, target_batch) in enumerate(data_loader):
if i < num_batches:
input_batch, target_batch = input_batch.to(device), target_batch.to(device)
# logits = model(input_batch)[:, -1, :] # Logits of last ouput token
logits = model(input_batch).logits
predicted_labels = torch.argmax(logits, dim=1)
num_examples += predicted_labels.shape[0]
correct_predictions += (predicted_labels == target_batch).sum().item()
else:
break
return correct_predictions / num_examples
def evaluate_model(model, train_loader, val_loader, device, eval_iter):
model.eval()
with torch.no_grad():
train_loss = calc_loss_loader(train_loader, model, device, num_batches=eval_iter)
val_loss = calc_loss_loader(val_loader, model, device, num_batches=eval_iter)
model.train()
return train_loss, val_loss
def train_classifier_simple(model, train_loader, val_loader, optimizer, device, num_epochs,
eval_freq, eval_iter, tokenizer, max_steps=None):
# Initialize lists to track losses and tokens seen
train_losses, val_losses, train_accs, val_accs = [], [], [], []
examples_seen, global_step = 0, -1
# Main training loop
for epoch in range(num_epochs):
model.train() # Set model to training mode
for input_batch, target_batch in train_loader:
optimizer.zero_grad() # Reset loss gradients from previous epoch
loss = calc_loss_batch(input_batch, target_batch, model, device)
loss.backward() # Calculate loss gradients
optimizer.step() # Update model weights using loss gradients
examples_seen += input_batch.shape[0] # New: track examples instead of tokens
global_step += 1
# Optional evaluation step
if global_step % eval_freq == 0:
train_loss, val_loss = evaluate_model(
model, train_loader, val_loader, device, eval_iter)
train_losses.append(train_loss)
val_losses.append(val_loss)
print(f"Ep {epoch+1} (Step {global_step:06d}): "
f"Train loss {train_loss:.3f}, Val loss {val_loss:.3f}")
if max_steps is not None and global_step > max_steps:
break
# New: Calculate accuracy after each epoch
train_accuracy = calc_accuracy_loader(train_loader, model, device, num_batches=eval_iter)
val_accuracy = calc_accuracy_loader(val_loader, model, device, num_batches=eval_iter)
print(f"Training accuracy: {train_accuracy*100:.2f}% | ", end="")
print(f"Validation accuracy: {val_accuracy*100:.2f}%")
train_accs.append(train_accuracy)
val_accs.append(val_accuracy)
if max_steps is not None and global_step > max_steps:
break
return train_losses, val_losses, train_accs, val_accs, examples_seen
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--trainable_layers",
type=str,
default="last_block",
help=(
"Which layers to train. Options: 'all', 'last_block', 'last_layer'."
)
)
args = parser.parse_args()
###############################
# Load model
###############################
model = AutoModelForSequenceClassification.from_pretrained(
"distilbert-base-uncased", num_labels=2
)
torch.manual_seed(123)
model.out_head = torch.nn.Linear(in_features=768, out_features=2)
if args.trainable_layers == "last_layer":
pass
elif args.trainable_layers == "last_block":
for param in model.pre_classifier.parameters():
param.requires_grad = True
for param in model.distilbert.transformer.layer[-1].parameters():
param.requires_grad = True
elif args.trainable_layers == "all":
for param in model.parameters():
param.requires_grad = True
else:
raise ValueError("Invalid --trainable_layers argument.")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
###############################
# Instantiate dataloaders
###############################
url = "https://archive.ics.uci.edu/static/public/228/sms+spam+collection.zip"
zip_path = "sms_spam_collection.zip"
extract_to = "sms_spam_collection"
new_file_path = Path(extract_to) / "SMSSpamCollection.tsv"
base_path = Path(".")
file_names = ["train.csv", "val.csv", "test.csv"]
all_exist = all((base_path / file_name).exists() for file_name in file_names)
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
pad_token_id = tokenizer.encode(tokenizer.pad_token)
train_dataset = IMDBDataset(base_path / "train.csv", max_length=256, tokenizer=tokenizer, pad_token_id=pad_token_id)
val_dataset = IMDBDataset(base_path / "val.csv", max_length=256, tokenizer=tokenizer, pad_token_id=pad_token_id)
test_dataset = IMDBDataset(base_path / "test.csv", max_length=256, tokenizer=tokenizer, pad_token_id=pad_token_id)
num_workers = 0
batch_size = 8
train_loader = DataLoader(
dataset=train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
drop_last=True,
)
val_loader = DataLoader(
dataset=val_dataset,
batch_size=batch_size,
num_workers=num_workers,
drop_last=False,
)
test_loader = DataLoader(
dataset=test_dataset,
batch_size=batch_size,
num_workers=num_workers,
drop_last=False,
)
###############################
# Train model
###############################
start_time = time.time()
torch.manual_seed(123)
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5, weight_decay=0.1)
num_epochs = 3
train_losses, val_losses, train_accs, val_accs, examples_seen = train_classifier_simple(
model, train_loader, val_loader, optimizer, device,
num_epochs=num_epochs, eval_freq=50, eval_iter=20,
tokenizer=tokenizer, max_steps=None
)
end_time = time.time()
execution_time_minutes = (end_time - start_time) / 60
print(f"Training completed in {execution_time_minutes:.2f} minutes.")
###############################
# Evaluate model
###############################
print("\nEvaluating on the full datasets ...\n")
train_accuracy = calc_accuracy_loader(train_loader, model, device)
val_accuracy = calc_accuracy_loader(val_loader, model, device)
test_accuracy = calc_accuracy_loader(test_loader, model, device)
print(f"Training accuracy: {train_accuracy*100:.2f}%")
print(f"Validation accuracy: {val_accuracy*100:.2f}%")
print(f"Test accuracy: {test_accuracy*100:.2f}%")

View File

@ -0,0 +1,367 @@
# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
# Source for "Build a Large Language Model From Scratch"
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
# Code: https://github.com/rasbt/LLMs-from-scratch
import argparse
from pathlib import Path
import time
import pandas as pd
import tiktoken
import torch
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from gpt_download import download_and_load_gpt2
from previous_chapters import GPTModel, load_weights_into_gpt
class IMDBDataset(Dataset):
def __init__(self, csv_file, tokenizer, max_length=None, pad_token_id=50256):
self.data = pd.read_csv(csv_file)
self.max_length = max_length if max_length is not None else self._longest_encoded_length(tokenizer)
# Pre-tokenize texts
self.encoded_texts = [
tokenizer.encode(text)[:self.max_length]
for text in self.data["text"]
]
# Pad sequences to the longest sequence
self.encoded_texts = [
et + [pad_token_id] * (self.max_length - len(et))
for et in self.encoded_texts
]
def __getitem__(self, index):
encoded = self.encoded_texts[index]
label = self.data.iloc[index]["label"]
return torch.tensor(encoded, dtype=torch.long), torch.tensor(label, dtype=torch.long)
def __len__(self):
return len(self.data)
def _longest_encoded_length(self, tokenizer):
max_length = 0
for text in self.data["text"]:
encoded_length = len(tokenizer.encode(text))
if encoded_length > max_length:
max_length = encoded_length
return max_length
def instantiate_model(choose_model, load_weights):
BASE_CONFIG = {
"vocab_size": 50257, # Vocabulary size
"context_length": 1024, # Context length
"drop_rate": 0.0, # Dropout rate
"qkv_bias": True # Query-key-value bias
}
model_configs = {
"gpt2-small (124M)": {"emb_dim": 768, "n_layers": 12, "n_heads": 12},
"gpt2-medium (355M)": {"emb_dim": 1024, "n_layers": 24, "n_heads": 16},
"gpt2-large (774M)": {"emb_dim": 1280, "n_layers": 36, "n_heads": 20},
"gpt2-xl (1558M)": {"emb_dim": 1600, "n_layers": 48, "n_heads": 25},
}
BASE_CONFIG.update(model_configs[choose_model])
model = GPTModel(BASE_CONFIG)
if load_weights:
model_size = choose_model.split(" ")[-1].lstrip("(").rstrip(")")
settings, params = download_and_load_gpt2(model_size=model_size, models_dir="gpt2")
load_weights_into_gpt(model, params)
model.eval()
return model
def calc_loss_batch(input_batch, target_batch, model, device, trainable_token=-1):
input_batch, target_batch = input_batch.to(device), target_batch.to(device)
logits = model(input_batch)[:, trainable_token, :] # Logits of last ouput token
loss = torch.nn.functional.cross_entropy(logits, target_batch)
return loss
def calc_loss_loader(data_loader, model, device, num_batches=None, trainable_token=-1):
total_loss = 0.
if len(data_loader) == 0:
return float("nan")
elif num_batches is None:
num_batches = len(data_loader)
else:
# Reduce the number of batches to match the total number of batches in the data loader
# if num_batches exceeds the number of batches in the data loader
num_batches = min(num_batches, len(data_loader))
for i, (input_batch, target_batch) in enumerate(data_loader):
if i < num_batches:
loss = calc_loss_batch(input_batch, target_batch, model, device, trainable_token=trainable_token)
total_loss += loss.item()
else:
break
return total_loss / num_batches
@torch.no_grad() # Disable gradient tracking for efficiency
def calc_accuracy_loader(data_loader, model, device, num_batches=None, trainable_token=-1):
model.eval()
correct_predictions, num_examples = 0, 0
if num_batches is None:
num_batches = len(data_loader)
else:
num_batches = min(num_batches, len(data_loader))
for i, (input_batch, target_batch) in enumerate(data_loader):
if i < num_batches:
input_batch, target_batch = input_batch.to(device), target_batch.to(device)
logits = model(input_batch)[:, trainable_token, :] # Logits of last ouput token
predicted_labels = torch.argmax(logits, dim=-1)
num_examples += predicted_labels.shape[0]
correct_predictions += (predicted_labels == target_batch).sum().item()
else:
break
return correct_predictions / num_examples
def evaluate_model(model, train_loader, val_loader, device, eval_iter, trainable_token=-1):
model.eval()
with torch.no_grad():
train_loss = calc_loss_loader(train_loader, model, device, num_batches=eval_iter, trainable_token=trainable_token)
val_loss = calc_loss_loader(val_loader, model, device, num_batches=eval_iter, trainable_token=trainable_token)
model.train()
return train_loss, val_loss
def train_classifier_simple(model, train_loader, val_loader, optimizer, device, num_epochs,
eval_freq, eval_iter, tokenizer, max_steps=None, trainable_token=-1):
# Initialize lists to track losses and tokens seen
train_losses, val_losses, train_accs, val_accs = [], [], [], []
examples_seen, global_step = 0, -1
# Main training loop
for epoch in range(num_epochs):
model.train() # Set model to training mode
for input_batch, target_batch in train_loader:
optimizer.zero_grad() # Reset loss gradients from previous epoch
loss = calc_loss_batch(input_batch, target_batch, model, device, trainable_token=trainable_token)
loss.backward() # Calculate loss gradients
optimizer.step() # Update model weights using loss gradients
examples_seen += input_batch.shape[0] # New: track examples instead of tokens
global_step += 1
# Optional evaluation step
if global_step % eval_freq == 0:
train_loss, val_loss = evaluate_model(
model, train_loader, val_loader, device, eval_iter, trainable_token=trainable_token)
train_losses.append(train_loss)
val_losses.append(val_loss)
print(f"Ep {epoch+1} (Step {global_step:06d}): "
f"Train loss {train_loss:.3f}, Val loss {val_loss:.3f}")
if max_steps is not None and global_step > max_steps:
break
# New: Calculate accuracy after each epoch
train_accuracy = calc_accuracy_loader(train_loader, model, device, num_batches=eval_iter, trainable_token=trainable_token)
val_accuracy = calc_accuracy_loader(val_loader, model, device, num_batches=eval_iter, trainable_token=trainable_token)
print(f"Training accuracy: {train_accuracy*100:.2f}% | ", end="")
print(f"Validation accuracy: {val_accuracy*100:.2f}%")
train_accs.append(train_accuracy)
val_accs.append(val_accuracy)
if max_steps is not None and global_step > max_steps:
break
return train_losses, val_losses, train_accs, val_accs, examples_seen
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_size",
type=str,
default="gpt2-small (124M)",
help=(
"Which GPT model to use. Options: 'gpt2-small (124M)', 'gpt2-medium (355M)',"
" 'gpt2-large (774M)', 'gpt2-xl (1558M)'."
)
)
parser.add_argument(
"--weights",
type=str,
default="pretrained",
help=(
"Whether to use 'pretrained' or 'random' weights."
)
)
parser.add_argument(
"--trainable_layers",
type=str,
default="last_block",
help=(
"Which layers to train. Options: 'all', 'last_block', 'last_layer'."
)
)
parser.add_argument(
"--trainable_token",
type=str,
default="last",
help=(
"Which token to train. Options: 'first', 'last'."
)
)
parser.add_argument(
"--context_length",
type=str,
default="256",
help=(
"The context length of the data inputs."
"Options: 'longest_training_example', 'model_context_length' or integer value."
)
)
args = parser.parse_args()
if args.trainable_token == "first":
args.trainable_token = 0
elif args.trainable_token == "last":
args.trainable_token = -1
else:
raise ValueError("Invalid --trainable_token argument")
###############################
# Load model
###############################
if args.weights == "pretrained":
load_weights = True
elif args.weights == "random":
load_weights = False
else:
raise ValueError("Invalid --weights argument.")
model = instantiate_model(args.model_size, load_weights)
for param in model.parameters():
param.requires_grad = False
if args.model_size == "gpt2-small (124M)":
in_features = 768
elif args.model_size == "gpt2-medium (355M)":
in_features = 1024
elif args.model_size == "gpt2-large (774M)":
in_features = 1280
elif args.model_size == "gpt2-xl (1558M)":
in_features = 1280
else:
raise ValueError("Invalid --model_size argument")
torch.manual_seed(123)
model.out_head = torch.nn.Linear(in_features=in_features, out_features=2)
if args.trainable_layers == "last_layer":
pass
elif args.trainable_layers == "last_block":
for param in model.trf_blocks[-1].parameters():
param.requires_grad = True
for param in model.final_norm.parameters():
param.requires_grad = True
elif args.trainable_layers == "all":
for param in model.parameters():
param.requires_grad = True
else:
raise ValueError("Invalid --trainable_layers argument.")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
###############################
# Instantiate dataloaders
###############################
url = "https://archive.ics.uci.edu/static/public/228/sms+spam+collection.zip"
zip_path = "sms_spam_collection.zip"
extract_to = "sms_spam_collection"
new_file_path = Path(extract_to) / "SMSSpamCollection.tsv"
base_path = Path(".")
file_names = ["train.csv", "val.csv", "test.csv"]
all_exist = all((base_path / file_name).exists() for file_name in file_names)
tokenizer = tiktoken.get_encoding("gpt2")
if args.context_length == "model_context_length":
max_length = model.pos_emb.weight.shape[0]
elif args.context_length == "longest_training_example":
max_length = None
else:
try:
max_length = int(args.context_length)
except ValueError:
raise ValueError("Invalid --context_length argument")
train_dataset = IMDBDataset(base_path / "train.csv", max_length=max_length, tokenizer=tokenizer)
val_dataset = IMDBDataset(base_path / "val.csv", max_length=max_length, tokenizer=tokenizer)
test_dataset = IMDBDataset(base_path / "test.csv", max_length=max_length, tokenizer=tokenizer)
num_workers = 0
batch_size = 8
train_loader = DataLoader(
dataset=train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
drop_last=True,
)
val_loader = DataLoader(
dataset=val_dataset,
batch_size=batch_size,
num_workers=num_workers,
drop_last=False,
)
test_loader = DataLoader(
dataset=test_dataset,
batch_size=batch_size,
num_workers=num_workers,
drop_last=False,
)
###############################
# Train model
###############################
start_time = time.time()
torch.manual_seed(123)
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5, weight_decay=0.1)
num_epochs = 3
train_losses, val_losses, train_accs, val_accs, examples_seen = train_classifier_simple(
model, train_loader, val_loader, optimizer, device,
num_epochs=num_epochs, eval_freq=50, eval_iter=20,
tokenizer=tokenizer, max_steps=None, trainable_token=args.trainable_token
)
end_time = time.time()
execution_time_minutes = (end_time - start_time) / 60
print(f"Training completed in {execution_time_minutes:.2f} minutes.")
###############################
# Evaluate model
###############################
print("\nEvaluating on the full datasets ...\n")
train_accuracy = calc_accuracy_loader(train_loader, model, device, trainable_token=args.trainable_token)
val_accuracy = calc_accuracy_loader(val_loader, model, device, trainable_token=args.trainable_token)
test_accuracy = calc_accuracy_loader(test_loader, model, device, trainable_token=args.trainable_token)
print(f"Training accuracy: {train_accuracy*100:.2f}%")
print(f"Validation accuracy: {val_accuracy*100:.2f}%")
print(f"Test accuracy: {test_accuracy*100:.2f}%")

View File

@ -0,0 +1,75 @@
# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
# Source for "Build a Large Language Model From Scratch"
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
# Code: https://github.com/rasbt/LLMs-from-scratch
import pandas as pd
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
# from sklearn.metrics import balanced_accuracy_score
from sklearn.dummy import DummyClassifier
def load_dataframes():
df_train = pd.read_csv("train.csv")
df_val = pd.read_csv("val.csv")
df_test = pd.read_csv("test.csv")
return df_train, df_val, df_test
def eval(model, X_train, y_train, X_val, y_val, X_test, y_test):
# Making predictions
y_pred_train = model.predict(X_train)
y_pred_val = model.predict(X_val)
y_pred_test = model.predict(X_test)
# Calculating accuracy and balanced accuracy
accuracy_train = accuracy_score(y_train, y_pred_train)
# balanced_accuracy_train = balanced_accuracy_score(y_train, y_pred_train)
accuracy_val = accuracy_score(y_val, y_pred_val)
# balanced_accuracy_val = balanced_accuracy_score(y_val, y_pred_val)
accuracy_test = accuracy_score(y_test, y_pred_test)
# balanced_accuracy_test = balanced_accuracy_score(y_test, y_pred_test)
# Printing the results
print(f"Training Accuracy: {accuracy_train*100:.2f}%")
print(f"Validation Accuracy: {accuracy_val*100:.2f}%")
print(f"Test Accuracy: {accuracy_test*100:.2f}%")
# print(f"\nTraining Balanced Accuracy: {balanced_accuracy_train*100:.2f}%")
# print(f"Validation Balanced Accuracy: {balanced_accuracy_val*100:.2f}%")
# print(f"Test Balanced Accuracy: {balanced_accuracy_test*100:.2f}%")
if __name__ == "__main__":
df_train, df_val, df_test = load_dataframes()
#########################################
# Convert text into bag-of-words model
vectorizer = CountVectorizer()
#########################################
X_train = vectorizer.fit_transform(df_train["text"])
X_val = vectorizer.transform(df_val["text"])
X_test = vectorizer.transform(df_test["text"])
y_train, y_val, y_test = df_train["label"], df_val["label"], df_test["label"]
#####################################
# Model training and evaluation
#####################################
# Create a dummy classifier with the strategy to predict the most frequent class
dummy_clf = DummyClassifier(strategy="most_frequent")
dummy_clf.fit(X_train, y_train)
print("Dummy classifier:")
eval(dummy_clf, X_train, y_train, X_val, y_val, X_test, y_test)
print("\n\nLogistic regression classifier:")
model = LogisticRegression(max_iter=1000)
model.fit(X_train, y_train)
eval(model, X_train, y_train, X_val, y_val, X_test, y_test)