diff --git a/ch06/03_bonus_imdb-classification/README.md b/ch06/03_bonus_imdb-classification/README.md
index 08ffa08..2828026 100644
--- a/ch06/03_bonus_imdb-classification/README.md
+++ b/ch06/03_bonus_imdb-classification/README.md
@@ -132,3 +132,25 @@ Training accuracy: 93.44%
Validation accuracy: 93.02%
Test accuracy: 92.95%
```
+
+
+
+
+---
+
+
+
+A scikit-learn logistic regression classifier as a baseline.
+
+```
+Dummy classifier:
+Training Accuracy: 50.01%
+Validation Accuracy: 50.14%
+Test Accuracy: 49.91%
+
+
+Logistic regression classifier:
+Training Accuracy: 99.80%
+Validation Accuracy: 88.62%
+Test Accuracy: 88.85%
+```
\ No newline at end of file
diff --git a/ch06/03_bonus_imdb-classification/train_bert_hf.py b/ch06/03_bonus_imdb-classification/train_bert_hf.py
index e4fb267..66dd0f1 100644
--- a/ch06/03_bonus_imdb-classification/train_bert_hf.py
+++ b/ch06/03_bonus_imdb-classification/train_bert_hf.py
@@ -327,7 +327,7 @@ if __name__ == "__main__":
max_length=256,
tokenizer=tokenizer,
pad_token_id=tokenizer.pad_token_id,
- se_attention_mask=use_attention_mask
+ use_attention_mask=use_attention_mask
)
test_dataset = IMDBDataset(
base_path / "test.csv",
diff --git a/ch06/03_bonus_imdb-classification/train_gpt.py b/ch06/03_bonus_imdb-classification/train_gpt.py
index d91403f..ca092ea 100644
--- a/ch06/03_bonus_imdb-classification/train_gpt.py
+++ b/ch06/03_bonus_imdb-classification/train_gpt.py
@@ -235,7 +235,14 @@ if __name__ == "__main__":
"Number of epochs."
)
)
-
+ parser.add_argument(
+ "--learning_rate",
+ type=float,
+ default=5e-5,
+ help=(
+ "Learning rate."
+ )
+ )
args = parser.parse_args()
if args.trainable_token == "first":
@@ -346,7 +353,7 @@ if __name__ == "__main__":
start_time = time.time()
torch.manual_seed(123)
- optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5, weight_decay=0.1)
+ optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate, weight_decay=0.1)
train_losses, val_losses, train_accs, val_accs, examples_seen = train_classifier_simple(
model, train_loader, val_loader, optimizer, device,