mirror of
				https://github.com/rasbt/LLMs-from-scratch.git
				synced 2025-10-31 18:00:08 +00:00 
			
		
		
		
	
							parent
							
								
									dd31946b2a
								
							
						
					
					
						commit
						d361cef65f
					
				| @ -1542,7 +1542,7 @@ | |||||||
|    "source": [ |    "source": [ | ||||||
|     "def calc_loss_batch(input_batch, target_batch, model, device):\n", |     "def calc_loss_batch(input_batch, target_batch, model, device):\n", | ||||||
|     "    input_batch, target_batch = input_batch.to(device), target_batch.to(device)\n", |     "    input_batch, target_batch = input_batch.to(device), target_batch.to(device)\n", | ||||||
|     "    logits = model(input_batch)[:, -1, :]  # Logits of last ouput token\n", |     "    logits = model(input_batch)[:, -1, :]  # Logits of last output token\n", | ||||||
|     "    loss = torch.nn.functional.cross_entropy(logits, target_batch)\n", |     "    loss = torch.nn.functional.cross_entropy(logits, target_batch)\n", | ||||||
|     "    return loss" |     "    return loss" | ||||||
|    ] |    ] | ||||||
| @ -1665,7 +1665,7 @@ | |||||||
|     "    for i, (input_batch, target_batch) in enumerate(data_loader):\n", |     "    for i, (input_batch, target_batch) in enumerate(data_loader):\n", | ||||||
|     "        if i < num_batches:\n", |     "        if i < num_batches:\n", | ||||||
|     "            input_batch, target_batch = input_batch.to(device), target_batch.to(device)\n", |     "            input_batch, target_batch = input_batch.to(device), target_batch.to(device)\n", | ||||||
|     "            logits = model(input_batch)[:, -1, :]  # Logits of last ouput token\n", |     "            logits = model(input_batch)[:, -1, :]  # Logits of last output token\n", | ||||||
|     "            predicted_labels = torch.argmax(logits, dim=-1)\n", |     "            predicted_labels = torch.argmax(logits, dim=-1)\n", | ||||||
|     "\n", |     "\n", | ||||||
|     "            num_examples += predicted_labels.shape[0]\n", |     "            num_examples += predicted_labels.shape[0]\n", | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Ikko Eltociear Ashimine
						Ikko Eltociear Ashimine