mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-11-02 10:50:30 +00:00
update dataset naming
This commit is contained in:
parent
55c3a91838
commit
2e47a6e61c
5
.gitignore
vendored
5
.gitignore
vendored
@ -24,6 +24,11 @@ ch06/01_main-chapter-code/sms_spam_collection
|
||||
ch06/01_main-chapter-code/test.csv
|
||||
ch06/01_main-chapter-code/train.csv
|
||||
ch06/01_main-chapter-code/validation.csv
|
||||
ch06/03_bonus_imdb-classification/aclImdb/
|
||||
ch06/03_bonus_imdb-classification/aclImdb_v1.tar.gz
|
||||
ch06/03_bonus_imdb-classification/test.csv
|
||||
ch06/03_bonus_imdb-classification/train.csv
|
||||
ch06/03_bonus_imdb-classification/validation.csv
|
||||
|
||||
# Temporary OS-related files
|
||||
.DS_Store
|
||||
|
||||
@ -1415,7 +1415,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.6"
|
||||
"version": "3.11.4"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@ -2347,7 +2347,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.6"
|
||||
"version": "3.11.4"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@ -1,59 +1,50 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "8b6e1cdd-b14e-4368-bdbb-9bf7ab821791",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Scikit-learn Logistic Regression Model"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 75,
|
||||
"id": "b612c4c1-fa3c-47b9-a8ce-9e32f371e160",
|
||||
"execution_count": 1,
|
||||
"id": "c2a72242-6197-4bef-aa05-696a152350d5",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"sms_spam_collection/SMSSpamCollection.tsv already exists. Skipping download and extraction.\n"
|
||||
"100% | 80.23 MB | 4.37 MB/s | 18.38 sec elapsed"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import urllib.request\n",
|
||||
"import zipfile\n",
|
||||
"import os\n",
|
||||
"from pathlib import Path\n",
|
||||
"\n",
|
||||
"url = \"https://archive.ics.uci.edu/static/public/228/sms+spam+collection.zip\"\n",
|
||||
"zip_path = \"sms_spam_collection.zip\"\n",
|
||||
"extract_to = \"sms_spam_collection\"\n",
|
||||
"new_file_path = Path(extract_to) / \"SMSSpamCollection.tsv\"\n",
|
||||
"\n",
|
||||
"def download_and_unzip(url, zip_path, extract_to, new_file_path):\n",
|
||||
" # Check if the target file already exists\n",
|
||||
" if new_file_path.exists():\n",
|
||||
" print(f\"{new_file_path} already exists. Skipping download and extraction.\")\n",
|
||||
" return\n",
|
||||
"\n",
|
||||
" # Downloading the file\n",
|
||||
" with urllib.request.urlopen(url) as response:\n",
|
||||
" with open(zip_path, \"wb\") as out_file:\n",
|
||||
" out_file.write(response.read())\n",
|
||||
"\n",
|
||||
" # Unzipping the file\n",
|
||||
" with zipfile.ZipFile(zip_path, 'r') as zip_ref:\n",
|
||||
" zip_ref.extractall(extract_to)\n",
|
||||
"\n",
|
||||
" # Renaming the file to indicate its format\n",
|
||||
" original_file = Path(extract_to) / \"SMSSpamCollection\"\n",
|
||||
" os.rename(original_file, new_file_path)\n",
|
||||
" print(f\"File download and saved as {new_file_path}\")\n",
|
||||
"\n",
|
||||
"# Execute the function\n",
|
||||
"download_and_unzip_spam_data(url, zip_path, extract_to, new_file_path)"
|
||||
"!python download-prepare-dataset.py"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 76,
|
||||
"execution_count": 14,
|
||||
"id": "69f32433-e19c-4066-b806-8f30b408107f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import pandas as pd\n",
|
||||
"\n",
|
||||
"train_df = pd.read_csv(\"train.csv\")\n",
|
||||
"val_df = pd.read_csv(\"validation.csv\")\n",
|
||||
"test_df = pd.read_csv(\"test.csv\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"id": "0808b212-fe91-48d9-80b8-55519f8835d5",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
@ -76,280 +67,56 @@
|
||||
" <thead>\n",
|
||||
" <tr style=\"text-align: right;\">\n",
|
||||
" <th></th>\n",
|
||||
" <th>Label</th>\n",
|
||||
" <th>Text</th>\n",
|
||||
" <th>text</th>\n",
|
||||
" <th>label</th>\n",
|
||||
" </tr>\n",
|
||||
" </thead>\n",
|
||||
" <tbody>\n",
|
||||
" <tr>\n",
|
||||
" <th>0</th>\n",
|
||||
" <td>ham</td>\n",
|
||||
" <td>Aight text me when you're back at mu and I'll ...</td>\n",
|
||||
" <td>The only reason I saw \"Shakedown\" was that it ...</td>\n",
|
||||
" <td>0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>1</th>\n",
|
||||
" <td>ham</td>\n",
|
||||
" <td>Our Prashanthettan's mother passed away last n...</td>\n",
|
||||
" <td>This is absolute drivel, designed to shock and...</td>\n",
|
||||
" <td>0</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>2</th>\n",
|
||||
" <td>ham</td>\n",
|
||||
" <td>No it will reach by 9 only. She telling she wi...</td>\n",
|
||||
" <td>Lots of scenes and dialogue are flat-out goofy...</td>\n",
|
||||
" <td>1</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>3</th>\n",
|
||||
" <td>ham</td>\n",
|
||||
" <td>Do you know when the result.</td>\n",
|
||||
" <td>** and 1/2 stars out of **** Lifeforce is one ...</td>\n",
|
||||
" <td>1</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>4</th>\n",
|
||||
" <td>spam</td>\n",
|
||||
" <td>Hi. Customer Loyalty Offer:The NEW Nokia6650 M...</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>...</th>\n",
|
||||
" <td>...</td>\n",
|
||||
" <td>...</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>5567</th>\n",
|
||||
" <td>ham</td>\n",
|
||||
" <td>I accidentally brought em home in the box</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>5568</th>\n",
|
||||
" <td>spam</td>\n",
|
||||
" <td>Moby Pub Quiz.Win a £100 High Street prize if ...</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>5569</th>\n",
|
||||
" <td>ham</td>\n",
|
||||
" <td>Que pases un buen tiempo or something like that</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>5570</th>\n",
|
||||
" <td>ham</td>\n",
|
||||
" <td>Nowadays people are notixiquating the laxinorf...</td>\n",
|
||||
" </tr>\n",
|
||||
" <tr>\n",
|
||||
" <th>5571</th>\n",
|
||||
" <td>ham</td>\n",
|
||||
" <td>Ard 4 lor...</td>\n",
|
||||
" <td>I learned a thing: you have to take this film ...</td>\n",
|
||||
" <td>1</td>\n",
|
||||
" </tr>\n",
|
||||
" </tbody>\n",
|
||||
"</table>\n",
|
||||
"<p>5572 rows × 2 columns</p>\n",
|
||||
"</div>"
|
||||
],
|
||||
"text/plain": [
|
||||
" Label Text\n",
|
||||
"0 ham Aight text me when you're back at mu and I'll ...\n",
|
||||
"1 ham Our Prashanthettan's mother passed away last n...\n",
|
||||
"2 ham No it will reach by 9 only. She telling she wi...\n",
|
||||
"3 ham Do you know when the result.\n",
|
||||
"4 spam Hi. Customer Loyalty Offer:The NEW Nokia6650 M...\n",
|
||||
"... ... ...\n",
|
||||
"5567 ham I accidentally brought em home in the box\n",
|
||||
"5568 spam Moby Pub Quiz.Win a £100 High Street prize if ...\n",
|
||||
"5569 ham Que pases un buen tiempo or something like that\n",
|
||||
"5570 ham Nowadays people are notixiquating the laxinorf...\n",
|
||||
"5571 ham Ard 4 lor...\n",
|
||||
"\n",
|
||||
"[5572 rows x 2 columns]"
|
||||
" text label\n",
|
||||
"0 The only reason I saw \"Shakedown\" was that it ... 0\n",
|
||||
"1 This is absolute drivel, designed to shock and... 0\n",
|
||||
"2 Lots of scenes and dialogue are flat-out goofy... 1\n",
|
||||
"3 ** and 1/2 stars out of **** Lifeforce is one ... 1\n",
|
||||
"4 I learned a thing: you have to take this film ... 1"
|
||||
]
|
||||
},
|
||||
"execution_count": 76,
|
||||
"execution_count": 16,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import pandas as pd\n",
|
||||
"\n",
|
||||
"df = pd.read_csv(new_file_path, sep=\"\\t\", header=None, names=[\"Label\", \"Text\"])\n",
|
||||
"df = df.sample(frac=1, random_state=123).reset_index(drop=True) # Shuffle the DataFrame\n",
|
||||
"df"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 77,
|
||||
"id": "4b7beeba-9f3a-45f0-b2dc-76bb155a8f0e",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Label\n",
|
||||
"ham 4825\n",
|
||||
"spam 747\n",
|
||||
"Name: count, dtype: int64\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Class distribution\n",
|
||||
"print(df[\"Label\"].value_counts())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 78,
|
||||
"id": "b3db862a-9e03-4715-babb-9b699e4f4a36",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Label\n",
|
||||
"spam 747\n",
|
||||
"ham 747\n",
|
||||
"Name: count, dtype: int64\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Count the instances of 'spam'\n",
|
||||
"n_spam = df[df[\"Label\"] == \"spam\"].shape[0]\n",
|
||||
"\n",
|
||||
"# Randomly sample 'ham' instances to match the number of 'spam' instances\n",
|
||||
"ham_sampled = df[df[\"Label\"] == \"ham\"].sample(n_spam)\n",
|
||||
"\n",
|
||||
"# Combine the sampled 'ham' with all 'spam'\n",
|
||||
"balanced_df = pd.concat([ham_sampled, df[df[\"Label\"] == \"spam\"]])\n",
|
||||
"\n",
|
||||
"# Shuffle the DataFrame\n",
|
||||
"balanced_df = balanced_df.sample(frac=1).reset_index(drop=True)\n",
|
||||
"\n",
|
||||
"# Now balanced_df is the balanced DataFrame\n",
|
||||
"print(balanced_df[\"Label\"].value_counts())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 79,
|
||||
"id": "0af991e5-98ef-439a-a43d-63a581a2cc6d",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"df[\"Label\"] = df[\"Label\"].map({\"ham\": 0, \"spam\": 1})"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 80,
|
||||
"id": "2f5b00ef-e3ed-4819-b271-5f355848feb5",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Training set:\n",
|
||||
"Label\n",
|
||||
"0 0.86612\n",
|
||||
"1 0.13388\n",
|
||||
"Name: proportion, dtype: float64\n",
|
||||
"\n",
|
||||
"Validation set:\n",
|
||||
"Label\n",
|
||||
"0 0.866906\n",
|
||||
"1 0.133094\n",
|
||||
"Name: proportion, dtype: float64\n",
|
||||
"\n",
|
||||
"Test set:\n",
|
||||
"Label\n",
|
||||
"0 0.864816\n",
|
||||
"1 0.135184\n",
|
||||
"Name: proportion, dtype: float64\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Define split ratios\n",
|
||||
"train_size, validation_size = 0.7, 0.1\n",
|
||||
"# Test size is implied to be 0.2 as the remainder\n",
|
||||
"\n",
|
||||
"# Split the data\n",
|
||||
"def stratified_split(df, stratify_col, train_frac, validation_frac):\n",
|
||||
" stratified_train = pd.DataFrame()\n",
|
||||
" stratified_validation = pd.DataFrame()\n",
|
||||
" stratified_test = pd.DataFrame()\n",
|
||||
"\n",
|
||||
" # Stratify split by the unique values in the column\n",
|
||||
" for value in df[stratify_col].unique():\n",
|
||||
" # Filter the DataFrame for the class\n",
|
||||
" df_class = df[df[stratify_col] == value]\n",
|
||||
" \n",
|
||||
" # Calculate class split sizes\n",
|
||||
" train_end = int(len(df_class) * train_frac)\n",
|
||||
" validation_end = train_end + int(len(df_class) * validation_frac)\n",
|
||||
" \n",
|
||||
" # Slice the DataFrame to get the sets\n",
|
||||
" stratified_train = pd.concat([stratified_train, df_class[:train_end]], axis=0)\n",
|
||||
" stratified_validation = pd.concat([stratified_validation, df_class[train_end:validation_end]], axis=0)\n",
|
||||
" stratified_test = pd.concat([stratified_test, df_class[validation_end:]], axis=0)\n",
|
||||
"\n",
|
||||
" # Shuffle the sets again\n",
|
||||
" stratified_train = stratified_train.sample(frac=1, random_state=123).reset_index(drop=True)\n",
|
||||
" stratified_validation = stratified_validation.sample(frac=1, random_state=123).reset_index(drop=True)\n",
|
||||
" stratified_test = stratified_test.sample(frac=1, random_state=123).reset_index(drop=True)\n",
|
||||
"\n",
|
||||
" return stratified_train, stratified_validation, stratified_test\n",
|
||||
"\n",
|
||||
"# Apply the stratified split function\n",
|
||||
"train_df, validation_df, test_df = stratified_split(df, \"Label\", train_size, validation_size)\n",
|
||||
"\n",
|
||||
"# Check the results\n",
|
||||
"print(f\"Training set:\\n{train_df['Label'].value_counts(normalize=True)}\")\n",
|
||||
"print(f\"\\nValidation set:\\n{validation_df['Label'].value_counts(normalize=True)}\")\n",
|
||||
"print(f\"\\nTest set:\\n{test_df['Label'].value_counts(normalize=True)}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 81,
|
||||
"id": "65808167-2b93-45b0-8506-ce722732ce77",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Training set:\n",
|
||||
"Label\n",
|
||||
"ham 0.5\n",
|
||||
"spam 0.5\n",
|
||||
"Name: proportion, dtype: float64\n",
|
||||
"\n",
|
||||
"Validation set:\n",
|
||||
"Label\n",
|
||||
"ham 0.5\n",
|
||||
"spam 0.5\n",
|
||||
"Name: proportion, dtype: float64\n",
|
||||
"\n",
|
||||
"Test set:\n",
|
||||
"Label\n",
|
||||
"spam 0.5\n",
|
||||
"ham 0.5\n",
|
||||
"Name: proportion, dtype: float64\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# Define split ratios\n",
|
||||
"train_size, validation_size = 0.7, 0.1\n",
|
||||
"# Test size is implied to be 0.2 as the remainder\n",
|
||||
"\n",
|
||||
"# Apply the stratified split function\n",
|
||||
"train_df, validation_df, test_df = stratified_split(balanced_df, \"Label\", train_size, validation_size)\n",
|
||||
"\n",
|
||||
"# Check the results\n",
|
||||
"print(f\"Training set:\\n{train_df['Label'].value_counts(normalize=True)}\")\n",
|
||||
"print(f\"\\nValidation set:\\n{validation_df['Label'].value_counts(normalize=True)}\")\n",
|
||||
"print(f\"\\nTest set:\\n{test_df['Label'].value_counts(normalize=True)}\")"
|
||||
"train_df.head()"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -362,35 +129,35 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 82,
|
||||
"execution_count": 17,
|
||||
"id": "180318b7-de18-4b05-b84a-ba97c72b9d8e",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from sklearn.feature_extraction.text import CountVectorizer\n",
|
||||
"from sklearn.linear_model import LogisticRegression\n",
|
||||
"from sklearn.metrics import accuracy_score, balanced_accuracy_score"
|
||||
"from sklearn.metrics import accuracy_score"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 83,
|
||||
"execution_count": 20,
|
||||
"id": "25090b7c-f516-4be2-8083-3a7187fe4635",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"vectorizer = CountVectorizer()\n",
|
||||
"\n",
|
||||
"X_train = vectorizer.fit_transform(train_df[\"Text\"])\n",
|
||||
"X_val = vectorizer.transform(validation_df[\"Text\"])\n",
|
||||
"X_test = vectorizer.transform(test_df[\"Text\"])\n",
|
||||
"X_train = vectorizer.fit_transform(train_df[\"text\"])\n",
|
||||
"X_val = vectorizer.transform(val_df[\"text\"])\n",
|
||||
"X_test = vectorizer.transform(test_df[\"text\"])\n",
|
||||
"\n",
|
||||
"y_train, y_val, y_test = train_df[\"Label\"], validation_df[\"Label\"], test_df[\"Label\"]"
|
||||
"y_train, y_val, y_test = train_df[\"label\"], val_df[\"label\"], test_df[\"label\"]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 84,
|
||||
"execution_count": 22,
|
||||
"id": "0247de3a-88f0-4b9c-becd-157baf3acf49",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@ -414,16 +181,12 @@
|
||||
" # Printing the results\n",
|
||||
" print(f\"Training Accuracy: {accuracy_train*100:.2f}%\")\n",
|
||||
" print(f\"Validation Accuracy: {accuracy_val*100:.2f}%\")\n",
|
||||
" print(f\"Test Accuracy: {accuracy_test*100:.2f}%\")\n",
|
||||
" \n",
|
||||
" print(f\"\\nTraining Balanced Accuracy: {balanced_accuracy_train*100:.2f}%\")\n",
|
||||
" print(f\"Validation Balanced Accuracy: {balanced_accuracy_val*100:.2f}%\")\n",
|
||||
" print(f\"Test Balanced Accuracy: {balanced_accuracy_test*100:.2f}%\")"
|
||||
" print(f\"Test Accuracy: {accuracy_test*100:.2f}%\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 85,
|
||||
"execution_count": 23,
|
||||
"id": "c29c6dfc-f72d-40ab-8cb5-783aad1a15ab",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -431,13 +194,9 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Training Accuracy: 50.00%\n",
|
||||
"Validation Accuracy: 50.00%\n",
|
||||
"Test Accuracy: 50.00%\n",
|
||||
"\n",
|
||||
"Training Balanced Accuracy: 50.00%\n",
|
||||
"Validation Balanced Accuracy: 50.00%\n",
|
||||
"Test Balanced Accuracy: 50.00%\n"
|
||||
"Training Accuracy: 50.01%\n",
|
||||
"Validation Accuracy: 50.14%\n",
|
||||
"Test Accuracy: 49.91%\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -453,7 +212,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 86,
|
||||
"execution_count": 24,
|
||||
"id": "088a8a3a-3b74-4d10-a51b-cb662569ae39",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
@ -461,13 +220,9 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Training Accuracy: 99.81%\n",
|
||||
"Validation Accuracy: 95.27%\n",
|
||||
"Test Accuracy: 96.03%\n",
|
||||
"\n",
|
||||
"Training Balanced Accuracy: 99.81%\n",
|
||||
"Validation Balanced Accuracy: 95.27%\n",
|
||||
"Test Balanced Accuracy: 96.03%\n"
|
||||
"Training Accuracy: 99.80%\n",
|
||||
"Validation Accuracy: 88.62%\n",
|
||||
"Test Accuracy: 88.85%\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -476,22 +231,6 @@
|
||||
"model.fit(X_train, y_train)\n",
|
||||
"eval(model, X_train, y_train, X_val, y_val, X_test, y_test)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "34411348-45bc-4b01-bebf-b3602c002ef1",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "5a9bc6b1-c8b9-4d4f-bfe4-c5a4a8b0c756",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
@ -510,7 +249,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.6"
|
||||
"version": "3.11.4"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@ -234,15 +234,6 @@ if __name__ == "__main__":
|
||||
# Instantiate dataloaders
|
||||
###############################
|
||||
|
||||
url = "https://archive.ics.uci.edu/static/public/228/sms+spam+collection.zip"
|
||||
zip_path = "sms_spam_collection.zip"
|
||||
extract_to = "sms_spam_collection"
|
||||
new_file_path = Path(extract_to) / "SMSSpamCollection.tsv"
|
||||
|
||||
base_path = Path(".")
|
||||
file_names = ["train.csv", "val.csv", "test.csv"]
|
||||
all_exist = all((base_path / file_name).exists() for file_name in file_names)
|
||||
|
||||
pad_token_id = tokenizer.encode(tokenizer.pad_token)
|
||||
|
||||
train_dataset = IMDBDataset(base_path / "train.csv", max_length=256, tokenizer=tokenizer, pad_token_id=pad_token_id)
|
||||
|
||||
@ -286,15 +286,6 @@ if __name__ == "__main__":
|
||||
# Instantiate dataloaders
|
||||
###############################
|
||||
|
||||
url = "https://archive.ics.uci.edu/static/public/228/sms+spam+collection.zip"
|
||||
zip_path = "sms_spam_collection.zip"
|
||||
extract_to = "sms_spam_collection"
|
||||
new_file_path = Path(extract_to) / "SMSSpamCollection.tsv"
|
||||
|
||||
base_path = Path(".")
|
||||
file_names = ["train.csv", "val.csv", "test.csv"]
|
||||
all_exist = all((base_path / file_name).exists() for file_name in file_names)
|
||||
|
||||
tokenizer = tiktoken.get_encoding("gpt2")
|
||||
|
||||
train_dataset = None
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user