mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-08-31 20:08:08 +00:00
add roberta option (#135)
This commit is contained in:
parent
d088753fca
commit
d1edfcb63f
@ -97,6 +97,15 @@ Test accuracy: 90.81%
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
A 355M parameter encoder-style [RoBERTa](https://arxiv.org/abs/1907.11692) model, starting for the pretrained weights and only training the last transformer block plus output layers:
|
||||||
|
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python train-bert-hf.py --bert_model roberta
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
A scikit-learn Logistic Regression model as a basline.
|
A scikit-learn Logistic Regression model as a basline.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
@ -164,32 +164,71 @@ if __name__ == "__main__":
|
|||||||
"Which layers to train. Options: 'all', 'last_block', 'last_layer'."
|
"Which layers to train. Options: 'all', 'last_block', 'last_layer'."
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--bert_model",
|
||||||
|
type=str,
|
||||||
|
default="distilbert",
|
||||||
|
help=(
|
||||||
|
"Which layers to train. Options: 'all', 'last_block', 'last_layer'."
|
||||||
|
)
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
###############################
|
###############################
|
||||||
# Load model
|
# Load model
|
||||||
###############################
|
###############################
|
||||||
model = AutoModelForSequenceClassification.from_pretrained(
|
|
||||||
"distilbert-base-uncased", num_labels=2
|
|
||||||
)
|
|
||||||
torch.manual_seed(123)
|
|
||||||
model.out_head = torch.nn.Linear(in_features=768, out_features=2)
|
|
||||||
|
|
||||||
if args.trainable_layers == "last_layer":
|
torch.manual_seed(123)
|
||||||
pass
|
if args.bert_model == "distilbert":
|
||||||
elif args.trainable_layers == "last_block":
|
|
||||||
for param in model.pre_classifier.parameters():
|
model = AutoModelForSequenceClassification.from_pretrained(
|
||||||
param.requires_grad = True
|
"distilbert-base-uncased", num_labels=2
|
||||||
for param in model.distilbert.transformer.layer[-1].parameters():
|
)
|
||||||
param.requires_grad = True
|
model.out_head = torch.nn.Linear(in_features=768, out_features=2)
|
||||||
elif args.trainable_layers == "all":
|
|
||||||
for param in model.parameters():
|
if args.trainable_layers == "last_layer":
|
||||||
param.requires_grad = True
|
pass
|
||||||
|
elif args.trainable_layers == "last_block":
|
||||||
|
for param in model.pre_classifier.parameters():
|
||||||
|
param.requires_grad = True
|
||||||
|
for param in model.distilbert.transformer.layer[-1].parameters():
|
||||||
|
param.requires_grad = True
|
||||||
|
elif args.trainable_layers == "all":
|
||||||
|
for param in model.parameters():
|
||||||
|
param.requires_grad = True
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid --trainable_layers argument.")
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
|
||||||
|
|
||||||
|
elif args.bert_model == "roberta":
|
||||||
|
|
||||||
|
model = AutoModelForSequenceClassification.from_pretrained(
|
||||||
|
"FacebookAI/roberta-large", num_labels=2
|
||||||
|
)
|
||||||
|
model.classifier.out_proj = torch.nn.Linear(in_features=1024, out_features=2)
|
||||||
|
|
||||||
|
if args.trainable_layers == "last_layer":
|
||||||
|
pass
|
||||||
|
elif args.trainable_layers == "last_block":
|
||||||
|
for param in model.classifier.parameters():
|
||||||
|
param.requires_grad = True
|
||||||
|
for param in model.roberta.encoder.layer[-1].parameters():
|
||||||
|
param.requires_grad = True
|
||||||
|
elif args.trainable_layers == "all":
|
||||||
|
for param in model.parameters():
|
||||||
|
param.requires_grad = True
|
||||||
|
else:
|
||||||
|
raise ValueError("Invalid --trainable_layers argument.")
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("FacebookAI/roberta-large")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError("Invalid --trainable_layers argument.")
|
raise ValueError("Selected --bert_model not supported.")
|
||||||
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
model.to(device)
|
model.to(device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
###############################
|
###############################
|
||||||
# Instantiate dataloaders
|
# Instantiate dataloaders
|
||||||
@ -204,7 +243,6 @@ if __name__ == "__main__":
|
|||||||
file_names = ["train.csv", "val.csv", "test.csv"]
|
file_names = ["train.csv", "val.csv", "test.csv"]
|
||||||
all_exist = all((base_path / file_name).exists() for file_name in file_names)
|
all_exist = all((base_path / file_name).exists() for file_name in file_names)
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
|
|
||||||
pad_token_id = tokenizer.encode(tokenizer.pad_token)
|
pad_token_id = tokenizer.encode(tokenizer.pad_token)
|
||||||
|
|
||||||
train_dataset = IMDBDataset(base_path / "train.csv", max_length=256, tokenizer=tokenizer, pad_token_id=pad_token_id)
|
train_dataset = IMDBDataset(base_path / "train.csv", max_length=256, tokenizer=tokenizer, pad_token_id=pad_token_id)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user