mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-10-27 07:49:25 +00:00
use probas in argmax
This commit is contained in:
parent
9cc9c4244e
commit
c88e8edf72
@ -37,7 +37,7 @@
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"2.0.1\n"
|
||||
"2.2.1\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
@ -591,13 +591,13 @@
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Parameter containing:\n",
|
||||
"tensor([[-0.0064, 0.0004, -0.0903, ..., -0.1316, 0.0910, 0.0363],\n",
|
||||
" [ 0.1354, 0.1124, -0.0476, ..., 0.0578, 0.1014, 0.0008],\n",
|
||||
" [ 0.0975, -0.0478, 0.0298, ..., 0.0416, 0.0849, 0.1314],\n",
|
||||
"tensor([[ 0.0956, 0.1280, -0.0696, ..., 0.0961, 0.0631, 0.1349],\n",
|
||||
" [ 0.0983, 0.0580, -0.0574, ..., 0.0981, 0.0370, 0.0516],\n",
|
||||
" [-0.0429, -0.1411, -0.1399, ..., 0.0767, 0.0019, 0.1400],\n",
|
||||
" ...,\n",
|
||||
" [ 0.0118, 0.0240, 0.0420, ..., -0.1305, -0.0517, -0.0826],\n",
|
||||
" [-0.0323, 0.1073, 0.0215, ..., -0.1264, -0.1100, 0.1232],\n",
|
||||
" [ 0.0861, 0.0403, -0.0545, ..., 0.1352, 0.0817, -0.0938]],\n",
|
||||
" [-0.0777, -0.0726, 0.1273, ..., -0.0613, 0.0491, -0.1381],\n",
|
||||
" [-0.0830, -0.0969, -0.0473, ..., 0.0762, 0.1318, -0.1174],\n",
|
||||
" [ 0.0468, -0.0213, 0.0387, ..., 0.0639, 0.0927, -0.0668]],\n",
|
||||
" requires_grad=True)\n"
|
||||
]
|
||||
}
|
||||
@ -881,10 +881,21 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 37,
|
||||
"id": "4db4d7f4-82da-44a4-b94e-ee04665d9c3c",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Batch 1: tensor([[-1.2000, 3.1000],\n",
|
||||
" [-0.5000, 2.6000]]) tensor([0, 0])\n",
|
||||
"Batch 2: tensor([[ 2.3000, -1.1000],\n",
|
||||
" [-0.9000, 2.9000]]) tensor([1, 0])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"for idx, (x, y) in enumerate(train_loader):\n",
|
||||
" print(f\"Batch {idx+1}:\", x, y)"
|
||||
@ -1000,7 +1011,7 @@
|
||||
"probas = torch.softmax(outputs, dim=1)\n",
|
||||
"print(probas)\n",
|
||||
"\n",
|
||||
"predictions = torch.argmax(outputs, dim=1)\n",
|
||||
"predictions = torch.argmax(probas, dim=1)\n",
|
||||
"print(predictions)"
|
||||
]
|
||||
},
|
||||
@ -1254,7 +1265,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.6"
|
||||
"version": "3.11.4"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user