mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-10-29 17:01:30 +00:00
sklearn baseline and roberta-large update
This commit is contained in:
parent
5acab58d41
commit
8eb6fc0ad0
@ -132,3 +132,25 @@ Training accuracy: 93.44%
|
|||||||
Validation accuracy: 93.02%
|
Validation accuracy: 93.02%
|
||||||
Test accuracy: 92.95%
|
Test accuracy: 92.95%
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
|
<br>
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
<br>
|
||||||
|
|
||||||
|
A scikit-learn logistic regression classifier as a baseline.
|
||||||
|
|
||||||
|
```
|
||||||
|
Dummy classifier:
|
||||||
|
Training Accuracy: 50.01%
|
||||||
|
Validation Accuracy: 50.14%
|
||||||
|
Test Accuracy: 49.91%
|
||||||
|
|
||||||
|
|
||||||
|
Logistic regression classifier:
|
||||||
|
Training Accuracy: 99.80%
|
||||||
|
Validation Accuracy: 88.62%
|
||||||
|
Test Accuracy: 88.85%
|
||||||
|
```
|
||||||
@ -327,7 +327,7 @@ if __name__ == "__main__":
|
|||||||
max_length=256,
|
max_length=256,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
pad_token_id=tokenizer.pad_token_id,
|
pad_token_id=tokenizer.pad_token_id,
|
||||||
se_attention_mask=use_attention_mask
|
use_attention_mask=use_attention_mask
|
||||||
)
|
)
|
||||||
test_dataset = IMDBDataset(
|
test_dataset = IMDBDataset(
|
||||||
base_path / "test.csv",
|
base_path / "test.csv",
|
||||||
|
|||||||
@ -235,7 +235,14 @@ if __name__ == "__main__":
|
|||||||
"Number of epochs."
|
"Number of epochs."
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--learning_rate",
|
||||||
|
type=float,
|
||||||
|
default=5e-5,
|
||||||
|
help=(
|
||||||
|
"Learning rate."
|
||||||
|
)
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if args.trainable_token == "first":
|
if args.trainable_token == "first":
|
||||||
@ -346,7 +353,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
torch.manual_seed(123)
|
torch.manual_seed(123)
|
||||||
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5, weight_decay=0.1)
|
optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate, weight_decay=0.1)
|
||||||
|
|
||||||
train_losses, val_losses, train_accs, val_accs, examples_seen = train_classifier_simple(
|
train_losses, val_losses, train_accs, val_accs, examples_seen = train_classifier_simple(
|
||||||
model, train_loader, val_loader, optimizer, device,
|
model, train_loader, val_loader, optimizer, device,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user