sklearn baseline and roberta-large update

This commit is contained in:
rasbt 2024-08-26 10:31:54 +02:00
parent 5acab58d41
commit 8eb6fc0ad0
3 changed files with 32 additions and 3 deletions

View File

@ -132,3 +132,25 @@ Training accuracy: 93.44%
Validation accuracy: 93.02% Validation accuracy: 93.02%
Test accuracy: 92.95% Test accuracy: 92.95%
``` ```
<br>
---
<br>
A scikit-learn logistic regression classifier as a baseline.
```
Dummy classifier:
Training Accuracy: 50.01%
Validation Accuracy: 50.14%
Test Accuracy: 49.91%
Logistic regression classifier:
Training Accuracy: 99.80%
Validation Accuracy: 88.62%
Test Accuracy: 88.85%
```

View File

@ -327,7 +327,7 @@ if __name__ == "__main__":
max_length=256, max_length=256,
tokenizer=tokenizer, tokenizer=tokenizer,
pad_token_id=tokenizer.pad_token_id, pad_token_id=tokenizer.pad_token_id,
se_attention_mask=use_attention_mask use_attention_mask=use_attention_mask
) )
test_dataset = IMDBDataset( test_dataset = IMDBDataset(
base_path / "test.csv", base_path / "test.csv",

View File

@ -235,7 +235,14 @@ if __name__ == "__main__":
"Number of epochs." "Number of epochs."
) )
) )
parser.add_argument(
"--learning_rate",
type=float,
default=5e-5,
help=(
"Learning rate."
)
)
args = parser.parse_args() args = parser.parse_args()
if args.trainable_token == "first": if args.trainable_token == "first":
@ -346,7 +353,7 @@ if __name__ == "__main__":
start_time = time.time() start_time = time.time()
torch.manual_seed(123) torch.manual_seed(123)
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5, weight_decay=0.1) optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate, weight_decay=0.1)
train_losses, val_losses, train_accs, val_accs, examples_seen = train_classifier_simple( train_losses, val_losses, train_accs, val_accs, examples_seen = train_classifier_simple(
model, train_loader, val_loader, optimizer, device, model, train_loader, val_loader, optimizer, device,