mirror of
https://github.com/rasbt/LLMs-from-scratch.git
synced 2025-08-09 09:12:51 +00:00
Add LoRA experiments (#151)
* Add LoRA experiments * Update ch06/02_bonus_additional-experiments/additional-experiments.py
This commit is contained in:
parent
51ac283257
commit
41288a3d3a
@ -17,9 +17,11 @@ For example,
|
||||
| 4 | gpt2-small (124M) | pretrained | last | all | longest train ex. (120) | V100 | 0.94 min | 99.62% | 96.64% | 96.67% |
|
||||
| 5 | gpt2-medium (355M) | pretrained | last | last_block | longest train ex. (120) | V100 | 0.91 min | 87.50% | 91.28% | 84.67% |
|
||||
| 6 | gpt2-large (774M) | pretrained | last | last_block | longest train ex. (120) | V100 | 1.91 min | 99.52% | 98.66% | 96.67% |
|
||||
| 7 | gpt2-xl (1558M) | pretrained | last | last_block | longest train ex. (120) | V100 | 3.84 min | 99.81% | 99.33% | 98.33% |
|
||||
| 7 | gpt2-xl (1558M) | pretrained | last | last_block | longest train ex. (120) | V100 | 3.84 min | 99.81% | 99.33% | 98.33% |
|
||||
| 8 | gpt2-small (124M) | random | last | all | longest train ex. (120) | V100 | 0.93 min | 100% | 96.64% | 93.67% |
|
||||
| 9 | gpt2-small (124M) | pretrained | last | last_block | context length (1024) | V100 | 3.24 min | 83.08% | 87.92% | 78.33% |
|
||||
| 9 | gpt2-small (124M) | pretrained | last | LoRA | longest train ex. (120) | V100 | 0.82 min | 99.52% | 97.99% | 97.67% |
|
||||
| 10 | gpt2-small (124M) | pretrained | last | last_block | context length (1024) | V100 | 3.24 min | 83.08% | 87.92% | 78.33% |
|
||||
|
||||
|
||||
|
||||
|
||||
@ -35,7 +37,8 @@ You can use the following code to reproduce the experiments:
|
||||
- Row 6: `python additional-experiments.py --model_size "gpt2-large (774M)"`
|
||||
- Row 7: `python additional-experiments.py --model_size "gpt2-xl (1558M)"`
|
||||
- Row 8: `python additional-experiments.py --weights random --trainable_layers all`
|
||||
- Row 9: `python additional-experiments.py --context_length "model_context_length"`
|
||||
- Row 9: `python additional-experiments.py --trainable_layers lora --lora_rank 16 --lora_alpha 8`
|
||||
- Row 10: `python additional-experiments.py --context_length "model_context_length"`
|
||||
|
||||
I've kept the LLM and dataset small on purpose, so you can run the training on a regular laptop like a MacBook Air M3 in about 15 minutes in case you don't have access to a GPU.
|
||||
|
||||
@ -53,4 +56,6 @@ I've kept the LLM and dataset small on purpose, so you can run the training on a
|
||||
|
||||
5. **Using a Model with Random Weights vs. Pretrained Weights (Row 1 vs. 8)**: Utilizing a model with random weights yields results that are only slightly worse by 1.3% compared to using pretrained weights.
|
||||
|
||||
6. **Padding Input to Full Context Length vs. Longest Training Example (Row 1 vs. 9)**: Padding the input to the full supported context length results is significantly worse.
|
||||
6. **Using LoRA (Low-Rank Adaptation) vs Training All Layers (Row 9 vs. 4)**: Keeping the model frozen and adding trainable LoRA layers (see [Appendix E](../../appendix-E/01_main-chapter-code/appendix-E.ipynb) for details) is a viable alternative to training all model parameters and even improves the performance by 1% point. As it can be seen by the 1% lower gap between the training and validation accuracy when using LoRA, this is likely due to less overfitting. Moreover, using LoRA is also slightly faster because fewer parameters have to be updated.
|
||||
|
||||
7. **Padding Input to Full Context Length vs. Longest Training Example (Row 1 vs. 10)**: Padding the input to the full supported context length results is significantly worse.
|
||||
|
@ -20,6 +20,31 @@ from gpt_download import download_and_load_gpt2
|
||||
from previous_chapters import GPTModel, load_weights_into_gpt
|
||||
|
||||
|
||||
class LoRALayer(torch.nn.Module):
|
||||
def __init__(self, in_dim, out_dim, rank, alpha):
|
||||
super().__init__()
|
||||
std_dev = 1 / torch.sqrt(torch.tensor(rank).float())
|
||||
self.A = torch.nn.Parameter(torch.randn(in_dim, rank) * std_dev)
|
||||
self.B = torch.nn.Parameter(torch.zeros(rank, out_dim))
|
||||
self.alpha = alpha
|
||||
|
||||
def forward(self, x):
|
||||
x = self.alpha * (x @ self.A @ self.B)
|
||||
return x
|
||||
|
||||
|
||||
class LinearWithLoRA(torch.nn.Module):
|
||||
def __init__(self, linear, rank, alpha):
|
||||
super().__init__()
|
||||
self.linear = linear
|
||||
self.lora = LoRALayer(
|
||||
linear.in_features, linear.out_features, rank, alpha
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear(x) + self.lora(x)
|
||||
|
||||
|
||||
class SpamDataset(Dataset):
|
||||
def __init__(self, csv_file, tokenizer, max_length=None, pad_token_id=50256):
|
||||
self.data = pd.read_csv(csv_file)
|
||||
@ -238,6 +263,16 @@ def train_classifier_simple(model, train_loader, val_loader, optimizer, device,
|
||||
return train_losses, val_losses, train_accs, val_accs, examples_seen
|
||||
|
||||
|
||||
def replace_linear_with_lora(model, rank, alpha):
|
||||
for name, module in model.named_children():
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
# Replace the Linear layer with LinearWithLoRA
|
||||
setattr(model, name, LinearWithLoRA(module, rank, alpha))
|
||||
else:
|
||||
# Recursively apply the same function to child modules
|
||||
replace_linear_with_lora(module, rank, alpha)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
@ -263,7 +298,7 @@ if __name__ == "__main__":
|
||||
type=str,
|
||||
default="last_block",
|
||||
help=(
|
||||
"Which layers to train. Options: 'all', 'last_block', 'last_layer'."
|
||||
"Which layers to train. Options: 'all', 'last_block', 'last_layer', 'lora'."
|
||||
)
|
||||
)
|
||||
parser.add_argument(
|
||||
@ -283,6 +318,22 @@ if __name__ == "__main__":
|
||||
"Options: 'longest_training_example', 'model_context_length' or integer value."
|
||||
)
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora_rank",
|
||||
type=int,
|
||||
default=8,
|
||||
help=(
|
||||
"The LoRA rank when choosing `--trainable_layers lora`"
|
||||
)
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora_alpha",
|
||||
type=int,
|
||||
default=8,
|
||||
help=(
|
||||
"The LoRA alpha value when choosing `--trainable_layers lora`"
|
||||
)
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -332,6 +383,8 @@ if __name__ == "__main__":
|
||||
elif args.trainable_layers == "all":
|
||||
for param in model.parameters():
|
||||
param.requires_grad = True
|
||||
elif args.trainable_layers == "lora":
|
||||
replace_linear_with_lora(model, rank=args.lora_rank, alpha=args.lora_alpha)
|
||||
else:
|
||||
raise ValueError("Invalid --trainable_layers argument.")
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user