trainable token -> trainable token position

This commit is contained in:
rasbt 2024-05-23 11:43:20 -05:00
parent 209a103d66
commit 30ba6a3f4b
2 changed files with 44 additions and 44 deletions

View File

@ -9,23 +9,23 @@ For example,
 
| | Model | Weights | Trainable token | Trainable layers | Context length | Training acc | Validation acc | Test acc | Training time | CPU/GPU |
| --- | ------------------ | ---------- | --------------- | ---------------- | ------------------------------------------------------ | ------------ | -------------- | -------- | ------------- | ------- |
| 1 | gpt2-small (124M) | pretrained | last | last_block | longest train ex. (120) | 96.63% | 99.33% | 95.00% | 0.28 min | A100 |
| 2 | gpt2-small (124M) | pretrained | first | last_block | longest train ex. (120) | 78.46% | 80.54% | 75.00% | 0.28 min | A100 |
| 3 | gpt2-small (124M) | pretrained | last | last_layer | longest train ex. (120) | 78.65% | 79.87% | 72.00% | 0.25 min | A100 |
| 4 | gpt2-small (124M) | pretrained | last | last_two_blocks | longest train ex. (120) | 98.85% | 98.66% | 98.33% | 0.33 min | A100 |
| 5 | gpt2-small (124M) | pretrained | last | all | longest train ex. (120) | 99.62% | 96.64% | 96.67% | 0.69 min | A100 |
| 6 | gpt2-medium (355M) | pretrained | last | last_block | longest train ex. (120) | 87.50% | 91.28% | 84.67% | 0.75 min | A100 |
| 7 | gpt2-large (774M) | pretrained | last | last_block | longest train ex. (120) | 99.52% | 98.66% | 96.67% | 1.50 min | A100 |
| 8 | gpt2-xl (1558M) | pretrained | last | last_block | longest train ex. (120) | 99.81% | 99.33% | 98.33% | 2.83 min | A100 |
| 9 | gpt2-small (124M) | random | last | all | longest train ex. (120) | 100% | 96.64% | 93.67% | 0.69 min | A100 |
| 10 | gpt2-small (124M) | pretrained | last | LoRA | longest train ex. (120) | 100.00% | 97.32% | 96.67% | 0.75 min | A100 |
| 11 | gpt2-small (124M) | pretrained | last | last_block | context length (1024) | 83.08% | 87.92% | 78.33% | 2.46 min | A100 |
| 12 | gpt2-small (124M) | pretrained | last | last_block | variable: no padding (batch size 1) | 100.00% | 98.66% | 98.00% | 1.75 min | A100 |
| 13 | gpt2-small (124M) | pretrained | last | last_block | variable: no padding (batch size 8) | 99.33% | 98.66% | 98.33% | 1.70 min | A100 |
| 14 | gpt2-small (124M) | pretrained | last | last_block | longest train ex. (120); but no causal mask | 99.23% | 98.66% | 95.33% | 0.29 min | A100 |
| 15 | gpt2-small (124M) | pretrained | last | last_block | longest train ex. (120) and `ignore_index` for padding | 96.63% | 99.33% | 95.00% | 0.28 min | A100 |
| | Model | Weights | Trainable token position | Trainable layers | Context length | Training acc | Validation acc | Test acc | Training time | CPU/GPU |
| ---- | ------------------ | ---------- | ------------------------ | ---------------- | ------------------------------------------------------ | ------------ | -------------- | -------- | ------------- | ------- |
| 1 | gpt2-small (124M) | pretrained | last | last_block | longest train ex. (120) | 96.63% | 99.33% | 95.00% | 0.28 min | A100 |
| 2 | gpt2-small (124M) | pretrained | first | last_block | longest train ex. (120) | 78.46% | 80.54% | 75.00% | 0.28 min | A100 |
| 3 | gpt2-small (124M) | pretrained | last | last_layer | longest train ex. (120) | 78.65% | 79.87% | 72.00% | 0.25 min | A100 |
| 4 | gpt2-small (124M) | pretrained | last | last_two_blocks | longest train ex. (120) | 98.85% | 98.66% | 98.33% | 0.33 min | A100 |
| 5 | gpt2-small (124M) | pretrained | last | all | longest train ex. (120) | 99.62% | 96.64% | 96.67% | 0.69 min | A100 |
| 6 | gpt2-medium (355M) | pretrained | last | last_block | longest train ex. (120) | 87.50% | 91.28% | 84.67% | 0.75 min | A100 |
| 7 | gpt2-large (774M) | pretrained | last | last_block | longest train ex. (120) | 99.52% | 98.66% | 96.67% | 1.50 min | A100 |
| 8 | gpt2-xl (1558M) | pretrained | last | last_block | longest train ex. (120) | 99.81% | 99.33% | 98.33% | 2.83 min | A100 |
| 9 | gpt2-small (124M) | random | last | all | longest train ex. (120) | 100% | 96.64% | 93.67% | 0.69 min | A100 |
| 10 | gpt2-small (124M) | pretrained | last | LoRA | longest train ex. (120) | 100.00% | 97.32% | 96.67% | 0.75 min | A100 |
| 11 | gpt2-small (124M) | pretrained | last | last_block | context length (1024) | 83.08% | 87.92% | 78.33% | 2.46 min | A100 |
| 12 | gpt2-small (124M) | pretrained | last | last_block | variable: no padding (batch size 1) | 100.00% | 98.66% | 98.00% | 1.75 min | A100 |
| 13 | gpt2-small (124M) | pretrained | last | last_block | variable: no padding (batch size 8) | 99.33% | 98.66% | 98.33% | 1.70 min | A100 |
| 14 | gpt2-small (124M) | pretrained | last | last_block | longest train ex. (120); but no causal mask | 99.23% | 98.66% | 95.33% | 0.29 min | A100 |
| 15 | gpt2-small (124M) | pretrained | last | last_block | longest train ex. (120) and `ignore_index` for padding | 96.63% | 99.33% | 95.00% | 0.28 min | A100 |
 
