update dataset naming

This commit is contained in:
rasbt 2024-05-12 09:22:42 -05:00
parent 55c3a91838
commit 2e47a6e61c
No known key found for this signature in database
GPG Key ID: 3C6E5C7C075611DB
6 changed files with 72 additions and 346 deletions

5
.gitignore vendored
View File

@ -24,6 +24,11 @@ ch06/01_main-chapter-code/sms_spam_collection
ch06/01_main-chapter-code/test.csv
ch06/01_main-chapter-code/train.csv
ch06/01_main-chapter-code/validation.csv
ch06/03_bonus_imdb-classification/aclImdb/
ch06/03_bonus_imdb-classification/aclImdb_v1.tar.gz
ch06/03_bonus_imdb-classification/test.csv
ch06/03_bonus_imdb-classification/train.csv
ch06/03_bonus_imdb-classification/validation.csv
# Temporary OS-related files
.DS_Store

View File

@ -1415,7 +1415,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
"version": "3.11.4"
}
},
"nbformat": 4,

View File

@ -2347,7 +2347,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
"version": "3.11.4"
}
},
"nbformat": 4,

View File

@ -1,59 +1,50 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "8b6e1cdd-b14e-4368-bdbb-9bf7ab821791",
"metadata": {},
"source": [
"# Scikit-learn Logistic Regression Model"
]
},
{
"cell_type": "code",
"execution_count": 75,
"id": "b612c4c1-fa3c-47b9-a8ce-9e32f371e160",
"execution_count": 1,
"id": "c2a72242-6197-4bef-aa05-696a152350d5",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"sms_spam_collection/SMSSpamCollection.tsv already exists. Skipping download and extraction.\n"
"100% | 80.23 MB | 4.37 MB/s | 18.38 sec elapsed"
]
}
],
"source": [
"import urllib.request\n",
"import zipfile\n",
"import os\n",
"from pathlib import Path\n",
"\n",
"url = \"https://archive.ics.uci.edu/static/public/228/sms+spam+collection.zip\"\n",
"zip_path = \"sms_spam_collection.zip\"\n",
"extract_to = \"sms_spam_collection\"\n",
"new_file_path = Path(extract_to) / \"SMSSpamCollection.tsv\"\n",
"\n",
"def download_and_unzip(url, zip_path, extract_to, new_file_path):\n",
" # Check if the target file already exists\n",
" if new_file_path.exists():\n",
" print(f\"{new_file_path} already exists. Skipping download and extraction.\")\n",
" return\n",
"\n",
" # Downloading the file\n",
" with urllib.request.urlopen(url) as response:\n",
" with open(zip_path, \"wb\") as out_file:\n",
" out_file.write(response.read())\n",
"\n",
" # Unzipping the file\n",
" with zipfile.ZipFile(zip_path, 'r') as zip_ref:\n",
" zip_ref.extractall(extract_to)\n",
"\n",
" # Renaming the file to indicate its format\n",
" original_file = Path(extract_to) / \"SMSSpamCollection\"\n",
" os.rename(original_file, new_file_path)\n",
" print(f\"File download and saved as {new_file_path}\")\n",
"\n",
"# Execute the function\n",
"download_and_unzip_spam_data(url, zip_path, extract_to, new_file_path)"
"!python download-prepare-dataset.py"
]
},
{
"cell_type": "code",
"execution_count": 76,
"execution_count": 14,
"id": "69f32433-e19c-4066-b806-8f30b408107f",
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"\n",
"train_df = pd.read_csv(\"train.csv\")\n",
"val_df = pd.read_csv(\"validation.csv\")\n",
"test_df = pd.read_csv(\"test.csv\")"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "0808b212-fe91-48d9-80b8-55519f8835d5",
"metadata": {},
"outputs": [
{
"data": {
@ -76,280 +67,56 @@
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Label</th>\n",
" <th>Text</th>\n",
" <th>text</th>\n",
" <th>label</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>ham</td>\n",
" <td>Aight text me when you're back at mu and I'll ...</td>\n",
" <td>The only reason I saw \"Shakedown\" was that it ...</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>ham</td>\n",
" <td>Our Prashanthettan's mother passed away last n...</td>\n",
" <td>This is absolute drivel, designed to shock and...</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>ham</td>\n",
" <td>No it will reach by 9 only. She telling she wi...</td>\n",
" <td>Lots of scenes and dialogue are flat-out goofy...</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>ham</td>\n",
" <td>Do you know when the result.</td>\n",
" <td>** and 1/2 stars out of **** Lifeforce is one ...</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>spam</td>\n",
" <td>Hi. Customer Loyalty Offer:The NEW Nokia6650 M...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5567</th>\n",
" <td>ham</td>\n",
" <td>I accidentally brought em home in the box</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5568</th>\n",
" <td>spam</td>\n",
" <td>Moby Pub Quiz.Win a £100 High Street prize if ...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5569</th>\n",
" <td>ham</td>\n",
" <td>Que pases un buen tiempo or something like that</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5570</th>\n",
" <td>ham</td>\n",
" <td>Nowadays people are notixiquating the laxinorf...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5571</th>\n",
" <td>ham</td>\n",
" <td>Ard 4 lor...</td>\n",
" <td>I learned a thing: you have to take this film ...</td>\n",
" <td>1</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>5572 rows × 2 columns</p>\n",
"</div>"
],
"text/plain": [
" Label Text\n",
"0 ham Aight text me when you're back at mu and I'll ...\n",
"1 ham Our Prashanthettan's mother passed away last n...\n",
"2 ham No it will reach by 9 only. She telling she wi...\n",
"3 ham Do you know when the result.\n",
"4 spam Hi. Customer Loyalty Offer:The NEW Nokia6650 M...\n",
"... ... ...\n",
"5567 ham I accidentally brought em home in the box\n",
"5568 spam Moby Pub Quiz.Win a £100 High Street prize if ...\n",
"5569 ham Que pases un buen tiempo or something like that\n",
"5570 ham Nowadays people are notixiquating the laxinorf...\n",
"5571 ham Ard 4 lor...\n",
"\n",
"[5572 rows x 2 columns]"
" text label\n",
"0 The only reason I saw \"Shakedown\" was that it ... 0\n",
"1 This is absolute drivel, designed to shock and... 0\n",
"2 Lots of scenes and dialogue are flat-out goofy... 1\n",
"3 ** and 1/2 stars out of **** Lifeforce is one ... 1\n",
"4 I learned a thing: you have to take this film ... 1"
]
},
"execution_count": 76,
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import pandas as pd\n",
"\n",
"df = pd.read_csv(new_file_path, sep=\"\\t\", header=None, names=[\"Label\", \"Text\"])\n",
"df = df.sample(frac=1, random_state=123).reset_index(drop=True) # Shuffle the DataFrame\n",
"df"
]
},
{
"cell_type": "code",
"execution_count": 77,
"id": "4b7beeba-9f3a-45f0-b2dc-76bb155a8f0e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Label\n",
"ham 4825\n",
"spam 747\n",
"Name: count, dtype: int64\n"
]
}
],
"source": [
"# Class distribution\n",
"print(df[\"Label\"].value_counts())"
]
},
{
"cell_type": "code",
"execution_count": 78,
"id": "b3db862a-9e03-4715-babb-9b699e4f4a36",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Label\n",
"spam 747\n",
"ham 747\n",
"Name: count, dtype: int64\n"
]
}
],
"source": [
"# Count the instances of 'spam'\n",
"n_spam = df[df[\"Label\"] == \"spam\"].shape[0]\n",
"\n",
"# Randomly sample 'ham' instances to match the number of 'spam' instances\n",
"ham_sampled = df[df[\"Label\"] == \"ham\"].sample(n_spam)\n",
"\n",
"# Combine the sampled 'ham' with all 'spam'\n",
"balanced_df = pd.concat([ham_sampled, df[df[\"Label\"] == \"spam\"]])\n",
"\n",
"# Shuffle the DataFrame\n",
"balanced_df = balanced_df.sample(frac=1).reset_index(drop=True)\n",
"\n",
"# Now balanced_df is the balanced DataFrame\n",
"print(balanced_df[\"Label\"].value_counts())"
]
},
{
"cell_type": "code",
"execution_count": 79,
"id": "0af991e5-98ef-439a-a43d-63a581a2cc6d",
"metadata": {},
"outputs": [],
"source": [
"df[\"Label\"] = df[\"Label\"].map({\"ham\": 0, \"spam\": 1})"
]
},
{
"cell_type": "code",
"execution_count": 80,
"id": "2f5b00ef-e3ed-4819-b271-5f355848feb5",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Training set:\n",
"Label\n",
"0 0.86612\n",
"1 0.13388\n",
"Name: proportion, dtype: float64\n",
"\n",
"Validation set:\n",
"Label\n",
"0 0.866906\n",
"1 0.133094\n",
"Name: proportion, dtype: float64\n",
"\n",
"Test set:\n",
"Label\n",
"0 0.864816\n",
"1 0.135184\n",
"Name: proportion, dtype: float64\n"
]
}
],
"source": [
"# Define split ratios\n",
"train_size, validation_size = 0.7, 0.1\n",
"# Test size is implied to be 0.2 as the remainder\n",
"\n",
"# Split the data\n",
"def stratified_split(df, stratify_col, train_frac, validation_frac):\n",
" stratified_train = pd.DataFrame()\n",
" stratified_validation = pd.DataFrame()\n",
" stratified_test = pd.DataFrame()\n",
"\n",
" # Stratify split by the unique values in the column\n",
" for value in df[stratify_col].unique():\n",
" # Filter the DataFrame for the class\n",
" df_class = df[df[stratify_col] == value]\n",
" \n",
" # Calculate class split sizes\n",
" train_end = int(len(df_class) * train_frac)\n",
" validation_end = train_end + int(len(df_class) * validation_frac)\n",
" \n",
" # Slice the DataFrame to get the sets\n",
" stratified_train = pd.concat([stratified_train, df_class[:train_end]], axis=0)\n",
" stratified_validation = pd.concat([stratified_validation, df_class[train_end:validation_end]], axis=0)\n",
" stratified_test = pd.concat([stratified_test, df_class[validation_end:]], axis=0)\n",
"\n",
" # Shuffle the sets again\n",
" stratified_train = stratified_train.sample(frac=1, random_state=123).reset_index(drop=True)\n",
" stratified_validation = stratified_validation.sample(frac=1, random_state=123).reset_index(drop=True)\n",
" stratified_test = stratified_test.sample(frac=1, random_state=123).reset_index(drop=True)\n",
"\n",
" return stratified_train, stratified_validation, stratified_test\n",
"\n",
"# Apply the stratified split function\n",
"train_df, validation_df, test_df = stratified_split(df, \"Label\", train_size, validation_size)\n",
"\n",
"# Check the results\n",
"print(f\"Training set:\\n{train_df['Label'].value_counts(normalize=True)}\")\n",
"print(f\"\\nValidation set:\\n{validation_df['Label'].value_counts(normalize=True)}\")\n",
"print(f\"\\nTest set:\\n{test_df['Label'].value_counts(normalize=True)}\")"
]
},
{
"cell_type": "code",
"execution_count": 81,
"id": "65808167-2b93-45b0-8506-ce722732ce77",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Training set:\n",
"Label\n",
"ham 0.5\n",
"spam 0.5\n",
"Name: proportion, dtype: float64\n",
"\n",
"Validation set:\n",
"Label\n",
"ham 0.5\n",
"spam 0.5\n",
"Name: proportion, dtype: float64\n",
"\n",
"Test set:\n",
"Label\n",
"spam 0.5\n",
"ham 0.5\n",
"Name: proportion, dtype: float64\n"
]
}
],
"source": [
"# Define split ratios\n",
"train_size, validation_size = 0.7, 0.1\n",
"# Test size is implied to be 0.2 as the remainder\n",
"\n",
"# Apply the stratified split function\n",
"train_df, validation_df, test_df = stratified_split(balanced_df, \"Label\", train_size, validation_size)\n",
"\n",
"# Check the results\n",
"print(f\"Training set:\\n{train_df['Label'].value_counts(normalize=True)}\")\n",
"print(f\"\\nValidation set:\\n{validation_df['Label'].value_counts(normalize=True)}\")\n",
"print(f\"\\nTest set:\\n{test_df['Label'].value_counts(normalize=True)}\")"
"train_df.head()"
]
},
{
@ -362,35 +129,35 @@
},
{
"cell_type": "code",
"execution_count": 82,
"execution_count": 17,
"id": "180318b7-de18-4b05-b84a-ba97c72b9d8e",
"metadata": {},
"outputs": [],
"source": [
"from sklearn.feature_extraction.text import CountVectorizer\n",
"from sklearn.linear_model import LogisticRegression\n",
"from sklearn.metrics import accuracy_score, balanced_accuracy_score"
"from sklearn.metrics import accuracy_score"
]
},
{
"cell_type": "code",
"execution_count": 83,
"execution_count": 20,
"id": "25090b7c-f516-4be2-8083-3a7187fe4635",
"metadata": {},
"outputs": [],
"source": [
"vectorizer = CountVectorizer()\n",
"\n",
"X_train = vectorizer.fit_transform(train_df[\"Text\"])\n",
"X_val = vectorizer.transform(validation_df[\"Text\"])\n",
"X_test = vectorizer.transform(test_df[\"Text\"])\n",
"X_train = vectorizer.fit_transform(train_df[\"text\"])\n",
"X_val = vectorizer.transform(val_df[\"text\"])\n",
"X_test = vectorizer.transform(test_df[\"text\"])\n",
"\n",
"y_train, y_val, y_test = train_df[\"Label\"], validation_df[\"Label\"], test_df[\"Label\"]"
"y_train, y_val, y_test = train_df[\"label\"], val_df[\"label\"], test_df[\"label\"]"
]
},
{
"cell_type": "code",
"execution_count": 84,
"execution_count": 22,
"id": "0247de3a-88f0-4b9c-becd-157baf3acf49",
"metadata": {},
"outputs": [],
@ -414,16 +181,12 @@
" # Printing the results\n",
" print(f\"Training Accuracy: {accuracy_train*100:.2f}%\")\n",
" print(f\"Validation Accuracy: {accuracy_val*100:.2f}%\")\n",
" print(f\"Test Accuracy: {accuracy_test*100:.2f}%\")\n",
" \n",
" print(f\"\\nTraining Balanced Accuracy: {balanced_accuracy_train*100:.2f}%\")\n",
" print(f\"Validation Balanced Accuracy: {balanced_accuracy_val*100:.2f}%\")\n",
" print(f\"Test Balanced Accuracy: {balanced_accuracy_test*100:.2f}%\")"
" print(f\"Test Accuracy: {accuracy_test*100:.2f}%\")"
]
},
{
"cell_type": "code",
"execution_count": 85,
"execution_count": 23,
"id": "c29c6dfc-f72d-40ab-8cb5-783aad1a15ab",
"metadata": {},
"outputs": [
@ -431,13 +194,9 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Training Accuracy: 50.00%\n",
"Validation Accuracy: 50.00%\n",
"Test Accuracy: 50.00%\n",
"\n",
"Training Balanced Accuracy: 50.00%\n",
"Validation Balanced Accuracy: 50.00%\n",
"Test Balanced Accuracy: 50.00%\n"
"Training Accuracy: 50.01%\n",
"Validation Accuracy: 50.14%\n",
"Test Accuracy: 49.91%\n"
]
}
],
@ -453,7 +212,7 @@
},
{
"cell_type": "code",
"execution_count": 86,
"execution_count": 24,
"id": "088a8a3a-3b74-4d10-a51b-cb662569ae39",
"metadata": {},
"outputs": [
@ -461,13 +220,9 @@
"name": "stdout",
"output_type": "stream",
"text": [
"Training Accuracy: 99.81%\n",
"Validation Accuracy: 95.27%\n",
"Test Accuracy: 96.03%\n",
"\n",
"Training Balanced Accuracy: 99.81%\n",
"Validation Balanced Accuracy: 95.27%\n",
"Test Balanced Accuracy: 96.03%\n"
"Training Accuracy: 99.80%\n",
"Validation Accuracy: 88.62%\n",
"Test Accuracy: 88.85%\n"
]
}
],
@ -476,22 +231,6 @@
"model.fit(X_train, y_train)\n",
"eval(model, X_train, y_train, X_val, y_val, X_test, y_test)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "34411348-45bc-4b01-bebf-b3602c002ef1",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "5a9bc6b1-c8b9-4d4f-bfe4-c5a4a8b0c756",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
@ -510,7 +249,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
"version": "3.11.4"
}
},
"nbformat": 4,

