From 30ba6a3f4b41c0baf7c42c616d8f99ae601e279d Mon Sep 17 00:00:00 2001 From: rasbt Date: Thu, 23 May 2024 11:43:20 -0500 Subject: [PATCH] trainable token -> trainable token position --- .../02_bonus_additional-experiments/README.md | 38 +++++++------- .../additional-experiments.py | 50 +++++++++---------- 2 files changed, 44 insertions(+), 44 deletions(-) diff --git a/ch06/02_bonus_additional-experiments/README.md b/ch06/02_bonus_additional-experiments/README.md index d74d8ed..4172719 100644 --- a/ch06/02_bonus_additional-experiments/README.md +++ b/ch06/02_bonus_additional-experiments/README.md @@ -9,23 +9,23 @@ For example,   -| | Model | Weights | Trainable token | Trainable layers | Context length | Training acc | Validation acc | Test acc | Training time | CPU/GPU | -| --- | ------------------ | ---------- | --------------- | ---------------- | ------------------------------------------------------ | ------------ | -------------- | -------- | ------------- | ------- | -| 1 | gpt2-small (124M) | pretrained | last | last_block | longest train ex. (120) | 96.63% | 99.33% | 95.00% | 0.28 min | A100 | -| 2 | gpt2-small (124M) | pretrained | first | last_block | longest train ex. (120) | 78.46% | 80.54% | 75.00% | 0.28 min | A100 | -| 3 | gpt2-small (124M) | pretrained | last | last_layer | longest train ex. (120) | 78.65% | 79.87% | 72.00% | 0.25 min | A100 | -| 4 | gpt2-small (124M) | pretrained | last | last_two_blocks | longest train ex. (120) | 98.85% | 98.66% | 98.33% | 0.33 min | A100 | -| 5 | gpt2-small (124M) | pretrained | last | all | longest train ex. (120) | 99.62% | 96.64% | 96.67% | 0.69 min | A100 | -| 6 | gpt2-medium (355M) | pretrained | last | last_block | longest train ex. (120) | 87.50% | 91.28% | 84.67% | 0.75 min | A100 | -| 7 | gpt2-large (774M) | pretrained | last | last_block | longest train ex. (120) | 99.52% | 98.66% | 96.67% | 1.50 min | A100 | -| 8 | gpt2-xl (1558M) | pretrained | last | last_block | longest train ex. (120) | 99.81% | 99.33% | 98.33% | 2.83 min | A100 | -| 9 | gpt2-small (124M) | random | last | all | longest train ex. (120) | 100% | 96.64% | 93.67% | 0.69 min | A100 | -| 10 | gpt2-small (124M) | pretrained | last | LoRA | longest train ex. (120) | 100.00% | 97.32% | 96.67% | 0.75 min | A100 | -| 11 | gpt2-small (124M) | pretrained | last | last_block | context length (1024) | 83.08% | 87.92% | 78.33% | 2.46 min | A100 | -| 12 | gpt2-small (124M) | pretrained | last | last_block | variable: no padding (batch size 1) | 100.00% | 98.66% | 98.00% | 1.75 min | A100 | -| 13 | gpt2-small (124M) | pretrained | last | last_block | variable: no padding (batch size 8) | 99.33% | 98.66% | 98.33% | 1.70 min | A100 | -| 14 | gpt2-small (124M) | pretrained | last | last_block | longest train ex. (120); but no causal mask | 99.23% | 98.66% | 95.33% | 0.29 min | A100 | -| 15 | gpt2-small (124M) | pretrained | last | last_block | longest train ex. (120) and `ignore_index` for padding | 96.63% | 99.33% | 95.00% | 0.28 min | A100 | +| | Model | Weights | Trainable token position | Trainable layers | Context length | Training acc | Validation acc | Test acc | Training time | CPU/GPU | +| ---- | ------------------ | ---------- | ------------------------ | ---------------- | ------------------------------------------------------ | ------------ | -------------- | -------- | ------------- | ------- | +| 1 | gpt2-small (124M) | pretrained | last | last_block | longest train ex. (120) | 96.63% | 99.33% | 95.00% | 0.28 min | A100 | +| 2 | gpt2-small (124M) | pretrained | first | last_block | longest train ex. (120) | 78.46% | 80.54% | 75.00% | 0.28 min | A100 | +| 3 | gpt2-small (124M) | pretrained | last | last_layer | longest train ex. (120) | 78.65% | 79.87% | 72.00% | 0.25 min | A100 | +| 4 | gpt2-small (124M) | pretrained | last | last_two_blocks | longest train ex. (120) | 98.85% | 98.66% | 98.33% | 0.33 min | A100 | +| 5 | gpt2-small (124M) | pretrained | last | all | longest train ex. (120) | 99.62% | 96.64% | 96.67% | 0.69 min | A100 | +| 6 | gpt2-medium (355M) | pretrained | last | last_block | longest train ex. (120) | 87.50% | 91.28% | 84.67% | 0.75 min | A100 | +| 7 | gpt2-large (774M) | pretrained | last | last_block | longest train ex. (120) | 99.52% | 98.66% | 96.67% | 1.50 min | A100 | +| 8 | gpt2-xl (1558M) | pretrained | last | last_block | longest train ex. (120) | 99.81% | 99.33% | 98.33% | 2.83 min | A100 | +| 9 | gpt2-small (124M) | random | last | all | longest train ex. (120) | 100% | 96.64% | 93.67% | 0.69 min | A100 | +| 10 | gpt2-small (124M) | pretrained | last | LoRA | longest train ex. (120) | 100.00% | 97.32% | 96.67% | 0.75 min | A100 | +| 11 | gpt2-small (124M) | pretrained | last | last_block | context length (1024) | 83.08% | 87.92% | 78.33% | 2.46 min | A100 | +| 12 | gpt2-small (124M) | pretrained | last | last_block | variable: no padding (batch size 1) | 100.00% | 98.66% | 98.00% | 1.75 min | A100 | +| 13 | gpt2-small (124M) | pretrained | last | last_block | variable: no padding (batch size 8) | 99.33% | 98.66% | 98.33% | 1.70 min | A100 | +| 14 | gpt2-small (124M) | pretrained | last | last_block | longest train ex. (120); but no causal mask | 99.23% | 98.66% | 95.33% | 0.29 min | A100 | +| 15 | gpt2-small (124M) | pretrained | last | last_block | longest train ex. (120) and `ignore_index` for padding | 96.63% | 99.33% | 95.00% | 0.28 min | A100 |   @@ -34,7 +34,7 @@ For example, You can use the following code to reproduce the experiments: - Row 1: `python additional-experiments.py` -- Row 2: `python additional-experiments.py --trainable_token first` +- Row 2: `python additional-experiments.py --trainable_token_pos first` - Row 3: `python additional-experiments.py --trainable_layers last_layer` - Row 4: `python additional-experiments.py --trainable_layers last_two_blocks` - Row 5: `python additional-experiments.py --trainable_layers all` @@ -55,7 +55,7 @@ I've kept the LLM and dataset small on purpose, so you can run the training on a ### Interpretation -1. **Training the Last vs. First Output Token (Row 1 vs. 2)**: Training the last output token results in substantially better performance compared to the first. This improvement is expected due to the causal self-attention mask. +1. **Training the Last vs. First Output Token Position (Row 1 vs. 2)**: Training the last output token position results in substantially better performance compared to the first. This improvement is expected due to the causal self-attention mask. 2. **Training the Last Transformer Block vs. Last Layer (Row 1 vs. 3)**: Training the entire last transformer block is also results in substantially better results than training only the last layer. 3. **Training the Last vs. Last Two Last Transformer Blocks (Row 1 vs. 4)**: Training the two last transformer blocks instead of only the last block results in a noticeable 3.33% accuracy boost. 4. **Training Last Transformer Block vs All Layers (Row 1 vs. 5)**: Training all layers shows a modest improvement of ~2% over just training the last transformer block, but it requires almost three times longer in terms of training duration. Also, it does not perform as well as training only the last two out of 12 transformer blocks. diff --git a/ch06/02_bonus_additional-experiments/additional-experiments.py b/ch06/02_bonus_additional-experiments/additional-experiments.py index dd3d559..bcfc0b8 100644 --- a/ch06/02_bonus_additional-experiments/additional-experiments.py +++ b/ch06/02_bonus_additional-experiments/additional-experiments.py @@ -166,15 +166,15 @@ def instantiate_model(choose_model, load_weights): def calc_loss_batch(input_batch, target_batch, model, device, - trainable_token=-1, ignore_index=-100): + trainable_token_pos=-1, ignore_index=-100): input_batch, target_batch = input_batch.to(device), target_batch.to(device) - logits = model(input_batch)[:, trainable_token, :] # Logits of last output token + logits = model(input_batch)[:, trainable_token_pos, :] # Logits of last output token loss = torch.nn.functional.cross_entropy(logits, target_batch, ignore_index=ignore_index) return loss def calc_loss_loader(data_loader, model, device, - num_batches=None, trainable_token=-1, ignore_index=-100): + num_batches=None, trainable_token_pos=-1, ignore_index=-100): total_loss = 0. if len(data_loader) == 0: return float("nan") @@ -188,7 +188,7 @@ def calc_loss_loader(data_loader, model, device, if i < num_batches: loss = calc_loss_batch( input_batch, target_batch, model, device, - trainable_token=trainable_token, ignore_index=ignore_index + trainable_token_pos=trainable_token_pos, ignore_index=ignore_index ) total_loss += loss.item() else: @@ -197,7 +197,7 @@ def calc_loss_loader(data_loader, model, device, @torch.no_grad() # Disable gradient tracking for efficiency -def calc_accuracy_loader(data_loader, model, device, num_batches=None, trainable_token=-1): +def calc_accuracy_loader(data_loader, model, device, num_batches=None, trainable_token_pos=-1): model.eval() correct_predictions, num_examples = 0, 0 @@ -208,7 +208,7 @@ def calc_accuracy_loader(data_loader, model, device, num_batches=None, trainable for i, (input_batch, target_batch) in enumerate(data_loader): if i < num_batches: input_batch, target_batch = input_batch.to(device), target_batch.to(device) - logits = model(input_batch)[:, trainable_token, :] # Logits of last output token + logits = model(input_batch)[:, trainable_token_pos, :] # Logits of last output token predicted_labels = torch.argmax(logits, dim=-1) num_examples += predicted_labels.shape[0] @@ -219,23 +219,23 @@ def calc_accuracy_loader(data_loader, model, device, num_batches=None, trainable def evaluate_model(model, train_loader, val_loader, device, - eval_iter, trainable_token=-1, ignore_index=-100): + eval_iter, trainable_token_pos=-1, ignore_index=-100): model.eval() with torch.no_grad(): train_loss = calc_loss_loader( train_loader, model, device, num_batches=eval_iter, - trainable_token=trainable_token, ignore_index=ignore_index + trainable_token_pos=trainable_token_pos, ignore_index=ignore_index ) val_loss = calc_loss_loader( val_loader, model, device, num_batches=eval_iter, - trainable_token=trainable_token, ignore_index=ignore_index + trainable_token_pos=trainable_token_pos, ignore_index=ignore_index ) model.train() return train_loss, val_loss def train_classifier_simple(model, train_loader, val_loader, optimizer, device, num_epochs, - eval_freq, eval_iter, tokenizer, max_steps=None, trainable_token=-1, + eval_freq, eval_iter, tokenizer, max_steps=None, trainable_token_pos=-1, accumulation_steps=1, ignore_index=-100): # Initialize lists to track losses and tokens seen train_losses, val_losses, train_accs, val_accs = [], [], [], [] @@ -248,7 +248,7 @@ def train_classifier_simple(model, train_loader, val_loader, optimizer, device, for batch_idx, (input_batch, target_batch) in enumerate(train_loader): loss = calc_loss_batch( input_batch, target_batch, model, device, - trainable_token=trainable_token, ignore_index=ignore_index + trainable_token_pos=trainable_token_pos, ignore_index=ignore_index ) # Use gradient accumulation if accumulation_steps > 1 @@ -270,7 +270,7 @@ def train_classifier_simple(model, train_loader, val_loader, optimizer, device, if global_step % eval_freq == 0: train_loss, val_loss = evaluate_model( model, train_loader, val_loader, device, eval_iter, - trainable_token=trainable_token, ignore_index=ignore_index + trainable_token_pos=trainable_token_pos, ignore_index=ignore_index ) train_losses.append(train_loss) val_losses.append(val_loss) @@ -281,8 +281,8 @@ def train_classifier_simple(model, train_loader, val_loader, optimizer, device, break # New: Calculate accuracy after each epoch - train_accuracy = calc_accuracy_loader(train_loader, model, device, num_batches=eval_iter, trainable_token=trainable_token) - val_accuracy = calc_accuracy_loader(val_loader, model, device, num_batches=eval_iter, trainable_token=trainable_token) + train_accuracy = calc_accuracy_loader(train_loader, model, device, num_batches=eval_iter, trainable_token_pos=trainable_token_pos) + val_accuracy = calc_accuracy_loader(val_loader, model, device, num_batches=eval_iter, trainable_token_pos=trainable_token_pos) print(f"Training accuracy: {train_accuracy*100:.2f}% | ", end="") print(f"Validation accuracy: {val_accuracy*100:.2f}%") train_accs.append(train_accuracy) @@ -333,11 +333,11 @@ if __name__ == "__main__": ) ) parser.add_argument( - "--trainable_token", + "--trainable_token_pos", type=str, default="last", help=( - "Which token to train. Options: 'first', 'last'." + "Which token position to train. Options: 'first', 'last'." ) ) parser.add_argument( @@ -424,12 +424,12 @@ if __name__ == "__main__": args = parser.parse_args() - if args.trainable_token == "first": - args.trainable_token = 0 - elif args.trainable_token == "last": - args.trainable_token = -1 + if args.trainable_token_pos == "first": + args.trainable_token_pos = 0 + elif args.trainable_token_pos == "last": + args.trainable_token_pos = -1 else: - raise ValueError("Invalid --trainable_token argument") + raise ValueError("Invalid --trainable_token_pos argument") ############################### # Load model @@ -565,7 +565,7 @@ if __name__ == "__main__": train_losses, val_losses, train_accs, val_accs, examples_seen = train_classifier_simple( model, train_loader, val_loader, optimizer, device, num_epochs=args.num_epochs, eval_freq=50, eval_iter=5, - tokenizer=tokenizer, max_steps=None, trainable_token=args.trainable_token, + tokenizer=tokenizer, max_steps=None, trainable_token_pos=args.trainable_token_pos, accumulation_steps=args.accumulation_steps ) @@ -577,9 +577,9 @@ if __name__ == "__main__": # Evaluate model ############################### - train_accuracy = calc_accuracy_loader(train_loader, model, device, trainable_token=args.trainable_token) - val_accuracy = calc_accuracy_loader(val_loader, model, device, trainable_token=args.trainable_token) - test_accuracy = calc_accuracy_loader(test_loader, model, device, trainable_token=args.trainable_token) + train_accuracy = calc_accuracy_loader(train_loader, model, device, trainable_token_pos=args.trainable_token_pos) + val_accuracy = calc_accuracy_loader(val_loader, model, device, trainable_token_pos=args.trainable_token_pos) + test_accuracy = calc_accuracy_loader(test_loader, model, device, trainable_token_pos=args.trainable_token_pos) print(f"Training accuracy: {train_accuracy*100:.2f}%") print(f"Validation accuracy: {val_accuracy*100:.2f}%")