mirror of
				https://github.com/rasbt/LLMs-from-scratch.git
				synced 2025-10-31 01:41:26 +00:00 
			
		
		
		
	
							parent
							
								
									dd31946b2a
								
							
						
					
					
						commit
						d361cef65f
					
				| @ -1542,7 +1542,7 @@ | ||||
|    "source": [ | ||||
|     "def calc_loss_batch(input_batch, target_batch, model, device):\n", | ||||
|     "    input_batch, target_batch = input_batch.to(device), target_batch.to(device)\n", | ||||
|     "    logits = model(input_batch)[:, -1, :]  # Logits of last ouput token\n", | ||||
|     "    logits = model(input_batch)[:, -1, :]  # Logits of last output token\n", | ||||
|     "    loss = torch.nn.functional.cross_entropy(logits, target_batch)\n", | ||||
|     "    return loss" | ||||
|    ] | ||||
| @ -1665,7 +1665,7 @@ | ||||
|     "    for i, (input_batch, target_batch) in enumerate(data_loader):\n", | ||||
|     "        if i < num_batches:\n", | ||||
|     "            input_batch, target_batch = input_batch.to(device), target_batch.to(device)\n", | ||||
|     "            logits = model(input_batch)[:, -1, :]  # Logits of last ouput token\n", | ||||
|     "            logits = model(input_batch)[:, -1, :]  # Logits of last output token\n", | ||||
|     "            predicted_labels = torch.argmax(logits, dim=-1)\n", | ||||
|     "\n", | ||||
|     "            num_examples += predicted_labels.shape[0]\n", | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Ikko Eltociear Ashimine
						Ikko Eltociear Ashimine