View File

@ -234,15 +234,6 @@ if __name__ == "__main__":
# Instantiate dataloaders
###############################
url = "https://archive.ics.uci.edu/static/public/228/sms+spam+collection.zip"
zip_path = "sms_spam_collection.zip"
extract_to = "sms_spam_collection"
new_file_path = Path(extract_to) / "SMSSpamCollection.tsv"
base_path = Path(".")
file_names = ["train.csv", "val.csv", "test.csv"]
all_exist = all((base_path / file_name).exists() for file_name in file_names)
pad_token_id = tokenizer.encode(tokenizer.pad_token)
train_dataset = IMDBDataset(base_path / "train.csv", max_length=256, tokenizer=tokenizer, pad_token_id=pad_token_id)

View File

@ -286,15 +286,6 @@ if __name__ == "__main__":
# Instantiate dataloaders
###############################
url = "https://archive.ics.uci.edu/static/public/228/sms+spam+collection.zip"
zip_path = "sms_spam_collection.zip"
extract_to = "sms_spam_collection"
new_file_path = Path(extract_to) / "SMSSpamCollection.tsv"
base_path = Path(".")
file_names = ["train.csv", "val.csv", "test.csv"]
all_exist = all((base_path / file_name).exists() for file_name in file_names)
tokenizer = tiktoken.get_encoding("gpt2")
train_dataset = None