From c88e8edf72c9c30b33d1c13b06f4cf4454850b2e Mon Sep 17 00:00:00 2001 From: rasbt Date: Tue, 26 Mar 2024 08:38:27 -0500 Subject: [PATCH] use probas in argmax --- .../03_main-chapter-code/code-part1.ipynb | 33 ++++++++++++------- 1 file changed, 22 insertions(+), 11 deletions(-) diff --git a/appendix-A/03_main-chapter-code/code-part1.ipynb b/appendix-A/03_main-chapter-code/code-part1.ipynb index 0589d7c..d351961 100644 --- a/appendix-A/03_main-chapter-code/code-part1.ipynb +++ b/appendix-A/03_main-chapter-code/code-part1.ipynb @@ -37,7 +37,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "2.0.1\n" + "2.2.1\n" ] } ], @@ -591,13 +591,13 @@ "output_type": "stream", "text": [ "Parameter containing:\n", - "tensor([[-0.0064, 0.0004, -0.0903, ..., -0.1316, 0.0910, 0.0363],\n", - " [ 0.1354, 0.1124, -0.0476, ..., 0.0578, 0.1014, 0.0008],\n", - " [ 0.0975, -0.0478, 0.0298, ..., 0.0416, 0.0849, 0.1314],\n", + "tensor([[ 0.0956, 0.1280, -0.0696, ..., 0.0961, 0.0631, 0.1349],\n", + " [ 0.0983, 0.0580, -0.0574, ..., 0.0981, 0.0370, 0.0516],\n", + " [-0.0429, -0.1411, -0.1399, ..., 0.0767, 0.0019, 0.1400],\n", " ...,\n", - " [ 0.0118, 0.0240, 0.0420, ..., -0.1305, -0.0517, -0.0826],\n", - " [-0.0323, 0.1073, 0.0215, ..., -0.1264, -0.1100, 0.1232],\n", - " [ 0.0861, 0.0403, -0.0545, ..., 0.1352, 0.0817, -0.0938]],\n", + " [-0.0777, -0.0726, 0.1273, ..., -0.0613, 0.0491, -0.1381],\n", + " [-0.0830, -0.0969, -0.0473, ..., 0.0762, 0.1318, -0.1174],\n", + " [ 0.0468, -0.0213, 0.0387, ..., 0.0639, 0.0927, -0.0668]],\n", " requires_grad=True)\n" ] } @@ -881,10 +881,21 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 37, "id": "4db4d7f4-82da-44a4-b94e-ee04665d9c3c", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Batch 1: tensor([[-1.2000, 3.1000],\n", + " [-0.5000, 2.6000]]) tensor([0, 0])\n", + "Batch 2: tensor([[ 2.3000, -1.1000],\n", + " [-0.9000, 2.9000]]) tensor([1, 0])\n" + ] + } + ], "source": [ "for idx, (x, y) in enumerate(train_loader):\n", " print(f\"Batch {idx+1}:\", x, y)" @@ -1000,7 +1011,7 @@ "probas = torch.softmax(outputs, dim=1)\n", "print(probas)\n", "\n", - "predictions = torch.argmax(outputs, dim=1)\n", + "predictions = torch.argmax(probas, dim=1)\n", "print(predictions)" ] }, @@ -1254,7 +1265,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.6" + "version": "3.11.4" } }, "nbformat": 4,