@ -34,7 +34,7 @@ For example,
You can use the following code to reproduce the experiments:
- Row 1: `python additional-experiments.py`
- Row 2: `python additional-experiments.py --trainable_token first`
- Row 2: `python additional-experiments.py --trainable_token_pos first`
- Row 3: `python additional-experiments.py --trainable_layers last_layer`
- Row 4: `python additional-experiments.py --trainable_layers last_two_blocks`
- Row 5: `python additional-experiments.py --trainable_layers all`
@ -55,7 +55,7 @@ I've kept the LLM and dataset small on purpose, so you can run the training on a
### Interpretation
1. **Training the Last vs. First Output Token (Row 1 vs. 2)**: Training the last output token results in substantially better performance compared to the first. This improvement is expected due to the causal self-attention mask.
1. **Training the Last vs. First Output Token Position (Row 1 vs. 2)**: Training the last output token position results in substantially better performance compared to the first. This improvement is expected due to the causal self-attention mask.
2. **Training the Last Transformer Block vs. Last Layer (Row 1 vs. 3)**: Training the entire last transformer block is also results in substantially better results than training only the last layer.
3. **Training the Last vs. Last Two Last Transformer Blocks (Row 1 vs. 4)**: Training the two last transformer blocks instead of only the last block results in a noticeable 3.33% accuracy boost.
4. **Training Last Transformer Block vs All Layers (Row 1 vs. 5)**: Training all layers shows a modest improvement of ~2% over just training the last transformer block, but it requires almost three times longer in terms of training duration. Also, it does not perform as well as training only the last two out of 12 transformer blocks.

View File

