mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-10-29 17:01:30 +00:00
sklearn baseline and roberta-large update
This commit is contained in:
parent
5acab58d41
commit
8eb6fc0ad0
@ -132,3 +132,25 @@ Training accuracy: 93.44%
|
||||
Validation accuracy: 93.02%
|
||||
Test accuracy: 92.95%
|
||||
```
|
||||
|
||||
|
||||
<br>
|
||||
|
||||
---
|
||||
|
||||
<br>
|
||||
|
||||
A scikit-learn logistic regression classifier as a baseline.
|
||||
|
||||
```
|
||||
Dummy classifier:
|
||||
Training Accuracy: 50.01%
|
||||
Validation Accuracy: 50.14%
|
||||
Test Accuracy: 49.91%
|
||||
|
||||
|
||||
Logistic regression classifier:
|
||||
Training Accuracy: 99.80%
|
||||
Validation Accuracy: 88.62%
|
||||
Test Accuracy: 88.85%
|
||||
```
|
||||
@ -327,7 +327,7 @@ if __name__ == "__main__":
|
||||
max_length=256,
|
||||
tokenizer=tokenizer,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
se_attention_mask=use_attention_mask
|
||||
use_attention_mask=use_attention_mask
|
||||
)
|
||||
test_dataset = IMDBDataset(
|
||||
base_path / "test.csv",
|
||||
|
||||
@ -235,7 +235,14 @@ if __name__ == "__main__":
|
||||
"Number of epochs."
|
||||
)
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--learning_rate",
|
||||
type=float,
|
||||
default=5e-5,
|
||||
help=(
|
||||
"Learning rate."
|
||||
)
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.trainable_token == "first":
|
||||
@ -346,7 +353,7 @@ if __name__ == "__main__":
|
||||
|
||||
start_time = time.time()
|
||||
torch.manual_seed(123)
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5, weight_decay=0.1)
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate, weight_decay=0.1)
|
||||
|
||||
train_losses, val_losses, train_accs, val_accs, examples_seen = train_classifier_simple(
|
||||
model, train_loader, val_loader, optimizer, device,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user