diff --git a/ch06/03_bonus_imdb-classification/README.md b/ch06/03_bonus_imdb-classification/README.md index 08ffa08..2828026 100644 --- a/ch06/03_bonus_imdb-classification/README.md +++ b/ch06/03_bonus_imdb-classification/README.md @@ -132,3 +132,25 @@ Training accuracy: 93.44% Validation accuracy: 93.02% Test accuracy: 92.95% ``` + + +
+ +--- + +
+ +A scikit-learn logistic regression classifier as a baseline. + +``` +Dummy classifier: +Training Accuracy: 50.01% +Validation Accuracy: 50.14% +Test Accuracy: 49.91% + + +Logistic regression classifier: +Training Accuracy: 99.80% +Validation Accuracy: 88.62% +Test Accuracy: 88.85% +``` \ No newline at end of file diff --git a/ch06/03_bonus_imdb-classification/train_bert_hf.py b/ch06/03_bonus_imdb-classification/train_bert_hf.py index e4fb267..66dd0f1 100644 --- a/ch06/03_bonus_imdb-classification/train_bert_hf.py +++ b/ch06/03_bonus_imdb-classification/train_bert_hf.py @@ -327,7 +327,7 @@ if __name__ == "__main__": max_length=256, tokenizer=tokenizer, pad_token_id=tokenizer.pad_token_id, - se_attention_mask=use_attention_mask + use_attention_mask=use_attention_mask ) test_dataset = IMDBDataset( base_path / "test.csv", diff --git a/ch06/03_bonus_imdb-classification/train_gpt.py b/ch06/03_bonus_imdb-classification/train_gpt.py index d91403f..ca092ea 100644 --- a/ch06/03_bonus_imdb-classification/train_gpt.py +++ b/ch06/03_bonus_imdb-classification/train_gpt.py @@ -235,7 +235,14 @@ if __name__ == "__main__": "Number of epochs." ) ) - + parser.add_argument( + "--learning_rate", + type=float, + default=5e-5, + help=( + "Learning rate." + ) + ) args = parser.parse_args() if args.trainable_token == "first": @@ -346,7 +353,7 @@ if __name__ == "__main__": start_time = time.time() torch.manual_seed(123) - optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5, weight_decay=0.1) + optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate, weight_decay=0.1) train_losses, val_losses, train_accs, val_accs, examples_seen = train_classifier_simple( model, train_loader, val_loader, optimizer, device,