use probas in argmax

This commit is contained in:
rasbt 2024-03-26 08:38:27 -05:00
parent 9cc9c4244e
commit c88e8edf72

View File

@ -37,7 +37,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"2.0.1\n"
"2.2.1\n"
]
}
],
@ -591,13 +591,13 @@
"output_type": "stream",
"text": [
"Parameter containing:\n",
"tensor([[-0.0064, 0.0004, -0.0903, ..., -0.1316, 0.0910, 0.0363],\n",
" [ 0.1354, 0.1124, -0.0476, ..., 0.0578, 0.1014, 0.0008],\n",
" [ 0.0975, -0.0478, 0.0298, ..., 0.0416, 0.0849, 0.1314],\n",
"tensor([[ 0.0956, 0.1280, -0.0696, ..., 0.0961, 0.0631, 0.1349],\n",
" [ 0.0983, 0.0580, -0.0574, ..., 0.0981, 0.0370, 0.0516],\n",
" [-0.0429, -0.1411, -0.1399, ..., 0.0767, 0.0019, 0.1400],\n",
" ...,\n",
" [ 0.0118, 0.0240, 0.0420, ..., -0.1305, -0.0517, -0.0826],\n",
" [-0.0323, 0.1073, 0.0215, ..., -0.1264, -0.1100, 0.1232],\n",
" [ 0.0861, 0.0403, -0.0545, ..., 0.1352, 0.0817, -0.0938]],\n",
" [-0.0777, -0.0726, 0.1273, ..., -0.0613, 0.0491, -0.1381],\n",
" [-0.0830, -0.0969, -0.0473, ..., 0.0762, 0.1318, -0.1174],\n",
" [ 0.0468, -0.0213, 0.0387, ..., 0.0639, 0.0927, -0.0668]],\n",
" requires_grad=True)\n"
]
}
@ -881,10 +881,21 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 37,
"id": "4db4d7f4-82da-44a4-b94e-ee04665d9c3c",
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Batch 1: tensor([[-1.2000, 3.1000],\n",
" [-0.5000, 2.6000]]) tensor([0, 0])\n",
"Batch 2: tensor([[ 2.3000, -1.1000],\n",
" [-0.9000, 2.9000]]) tensor([1, 0])\n"
]
}
],
"source": [
"for idx, (x, y) in enumerate(train_loader):\n",
" print(f\"Batch {idx+1}:\", x, y)"
@ -1000,7 +1011,7 @@
"probas = torch.softmax(outputs, dim=1)\n",
"print(probas)\n",
"\n",
"predictions = torch.argmax(outputs, dim=1)\n",
"predictions = torch.argmax(probas, dim=1)\n",
"print(predictions)"
]
},
@ -1254,7 +1265,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
"version": "3.11.4"
}
},
"nbformat": 4,