Add LoRA experiments (#151)

* Add LoRA experiments

* Update ch06/02_bonus_additional-experiments/additional-experiments.py
This commit is contained in:
Sebastian Raschka 2024-05-10 07:26:41 -05:00 committed by GitHub
parent 51ac283257
commit 41288a3d3a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 63 additions and 5 deletions

View File

@ -17,9 +17,11 @@ For example,
| 4 | gpt2-small (124M) | pretrained | last | all | longest train ex. (120) | V100 | 0.94 min | 99.62% | 96.64% | 96.67% | | 4 | gpt2-small (124M) | pretrained | last | all | longest train ex. (120) | V100 | 0.94 min | 99.62% | 96.64% | 96.67% |
| 5 | gpt2-medium (355M) | pretrained | last | last_block | longest train ex. (120) | V100 | 0.91 min | 87.50% | 91.28% | 84.67% | | 5 | gpt2-medium (355M) | pretrained | last | last_block | longest train ex. (120) | V100 | 0.91 min | 87.50% | 91.28% | 84.67% |
| 6 | gpt2-large (774M) | pretrained | last | last_block | longest train ex. (120) | V100 | 1.91 min | 99.52% | 98.66% | 96.67% | | 6 | gpt2-large (774M) | pretrained | last | last_block | longest train ex. (120) | V100 | 1.91 min | 99.52% | 98.66% | 96.67% |
| 7 | gpt2-xl (1558M) | pretrained | last | last_block | longest train ex. (120) | V100 | 3.84 min | 99.81% | 99.33% | 98.33% | | 7 | gpt2-xl (1558M) | pretrained | last | last_block | longest train ex. (120) | V100 | 3.84 min | 99.81% | 99.33% | 98.33% |
| 8 | gpt2-small (124M) | random | last | all | longest train ex. (120) | V100 | 0.93 min | 100% | 96.64% | 93.67% | | 8 | gpt2-small (124M) | random | last | all | longest train ex. (120) | V100 | 0.93 min | 100% | 96.64% | 93.67% |
| 9 | gpt2-small (124M) | pretrained | last | last_block | context length (1024) | V100 | 3.24 min | 83.08% | 87.92% | 78.33% | | 9 | gpt2-small (124M) | pretrained | last | LoRA | longest train ex. (120) | V100 | 0.82 min | 99.52% | 97.99% | 97.67% |
| 10 | gpt2-small (124M) | pretrained | last | last_block | context length (1024) | V100 | 3.24 min | 83.08% | 87.92% | 78.33% |
   
@ -35,7 +37,8 @@ You can use the following code to reproduce the experiments:
- Row 6: `python additional-experiments.py --model_size "gpt2-large (774M)"` - Row 6: `python additional-experiments.py --model_size "gpt2-large (774M)"`
- Row 7: `python additional-experiments.py --model_size "gpt2-xl (1558M)"` - Row 7: `python additional-experiments.py --model_size "gpt2-xl (1558M)"`
- Row 8: `python additional-experiments.py --weights random --trainable_layers all` - Row 8: `python additional-experiments.py --weights random --trainable_layers all`
- Row 9: `python additional-experiments.py --context_length "model_context_length"` - Row 9: `python additional-experiments.py --trainable_layers lora --lora_rank 16 --lora_alpha 8`
- Row 10: `python additional-experiments.py --context_length "model_context_length"`
I've kept the LLM and dataset small on purpose, so you can run the training on a regular laptop like a MacBook Air M3 in about 15 minutes in case you don't have access to a GPU. I've kept the LLM and dataset small on purpose, so you can run the training on a regular laptop like a MacBook Air M3 in about 15 minutes in case you don't have access to a GPU.
@ -53,4 +56,6 @@ I've kept the LLM and dataset small on purpose, so you can run the training on a
5. **Using a Model with Random Weights vs. Pretrained Weights (Row 1 vs. 8)**: Utilizing a model with random weights yields results that are only slightly worse by 1.3% compared to using pretrained weights. 5. **Using a Model with Random Weights vs. Pretrained Weights (Row 1 vs. 8)**: Utilizing a model with random weights yields results that are only slightly worse by 1.3% compared to using pretrained weights.
6. **Padding Input to Full Context Length vs. Longest Training Example (Row 1 vs. 9)**: Padding the input to the full supported context length results is significantly worse. 6. **Using LoRA (Low-Rank Adaptation) vs Training All Layers (Row 9 vs. 4)**: Keeping the model frozen and adding trainable LoRA layers (see [Appendix E](../../appendix-E/01_main-chapter-code/appendix-E.ipynb) for details) is a viable alternative to training all model parameters and even improves the performance by 1% point. As it can be seen by the 1% lower gap between the training and validation accuracy when using LoRA, this is likely due to less overfitting. Moreover, using LoRA is also slightly faster because fewer parameters have to be updated.
7. **Padding Input to Full Context Length vs. Longest Training Example (Row 1 vs. 10)**: Padding the input to the full supported context length results is significantly worse.

View File

@ -20,6 +20,31 @@ from gpt_download import download_and_load_gpt2
from previous_chapters import GPTModel, load_weights_into_gpt from previous_chapters import GPTModel, load_weights_into_gpt
class LoRALayer(torch.nn.Module):
def __init__(self, in_dim, out_dim, rank, alpha):
super().__init__()
std_dev = 1 / torch.sqrt(torch.tensor(rank).float())
self.A = torch.nn.Parameter(torch.randn(in_dim, rank) * std_dev)
self.B = torch.nn.Parameter(torch.zeros(rank, out_dim))
self.alpha = alpha
def forward(self, x):
x = self.alpha * (x @ self.A @ self.B)
return x
class LinearWithLoRA(torch.nn.Module):
def __init__(self, linear, rank, alpha):
super().__init__()
self.linear = linear
self.lora = LoRALayer(
linear.in_features, linear.out_features, rank, alpha
)
def forward(self, x):
return self.linear(x) + self.lora(x)
class SpamDataset(Dataset): class SpamDataset(Dataset):
def __init__(self, csv_file, tokenizer, max_length=None, pad_token_id=50256): def __init__(self, csv_file, tokenizer, max_length=None, pad_token_id=50256):
self.data = pd.read_csv(csv_file) self.data = pd.read_csv(csv_file)
@ -238,6 +263,16 @@ def train_classifier_simple(model, train_loader, val_loader, optimizer, device,
return train_losses, val_losses, train_accs, val_accs, examples_seen return train_losses, val_losses, train_accs, val_accs, examples_seen
def replace_linear_with_lora(model, rank, alpha):
for name, module in model.named_children():
if isinstance(module, torch.nn.Linear):
# Replace the Linear layer with LinearWithLoRA
setattr(model, name, LinearWithLoRA(module, rank, alpha))
else:
# Recursively apply the same function to child modules
replace_linear_with_lora(module, rank, alpha)
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
@ -263,7 +298,7 @@ if __name__ == "__main__":
type=str, type=str,
default="last_block", default="last_block",
help=( help=(
"Which layers to train. Options: 'all', 'last_block', 'last_layer'." "Which layers to train. Options: 'all', 'last_block', 'last_layer', 'lora'."
) )
) )
parser.add_argument( parser.add_argument(
@ -283,6 +318,22 @@ if __name__ == "__main__":
"Options: 'longest_training_example', 'model_context_length' or integer value." "Options: 'longest_training_example', 'model_context_length' or integer value."
) )
) )
parser.add_argument(
"--lora_rank",
type=int,
default=8,
help=(
"The LoRA rank when choosing `--trainable_layers lora`"
)
)
parser.add_argument(
"--lora_alpha",
type=int,
default=8,
help=(
"The LoRA alpha value when choosing `--trainable_layers lora`"
)
)
args = parser.parse_args() args = parser.parse_args()
@ -332,6 +383,8 @@ if __name__ == "__main__":
elif args.trainable_layers == "all": elif args.trainable_layers == "all":
for param in model.parameters(): for param in model.parameters():
param.requires_grad = True param.requires_grad = True
elif args.trainable_layers == "lora":
replace_linear_with_lora(model, rank=args.lora_rank, alpha=args.lora_alpha)
else: else:
raise ValueError("Invalid --trainable_layers argument.") raise ValueError("Invalid --trainable_layers argument.")