mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-11-01 10:20:00 +00:00
update lora experiments
This commit is contained in:
parent
742525cb28
commit
1a962f3983
@ -19,13 +19,15 @@ For example,
|
||||
| 6 | gpt2-medium (355M) | pretrained | last | last_block | longest train ex. (120) | 87.50% | 91.28% | 84.67% | 0.75 min | A100 |
|
||||
| 7 | gpt2-large (774M) | pretrained | last | last_block | longest train ex. (120) | 99.52% | 98.66% | 96.67% | 1.50 min | A100 |
|
||||
| 8 | gpt2-xl (1558M) | pretrained | last | last_block | longest train ex. (120) | 99.81% | 99.33% | 98.33% | 2.83 min | A100 |
|
||||
| 9 | gpt2-small (124M) | random | last | all | longest train ex. (120) | 100.00% | 96.64% | 93.67% | 0.69 min | A100 |
|
||||
| 10 | gpt2-small (124M) | pretrained | last | LoRA | longest train ex. (120) | 100.00% | 97.32% | 96.67% | 0.75 min | A100 |
|
||||
| 11 | gpt2-small (124M) | pretrained | last | last_block | context length (1024) | 83.08% | 87.92% | 78.33% | 2.46 min | A100 |
|
||||
| 12 | gpt2-small (124M) | pretrained | last | last_block | variable: no padding (batch size 1) | 100.00% | 98.66% | 98.00% | 1.75 min | A100 |
|
||||
| 13 | gpt2-small (124M) | pretrained | last | last_block | variable: no padding (batch size 8) | 99.33% | 98.66% | 98.33% | 1.70 min | A100 |
|
||||
| 14 | gpt2-small (124M) | pretrained | last | last_block | longest train ex. (120); but no causal mask | 99.23% | 98.66% | 95.33% | 0.29 min | A100 |
|
||||
| 15 | gpt2-small (124M) | pretrained | last | last_block | longest train ex. (120) and `ignore_index` for padding | 96.63% | 99.33% | 95.00% | 0.28 min | A100 |
|
||||
| 9 | gpt2-xl (1558M) | pretrained | last | all | longest train ex. (120) | 100.00% | 98.66% | 98.67% | 8.12 min | A100 |
|
||||
| 10 | gpt2-small (124M) | random | last | all | longest train ex. (120) | 100.00% | 96.64% | 93.67% | 0.69 min | A100 |
|
||||
| 11 | gpt2-small (124M) | pretrained | last | LoRA | longest train ex. (120) | 100.00% | 97.32% | 96.67% | 0.75 min | A100 |
|
||||
| 12 | gpt2-xl (1558M) | pretrained | last | LoRA | longest train ex. (120) | 100.00% | 98.66% | 98.33% | 5.79 min | A100 |
|
||||
| 13 | gpt2-small (124M) | pretrained | last | last_block | context length (1024) | 83.08% | 87.92% | 78.33% | 2.46 min | A100 |
|
||||
| 14 | gpt2-small (124M) | pretrained | last | last_block | variable: no padding (batch size 1) | 100.00% | 98.66% | 98.00% | 1.75 min | A100 |
|
||||
| 15 | gpt2-small (124M) | pretrained | last | last_block | variable: no padding (batch size 8) | 99.33% | 98.66% | 98.33% | 1.70 min | A100 |
|
||||
| 16 | gpt2-small (124M) | pretrained | last | last_block | longest train ex. (120); but no causal mask | 99.23% | 98.66% | 95.33% | 0.29 min | A100 |
|
||||
| 17 | gpt2-small (124M) | pretrained | last | last_block | longest train ex. (120) and `ignore_index` for padding | 96.63% | 99.33% | 95.00% | 0.28 min | A100 |
|
||||
|
||||
|
||||
|
||||
@ -41,13 +43,15 @@ You can use the following code to reproduce the experiments:
|
||||
- Row 6: `python additional-experiments.py --model_size "gpt2-medium (355M)"`
|
||||
- Row 7: `python additional-experiments.py --model_size "gpt2-large (774M)"`
|
||||
- Row 8: `python additional-experiments.py --model_size "gpt2-xl (1558M)"`
|
||||
- Row 9: `python additional-experiments.py --weights random --trainable_layers all`
|
||||
- Row 10: `python additional-experiments.py --trainable_layers lora --lora_rank 16 --lora_alpha 16`
|
||||
- Row 11: `python additional-experiments.py --context_length "model_context_length"`
|
||||
- Row 12: `python additional-experiments.py --no_padding --batch_size 1`
|
||||
- Row 13: `python additional-experiments.py --no_padding --batch_size 1 --accumulation_steps 8`
|
||||
- Row 14: `python additional-experiments.py --disable_causal_mask`
|
||||
- Row 15: `python additional-experiments.py --ignore_index 50256`
|
||||
- Row 9: `python additional-experiments.py --model_size "gpt2-xl (1558M)"--trainable_layers all`
|
||||
- Row 10: `python additional-experiments.py --weights random --trainable_layers all`
|
||||
- Row 11: `python additional-experiments.py --trainable_layers lora --lora_rank 16 --lora_alpha 16`
|
||||
- Row 12: `python additional-experiments.py --trainable_layers lora --lora_rank 16 --lora_alpha 8 --model_size "gpt2-xl (1558M)"`
|
||||
- Row 13: `python additional-experiments.py --context_length "model_context_length"`
|
||||
- Row 14: `python additional-experiments.py --no_padding --batch_size 1`
|
||||
- Row 15: `python additional-experiments.py --no_padding --batch_size 1 --accumulation_steps 8`
|
||||
- Row 16: `python additional-experiments.py --disable_causal_mask`
|
||||
- Row 17: `python additional-experiments.py --ignore_index 50256`
|
||||
|
||||
I've kept the LLM and dataset small on purpose, so you can run the training on a regular laptop like a MacBook Air M3 in about 15 minutes (for the default setting) in case you don't have access to a GPU.
|
||||
|
||||
@ -60,9 +64,9 @@ I've kept the LLM and dataset small on purpose, so you can run the training on a
|
||||
3. **Training the Last vs. Last Two Last Transformer Blocks (Row 1 vs. 4)**: Training the two last transformer blocks instead of only the last block results in a noticeable 3.33% accuracy boost.
|
||||
4. **Training Last Transformer Block vs All Layers (Row 1 vs. 5)**: Training all layers shows a modest improvement of ~2% over just training the last transformer block, but it requires almost three times longer in terms of training duration. Also, it does not perform as well as training only the last two out of 12 transformer blocks.
|
||||
5. **Using Larger Pretrained Models (Row 1 vs 6, and Row 1 vs. 7 and 8)**: Employing a 3x larger pretrained model leads to worse results. However, using a 5x larger model improves performance compared to the initial model, as was anticipated. Similarly, the 12x larger model improves the predictive performance even further. (The medium model was perhaps not well pretrained or the particular finetuning configuration works not as well for this model.)
|
||||
6. **Using a Model with Random Weights vs. Pretrained Weights (Row 1 and 5 vs. 9)**: Utilizing a model with random weights yields results that are only slightly worse (by 3% and 1.3%) compared to using pretrained weights.
|
||||
7. **Using LoRA (Low-Rank Adaptation) vs Training All Layers (Row 10 vs. 5)**: Keeping the model frozen and adding trainable LoRA layers (see [Appendix E](../../appendix-E/01_main-chapter-code/appendix-E.ipynb) for details) is a viable alternative to training all model parameters and even improves the performance by 1% point. As it can be seen by the ~1% lower gap between the training and validation accuracy when using LoRA, this is likely due to less overfitting. Moreover, using LoRA is also more memory-efficient because fewer parameters have to be updated.
|
||||
8. **Padding Input to Full Context Length vs. Longest Training Example (Row 1 vs. 11)**: Padding the input to the full supported context length results is significantly worse.
|
||||
9. **Padding vs no padding (Row 1 vs. 12 and 13)**: The `--no_padding` option disables the padding in the dataset, which requires training the model with a batch size of 1 since the inputs have variable lengths. This results in a better test accuracy but takes longer to train. In row 13, we additionally enable gradient accumulation with 8 steps to achieve the same batch size as in the other experiments, which helps reduce overfitting and slightly boost the test set accuracy.
|
||||
10. **Disabling the causal attention mask (Row 1 vs. 14)**: Disables the causal attention mask used in the multi-head attention module. This means all tokens can attend all other tokens. The model accuracy is slightly improved compared to the GPT model with causal mask.
|
||||
11. **Ignoring the padding indices in the loss and backpropagation (Row 1 vs. 15)**: Setting `--ignore_index 50256` excludes the `|endoftext|` padding tokens in the `cross_entropy` loss function in PyTorch. In this case, it does not have any effect because we replaced the output layers so that the token IDs are either 0 or 1 for the binary classification example. However, this setting is useful when instruction finetuning models in chapter 7.
|
||||
6. **Using a Model with Random Weights vs. Pretrained Weights (Row 1 and 5 vs. 10)**: Utilizing a model with random weights yields results that are only slightly worse (by 3% and 1.3%) compared to using pretrained weights.
|
||||
7. **Using LoRA (Low-Rank Adaptation) vs Training All Layers (Row 11 vs. 5, and row 12 vs. 9)**: Keeping the model frozen and adding trainable LoRA layers (see [Appendix E](../../appendix-E/01_main-chapter-code/appendix-E.ipynb) for details) is a viable alternative to training all model parameters and even improves the performance by 1% point (row 11 vs. 5). As it can be seen by the ~1% lower gap between the training and validation accuracy when using LoRA, this is likely due to less overfitting. Moreover, using LoRA is also more memory-efficient because fewer parameters have to be updated. When training the larger model (row 12 vs. 9), we can also see that LoRA trains much faster (5.79 min instead of 8.12 min).
|
||||
8. **Padding Input to Full Context Length vs. Longest Training Example (Row 1 vs. 13)**: Padding the input to the full supported context length results is significantly worse.
|
||||
9. **Padding vs no padding (Row 1 vs. 14 and 15)**: The `--no_padding` option disables the padding in the dataset, which requires training the model with a batch size of 1 since the inputs have variable lengths. This results in a better test accuracy but takes longer to train. In row 15, we additionally enable gradient accumulation with 8 steps to achieve the same batch size as in the other experiments, which helps reduce overfitting and slightly boost the test set accuracy.
|
||||
10. **Disabling the causal attention mask (Row 1 vs. 16)**: Disables the causal attention mask used in the multi-head attention module. This means all tokens can attend all other tokens. The model accuracy is slightly improved compared to the GPT model with causal mask.
|
||||
11. **Ignoring the padding indices in the loss and backpropagation (Row 1 vs. 17)**: Setting `--ignore_index 50256` excludes the `|endoftext|` padding tokens in the `cross_entropy` loss function in PyTorch. In this case, it does not have any effect because we replaced the output layers so that the token IDs are either 0 or 1 for the binary classification example. However, this setting is useful when instruction finetuning models in chapter 7.
|
||||
|
||||
@ -46,6 +46,21 @@ class LinearWithLoRA(torch.nn.Module):
|
||||
return self.linear(x) + self.lora(x)
|
||||
|
||||
|
||||
# This LoRA code is equivalent to LinearWithLoRA
|
||||
class LinearWithLoRAMerged(torch.nn.Module):
|
||||
def __init__(self, linear, rank, alpha):
|
||||
super().__init__()
|
||||
self.linear = linear
|
||||
self.lora = LoRALayer(
|
||||
linear.in_features, linear.out_features, rank, alpha
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
lora = self.lora.A @ self.lora.B
|
||||
combined_weight = self.linear.weight + self.lora.alpha*lora.T
|
||||
return torch.nn.functional.linear(x, combined_weight, self.linear.bias)
|
||||
|
||||
|
||||
class SpamDataset(Dataset):
|
||||
def __init__(self, csv_file, tokenizer, max_length=None, pad_token_id=50256, no_padding=False):
|
||||
self.data = pd.read_csv(csv_file)
|
||||
@ -295,11 +310,14 @@ def train_classifier_simple(model, train_loader, val_loader, optimizer, device,
|
||||
return train_losses, val_losses, train_accs, val_accs, examples_seen
|
||||
|
||||
|
||||
def replace_linear_with_lora(model, rank, alpha):
|
||||
def replace_linear_with_lora(model, rank, alpha, alternative=False):
|
||||
for name, module in model.named_children():
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
# Replace the Linear layer with LinearWithLoRA
|
||||
setattr(model, name, LinearWithLoRA(module, rank, alpha))
|
||||
if alternative:
|
||||
setattr(model, name, LinearWithLoRAMerged(module, rank, alpha))
|
||||
else:
|
||||
setattr(model, name, LinearWithLoRA(module, rank, alpha))
|
||||
else:
|
||||
# Recursively apply the same function to child modules
|
||||
replace_linear_with_lora(module, rank, alpha)
|
||||
@ -330,7 +348,7 @@ if __name__ == "__main__":
|
||||
type=str,
|
||||
default="last_block",
|
||||
help=(
|
||||
"Which layers to train. Options: 'all', 'last_block', 'last_two_blocks', 'last_layer', 'lora'."
|
||||
"Which layers to train. Options: 'all', 'last_block', 'last_two_blocks', 'last_layer', 'lora', 'lora_alternative'."
|
||||
)
|
||||
)
|
||||
parser.add_argument(
|
||||
@ -474,8 +492,12 @@ if __name__ == "__main__":
|
||||
elif args.trainable_layers == "all":
|
||||
for param in model.parameters():
|
||||
param.requires_grad = True
|
||||
elif args.trainable_layers == "lora":
|
||||
replace_linear_with_lora(model, rank=args.lora_rank, alpha=args.lora_alpha)
|
||||
elif args.trainable_layers in ("lora", "lora_alternative"):
|
||||
if args.trainable_layers == "lora_alternative":
|
||||
alternative = True
|
||||
else:
|
||||
alternative = False
|
||||
replace_linear_with_lora(model, rank=args.lora_rank, alpha=args.lora_alpha, alternative=alternative)
|
||||
else:
|
||||
raise ValueError("Invalid --trainable_layers argument.")
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user