reorder section 6.6

This commit is contained in:
rasbt 2024-05-11 08:27:07 -05:00
parent 694a57a472
commit 02ad1bef3a

View File

@ -1348,7 +1348,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 39, "execution_count": 22,
"id": "2aedc120-5ee3-48f6-92f2-ad9304ebcdc7", "id": "2aedc120-5ee3-48f6-92f2-ad9304ebcdc7",
"metadata": { "metadata": {
"id": "2aedc120-5ee3-48f6-92f2-ad9304ebcdc7" "id": "2aedc120-5ee3-48f6-92f2-ad9304ebcdc7"
@ -1373,7 +1373,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 42, "execution_count": 23,
"id": "f645c06a-7df6-451c-ad3f-eafb18224ebc", "id": "f645c06a-7df6-451c-ad3f-eafb18224ebc",
"metadata": { "metadata": {
"colab": { "colab": {
@ -1409,7 +1409,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 43, "execution_count": 24,
"id": "48dc84f1-85cc-4609-9cee-94ff539f00f4", "id": "48dc84f1-85cc-4609-9cee-94ff539f00f4",
"metadata": { "metadata": {
"colab": { "colab": {
@ -1470,7 +1470,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 44, "execution_count": 25,
"id": "49383a8c-41d5-4dab-98f1-238bca0c2ed7", "id": "49383a8c-41d5-4dab-98f1-238bca0c2ed7",
"metadata": { "metadata": {
"colab": { "colab": {
@ -1526,7 +1526,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 45, "execution_count": 26,
"id": "c77faab1-3461-4118-866a-6171f2b89aa0", "id": "c77faab1-3461-4118-866a-6171f2b89aa0",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -1547,12 +1547,12 @@
"id": "7edd71fa-628a-4d00-b81d-6d8bcb2c341d", "id": "7edd71fa-628a-4d00-b81d-6d8bcb2c341d",
"metadata": {}, "metadata": {},
"source": [ "source": [
"- Similar to chapter 5, we convert the outputs (logits) into probability scores via the `softmax` function and then obtain the index position of the largest probability value via the `argmax` function:" "- Similar to chapter 5, we convert the outputs (logits) into probability scores via the `softmax` function and then obtain the index position of the largest probability value via the `argmax` function"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 48, "execution_count": 27,
"id": "b81efa92-9be1-4b9e-8790-ce1fc7b17f01", "id": "b81efa92-9be1-4b9e-8790-ce1fc7b17f01",
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
@ -1572,12 +1572,118 @@
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "d5241f47-a1e4-4bba-8064-5d06cffa7941", "id": "414a6f02-307e-4147-a416-14d115bf8179",
"metadata": {}, "metadata": {},
"source": [ "source": [
"- Note that the softmax function is optional here, as explained in chapter 5, because the largest outputs correspond to the largest probability scores" "- Note that the softmax function is optional here, as explained in chapter 5, because the largest outputs correspond to the largest probability scores"
] ]
}, },
{
"cell_type": "code",
"execution_count": 28,
"id": "f9f9ad66-4969-4501-8239-3ccdb37e71a2",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Class label: 1\n"
]
}
],
"source": [
"logits = outputs[:, -1, :]\n",
"label = torch.argmax(logits)\n",
"print(\"Class label:\", label.item())"
]
},
{
"cell_type": "markdown",
"id": "dcb20d3a-cbba-4ab1-8584-d94e16589505",
"metadata": {},
"source": [
"- We can apply this concept to calculate the so-called classification accuracy, which computes the percentage of correct predictions in a given dataset\n",
"- To calculate the classification accuracy, we can apply the preceding `argmax`-based prediction code to all examples in a dataset and calculate the fraction of correct predictions as follows:"
]
},
{
"cell_type": "code",
"execution_count": 29,
"id": "3ecf9572-aed0-4a21-9c3b-7f9f2aec5f23",
"metadata": {},
"outputs": [],
"source": [
"def calc_accuracy_loader(data_loader, model, device, num_batches=None):\n",
" model.eval()\n",
" correct_predictions, num_examples = 0, 0\n",
"\n",
" if num_batches is None:\n",
" num_batches = len(data_loader)\n",
" else:\n",
" num_batches = min(num_batches, len(data_loader))\n",
" for i, (input_batch, target_batch) in enumerate(data_loader):\n",
" if i < num_batches:\n",
" input_batch, target_batch = input_batch.to(device), target_batch.to(device)\n",
"\n",
" with torch.no_grad():\n",
" logits = model(input_batch)[:, -1, :] # Logits of last output token\n",
" predicted_labels = torch.argmax(logits, dim=-1)\n",
"\n",
" num_examples += predicted_labels.shape[0]\n",
" correct_predictions += (predicted_labels == target_batch).sum().item()\n",
" else:\n",
" break\n",
" return correct_predictions / num_examples"
]
},
{
"cell_type": "markdown",
"id": "7165fe46-a284-410b-957f-7524877d1a1a",
"metadata": {},
"source": [
"- Let's apply the function to calculate the classification accuracies for the different datasets:"
]
},
{
"cell_type": "code",
"execution_count": 30,
"id": "390e5255-8427-488c-adef-e1c10ab4fb26",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Training accuracy: 46.25%\n",
"Validation accuracy: 45.00%\n",
"Test accuracy: 48.75%\n"
]
}
],
"source": [
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"model.to(device) # no assignment model = model.to(device) necessary for nn.Module classes\n",
"\n",
"torch.manual_seed(123) # For reproducibility due to the shuffling in the training data loader\n",
"\n",
"train_accuracy = calc_accuracy_loader(train_loader, model, device, num_batches=10)\n",
"val_accuracy = calc_accuracy_loader(val_loader, model, device, num_batches=10)\n",
"test_accuracy = calc_accuracy_loader(test_loader, model, device, num_batches=10)\n",
"\n",
"print(f\"Training accuracy: {train_accuracy*100:.2f}%\")\n",
"print(f\"Validation accuracy: {val_accuracy*100:.2f}%\")\n",
"print(f\"Test accuracy: {test_accuracy*100:.2f}%\")"
]
},
{
"cell_type": "markdown",
"id": "30345e2a-afed-4d22-9486-f4010f90a871",
"metadata": {},
"source": [
"- As we can see, the prediction accuracies are not very good, since we haven't finetuned the model, yet"
]
},
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "4f4a9d15-8fc7-48a2-8734-d92a2f265328", "id": "4f4a9d15-8fc7-48a2-8734-d92a2f265328",
@ -1585,19 +1691,14 @@
"source": [ "source": [
"- Before we can start finetuning (/training), we first have to define the loss function we want to optimize during training\n", "- Before we can start finetuning (/training), we first have to define the loss function we want to optimize during training\n",
"- The goal is to maximize the spam classification accuracy of the model; however, classification accuracy is not a differentiable function\n", "- The goal is to maximize the spam classification accuracy of the model; however, classification accuracy is not a differentiable function\n",
"- Hence, instead, we minimize the cross entropy loss as a proxy for maximizing the classification accuracy (you can learn more about this topic in lecture 8 of my freely available [Introduction to Deep Learning](https://sebastianraschka.com/blog/2021/dl-course.html#l08-multinomial-logistic-regression--softmax-regression) class.\n", "- Hence, instead, we minimize the cross entropy loss as a proxy for maximizing the classification accuracy (you can learn more about this topic in lecture 8 of my freely available [Introduction to Deep Learning](https://sebastianraschka.com/blog/2021/dl-course.html#l08-multinomial-logistic-regression--softmax-regression) class)\n",
"\n", "\n",
"- Note that in chapter 5, we calculated the cross entropy loss for the next predicted token over the 50,257 token IDs in the vocabulary\n", "- The `calc_loss_batch` function is the same here as in chapter 5, except that we are only interested in optimizing the last token `model(input_batch)[:, -1, :]` instead of all tokens `model(input_batch)`"
"- Here, we calculate the cross entropy in a similar fashion; the only difference is that instead of 50,257 token IDs, we now have only two choices: \"spam\" (label 1) or \"not spam\" (label 0).\n",
"- In other words, the loss calculation training code is practically identical to the one in chapter 5, but we now only have two labels instead of 50,257 labels (token IDs).\n",
"\n",
"\n",
"- Consequently, the `calc_loss_batch` function is the same here as in chapter 5, except that we are only interested in optimizing the last token `model(input_batch)[:, -1, :]` instead of all tokens `model(input_batch)`:"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 26, "execution_count": 31,
"id": "2f1e9547-806c-41a9-8aba-3b2822baabe4", "id": "2f1e9547-806c-41a9-8aba-3b2822baabe4",
"metadata": { "metadata": {
"id": "2f1e9547-806c-41a9-8aba-3b2822baabe4" "id": "2f1e9547-806c-41a9-8aba-3b2822baabe4"
@ -1616,12 +1717,12 @@
"id": "a013aab9-f854-4866-ad55-5b8350adb50a", "id": "a013aab9-f854-4866-ad55-5b8350adb50a",
"metadata": {}, "metadata": {},
"source": [ "source": [
"The `calc_loss_loader` is exactly the same as in chapter 5:" "The `calc_loss_loader` is exactly the same as in chapter 5"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 27, "execution_count": 32,
"id": "b7b83e10-5720-45e7-ac5e-369417ca846b", "id": "b7b83e10-5720-45e7-ac5e-369417ca846b",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -1651,14 +1752,12 @@
"id": "56826ecd-6e74-40e6-b772-d3541e585067", "id": "56826ecd-6e74-40e6-b772-d3541e585067",
"metadata": {}, "metadata": {},
"source": [ "source": [
"- Using the `calc_closs_loader`, we compute the initial training, validation, and test set losses before we start training\n", "- Using the `calc_closs_loader`, we compute the initial training, validation, and test set losses before we start training"
"- Here, we use `torch.no_grad()` so that no gradients are computed during the forward pass, which reduces memory consumption and speeds up computations since we are not training the model yet\n",
"- Via the `device` setting, the model automatically runs on a GPU if a GPU with Nvidia CUDA support is available and otherwise runs on a CPU"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 28, "execution_count": 33,
"id": "f6f00e53-5beb-4e64-b147-f26fd481c6ff", "id": "f6f00e53-5beb-4e64-b147-f26fd481c6ff",
"metadata": { "metadata": {
"colab": { "colab": {
@ -1672,18 +1771,13 @@
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"Training loss: 3.095\n", "Training loss: 2.453\n",
"Validation loss: 2.583\n", "Validation loss: 2.583\n",
"Test loss: 2.322\n" "Test loss: 2.322\n"
] ]
} }
], ],
"source": [ "source": [
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"model.to(device) # no assignment model = model.to(device) necessary for nn.Module classes\n",
"\n",
"torch.manual_seed(123) # For reproducibility due to the shuffling in the training data loader\n",
"\n",
"with torch.no_grad(): # Disable gradient tracking for efficiency because we are not training, yet\n", "with torch.no_grad(): # Disable gradient tracking for efficiency because we are not training, yet\n",
" train_loss = calc_loss_loader(train_loader, model, device, num_batches=5)\n", " train_loss = calc_loss_loader(train_loader, model, device, num_batches=5)\n",
" val_loss = calc_loss_loader(val_loader, model, device, num_batches=5)\n", " val_loss = calc_loss_loader(val_loader, model, device, num_batches=5)\n",
@ -1694,93 +1788,12 @@
"print(f\"Test loss: {test_loss:.3f}\")" "print(f\"Test loss: {test_loss:.3f}\")"
] ]
}, },
{
"cell_type": "markdown",
"id": "b109556e-ddae-49fd-ad08-e6fa1032ea7a",
"metadata": {},
"source": [
"- Similar to the `calc_loss_loader` function above, we can define a `calc_accuracy_loader` function that calculates the classification accuracy by checking how many predicted class (spam and ham) labels match the given labels in the dataset\n",
"- Note that the classification accuracy is a mathematically non-differentiable function, and we only use it for evaluation; hence, we can disable the gradient calculation permanently to save resources here\n",
"- We can disable the gradient tracking either using the `with torch.no_grad():` inside the function or by using the `@torch.no_grad()` function decorator"
]
},
{
"cell_type": "code",
"execution_count": 29,
"id": "64ce5b12-84cd-488c-8ea7-4cef5b2d947e",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "64ce5b12-84cd-488c-8ea7-4cef5b2d947e",
"outputId": "239581b4-fd0f-4adf-e67b-364e0f0f96b7"
},
"outputs": [],
"source": [
"@torch.no_grad() # Disable gradient tracking for efficiency\n",
"def calc_accuracy_loader(data_loader, model, device, num_batches=None):\n",
" model.eval()\n",
" correct_predictions, num_examples = 0, 0\n",
"\n",
" if num_batches is None:\n",
" num_batches = len(data_loader)\n",
" else:\n",
" num_batches = min(num_batches, len(data_loader))\n",
" for i, (input_batch, target_batch) in enumerate(data_loader):\n",
" if i < num_batches:\n",
" input_batch, target_batch = input_batch.to(device), target_batch.to(device)\n",
" logits = model(input_batch)[:, -1, :] # Logits of last output token\n",
" predicted_labels = torch.argmax(logits, dim=-1)\n",
"\n",
" num_examples += predicted_labels.shape[0]\n",
" correct_predictions += (predicted_labels == target_batch).sum().item()\n",
" else:\n",
" break\n",
" return correct_predictions / num_examples"
]
},
{
"cell_type": "markdown",
"id": "90521a9a-639c-4c7f-a5c0-aca8fa5d4c1b",
"metadata": {},
"source": [
"- Let's check the initial classification accuracy before we start training the model"
]
},
{
"cell_type": "code",
"execution_count": 30,
"id": "2160418f-988b-40f3-bce8-e431021e97dc",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Training accuracy: 46.25%\n",
"Validation accuracy: 45.00%\n",
"Test accuracy: 48.75%\n"
]
}
],
"source": [
"torch.manual_seed(123)\n",
"train_accuracy = calc_accuracy_loader(train_loader, model, device, num_batches=10)\n",
"val_accuracy = calc_accuracy_loader(val_loader, model, device, num_batches=10)\n",
"test_accuracy = calc_accuracy_loader(test_loader, model, device, num_batches=10)\n",
"\n",
"print(f\"Training accuracy: {train_accuracy*100:.2f}%\")\n",
"print(f\"Validation accuracy: {val_accuracy*100:.2f}%\")\n",
"print(f\"Test accuracy: {test_accuracy*100:.2f}%\")"
]
},
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "e04b980b-e583-4f62-84a0-4edafaf99d5d", "id": "e04b980b-e583-4f62-84a0-4edafaf99d5d",
"metadata": {}, "metadata": {},
"source": [ "source": [
"- As we can see, the model only gets roughly half (50%) of the predictions correctly\n", "- In the next section, we train the model to improve the loss values and consequently the classification accuracy"
"- In the next section, we train the model to improve the classification accuracy"
] ]
}, },
{ {