@ -166,15 +166,15 @@ def instantiate_model(choose_model, load_weights):
def calc_loss_batch(input_batch, target_batch, model, device,
trainable_token=-1, ignore_index=-100):
trainable_token_pos=-1, ignore_index=-100):
input_batch, target_batch = input_batch.to(device), target_batch.to(device)
logits = model(input_batch)[:, trainable_token, :] # Logits of last output token
logits = model(input_batch)[:, trainable_token_pos, :] # Logits of last output token
loss = torch.nn.functional.cross_entropy(logits, target_batch, ignore_index=ignore_index)
return loss
def calc_loss_loader(data_loader, model, device,
num_batches=None, trainable_token=-1, ignore_index=-100):
num_batches=None, trainable_token_pos=-1, ignore_index=-100):
total_loss = 0.
if len(data_loader) == 0:
return float("nan")
@ -188,7 +188,7 @@ def calc_loss_loader(data_loader, model, device,
if i < num_batches:
loss = calc_loss_batch(
input_batch, target_batch, model, device,
trainable_token=trainable_token, ignore_index=ignore_index
trainable_token_pos=trainable_token_pos, ignore_index=ignore_index
)
total_loss += loss.item()
else:
@ -197,7 +197,7 @@ def calc_loss_loader(data_loader, model, device,
@torch.no_grad() # Disable gradient tracking for efficiency
def calc_accuracy_loader(data_loader, model, device, num_batches=None, trainable_token=-1):
def calc_accuracy_loader(data_loader, model, device, num_batches=None, trainable_token_pos=-1):
model.eval()
correct_predictions, num_examples = 0, 0
@ -208,7 +208,7 @@ def calc_accuracy_loader(data_loader, model, device, num_batches=None, trainable
for i, (input_batch, target_batch) in enumerate(data_loader):
if i < num_batches:
input_batch, target_batch = input_batch.to(device), target_batch.to(device)
logits = model(input_batch)[:, trainable_token, :] # Logits of last output token
logits = model(input_batch)[:, trainable_token_pos, :] # Logits of last output token
predicted_labels = torch.argmax(logits, dim=-1)
num_examples += predicted_labels.shape[0]
@ -219,23 +219,23 @@ def calc_accuracy_loader(data_loader, model, device, num_batches=None, trainable
def evaluate_model(model, train_loader, val_loader, device,
eval_iter, trainable_token=-1, ignore_index=-100):
eval_iter, trainable_token_pos=-1, ignore_index=-100):
model.eval()
with torch.no_grad():
train_loss = calc_loss_loader(
train_loader, model, device, num_batches=eval_iter,
trainable_token=trainable_token, ignore_index=ignore_index
trainable_token_pos=trainable_token_pos, ignore_index=ignore_index
)
val_loss = calc_loss_loader(
val_loader, model, device, num_batches=eval_iter,
trainable_token=trainable_token, ignore_index=ignore_index
trainable_token_pos=trainable_token_pos, ignore_index=ignore_index
)
model.train()
return train_loss, val_loss
def train_classifier_simple(model, train_loader, val_loader, optimizer, device, num_epochs,
eval_freq, eval_iter, tokenizer, max_steps=None, trainable_token=-1,
eval_freq, eval_iter, tokenizer, max_steps=None, trainable_token_pos=-1,
accumulation_steps=1, ignore_index=-100):
# Initialize lists to track losses and tokens seen
train_losses, val_losses, train_accs, val_accs = [], [], [], []
@ -248,7 +248,7 @@ def train_classifier_simple(model, train_loader, val_loader, optimizer, device,
for batch_idx, (input_batch, target_batch) in enumerate(train_loader):
loss = calc_loss_batch(
input_batch, target_batch, model, device,
trainable_token=trainable_token, ignore_index=ignore_index
trainable_token_pos=trainable_token_pos, ignore_index=ignore_index
)
# Use gradient accumulation if accumulation_steps > 1
@ -270,7 +270,7 @@ def train_classifier_simple(model, train_loader, val_loader, optimizer, device,
if global_step % eval_freq == 0:
train_loss, val_loss = evaluate_model(
model, train_loader, val_loader, device, eval_iter,
trainable_token=trainable_token, ignore_index=ignore_index
trainable_token_pos=trainable_token_pos, ignore_index=ignore_index
)
train_losses.append(train_loss)
val_losses.append(val_loss)
@ -281,8 +281,8 @@ def train_classifier_simple(model, train_loader, val_loader, optimizer, device,
break
# New: Calculate accuracy after each epoch
train_accuracy = calc_accuracy_loader(train_loader, model, device, num_batches=eval_iter, trainable_token=trainable_token)
val_accuracy = calc_accuracy_loader(val_loader, model, device, num_batches=eval_iter, trainable_token=trainable_token)
train_accuracy = calc_accuracy_loader(train_loader, model, device, num_batches=eval_iter, trainable_token_pos=trainable_token_pos)
val_accuracy = calc_accuracy_loader(val_loader, model, device, num_batches=eval_iter, trainable_token_pos=trainable_token_pos)
print(f"Training accuracy: {train_accuracy*100:.2f}% | ", end="")
print(f"Validation accuracy: {val_accuracy*100:.2f}%")
train_accs.append(train_accuracy)
@ -333,11 +333,11 @@ if __name__ == "__main__":
)
)
parser.add_argument(
"--trainable_token",
"--trainable_token_pos",
type=str,
default="last",
help=(
"Which token to train. Options: 'first', 'last'."
"Which token position to train. Options: 'first', 'last'."
)
)
parser.add_argument(
@ -424,12 +424,12 @@ if __name__ == "__main__":
args = parser.parse_args()
if args.trainable_token == "first":
args.trainable_token = 0
elif args.trainable_token == "last":
args.trainable_token = -1
if args.trainable_token_pos == "first":
args.trainable_token_pos = 0
elif args.trainable_token_pos == "last":
args.trainable_token_pos = -1
else:
raise ValueError("Invalid --trainable_token argument")
raise ValueError("Invalid --trainable_token_pos argument")
###############################
# Load model
@ -565,7 +565,7 @@ if __name__ == "__main__":
train_losses, val_losses, train_accs, val_accs, examples_seen = train_classifier_simple(
model, train_loader, val_loader, optimizer, device,
num_epochs=args.num_epochs, eval_freq=50, eval_iter=5,
tokenizer=tokenizer, max_steps=None, trainable_token=args.trainable_token,
tokenizer=tokenizer, max_steps=None, trainable_token_pos=args.trainable_token_pos,
accumulation_steps=args.accumulation_steps
)
@ -577,9 +577,9 @@ if __name__ == "__main__":
# Evaluate model
###############################
train_accuracy = calc_accuracy_loader(train_loader, model, device, trainable_token=args.trainable_token)
val_accuracy = calc_accuracy_loader(val_loader, model, device, trainable_token=args.trainable_token)
test_accuracy = calc_accuracy_loader(test_loader, model, device, trainable_token=args.trainable_token)
train_accuracy = calc_accuracy_loader(train_loader, model, device, trainable_token_pos=args.trainable_token_pos)
val_accuracy = calc_accuracy_loader(val_loader, model, device, trainable_token_pos=args.trainable_token_pos)
test_accuracy = calc_accuracy_loader(test_loader, model, device, trainable_token_pos=args.trainable_token_pos)
print(f"Training accuracy: {train_accuracy*100:.2f}%")
print(f"Validation accuracy: {val_accuracy*100:.2f}%")