mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-07-21 07:52:03 +00:00
reorder section 6.6
This commit is contained in:
parent
694a57a472
commit
02ad1bef3a
@ -1348,7 +1348,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 39,
|
"execution_count": 22,
|
||||||
"id": "2aedc120-5ee3-48f6-92f2-ad9304ebcdc7",
|
"id": "2aedc120-5ee3-48f6-92f2-ad9304ebcdc7",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"id": "2aedc120-5ee3-48f6-92f2-ad9304ebcdc7"
|
"id": "2aedc120-5ee3-48f6-92f2-ad9304ebcdc7"
|
||||||
@ -1373,7 +1373,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 42,
|
"execution_count": 23,
|
||||||
"id": "f645c06a-7df6-451c-ad3f-eafb18224ebc",
|
"id": "f645c06a-7df6-451c-ad3f-eafb18224ebc",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"colab": {
|
"colab": {
|
||||||
@ -1409,7 +1409,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 43,
|
"execution_count": 24,
|
||||||
"id": "48dc84f1-85cc-4609-9cee-94ff539f00f4",
|
"id": "48dc84f1-85cc-4609-9cee-94ff539f00f4",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"colab": {
|
"colab": {
|
||||||
@ -1470,7 +1470,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 44,
|
"execution_count": 25,
|
||||||
"id": "49383a8c-41d5-4dab-98f1-238bca0c2ed7",
|
"id": "49383a8c-41d5-4dab-98f1-238bca0c2ed7",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"colab": {
|
"colab": {
|
||||||
@ -1526,7 +1526,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 45,
|
"execution_count": 26,
|
||||||
"id": "c77faab1-3461-4118-866a-6171f2b89aa0",
|
"id": "c77faab1-3461-4118-866a-6171f2b89aa0",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@ -1547,12 +1547,12 @@
|
|||||||
"id": "7edd71fa-628a-4d00-b81d-6d8bcb2c341d",
|
"id": "7edd71fa-628a-4d00-b81d-6d8bcb2c341d",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"- Similar to chapter 5, we convert the outputs (logits) into probability scores via the `softmax` function and then obtain the index position of the largest probability value via the `argmax` function:"
|
"- Similar to chapter 5, we convert the outputs (logits) into probability scores via the `softmax` function and then obtain the index position of the largest probability value via the `argmax` function"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 48,
|
"execution_count": 27,
|
||||||
"id": "b81efa92-9be1-4b9e-8790-ce1fc7b17f01",
|
"id": "b81efa92-9be1-4b9e-8790-ce1fc7b17f01",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
@ -1572,12 +1572,118 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"id": "d5241f47-a1e4-4bba-8064-5d06cffa7941",
|
"id": "414a6f02-307e-4147-a416-14d115bf8179",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"- Note that the softmax function is optional here, as explained in chapter 5, because the largest outputs correspond to the largest probability scores"
|
"- Note that the softmax function is optional here, as explained in chapter 5, because the largest outputs correspond to the largest probability scores"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 28,
|
||||||
|
"id": "f9f9ad66-4969-4501-8239-3ccdb37e71a2",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Class label: 1\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"logits = outputs[:, -1, :]\n",
|
||||||
|
"label = torch.argmax(logits)\n",
|
||||||
|
"print(\"Class label:\", label.item())"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "dcb20d3a-cbba-4ab1-8584-d94e16589505",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"- We can apply this concept to calculate the so-called classification accuracy, which computes the percentage of correct predictions in a given dataset\n",
|
||||||
|
"- To calculate the classification accuracy, we can apply the preceding `argmax`-based prediction code to all examples in a dataset and calculate the fraction of correct predictions as follows:"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 29,
|
||||||
|
"id": "3ecf9572-aed0-4a21-9c3b-7f9f2aec5f23",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def calc_accuracy_loader(data_loader, model, device, num_batches=None):\n",
|
||||||
|
" model.eval()\n",
|
||||||
|
" correct_predictions, num_examples = 0, 0\n",
|
||||||
|
"\n",
|
||||||
|
" if num_batches is None:\n",
|
||||||
|
" num_batches = len(data_loader)\n",
|
||||||
|
" else:\n",
|
||||||
|
" num_batches = min(num_batches, len(data_loader))\n",
|
||||||
|
" for i, (input_batch, target_batch) in enumerate(data_loader):\n",
|
||||||
|
" if i < num_batches:\n",
|
||||||
|
" input_batch, target_batch = input_batch.to(device), target_batch.to(device)\n",
|
||||||
|
"\n",
|
||||||
|
" with torch.no_grad():\n",
|
||||||
|
" logits = model(input_batch)[:, -1, :] # Logits of last output token\n",
|
||||||
|
" predicted_labels = torch.argmax(logits, dim=-1)\n",
|
||||||
|
"\n",
|
||||||
|
" num_examples += predicted_labels.shape[0]\n",
|
||||||
|
" correct_predictions += (predicted_labels == target_batch).sum().item()\n",
|
||||||
|
" else:\n",
|
||||||
|
" break\n",
|
||||||
|
" return correct_predictions / num_examples"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "7165fe46-a284-410b-957f-7524877d1a1a",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"- Let's apply the function to calculate the classification accuracies for the different datasets:"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 30,
|
||||||
|
"id": "390e5255-8427-488c-adef-e1c10ab4fb26",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Training accuracy: 46.25%\n",
|
||||||
|
"Validation accuracy: 45.00%\n",
|
||||||
|
"Test accuracy: 48.75%\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
||||||
|
"model.to(device) # no assignment model = model.to(device) necessary for nn.Module classes\n",
|
||||||
|
"\n",
|
||||||
|
"torch.manual_seed(123) # For reproducibility due to the shuffling in the training data loader\n",
|
||||||
|
"\n",
|
||||||
|
"train_accuracy = calc_accuracy_loader(train_loader, model, device, num_batches=10)\n",
|
||||||
|
"val_accuracy = calc_accuracy_loader(val_loader, model, device, num_batches=10)\n",
|
||||||
|
"test_accuracy = calc_accuracy_loader(test_loader, model, device, num_batches=10)\n",
|
||||||
|
"\n",
|
||||||
|
"print(f\"Training accuracy: {train_accuracy*100:.2f}%\")\n",
|
||||||
|
"print(f\"Validation accuracy: {val_accuracy*100:.2f}%\")\n",
|
||||||
|
"print(f\"Test accuracy: {test_accuracy*100:.2f}%\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "30345e2a-afed-4d22-9486-f4010f90a871",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"- As we can see, the prediction accuracies are not very good, since we haven't finetuned the model, yet"
|
||||||
|
]
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"id": "4f4a9d15-8fc7-48a2-8734-d92a2f265328",
|
"id": "4f4a9d15-8fc7-48a2-8734-d92a2f265328",
|
||||||
@ -1585,19 +1691,14 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"- Before we can start finetuning (/training), we first have to define the loss function we want to optimize during training\n",
|
"- Before we can start finetuning (/training), we first have to define the loss function we want to optimize during training\n",
|
||||||
"- The goal is to maximize the spam classification accuracy of the model; however, classification accuracy is not a differentiable function\n",
|
"- The goal is to maximize the spam classification accuracy of the model; however, classification accuracy is not a differentiable function\n",
|
||||||
"- Hence, instead, we minimize the cross entropy loss as a proxy for maximizing the classification accuracy (you can learn more about this topic in lecture 8 of my freely available [Introduction to Deep Learning](https://sebastianraschka.com/blog/2021/dl-course.html#l08-multinomial-logistic-regression--softmax-regression) class.\n",
|
"- Hence, instead, we minimize the cross entropy loss as a proxy for maximizing the classification accuracy (you can learn more about this topic in lecture 8 of my freely available [Introduction to Deep Learning](https://sebastianraschka.com/blog/2021/dl-course.html#l08-multinomial-logistic-regression--softmax-regression) class)\n",
|
||||||
"\n",
|
"\n",
|
||||||
"- Note that in chapter 5, we calculated the cross entropy loss for the next predicted token over the 50,257 token IDs in the vocabulary\n",
|
"- The `calc_loss_batch` function is the same here as in chapter 5, except that we are only interested in optimizing the last token `model(input_batch)[:, -1, :]` instead of all tokens `model(input_batch)`"
|
||||||
"- Here, we calculate the cross entropy in a similar fashion; the only difference is that instead of 50,257 token IDs, we now have only two choices: \"spam\" (label 1) or \"not spam\" (label 0).\n",
|
|
||||||
"- In other words, the loss calculation training code is practically identical to the one in chapter 5, but we now only have two labels instead of 50,257 labels (token IDs).\n",
|
|
||||||
"\n",
|
|
||||||
"\n",
|
|
||||||
"- Consequently, the `calc_loss_batch` function is the same here as in chapter 5, except that we are only interested in optimizing the last token `model(input_batch)[:, -1, :]` instead of all tokens `model(input_batch)`:"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 26,
|
"execution_count": 31,
|
||||||
"id": "2f1e9547-806c-41a9-8aba-3b2822baabe4",
|
"id": "2f1e9547-806c-41a9-8aba-3b2822baabe4",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"id": "2f1e9547-806c-41a9-8aba-3b2822baabe4"
|
"id": "2f1e9547-806c-41a9-8aba-3b2822baabe4"
|
||||||
@ -1616,12 +1717,12 @@
|
|||||||
"id": "a013aab9-f854-4866-ad55-5b8350adb50a",
|
"id": "a013aab9-f854-4866-ad55-5b8350adb50a",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"The `calc_loss_loader` is exactly the same as in chapter 5:"
|
"The `calc_loss_loader` is exactly the same as in chapter 5"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 27,
|
"execution_count": 32,
|
||||||
"id": "b7b83e10-5720-45e7-ac5e-369417ca846b",
|
"id": "b7b83e10-5720-45e7-ac5e-369417ca846b",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
@ -1651,14 +1752,12 @@
|
|||||||
"id": "56826ecd-6e74-40e6-b772-d3541e585067",
|
"id": "56826ecd-6e74-40e6-b772-d3541e585067",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"- Using the `calc_closs_loader`, we compute the initial training, validation, and test set losses before we start training\n",
|
"- Using the `calc_closs_loader`, we compute the initial training, validation, and test set losses before we start training"
|
||||||
"- Here, we use `torch.no_grad()` so that no gradients are computed during the forward pass, which reduces memory consumption and speeds up computations since we are not training the model yet\n",
|
|
||||||
"- Via the `device` setting, the model automatically runs on a GPU if a GPU with Nvidia CUDA support is available and otherwise runs on a CPU"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 28,
|
"execution_count": 33,
|
||||||
"id": "f6f00e53-5beb-4e64-b147-f26fd481c6ff",
|
"id": "f6f00e53-5beb-4e64-b147-f26fd481c6ff",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"colab": {
|
"colab": {
|
||||||
@ -1672,18 +1771,13 @@
|
|||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"Training loss: 3.095\n",
|
"Training loss: 2.453\n",
|
||||||
"Validation loss: 2.583\n",
|
"Validation loss: 2.583\n",
|
||||||
"Test loss: 2.322\n"
|
"Test loss: 2.322\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
|
||||||
"model.to(device) # no assignment model = model.to(device) necessary for nn.Module classes\n",
|
|
||||||
"\n",
|
|
||||||
"torch.manual_seed(123) # For reproducibility due to the shuffling in the training data loader\n",
|
|
||||||
"\n",
|
|
||||||
"with torch.no_grad(): # Disable gradient tracking for efficiency because we are not training, yet\n",
|
"with torch.no_grad(): # Disable gradient tracking for efficiency because we are not training, yet\n",
|
||||||
" train_loss = calc_loss_loader(train_loader, model, device, num_batches=5)\n",
|
" train_loss = calc_loss_loader(train_loader, model, device, num_batches=5)\n",
|
||||||
" val_loss = calc_loss_loader(val_loader, model, device, num_batches=5)\n",
|
" val_loss = calc_loss_loader(val_loader, model, device, num_batches=5)\n",
|
||||||
@ -1694,93 +1788,12 @@
|
|||||||
"print(f\"Test loss: {test_loss:.3f}\")"
|
"print(f\"Test loss: {test_loss:.3f}\")"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"id": "b109556e-ddae-49fd-ad08-e6fa1032ea7a",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"- Similar to the `calc_loss_loader` function above, we can define a `calc_accuracy_loader` function that calculates the classification accuracy by checking how many predicted class (spam and ham) labels match the given labels in the dataset\n",
|
|
||||||
"- Note that the classification accuracy is a mathematically non-differentiable function, and we only use it for evaluation; hence, we can disable the gradient calculation permanently to save resources here\n",
|
|
||||||
"- We can disable the gradient tracking either using the `with torch.no_grad():` inside the function or by using the `@torch.no_grad()` function decorator"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 29,
|
|
||||||
"id": "64ce5b12-84cd-488c-8ea7-4cef5b2d947e",
|
|
||||||
"metadata": {
|
|
||||||
"colab": {
|
|
||||||
"base_uri": "https://localhost:8080/"
|
|
||||||
},
|
|
||||||
"id": "64ce5b12-84cd-488c-8ea7-4cef5b2d947e",
|
|
||||||
"outputId": "239581b4-fd0f-4adf-e67b-364e0f0f96b7"
|
|
||||||
},
|
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"@torch.no_grad() # Disable gradient tracking for efficiency\n",
|
|
||||||
"def calc_accuracy_loader(data_loader, model, device, num_batches=None):\n",
|
|
||||||
" model.eval()\n",
|
|
||||||
" correct_predictions, num_examples = 0, 0\n",
|
|
||||||
"\n",
|
|
||||||
" if num_batches is None:\n",
|
|
||||||
" num_batches = len(data_loader)\n",
|
|
||||||
" else:\n",
|
|
||||||
" num_batches = min(num_batches, len(data_loader))\n",
|
|
||||||
" for i, (input_batch, target_batch) in enumerate(data_loader):\n",
|
|
||||||
" if i < num_batches:\n",
|
|
||||||
" input_batch, target_batch = input_batch.to(device), target_batch.to(device)\n",
|
|
||||||
" logits = model(input_batch)[:, -1, :] # Logits of last output token\n",
|
|
||||||
" predicted_labels = torch.argmax(logits, dim=-1)\n",
|
|
||||||
"\n",
|
|
||||||
" num_examples += predicted_labels.shape[0]\n",
|
|
||||||
" correct_predictions += (predicted_labels == target_batch).sum().item()\n",
|
|
||||||
" else:\n",
|
|
||||||
" break\n",
|
|
||||||
" return correct_predictions / num_examples"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "markdown",
|
|
||||||
"id": "90521a9a-639c-4c7f-a5c0-aca8fa5d4c1b",
|
|
||||||
"metadata": {},
|
|
||||||
"source": [
|
|
||||||
"- Let's check the initial classification accuracy before we start training the model"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"cell_type": "code",
|
|
||||||
"execution_count": 30,
|
|
||||||
"id": "2160418f-988b-40f3-bce8-e431021e97dc",
|
|
||||||
"metadata": {},
|
|
||||||
"outputs": [
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Training accuracy: 46.25%\n",
|
|
||||||
"Validation accuracy: 45.00%\n",
|
|
||||||
"Test accuracy: 48.75%\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
|
||||||
"torch.manual_seed(123)\n",
|
|
||||||
"train_accuracy = calc_accuracy_loader(train_loader, model, device, num_batches=10)\n",
|
|
||||||
"val_accuracy = calc_accuracy_loader(val_loader, model, device, num_batches=10)\n",
|
|
||||||
"test_accuracy = calc_accuracy_loader(test_loader, model, device, num_batches=10)\n",
|
|
||||||
"\n",
|
|
||||||
"print(f\"Training accuracy: {train_accuracy*100:.2f}%\")\n",
|
|
||||||
"print(f\"Validation accuracy: {val_accuracy*100:.2f}%\")\n",
|
|
||||||
"print(f\"Test accuracy: {test_accuracy*100:.2f}%\")"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"id": "e04b980b-e583-4f62-84a0-4edafaf99d5d",
|
"id": "e04b980b-e583-4f62-84a0-4edafaf99d5d",
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"source": [
|
"source": [
|
||||||
"- As we can see, the model only gets roughly half (50%) of the predictions correctly\n",
|
"- In the next section, we train the model to improve the loss values and consequently the classification accuracy"
|
||||||
"- In the next section, we train the model to improve the classification accuracy"
|
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
Loading…
x
Reference in New Issue
Block a user