mirror of
				https://github.com/rasbt/LLMs-from-scratch.git
				synced 2025-11-04 03:40:21 +00:00 
			
		
		
		
	use probas in argmax
This commit is contained in:
		
							parent
							
								
									9cc9c4244e
								
							
						
					
					
						commit
						c88e8edf72
					
				@ -37,7 +37,7 @@
 | 
			
		||||
     "name": "stdout",
 | 
			
		||||
     "output_type": "stream",
 | 
			
		||||
     "text": [
 | 
			
		||||
      "2.0.1\n"
 | 
			
		||||
      "2.2.1\n"
 | 
			
		||||
     ]
 | 
			
		||||
    }
 | 
			
		||||
   ],
 | 
			
		||||
@ -591,13 +591,13 @@
 | 
			
		||||
     "output_type": "stream",
 | 
			
		||||
     "text": [
 | 
			
		||||
      "Parameter containing:\n",
 | 
			
		||||
      "tensor([[-0.0064,  0.0004, -0.0903,  ..., -0.1316,  0.0910,  0.0363],\n",
 | 
			
		||||
      "        [ 0.1354,  0.1124, -0.0476,  ...,  0.0578,  0.1014,  0.0008],\n",
 | 
			
		||||
      "        [ 0.0975, -0.0478,  0.0298,  ...,  0.0416,  0.0849,  0.1314],\n",
 | 
			
		||||
      "tensor([[ 0.0956,  0.1280, -0.0696,  ...,  0.0961,  0.0631,  0.1349],\n",
 | 
			
		||||
      "        [ 0.0983,  0.0580, -0.0574,  ...,  0.0981,  0.0370,  0.0516],\n",
 | 
			
		||||
      "        [-0.0429, -0.1411, -0.1399,  ...,  0.0767,  0.0019,  0.1400],\n",
 | 
			
		||||
      "        ...,\n",
 | 
			
		||||
      "        [ 0.0118,  0.0240,  0.0420,  ..., -0.1305, -0.0517, -0.0826],\n",
 | 
			
		||||
      "        [-0.0323,  0.1073,  0.0215,  ..., -0.1264, -0.1100,  0.1232],\n",
 | 
			
		||||
      "        [ 0.0861,  0.0403, -0.0545,  ...,  0.1352,  0.0817, -0.0938]],\n",
 | 
			
		||||
      "        [-0.0777, -0.0726,  0.1273,  ..., -0.0613,  0.0491, -0.1381],\n",
 | 
			
		||||
      "        [-0.0830, -0.0969, -0.0473,  ...,  0.0762,  0.1318, -0.1174],\n",
 | 
			
		||||
      "        [ 0.0468, -0.0213,  0.0387,  ...,  0.0639,  0.0927, -0.0668]],\n",
 | 
			
		||||
      "       requires_grad=True)\n"
 | 
			
		||||
     ]
 | 
			
		||||
    }
 | 
			
		||||
@ -881,10 +881,21 @@
 | 
			
		||||
  },
 | 
			
		||||
  {
 | 
			
		||||
   "cell_type": "code",
 | 
			
		||||
   "execution_count": null,
 | 
			
		||||
   "execution_count": 37,
 | 
			
		||||
   "id": "4db4d7f4-82da-44a4-b94e-ee04665d9c3c",
 | 
			
		||||
   "metadata": {},
 | 
			
		||||
   "outputs": [],
 | 
			
		||||
   "outputs": [
 | 
			
		||||
    {
 | 
			
		||||
     "name": "stdout",
 | 
			
		||||
     "output_type": "stream",
 | 
			
		||||
     "text": [
 | 
			
		||||
      "Batch 1: tensor([[-1.2000,  3.1000],\n",
 | 
			
		||||
      "        [-0.5000,  2.6000]]) tensor([0, 0])\n",
 | 
			
		||||
      "Batch 2: tensor([[ 2.3000, -1.1000],\n",
 | 
			
		||||
      "        [-0.9000,  2.9000]]) tensor([1, 0])\n"
 | 
			
		||||
     ]
 | 
			
		||||
    }
 | 
			
		||||
   ],
 | 
			
		||||
   "source": [
 | 
			
		||||
    "for idx, (x, y) in enumerate(train_loader):\n",
 | 
			
		||||
    "    print(f\"Batch {idx+1}:\", x, y)"
 | 
			
		||||
@ -1000,7 +1011,7 @@
 | 
			
		||||
    "probas = torch.softmax(outputs, dim=1)\n",
 | 
			
		||||
    "print(probas)\n",
 | 
			
		||||
    "\n",
 | 
			
		||||
    "predictions = torch.argmax(outputs, dim=1)\n",
 | 
			
		||||
    "predictions = torch.argmax(probas, dim=1)\n",
 | 
			
		||||
    "print(predictions)"
 | 
			
		||||
   ]
 | 
			
		||||
  },
 | 
			
		||||
@ -1254,7 +1265,7 @@
 | 
			
		||||
   "name": "python",
 | 
			
		||||
   "nbconvert_exporter": "python",
 | 
			
		||||
   "pygments_lexer": "ipython3",
 | 
			
		||||
   "version": "3.10.6"
 | 
			
		||||
   "version": "3.11.4"
 | 
			
		||||
  }
 | 
			
		||||
 },
 | 
			
		||||
 "nbformat": 4,
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user