diff --git a/ch06/02_bonus_additional-experiments/README.md b/ch06/02_bonus_additional-experiments/README.md index 7178855..e14644f 100644 --- a/ch06/02_bonus_additional-experiments/README.md +++ b/ch06/02_bonus_additional-experiments/README.md @@ -21,7 +21,8 @@ For example, | 8 | gpt2-small (124M) | random | last | all | longest train ex. (120) | 100% | 96.64% | 93.67% | 0.69 min | A100 | | 9 | gpt2-small (124M) | pretrained | last | LoRA | longest train ex. (120) | 99.52% | 97.99% | 97.67% | 0.75 min | A100 | | 10 | gpt2-small (124M) | pretrained | last | last_block | context length (1024) | 83.08% | 87.92% | 78.33% | 2.46 min | A100 | -| 11 | gpt2-small (124M) | pretrained | last | last_block | variable: no padding | 97.42% | 95.30% | 95.00% | 1.71 min | A100 | +| 11 | gpt2-small (124M) | pretrained | last | last_block | variable: no padding (batch size 1) | 100.00% | 98.66% | 98.00% | 1.75 min | A100 | +| 11 | gpt2-small (124M) | pretrained | last | last_block | variable: no padding (batch size 8) | 99.33% | 98.66% | 98.33% | 1.70 min | A100 |   @@ -40,7 +41,8 @@ You can use the following code to reproduce the experiments: - Row 8: `python additional-experiments.py --weights random --trainable_layers all` - Row 9: `python additional-experiments.py --trainable_layers lora --lora_rank 16 --lora_alpha 8` - Row 10: `python additional-experiments.py --context_length "model_context_length"` -- Row 11: `python additional-experiments.py --no_padding` +- Row 11: `python additional-experiments.py --no_padding --batch_size 1` +- Row 12: `python additional-experiments.py --no_padding --batch_size 1 --accumulation_steps 8` I've kept the LLM and dataset small on purpose, so you can run the training on a regular laptop like a MacBook Air M3 in about 15 minutes in case you don't have access to a GPU. @@ -62,4 +64,4 @@ I've kept the LLM and dataset small on purpose, so you can run the training on a 7. **Padding Input to Full Context Length vs. Longest Training Example (Row 1 vs. 10)**: Padding the input to the full supported context length results is significantly worse. -8. **Padding vs no padding (Row 1 vs 11)**: The `--no_padding` option disables the padding in the dataset and trains the model with a batch size of 1 where the inputs have variable lengths. This results in exactly the same test set accuracy but takes substantially longer to train. +8. **Padding vs no padding (Row 1 vs. 11 and 12)**: The `--no_padding` option disables the padding in the dataset, which requires training the model with a batch size of 1 since the inputs have variable lengths. This results in a better test accuracy but takes longer to train. In row 12, we additionally enable gradient accumulation with 8 steps to achieve the same batch size as in the other experiments, which helps reduce overfitting and slightly boost the test set accuracy. diff --git a/ch06/02_bonus_additional-experiments/additional-experiments.py b/ch06/02_bonus_additional-experiments/additional-experiments.py index ae60d8d..7492ed6 100644 --- a/ch06/02_bonus_additional-experiments/additional-experiments.py +++ b/ch06/02_bonus_additional-experiments/additional-experiments.py @@ -222,7 +222,8 @@ def evaluate_model(model, train_loader, val_loader, device, eval_iter, trainable def train_classifier_simple(model, train_loader, val_loader, optimizer, device, num_epochs, - eval_freq, eval_iter, tokenizer, max_steps=None, trainable_token=-1): + eval_freq, eval_iter, tokenizer, max_steps=None, trainable_token=-1, + accumulation_steps=1): # Initialize lists to track losses and tokens seen train_losses, val_losses, train_accs, val_accs = [], [], [], [] examples_seen, global_step = 0, -1 @@ -231,11 +232,21 @@ def train_classifier_simple(model, train_loader, val_loader, optimizer, device, for epoch in range(num_epochs): model.train() # Set model to training mode - for input_batch, target_batch in train_loader: - optimizer.zero_grad() # Reset loss gradients from previous epoch + for batch_idx, (input_batch, target_batch) in enumerate(train_loader): loss = calc_loss_batch(input_batch, target_batch, model, device, trainable_token=trainable_token) + + # Use gradient accumulation if accumulation_steps > 1 + # See https://sebastianraschka.com/blog/2023/llm-grad-accumulation.html + # for an explanation + loss /= accumulation_steps + loss.backward() # Calculate loss gradients - optimizer.step() # Update model weights using loss gradients + + # Use gradient accumulation if accumulation_steps > 1 + if batch_idx % accumulation_steps == 0: + optimizer.step() # Update model weights using loss gradients + optimizer.zero_grad() # Reset loss gradients from previous epoch + examples_seen += input_batch.shape[0] # New: track examples instead of tokens global_step += 1 @@ -341,8 +352,8 @@ if __name__ == "__main__": action='store_true', default=False, help=( - "Enable no padding. When this flag is set it will train" - " the model with a batch size of 1 and no padding." + "Disable padding, which means each example may have a different lenght." + " This requires setting `--batch_size 1`." ) ) parser.add_argument( @@ -353,6 +364,27 @@ if __name__ == "__main__": "Number of training epochs." ) ) + parser.add_argument( + "--batch_size", + type=int, + default=8, + help=( + "The batch size used for training." + ) + ) + + parser.add_argument( + "--accumulation_steps", + type=int, + default=1, + help=( + "Accumulation steps to allow for gradient accumulation." + " See https://sebastianraschka.com/blog/2023/llm-grad-accumulation.html for explanation." + " For example, setting `batch_size=8` and `accumulation_steps=1` compute the exact same" + " loss and weight updates as setting `batch_size=1` and `accumulation_steps=8`, however," + " the latter setting uses more iterations." + ) + ) args = parser.parse_args() @@ -455,14 +487,9 @@ if __name__ == "__main__": num_workers = 0 - if args.no_padding: - batch_size = 1 - else: - batch_size = 8 - train_loader = DataLoader( dataset=train_dataset, - batch_size=batch_size, + batch_size=args.batch_size, shuffle=True, num_workers=num_workers, drop_last=True, @@ -470,14 +497,14 @@ if __name__ == "__main__": val_loader = DataLoader( dataset=val_dataset, - batch_size=batch_size, + batch_size=args.batch_size, num_workers=num_workers, drop_last=False, ) test_loader = DataLoader( dataset=test_dataset, - batch_size=batch_size, + batch_size=args.batch_size, num_workers=num_workers, drop_last=False, ) @@ -493,7 +520,8 @@ if __name__ == "__main__": train_losses, val_losses, train_accs, val_accs, examples_seen = train_classifier_simple( model, train_loader, val_loader, optimizer, device, num_epochs=args.num_epochs, eval_freq=50, eval_iter=5, - tokenizer=tokenizer, max_steps=None, trainable_token=args.trainable_token + tokenizer=tokenizer, max_steps=None, trainable_token=args.trainable_token, + accumulation_steps=args.accumulation_steps ) end_time = time.time() @@ -510,4 +538,4 @@ if __name__ == "__main__": print(f"Training accuracy: {train_accuracy*100:.2f}%") print(f"Validation accuracy: {val_accuracy*100:.2f}%") - print(f"Test accuracy: {test_accuracy*100:.2f}%") \ No newline at end of file + print(f"Test accuracy: {test_accuracy*100:.2f}%")