mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-08-31 20:08:08 +00:00
add roberta option (#135)
This commit is contained in:
parent
d088753fca
commit
d1edfcb63f
@ -97,6 +97,15 @@ Test accuracy: 90.81%
|
||||
|
||||
---
|
||||
|
||||
A 355M parameter encoder-style [RoBERTa](https://arxiv.org/abs/1907.11692) model, starting for the pretrained weights and only training the last transformer block plus output layers:
|
||||
|
||||
|
||||
```bash
|
||||
python train-bert-hf.py --bert_model roberta
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
A scikit-learn Logistic Regression model as a basline.
|
||||
|
||||
```bash
|
||||
|
@ -164,15 +164,26 @@ if __name__ == "__main__":
|
||||
"Which layers to train. Options: 'all', 'last_block', 'last_layer'."
|
||||
)
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bert_model",
|
||||
type=str,
|
||||
default="distilbert",
|
||||
help=(
|
||||
"Which layers to train. Options: 'all', 'last_block', 'last_layer'."
|
||||
)
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
###############################
|
||||
# Load model
|
||||
###############################
|
||||
|
||||
torch.manual_seed(123)
|
||||
if args.bert_model == "distilbert":
|
||||
|
||||
model = AutoModelForSequenceClassification.from_pretrained(
|
||||
"distilbert-base-uncased", num_labels=2
|
||||
)
|
||||
torch.manual_seed(123)
|
||||
model.out_head = torch.nn.Linear(in_features=768, out_features=2)
|
||||
|
||||
if args.trainable_layers == "last_layer":
|
||||
@ -188,8 +199,36 @@ if __name__ == "__main__":
|
||||
else:
|
||||
raise ValueError("Invalid --trainable_layers argument.")
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
|
||||
|
||||
elif args.bert_model == "roberta":
|
||||
|
||||
model = AutoModelForSequenceClassification.from_pretrained(
|
||||
"FacebookAI/roberta-large", num_labels=2
|
||||
)
|
||||
model.classifier.out_proj = torch.nn.Linear(in_features=1024, out_features=2)
|
||||
|
||||
if args.trainable_layers == "last_layer":
|
||||
pass
|
||||
elif args.trainable_layers == "last_block":
|
||||
for param in model.classifier.parameters():
|
||||
param.requires_grad = True
|
||||
for param in model.roberta.encoder.layer[-1].parameters():
|
||||
param.requires_grad = True
|
||||
elif args.trainable_layers == "all":
|
||||
for param in model.parameters():
|
||||
param.requires_grad = True
|
||||
else:
|
||||
raise ValueError("Invalid --trainable_layers argument.")
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("FacebookAI/roberta-large")
|
||||
|
||||
else:
|
||||
raise ValueError("Selected --bert_model not supported.")
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
###############################
|
||||
# Instantiate dataloaders
|
||||
@ -204,7 +243,6 @@ if __name__ == "__main__":
|
||||
file_names = ["train.csv", "val.csv", "test.csv"]
|
||||
all_exist = all((base_path / file_name).exists() for file_name in file_names)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
|
||||
pad_token_id = tokenizer.encode(tokenizer.pad_token)
|
||||
|
||||
train_dataset = IMDBDataset(base_path / "train.csv", max_length=256, tokenizer=tokenizer, pad_token_id=pad_token_id)
|
||||
|
Loading…
x
Reference in New Issue
Block a user