mirror of
				https://github.com/rasbt/LLMs-from-scratch.git
				synced 2025-10-26 07:20:09 +00:00 
			
		
		
		
	
						commit
						1bf24669f8
					
				| @ -350,9 +350,7 @@ if __name__ == "__main__": | ||||
|         } | ||||
|         model = GPTModel(BASE_CONFIG) | ||||
|         model.eval() | ||||
| 
 | ||||
|         device = "cpu" | ||||
|         model.to(device) | ||||
| 
 | ||||
|     # Code as it is used in the main chapter | ||||
|     else: | ||||
| @ -380,10 +378,7 @@ if __name__ == "__main__": | ||||
| 
 | ||||
|         model = GPTModel(BASE_CONFIG) | ||||
|         load_weights_into_gpt(model, params) | ||||
|         model.eval() | ||||
| 
 | ||||
|         device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||||
|         model.to(device) | ||||
| 
 | ||||
|     ######################################## | ||||
|     # Modify and pretrained model | ||||
| @ -396,6 +391,7 @@ if __name__ == "__main__": | ||||
| 
 | ||||
|     num_classes = 2 | ||||
|     model.out_head = torch.nn.Linear(in_features=BASE_CONFIG["emb_dim"], out_features=num_classes) | ||||
|     model.to(device) | ||||
| 
 | ||||
|     for param in model.trf_blocks[-1].parameters(): | ||||
|         param.requires_grad = True | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 Sebastian Raschka
						Sebastian Raschka