mirror of
				https://github.com/rasbt/LLMs-from-scratch.git
				synced 2025-10-30 17:29:59 +00:00 
			
		
		
		
	sklearn baseline and roberta-large update
This commit is contained in:
		
							parent
							
								
									5acab58d41
								
							
						
					
					
						commit
						8eb6fc0ad0
					
				| @ -132,3 +132,25 @@ Training accuracy: 93.44% | |||||||
| Validation accuracy: 93.02% | Validation accuracy: 93.02% | ||||||
| Test accuracy: 92.95% | Test accuracy: 92.95% | ||||||
| ``` | ``` | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | <br> | ||||||
|  | 
 | ||||||
|  | --- | ||||||
|  | 
 | ||||||
|  | <br> | ||||||
|  | 
 | ||||||
|  | A scikit-learn logistic regression classifier as a baseline. | ||||||
|  | 
 | ||||||
|  | ``` | ||||||
|  | Dummy classifier: | ||||||
|  | Training Accuracy: 50.01% | ||||||
|  | Validation Accuracy: 50.14% | ||||||
|  | Test Accuracy: 49.91% | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | Logistic regression classifier: | ||||||
|  | Training Accuracy: 99.80% | ||||||
|  | Validation Accuracy: 88.62% | ||||||
|  | Test Accuracy: 88.85% | ||||||
|  | ``` | ||||||
| @ -327,7 +327,7 @@ if __name__ == "__main__": | |||||||
|         max_length=256, |         max_length=256, | ||||||
|         tokenizer=tokenizer, |         tokenizer=tokenizer, | ||||||
|         pad_token_id=tokenizer.pad_token_id, |         pad_token_id=tokenizer.pad_token_id, | ||||||
|         se_attention_mask=use_attention_mask |         use_attention_mask=use_attention_mask | ||||||
|     ) |     ) | ||||||
|     test_dataset = IMDBDataset( |     test_dataset = IMDBDataset( | ||||||
|         base_path / "test.csv", |         base_path / "test.csv", | ||||||
|  | |||||||
| @ -235,7 +235,14 @@ if __name__ == "__main__": | |||||||
|             "Number of epochs." |             "Number of epochs." | ||||||
|         ) |         ) | ||||||
|     ) |     ) | ||||||
| 
 |     parser.add_argument( | ||||||
|  |         "--learning_rate", | ||||||
|  |         type=float, | ||||||
|  |         default=5e-5, | ||||||
|  |         help=( | ||||||
|  |             "Learning rate." | ||||||
|  |         ) | ||||||
|  |     ) | ||||||
|     args = parser.parse_args() |     args = parser.parse_args() | ||||||
| 
 | 
 | ||||||
|     if args.trainable_token == "first": |     if args.trainable_token == "first": | ||||||
| @ -346,7 +353,7 @@ if __name__ == "__main__": | |||||||
| 
 | 
 | ||||||
|     start_time = time.time() |     start_time = time.time() | ||||||
|     torch.manual_seed(123) |     torch.manual_seed(123) | ||||||
|     optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5, weight_decay=0.1) |     optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate, weight_decay=0.1) | ||||||
| 
 | 
 | ||||||
|     train_losses, val_losses, train_accs, val_accs, examples_seen = train_classifier_simple( |     train_losses, val_losses, train_accs, val_accs, examples_seen = train_classifier_simple( | ||||||
|         model, train_loader, val_loader, optimizer, device, |         model, train_loader, val_loader, optimizer, device, | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 rasbt
						rasbt